|
|
"""
|
|
|
Fallback model handler for Codette
|
|
|
Uses open source models as fallbacks when proprietary models are unavailable
|
|
|
"""
|
|
|
import os
|
|
|
from typing import Optional, Dict, Any
|
|
|
import json
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FallbackModelManager:
|
|
|
def __init__(self):
|
|
|
self.models_dir = Path(__file__).parent
|
|
|
self.fallback_dir = self.models_dir / 'fallback'
|
|
|
self.fallback_dir.mkdir(exist_ok=True)
|
|
|
self.model_cache: Dict[str, Any] = {}
|
|
|
|
|
|
def get_model(self, model_name: str) -> Optional[Any]:
|
|
|
"""Get a model, falling back to open source alternatives if needed."""
|
|
|
try:
|
|
|
|
|
|
if model_name in self.model_cache:
|
|
|
return self.model_cache[model_name]
|
|
|
|
|
|
original_path = self.models_dir / f"{model_name}.pt"
|
|
|
if original_path.exists():
|
|
|
import torch
|
|
|
model = torch.load(original_path)
|
|
|
self.model_cache[model_name] = model
|
|
|
return model
|
|
|
|
|
|
|
|
|
return self._load_fallback(model_name)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading model {model_name}: {str(e)}")
|
|
|
return self._load_fallback(model_name)
|
|
|
|
|
|
def _load_fallback(self, model_name: str) -> Optional[Any]:
|
|
|
"""Load an open source fallback model."""
|
|
|
try:
|
|
|
if model_name.startswith('nlp_'):
|
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
model_id = "bert-base-uncased"
|
|
|
model = AutoModel.from_pretrained(model_id)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
return {'model': model, 'tokenizer': tokenizer}
|
|
|
|
|
|
elif model_name.startswith('vision_'):
|
|
|
import torchvision.models as models
|
|
|
return models.resnet18(pretrained=True)
|
|
|
|
|
|
elif model_name.startswith('quantum_'):
|
|
|
|
|
|
return self._create_quantum_approximation()
|
|
|
|
|
|
else:
|
|
|
logger.warning(f"No fallback available for {model_name}")
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading fallback for {model_name}: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def _create_quantum_approximation(self) -> Any:
|
|
|
"""Create a classical approximation of quantum operations."""
|
|
|
try:
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class QuantumApproximator(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.layers = nn.Sequential(
|
|
|
nn.Linear(64, 128),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(128, 64)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.layers(x)
|
|
|
|
|
|
return QuantumApproximator()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error creating quantum approximation: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def download_if_needed(self, model_name: str):
|
|
|
"""Download model files if they don't exist."""
|
|
|
try:
|
|
|
if not (self.fallback_dir / f"{model_name}.pt").exists():
|
|
|
if model_name.startswith('nlp_'):
|
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
model_id = "bert-base-uncased"
|
|
|
AutoModel.from_pretrained(model_id)
|
|
|
AutoTokenizer.from_pretrained(model_id)
|
|
|
elif model_name.startswith('vision_'):
|
|
|
import torchvision.models as models
|
|
|
models.resnet18(pretrained=True)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error downloading model {model_name}: {str(e)}")
|
|
|
|
|
|
|
|
|
fallback_manager = FallbackModelManager()
|
|
|
|