Delta0723 commited on
Commit
05334b7
·
verified ·
1 Parent(s): 175a666

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -20
app.py CHANGED
@@ -1,35 +1,335 @@
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
- import gradio as gr
 
 
 
 
 
 
 
 
5
 
6
- # Modelo base y LoRA
7
  BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
8
- LORA_MODEL = "Delta0723/techmind-pro-v9"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Carga del modelo
11
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
12
- model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto", torch_dtype=torch.float16)
 
 
 
 
 
 
 
13
  model = PeftModel.from_pretrained(model, LORA_MODEL)
14
  model.eval()
15
 
16
- # Función de inferencia
17
- def responder(pregunta):
18
- prompt = f"<s>[INST] {pregunta} [/INST]"
19
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
20
- outputs = model.generate(**inputs, max_new_tokens=100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  respuesta = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
22
  if "[/INST]" in respuesta:
23
  respuesta = respuesta.split("[/INST]")[1].strip()
 
24
  return respuesta
25
 
26
- # Interfaz Gradio
27
- demo = gr.Interface(
28
- fn=responder,
29
- inputs=gr.Textbox(label="Pregunta sobre redes Cisco"),
30
- outputs=gr.Textbox(label="Respuesta del modelo"),
31
- title="🧠 TechMind Pro v9",
32
- description="Modelo especializado en configuración de redes Cisco y Packet Tracer."
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TechMind Pro - API Production Ready
3
+ Fine-tuning IA especializada en Redes Cisco
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import FileResponse, JSONResponse
9
+ from pydantic import BaseModel
10
+ from typing import Optional, List
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForCausalLM
13
  from peft import PeftModel
14
+ import uvicorn
15
+ import os
16
+ import json
17
+ from datetime import datetime
18
+ import re
19
+
20
+ # ============================================
21
+ # CONFIGURACIÓN
22
+ # ============================================
23
 
 
24
  BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
25
+ LORA_MODEL = "/workspace/TechMind/lora_MISTRAL_v9_ULTIMATE/final_model"
26
+ OUTPUT_DIR = "/workspace/TechMind/api_outputs"
27
+
28
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
29
+
30
+ # ============================================
31
+ # INICIALIZAR APP
32
+ # ============================================
33
+
34
+ app = FastAPI(
35
+ title="TechMind Pro API",
36
+ description="Asistente IA especializado en Redes Cisco & Packet Tracer",
37
+ version="1.0.0",
38
+ docs_url="/docs",
39
+ redoc_url="/redoc"
40
+ )
41
+
42
+ # CORS para permitir requests desde cualquier origen
43
+ app.add_middleware(
44
+ CORSMiddleware,
45
+ allow_origins=["*"],
46
+ allow_credentials=True,
47
+ allow_methods=["*"],
48
+ allow_headers=["*"],
49
+ )
50
+
51
+ # ============================================
52
+ # CARGAR MODELO (Al iniciar)
53
+ # ============================================
54
+
55
+ print("🔥 Iniciando TechMind Pro API...")
56
+ print("="*60)
57
+
58
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
59
+ tokenizer.pad_token = tokenizer.eos_token
60
 
61
+ print("📦 Cargando Mistral 7B...")
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ BASE_MODEL,
64
+ load_in_8bit=True,
65
+ device_map="auto",
66
+ torch_dtype=torch.float16,
67
+ trust_remote_code=True
68
+ )
69
+
70
+ print("🔧 Cargando LoRA v9 ULTIMATE...")
71
  model = PeftModel.from_pretrained(model, LORA_MODEL)
72
  model.eval()
73
 
74
+ print("✅ TechMind Pro listo para producción")
75
+ print("="*60)
76
+
77
+ # ============================================
78
+ # MODELOS DE DATOS
79
+ # ============================================
80
+
81
+ class QueryRequest(BaseModel):
82
+ question: str
83
+ max_tokens: Optional[int] = 500
84
+ temperature: Optional[float] = 0.7
85
+ include_files: Optional[bool] = False
86
+
87
+ class QueryResponse(BaseModel):
88
+ answer: str
89
+ confidence: float
90
+ processing_time: float
91
+ files: Optional[List[dict]] = None
92
+ metadata: dict
93
+
94
+ # ============================================
95
+ # FUNCIONES CORE
96
+ # ============================================
97
+
98
+ def generar_respuesta(question: str, max_tokens: int = 500, temperature: float = 0.7) -> str:
99
+ """
100
+ Genera respuesta del modelo TechMind
101
+ """
102
+ prompt = f"<s>[INST] {question} [/INST]"
103
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
104
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
105
+
106
+ with torch.no_grad():
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_new_tokens=max_tokens,
110
+ temperature=temperature,
111
+ top_p=0.9,
112
+ do_sample=True,
113
+ pad_token_id=tokenizer.eos_token_id,
114
+ eos_token_id=tokenizer.eos_token_id
115
+ )
116
+
117
  respuesta = tokenizer.decode(outputs[0], skip_special_tokens=True)
118
+
119
  if "[/INST]" in respuesta:
120
  respuesta = respuesta.split("[/INST]")[1].strip()
121
+
122
  return respuesta
123
 
124
+ def calcular_confianza(respuesta: str, pregunta: str) -> float:
125
+ """
126
+ Calcula score de confianza basado en keywords técnicos
127
+ """
128
+ keywords_cisco = [
129
+ 'interface', 'ip address', 'router', 'switch', 'vlan',
130
+ 'configure', 'enable', 'show', 'no shutdown', 'ospf',
131
+ 'eigrp', 'bgp', 'acl', 'nat', 'trunk'
132
+ ]
133
+
134
+ resp_lower = respuesta.lower()
135
+ encontrados = sum(1 for k in keywords_cisco if k in resp_lower)
136
+
137
+ # Score base por keywords
138
+ score = min(encontrados / 5, 1.0) * 0.7
139
+
140
+ # Bonus si tiene bloques de código
141
+ if '```' in respuesta or 'enable\nconfigure' in respuesta:
142
+ score += 0.2
143
+
144
+ # Bonus si menciona verificación
145
+ if any(v in resp_lower for v in ['show', 'verify', 'debug']):
146
+ score += 0.1
147
+
148
+ return min(score, 1.0)
149
+
150
+ def extraer_bloques_codigo(respuesta: str) -> List[dict]:
151
+ """
152
+ Extrae bloques de código de la respuesta
153
+ """
154
+ bloques = []
155
+
156
+ # Buscar bloques ```
157
+ patron = r'```(?:cisco|bash|text)?\n(.*?)```'
158
+ matches = re.findall(patron, respuesta, re.DOTALL)
159
+
160
+ for i, codigo in enumerate(matches, 1):
161
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
162
+ filename = f"config_{i}_{timestamp}.txt"
163
+ filepath = os.path.join(OUTPUT_DIR, filename)
164
+
165
+ with open(filepath, 'w') as f:
166
+ f.write(codigo)
167
+
168
+ bloques.append({
169
+ "filename": filename,
170
+ "content": codigo,
171
+ "size": len(codigo),
172
+ "download_url": f"/download/{filename}"
173
+ })
174
+
175
+ return bloques
176
+
177
+ # ============================================
178
+ # ENDPOINTS
179
+ # ============================================
180
+
181
+ @app.get("/")
182
+ def root():
183
+ """
184
+ Información de la API
185
+ """
186
+ return {
187
+ "service": "TechMind Pro API",
188
+ "version": "1.0.0",
189
+ "model": "Mistral-7B v9 ULTIMATE",
190
+ "specialization": "Cisco Networking & Packet Tracer",
191
+ "status": "operational",
192
+ "docs": "/docs",
193
+ "endpoints": {
194
+ "ask": "POST /ask",
195
+ "health": "GET /health",
196
+ "stats": "GET /stats"
197
+ }
198
+ }
199
+
200
+ @app.get("/health")
201
+ def health_check():
202
+ """
203
+ Health check del servicio
204
+ """
205
+ return {
206
+ "status": "healthy",
207
+ "model_loaded": model is not None,
208
+ "timestamp": datetime.now().isoformat()
209
+ }
210
+
211
+ @app.post("/ask", response_model=QueryResponse)
212
+ async def ask_techmind(request: QueryRequest):
213
+ """
214
+ Endpoint principal - Consultar a TechMind
215
+
216
+ Ejemplo:
217
+ ```json
218
+ {
219
+ "question": "¿Cómo configuro OSPF área 0?",
220
+ "max_tokens": 500,
221
+ "temperature": 0.7,
222
+ "include_files": true
223
+ }
224
+ ```
225
+ """
226
+ try:
227
+ start_time = datetime.now()
228
+
229
+ # Generar respuesta
230
+ answer = generar_respuesta(
231
+ request.question,
232
+ max_tokens=request.max_tokens,
233
+ temperature=request.temperature
234
+ )
235
+
236
+ # Calcular confianza
237
+ confidence = calcular_confianza(answer, request.question)
238
+
239
+ # Extraer archivos si se solicita
240
+ files = None
241
+ if request.include_files:
242
+ files = extraer_bloques_codigo(answer)
243
+
244
+ # Calcular tiempo
245
+ processing_time = (datetime.now() - start_time).total_seconds()
246
+
247
+ return QueryResponse(
248
+ answer=answer,
249
+ confidence=confidence,
250
+ processing_time=processing_time,
251
+ files=files,
252
+ metadata={
253
+ "model": "Mistral-7B v9 ULTIMATE",
254
+ "timestamp": datetime.now().isoformat(),
255
+ "tokens_generated": len(answer.split())
256
+ }
257
+ )
258
+
259
+ except Exception as e:
260
+ raise HTTPException(status_code=500, detail=str(e))
261
+
262
+ @app.get("/download/{filename}")
263
+ async def download_file(filename: str):
264
+ """
265
+ Descargar archivos de configuración generados
266
+ """
267
+ filepath = os.path.join(OUTPUT_DIR, filename)
268
+
269
+ if not os.path.exists(filepath):
270
+ raise HTTPException(status_code=404, detail="Archivo no encontrado")
271
+
272
+ return FileResponse(
273
+ filepath,
274
+ media_type='application/octet-stream',
275
+ filename=filename
276
+ )
277
+
278
+ @app.get("/stats")
279
+ def get_stats():
280
+ """
281
+ Estadísticas del servicio
282
+ """
283
+ archivos_generados = len([f for f in os.listdir(OUTPUT_DIR) if f.endswith('.txt')])
284
+
285
+ return {
286
+ "archivos_generados": archivos_generados,
287
+ "modelo": "Mistral-7B v9 ULTIMATE",
288
+ "dataset": "1,191 ejemplos",
289
+ "especialización": "Redes Cisco & Packet Tracer",
290
+ "uptime": "N/A"
291
+ }
292
+
293
+ @app.post("/batch")
294
+ async def batch_queries(questions: List[str]):
295
+ """
296
+ Procesar múltiples preguntas
297
+ """
298
+ results = []
299
+
300
+ for q in questions:
301
+ try:
302
+ answer = generar_respuesta(q)
303
+ confidence = calcular_confianza(answer, q)
304
+ results.append({
305
+ "question": q,
306
+ "answer": answer,
307
+ "confidence": confidence
308
+ })
309
+ except Exception as e:
310
+ results.append({
311
+ "question": q,
312
+ "error": str(e)
313
+ })
314
+
315
+ return {"results": results}
316
+
317
+ # ============================================
318
+ # MAIN
319
+ # ============================================
320
 
321
+ if __name__ == "__main__":
322
+ print("\n" + "="*60)
323
+ print("🚀 TechMind Pro API - Production Mode")
324
+ print("="*60)
325
+ print("📍 URL: http://0.0.0.0:8000")
326
+ print("📚 Docs: http://0.0.0.0:8000/docs")
327
+ print("🔥 Listo para recibir consultas")
328
+ print("="*60 + "\n")
329
+
330
+ uvicorn.run(
331
+ app,
332
+ host="0.0.0.0",
333
+ port=8000,
334
+ log_level="info"
335
+ )