File size: 3,834 Bytes
d8e039b
 
 
 
6a50e97
 
d8e039b
 
 
 
 
 
6a50e97
 
d8e039b
6a50e97
d8e039b
6a50e97
d8e039b
 
6a50e97
 
 
 
 
 
 
 
 
 
 
 
d8e039b
6a50e97
d8e039b
 
 
 
 
 
6a50e97
d8e039b
 
 
 
 
6a50e97
d8e039b
6a50e97
 
 
 
 
 
 
 
 
d8e039b
 
 
 
 
 
 
6a50e97
d8e039b
 
6a50e97
 
 
 
 
 
 
d8e039b
 
6a50e97
 
 
 
 
 
 
 
 
d8e039b
6a50e97
 
 
 
 
 
 
 
d8e039b
6a50e97
d8e039b
6a50e97
 
 
 
 
 
 
 
d8e039b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Model loading and management service
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import Dict, Any
from ..config import AVAILABLE_MODELS


class ModelService:
    def __init__(self):
        self.models_cache: Dict[str, Dict[str, Any]] = {}
        self.current_model_name: str = None
    
    def load_model(self, model_name: str) -> bool:
        """Load a model into memory"""
        if model_name not in AVAILABLE_MODELS:
            print(f"Model {model_name} not available.")
            return False
        
        model_info = AVAILABLE_MODELS[model_name]
        
        # API models don't need to be "loaded" - they're always available
        if model_info["type"] == "api":
            print(f"API model {model_name} is always available")
            return True
        
        # Handle local models
        if model_name in self.models_cache:
            print(f"Model {model_name} already loaded.")
            return True
        
        try:
            print(f"Loading local model: {model_name}")
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            self.models_cache[model_name] = {"model": model, "tokenizer": tokenizer}
            print(f"Model {model_name} loaded successfully")
            return True
        except Exception as e:
            print(f"Error loading model {model_name}: {e}")
            return False
    
    def unload_model(self, model_name: str) -> bool:
        """Unload a model from memory"""
        model_info = AVAILABLE_MODELS.get(model_name, {})
        
        # API models can't be "unloaded"
        if model_info.get("type") == "api":
            print(f"API model {model_name} cannot be unloaded")
            return True
        
        # Handle local models
        if model_name in self.models_cache:
            del self.models_cache[model_name]
            if self.current_model_name == model_name:
                self.current_model_name = None
            print(f"Model {model_name} unloaded")
            return True
        return False
    
    def set_current_model(self, model_name: str) -> bool:
        """Set the current active model"""
        if model_name not in AVAILABLE_MODELS:
            return False
        
        model_info = AVAILABLE_MODELS[model_name]
        
        # API models are always "available"
        if model_info["type"] == "api":
            self.current_model_name = model_name
            return True
        
        # Local models need to be loaded first
        if model_name not in self.models_cache:
            if not self.load_model(model_name):
                return False
        
        self.current_model_name = model_name
        return True
    
    def is_model_loaded(self, model_name: str) -> bool:
        """Check if a model is loaded/available"""
        model_info = AVAILABLE_MODELS.get(model_name, {})
        
        # API models are always available
        if model_info.get("type") == "api":
            return True
        
        # Local models need to be in cache
        return model_name in self.models_cache
    
    def get_loaded_models(self) -> list:
        """Get list of currently loaded/available models"""
        loaded = []
        for model_name, model_info in AVAILABLE_MODELS.items():
            if model_info["type"] == "api" or model_name in self.models_cache:
                loaded.append(model_name)
        return loaded
    
    def get_current_model(self) -> str:
        """Get the current active model"""
        return self.current_model_name


# Global model service instance
model_service = ModelService()