import logging import torch from typing import Optional, Tuple from transformers import AutoModelForCausalLM, AutoTokenizer logger = logging.getLogger(__name__) class ModelManager: def __init__(self): """Initialize the model manager.""" self.current_model = None self.current_tokenizer = None self.current_model_name = None self.load_model() def load_model(self, model_name: Optional[str] = None) -> bool: """ Load the language model, trying different models in order of preference. Args: model_name: Optional specific model to load Returns: bool: True if any model was loaded successfully """ models_to_try = [ model_name ] if model_name else [ "mistralai/Mistral-7B-Instruct-v0.2", # Best balance of capability/size "microsoft/phi-2", # Fallback "gpt2" # Last resort ] for model_id in models_to_try: try: logger.info(f"Loading {model_id}") self.current_tokenizer = AutoTokenizer.from_pretrained(model_id) self.current_model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16, # Use half precision load_in_8bit=True ) self.current_model_name = model_id self.current_model.eval() logger.info(f"Successfully loaded {model_id}") return True except Exception as e: logger.warning(f"Failed to load {model_id}: {e}") continue return False def get_current_model(self) -> Tuple[Optional[AutoModelForCausalLM], Optional[AutoTokenizer]]: """Get currently loaded model and tokenizer.""" return self.current_model, self.current_tokenizer def is_model_loaded(self) -> bool: """Check if a model is currently loaded.""" return self.current_model is not None and self.current_tokenizer is not None