# main.py (đã sửa đổi với Lazy Loading) # =============================================================== # 1. IMPORT THƯ VIỆN & CÁC HẰNG SỐ (KHÔNG TẢI MODEL Ở ĐÂY) # =============================================================== import os import torch import gc import re import logging from fastapi import FastAPI, HTTPException from pydantic import BaseModel import uvicorn # ... (Thêm lại các import khác của bạn ở đây: transformers, datasets, langchain, etc.) from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSequenceClassification from datasets import load_dataset from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings # ... và các import còn lại # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Tạo một "kho chứa" toàn cục để lưu các model sau khi được tải # Ban đầu nó sẽ trống model_cache = {} # =============================================================== # 2. TẠO MỘT HÀM ĐỂ TẢI TẤT CẢ MODEL VÀ RAG # =============================================================== def load_all_models(): """ Hàm này sẽ tải tất cả các model và thiết lập RAG. Nó chỉ thực sự chạy một lần duy nhất khi có request đầu tiên. """ # Kiểm tra xem model đã được tải chưa để tránh tải lại if "is_loaded" in model_cache: logger.info("Models đã được tải, bỏ qua.") return logger.info("Lần đầu khởi chạy, bắt đầu quá trình tải model (có thể mất vài phút)...") is_gpu_available = torch.cuda.is_available() device = "cuda" if is_gpu_available else "cpu" logger.info(f"Thiết bị được sử dụng: {device}") # Tải tất cả các model và lưu vào cache logger.info("Đang tải model kiểm duyệt (moderation)...") model_cache["moderation_tokenizer"] = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target") model_cache["moderation_model"] = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target").to("cpu") logger.info("Đang tải model Llama-2-7b-chat-hf...") hf_token = os.environ.get("HF_TOKEN") model_id = "meta-llama/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto", torch_dtype=torch.float16) model_cache["llama_pipe"] = pipeline("text-generation", model=model, tokenizer=tokenizer, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=512) logger.info("Đang tải model phân tích cảm xúc (sentiment)...") model_cache["sentiment_analyzer"] = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=0 if is_gpu_available else -1) logger.info("Đang tải model phân tích cảm xúc chi tiết (emotion)...") model_cache["emotion_analyzer"] = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", top_k=None, device=0 if is_gpu_available else -1) # Tải và xử lý RAG # LƯU Ý: Bạn cần copy lại hàm load_and_process_datasets và các hàm helper khác # (sanitize_input, combined_sentiment_analysis, etc.) vào file này. # Để ví dụ ngắn gọn, mình sẽ giả định chúng đã tồn tại. # documents = load_and_process_datasets() # Hàm này của bạn # if documents: # vector_store = FAISS.from_texts(documents, HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')) # model_cache["retriever"] = vector_store.as_retriever(search_kwargs={'k': 2}) # else: # model_cache["retriever"] = None logger.info("Tất cả model đã được tải và thiết lập thành công!") model_cache["is_loaded"] = True # =============================================================== # 3. ĐỊNH NGHĨA APP VÀ ENDPOINT # Server sẽ khởi động ngay lập tức vì không có gì nặng ở đây. # =============================================================== app = FastAPI(title="Athena AI Therapist API") class PredictRequest(BaseModel): user_input: str @app.on_event("startup") def startup_event(): """Sự kiện này chỉ chạy 1 lần khi server bắt đầu.""" logger.info("Server FastAPI đã khởi động. Sẵn sàng nhận yêu cầu.") logger.info("Các model sẽ được tải 'lười biếng' khi có yêu cầu /predict đầu tiên.") @app.get("/", tags=["Health Check"]) def health_check(): """Endpoint siêu nhẹ để Hugging Face kiểm tra sức khỏe.""" return {"status": "healthy", "models_loaded": model_cache.get("is_loaded", False)} @app.post("/predict", tags=["Core Logic"]) async def predict(request: PredictRequest): """ Endpoint chính. Nó sẽ kích hoạt việc tải model nếu đây là lần chạy đầu tiên. """ # Bước quan trọng: Gọi hàm tải model. # Nếu model đã được tải, nó sẽ bỏ qua ngay lập tức. # Nếu chưa, nó sẽ chặn và tải ở đây. load_all_models() try: # Bây giờ, sử dụng các model từ cache # response_text = generate_safe_response(request.user_input, model_cache) # Lưu ý: bạn sẽ cần sửa lại hàm generate_safe_response và các hàm khác # để chúng nhận `model_cache` làm tham số thay vì dùng biến toàn cục. # ---- VÍ DỤ TẠM THỜI ĐỂ TEST ---- prompt = f"User: {request.user_input}\nAthena:" llama_pipe = model_cache["llama_pipe"] result = llama_pipe(prompt) response_text = result[0]['generated_text'] # -------------------------------- return {"response": response_text} except Exception as e: logger.error(f"Lỗi tại endpoint /predict: {str(e)}") raise HTTPException(status_code=500, detail="Đã xảy ra lỗi máy chủ nội bộ.")