|
|
import logging
|
|
|
import torch
|
|
|
from typing import Optional, Tuple
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelManager:
|
|
|
def __init__(self):
|
|
|
"""Initialize the model manager."""
|
|
|
self.current_model = None
|
|
|
self.current_tokenizer = None
|
|
|
self.current_model_name = None
|
|
|
self.load_model()
|
|
|
|
|
|
def load_model(self, model_name: Optional[str] = None) -> bool:
|
|
|
"""
|
|
|
Load the language model, trying different models in order of preference.
|
|
|
|
|
|
Args:
|
|
|
model_name: Optional specific model to load
|
|
|
|
|
|
Returns:
|
|
|
bool: True if any model was loaded successfully
|
|
|
"""
|
|
|
models_to_try = [
|
|
|
model_name
|
|
|
] if model_name else [
|
|
|
"mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
"microsoft/phi-2",
|
|
|
"gpt2"
|
|
|
]
|
|
|
|
|
|
for model_id in models_to_try:
|
|
|
try:
|
|
|
logger.info(f"Loading {model_id}")
|
|
|
self.current_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
self.current_model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_id,
|
|
|
device_map="auto",
|
|
|
torch_dtype=torch.float16,
|
|
|
load_in_8bit=True
|
|
|
)
|
|
|
self.current_model_name = model_id
|
|
|
self.current_model.eval()
|
|
|
logger.info(f"Successfully loaded {model_id}")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to load {model_id}: {e}")
|
|
|
continue
|
|
|
|
|
|
return False
|
|
|
|
|
|
def get_current_model(self) -> Tuple[Optional[AutoModelForCausalLM], Optional[AutoTokenizer]]:
|
|
|
"""Get currently loaded model and tokenizer."""
|
|
|
return self.current_model, self.current_tokenizer
|
|
|
|
|
|
def is_model_loaded(self) -> bool:
|
|
|
"""Check if a model is currently loaded."""
|
|
|
return self.current_model is not None and self.current_tokenizer is not None |