|
|
|
|
|
""" |
|
|
SmolVLM2 Model Handler |
|
|
Handles loading and inference with SmolVLM2-256M-Instruct model (smallest model for HuggingFace Spaces) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
|
|
|
if 'HF_HOME' not in os.environ: |
|
|
|
|
|
CACHE_DIR = os.path.join("/tmp", ".cache", "huggingface") |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
os.makedirs(os.path.join("/tmp", ".cache", "torch"), exist_ok=True) |
|
|
os.environ['HF_HOME'] = CACHE_DIR |
|
|
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR |
|
|
os.environ['HF_DATASETS_CACHE'] = CACHE_DIR |
|
|
os.environ['TORCH_HOME'] = os.path.join("/tmp", ".cache", "torch") |
|
|
os.environ['XDG_CACHE_HOME'] = os.path.join("/tmp", ".cache") |
|
|
os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR |
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
from PIL import Image |
|
|
import requests |
|
|
from typing import List, Union, Optional |
|
|
import logging |
|
|
import warnings |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
class SmolVLM2Handler: |
|
|
"""Handler for SmolVLM2 model operations""" |
|
|
|
|
|
def __init__(self, model_name: str = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct", device: str = "auto"): |
|
|
""" |
|
|
Initialize SmolVLM2 model (2.2B version - better reasoning capabilities) |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace model identifier |
|
|
device: Device to use ('auto', 'cpu', 'cuda', 'mps') |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.device = self._get_device(device) |
|
|
self.model = None |
|
|
self.processor = None |
|
|
|
|
|
logger.info(f"Initializing SmolVLM2 on device: {self.device}") |
|
|
self._load_model() |
|
|
|
|
|
def _get_device(self, device: str) -> str: |
|
|
"""Determine the best device to use""" |
|
|
if device == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
else: |
|
|
return "cpu" |
|
|
return device |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the model and processor""" |
|
|
try: |
|
|
logger.info("Loading processor...") |
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
self.model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
logger.info("Loading model...") |
|
|
self.model = AutoModelForImageTextToText.from_pretrained( |
|
|
self.model_name, |
|
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32, |
|
|
trust_remote_code=True, |
|
|
device_map=self.device if self.device != "cpu" else None |
|
|
) |
|
|
|
|
|
if self.device == "cpu": |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
logger.info("✅ Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
def process_image(self, image_input: Union[str, Image.Image]) -> Image.Image: |
|
|
""" |
|
|
Process image input into PIL Image |
|
|
|
|
|
Args: |
|
|
image_input: File path, URL, or PIL Image |
|
|
|
|
|
Returns: |
|
|
PIL Image object |
|
|
""" |
|
|
if isinstance(image_input, str): |
|
|
if image_input.startswith(('http://', 'https://')): |
|
|
|
|
|
response = requests.get(image_input) |
|
|
image = Image.open(requests.get(image_input, stream=True).raw) |
|
|
else: |
|
|
|
|
|
image = Image.open(image_input) |
|
|
elif isinstance(image_input, Image.Image): |
|
|
image = image_input |
|
|
else: |
|
|
raise ValueError("Image input must be file path, URL, or PIL Image") |
|
|
|
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
return image |
|
|
|
|
|
def generate_response( |
|
|
self, |
|
|
image_input: Union[str, Image.Image, List[Image.Image]], |
|
|
text_prompt: str, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
do_sample: bool = True |
|
|
) -> str: |
|
|
""" |
|
|
Generate response from image(s) and text prompt |
|
|
|
|
|
Args: |
|
|
image_input: Single image or list of images |
|
|
text_prompt: Text prompt/question |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
do_sample: Whether to use sampling |
|
|
|
|
|
Returns: |
|
|
Generated text response |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(image_input, list): |
|
|
images = [self.process_image(img) for img in image_input] |
|
|
else: |
|
|
images = [self.process_image(image_input)] |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [{"type": "text", "text": text_prompt}] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
for img in images: |
|
|
messages[0]["content"].insert(0, {"type": "image", "image": img}) |
|
|
|
|
|
|
|
|
try: |
|
|
prompt = self.processor.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
except: |
|
|
|
|
|
image_tokens = "<image>" * len(images) |
|
|
prompt = f"{image_tokens}{text_prompt}" |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
images=images, |
|
|
text=prompt, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
try: |
|
|
generated_ids = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.85, |
|
|
top_k=40, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=self.processor.tokenizer.eos_token_id, |
|
|
eos_token_id=self.processor.tokenizer.eos_token_id, |
|
|
use_cache=True |
|
|
) |
|
|
except RuntimeError as e: |
|
|
if "probability tensor" in str(e) or "nan" in str(e) or "inf" in str(e): |
|
|
|
|
|
logger.warning("Retrying with conservative parameters due to probability tensor error") |
|
|
generated_ids = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=min(max_new_tokens, 256), |
|
|
temperature=0.5, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
pad_token_id=self.processor.tokenizer.eos_token_id, |
|
|
eos_token_id=self.processor.tokenizer.eos_token_id, |
|
|
use_cache=True |
|
|
) |
|
|
else: |
|
|
raise |
|
|
|
|
|
|
|
|
input_length = inputs['input_ids'].shape[1] |
|
|
new_tokens = generated_ids[0][input_length:] |
|
|
|
|
|
generated_text = self.processor.tokenizer.decode( |
|
|
new_tokens, |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
|
|
|
if not generated_text: |
|
|
return "I can see the image but cannot generate a specific description." |
|
|
|
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error during generation: {e}") |
|
|
raise |
|
|
|
|
|
def analyze_video_frames( |
|
|
self, |
|
|
frames: List[Image.Image], |
|
|
question: str, |
|
|
max_frames: int = 8 |
|
|
) -> str: |
|
|
""" |
|
|
Analyze video frames and answer questions |
|
|
|
|
|
Args: |
|
|
frames: List of PIL Image frames |
|
|
question: Question about the video |
|
|
max_frames: Maximum number of frames to process |
|
|
|
|
|
Returns: |
|
|
Analysis result |
|
|
""" |
|
|
|
|
|
if len(frames) > max_frames: |
|
|
step = len(frames) // max_frames |
|
|
sampled_frames = frames[::step][:max_frames] |
|
|
else: |
|
|
sampled_frames = frames |
|
|
|
|
|
logger.info(f"Analyzing {len(sampled_frames)} frames") |
|
|
|
|
|
|
|
|
video_prompt = f"These are frames from a video. {question}" |
|
|
|
|
|
return self.generate_response(sampled_frames, video_prompt) |
|
|
|
|
|
def get_model_info(self) -> dict: |
|
|
"""Get information about the loaded model""" |
|
|
return { |
|
|
"model_name": self.model_name, |
|
|
"device": self.device, |
|
|
"model_type": type(self.model).__name__, |
|
|
"processor_type": type(self.processor).__name__, |
|
|
"loaded": self.model is not None and self.processor is not None |
|
|
} |
|
|
|
|
|
def test_model(): |
|
|
"""Test the model with a simple example""" |
|
|
try: |
|
|
|
|
|
vlm = SmolVLM2Handler() |
|
|
|
|
|
print("📋 Model Info:") |
|
|
info = vlm.get_model_info() |
|
|
for key, value in info.items(): |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
|
|
|
test_image = Image.new('RGB', (224, 224), color='blue') |
|
|
test_prompt = "What color is this image?" |
|
|
|
|
|
print(f"\n🔍 Testing with prompt: '{test_prompt}'") |
|
|
response = vlm.generate_response(test_image, test_prompt) |
|
|
print(f"📝 Response: {response}") |
|
|
|
|
|
print("\n✅ Model test completed successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Model test failed: {e}") |
|
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_model() |
|
|
|