Spaces:
Runtime error
Runtime error
| # Author : Justin | |
| # Program : Vectorizer for Hybrid Search | |
| # Instructions : Check README.md | |
| import torch | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from qdrant_client import models | |
| import logging | |
| import json | |
| # --- Setup Logging --- | |
| # Configure logging to be more descriptive | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --- Configuration --- | |
| # Local models for vector generation | |
| DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2' | |
| # Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries | |
| SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query' | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # --- Global Variables for Models --- | |
| # These will be loaded once when the application starts | |
| dense_model = None | |
| splade_tokenizer = None | |
| splade_model = None | |
| # --- FastAPI Application --- | |
| app = FastAPI( | |
| title="Hybrid Vector Generation API", | |
| description="An API to generate dense and sparse vectors for a given text query.", | |
| version="1.2.0" | |
| ) | |
| # --- Pydantic Models for API --- | |
| class QueryRequest(BaseModel): | |
| """Request model for the API, expecting a single text query.""" | |
| query_text: str | |
| class SparseVectorResponse(BaseModel): | |
| """Response model for the sparse vector.""" | |
| indices: list[int] | |
| values: list[float] | |
| class VectorResponse(BaseModel): | |
| """Final JSON response model containing both vectors.""" | |
| dense_vector: list[float] | |
| sparse_vector: SparseVectorResponse | |
| async def load_models(): | |
| """ | |
| Asynchronous event to load ML models on application startup. | |
| This ensures models are loaded only once. | |
| """ | |
| global dense_model, splade_tokenizer, splade_model | |
| logger.info("Server is starting up... Time to load the ML models.") | |
| logger.info(f"I'll be using the '{DEVICE}' for processing.") | |
| try: | |
| dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE) | |
| splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID) | |
| splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE) | |
| logger.info("YAaay! All models have been loaded successfully.") | |
| except Exception as e: | |
| logger.critical(f"Oh no, a critical error occurred while loading models: {e}", exc_info=True) | |
| # In a real-world scenario, you might want the app to fail startup if models don't load | |
| raise e | |
| def compute_splade_vector(text: str) -> models.SparseVector: | |
| """ | |
| Computes a SPLADE sparse vector from a given text query. | |
| Args: | |
| text: The input text string. | |
| Returns: | |
| A Qdrant SparseVector object. | |
| """ | |
| tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device | |
| with torch.no_grad(): | |
| output = splade_model(**tokens) | |
| logits, attention_mask = output.logits, tokens['attention_mask'] | |
| relu_log = torch.log(1 + torch.relu(logits)) | |
| weighted_log = relu_log * attention_mask.unsqueeze(-1) | |
| max_val, _ = torch.max(weighted_log, dim=1) | |
| vec = max_val.squeeze() | |
| indices = vec.nonzero().squeeze().cpu().tolist() | |
| values = vec[indices].cpu().tolist() | |
| # Ensure indices and values are always lists, even for a single-element tensor | |
| if not isinstance(indices, list): | |
| indices = [indices] | |
| values = [values] | |
| return models.SparseVector(indices=indices, values=values) | |
| async def vectorize_query(request: QueryRequest): | |
| """ | |
| API endpoint to generate and return dense and sparse vectors for a text query. | |
| Args: | |
| request: A QueryRequest object containing the 'query_text'. | |
| Returns: | |
| A JSON response containing the dense and sparse vectors. | |
| """ | |
| # --- n8n Logging --- | |
| logger.info("=========================================================") | |
| logger.info("A new request just arrived! Let's see what we've got.") | |
| logger.info(f"The incoming search query from n8n is: '{request.query_text}'") | |
| # 1. Generate Dense Vector | |
| logger.info("First, generating the dense vector for semantic meaning...") | |
| dense_query_vector = dense_model.encode(request.query_text).tolist() | |
| logger.info("Done with the dense vector. It has %d dimensions.", len(dense_query_vector)) | |
| logger.info("Here's a small sample of the dense vector: %s...", str(dense_query_vector[:4])) | |
| # 2. Generate Sparse Vector | |
| logger.info("Next up, creating the sparse vector for keyword matching...") | |
| sparse_query_vector = compute_splade_vector(request.query_text) | |
| logger.info("Sparse vector is ready. It contains %d important terms.", len(sparse_query_vector.indices)) | |
| logger.info("Here's a sample of the sparse vector indices: %s...", str(sparse_query_vector.indices[:4])) | |
| # 3. Construct and return the response | |
| logger.info("Everything looks good. Packaging up the vectors to send back.") | |
| logger.info("-----------------------------------------------------------------") | |
| final_response = VectorResponse( | |
| dense_vector=dense_query_vector, | |
| sparse_vector=SparseVectorResponse( | |
| indices=sparse_query_vector.indices, | |
| values=sparse_query_vector.values | |
| ) | |
| ) | |
| return final_response | |
| async def root(): | |
| return {"message": "Vector Generation API is running. -- VERSION 2 --"} | |