Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ModelConfig: | |
| """Configuration for different LLM models optimized for Hugging Face Spaces""" | |
| MODELS = { | |
| "dialogpt-medium": { | |
| "name": "microsoft/DialoGPT-medium", | |
| "description": "Conversational AI model, good for chat", | |
| "max_length": 512, | |
| "memory_usage": "medium", | |
| "recommended_for": "chat, conversation" | |
| }, | |
| "dialogpt-small": { | |
| "name": "microsoft/DialoGPT-small", | |
| "description": "Smaller conversational model, faster inference", | |
| "max_length": 256, | |
| "memory_usage": "low", | |
| "recommended_for": "quick responses, limited resources" | |
| }, | |
| "gpt2": { | |
| "name": "gpt2", | |
| "description": "General purpose text generation", | |
| "max_length": 1024, | |
| "memory_usage": "medium", | |
| "recommended_for": "text generation, creative writing" | |
| }, | |
| "distilgpt2": { | |
| "name": "distilgpt2", | |
| "description": "Distilled GPT-2, faster and smaller", | |
| "max_length": 512, | |
| "memory_usage": "low", | |
| "recommended_for": "fast inference, resource constrained" | |
| }, | |
| "flan-t5-small": { | |
| "name": "google/flan-t5-small", | |
| "description": "Instruction-tuned T5 model", | |
| "max_length": 512, | |
| "memory_usage": "low", | |
| "recommended_for": "instruction following, Q&A" | |
| } | |
| } | |
| def get_model_info(cls, model_key: str = None): | |
| """Get information about available models""" | |
| if model_key: | |
| return cls.MODELS.get(model_key) | |
| return cls.MODELS | |
| def get_recommended_model(cls, use_case: str = "general"): | |
| """Get recommended model based on use case""" | |
| recommendations = { | |
| "chat": "dialogpt-medium", | |
| "fast": "distilgpt2", | |
| "general": "gpt2", | |
| "qa": "flan-t5-small", | |
| "low_memory": "dialogpt-small" | |
| } | |
| return recommendations.get(use_case, "dialogpt-medium") | |
| class ModelManager: | |
| """Manages model loading and inference""" | |
| def __init__(self, model_name: str = None): | |
| self.model_name = model_name or os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") | |
| self.model = None | |
| self.tokenizer = None | |
| self.pipeline = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.loaded = False | |
| def load_model(self): | |
| """Load the specified model""" | |
| try: | |
| logger.info(f"Loading model: {self.model_name}") | |
| logger.info(f"Using device: {self.device}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| padding_side="left" | |
| ) | |
| # Add padding token if it doesn't exist | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Load model with optimizations | |
| model_kwargs = { | |
| "low_cpu_mem_usage": True, | |
| "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32, | |
| } | |
| if self.device == "cuda": | |
| model_kwargs["device_map"] = "auto" | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| **model_kwargs | |
| ) | |
| # Move to device if not using device_map | |
| if self.device == "cpu": | |
| self.model = self.model.to(self.device) | |
| # Create pipeline | |
| self.pipeline = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=0 if self.device == "cuda" else -1, | |
| return_full_text=False | |
| ) | |
| self.loaded = True | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise e | |
| def generate_response(self, | |
| prompt: str, | |
| max_length: int = 100, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| do_sample: bool = True) -> str: | |
| """Generate response using the loaded model""" | |
| if not self.loaded: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| try: | |
| # Generate response | |
| outputs = self.pipeline( | |
| prompt, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=do_sample, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| truncation=True | |
| ) | |
| # Extract generated text | |
| if outputs and len(outputs) > 0: | |
| generated_text = outputs[0]['generated_text'] | |
| return generated_text.strip() | |
| else: | |
| return "Sorry, I couldn't generate a response." | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| raise e | |
| def get_model_info(self): | |
| """Get information about the loaded model""" | |
| return { | |
| "model_name": self.model_name, | |
| "device": self.device, | |
| "loaded": self.loaded, | |
| "tokenizer_vocab_size": len(self.tokenizer) if self.tokenizer else None, | |
| "model_parameters": sum(p.numel() for p in self.model.parameters()) if self.model else None | |
| } | |
| def unload_model(self): | |
| """Unload the model to free memory""" | |
| if self.model: | |
| del self.model | |
| self.model = None | |
| if self.tokenizer: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| if self.pipeline: | |
| del self.pipeline | |
| self.pipeline = None | |
| # Clear CUDA cache if using GPU | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self.loaded = False | |
| logger.info("Model unloaded successfully") | |
| # Global model manager instance | |
| model_manager = None | |
| def get_model_manager(model_name: str = None) -> ModelManager: | |
| """Get or create the global model manager instance""" | |
| global model_manager | |
| if model_manager is None: | |
| model_manager = ModelManager(model_name) | |
| return model_manager | |
| def initialize_model(model_name: str = None): | |
| """Initialize and load the model""" | |
| manager = get_model_manager(model_name) | |
| if not manager.loaded: | |
| manager.load_model() | |
| return manager | |