edgellm / backend /services /model_service.py
wu981526092's picture
add
6a50e97
raw
history blame
3.83 kB
"""
Model loading and management service
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import Dict, Any
from ..config import AVAILABLE_MODELS
class ModelService:
def __init__(self):
self.models_cache: Dict[str, Dict[str, Any]] = {}
self.current_model_name: str = None
def load_model(self, model_name: str) -> bool:
"""Load a model into memory"""
if model_name not in AVAILABLE_MODELS:
print(f"Model {model_name} not available.")
return False
model_info = AVAILABLE_MODELS[model_name]
# API models don't need to be "loaded" - they're always available
if model_info["type"] == "api":
print(f"API model {model_name} is always available")
return True
# Handle local models
if model_name in self.models_cache:
print(f"Model {model_name} already loaded.")
return True
try:
print(f"Loading local model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.models_cache[model_name] = {"model": model, "tokenizer": tokenizer}
print(f"Model {model_name} loaded successfully")
return True
except Exception as e:
print(f"Error loading model {model_name}: {e}")
return False
def unload_model(self, model_name: str) -> bool:
"""Unload a model from memory"""
model_info = AVAILABLE_MODELS.get(model_name, {})
# API models can't be "unloaded"
if model_info.get("type") == "api":
print(f"API model {model_name} cannot be unloaded")
return True
# Handle local models
if model_name in self.models_cache:
del self.models_cache[model_name]
if self.current_model_name == model_name:
self.current_model_name = None
print(f"Model {model_name} unloaded")
return True
return False
def set_current_model(self, model_name: str) -> bool:
"""Set the current active model"""
if model_name not in AVAILABLE_MODELS:
return False
model_info = AVAILABLE_MODELS[model_name]
# API models are always "available"
if model_info["type"] == "api":
self.current_model_name = model_name
return True
# Local models need to be loaded first
if model_name not in self.models_cache:
if not self.load_model(model_name):
return False
self.current_model_name = model_name
return True
def is_model_loaded(self, model_name: str) -> bool:
"""Check if a model is loaded/available"""
model_info = AVAILABLE_MODELS.get(model_name, {})
# API models are always available
if model_info.get("type") == "api":
return True
# Local models need to be in cache
return model_name in self.models_cache
def get_loaded_models(self) -> list:
"""Get list of currently loaded/available models"""
loaded = []
for model_name, model_info in AVAILABLE_MODELS.items():
if model_info["type"] == "api" or model_name in self.models_cache:
loaded.append(model_name)
return loaded
def get_current_model(self) -> str:
"""Get the current active model"""
return self.current_model_name
# Global model service instance
model_service = ModelService()