edgellm / backend /services /model_service.py
wu981526092's picture
add
6a50e97
"""
Model loading and management service
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import Dict, Any
from ..config import AVAILABLE_MODELS
class ModelService:
def __init__(self):
self.models_cache: Dict[str, Dict[str, Any]] = {}
self.current_model_name: str = None
def load_model(self, model_name: str) -> bool:
"""Load a model into memory"""
if model_name not in AVAILABLE_MODELS:
print(f"Model {model_name} not available.")
return False
model_info = AVAILABLE_MODELS[model_name]
# API models don't need to be "loaded" - they're always available
if model_info["type"] == "api":
print(f"API model {model_name} is always available")
return True
# Handle local models
if model_name in self.models_cache:
print(f"Model {model_name} already loaded.")
return True
try:
print(f"Loading local model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.models_cache[model_name] = {"model": model, "tokenizer": tokenizer}
print(f"Model {model_name} loaded successfully")
return True
except Exception as e:
print(f"Error loading model {model_name}: {e}")
return False
def unload_model(self, model_name: str) -> bool:
"""Unload a model from memory"""
model_info = AVAILABLE_MODELS.get(model_name, {})
# API models can't be "unloaded"
if model_info.get("type") == "api":
print(f"API model {model_name} cannot be unloaded")
return True
# Handle local models
if model_name in self.models_cache:
del self.models_cache[model_name]
if self.current_model_name == model_name:
self.current_model_name = None
print(f"Model {model_name} unloaded")
return True
return False
def set_current_model(self, model_name: str) -> bool:
"""Set the current active model"""
if model_name not in AVAILABLE_MODELS:
return False
model_info = AVAILABLE_MODELS[model_name]
# API models are always "available"
if model_info["type"] == "api":
self.current_model_name = model_name
return True
# Local models need to be loaded first
if model_name not in self.models_cache:
if not self.load_model(model_name):
return False
self.current_model_name = model_name
return True
def is_model_loaded(self, model_name: str) -> bool:
"""Check if a model is loaded/available"""
model_info = AVAILABLE_MODELS.get(model_name, {})
# API models are always available
if model_info.get("type") == "api":
return True
# Local models need to be in cache
return model_name in self.models_cache
def get_loaded_models(self) -> list:
"""Get list of currently loaded/available models"""
loaded = []
for model_name, model_info in AVAILABLE_MODELS.items():
if model_info["type"] == "api" or model_name in self.models_cache:
loaded.append(model_name)
return loaded
def get_current_model(self) -> str:
"""Get the current active model"""
return self.current_model_name
# Global model service instance
model_service = ModelService()