techmind-pro / app.py
Delta0723's picture
Update app.py
6498586 verified
raw
history blame
2.78 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
import os
# =========================
# CONFIG
# =========================
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
LORA_MODEL = "Delta0723/techmind-pro-v9"
# Crear carpeta para offload si no existe
os.makedirs("offload", exist_ok=True)
# =========================
# FastAPI Setup
# =========================
app = FastAPI(title="TechMind Pro API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
# =========================
# Load Model
# =========================
print("πŸš€ Cargando modelo y tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
quant_config = BitsAndBytesConfig(load_in_4bit=True)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
trust_remote_code=True,
offload_folder="offload",
quantization_config=quant_config
)
model = PeftModel.from_pretrained(base_model, LORA_MODEL)
model.eval()
print("βœ… Modelo listo para usar")
except Exception as e:
print("❌ Error al cargar el modelo:", e)
raise e
# =========================
# Data Models
# =========================
class Query(BaseModel):
question: str
max_tokens: Optional[int] = 300
temperature: Optional[float] = 0.7
# =========================
# Utilidades
# =========================
def generate_answer(question: str, max_tokens=300, temperature=0.7) -> str:
prompt = f"<s>[INST] {question} [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded.split("[/INST]")[-1].strip() if "[/INST]" in decoded else decoded
# =========================
# Endpoints
# =========================
@app.get("/")
def root():
return {"TechMind": "Mistral-7B Instruct + LoRA v9", "status": "online"}
@app.post("/ask")
def ask_q(req: Query):
try:
result = generate_answer(req.question, req.max_tokens, req.temperature)
return {"response": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))