minhvtt commited on
Commit
6c982a7
·
verified ·
1 Parent(s): adec8cd

Upload 6 files

Browse files
Files changed (6) hide show
  1. chatbot_rag.py +351 -0
  2. chatbot_rag_api.py +467 -0
  3. embedding_service.py +173 -0
  4. main.py +352 -0
  5. qdrant_service.py +447 -0
  6. requirements.txt +27 -0
chatbot_rag.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from pymongo import MongoClient
4
+ from datetime import datetime
5
+ from typing import List, Dict
6
+ import numpy as np
7
+
8
+ from embedding_service import JinaClipEmbeddingService
9
+ from qdrant_service import QdrantVectorService
10
+
11
+
12
+ class ChatbotRAG:
13
+ """
14
+ Chatbot RAG với:
15
+ - LLM: GPT-OSS-20B (Hugging Face)
16
+ - Embeddings: Jina CLIP v2
17
+ - Vector DB: Qdrant
18
+ - Document Store: MongoDB
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ mongodb_uri: str = "mongodb+srv://truongtn7122003:[email protected]/",
24
+ db_name: str = "chatbot_rag",
25
+ collection_name: str = "documents"
26
+ ):
27
+ """
28
+ Initialize ChatbotRAG
29
+
30
+ Args:
31
+ mongodb_uri: MongoDB connection string
32
+ db_name: Database name
33
+ collection_name: Collection name for documents
34
+ """
35
+ print("Initializing ChatbotRAG...")
36
+
37
+ # MongoDB client
38
+ self.mongo_client = MongoClient(mongodb_uri)
39
+ self.db = self.mongo_client[db_name]
40
+ self.documents_collection = self.db[collection_name]
41
+ self.chat_history_collection = self.db["chat_history"]
42
+
43
+ # Embedding service (Jina CLIP v2)
44
+ self.embedding_service = JinaClipEmbeddingService(
45
+ model_path="jinaai/jina-clip-v2"
46
+ )
47
+
48
+ # Qdrant vector service
49
+ self.qdrant_service = QdrantVectorService(
50
+ collection_name="chatbot_rag_vectors",
51
+ vector_size=self.embedding_service.get_embedding_dimension()
52
+ )
53
+
54
+ print("✓ ChatbotRAG initialized successfully")
55
+
56
+ def add_document(self, text: str, metadata: Dict = None) -> str:
57
+ """
58
+ Add document to MongoDB and Qdrant
59
+
60
+ Args:
61
+ text: Document text
62
+ metadata: Additional metadata
63
+
64
+ Returns:
65
+ Document ID
66
+ """
67
+ # Save to MongoDB
68
+ doc_data = {
69
+ "text": text,
70
+ "metadata": metadata or {},
71
+ "created_at": datetime.utcnow()
72
+ }
73
+ result = self.documents_collection.insert_one(doc_data)
74
+ doc_id = str(result.inserted_id)
75
+
76
+ # Generate embedding
77
+ embedding = self.embedding_service.encode_text(text)
78
+
79
+ # Index to Qdrant
80
+ self.qdrant_service.index_data(
81
+ doc_id=doc_id,
82
+ embedding=embedding,
83
+ metadata={
84
+ "text": text,
85
+ "source": "user_upload",
86
+ **(metadata or {})
87
+ }
88
+ )
89
+
90
+ return doc_id
91
+
92
+ def retrieve_context(self, query: str, top_k: int = 3) -> List[Dict]:
93
+ """
94
+ Retrieve relevant context from vector DB
95
+
96
+ Args:
97
+ query: User query
98
+ top_k: Number of results to retrieve
99
+
100
+ Returns:
101
+ List of relevant documents
102
+ """
103
+ # Generate query embedding
104
+ query_embedding = self.embedding_service.encode_text(query)
105
+
106
+ # Search in Qdrant
107
+ results = self.qdrant_service.search(
108
+ query_embedding=query_embedding,
109
+ limit=top_k,
110
+ score_threshold=0.5 # Only get relevant results
111
+ )
112
+
113
+ return results
114
+
115
+ def save_chat_history(self, user_message: str, assistant_response: str, context_used: List[Dict]):
116
+ """
117
+ Save chat interaction to MongoDB
118
+
119
+ Args:
120
+ user_message: User's message
121
+ assistant_response: Assistant's response
122
+ context_used: Context retrieved from RAG
123
+ """
124
+ chat_data = {
125
+ "user_message": user_message,
126
+ "assistant_response": assistant_response,
127
+ "context_used": context_used,
128
+ "timestamp": datetime.utcnow()
129
+ }
130
+ self.chat_history_collection.insert_one(chat_data)
131
+
132
+ def respond(
133
+ self,
134
+ message: str,
135
+ history: List[Dict[str, str]],
136
+ system_message: str,
137
+ max_tokens: int,
138
+ temperature: float,
139
+ top_p: float,
140
+ use_rag: bool,
141
+ hf_token: gr.OAuthToken,
142
+ ):
143
+ """
144
+ Generate response with RAG
145
+
146
+ Args:
147
+ message: User message
148
+ history: Chat history
149
+ system_message: System prompt
150
+ max_tokens: Max tokens to generate
151
+ temperature: Temperature for generation
152
+ top_p: Top-p sampling
153
+ use_rag: Whether to use RAG retrieval
154
+ hf_token: Hugging Face token
155
+
156
+ Yields:
157
+ Generated response
158
+ """
159
+ # Initialize LLM client
160
+ client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
161
+
162
+ # Prepare context from RAG
163
+ context_text = ""
164
+ context_used = []
165
+
166
+ if use_rag:
167
+ # Retrieve relevant context
168
+ retrieved_docs = self.retrieve_context(message, top_k=3)
169
+ context_used = retrieved_docs
170
+
171
+ if retrieved_docs:
172
+ context_text = "\n\n**Relevant Context:**\n"
173
+ for i, doc in enumerate(retrieved_docs, 1):
174
+ doc_text = doc["metadata"].get("text", "")
175
+ confidence = doc["confidence"]
176
+ context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n"
177
+
178
+ # Add context to system message
179
+ system_message = f"{system_message}\n\n{context_text}\n\nPlease use the above context to answer the user's question when relevant."
180
+
181
+ # Build messages for LLM
182
+ messages = [{"role": "system", "content": system_message}]
183
+ messages.extend(history)
184
+ messages.append({"role": "user", "content": message})
185
+
186
+ # Generate response
187
+ response = ""
188
+
189
+ try:
190
+ for msg in client.chat_completion(
191
+ messages,
192
+ max_tokens=max_tokens,
193
+ stream=True,
194
+ temperature=temperature,
195
+ top_p=top_p,
196
+ ):
197
+ choices = msg.choices
198
+ token = ""
199
+ if len(choices) and choices[0].delta.content:
200
+ token = choices[0].delta.content
201
+
202
+ response += token
203
+ yield response
204
+
205
+ # Save to chat history
206
+ self.save_chat_history(message, response, context_used)
207
+
208
+ except Exception as e:
209
+ error_msg = f"Error generating response: {str(e)}"
210
+ yield error_msg
211
+
212
+
213
+ # Initialize ChatbotRAG
214
+ chatbot_rag = ChatbotRAG()
215
+
216
+
217
+ def respond_wrapper(
218
+ message,
219
+ history,
220
+ system_message,
221
+ max_tokens,
222
+ temperature,
223
+ top_p,
224
+ use_rag,
225
+ hf_token,
226
+ ):
227
+ """Wrapper for Gradio ChatInterface"""
228
+ yield from chatbot_rag.respond(
229
+ message=message,
230
+ history=history,
231
+ system_message=system_message,
232
+ max_tokens=max_tokens,
233
+ temperature=temperature,
234
+ top_p=top_p,
235
+ use_rag=use_rag,
236
+ hf_token=hf_token,
237
+ )
238
+
239
+
240
+ def add_document_to_rag(text: str) -> str:
241
+ """
242
+ Add document to RAG knowledge base
243
+
244
+ Args:
245
+ text: Document text
246
+
247
+ Returns:
248
+ Success message
249
+ """
250
+ try:
251
+ doc_id = chatbot_rag.add_document(text)
252
+ return f"✓ Document added successfully! ID: {doc_id}"
253
+ except Exception as e:
254
+ return f"✗ Error adding document: {str(e)}"
255
+
256
+
257
+ # Create Gradio interface
258
+ with gr.Blocks(title="ChatbotRAG - GPT-OSS-20B + Jina CLIP v2 + MongoDB") as demo:
259
+ gr.Markdown("""
260
+ # 🤖 ChatbotRAG
261
+
262
+ **Features:**
263
+ - 💬 LLM: GPT-OSS-20B
264
+ - 🔍 Embeddings: Jina CLIP v2 (Vietnamese support)
265
+ - 📊 Vector DB: Qdrant Cloud
266
+ - 🗄️ Document Store: MongoDB
267
+
268
+ **How to use:**
269
+ 1. Add documents to knowledge base (optional)
270
+ 2. Toggle "Use RAG" to enable context retrieval
271
+ 3. Chat with the bot!
272
+ """)
273
+
274
+ with gr.Sidebar():
275
+ gr.LoginButton()
276
+
277
+ gr.Markdown("### ⚙️ Settings")
278
+
279
+ use_rag = gr.Checkbox(
280
+ label="Use RAG",
281
+ value=True,
282
+ info="Enable RAG to retrieve relevant context from knowledge base"
283
+ )
284
+
285
+ system_message = gr.Textbox(
286
+ value="You are a helpful AI assistant. Answer questions based on the provided context when available.",
287
+ label="System message",
288
+ lines=3
289
+ )
290
+
291
+ max_tokens = gr.Slider(
292
+ minimum=1,
293
+ maximum=2048,
294
+ value=512,
295
+ step=1,
296
+ label="Max new tokens"
297
+ )
298
+
299
+ temperature = gr.Slider(
300
+ minimum=0.1,
301
+ maximum=4.0,
302
+ value=0.7,
303
+ step=0.1,
304
+ label="Temperature"
305
+ )
306
+
307
+ top_p = gr.Slider(
308
+ minimum=0.1,
309
+ maximum=1.0,
310
+ value=0.95,
311
+ step=0.05,
312
+ label="Top-p (nucleus sampling)"
313
+ )
314
+
315
+ # Chat interface
316
+ chatbot = gr.ChatInterface(
317
+ respond_wrapper,
318
+ type="messages",
319
+ additional_inputs=[
320
+ system_message,
321
+ max_tokens,
322
+ temperature,
323
+ top_p,
324
+ use_rag,
325
+ ],
326
+ )
327
+
328
+ # Document management
329
+ with gr.Accordion("📚 Knowledge Base Management", open=False):
330
+ gr.Markdown("### Add Documents to Knowledge Base")
331
+
332
+ doc_text = gr.Textbox(
333
+ label="Document Text",
334
+ placeholder="Enter document text here...",
335
+ lines=5
336
+ )
337
+
338
+ add_btn = gr.Button("Add Document", variant="primary")
339
+ output_msg = gr.Textbox(label="Status", interactive=False)
340
+
341
+ add_btn.click(
342
+ fn=add_document_to_rag,
343
+ inputs=[doc_text],
344
+ outputs=[output_msg]
345
+ )
346
+
347
+ chatbot.render()
348
+
349
+
350
+ if __name__ == "__main__":
351
+ demo.launch(server_name="0.0.0.0", server_port=7860)
chatbot_rag_api.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List, Dict
5
+ from pymongo import MongoClient
6
+ from datetime import datetime
7
+ import numpy as np
8
+ import os
9
+ from huggingface_hub import InferenceClient
10
+
11
+ from embedding_service import JinaClipEmbeddingService
12
+ from qdrant_service import QdrantVectorService
13
+
14
+
15
+ # Pydantic models
16
+ class ChatRequest(BaseModel):
17
+ message: str
18
+ use_rag: bool = True
19
+ top_k: int = 3
20
+ system_message: Optional[str] = "You are a helpful AI assistant."
21
+ max_tokens: int = 512
22
+ temperature: float = 0.7
23
+ top_p: float = 0.95
24
+ hf_token: Optional[str] = None # Hugging Face token (optional, sẽ dùng env nếu không truyền)
25
+
26
+
27
+ class ChatResponse(BaseModel):
28
+ response: str
29
+ context_used: List[Dict]
30
+ timestamp: str
31
+
32
+
33
+ class AddDocumentRequest(BaseModel):
34
+ text: str
35
+ metadata: Optional[Dict] = None
36
+
37
+
38
+ class AddDocumentResponse(BaseModel):
39
+ success: bool
40
+ doc_id: str
41
+ message: str
42
+
43
+
44
+ class SearchRequest(BaseModel):
45
+ query: str
46
+ top_k: int = 5
47
+ score_threshold: Optional[float] = 0.5
48
+
49
+
50
+ class SearchResponse(BaseModel):
51
+ results: List[Dict]
52
+
53
+
54
+ # Initialize FastAPI
55
+ app = FastAPI(
56
+ title="ChatbotRAG API",
57
+ description="API for RAG Chatbot with GPT-OSS-20B + Jina CLIP v2 + MongoDB + Qdrant",
58
+ version="1.0.0"
59
+ )
60
+
61
+ # CORS middleware
62
+ app.add_middleware(
63
+ CORSMiddleware,
64
+ allow_origins=["*"], # Cho phép tất cả origins (có thể giới hạn trong production)
65
+ allow_credentials=True,
66
+ allow_methods=["*"],
67
+ allow_headers=["*"],
68
+ )
69
+
70
+
71
+ # ChatbotRAG Service
72
+ class ChatbotRAGService:
73
+ """
74
+ ChatbotRAG Service cho API
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ mongodb_uri: str = "mongodb+srv://truongtn7122003:[email protected]/",
80
+ db_name: str = "chatbot_rag",
81
+ collection_name: str = "documents",
82
+ hf_token: Optional[str] = None
83
+ ):
84
+ print("Initializing ChatbotRAG Service...")
85
+
86
+ # MongoDB
87
+ self.mongo_client = MongoClient(mongodb_uri)
88
+ self.db = self.mongo_client[db_name]
89
+ self.documents_collection = self.db[collection_name]
90
+ self.chat_history_collection = self.db["chat_history"]
91
+
92
+ # Embedding service
93
+ self.embedding_service = JinaClipEmbeddingService(
94
+ model_path="jinaai/jina-clip-v2"
95
+ )
96
+
97
+ # Qdrant
98
+ self.qdrant_service = QdrantVectorService(
99
+ collection_name="chatbot_rag_vectors",
100
+ vector_size=self.embedding_service.get_embedding_dimension()
101
+ )
102
+
103
+ # Hugging Face token (từ env hoặc truyền vào)
104
+ self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN")
105
+ if self.hf_token:
106
+ print("✓ Hugging Face token configured")
107
+ else:
108
+ print("⚠ No Hugging Face token - LLM generation will use placeholder")
109
+
110
+ print("✓ ChatbotRAG Service initialized")
111
+
112
+ def add_document(self, text: str, metadata: Dict = None) -> str:
113
+ """Add document to knowledge base"""
114
+ # Save to MongoDB
115
+ doc_data = {
116
+ "text": text,
117
+ "metadata": metadata or {},
118
+ "created_at": datetime.utcnow()
119
+ }
120
+ result = self.documents_collection.insert_one(doc_data)
121
+ doc_id = str(result.inserted_id)
122
+
123
+ # Generate embedding
124
+ embedding = self.embedding_service.encode_text(text)
125
+
126
+ # Index to Qdrant
127
+ self.qdrant_service.index_data(
128
+ doc_id=doc_id,
129
+ embedding=embedding,
130
+ metadata={
131
+ "text": text,
132
+ "source": "api",
133
+ **(metadata or {})
134
+ }
135
+ )
136
+
137
+ return doc_id
138
+
139
+ def retrieve_context(self, query: str, top_k: int = 3, score_threshold: float = 0.5) -> List[Dict]:
140
+ """Retrieve relevant context from vector DB"""
141
+ # Generate query embedding
142
+ query_embedding = self.embedding_service.encode_text(query)
143
+
144
+ # Search in Qdrant
145
+ results = self.qdrant_service.search(
146
+ query_embedding=query_embedding,
147
+ limit=top_k,
148
+ score_threshold=score_threshold
149
+ )
150
+
151
+ return results
152
+
153
+ def generate_response(
154
+ self,
155
+ message: str,
156
+ context: List[Dict],
157
+ system_message: str,
158
+ max_tokens: int = 512,
159
+ temperature: float = 0.7,
160
+ top_p: float = 0.95,
161
+ hf_token: Optional[str] = None
162
+ ) -> str:
163
+ """
164
+ Generate response using Hugging Face LLM
165
+ """
166
+ # Build context text
167
+ context_text = ""
168
+ if context:
169
+ context_text = "\n\nRelevant Context:\n"
170
+ for i, doc in enumerate(context, 1):
171
+ doc_text = doc["metadata"].get("text", "")
172
+ confidence = doc["confidence"]
173
+ context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n"
174
+
175
+ # Add context to system message
176
+ system_message = f"{system_message}\n{context_text}\n\nPlease use the above context to answer the user's question when relevant."
177
+
178
+ # Use token from request or fallback to service token
179
+ token = hf_token or self.hf_token
180
+
181
+ # If no token available, return placeholder
182
+ if not token:
183
+ return f"""[LLM Response Placeholder]
184
+
185
+ Context retrieved: {len(context)} documents
186
+ User question: {message}
187
+
188
+ To enable actual LLM generation:
189
+ 1. Set HUGGINGFACE_TOKEN environment variable, OR
190
+ 2. Pass hf_token in request body
191
+
192
+ Example:
193
+ {{
194
+ "message": "Your question",
195
+ "hf_token": "hf_xxxxxxxxxxxxx"
196
+ }}
197
+ """
198
+
199
+ # Initialize HF Inference Client
200
+ try:
201
+ client = InferenceClient(
202
+ token=token,
203
+ model="openai/gpt-oss-20b"
204
+ )
205
+
206
+ # Build messages
207
+ messages = [
208
+ {"role": "system", "content": system_message},
209
+ {"role": "user", "content": message}
210
+ ]
211
+
212
+ # Generate response (non-streaming for API)
213
+ response = ""
214
+ for msg in client.chat_completion(
215
+ messages,
216
+ max_tokens=max_tokens,
217
+ stream=True,
218
+ temperature=temperature,
219
+ top_p=top_p,
220
+ ):
221
+ choices = msg.choices
222
+ if len(choices) and choices[0].delta.content:
223
+ response += choices[0].delta.content
224
+
225
+ return response
226
+
227
+ except Exception as e:
228
+ return f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed."
229
+
230
+ def save_chat_history(self, user_message: str, assistant_response: str, context_used: List[Dict]):
231
+ """Save chat to MongoDB"""
232
+ chat_data = {
233
+ "user_message": user_message,
234
+ "assistant_response": assistant_response,
235
+ "context_used": context_used,
236
+ "timestamp": datetime.utcnow()
237
+ }
238
+ self.chat_history_collection.insert_one(chat_data)
239
+
240
+ def get_stats(self) -> Dict:
241
+ """Get statistics"""
242
+ return {
243
+ "documents_count": self.documents_collection.count_documents({}),
244
+ "chat_history_count": self.chat_history_collection.count_documents({}),
245
+ "qdrant_info": self.qdrant_service.get_collection_info()
246
+ }
247
+
248
+
249
+ # Initialize service
250
+ rag_service = ChatbotRAGService()
251
+
252
+
253
+ # API Endpoints
254
+
255
+ @app.get("/")
256
+ async def root():
257
+ """Health check"""
258
+ return {
259
+ "status": "running",
260
+ "service": "ChatbotRAG API",
261
+ "version": "1.0.0",
262
+ "endpoints": {
263
+ "POST /chat": "Chat with RAG",
264
+ "POST /documents": "Add document to knowledge base",
265
+ "POST /search": "Search in knowledge base",
266
+ "GET /stats": "Get statistics",
267
+ "GET /history": "Get chat history"
268
+ }
269
+ }
270
+
271
+
272
+ @app.post("/chat", response_model=ChatResponse)
273
+ async def chat(request: ChatRequest):
274
+ """
275
+ Chat endpoint with RAG
276
+
277
+ Body:
278
+ - message: User message
279
+ - use_rag: Enable RAG retrieval (default: true)
280
+ - top_k: Number of documents to retrieve (default: 3)
281
+ - system_message: System prompt (optional)
282
+ - max_tokens: Max tokens for response (default: 512)
283
+ - temperature: Temperature for generation (default: 0.7)
284
+
285
+ Returns:
286
+ - response: Generated response
287
+ - context_used: Retrieved context documents
288
+ - timestamp: Response timestamp
289
+ """
290
+ try:
291
+ # Retrieve context if RAG enabled
292
+ context_used = []
293
+ if request.use_rag:
294
+ context_used = rag_service.retrieve_context(
295
+ query=request.message,
296
+ top_k=request.top_k
297
+ )
298
+
299
+ # Generate response
300
+ response = rag_service.generate_response(
301
+ message=request.message,
302
+ context=context_used,
303
+ system_message=request.system_message,
304
+ max_tokens=request.max_tokens,
305
+ temperature=request.temperature,
306
+ top_p=request.top_p,
307
+ hf_token=request.hf_token
308
+ )
309
+
310
+ # Save to history
311
+ rag_service.save_chat_history(
312
+ user_message=request.message,
313
+ assistant_response=response,
314
+ context_used=context_used
315
+ )
316
+
317
+ return ChatResponse(
318
+ response=response,
319
+ context_used=context_used,
320
+ timestamp=datetime.utcnow().isoformat()
321
+ )
322
+
323
+ except Exception as e:
324
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
325
+
326
+
327
+ @app.post("/documents", response_model=AddDocumentResponse)
328
+ async def add_document(request: AddDocumentRequest):
329
+ """
330
+ Add document to knowledge base
331
+
332
+ Body:
333
+ - text: Document text
334
+ - metadata: Additional metadata (optional)
335
+
336
+ Returns:
337
+ - success: True/False
338
+ - doc_id: MongoDB document ID
339
+ - message: Status message
340
+ """
341
+ try:
342
+ doc_id = rag_service.add_document(
343
+ text=request.text,
344
+ metadata=request.metadata
345
+ )
346
+
347
+ return AddDocumentResponse(
348
+ success=True,
349
+ doc_id=doc_id,
350
+ message=f"Document added successfully with ID: {doc_id}"
351
+ )
352
+
353
+ except Exception as e:
354
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
355
+
356
+
357
+ @app.post("/search", response_model=SearchResponse)
358
+ async def search(request: SearchRequest):
359
+ """
360
+ Search in knowledge base
361
+
362
+ Body:
363
+ - query: Search query
364
+ - top_k: Number of results (default: 5)
365
+ - score_threshold: Minimum score (default: 0.5)
366
+
367
+ Returns:
368
+ - results: List of matching documents
369
+ """
370
+ try:
371
+ results = rag_service.retrieve_context(
372
+ query=request.query,
373
+ top_k=request.top_k,
374
+ score_threshold=request.score_threshold
375
+ )
376
+
377
+ return SearchResponse(results=results)
378
+
379
+ except Exception as e:
380
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
381
+
382
+
383
+ @app.get("/stats")
384
+ async def get_stats():
385
+ """
386
+ Get statistics
387
+
388
+ Returns:
389
+ - documents_count: Number of documents in MongoDB
390
+ - chat_history_count: Number of chat messages
391
+ - qdrant_info: Qdrant collection info
392
+ """
393
+ try:
394
+ return rag_service.get_stats()
395
+ except Exception as e:
396
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
397
+
398
+
399
+ @app.get("/history")
400
+ async def get_history(limit: int = 10, skip: int = 0):
401
+ """
402
+ Get chat history
403
+
404
+ Query params:
405
+ - limit: Number of messages to return (default: 10)
406
+ - skip: Number of messages to skip (default: 0)
407
+
408
+ Returns:
409
+ - history: List of chat messages
410
+ """
411
+ try:
412
+ history = list(
413
+ rag_service.chat_history_collection
414
+ .find({}, {"_id": 0})
415
+ .sort("timestamp", -1)
416
+ .skip(skip)
417
+ .limit(limit)
418
+ )
419
+
420
+ # Convert datetime to string
421
+ for msg in history:
422
+ if "timestamp" in msg:
423
+ msg["timestamp"] = msg["timestamp"].isoformat()
424
+
425
+ return {"history": history, "total": rag_service.chat_history_collection.count_documents({})}
426
+
427
+ except Exception as e:
428
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
429
+
430
+
431
+ @app.delete("/documents/{doc_id}")
432
+ async def delete_document(doc_id: str):
433
+ """
434
+ Delete document from knowledge base
435
+
436
+ Args:
437
+ - doc_id: Document ID (MongoDB ObjectId)
438
+
439
+ Returns:
440
+ - success: True/False
441
+ - message: Status message
442
+ """
443
+ try:
444
+ # Delete from MongoDB
445
+ result = rag_service.documents_collection.delete_one({"_id": doc_id})
446
+
447
+ # Delete from Qdrant
448
+ if result.deleted_count > 0:
449
+ rag_service.qdrant_service.delete_by_id(doc_id)
450
+ return {"success": True, "message": f"Document {doc_id} deleted"}
451
+ else:
452
+ raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
453
+
454
+ except HTTPException:
455
+ raise
456
+ except Exception as e:
457
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
458
+
459
+
460
+ if __name__ == "__main__":
461
+ import uvicorn
462
+ uvicorn.run(
463
+ app,
464
+ host="0.0.0.0",
465
+ port=8000,
466
+ log_level="info"
467
+ )
embedding_service.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import AutoModel
5
+ from typing import Union, List
6
+ import io
7
+
8
+
9
+ class JinaClipEmbeddingService:
10
+ """
11
+ Jina CLIP v2 Embedding Service với hỗ trợ tiếng Việt
12
+ Sử dụng AutoModel với trust_remote_code
13
+ """
14
+
15
+ def __init__(self, model_path: str = "jinaai/jina-clip-v2"):
16
+ """
17
+ Initialize Jina CLIP v2 model
18
+
19
+ Args:
20
+ model_path: Path to model hoặc HuggingFace model name
21
+ """
22
+ print(f"Loading Jina CLIP v2 model from {model_path}...")
23
+
24
+ # Load model với trust_remote_code
25
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
26
+
27
+ # Chuyển sang eval mode
28
+ self.model.eval()
29
+
30
+ # Sử dụng GPU nếu có
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ self.model.to(self.device)
33
+
34
+ print(f"✓ Loaded Jina CLIP v2 model on: {self.device}")
35
+
36
+ def encode_text(
37
+ self,
38
+ text: Union[str, List[str]],
39
+ truncate_dim: int = None,
40
+ normalize: bool = True
41
+ ) -> np.ndarray:
42
+ """
43
+ Encode text thành vector embeddings (hỗ trợ tiếng Việt)
44
+
45
+ Args:
46
+ text: Text hoặc list of texts (tiếng Việt)
47
+ truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
48
+ normalize: Có normalize embeddings không
49
+
50
+ Returns:
51
+ numpy array của embeddings
52
+ """
53
+ if isinstance(text, str):
54
+ text = [text]
55
+
56
+ # Jina CLIP v2 encode_text method
57
+ # Automatically handles tokenization internally
58
+ embeddings = self.model.encode_text(
59
+ text,
60
+ truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
61
+ )
62
+
63
+ # Convert to numpy
64
+ if isinstance(embeddings, torch.Tensor):
65
+ embeddings = embeddings.cpu().detach().numpy()
66
+
67
+ # Normalize nếu cần
68
+ if normalize:
69
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
70
+
71
+ return embeddings
72
+
73
+ def encode_image(
74
+ self,
75
+ image: Union[Image.Image, bytes, List, str],
76
+ truncate_dim: int = None,
77
+ normalize: bool = True
78
+ ) -> np.ndarray:
79
+ """
80
+ Encode image thành vector embeddings
81
+
82
+ Args:
83
+ image: PIL Image, bytes, URL string, hoặc list of images
84
+ truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
85
+ normalize: Có normalize embeddings không
86
+
87
+ Returns:
88
+ numpy array của embeddings
89
+ """
90
+ # Convert bytes to PIL Image nếu cần
91
+ if isinstance(image, bytes):
92
+ image = Image.open(io.BytesIO(image)).convert('RGB')
93
+ elif isinstance(image, list):
94
+ processed_images = []
95
+ for img in image:
96
+ if isinstance(img, bytes):
97
+ processed_images.append(Image.open(io.BytesIO(img)).convert('RGB'))
98
+ elif isinstance(img, str):
99
+ # URL string - keep as is, Jina CLIP can handle URLs
100
+ processed_images.append(img)
101
+ else:
102
+ processed_images.append(img)
103
+ image = processed_images
104
+ elif not isinstance(image, list) and not isinstance(image, str):
105
+ # Single PIL Image
106
+ image = [image]
107
+
108
+ # Jina CLIP v2 encode_image method
109
+ # Supports PIL Images, file paths, or URLs
110
+ embeddings = self.model.encode_image(
111
+ image,
112
+ truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
113
+ )
114
+
115
+ # Convert to numpy
116
+ if isinstance(embeddings, torch.Tensor):
117
+ embeddings = embeddings.cpu().detach().numpy()
118
+
119
+ # Normalize nếu cần
120
+ if normalize:
121
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
122
+
123
+ return embeddings
124
+
125
+ def encode_multimodal(
126
+ self,
127
+ text: Union[str, List[str]] = None,
128
+ image: Union[Image.Image, bytes, List] = None,
129
+ truncate_dim: int = None,
130
+ normalize: bool = True
131
+ ) -> np.ndarray:
132
+ """
133
+ Encode cả text và image, trả về embeddings kết hợp
134
+
135
+ Args:
136
+ text: Text hoặc list of texts (tiếng Việt)
137
+ image: PIL Image, bytes, hoặc list of images
138
+ truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
139
+ normalize: Có normalize embeddings không
140
+
141
+ Returns:
142
+ numpy array của embeddings
143
+ """
144
+ embeddings = []
145
+
146
+ if text is not None:
147
+ text_emb = self.encode_text(text, truncate_dim=truncate_dim, normalize=False)
148
+ embeddings.append(text_emb)
149
+
150
+ if image is not None:
151
+ image_emb = self.encode_image(image, truncate_dim=truncate_dim, normalize=False)
152
+ embeddings.append(image_emb)
153
+
154
+ # Combine embeddings (average)
155
+ if len(embeddings) == 2:
156
+ # Average của text và image embeddings
157
+ combined = np.mean(embeddings, axis=0)
158
+ elif len(embeddings) == 1:
159
+ combined = embeddings[0]
160
+ else:
161
+ raise ValueError("Phải cung cấp ít nhất text hoặc image")
162
+
163
+ # Normalize nếu cần
164
+ if normalize:
165
+ combined = combined / np.linalg.norm(combined, axis=1, keepdims=True)
166
+
167
+ return combined
168
+
169
+ def get_embedding_dimension(self) -> int:
170
+ """
171
+ Trả về dimension của embeddings (1024 cho Jina CLIP v2)
172
+ """
173
+ return 1024
main.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List
5
+ from PIL import Image
6
+ import io
7
+ import numpy as np
8
+
9
+ from embedding_service import JinaClipEmbeddingService
10
+ from qdrant_service import QdrantVectorService
11
+
12
+ # Initialize FastAPI app
13
+ app = FastAPI(
14
+ title="Event Social Media Embeddings API",
15
+ description="API để embeddings và search text + images từ events & social media với Jina CLIP v2 + Qdrant",
16
+ version="1.0.0"
17
+ )
18
+
19
+ # Initialize services
20
+ print("Initializing services...")
21
+ embedding_service = JinaClipEmbeddingService(model_path="jinaai/jina-clip-v2")
22
+ qdrant_service = QdrantVectorService(
23
+ # URL và API key sẽ lấy từ environment variables
24
+ collection_name="event_social_media",
25
+ vector_size=embedding_service.get_embedding_dimension()
26
+ )
27
+ print("✓ Services initialized successfully")
28
+
29
+
30
+ # Pydantic models
31
+ class SearchRequest(BaseModel):
32
+ text: Optional[str] = None
33
+ limit: int = 10
34
+ score_threshold: Optional[float] = None
35
+ text_weight: float = 0.5
36
+ image_weight: float = 0.5
37
+
38
+
39
+ class SearchResponse(BaseModel):
40
+ id: str
41
+ confidence: float
42
+ metadata: dict
43
+
44
+
45
+ class IndexResponse(BaseModel):
46
+ success: bool
47
+ id: str
48
+ message: str
49
+
50
+
51
+ @app.get("/")
52
+ async def root():
53
+ """Health check endpoint"""
54
+ return {
55
+ "status": "running",
56
+ "service": "Event Social Media Embeddings API",
57
+ "embedding_model": "Jina CLIP v2",
58
+ "vector_db": "Qdrant",
59
+ "language_support": "Vietnamese + 88 other languages"
60
+ }
61
+
62
+
63
+ @app.post("/index", response_model=IndexResponse)
64
+ async def index_data(
65
+ id: str = Form(...),
66
+ text: str = Form(...),
67
+ image: Optional[UploadFile] = File(None)
68
+ ):
69
+ """
70
+ Index data vào vector database
71
+
72
+ Body:
73
+ - id: Document ID (event ID, post ID, etc.)
74
+ - text: Text content (tiếng Việt supported)
75
+ - image: Image file (optional)
76
+
77
+ Returns:
78
+ - success: True/False
79
+ - id: Document ID
80
+ - message: Status message
81
+ """
82
+ try:
83
+ # Prepare embeddings
84
+ text_embedding = None
85
+ image_embedding = None
86
+
87
+ # Encode text (tiếng Việt)
88
+ if text and text.strip():
89
+ text_embedding = embedding_service.encode_text(text)
90
+
91
+ # Encode image nếu có
92
+ if image:
93
+ image_bytes = await image.read()
94
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
95
+ image_embedding = embedding_service.encode_image(pil_image)
96
+
97
+ # Combine embeddings
98
+ if text_embedding is not None and image_embedding is not None:
99
+ # Average của text và image embeddings
100
+ combined_embedding = np.mean([text_embedding, image_embedding], axis=0)
101
+ elif text_embedding is not None:
102
+ combined_embedding = text_embedding
103
+ elif image_embedding is not None:
104
+ combined_embedding = image_embedding
105
+ else:
106
+ raise HTTPException(status_code=400, detail="Phải cung cấp ít nhất text hoặc image")
107
+
108
+ # Normalize
109
+ combined_embedding = combined_embedding / np.linalg.norm(combined_embedding, axis=1, keepdims=True)
110
+
111
+ # Index vào Qdrant
112
+ metadata = {
113
+ "text": text,
114
+ "has_image": image is not None,
115
+ "image_filename": image.filename if image else None
116
+ }
117
+
118
+ result = qdrant_service.index_data(
119
+ doc_id=id,
120
+ embedding=combined_embedding,
121
+ metadata=metadata
122
+ )
123
+
124
+ return IndexResponse(
125
+ success=True,
126
+ id=result["original_id"], # Trả về MongoDB ObjectId
127
+ message=f"Đã index thành công document {result['original_id']} (Qdrant UUID: {result['qdrant_id']})"
128
+ )
129
+
130
+ except Exception as e:
131
+ raise HTTPException(status_code=500, detail=f"Lỗi khi index: {str(e)}")
132
+
133
+
134
+ @app.post("/search", response_model=List[SearchResponse])
135
+ async def search(
136
+ text: Optional[str] = Form(None),
137
+ image: Optional[UploadFile] = File(None),
138
+ limit: int = Form(10),
139
+ score_threshold: Optional[float] = Form(None),
140
+ text_weight: float = Form(0.5),
141
+ image_weight: float = Form(0.5)
142
+ ):
143
+ """
144
+ Search similar documents bằng text và/hoặc image
145
+
146
+ Body:
147
+ - text: Query text (tiếng Việt supported)
148
+ - image: Query image (optional)
149
+ - limit: Số lượng kết quả (default: 10)
150
+ - score_threshold: Minimum confidence score (0-1)
151
+ - text_weight: Weight cho text search (default: 0.5)
152
+ - image_weight: Weight cho image search (default: 0.5)
153
+
154
+ Returns:
155
+ - List of results với id, confidence, và metadata
156
+ """
157
+ try:
158
+ # Prepare query embeddings
159
+ text_embedding = None
160
+ image_embedding = None
161
+
162
+ # Encode text query
163
+ if text and text.strip():
164
+ text_embedding = embedding_service.encode_text(text)
165
+
166
+ # Encode image query
167
+ if image:
168
+ image_bytes = await image.read()
169
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
170
+ image_embedding = embedding_service.encode_image(pil_image)
171
+
172
+ # Validate input
173
+ if text_embedding is None and image_embedding is None:
174
+ raise HTTPException(status_code=400, detail="Phải cung cấp ít nhất text hoặc image để search")
175
+
176
+ # Hybrid search với Qdrant
177
+ results = qdrant_service.hybrid_search(
178
+ text_embedding=text_embedding,
179
+ image_embedding=image_embedding,
180
+ text_weight=text_weight,
181
+ image_weight=image_weight,
182
+ limit=limit,
183
+ score_threshold=score_threshold,
184
+ ef=256 # High accuracy search
185
+ )
186
+
187
+ # Format response
188
+ return [
189
+ SearchResponse(
190
+ id=result["id"],
191
+ confidence=result["confidence"],
192
+ metadata=result["metadata"]
193
+ )
194
+ for result in results
195
+ ]
196
+
197
+ except Exception as e:
198
+ raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
199
+
200
+
201
+ @app.post("/search/text", response_model=List[SearchResponse])
202
+ async def search_by_text(
203
+ text: str = Form(...),
204
+ limit: int = Form(10),
205
+ score_threshold: Optional[float] = Form(None)
206
+ ):
207
+ """
208
+ Search chỉ bằng text (tiếng Việt)
209
+
210
+ Body:
211
+ - text: Query text (tiếng Việt)
212
+ - limit: Số lượng kết quả
213
+ - score_threshold: Minimum confidence score
214
+
215
+ Returns:
216
+ - List of results
217
+ """
218
+ try:
219
+ # Encode text
220
+ text_embedding = embedding_service.encode_text(text)
221
+
222
+ # Search
223
+ results = qdrant_service.search(
224
+ query_embedding=text_embedding,
225
+ limit=limit,
226
+ score_threshold=score_threshold,
227
+ ef=256
228
+ )
229
+
230
+ return [
231
+ SearchResponse(
232
+ id=result["id"],
233
+ confidence=result["confidence"],
234
+ metadata=result["metadata"]
235
+ )
236
+ for result in results
237
+ ]
238
+
239
+ except Exception as e:
240
+ raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
241
+
242
+
243
+ @app.post("/search/image", response_model=List[SearchResponse])
244
+ async def search_by_image(
245
+ image: UploadFile = File(...),
246
+ limit: int = Form(10),
247
+ score_threshold: Optional[float] = Form(None)
248
+ ):
249
+ """
250
+ Search chỉ bằng image
251
+
252
+ Body:
253
+ - image: Query image
254
+ - limit: Số lượng kết quả
255
+ - score_threshold: Minimum confidence score
256
+
257
+ Returns:
258
+ - List of results
259
+ """
260
+ try:
261
+ # Encode image
262
+ image_bytes = await image.read()
263
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
264
+ image_embedding = embedding_service.encode_image(pil_image)
265
+
266
+ # Search
267
+ results = qdrant_service.search(
268
+ query_embedding=image_embedding,
269
+ limit=limit,
270
+ score_threshold=score_threshold,
271
+ ef=256
272
+ )
273
+
274
+ return [
275
+ SearchResponse(
276
+ id=result["id"],
277
+ confidence=result["confidence"],
278
+ metadata=result["metadata"]
279
+ )
280
+ for result in results
281
+ ]
282
+
283
+ except Exception as e:
284
+ raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
285
+
286
+
287
+ @app.delete("/delete/{doc_id}")
288
+ async def delete_document(doc_id: str):
289
+ """
290
+ Delete document by ID (MongoDB ObjectId hoặc UUID)
291
+
292
+ Args:
293
+ - doc_id: Document ID to delete
294
+
295
+ Returns:
296
+ - Success message
297
+ """
298
+ try:
299
+ qdrant_service.delete_by_id(doc_id)
300
+ return {"success": True, "message": f"Đã xóa document {doc_id}"}
301
+ except Exception as e:
302
+ raise HTTPException(status_code=500, detail=f"Lỗi khi xóa: {str(e)}")
303
+
304
+
305
+ @app.get("/document/{doc_id}")
306
+ async def get_document(doc_id: str):
307
+ """
308
+ Get document by ID (MongoDB ObjectId hoặc UUID)
309
+
310
+ Args:
311
+ - doc_id: Document ID (MongoDB ObjectId)
312
+
313
+ Returns:
314
+ - Document data
315
+ """
316
+ try:
317
+ doc = qdrant_service.get_by_id(doc_id)
318
+ if doc:
319
+ return {
320
+ "success": True,
321
+ "data": doc
322
+ }
323
+ raise HTTPException(status_code=404, detail=f"Không tìm thấy document {doc_id}")
324
+ except HTTPException:
325
+ raise
326
+ except Exception as e:
327
+ raise HTTPException(status_code=500, detail=f"Lỗi khi get document: {str(e)}")
328
+
329
+
330
+ @app.get("/stats")
331
+ async def get_stats():
332
+ """
333
+ Lấy thông tin thống kê collection
334
+
335
+ Returns:
336
+ - Collection statistics
337
+ """
338
+ try:
339
+ info = qdrant_service.get_collection_info()
340
+ return info
341
+ except Exception as e:
342
+ raise HTTPException(status_code=500, detail=f"Lỗi khi lấy stats: {str(e)}")
343
+
344
+
345
+ if __name__ == "__main__":
346
+ import uvicorn
347
+ uvicorn.run(
348
+ app,
349
+ host="0.0.0.0",
350
+ port=8000,
351
+ log_level="info"
352
+ )
qdrant_service.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from qdrant_client.models import (
3
+ Distance, VectorParams, PointStruct,
4
+ SearchRequest, SearchParams, HnswConfigDiff,
5
+ OptimizersConfigDiff, ScalarQuantization,
6
+ ScalarQuantizationConfig, ScalarType,
7
+ QuantizationSearchParams
8
+ )
9
+ from typing import List, Dict, Any, Optional
10
+ import numpy as np
11
+ import uuid
12
+ import os
13
+
14
+
15
+ class QdrantVectorService:
16
+ """
17
+ Qdrant Cloud Vector Database Service với cấu hình tối ưu
18
+ - HNSW algorithm với parameters mạnh mẽ nhất
19
+ - Scalar Quantization để tối ưu memory và speed
20
+ - Hỗ trợ hybrid search (text + image)
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ url: Optional[str] = None,
26
+ api_key: Optional[str] = None,
27
+ collection_name: str = "event_social_media",
28
+ vector_size: int = 1024, # Jina CLIP v2 dimension
29
+ ):
30
+ """
31
+ Initialize Qdrant Cloud client
32
+
33
+ Args:
34
+ url: Qdrant Cloud URL (từ env hoặc truyền vào)
35
+ api_key: Qdrant API key (từ env hoặc truyền vào)
36
+ collection_name: Tên collection
37
+ vector_size: Dimension của vectors (1024 cho Jina CLIP v2)
38
+ """
39
+ # Lấy credentials từ env nếu không truyền vào
40
+ self.url = url or os.getenv("QDRANT_URL")
41
+ self.api_key = api_key or os.getenv("QDRANT_API_KEY")
42
+
43
+ if not self.url or not self.api_key:
44
+ raise ValueError("Cần cung cấp QDRANT_URL và QDRANT_API_KEY (qua env hoặc params)")
45
+
46
+ print(f"Connecting to Qdrant Cloud...")
47
+
48
+ # Initialize Qdrant Cloud client
49
+ self.client = QdrantClient(
50
+ url=self.url,
51
+ api_key=self.api_key,
52
+ )
53
+
54
+ self.collection_name = collection_name
55
+ self.vector_size = vector_size
56
+
57
+ # Create collection nếu chưa tồn tại
58
+ self._ensure_collection()
59
+
60
+ print(f"✓ Connected to Qdrant collection: {collection_name}")
61
+
62
+ def _ensure_collection(self):
63
+ """
64
+ Tạo collection với HNSW config tối ưu nhất
65
+ """
66
+ # Check nếu collection đã tồn tại
67
+ collections = self.client.get_collections().collections
68
+ collection_exists = any(c.name == self.collection_name for c in collections)
69
+
70
+ if not collection_exists:
71
+ print(f"Creating collection {self.collection_name} with optimal HNSW config...")
72
+
73
+ self.client.create_collection(
74
+ collection_name=self.collection_name,
75
+ vectors_config=VectorParams(
76
+ size=self.vector_size,
77
+ distance=Distance.COSINE, # Cosine similarity cho embeddings
78
+ hnsw_config=HnswConfigDiff(
79
+ m=64, # Số edges per node - cao nhất cho accuracy
80
+ ef_construct=512, # Search range khi build index - cao cho quality
81
+ full_scan_threshold=10000, # Threshold để switch sang full scan
82
+ max_indexing_threads=0, # Auto-detect số threads
83
+ on_disk=False, # Keep trong RAM cho speed (nếu đủ memory)
84
+ )
85
+ ),
86
+ optimizers_config=OptimizersConfigDiff(
87
+ deleted_threshold=0.2,
88
+ vacuum_min_vector_number=1000,
89
+ default_segment_number=2,
90
+ max_segment_size=200000,
91
+ memmap_threshold=50000,
92
+ indexing_threshold=10000,
93
+ flush_interval_sec=5,
94
+ max_optimization_threads=0, # Auto-detect
95
+ ),
96
+ # Sử dụng Scalar Quantization để tối ưu memory và speed
97
+ quantization_config=ScalarQuantization(
98
+ scalar=ScalarQuantizationConfig(
99
+ type=ScalarType.INT8,
100
+ quantile=0.99,
101
+ always_ram=True, # Keep quantized vectors trong RAM
102
+ )
103
+ )
104
+ )
105
+ print("✓ Collection created with optimal configuration")
106
+ else:
107
+ print("✓ Collection already exists")
108
+
109
+ def _convert_to_valid_id(self, doc_id: str) -> str:
110
+ """
111
+ Convert bất kỳ string ID nào thành UUID hợp lệ cho Qdrant
112
+
113
+ Args:
114
+ doc_id: Original ID (có thể là MongoDB ObjectId, string, etc.)
115
+
116
+ Returns:
117
+ UUID string hợp lệ
118
+ """
119
+ if not doc_id:
120
+ return str(uuid.uuid4())
121
+
122
+ # Nếu đã là UUID hợp lệ, giữ nguyên
123
+ try:
124
+ uuid.UUID(doc_id)
125
+ return doc_id
126
+ except ValueError:
127
+ pass
128
+
129
+ # Convert string sang UUID deterministic (cùng input = cùng UUID)
130
+ # Sử dụng UUID v5 với namespace DNS
131
+ return str(uuid.uuid5(uuid.NAMESPACE_DNS, doc_id))
132
+
133
+ def index_data(
134
+ self,
135
+ doc_id: str,
136
+ embedding: np.ndarray,
137
+ metadata: Dict[str, Any]
138
+ ) -> Dict[str, str]:
139
+ """
140
+ Index data vào Qdrant
141
+
142
+ Args:
143
+ doc_id: ID của document (MongoDB ObjectId, string, etc.)
144
+ embedding: Vector embedding từ Jina CLIP
145
+ metadata: Metadata (text, image_url, event_info, etc.)
146
+
147
+ Returns:
148
+ Dict với original_id và qdrant_id
149
+ """
150
+ # Convert ID thành UUID hợp lệ
151
+ qdrant_id = self._convert_to_valid_id(doc_id)
152
+
153
+ # Lưu original ID vào metadata
154
+ metadata['original_id'] = doc_id
155
+
156
+ # Ensure embedding là 1D array
157
+ if len(embedding.shape) > 1:
158
+ embedding = embedding.flatten()
159
+
160
+ # Create point
161
+ point = PointStruct(
162
+ id=qdrant_id,
163
+ vector=embedding.tolist(),
164
+ payload=metadata
165
+ )
166
+
167
+ # Upsert vào collection
168
+ self.client.upsert(
169
+ collection_name=self.collection_name,
170
+ points=[point]
171
+ )
172
+
173
+ return {
174
+ "original_id": doc_id,
175
+ "qdrant_id": qdrant_id
176
+ }
177
+
178
+ def batch_index(
179
+ self,
180
+ doc_ids: List[str],
181
+ embeddings: np.ndarray,
182
+ metadata_list: List[Dict[str, Any]]
183
+ ) -> List[Dict[str, str]]:
184
+ """
185
+ Batch index nhiều documents cùng lúc
186
+
187
+ Args:
188
+ doc_ids: List of document IDs (MongoDB ObjectId, string, etc.)
189
+ embeddings: Numpy array of embeddings (n_samples, embedding_dim)
190
+ metadata_list: List of metadata dicts
191
+
192
+ Returns:
193
+ List of dicts với original_id và qdrant_id
194
+ """
195
+ points = []
196
+ id_mappings = []
197
+
198
+ for i, (doc_id, embedding, metadata) in enumerate(zip(doc_ids, embeddings, metadata_list)):
199
+ # Convert to valid UUID
200
+ qdrant_id = self._convert_to_valid_id(doc_id)
201
+
202
+ # Lưu original ID vào metadata
203
+ metadata['original_id'] = doc_id
204
+
205
+ # Ensure embedding là 1D
206
+ if len(embedding.shape) > 1:
207
+ embedding = embedding.flatten()
208
+
209
+ points.append(PointStruct(
210
+ id=qdrant_id,
211
+ vector=embedding.tolist(),
212
+ payload=metadata
213
+ ))
214
+
215
+ id_mappings.append({
216
+ "original_id": doc_id,
217
+ "qdrant_id": qdrant_id
218
+ })
219
+
220
+ # Batch upsert
221
+ self.client.upsert(
222
+ collection_name=self.collection_name,
223
+ points=points,
224
+ wait=True # Wait for indexing to complete
225
+ )
226
+
227
+ return id_mappings
228
+
229
+ def search(
230
+ self,
231
+ query_embedding: np.ndarray,
232
+ limit: int = 10,
233
+ score_threshold: Optional[float] = None,
234
+ filter_conditions: Optional[Dict] = None,
235
+ ef: int = 256 # Search quality parameter - cao hơn = accurate hơn
236
+ ) -> List[Dict[str, Any]]:
237
+ """
238
+ Search similar vectors trong Qdrant
239
+
240
+ Args:
241
+ query_embedding: Query embedding từ Jina CLIP
242
+ limit: Số lượng results trả về
243
+ score_threshold: Minimum similarity score (0-1)
244
+ filter_conditions: Qdrant filter conditions
245
+ ef: HNSW search parameter (128-512, cao hơn = accurate hơn)
246
+
247
+ Returns:
248
+ List of search results với id, score, và metadata
249
+ """
250
+ # Ensure query embedding là 1D
251
+ if len(query_embedding.shape) > 1:
252
+ query_embedding = query_embedding.flatten()
253
+
254
+ # Search với HNSW parameters tối ưu
255
+ search_result = self.client.search(
256
+ collection_name=self.collection_name,
257
+ query_vector=query_embedding.tolist(),
258
+ limit=limit,
259
+ score_threshold=score_threshold,
260
+ query_filter=filter_conditions,
261
+ search_params=SearchParams(
262
+ hnsw_ef=ef, # Higher ef = more accurate search
263
+ exact=False, # Use HNSW (not exact search)
264
+ quantization=QuantizationSearchParams(
265
+ ignore=False, # Use quantization
266
+ rescore=True, # Rescore với original vectors
267
+ oversampling=2.0 # Oversample factor
268
+ )
269
+ ),
270
+ with_payload=True,
271
+ with_vectors=False # Không cần return vectors
272
+ )
273
+
274
+ # Format results - trả về original_id thay vì UUID
275
+ results = []
276
+ for hit in search_result:
277
+ # Lấy original_id từ metadata (MongoDB ObjectId)
278
+ original_id = hit.payload.get('original_id', hit.id)
279
+
280
+ results.append({
281
+ "id": original_id, # Trả về MongoDB ObjectId
282
+ "qdrant_id": hit.id, # UUID trong Qdrant
283
+ "confidence": float(hit.score), # Cosine similarity score
284
+ "metadata": hit.payload
285
+ })
286
+
287
+ return results
288
+
289
+ def hybrid_search(
290
+ self,
291
+ text_embedding: Optional[np.ndarray] = None,
292
+ image_embedding: Optional[np.ndarray] = None,
293
+ text_weight: float = 0.5,
294
+ image_weight: float = 0.5,
295
+ limit: int = 10,
296
+ score_threshold: Optional[float] = None,
297
+ ef: int = 256
298
+ ) -> List[Dict[str, Any]]:
299
+ """
300
+ Hybrid search với cả text và image embeddings
301
+
302
+ Args:
303
+ text_embedding: Text query embedding
304
+ image_embedding: Image query embedding
305
+ text_weight: Weight cho text search (0-1)
306
+ image_weight: Weight cho image search (0-1)
307
+ limit: Số results
308
+ score_threshold: Minimum score
309
+ ef: HNSW search parameter
310
+
311
+ Returns:
312
+ Combined search results
313
+ """
314
+ # Combine embeddings với weights
315
+ combined_embedding = np.zeros(self.vector_size)
316
+
317
+ if text_embedding is not None:
318
+ if len(text_embedding.shape) > 1:
319
+ text_embedding = text_embedding.flatten()
320
+ combined_embedding += text_weight * text_embedding
321
+
322
+ if image_embedding is not None:
323
+ if len(image_embedding.shape) > 1:
324
+ image_embedding = image_embedding.flatten()
325
+ combined_embedding += image_weight * image_embedding
326
+
327
+ # Normalize combined embedding
328
+ norm = np.linalg.norm(combined_embedding)
329
+ if norm > 0:
330
+ combined_embedding = combined_embedding / norm
331
+
332
+ # Search với combined embedding
333
+ return self.search(
334
+ query_embedding=combined_embedding,
335
+ limit=limit,
336
+ score_threshold=score_threshold,
337
+ ef=ef
338
+ )
339
+
340
+ def delete_by_id(self, doc_id: str) -> bool:
341
+ """
342
+ Delete document by ID (hỗ trợ cả MongoDB ObjectId và UUID)
343
+
344
+ Args:
345
+ doc_id: Document ID to delete (MongoDB ObjectId hoặc UUID)
346
+
347
+ Returns:
348
+ Success status
349
+ """
350
+ # Convert to UUID nếu là MongoDB ObjectId
351
+ qdrant_id = self._convert_to_valid_id(doc_id)
352
+
353
+ self.client.delete(
354
+ collection_name=self.collection_name,
355
+ points_selector=[qdrant_id]
356
+ )
357
+ return True
358
+
359
+ def get_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]:
360
+ """
361
+ Get document by ID (hỗ trợ cả MongoDB ObjectId và UUID)
362
+
363
+ Args:
364
+ doc_id: Document ID (MongoDB ObjectId hoặc UUID)
365
+
366
+ Returns:
367
+ Document data hoặc None nếu không tìm thấy
368
+ """
369
+ # Convert to UUID nếu là MongoDB ObjectId
370
+ qdrant_id = self._convert_to_valid_id(doc_id)
371
+
372
+ try:
373
+ result = self.client.retrieve(
374
+ collection_name=self.collection_name,
375
+ ids=[qdrant_id],
376
+ with_payload=True,
377
+ with_vectors=False
378
+ )
379
+
380
+ if result:
381
+ point = result[0]
382
+ original_id = point.payload.get('original_id', point.id)
383
+ return {
384
+ "id": original_id, # MongoDB ObjectId
385
+ "qdrant_id": point.id, # UUID trong Qdrant
386
+ "metadata": point.payload
387
+ }
388
+ return None
389
+ except Exception as e:
390
+ print(f"Error retrieving document: {e}")
391
+ return None
392
+
393
+ def search_by_metadata(
394
+ self,
395
+ filter_conditions: Dict,
396
+ limit: int = 100
397
+ ) -> List[Dict[str, Any]]:
398
+ """
399
+ Search documents by metadata conditions (không cần embedding)
400
+
401
+ Args:
402
+ filter_conditions: Qdrant filter conditions
403
+ limit: Maximum số results
404
+
405
+ Returns:
406
+ List of matching documents
407
+ """
408
+ try:
409
+ result = self.client.scroll(
410
+ collection_name=self.collection_name,
411
+ scroll_filter=filter_conditions,
412
+ limit=limit,
413
+ with_payload=True,
414
+ with_vectors=False
415
+ )
416
+
417
+ documents = []
418
+ for point in result[0]: # result is tuple (points, next_page_offset)
419
+ original_id = point.payload.get('original_id', point.id)
420
+ documents.append({
421
+ "id": original_id, # MongoDB ObjectId
422
+ "qdrant_id": point.id, # UUID trong Qdrant
423
+ "metadata": point.payload
424
+ })
425
+
426
+ return documents
427
+ except Exception as e:
428
+ print(f"Error searching by metadata: {e}")
429
+ return []
430
+
431
+ def get_collection_info(self) -> Dict[str, Any]:
432
+ """
433
+ Lấy thông tin collection
434
+
435
+ Returns:
436
+ Collection info
437
+ """
438
+ info = self.client.get_collection(collection_name=self.collection_name)
439
+ return {
440
+ "vectors_count": info.vectors_count,
441
+ "points_count": info.points_count,
442
+ "status": info.status,
443
+ "config": {
444
+ "distance": info.config.params.vectors.distance,
445
+ "size": info.config.params.vectors.size,
446
+ }
447
+ }
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI và web framework
2
+ fastapi==0.115.5
3
+ uvicorn[standard]==0.32.1
4
+ python-multipart==0.0.20
5
+
6
+ # Gradio cho Hugging Face Spaces
7
+ gradio>=4.0.0
8
+
9
+ # Machine Learning & Embeddings
10
+ torch>=2.0.0
11
+ transformers>=4.50.0
12
+ onnxruntime==1.20.1
13
+ torchvision>=0.15.0
14
+ pillow>=10.0.0
15
+ numpy>=1.24.0
16
+
17
+ # Vector Database
18
+ qdrant-client>=1.12.1
19
+ grpcio>=1.60.0
20
+
21
+ # Utilities
22
+ pydantic>=2.0.0
23
+ python-dotenv==1.0.0
24
+
25
+ # MongoDB
26
+ pymongo>=4.6.0
27
+ huggingface-hub>=0.20.0