|
|
""" |
|
|
Model loading and management service |
|
|
""" |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
from typing import Dict, Any |
|
|
from ..config import AVAILABLE_MODELS |
|
|
|
|
|
|
|
|
class ModelService: |
|
|
def __init__(self): |
|
|
self.models_cache: Dict[str, Dict[str, Any]] = {} |
|
|
self.current_model_name: str = None |
|
|
|
|
|
def load_model(self, model_name: str) -> bool: |
|
|
"""Load a model into memory""" |
|
|
if model_name not in AVAILABLE_MODELS: |
|
|
print(f"Model {model_name} not available.") |
|
|
return False |
|
|
|
|
|
model_info = AVAILABLE_MODELS[model_name] |
|
|
|
|
|
|
|
|
if model_info["type"] == "api": |
|
|
print(f"API model {model_name} is always available") |
|
|
return True |
|
|
|
|
|
|
|
|
if model_name in self.models_cache: |
|
|
print(f"Model {model_name} already loaded.") |
|
|
return True |
|
|
|
|
|
try: |
|
|
print(f"Loading local model: {model_name}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
self.models_cache[model_name] = {"model": model, "tokenizer": tokenizer} |
|
|
print(f"Model {model_name} loaded successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error loading model {model_name}: {e}") |
|
|
return False |
|
|
|
|
|
def unload_model(self, model_name: str) -> bool: |
|
|
"""Unload a model from memory""" |
|
|
model_info = AVAILABLE_MODELS.get(model_name, {}) |
|
|
|
|
|
|
|
|
if model_info.get("type") == "api": |
|
|
print(f"API model {model_name} cannot be unloaded") |
|
|
return True |
|
|
|
|
|
|
|
|
if model_name in self.models_cache: |
|
|
del self.models_cache[model_name] |
|
|
if self.current_model_name == model_name: |
|
|
self.current_model_name = None |
|
|
print(f"Model {model_name} unloaded") |
|
|
return True |
|
|
return False |
|
|
|
|
|
def set_current_model(self, model_name: str) -> bool: |
|
|
"""Set the current active model""" |
|
|
if model_name not in AVAILABLE_MODELS: |
|
|
return False |
|
|
|
|
|
model_info = AVAILABLE_MODELS[model_name] |
|
|
|
|
|
|
|
|
if model_info["type"] == "api": |
|
|
self.current_model_name = model_name |
|
|
return True |
|
|
|
|
|
|
|
|
if model_name not in self.models_cache: |
|
|
if not self.load_model(model_name): |
|
|
return False |
|
|
|
|
|
self.current_model_name = model_name |
|
|
return True |
|
|
|
|
|
def is_model_loaded(self, model_name: str) -> bool: |
|
|
"""Check if a model is loaded/available""" |
|
|
model_info = AVAILABLE_MODELS.get(model_name, {}) |
|
|
|
|
|
|
|
|
if model_info.get("type") == "api": |
|
|
return True |
|
|
|
|
|
|
|
|
return model_name in self.models_cache |
|
|
|
|
|
def get_loaded_models(self) -> list: |
|
|
"""Get list of currently loaded/available models""" |
|
|
loaded = [] |
|
|
for model_name, model_info in AVAILABLE_MODELS.items(): |
|
|
if model_info["type"] == "api" or model_name in self.models_cache: |
|
|
loaded.append(model_name) |
|
|
return loaded |
|
|
|
|
|
def get_current_model(self) -> str: |
|
|
"""Get the current active model""" |
|
|
return self.current_model_name |
|
|
|
|
|
|
|
|
|
|
|
model_service = ModelService() |
|
|
|