from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Dict from pymongo import MongoClient from datetime import datetime import numpy as np import os from huggingface_hub import InferenceClient from embedding_service import JinaClipEmbeddingService from qdrant_service import QdrantVectorService # Pydantic models class ChatRequest(BaseModel): message: str use_rag: bool = True top_k: int = 3 system_message: Optional[str] = "You are a helpful AI assistant." max_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.95 hf_token: Optional[str] = None # Hugging Face token (optional, sẽ dùng env nếu không truyền) class ChatResponse(BaseModel): response: str context_used: List[Dict] timestamp: str class AddDocumentRequest(BaseModel): text: str metadata: Optional[Dict] = None class AddDocumentResponse(BaseModel): success: bool doc_id: str message: str class SearchRequest(BaseModel): query: str top_k: int = 5 score_threshold: Optional[float] = 0.5 class SearchResponse(BaseModel): results: List[Dict] # Initialize FastAPI app = FastAPI( title="ChatbotRAG API", description="API for RAG Chatbot with GPT-OSS-20B + Jina CLIP v2 + MongoDB + Qdrant", version="1.0.0" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Cho phép tất cả origins (có thể giới hạn trong production) allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ChatbotRAG Service class ChatbotRAGService: """ ChatbotRAG Service cho API """ def __init__( self, mongodb_uri: str = "mongodb+srv://truongtn7122003:7KaI9OT5KTUxWjVI@truongtn7122003.xogin4q.mongodb.net/", db_name: str = "chatbot_rag", collection_name: str = "documents", hf_token: Optional[str] = None ): print("Initializing ChatbotRAG Service...") # MongoDB self.mongo_client = MongoClient(mongodb_uri) self.db = self.mongo_client[db_name] self.documents_collection = self.db[collection_name] self.chat_history_collection = self.db["chat_history"] # Embedding service self.embedding_service = JinaClipEmbeddingService( model_path="jinaai/jina-clip-v2" ) # Qdrant collection_name = os.getenv("COLLECTION_NAME","event_social_media") self.qdrant_service = QdrantVectorService( collection_name= collection_name, vector_size=self.embedding_service.get_embedding_dimension() ) # Hugging Face token (từ env hoặc truyền vào) self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN") if self.hf_token: print("✓ Hugging Face token configured") else: print("⚠ No Hugging Face token - LLM generation will use placeholder") print("✓ ChatbotRAG Service initialized") def add_document(self, text: str, metadata: Dict = None) -> str: """Add document to knowledge base""" # Save to MongoDB doc_data = { "text": text, "metadata": metadata or {}, "created_at": datetime.utcnow() } result = self.documents_collection.insert_one(doc_data) doc_id = str(result.inserted_id) # Generate embedding embedding = self.embedding_service.encode_text(text) # Index to Qdrant self.qdrant_service.index_data( doc_id=doc_id, embedding=embedding, metadata={ "text": text, "source": "api", **(metadata or {}) } ) return doc_id def retrieve_context(self, query: str, top_k: int = 3, score_threshold: float = 0.5) -> List[Dict]: """Retrieve relevant context from vector DB""" # Generate query embedding query_embedding = self.embedding_service.encode_text(query) # Search in Qdrant results = self.qdrant_service.search( query_embedding=query_embedding, limit=top_k, score_threshold=score_threshold ) return results def generate_response( self, message: str, context: List[Dict], system_message: str, max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, hf_token: Optional[str] = None ) -> str: """ Generate response using Hugging Face LLM """ # Build context text context_text = "" if context: context_text = "\n\nRelevant Context:\n" for i, doc in enumerate(context, 1): doc_text = doc["metadata"].get("text", "") confidence = doc["confidence"] context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n" # Add context to system message system_message = f"{system_message}\n{context_text}\n\nPlease use the above context to answer the user's question when relevant." # Use token from request or fallback to service token token = hf_token or self.hf_token # If no token available, return placeholder if not token: return f"""[LLM Response Placeholder] Context retrieved: {len(context)} documents User question: {message} To enable actual LLM generation: 1. Set HUGGINGFACE_TOKEN environment variable, OR 2. Pass hf_token in request body Example: {{ "message": "Your question", "hf_token": "hf_xxxxxxxxxxxxx" }} """ # Initialize HF Inference Client try: client = InferenceClient( token=token, model="openai/gpt-oss-20b" ) # Build messages messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": message} ] # Generate response (non-streaming for API) response = "" for msg in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): choices = msg.choices if len(choices) and choices[0].delta.content: response += choices[0].delta.content return response except Exception as e: return f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed." def save_chat_history(self, user_message: str, assistant_response: str, context_used: List[Dict]): """Save chat to MongoDB""" chat_data = { "user_message": user_message, "assistant_response": assistant_response, "context_used": context_used, "timestamp": datetime.utcnow() } self.chat_history_collection.insert_one(chat_data) def get_stats(self) -> Dict: """Get statistics""" return { "documents_count": self.documents_collection.count_documents({}), "chat_history_count": self.chat_history_collection.count_documents({}), "qdrant_info": self.qdrant_service.get_collection_info() } # Initialize service rag_service = ChatbotRAGService() # API Endpoints @app.get("/") async def root(): """Health check""" return { "status": "running", "service": "ChatbotRAG API", "version": "1.0.0", "endpoints": { "POST /chat": "Chat with RAG", "POST /documents": "Add document to knowledge base", "POST /search": "Search in knowledge base", "GET /stats": "Get statistics", "GET /history": "Get chat history" } } @app.post("/chat", response_model=ChatResponse) async def chat(request: ChatRequest): """ Chat endpoint with RAG Body: - message: User message - use_rag: Enable RAG retrieval (default: true) - top_k: Number of documents to retrieve (default: 3) - system_message: System prompt (optional) - max_tokens: Max tokens for response (default: 512) - temperature: Temperature for generation (default: 0.7) Returns: - response: Generated response - context_used: Retrieved context documents - timestamp: Response timestamp """ try: # Retrieve context if RAG enabled context_used = [] if request.use_rag: context_used = rag_service.retrieve_context( query=request.message, top_k=request.top_k ) # Generate response response = rag_service.generate_response( message=request.message, context=context_used, system_message=request.system_message, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, hf_token=request.hf_token ) # Save to history rag_service.save_chat_history( user_message=request.message, assistant_response=response, context_used=context_used ) return ChatResponse( response=response, context_used=context_used, timestamp=datetime.utcnow().isoformat() ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") @app.post("/documents", response_model=AddDocumentResponse) async def add_document(request: AddDocumentRequest): """ Add document to knowledge base Body: - text: Document text - metadata: Additional metadata (optional) Returns: - success: True/False - doc_id: MongoDB document ID - message: Status message """ try: doc_id = rag_service.add_document( text=request.text, metadata=request.metadata ) return AddDocumentResponse( success=True, doc_id=doc_id, message=f"Document added successfully with ID: {doc_id}" ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") @app.post("/search", response_model=SearchResponse) async def search(request: SearchRequest): """ Search in knowledge base Body: - query: Search query - top_k: Number of results (default: 5) - score_threshold: Minimum score (default: 0.5) Returns: - results: List of matching documents """ try: results = rag_service.retrieve_context( query=request.query, top_k=request.top_k, score_threshold=request.score_threshold ) return SearchResponse(results=results) except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") @app.get("/stats") async def get_stats(): """ Get statistics Returns: - documents_count: Number of documents in MongoDB - chat_history_count: Number of chat messages - qdrant_info: Qdrant collection info """ try: return rag_service.get_stats() except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") @app.get("/history") async def get_history(limit: int = 10, skip: int = 0): """ Get chat history Query params: - limit: Number of messages to return (default: 10) - skip: Number of messages to skip (default: 0) Returns: - history: List of chat messages """ try: history = list( rag_service.chat_history_collection .find({}, {"_id": 0}) .sort("timestamp", -1) .skip(skip) .limit(limit) ) # Convert datetime to string for msg in history: if "timestamp" in msg: msg["timestamp"] = msg["timestamp"].isoformat() return {"history": history, "total": rag_service.chat_history_collection.count_documents({})} except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") @app.delete("/documents/{doc_id}") async def delete_document(doc_id: str): """ Delete document from knowledge base Args: - doc_id: Document ID (MongoDB ObjectId) Returns: - success: True/False - message: Status message """ try: # Delete from MongoDB result = rag_service.documents_collection.delete_one({"_id": doc_id}) # Delete from Qdrant if result.deleted_count > 0: rag_service.qdrant_service.delete_by_id(doc_id) return {"success": True, "message": f"Document {doc_id} deleted"} else: raise HTTPException(status_code=404, detail=f"Document {doc_id} not found") except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=8000, log_level="info" )