Delta0723 commited on
Commit
5ce85fd
·
verified ·
1 Parent(s): da930c8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +336 -0
app.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ TechMind Pro - API Production Ready
4
+ Fine-tuning IA especializada en Redes Cisco
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import FileResponse, JSONResponse
10
+ from pydantic import BaseModel
11
+ from typing import Optional, List
12
+ import torch
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ from peft import PeftModel
15
+ import uvicorn
16
+ import os
17
+ import json
18
+ from datetime import datetime
19
+ import re
20
+
21
+ # ============================================
22
+ # CONFIGURACIÓN
23
+ # ============================================
24
+
25
+ BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
26
+ LORA_MODEL = "/workspace/TechMind/lora_MISTRAL_v9_ULTIMATE/final_model"
27
+ OUTPUT_DIR = "/workspace/TechMind/api_outputs"
28
+
29
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
30
+
31
+ # ============================================
32
+ # INICIALIZAR APP
33
+ # ============================================
34
+
35
+ app = FastAPI(
36
+ title="TechMind Pro API",
37
+ description="Asistente IA especializado en Redes Cisco & Packet Tracer",
38
+ version="1.0.0",
39
+ docs_url="/docs",
40
+ redoc_url="/redoc"
41
+ )
42
+
43
+ # CORS para permitir requests desde cualquier origen
44
+ app.add_middleware(
45
+ CORSMiddleware,
46
+ allow_origins=["*"],
47
+ allow_credentials=True,
48
+ allow_methods=["*"],
49
+ allow_headers=["*"],
50
+ )
51
+
52
+ # ============================================
53
+ # CARGAR MODELO (Al iniciar)
54
+ # ============================================
55
+
56
+ print("🔥 Iniciando TechMind Pro API...")
57
+ print("="*60)
58
+
59
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+
62
+ print("📦 Cargando Mistral 7B...")
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ BASE_MODEL,
65
+ load_in_8bit=True,
66
+ device_map="auto",
67
+ torch_dtype=torch.float16,
68
+ trust_remote_code=True
69
+ )
70
+
71
+ print("🔧 Cargando LoRA v9 ULTIMATE...")
72
+ model = PeftModel.from_pretrained(model, LORA_MODEL)
73
+ model.eval()
74
+
75
+ print("✅ TechMind Pro listo para producción")
76
+ print("="*60)
77
+
78
+ # ============================================
79
+ # MODELOS DE DATOS
80
+ # ============================================
81
+
82
+ class QueryRequest(BaseModel):
83
+ question: str
84
+ max_tokens: Optional[int] = 500
85
+ temperature: Optional[float] = 0.7
86
+ include_files: Optional[bool] = False
87
+
88
+ class QueryResponse(BaseModel):
89
+ answer: str
90
+ confidence: float
91
+ processing_time: float
92
+ files: Optional[List[dict]] = None
93
+ metadata: dict
94
+
95
+ # ============================================
96
+ # FUNCIONES CORE
97
+ # ============================================
98
+
99
+ def generar_respuesta(question: str, max_tokens: int = 500, temperature: float = 0.7) -> str:
100
+ """
101
+ Genera respuesta del modelo TechMind
102
+ """
103
+ prompt = f"<s>[INST] {question} [/INST]"
104
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
105
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
106
+
107
+ with torch.no_grad():
108
+ outputs = model.generate(
109
+ **inputs,
110
+ max_new_tokens=max_tokens,
111
+ temperature=temperature,
112
+ top_p=0.9,
113
+ do_sample=True,
114
+ pad_token_id=tokenizer.eos_token_id,
115
+ eos_token_id=tokenizer.eos_token_id
116
+ )
117
+
118
+ respuesta = tokenizer.decode(outputs[0], skip_special_tokens=True)
119
+
120
+ if "[/INST]" in respuesta:
121
+ respuesta = respuesta.split("[/INST]")[1].strip()
122
+
123
+ return respuesta
124
+
125
+ def calcular_confianza(respuesta: str, pregunta: str) -> float:
126
+ """
127
+ Calcula score de confianza basado en keywords técnicos
128
+ """
129
+ keywords_cisco = [
130
+ 'interface', 'ip address', 'router', 'switch', 'vlan',
131
+ 'configure', 'enable', 'show', 'no shutdown', 'ospf',
132
+ 'eigrp', 'bgp', 'acl', 'nat', 'trunk'
133
+ ]
134
+
135
+ resp_lower = respuesta.lower()
136
+ encontrados = sum(1 for k in keywords_cisco if k in resp_lower)
137
+
138
+ # Score base por keywords
139
+ score = min(encontrados / 5, 1.0) * 0.7
140
+
141
+ # Bonus si tiene bloques de código
142
+ if '```' in respuesta or 'enable\nconfigure' in respuesta:
143
+ score += 0.2
144
+
145
+ # Bonus si menciona verificación
146
+ if any(v in resp_lower for v in ['show', 'verify', 'debug']):
147
+ score += 0.1
148
+
149
+ return min(score, 1.0)
150
+
151
+ def extraer_bloques_codigo(respuesta: str) -> List[dict]:
152
+ """
153
+ Extrae bloques de código de la respuesta
154
+ """
155
+ bloques = []
156
+
157
+ # Buscar bloques ```
158
+ patron = r'```(?:cisco|bash|text)?\n(.*?)```'
159
+ matches = re.findall(patron, respuesta, re.DOTALL)
160
+
161
+ for i, codigo in enumerate(matches, 1):
162
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
163
+ filename = f"config_{i}_{timestamp}.txt"
164
+ filepath = os.path.join(OUTPUT_DIR, filename)
165
+
166
+ with open(filepath, 'w') as f:
167
+ f.write(codigo)
168
+
169
+ bloques.append({
170
+ "filename": filename,
171
+ "content": codigo,
172
+ "size": len(codigo),
173
+ "download_url": f"/download/{filename}"
174
+ })
175
+
176
+ return bloques
177
+
178
+ # ============================================
179
+ # ENDPOINTS
180
+ # ============================================
181
+
182
+ @app.get("/")
183
+ def root():
184
+ """
185
+ Información de la API
186
+ """
187
+ return {
188
+ "service": "TechMind Pro API",
189
+ "version": "1.0.0",
190
+ "model": "Mistral-7B v9 ULTIMATE",
191
+ "specialization": "Cisco Networking & Packet Tracer",
192
+ "status": "operational",
193
+ "docs": "/docs",
194
+ "endpoints": {
195
+ "ask": "POST /ask",
196
+ "health": "GET /health",
197
+ "stats": "GET /stats"
198
+ }
199
+ }
200
+
201
+ @app.get("/health")
202
+ def health_check():
203
+ """
204
+ Health check del servicio
205
+ """
206
+ return {
207
+ "status": "healthy",
208
+ "model_loaded": model is not None,
209
+ "timestamp": datetime.now().isoformat()
210
+ }
211
+
212
+ @app.post("/ask", response_model=QueryResponse)
213
+ async def ask_techmind(request: QueryRequest):
214
+ """
215
+ Endpoint principal - Consultar a TechMind
216
+
217
+ Ejemplo:
218
+ ```json
219
+ {
220
+ "question": "¿Cómo configuro OSPF área 0?",
221
+ "max_tokens": 500,
222
+ "temperature": 0.7,
223
+ "include_files": true
224
+ }
225
+ ```
226
+ """
227
+ try:
228
+ start_time = datetime.now()
229
+
230
+ # Generar respuesta
231
+ answer = generar_respuesta(
232
+ request.question,
233
+ max_tokens=request.max_tokens,
234
+ temperature=request.temperature
235
+ )
236
+
237
+ # Calcular confianza
238
+ confidence = calcular_confianza(answer, request.question)
239
+
240
+ # Extraer archivos si se solicita
241
+ files = None
242
+ if request.include_files:
243
+ files = extraer_bloques_codigo(answer)
244
+
245
+ # Calcular tiempo
246
+ processing_time = (datetime.now() - start_time).total_seconds()
247
+
248
+ return QueryResponse(
249
+ answer=answer,
250
+ confidence=confidence,
251
+ processing_time=processing_time,
252
+ files=files,
253
+ metadata={
254
+ "model": "Mistral-7B v9 ULTIMATE",
255
+ "timestamp": datetime.now().isoformat(),
256
+ "tokens_generated": len(answer.split())
257
+ }
258
+ )
259
+
260
+ except Exception as e:
261
+ raise HTTPException(status_code=500, detail=str(e))
262
+
263
+ @app.get("/download/{filename}")
264
+ async def download_file(filename: str):
265
+ """
266
+ Descargar archivos de configuración generados
267
+ """
268
+ filepath = os.path.join(OUTPUT_DIR, filename)
269
+
270
+ if not os.path.exists(filepath):
271
+ raise HTTPException(status_code=404, detail="Archivo no encontrado")
272
+
273
+ return FileResponse(
274
+ filepath,
275
+ media_type='application/octet-stream',
276
+ filename=filename
277
+ )
278
+
279
+ @app.get("/stats")
280
+ def get_stats():
281
+ """
282
+ Estadísticas del servicio
283
+ """
284
+ archivos_generados = len([f for f in os.listdir(OUTPUT_DIR) if f.endswith('.txt')])
285
+
286
+ return {
287
+ "archivos_generados": archivos_generados,
288
+ "modelo": "Mistral-7B v9 ULTIMATE",
289
+ "dataset": "1,191 ejemplos",
290
+ "especialización": "Redes Cisco & Packet Tracer",
291
+ "uptime": "N/A"
292
+ }
293
+
294
+ @app.post("/batch")
295
+ async def batch_queries(questions: List[str]):
296
+ """
297
+ Procesar múltiples preguntas
298
+ """
299
+ results = []
300
+
301
+ for q in questions:
302
+ try:
303
+ answer = generar_respuesta(q)
304
+ confidence = calcular_confianza(answer, q)
305
+ results.append({
306
+ "question": q,
307
+ "answer": answer,
308
+ "confidence": confidence
309
+ })
310
+ except Exception as e:
311
+ results.append({
312
+ "question": q,
313
+ "error": str(e)
314
+ })
315
+
316
+ return {"results": results}
317
+
318
+ # ============================================
319
+ # MAIN
320
+ # ============================================
321
+
322
+ if __name__ == "__main__":
323
+ print("\n" + "="*60)
324
+ print("🚀 TechMind Pro API - Production Mode")
325
+ print("="*60)
326
+ print("📍 URL: http://0.0.0.0:8000")
327
+ print("📚 Docs: http://0.0.0.0:8000/docs")
328
+ print("🔥 Listo para recibir consultas")
329
+ print("="*60 + "\n")
330
+
331
+ uvicorn.run(
332
+ app,
333
+ host="0.0.0.0",
334
+ port=8000,
335
+ log_level="info"
336
+ )