Spaces:
Sleeping
Sleeping
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import ( | |
| Distance, VectorParams, PointStruct, | |
| SearchRequest, SearchParams, HnswConfigDiff, | |
| OptimizersConfigDiff, ScalarQuantization, | |
| ScalarQuantizationConfig, ScalarType, | |
| QuantizationSearchParams | |
| ) | |
| from typing import List, Dict, Any, Optional | |
| import numpy as np | |
| import uuid | |
| import os | |
| class QdrantVectorService: | |
| """ | |
| Qdrant Cloud Vector Database Service với cấu hình tối ưu | |
| - HNSW algorithm với parameters mạnh mẽ nhất | |
| - Scalar Quantization để tối ưu memory và speed | |
| - Hỗ trợ hybrid search (text + image) | |
| """ | |
| def __init__( | |
| self, | |
| url: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| collection_name: str = "event_social_media", | |
| vector_size: int = 1024, # Jina CLIP v2 dimension | |
| ): | |
| """ | |
| Initialize Qdrant Cloud client | |
| Args: | |
| url: Qdrant Cloud URL (từ env hoặc truyền vào) | |
| api_key: Qdrant API key (từ env hoặc truyền vào) | |
| collection_name: Tên collection | |
| vector_size: Dimension của vectors (1024 cho Jina CLIP v2) | |
| """ | |
| # Lấy credentials từ env nếu không truyền vào | |
| self.url = url or os.getenv("QDRANT_URL") | |
| self.api_key = api_key or os.getenv("QDRANT_API_KEY") | |
| if not self.url or not self.api_key: | |
| raise ValueError("Cần cung cấp QDRANT_URL và QDRANT_API_KEY (qua env hoặc params)") | |
| print(f"Connecting to Qdrant Cloud...") | |
| # Initialize Qdrant Cloud client | |
| self.client = QdrantClient( | |
| url=self.url, | |
| api_key=self.api_key, | |
| ) | |
| self.collection_name = collection_name | |
| self.vector_size = vector_size | |
| # Create collection nếu chưa tồn tại | |
| self._ensure_collection() | |
| print(f"✓ Connected to Qdrant collection: {collection_name}") | |
| def _ensure_collection(self): | |
| """ | |
| Tạo collection với HNSW config tối ưu nhất | |
| """ | |
| # Check nếu collection đã tồn tại | |
| collections = self.client.get_collections().collections | |
| collection_exists = any(c.name == self.collection_name for c in collections) | |
| if not collection_exists: | |
| print(f"Creating collection {self.collection_name} with optimal HNSW config...") | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams( | |
| size=self.vector_size, | |
| distance=Distance.COSINE, # Cosine similarity cho embeddings | |
| hnsw_config=HnswConfigDiff( | |
| m=64, # Số edges per node - cao nhất cho accuracy | |
| ef_construct=512, # Search range khi build index - cao cho quality | |
| full_scan_threshold=10000, # Threshold để switch sang full scan | |
| max_indexing_threads=0, # Auto-detect số threads | |
| on_disk=False, # Keep trong RAM cho speed (nếu đủ memory) | |
| ) | |
| ), | |
| optimizers_config=OptimizersConfigDiff( | |
| deleted_threshold=0.2, | |
| vacuum_min_vector_number=1000, | |
| default_segment_number=2, | |
| max_segment_size=200000, | |
| memmap_threshold=50000, | |
| indexing_threshold=10000, | |
| flush_interval_sec=5, | |
| max_optimization_threads=0, # Auto-detect | |
| ), | |
| # Sử dụng Scalar Quantization để tối ưu memory và speed | |
| quantization_config=ScalarQuantization( | |
| scalar=ScalarQuantizationConfig( | |
| type=ScalarType.INT8, | |
| quantile=0.99, | |
| always_ram=True, # Keep quantized vectors trong RAM | |
| ) | |
| ) | |
| ) | |
| print("✓ Collection created with optimal configuration") | |
| else: | |
| print("✓ Collection already exists") | |
| def _convert_to_valid_id(self, doc_id: str) -> str: | |
| """ | |
| Convert bất kỳ string ID nào thành UUID hợp lệ cho Qdrant | |
| Args: | |
| doc_id: Original ID (có thể là MongoDB ObjectId, string, etc.) | |
| Returns: | |
| UUID string hợp lệ | |
| """ | |
| if not doc_id: | |
| return str(uuid.uuid4()) | |
| # Nếu đã là UUID hợp lệ, giữ nguyên | |
| try: | |
| uuid.UUID(doc_id) | |
| return doc_id | |
| except ValueError: | |
| pass | |
| # Convert string sang UUID deterministic (cùng input = cùng UUID) | |
| # Sử dụng UUID v5 với namespace DNS | |
| return str(uuid.uuid5(uuid.NAMESPACE_DNS, doc_id)) | |
| def index_data( | |
| self, | |
| doc_id: str, | |
| embedding: np.ndarray, | |
| metadata: Dict[str, Any] | |
| ) -> Dict[str, str]: | |
| """ | |
| Index data vào Qdrant | |
| Args: | |
| doc_id: ID của document (MongoDB ObjectId, string, etc.) | |
| embedding: Vector embedding từ Jina CLIP | |
| metadata: Metadata (text, image_url, event_info, etc.) | |
| Returns: | |
| Dict với original_id và qdrant_id | |
| """ | |
| # Convert ID thành UUID hợp lệ | |
| qdrant_id = self._convert_to_valid_id(doc_id) | |
| # Lưu original ID vào metadata | |
| metadata['original_id'] = doc_id | |
| # Ensure embedding là 1D array | |
| if len(embedding.shape) > 1: | |
| embedding = embedding.flatten() | |
| # Create point | |
| point = PointStruct( | |
| id=qdrant_id, | |
| vector=embedding.tolist(), | |
| payload=metadata | |
| ) | |
| # Upsert vào collection | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=[point] | |
| ) | |
| return { | |
| "original_id": doc_id, | |
| "qdrant_id": qdrant_id | |
| } | |
| def batch_index( | |
| self, | |
| doc_ids: List[str], | |
| embeddings: np.ndarray, | |
| metadata_list: List[Dict[str, Any]] | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| Batch index nhiều documents cùng lúc | |
| Args: | |
| doc_ids: List of document IDs (MongoDB ObjectId, string, etc.) | |
| embeddings: Numpy array of embeddings (n_samples, embedding_dim) | |
| metadata_list: List of metadata dicts | |
| Returns: | |
| List of dicts với original_id và qdrant_id | |
| """ | |
| points = [] | |
| id_mappings = [] | |
| for i, (doc_id, embedding, metadata) in enumerate(zip(doc_ids, embeddings, metadata_list)): | |
| # Convert to valid UUID | |
| qdrant_id = self._convert_to_valid_id(doc_id) | |
| # Lưu original ID vào metadata | |
| metadata['original_id'] = doc_id | |
| # Ensure embedding là 1D | |
| if len(embedding.shape) > 1: | |
| embedding = embedding.flatten() | |
| points.append(PointStruct( | |
| id=qdrant_id, | |
| vector=embedding.tolist(), | |
| payload=metadata | |
| )) | |
| id_mappings.append({ | |
| "original_id": doc_id, | |
| "qdrant_id": qdrant_id | |
| }) | |
| # Batch upsert | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=points, | |
| wait=True # Wait for indexing to complete | |
| ) | |
| return id_mappings | |
| def search( | |
| self, | |
| query_embedding: np.ndarray, | |
| limit: int = 10, | |
| score_threshold: Optional[float] = None, | |
| filter_conditions: Optional[Dict] = None, | |
| ef: int = 256 # Search quality parameter - cao hơn = accurate hơn | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Search similar vectors trong Qdrant | |
| Args: | |
| query_embedding: Query embedding từ Jina CLIP | |
| limit: Số lượng results trả về | |
| score_threshold: Minimum similarity score (0-1) | |
| filter_conditions: Qdrant filter conditions | |
| ef: HNSW search parameter (128-512, cao hơn = accurate hơn) | |
| Returns: | |
| List of search results với id, score, và metadata | |
| """ | |
| # Ensure query embedding là 1D | |
| if len(query_embedding.shape) > 1: | |
| query_embedding = query_embedding.flatten() | |
| # Search với HNSW parameters tối ưu | |
| search_result = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding.tolist(), | |
| limit=limit, | |
| score_threshold=score_threshold, | |
| query_filter=filter_conditions, | |
| search_params=SearchParams( | |
| hnsw_ef=ef, # Higher ef = more accurate search | |
| exact=False, # Use HNSW (not exact search) | |
| quantization=QuantizationSearchParams( | |
| ignore=False, # Use quantization | |
| rescore=True, # Rescore với original vectors | |
| oversampling=2.0 # Oversample factor | |
| ) | |
| ), | |
| with_payload=True, | |
| with_vectors=False # Không cần return vectors | |
| ) | |
| # Format results - trả về original_id thay vì UUID | |
| results = [] | |
| for hit in search_result: | |
| # Lấy original_id từ metadata (MongoDB ObjectId) | |
| original_id = hit.payload.get('original_id', hit.id) | |
| results.append({ | |
| "id": original_id, # Trả về MongoDB ObjectId | |
| "qdrant_id": hit.id, # UUID trong Qdrant | |
| "confidence": float(hit.score), # Cosine similarity score | |
| "metadata": hit.payload | |
| }) | |
| return results | |
| def hybrid_search( | |
| self, | |
| text_embedding: Optional[np.ndarray] = None, | |
| image_embedding: Optional[np.ndarray] = None, | |
| text_weight: float = 0.5, | |
| image_weight: float = 0.5, | |
| limit: int = 10, | |
| score_threshold: Optional[float] = None, | |
| ef: int = 256 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Hybrid search với cả text và image embeddings | |
| Args: | |
| text_embedding: Text query embedding | |
| image_embedding: Image query embedding | |
| text_weight: Weight cho text search (0-1) | |
| image_weight: Weight cho image search (0-1) | |
| limit: Số results | |
| score_threshold: Minimum score | |
| ef: HNSW search parameter | |
| Returns: | |
| Combined search results | |
| """ | |
| # Combine embeddings với weights | |
| combined_embedding = np.zeros(self.vector_size) | |
| if text_embedding is not None: | |
| if len(text_embedding.shape) > 1: | |
| text_embedding = text_embedding.flatten() | |
| combined_embedding += text_weight * text_embedding | |
| if image_embedding is not None: | |
| if len(image_embedding.shape) > 1: | |
| image_embedding = image_embedding.flatten() | |
| combined_embedding += image_weight * image_embedding | |
| # Normalize combined embedding | |
| norm = np.linalg.norm(combined_embedding) | |
| if norm > 0: | |
| combined_embedding = combined_embedding / norm | |
| # Search với combined embedding | |
| return self.search( | |
| query_embedding=combined_embedding, | |
| limit=limit, | |
| score_threshold=score_threshold, | |
| ef=ef | |
| ) | |
| def delete_by_id(self, doc_id: str) -> bool: | |
| """ | |
| Delete document by ID (hỗ trợ cả MongoDB ObjectId và UUID) | |
| Args: | |
| doc_id: Document ID to delete (MongoDB ObjectId hoặc UUID) | |
| Returns: | |
| Success status | |
| """ | |
| # Convert to UUID nếu là MongoDB ObjectId | |
| qdrant_id = self._convert_to_valid_id(doc_id) | |
| self.client.delete( | |
| collection_name=self.collection_name, | |
| points_selector=[qdrant_id] | |
| ) | |
| return True | |
| def get_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get document by ID (hỗ trợ cả MongoDB ObjectId và UUID) | |
| Args: | |
| doc_id: Document ID (MongoDB ObjectId hoặc UUID) | |
| Returns: | |
| Document data hoặc None nếu không tìm thấy | |
| """ | |
| # Convert to UUID nếu là MongoDB ObjectId | |
| qdrant_id = self._convert_to_valid_id(doc_id) | |
| try: | |
| result = self.client.retrieve( | |
| collection_name=self.collection_name, | |
| ids=[qdrant_id], | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| if result: | |
| point = result[0] | |
| original_id = point.payload.get('original_id', point.id) | |
| return { | |
| "id": original_id, # MongoDB ObjectId | |
| "qdrant_id": point.id, # UUID trong Qdrant | |
| "metadata": point.payload | |
| } | |
| return None | |
| except Exception as e: | |
| print(f"Error retrieving document: {e}") | |
| return None | |
| def search_by_metadata( | |
| self, | |
| filter_conditions: Dict, | |
| limit: int = 100 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Search documents by metadata conditions (không cần embedding) | |
| Args: | |
| filter_conditions: Qdrant filter conditions | |
| limit: Maximum số results | |
| Returns: | |
| List of matching documents | |
| """ | |
| try: | |
| result = self.client.scroll( | |
| collection_name=self.collection_name, | |
| scroll_filter=filter_conditions, | |
| limit=limit, | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| documents = [] | |
| for point in result[0]: # result is tuple (points, next_page_offset) | |
| original_id = point.payload.get('original_id', point.id) | |
| documents.append({ | |
| "id": original_id, # MongoDB ObjectId | |
| "qdrant_id": point.id, # UUID trong Qdrant | |
| "metadata": point.payload | |
| }) | |
| return documents | |
| except Exception as e: | |
| print(f"Error searching by metadata: {e}") | |
| return [] | |
| def get_collection_info(self) -> Dict[str, Any]: | |
| """ | |
| Lấy thông tin collection | |
| Returns: | |
| Collection info | |
| """ | |
| info = self.client.get_collection(collection_name=self.collection_name) | |
| return { | |
| "vectors_count": info.vectors_count, | |
| "points_count": info.points_count, | |
| "status": info.status, | |
| "config": { | |
| "distance": info.config.params.vectors.distance, | |
| "size": info.config.params.vectors.size, | |
| } | |
| } | |