|
|
""" |
|
|
API routes for Edge LLM |
|
|
""" |
|
|
from fastapi import APIRouter, HTTPException, Request, UploadFile, File |
|
|
from fastapi.responses import FileResponse |
|
|
from typing import List |
|
|
from ..models import ( |
|
|
PromptRequest, PromptResponse, ModelInfo, ModelsResponse, |
|
|
ModelLoadRequest, ModelUnloadRequest |
|
|
) |
|
|
from ..services.model_service import model_service |
|
|
from ..services.chat_service import chat_service |
|
|
from ..config import AVAILABLE_MODELS |
|
|
|
|
|
|
|
|
try: |
|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
from rag_system import get_rag_system |
|
|
RAG_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
print(f"RAG system not available: {e}") |
|
|
RAG_AVAILABLE = False |
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
@router.get("/") |
|
|
async def read_index(): |
|
|
"""Serve the React app""" |
|
|
from ..config import FRONTEND_DIST_DIR |
|
|
return FileResponse(f'{FRONTEND_DIST_DIR}/index.html') |
|
|
|
|
|
|
|
|
@router.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return {"status": "healthy", "message": "Edge LLM API is running"} |
|
|
|
|
|
|
|
|
@router.get("/models", response_model=ModelsResponse) |
|
|
async def get_models(): |
|
|
"""Get available models and their status""" |
|
|
models = [] |
|
|
for model_name, info in AVAILABLE_MODELS.items(): |
|
|
models.append(ModelInfo( |
|
|
model_name=model_name, |
|
|
name=info["name"], |
|
|
supports_thinking=info["supports_thinking"], |
|
|
description=info["description"], |
|
|
size_gb=info["size_gb"], |
|
|
is_loaded=model_service.is_model_loaded(model_name), |
|
|
type=info["type"] |
|
|
)) |
|
|
|
|
|
return ModelsResponse( |
|
|
models=models, |
|
|
current_model=model_service.get_current_model() or "" |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/load-model") |
|
|
async def load_model(request: ModelLoadRequest): |
|
|
"""Load a specific model""" |
|
|
if request.model_name not in AVAILABLE_MODELS: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {request.model_name} not available" |
|
|
) |
|
|
|
|
|
success = model_service.load_model(request.model_name) |
|
|
if success: |
|
|
model_service.set_current_model(request.model_name) |
|
|
return { |
|
|
"message": f"Model {request.model_name} loaded successfully", |
|
|
"current_model": model_service.get_current_model() |
|
|
} |
|
|
else: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Failed to load model {request.model_name}" |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/unload-model") |
|
|
async def unload_model(request: ModelUnloadRequest): |
|
|
"""Unload a specific model""" |
|
|
success = model_service.unload_model(request.model_name) |
|
|
if success: |
|
|
return { |
|
|
"message": f"Model {request.model_name} unloaded successfully", |
|
|
"current_model": model_service.get_current_model() or "" |
|
|
} |
|
|
else: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model {request.model_name} not found in cache" |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/set-current-model") |
|
|
async def set_current_model(request: ModelLoadRequest): |
|
|
"""Set the current active model""" |
|
|
if not model_service.is_model_loaded(request.model_name): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {request.model_name} is not loaded. Please load it first." |
|
|
) |
|
|
|
|
|
model_service.set_current_model(request.model_name) |
|
|
return { |
|
|
"message": f"Current model set to {request.model_name}", |
|
|
"current_model": model_service.get_current_model() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@router.post("/rag/upload") |
|
|
async def upload_document(files: List[UploadFile] = File(...)): |
|
|
"""Upload documents for RAG system""" |
|
|
if not RAG_AVAILABLE: |
|
|
raise HTTPException(status_code=503, detail="RAG system not available") |
|
|
|
|
|
rag_system = get_rag_system() |
|
|
results = [] |
|
|
|
|
|
for file in files: |
|
|
try: |
|
|
|
|
|
content = await file.read() |
|
|
|
|
|
|
|
|
result = rag_system.add_document( |
|
|
file_content=content, |
|
|
filename=file.filename, |
|
|
file_type=file.content_type |
|
|
) |
|
|
|
|
|
results.append({ |
|
|
"filename": file.filename, |
|
|
"success": result["success"], |
|
|
"doc_id": result.get("doc_id"), |
|
|
"chunks": result.get("chunks"), |
|
|
"message": result.get("message"), |
|
|
"error": result.get("error") |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
results.append({ |
|
|
"filename": file.filename, |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return {"results": results} |
|
|
|
|
|
|
|
|
@router.delete("/rag/documents/{doc_id}") |
|
|
async def delete_document(doc_id: str): |
|
|
"""Delete a document from RAG system""" |
|
|
if not RAG_AVAILABLE: |
|
|
raise HTTPException(status_code=503, detail="RAG system not available") |
|
|
|
|
|
rag_system = get_rag_system() |
|
|
result = rag_system.remove_document(doc_id) |
|
|
|
|
|
if result["success"]: |
|
|
return result |
|
|
else: |
|
|
raise HTTPException(status_code=404, detail=result["error"]) |
|
|
|
|
|
|
|
|
@router.get("/rag/documents") |
|
|
async def get_documents(): |
|
|
"""Get information about uploaded documents""" |
|
|
if not RAG_AVAILABLE: |
|
|
raise HTTPException(status_code=503, detail="RAG system not available") |
|
|
|
|
|
rag_system = get_rag_system() |
|
|
return rag_system.get_documents_info() |
|
|
|
|
|
|
|
|
@router.post("/rag/search") |
|
|
async def search_documents(query: str, max_results: int = 3): |
|
|
"""Search through uploaded documents""" |
|
|
if not RAG_AVAILABLE: |
|
|
raise HTTPException(status_code=503, detail="RAG system not available") |
|
|
|
|
|
rag_system = get_rag_system() |
|
|
results = rag_system.search_similar(query, k=max_results) |
|
|
|
|
|
return {"query": query, "results": results} |
|
|
|
|
|
|
|
|
@router.post("/generate", response_model=PromptResponse) |
|
|
async def generate_text(request: PromptRequest): |
|
|
"""Generate text using the loaded model with optional RAG enhancement""" |
|
|
|
|
|
model_to_use = request.model_name if request.model_name else model_service.get_current_model() |
|
|
|
|
|
if not model_to_use: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="No model specified. Please load a model first." |
|
|
) |
|
|
|
|
|
if not model_service.is_model_loaded(model_to_use): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {model_to_use} is not loaded. Please load it first." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
enhanced_system_prompt = request.system_prompt |
|
|
|
|
|
|
|
|
use_rag = request.use_rag or False |
|
|
retrieval_count = request.retrieval_count or 3 |
|
|
|
|
|
if RAG_AVAILABLE and use_rag: |
|
|
rag_system = get_rag_system() |
|
|
|
|
|
|
|
|
context = rag_system.get_context_for_query(request.prompt, max_chunks=retrieval_count) |
|
|
|
|
|
if context: |
|
|
|
|
|
context_instruction = ( |
|
|
"\n\nAdditional Context from Documents:\n" |
|
|
"Use the following information to help answer the user's question. " |
|
|
"If the context is relevant, incorporate it into your response. " |
|
|
"If the context is not relevant, you can ignore it.\n\n" |
|
|
f"{context}\n" |
|
|
"---\n" |
|
|
) |
|
|
enhanced_system_prompt = (request.system_prompt or "") + context_instruction |
|
|
|
|
|
thinking_content, final_content, model_used, supports_thinking = chat_service.generate_response( |
|
|
prompt=request.prompt, |
|
|
model_name=model_to_use, |
|
|
messages=[msg.dict() for msg in request.messages] if request.messages else [], |
|
|
system_prompt=enhanced_system_prompt, |
|
|
temperature=request.temperature, |
|
|
max_new_tokens=request.max_new_tokens |
|
|
) |
|
|
|
|
|
return PromptResponse( |
|
|
thinking_content=thinking_content, |
|
|
content=final_content, |
|
|
model_used=model_used, |
|
|
supports_thinking=supports_thinking |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Generation error: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{full_path:path}") |
|
|
async def catch_all(request: Request, full_path: str): |
|
|
""" |
|
|
Catch-all route to serve index.html for any unmatched paths. |
|
|
This enables client-side routing for the React SPA. |
|
|
Skip static file paths. |
|
|
""" |
|
|
|
|
|
if full_path.startswith(('assets/', 'images/', 'static/')): |
|
|
from fastapi import HTTPException |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
|
|
from ..config import FRONTEND_DIST_DIR |
|
|
return FileResponse(f'{FRONTEND_DIST_DIR}/index.html') |
|
|
|