Spaces:
Running
Running
| import logging | |
| import os | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional | |
| import mteb | |
| from sqlitedict import SqliteDict | |
| from pylate import indexes, models, retrieve | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| class IndexType(Enum): | |
| """Supported index types.""" | |
| PREBUILT = "prebuilt" | |
| LOCAL = "local" | |
| class IndexConfig: | |
| """Configuration for a search index.""" | |
| name: str | |
| type: IndexType | |
| path: str | |
| description: Optional[str] = None | |
| class MCPyLate: | |
| """Main server class that manages PyLate indexes and search operations.""" | |
| def __init__(self, override: bool = False): | |
| self.logger = logging.getLogger(__name__) | |
| dataset_name = "leetcode" | |
| model_name = "lightonai/Reason-ModernColBERT" | |
| override = override or not os.path.exists( | |
| f"indexes/{dataset_name}_{model_name.split('/')[-1]}" | |
| ) | |
| self.model = models.ColBERT( | |
| model_name_or_path=model_name, | |
| ) | |
| self.index = indexes.PLAID( | |
| override=override, | |
| index_name=f"{dataset_name}_{model_name.split('/')[-1]}", | |
| ) | |
| self.id_to_doc = SqliteDict( | |
| f"./indexes/{dataset_name}_{model_name.split('/')[-1]}/id_to_doc.sqlite", | |
| outer_stack=False, | |
| ) | |
| self.retriever = retrieve.ColBERT(index=self.index) | |
| if override: | |
| tasks = mteb.get_tasks(tasks=["BrightRetrieval"]) | |
| tasks[0].load_data() | |
| for doc, doc_id in zip( | |
| list(tasks[0].corpus[dataset_name]["standard"].values()), | |
| list(tasks[0].corpus[dataset_name]["standard"].keys()), | |
| ): | |
| self.id_to_doc[doc_id] = doc | |
| self.id_to_doc.commit() # Don't forget to commit to save changes! | |
| documents_embeddings = self.model.encode( | |
| sentences=list(tasks[0].corpus[dataset_name]["standard"].values()), | |
| batch_size=100, | |
| is_query=False, | |
| show_progress_bar=True, | |
| ) | |
| self.index.add_documents( | |
| documents_ids=list(tasks[0].corpus[dataset_name]["standard"].keys()), | |
| documents_embeddings=documents_embeddings, | |
| ) | |
| self.logger.info("Created PyLate MCP Server") | |
| def get_document( | |
| self, | |
| docid: str, | |
| ) -> Optional[Dict[str, Any]]: | |
| """Retrieve full document by document ID.""" | |
| return {"docid": docid, "text": self.id_to_doc[docid]} | |
| def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]: | |
| """Perform multi-vector search on specified index.""" | |
| try: | |
| query_embeddings = self.model.encode( | |
| sentences=[query], | |
| is_query=True, | |
| show_progress_bar=True, | |
| batch_size=32, | |
| ) | |
| scores = self.retriever.retrieve(queries_embeddings=query_embeddings, k=20) | |
| results = [] | |
| for score in scores[0]: | |
| results.append( | |
| { | |
| "docid": score["id"], | |
| "score": round(score["score"], 5), | |
| "text": self.id_to_doc[score["id"]], | |
| # "text": self.id_to_doc[score["id"]][:200] + "…" | |
| # if len(self.id_to_doc[score["id"]]) > 200 | |
| # else self.id_to_doc[score["id"]], | |
| } | |
| ) | |
| return results | |
| except Exception as e: | |
| self.logger.error(f"Search failed: {e}") | |
| raise RuntimeError(f"Search operation failed: {e}") | |