File size: 9,991 Bytes
a4bd75a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#!/usr/bin/env python3
"""
SmolVLM2 Model Handler
Handles loading and inference with SmolVLM2-1.7B-Instruct model
"""

import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import requests
from typing import List, Union, Optional
import logging
import warnings

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Suppress some warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)

class SmolVLM2Handler:
    """Handler for SmolVLM2 model operations"""
    
    def __init__(self, model_name: str = "HuggingFaceTB/SmolVLM2-2.2B-Instruct", device: str = "auto"):
        """
        Initialize SmolVLM2 model
        
        Args:
            model_name: HuggingFace model identifier
            device: Device to use ('auto', 'cpu', 'cuda', 'mps')
        """
        self.model_name = model_name
        self.device = self._get_device(device)
        self.model = None
        self.processor = None
        
        logger.info(f"Initializing SmolVLM2 on device: {self.device}")
        self._load_model()
    
    def _get_device(self, device: str) -> str:
        """Determine the best device to use"""
        if device == "auto":
            if torch.cuda.is_available():
                return "cuda"
            elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                return "mps"  # Apple Silicon GPU
            else:
                return "cpu"
        return device
    
    def _load_model(self):
        """Load the model and processor"""
        try:
            logger.info("Loading processor...")
            self.processor = AutoProcessor.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
            
            logger.info("Loading model...")
            self.model = AutoModelForImageTextToText.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
                trust_remote_code=True,
                device_map=self.device if self.device != "cpu" else None
            )
            
            if self.device == "cpu":
                self.model = self.model.to(self.device)
            
            logger.info("βœ… Model loaded successfully!")
            
        except Exception as e:
            logger.error(f"❌ Failed to load model: {e}")
            raise
    
    def process_image(self, image_input: Union[str, Image.Image]) -> Image.Image:
        """
        Process image input into PIL Image
        
        Args:
            image_input: File path, URL, or PIL Image
            
        Returns:
            PIL Image object
        """
        if isinstance(image_input, str):
            if image_input.startswith(('http://', 'https://')):
                # Download from URL
                response = requests.get(image_input)
                image = Image.open(requests.get(image_input, stream=True).raw)
            else:
                # Load from file path
                image = Image.open(image_input)
        elif isinstance(image_input, Image.Image):
            image = image_input
        else:
            raise ValueError("Image input must be file path, URL, or PIL Image")
        
        # Convert to RGB if necessary
        if image.mode != 'RGB':
            image = image.convert('RGB')
            
        return image
    
    def generate_response(
        self,
        image_input: Union[str, Image.Image, List[Image.Image]], 
        text_prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        do_sample: bool = True
    ) -> str:
        """
        Generate response from image(s) and text prompt
        
        Args:
            image_input: Single image or list of images
            text_prompt: Text prompt/question
            max_new_tokens: Maximum tokens to generate
            temperature: Sampling temperature
            do_sample: Whether to use sampling
            
        Returns:
            Generated text response
        """
        try:
            # Process images
            if isinstance(image_input, list):
                images = [self.process_image(img) for img in image_input]
            else:
                images = [self.process_image(image_input)]
            
            # Create proper conversation format for SmolVLM2
            messages = [
                {
                    "role": "user", 
                    "content": [{"type": "text", "text": text_prompt}]
                }
            ]
            
            # Add image content to the message
            for img in images:
                messages[0]["content"].insert(0, {"type": "image", "image": img})
            
            # Apply chat template
            try:
                prompt = self.processor.apply_chat_template(
                    messages, 
                    add_generation_prompt=True
                )
            except:
                # Fallback to simple format if chat template fails
                image_tokens = "<image>" * len(images)
                prompt = f"{image_tokens}{text_prompt}"
            
            # Prepare inputs
            inputs = self.processor(
                images=images,
                text=prompt,
                return_tensors="pt"
            ).to(self.device)
            
            # Generate response with robust parameters
            with torch.no_grad():
                try:
                    generated_ids = self.model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        temperature=max(0.1, min(temperature, 1.0)),  # Clamp temperature
                        do_sample=do_sample,
                        top_p=0.9,
                        repetition_penalty=1.1,
                        pad_token_id=self.processor.tokenizer.eos_token_id,
                        eos_token_id=self.processor.tokenizer.eos_token_id,
                        use_cache=True
                    )
                except RuntimeError as e:
                    if "probability tensor" in str(e) or "nan" in str(e) or "inf" in str(e):
                        # Retry with more conservative parameters
                        logger.warning("Retrying with conservative parameters due to probability tensor error")
                        generated_ids = self.model.generate(
                            **inputs,
                            max_new_tokens=min(max_new_tokens, 256),
                            temperature=0.3,
                            do_sample=False,  # Use greedy decoding
                            pad_token_id=self.processor.tokenizer.eos_token_id,
                            eos_token_id=self.processor.tokenizer.eos_token_id,
                            use_cache=True
                        )
                    else:
                        raise
            
            # Decode only the new tokens (skip input)
            input_length = inputs['input_ids'].shape[1]
            new_tokens = generated_ids[0][input_length:]
            
            generated_text = self.processor.tokenizer.decode(
                new_tokens, 
                skip_special_tokens=True
            ).strip()
            
            # Return meaningful response even if empty
            if not generated_text:
                return "I can see the image but cannot generate a specific description."
            
            return generated_text
            
        except Exception as e:
            logger.error(f"❌ Error during generation: {e}")
            raise
    
    def analyze_video_frames(
        self,
        frames: List[Image.Image],
        question: str,
        max_frames: int = 8
    ) -> str:
        """
        Analyze video frames and answer questions
        
        Args:
            frames: List of PIL Image frames
            question: Question about the video
            max_frames: Maximum number of frames to process
            
        Returns:
            Analysis result
        """
        # Sample frames if too many
        if len(frames) > max_frames:
            step = len(frames) // max_frames
            sampled_frames = frames[::step][:max_frames]
        else:
            sampled_frames = frames
        
        logger.info(f"Analyzing {len(sampled_frames)} frames")
        
        # Create a simple prompt for video analysis (don't add image tokens manually)
        video_prompt = f"These are frames from a video. {question}"
        
        return self.generate_response(sampled_frames, video_prompt)
    
    def get_model_info(self) -> dict:
        """Get information about the loaded model"""
        return {
            "model_name": self.model_name,
            "device": self.device,
            "model_type": type(self.model).__name__,
            "processor_type": type(self.processor).__name__,
            "loaded": self.model is not None and self.processor is not None
        }

def test_model():
    """Test the model with a simple example"""
    try:
        # Initialize model
        vlm = SmolVLM2Handler()
        
        print("πŸ“‹ Model Info:")
        info = vlm.get_model_info()
        for key, value in info.items():
            print(f"  {key}: {value}")
        
        # Test with a simple image (create a test image)
        test_image = Image.new('RGB', (224, 224), color='blue')
        test_prompt = "What color is this image?"
        
        print(f"\nπŸ” Testing with prompt: '{test_prompt}'")
        response = vlm.generate_response(test_image, test_prompt)
        print(f"πŸ“ Response: {response}")
        
        print("\nβœ… Model test completed successfully!")
        
    except Exception as e:
        print(f"❌ Model test failed: {e}")
        raise

if __name__ == "__main__":
    test_model()