File size: 11,072 Bytes
a4bd75a
 
 
d54232c
a4bd75a
 
fbfadd2
 
 
77685c6
fbfadd2
1824737
 
d54232c
1824737
fbfadd2
 
 
1824737
 
beb9eb9
 
fbfadd2
a4bd75a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ef6739
a4bd75a
58f0729
a4bd75a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58f0729
a4bd75a
 
 
 
 
58f0729
 
 
 
 
a4bd75a
 
 
 
 
 
 
 
 
 
 
58f0729
 
 
a4bd75a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#!/usr/bin/env python3
"""
SmolVLM2 Model Handler
Handles loading and inference with SmolVLM2-256M-Instruct model (smallest model for HuggingFace Spaces)
"""

import os
import tempfile

# Set cache directories to writable locations for HuggingFace Spaces
if 'HF_HOME' not in os.environ:
    # Use /tmp which is guaranteed to be writable in containers
    CACHE_DIR = os.path.join("/tmp", ".cache", "huggingface")
    os.makedirs(CACHE_DIR, exist_ok=True)
    os.makedirs(os.path.join("/tmp", ".cache", "torch"), exist_ok=True)
    os.environ['HF_HOME'] = CACHE_DIR
    os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
    os.environ['HF_DATASETS_CACHE'] = CACHE_DIR
    os.environ['TORCH_HOME'] = os.path.join("/tmp", ".cache", "torch")
    os.environ['XDG_CACHE_HOME'] = os.path.join("/tmp", ".cache")
    os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import requests
from typing import List, Union, Optional
import logging
import warnings

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Suppress some warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)

class SmolVLM2Handler:
    """Handler for SmolVLM2 model operations"""
    
    def __init__(self, model_name: str = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct", device: str = "auto"):
        """
        Initialize SmolVLM2 model (2.2B version - better reasoning capabilities)
        
        Args:
            model_name: HuggingFace model identifier
            device: Device to use ('auto', 'cpu', 'cuda', 'mps')
        """
        self.model_name = model_name
        self.device = self._get_device(device)
        self.model = None
        self.processor = None
        
        logger.info(f"Initializing SmolVLM2 on device: {self.device}")
        self._load_model()
    
    def _get_device(self, device: str) -> str:
        """Determine the best device to use"""
        if device == "auto":
            if torch.cuda.is_available():
                return "cuda"
            elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                return "mps"  # Apple Silicon GPU
            else:
                return "cpu"
        return device
    
    def _load_model(self):
        """Load the model and processor"""
        try:
            logger.info("Loading processor...")
            self.processor = AutoProcessor.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
            
            logger.info("Loading model...")
            self.model = AutoModelForImageTextToText.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
                trust_remote_code=True,
                device_map=self.device if self.device != "cpu" else None
            )
            
            if self.device == "cpu":
                self.model = self.model.to(self.device)
            
            logger.info("βœ… Model loaded successfully!")
            
        except Exception as e:
            logger.error(f"❌ Failed to load model: {e}")
            raise
    
    def process_image(self, image_input: Union[str, Image.Image]) -> Image.Image:
        """
        Process image input into PIL Image
        
        Args:
            image_input: File path, URL, or PIL Image
            
        Returns:
            PIL Image object
        """
        if isinstance(image_input, str):
            if image_input.startswith(('http://', 'https://')):
                # Download from URL
                response = requests.get(image_input)
                image = Image.open(requests.get(image_input, stream=True).raw)
            else:
                # Load from file path
                image = Image.open(image_input)
        elif isinstance(image_input, Image.Image):
            image = image_input
        else:
            raise ValueError("Image input must be file path, URL, or PIL Image")
        
        # Convert to RGB if necessary
        if image.mode != 'RGB':
            image = image.convert('RGB')
            
        return image
    
    def generate_response(
        self,
        image_input: Union[str, Image.Image, List[Image.Image]], 
        text_prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        do_sample: bool = True
    ) -> str:
        """
        Generate response from image(s) and text prompt
        
        Args:
            image_input: Single image or list of images
            text_prompt: Text prompt/question
            max_new_tokens: Maximum tokens to generate
            temperature: Sampling temperature
            do_sample: Whether to use sampling
            
        Returns:
            Generated text response
        """
        try:
            # Process images
            if isinstance(image_input, list):
                images = [self.process_image(img) for img in image_input]
            else:
                images = [self.process_image(image_input)]
            
            # Create proper conversation format for SmolVLM2
            messages = [
                {
                    "role": "user", 
                    "content": [{"type": "text", "text": text_prompt}]
                }
            ]
            
            # Add image content to the message
            for img in images:
                messages[0]["content"].insert(0, {"type": "image", "image": img})
            
            # Apply chat template
            try:
                prompt = self.processor.apply_chat_template(
                    messages, 
                    add_generation_prompt=True
                )
            except:
                # Fallback to simple format if chat template fails
                image_tokens = "<image>" * len(images)
                prompt = f"{image_tokens}{text_prompt}"
            
            # Prepare inputs
            inputs = self.processor(
                images=images,
                text=prompt,
                return_tensors="pt"
            ).to(self.device)
            
            # Generate response with robust parameters optimized for scoring
            with torch.no_grad():
                try:
                    generated_ids = self.model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        temperature=0.7,  # Higher temperature for more varied responses
                        do_sample=True,   # Enable sampling for variety
                        top_p=0.85,       # Slightly lower top_p for more focused responses
                        top_k=40,         # Add top_k for better control
                        repetition_penalty=1.2,  # Higher repetition penalty
                        pad_token_id=self.processor.tokenizer.eos_token_id,
                        eos_token_id=self.processor.tokenizer.eos_token_id,
                        use_cache=True
                    )
                except RuntimeError as e:
                    if "probability tensor" in str(e) or "nan" in str(e) or "inf" in str(e):
                        # Retry with more conservative parameters
                        logger.warning("Retrying with conservative parameters due to probability tensor error")
                        generated_ids = self.model.generate(
                            **inputs,
                            max_new_tokens=min(max_new_tokens, 256),
                            temperature=0.5,  # Still some variety
                            do_sample=True,
                            top_p=0.9,
                            pad_token_id=self.processor.tokenizer.eos_token_id,
                            eos_token_id=self.processor.tokenizer.eos_token_id,
                            use_cache=True
                        )
                    else:
                        raise
            
            # Decode only the new tokens (skip input)
            input_length = inputs['input_ids'].shape[1]
            new_tokens = generated_ids[0][input_length:]
            
            generated_text = self.processor.tokenizer.decode(
                new_tokens, 
                skip_special_tokens=True
            ).strip()
            
            # Return meaningful response even if empty
            if not generated_text:
                return "I can see the image but cannot generate a specific description."
            
            return generated_text
            
        except Exception as e:
            logger.error(f"❌ Error during generation: {e}")
            raise
    
    def analyze_video_frames(
        self,
        frames: List[Image.Image],
        question: str,
        max_frames: int = 8
    ) -> str:
        """
        Analyze video frames and answer questions
        
        Args:
            frames: List of PIL Image frames
            question: Question about the video
            max_frames: Maximum number of frames to process
            
        Returns:
            Analysis result
        """
        # Sample frames if too many
        if len(frames) > max_frames:
            step = len(frames) // max_frames
            sampled_frames = frames[::step][:max_frames]
        else:
            sampled_frames = frames
        
        logger.info(f"Analyzing {len(sampled_frames)} frames")
        
        # Create a simple prompt for video analysis (don't add image tokens manually)
        video_prompt = f"These are frames from a video. {question}"
        
        return self.generate_response(sampled_frames, video_prompt)
    
    def get_model_info(self) -> dict:
        """Get information about the loaded model"""
        return {
            "model_name": self.model_name,
            "device": self.device,
            "model_type": type(self.model).__name__,
            "processor_type": type(self.processor).__name__,
            "loaded": self.model is not None and self.processor is not None
        }

def test_model():
    """Test the model with a simple example"""
    try:
        # Initialize model
        vlm = SmolVLM2Handler()
        
        print("πŸ“‹ Model Info:")
        info = vlm.get_model_info()
        for key, value in info.items():
            print(f"  {key}: {value}")
        
        # Test with a simple image (create a test image)
        test_image = Image.new('RGB', (224, 224), color='blue')
        test_prompt = "What color is this image?"
        
        print(f"\nπŸ” Testing with prompt: '{test_prompt}'")
        response = vlm.generate_response(test_image, test_prompt)
        print(f"πŸ“ Response: {response}")
        
        print("\nβœ… Model test completed successfully!")
        
    except Exception as e:
        print(f"❌ Model test failed: {e}")
        raise

if __name__ == "__main__":
    test_model()