""" Model loading and management service """ from transformers import AutoModelForCausalLM, AutoTokenizer import torch from typing import Dict, Any from ..config import AVAILABLE_MODELS class ModelService: def __init__(self): self.models_cache: Dict[str, Dict[str, Any]] = {} self.current_model_name: str = None def load_model(self, model_name: str) -> bool: """Load a model into memory""" if model_name not in AVAILABLE_MODELS: print(f"Model {model_name} not available.") return False model_info = AVAILABLE_MODELS[model_name] # API models don't need to be "loaded" - they're always available if model_info["type"] == "api": print(f"API model {model_name} is always available") return True # Handle local models if model_name in self.models_cache: print(f"Model {model_name} already loaded.") return True try: print(f"Loading local model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) self.models_cache[model_name] = {"model": model, "tokenizer": tokenizer} print(f"Model {model_name} loaded successfully") return True except Exception as e: print(f"Error loading model {model_name}: {e}") return False def unload_model(self, model_name: str) -> bool: """Unload a model from memory""" model_info = AVAILABLE_MODELS.get(model_name, {}) # API models can't be "unloaded" if model_info.get("type") == "api": print(f"API model {model_name} cannot be unloaded") return True # Handle local models if model_name in self.models_cache: del self.models_cache[model_name] if self.current_model_name == model_name: self.current_model_name = None print(f"Model {model_name} unloaded") return True return False def set_current_model(self, model_name: str) -> bool: """Set the current active model""" if model_name not in AVAILABLE_MODELS: return False model_info = AVAILABLE_MODELS[model_name] # API models are always "available" if model_info["type"] == "api": self.current_model_name = model_name return True # Local models need to be loaded first if model_name not in self.models_cache: if not self.load_model(model_name): return False self.current_model_name = model_name return True def is_model_loaded(self, model_name: str) -> bool: """Check if a model is loaded/available""" model_info = AVAILABLE_MODELS.get(model_name, {}) # API models are always available if model_info.get("type") == "api": return True # Local models need to be in cache return model_name in self.models_cache def get_loaded_models(self) -> list: """Get list of currently loaded/available models""" loaded = [] for model_name, model_info in AVAILABLE_MODELS.items(): if model_info["type"] == "api" or model_name in self.models_cache: loaded.append(model_name) return loaded def get_current_model(self) -> str: """Get the current active model""" return self.current_model_name # Global model service instance model_service = ModelService()