|
|
""" |
|
|
API routes for Edge LLM |
|
|
""" |
|
|
from fastapi import APIRouter, HTTPException, Request |
|
|
from fastapi.responses import FileResponse |
|
|
from ..models import ( |
|
|
PromptRequest, PromptResponse, ModelInfo, ModelsResponse, |
|
|
ModelLoadRequest, ModelUnloadRequest |
|
|
) |
|
|
from ..services.model_service import model_service |
|
|
from ..services.chat_service import chat_service |
|
|
from ..config import AVAILABLE_MODELS |
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
@router.get("/") |
|
|
async def read_index(): |
|
|
"""Serve the React app""" |
|
|
from ..config import FRONTEND_DIST_DIR |
|
|
return FileResponse(f'{FRONTEND_DIST_DIR}/index.html') |
|
|
|
|
|
|
|
|
@router.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return {"status": "healthy", "message": "Edge LLM API is running"} |
|
|
|
|
|
|
|
|
@router.get("/models", response_model=ModelsResponse) |
|
|
async def get_models(): |
|
|
"""Get available models and their status""" |
|
|
models = [] |
|
|
for model_name, info in AVAILABLE_MODELS.items(): |
|
|
models.append(ModelInfo( |
|
|
model_name=model_name, |
|
|
name=info["name"], |
|
|
supports_thinking=info["supports_thinking"], |
|
|
description=info["description"], |
|
|
size_gb=info["size_gb"], |
|
|
is_loaded=model_service.is_model_loaded(model_name), |
|
|
type=info["type"] |
|
|
)) |
|
|
|
|
|
return ModelsResponse( |
|
|
models=models, |
|
|
current_model=model_service.get_current_model() or "" |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/load-model") |
|
|
async def load_model(request: ModelLoadRequest): |
|
|
"""Load a specific model""" |
|
|
if request.model_name not in AVAILABLE_MODELS: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {request.model_name} not available" |
|
|
) |
|
|
|
|
|
success = model_service.load_model(request.model_name) |
|
|
if success: |
|
|
model_service.set_current_model(request.model_name) |
|
|
return { |
|
|
"message": f"Model {request.model_name} loaded successfully", |
|
|
"current_model": model_service.get_current_model() |
|
|
} |
|
|
else: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Failed to load model {request.model_name}" |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/unload-model") |
|
|
async def unload_model(request: ModelUnloadRequest): |
|
|
"""Unload a specific model""" |
|
|
success = model_service.unload_model(request.model_name) |
|
|
if success: |
|
|
return { |
|
|
"message": f"Model {request.model_name} unloaded successfully", |
|
|
"current_model": model_service.get_current_model() or "" |
|
|
} |
|
|
else: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model {request.model_name} not found in cache" |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/set-current-model") |
|
|
async def set_current_model(request: ModelLoadRequest): |
|
|
"""Set the current active model""" |
|
|
if not model_service.is_model_loaded(request.model_name): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {request.model_name} is not loaded. Please load it first." |
|
|
) |
|
|
|
|
|
model_service.set_current_model(request.model_name) |
|
|
return { |
|
|
"message": f"Current model set to {request.model_name}", |
|
|
"current_model": model_service.get_current_model() |
|
|
} |
|
|
|
|
|
|
|
|
@router.post("/generate", response_model=PromptResponse) |
|
|
async def generate_text(request: PromptRequest): |
|
|
"""Generate text using the loaded model""" |
|
|
|
|
|
model_to_use = request.model_name if request.model_name else model_service.get_current_model() |
|
|
|
|
|
if not model_to_use: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="No model specified. Please load a model first." |
|
|
) |
|
|
|
|
|
if not model_service.is_model_loaded(model_to_use): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {model_to_use} is not loaded. Please load it first." |
|
|
) |
|
|
|
|
|
try: |
|
|
thinking_content, final_content, model_used, supports_thinking = chat_service.generate_response( |
|
|
prompt=request.prompt, |
|
|
model_name=model_to_use, |
|
|
messages=[msg.dict() for msg in request.messages] if request.messages else [], |
|
|
system_prompt=request.system_prompt, |
|
|
temperature=request.temperature, |
|
|
max_new_tokens=request.max_new_tokens |
|
|
) |
|
|
|
|
|
return PromptResponse( |
|
|
thinking_content=thinking_content, |
|
|
content=final_content, |
|
|
model_used=model_used, |
|
|
supports_thinking=supports_thinking |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Generation error: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{full_path:path}") |
|
|
async def catch_all(request: Request, full_path: str): |
|
|
""" |
|
|
Catch-all route to serve index.html for any unmatched paths. |
|
|
This enables client-side routing for the React SPA. |
|
|
""" |
|
|
from ..config import FRONTEND_DIST_DIR |
|
|
return FileResponse(f'{FRONTEND_DIST_DIR}/index.html') |
|
|
|