smolvlm2-video-highlights2 / src /smolvlm2_handler.py
avinashHuggingface108's picture
Update deployment to use SmolVLM2-256M-Video-Instruct model
7ef6739
#!/usr/bin/env python3
"""
SmolVLM2 Model Handler
Handles loading and inference with SmolVLM2-256M-Instruct model (smallest model for HuggingFace Spaces)
"""
import os
import tempfile
# Set cache directories to writable locations for HuggingFace Spaces
if 'HF_HOME' not in os.environ:
# Use /tmp which is guaranteed to be writable in containers
CACHE_DIR = os.path.join("/tmp", ".cache", "huggingface")
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(os.path.join("/tmp", ".cache", "torch"), exist_ok=True)
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = CACHE_DIR
os.environ['TORCH_HOME'] = os.path.join("/tmp", ".cache", "torch")
os.environ['XDG_CACHE_HOME'] = os.path.join("/tmp", ".cache")
os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import requests
from typing import List, Union, Optional
import logging
import warnings
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Suppress some warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
class SmolVLM2Handler:
"""Handler for SmolVLM2 model operations"""
def __init__(self, model_name: str = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct", device: str = "auto"):
"""
Initialize SmolVLM2 model (2.2B version - better reasoning capabilities)
Args:
model_name: HuggingFace model identifier
device: Device to use ('auto', 'cpu', 'cuda', 'mps')
"""
self.model_name = model_name
self.device = self._get_device(device)
self.model = None
self.processor = None
logger.info(f"Initializing SmolVLM2 on device: {self.device}")
self._load_model()
def _get_device(self, device: str) -> str:
"""Determine the best device to use"""
if device == "auto":
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps" # Apple Silicon GPU
else:
return "cpu"
return device
def _load_model(self):
"""Load the model and processor"""
try:
logger.info("Loading processor...")
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForImageTextToText.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
trust_remote_code=True,
device_map=self.device if self.device != "cpu" else None
)
if self.device == "cpu":
self.model = self.model.to(self.device)
logger.info("✅ Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
raise
def process_image(self, image_input: Union[str, Image.Image]) -> Image.Image:
"""
Process image input into PIL Image
Args:
image_input: File path, URL, or PIL Image
Returns:
PIL Image object
"""
if isinstance(image_input, str):
if image_input.startswith(('http://', 'https://')):
# Download from URL
response = requests.get(image_input)
image = Image.open(requests.get(image_input, stream=True).raw)
else:
# Load from file path
image = Image.open(image_input)
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("Image input must be file path, URL, or PIL Image")
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
return image
def generate_response(
self,
image_input: Union[str, Image.Image, List[Image.Image]],
text_prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
do_sample: bool = True
) -> str:
"""
Generate response from image(s) and text prompt
Args:
image_input: Single image or list of images
text_prompt: Text prompt/question
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
do_sample: Whether to use sampling
Returns:
Generated text response
"""
try:
# Process images
if isinstance(image_input, list):
images = [self.process_image(img) for img in image_input]
else:
images = [self.process_image(image_input)]
# Create proper conversation format for SmolVLM2
messages = [
{
"role": "user",
"content": [{"type": "text", "text": text_prompt}]
}
]
# Add image content to the message
for img in images:
messages[0]["content"].insert(0, {"type": "image", "image": img})
# Apply chat template
try:
prompt = self.processor.apply_chat_template(
messages,
add_generation_prompt=True
)
except:
# Fallback to simple format if chat template fails
image_tokens = "<image>" * len(images)
prompt = f"{image_tokens}{text_prompt}"
# Prepare inputs
inputs = self.processor(
images=images,
text=prompt,
return_tensors="pt"
).to(self.device)
# Generate response with robust parameters optimized for scoring
with torch.no_grad():
try:
generated_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7, # Higher temperature for more varied responses
do_sample=True, # Enable sampling for variety
top_p=0.85, # Slightly lower top_p for more focused responses
top_k=40, # Add top_k for better control
repetition_penalty=1.2, # Higher repetition penalty
pad_token_id=self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True
)
except RuntimeError as e:
if "probability tensor" in str(e) or "nan" in str(e) or "inf" in str(e):
# Retry with more conservative parameters
logger.warning("Retrying with conservative parameters due to probability tensor error")
generated_ids = self.model.generate(
**inputs,
max_new_tokens=min(max_new_tokens, 256),
temperature=0.5, # Still some variety
do_sample=True,
top_p=0.9,
pad_token_id=self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True
)
else:
raise
# Decode only the new tokens (skip input)
input_length = inputs['input_ids'].shape[1]
new_tokens = generated_ids[0][input_length:]
generated_text = self.processor.tokenizer.decode(
new_tokens,
skip_special_tokens=True
).strip()
# Return meaningful response even if empty
if not generated_text:
return "I can see the image but cannot generate a specific description."
return generated_text
except Exception as e:
logger.error(f"❌ Error during generation: {e}")
raise
def analyze_video_frames(
self,
frames: List[Image.Image],
question: str,
max_frames: int = 8
) -> str:
"""
Analyze video frames and answer questions
Args:
frames: List of PIL Image frames
question: Question about the video
max_frames: Maximum number of frames to process
Returns:
Analysis result
"""
# Sample frames if too many
if len(frames) > max_frames:
step = len(frames) // max_frames
sampled_frames = frames[::step][:max_frames]
else:
sampled_frames = frames
logger.info(f"Analyzing {len(sampled_frames)} frames")
# Create a simple prompt for video analysis (don't add image tokens manually)
video_prompt = f"These are frames from a video. {question}"
return self.generate_response(sampled_frames, video_prompt)
def get_model_info(self) -> dict:
"""Get information about the loaded model"""
return {
"model_name": self.model_name,
"device": self.device,
"model_type": type(self.model).__name__,
"processor_type": type(self.processor).__name__,
"loaded": self.model is not None and self.processor is not None
}
def test_model():
"""Test the model with a simple example"""
try:
# Initialize model
vlm = SmolVLM2Handler()
print("📋 Model Info:")
info = vlm.get_model_info()
for key, value in info.items():
print(f" {key}: {value}")
# Test with a simple image (create a test image)
test_image = Image.new('RGB', (224, 224), color='blue')
test_prompt = "What color is this image?"
print(f"\n🔍 Testing with prompt: '{test_prompt}'")
response = vlm.generate_response(test_image, test_prompt)
print(f"📝 Response: {response}")
print("\n✅ Model test completed successfully!")
except Exception as e:
print(f"❌ Model test failed: {e}")
raise
if __name__ == "__main__":
test_model()