""" Model loading and management service """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Dict, Any, Optional from ..config import AVAILABLE_MODELS class ModelService: def __init__(self): self.models_cache: Dict[str, Dict[str, Any]] = {} self.current_model_name: Optional[str] = None def load_model(self, model_name: str) -> bool: """Load a model into the cache""" if model_name in self.models_cache: return True if model_name not in AVAILABLE_MODELS: return False try: print(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) self.models_cache[model_name] = { "model": model, "tokenizer": tokenizer } print(f"Model {model_name} loaded successfully") return True except Exception as e: print(f"Error loading model {model_name}: {e}") return False def unload_model(self, model_name: str) -> bool: """Unload a model from the cache""" if model_name in self.models_cache: del self.models_cache[model_name] if self.current_model_name == model_name: self.current_model_name = None print(f"Model {model_name} unloaded") return True return False def set_current_model(self, model_name: str) -> bool: """Set the current active model""" if model_name in self.models_cache: self.current_model_name = model_name return True return False def get_model_info(self, model_name: str) -> Dict[str, Any]: """Get model configuration info""" return AVAILABLE_MODELS.get(model_name, {}) def is_model_loaded(self, model_name: str) -> bool: """Check if a model is loaded""" return model_name in self.models_cache def get_loaded_models(self) -> list: """Get list of currently loaded models""" return list(self.models_cache.keys()) def get_current_model(self) -> Optional[str]: """Get the current active model""" return self.current_model_name # Global model service instance model_service = ModelService()