techmind-pro / app.py
Delta0723's picture
Update app.py
c456490 verified
raw
history blame
2.56 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import os
from datetime import datetime
import re
# =========================
# CONFIG
# =========================
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
LORA_MODEL = "Delta0723/techmind-pro-v9"
# =========================
# FastAPI Setup
# =========================
app = FastAPI(title="TechMind Pro API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
# =========================
# Load Model
# =========================
print("πŸš€ Cargando modelo y tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, LORA_MODEL)
model.eval()
except Exception as e:
print("❌ Error al cargar el modelo:", e)
raise e
print("βœ… Modelo listo")
# =========================
# Data Models
# =========================
class Query(BaseModel):
question: str
max_tokens: Optional[int] = 300
temperature: Optional[float] = 0.7
# =========================
# Utilidades
# =========================
def generate_answer(question: str, max_tokens=300, temperature=0.7) -> str:
prompt = f"<s>[INST] {question} [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded.split("[/INST]")[-1].strip() if "[/INST]" in decoded else decoded
# =========================
# Endpoints
# =========================
@app.get("/")
def root():
return {"TechMind": "Mistral-7B Instruct + LoRA v9", "status": "online"}
@app.post("/ask")
def ask_q(req: Query):
try:
result = generate_answer(req.question, req.max_tokens, req.temperature)
return {"response": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))