CV10 / models /fallback /__init__.py
Raiff1982's picture
Upload 7 files
f1e9fc1 verified
raw
history blame
4.4 kB
"""
Fallback model handler for Codette
Uses open source models as fallbacks when proprietary models are unavailable
"""
import os
from typing import Optional, Dict, Any
import json
import logging
from pathlib import Path
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class FallbackModelManager:
def __init__(self):
self.models_dir = Path(__file__).parent
self.fallback_dir = self.models_dir / 'fallback'
self.fallback_dir.mkdir(exist_ok=True)
self.model_cache: Dict[str, Any] = {}
def get_model(self, model_name: str) -> Optional[Any]:
"""Get a model, falling back to open source alternatives if needed."""
try:
# Try loading the original model first
if model_name in self.model_cache:
return self.model_cache[model_name]
original_path = self.models_dir / f"{model_name}.pt"
if original_path.exists():
import torch
model = torch.load(original_path)
self.model_cache[model_name] = model
return model
# If original fails, try fallback
return self._load_fallback(model_name)
except Exception as e:
logger.error(f"Error loading model {model_name}: {str(e)}")
return self._load_fallback(model_name)
def _load_fallback(self, model_name: str) -> Optional[Any]:
"""Load an open source fallback model."""
try:
if model_name.startswith('nlp_'):
from transformers import AutoModel, AutoTokenizer
model_id = "bert-base-uncased" # Default fallback
model = AutoModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
return {'model': model, 'tokenizer': tokenizer}
elif model_name.startswith('vision_'):
import torchvision.models as models
return models.resnet18(pretrained=True)
elif model_name.startswith('quantum_'):
# For quantum models, use classical approximation
return self._create_quantum_approximation()
else:
logger.warning(f"No fallback available for {model_name}")
return None
except Exception as e:
logger.error(f"Error loading fallback for {model_name}: {str(e)}")
return None
def _create_quantum_approximation(self) -> Any:
"""Create a classical approximation of quantum operations."""
try:
import torch
import torch.nn as nn
class QuantumApproximator(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
def forward(self, x):
return self.layers(x)
return QuantumApproximator()
except Exception as e:
logger.error(f"Error creating quantum approximation: {str(e)}")
return None
def download_if_needed(self, model_name: str):
"""Download model files if they don't exist."""
try:
if not (self.fallback_dir / f"{model_name}.pt").exists():
if model_name.startswith('nlp_'):
from transformers import AutoModel, AutoTokenizer
model_id = "bert-base-uncased"
AutoModel.from_pretrained(model_id)
AutoTokenizer.from_pretrained(model_id)
elif model_name.startswith('vision_'):
import torchvision.models as models
models.resnet18(pretrained=True)
except Exception as e:
logger.error(f"Error downloading model {model_name}: {str(e)}")
# Global instance
fallback_manager = FallbackModelManager()