avinashHuggingface108 commited on
Commit
a4bd75a
Β·
1 Parent(s): 8a9a9e9

πŸš€ Deploy SmolVLM2 Video Highlights API

Browse files

- FastAPI server with background processing
- SmolVLM2 + Whisper AI integration
- Docker deployment configuration
- Complete video highlights generation system
- REST API for mobile app integration

Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 slim image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ ffmpeg \
10
+ git \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first for better caching
14
+ COPY requirements.txt fastapi_requirements.txt ./
15
+ RUN pip install --no-cache-dir -r requirements.txt && \
16
+ pip install --no-cache-dir -r fastapi_requirements.txt
17
+
18
+ # Copy application code
19
+ COPY . .
20
+
21
+ # Create necessary directories
22
+ RUN mkdir -p outputs temp samples src
23
+
24
+ # Expose port
25
+ EXPOSE 7860
26
+
27
+ # Set environment variables
28
+ ENV PYTHONPATH=/app
29
+ ENV GRADIO_SERVER_NAME=0.0.0.0
30
+ ENV GRADIO_SERVER_PORT=7860
31
+
32
+ # Run the FastAPI app
33
+ CMD ["python", "-m", "uvicorn", "highlights_api:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,81 @@
1
  ---
2
- title: Smolvlm2 Video Highlights
3
- emoji: πŸ‘€
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
- short_description: video highlights
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SmolVLM2 Video Highlights
3
+ emoji: 🎬
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ license: apache-2.0
9
+ app_port: 7860
10
  ---
11
 
12
+ # 🎬 SmolVLM2 Video Highlights API
13
+
14
+ **Generate intelligent video highlights using SmolVLM2 + Whisper AI**
15
+
16
+ This is a FastAPI service that combines visual analysis (SmolVLM2) with audio transcription (Whisper) to automatically create highlight videos from longer content.
17
+
18
+ ## πŸš€ Features
19
+
20
+ - **Visual Analysis**: SmolVLM2-2.2B-Instruct analyzes video frames for interesting content
21
+ - **Audio Processing**: Whisper transcribes speech in 99+ languages
22
+ - **Smart Scoring**: Combines visual and audio analysis for intelligent highlights
23
+ - **REST API**: Upload videos and download processed highlights
24
+ - **Background Processing**: Non-blocking video processing with job tracking
25
+
26
+ ## πŸ”— API Endpoints
27
+
28
+ - `POST /upload-video` - Upload video for processing
29
+ - `GET /job-status/{job_id}` - Check processing status
30
+ - `GET /download/{filename}` - Download generated highlights
31
+ - `GET /docs` - Interactive API documentation
32
+
33
+ ## πŸ“± Usage
34
+
35
+ ### Via API
36
+ ```bash
37
+ # Upload video
38
+ curl -X POST -F "video=@your_video.mp4" https://avinashhuggingface108-smolvlm2-video-highlights.hf.space/upload-video
39
+
40
+ # Check status
41
+ curl https://avinashhuggingface108-smolvlm2-video-highlights.hf.space/job-status/YOUR_JOB_ID
42
+
43
+ # Download highlights
44
+ curl -O https://avinashhuggingface108-smolvlm2-video-highlights.hf.space/download/FILENAME.mp4
45
+ ```
46
+
47
+ ### Via Android App
48
+ Use the provided Android client code to integrate with your mobile app.
49
+
50
+ ## βš™οΈ Configuration
51
+
52
+ Default settings:
53
+ - **Interval**: 20 seconds (analyze every 20s)
54
+ - **Min Score**: 6.5 (quality threshold)
55
+ - **Max Highlights**: 3 (maximum highlight segments)
56
+ - **Whisper Model**: base (accuracy vs speed)
57
+ - **Timeout**: 35 seconds per segment
58
+
59
+ ## πŸ› οΈ Technology Stack
60
+
61
+ - **SmolVLM2-2.2B-Instruct**: Vision-language model for visual content analysis
62
+ - **OpenAI Whisper**: Speech-to-text in 99+ languages
63
+ - **FastAPI**: Modern web framework for APIs
64
+ - **FFmpeg**: Video processing and manipulation
65
+ - **PyTorch**: Deep learning framework with MPS acceleration
66
+
67
+ ## 🎯 Perfect For
68
+
69
+ - Social media content creators
70
+ - Educational video processing
71
+ - Meeting/lecture summarization
72
+ - Sports highlight generation
73
+ - Entertainment content curation
74
+
75
+ ## οΏ½οΏ½ License
76
+
77
+ Apache 2.0 - Free for commercial and personal use
78
+
79
+ ## 🀝 Contributing
80
+
81
+ Built with ❀️ using Hugging Face Transformers and open-source AI models.
audio_enhanced_highlights_final.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Audio-Enhanced Video Highlights Generator
4
+ Combines SmolVLM2 visual analysis with Whisper audio transcription
5
+ Supports 99+ languages including Telugu, Hindi, English
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import cv2
11
+ import argparse
12
+ import json
13
+ import subprocess
14
+ import threading
15
+ import time
16
+ import tempfile
17
+ from pathlib import Path
18
+ from PIL import Image
19
+ from typing import List, Dict, Optional
20
+ import logging
21
+
22
+ # Add src directory to path for imports
23
+ sys.path.append(str(Path(__file__).parent / "src"))
24
+
25
+ try:
26
+ from src.smolvlm2_handler import SmolVLM2Handler
27
+ except ImportError:
28
+ print("❌ SmolVLM2Handler not found. Make sure to install dependencies first.")
29
+ sys.exit(1)
30
+
31
+ try:
32
+ import whisper
33
+ WHISPER_AVAILABLE = True
34
+ print("βœ… Whisper available for audio transcription")
35
+ except ImportError:
36
+ WHISPER_AVAILABLE = False
37
+ print("❌ Whisper not available. Install with: pip install openai-whisper")
38
+ sys.exit(1)
39
+
40
+ logging.basicConfig(level=logging.INFO)
41
+ logger = logging.getLogger(__name__)
42
+
43
+ class AudioVisualAnalyzer:
44
+ """Comprehensive analyzer combining visual and audio analysis"""
45
+
46
+ def __init__(self, whisper_model_size="base", timeout_seconds=30):
47
+ """Initialize with SmolVLM2 and Whisper models"""
48
+ print("πŸ”§ Initializing Audio-Visual Analyzer...")
49
+
50
+ # Initialize SmolVLM2 for visual analysis
51
+ self.vlm_handler = SmolVLM2Handler()
52
+ self.timeout_seconds = timeout_seconds
53
+
54
+ # Initialize Whisper for audio analysis
55
+ if WHISPER_AVAILABLE:
56
+ print(f"πŸ“₯ Loading Whisper model ({whisper_model_size})...")
57
+ self.whisper_model = whisper.load_model(whisper_model_size)
58
+ print("βœ… Whisper model loaded successfully")
59
+ else:
60
+ self.whisper_model = None
61
+ print("⚠️ Whisper not available - audio analysis disabled")
62
+
63
+ def extract_audio_segments(self, video_path: str, segments: List[Dict]) -> List[str]:
64
+ """Extract audio for specific video segments"""
65
+ audio_files = []
66
+ temp_dir = tempfile.mkdtemp()
67
+
68
+ for i, segment in enumerate(segments):
69
+ start_time = segment['start_time']
70
+ duration = segment['duration']
71
+
72
+ audio_path = os.path.join(temp_dir, f"segment_{i}.wav")
73
+
74
+ # Extract audio segment using FFmpeg
75
+ cmd = [
76
+ 'ffmpeg', '-i', video_path,
77
+ '-ss', str(start_time),
78
+ '-t', str(duration),
79
+ '-vn', # No video
80
+ '-acodec', 'pcm_s16le', # Uncompressed audio
81
+ '-ar', '16000', # 16kHz sample rate for Whisper
82
+ '-ac', '1', # Mono
83
+ '-f', 'wav', # Force WAV format
84
+ '-y', # Overwrite
85
+ audio_path
86
+ ]
87
+
88
+ try:
89
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True)
90
+ if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
91
+ audio_files.append(audio_path)
92
+ logger.info(f"πŸ“„ Extracted audio segment {i+1}: {duration:.1f}s")
93
+ else:
94
+ logger.warning(f"⚠️ Audio segment {i+1} is empty or missing")
95
+ audio_files.append(None)
96
+ except subprocess.CalledProcessError as e:
97
+ logger.warning(f"⚠️ No audio stream in segment {i+1} (this is normal for silent videos)")
98
+ audio_files.append(None)
99
+
100
+ return audio_files
101
+
102
+ def transcribe_audio_segment(self, audio_path: str) -> Dict:
103
+ """Transcribe audio segment with Whisper"""
104
+ if not WHISPER_AVAILABLE or not audio_path or not os.path.exists(audio_path):
105
+ return {"text": "", "language": "unknown", "confidence": 0.0}
106
+
107
+ try:
108
+ result = self.whisper_model.transcribe(
109
+ audio_path,
110
+ language=None, # Auto-detect language
111
+ task="transcribe"
112
+ )
113
+
114
+ return {
115
+ "text": result.get("text", "").strip(),
116
+ "language": result.get("language", "unknown"),
117
+ "confidence": 1.0 # Whisper doesn't provide confidence scores
118
+ }
119
+ except Exception as e:
120
+ logger.error(f"❌ Audio transcription failed: {e}")
121
+ return {"text": "", "language": "unknown", "confidence": 0.0}
122
+
123
+ def analyze_visual_content(self, frame_path: str) -> Dict:
124
+ """Analyze visual content using SmolVLM2 with robust error handling"""
125
+ max_retries = 2
126
+ retry_count = 0
127
+
128
+ while retry_count < max_retries:
129
+ try:
130
+ def generate_with_timeout():
131
+ prompt = ("Analyze this video frame for interesting, engaging, or highlight-worthy content. "
132
+ "Rate the excitement/interest level from 1-10 and explain what makes it noteworthy. "
133
+ "Focus on action, emotion, important moments, or visually striking elements.")
134
+ return self.vlm_handler.generate_response(frame_path, prompt)
135
+
136
+ # Run with timeout protection
137
+ thread_result = [None]
138
+ exception_result = [None]
139
+
140
+ def target():
141
+ try:
142
+ thread_result[0] = generate_with_timeout()
143
+ except Exception as e:
144
+ exception_result[0] = e
145
+
146
+ thread = threading.Thread(target=target)
147
+ thread.daemon = True
148
+ thread.start()
149
+ thread.join(self.timeout_seconds)
150
+
151
+ if thread.is_alive():
152
+ logger.warning(f"⏰ Visual analysis timed out after {self.timeout_seconds}s (attempt {retry_count + 1})")
153
+ retry_count += 1
154
+ if retry_count >= max_retries:
155
+ return {"description": "Analysis timed out after multiple attempts", "score": 6.0}
156
+ continue
157
+
158
+ if exception_result[0]:
159
+ error_msg = str(exception_result[0])
160
+ if "probability tensor" in error_msg or "inf" in error_msg or "nan" in error_msg:
161
+ logger.warning(f"⚠️ Model inference error, retrying (attempt {retry_count + 1}): {error_msg}")
162
+ retry_count += 1
163
+ if retry_count >= max_retries:
164
+ return {"description": "Model inference failed after retries", "score": 6.0}
165
+ continue
166
+ else:
167
+ raise exception_result[0]
168
+
169
+ response = thread_result[0]
170
+ if not response or len(response.strip()) == 0:
171
+ logger.warning(f"⚠️ Empty response, retrying (attempt {retry_count + 1})")
172
+ retry_count += 1
173
+ if retry_count >= max_retries:
174
+ return {"description": "No meaningful response after retries", "score": 6.0}
175
+ continue
176
+
177
+ # Extract score from response
178
+ score = self.extract_score_from_text(response)
179
+ return {"description": response, "score": score}
180
+
181
+ except Exception as e:
182
+ error_msg = str(e)
183
+ logger.warning(f"⚠️ Visual analysis error (attempt {retry_count + 1}): {error_msg}")
184
+ retry_count += 1
185
+ if retry_count >= max_retries:
186
+ return {"description": f"Analysis failed after {max_retries} attempts: {error_msg}", "score": 6.0}
187
+
188
+ # Fallback if all retries failed
189
+ return {"description": "Analysis failed after all retry attempts", "score": 6.0}
190
+
191
+ def extract_score_from_text(self, text: str) -> float:
192
+ """Extract numeric score from analysis text"""
193
+ import re
194
+
195
+ # Look for patterns like "8/10", "score: 7", "rate: 6.5", etc.
196
+ patterns = [
197
+ r'(\d+(?:\.\d+)?)\s*/\s*10', # "8/10" or "7.5/10"
198
+ r'(?:score|rating|rate)(?:\s*[:=]\s*)(\d+(?:\.\d+)?)', # "score: 8" or "rating=7.5"
199
+ r'(\d+(?:\.\d+)?)\s*(?:out of|/)\s*10', # "8 out of 10"
200
+ r'(?:^|\s)(\d+(?:\.\d+)?)(?:\s*[/]\s*10)?(?:\s|$)', # Just numbers
201
+ ]
202
+
203
+ for pattern in patterns:
204
+ matches = re.findall(pattern, text.lower())
205
+ if matches:
206
+ try:
207
+ score = float(matches[0])
208
+ return min(max(score, 1.0), 10.0) # Clamp between 1-10
209
+ except ValueError:
210
+ continue
211
+
212
+ return 6.0 # Default score if no pattern found
213
+
214
+ def calculate_combined_score(self, visual_score: float, audio_text: str, audio_lang: str) -> float:
215
+ """Calculate combined score from visual and audio analysis"""
216
+ # Start with visual score
217
+ combined_score = visual_score
218
+
219
+ # Audio content scoring
220
+ if audio_text:
221
+ audio_bonus = 0.0
222
+ text_lower = audio_text.lower()
223
+
224
+ # Positive indicators
225
+ excitement_words = ['amazing', 'incredible', 'wow', 'fantastic', 'awesome', 'perfect', 'excellent']
226
+ action_words = ['goal', 'win', 'victory', 'success', 'breakthrough', 'achievement']
227
+ emotion_words = ['happy', 'excited', 'thrilled', 'surprised', 'shocked', 'love']
228
+
229
+ # Telugu positive indicators (basic)
230
+ telugu_positive = ['అద్భుఀం', 'చాలా బాగుంది', 'డాడ్', 'సూΰ°ͺర్']
231
+
232
+ # Count positive indicators
233
+ for word_list in [excitement_words, action_words, emotion_words, telugu_positive]:
234
+ for word in word_list:
235
+ if word in text_lower:
236
+ audio_bonus += 0.5
237
+
238
+ # Length bonus for substantial content
239
+ if len(audio_text) > 50:
240
+ audio_bonus += 0.3
241
+ elif len(audio_text) > 20:
242
+ audio_bonus += 0.1
243
+
244
+ # Language diversity bonus
245
+ if audio_lang in ['te', 'telugu']: # Telugu content
246
+ audio_bonus += 0.2
247
+ elif audio_lang in ['hi', 'hindi']: # Hindi content
248
+ audio_bonus += 0.2
249
+
250
+ combined_score += audio_bonus
251
+
252
+ # Clamp final score
253
+ return min(max(combined_score, 1.0), 10.0)
254
+
255
+ def analyze_segment(self, video_path: str, segment: Dict, temp_frame_path: str) -> Dict:
256
+ """Analyze a single video segment with both visual and audio"""
257
+ start_time = segment['start_time']
258
+ duration = segment['duration']
259
+
260
+ logger.info(f"πŸ” Analyzing segment at {start_time:.1f}s ({duration:.1f}s duration)")
261
+
262
+ # Visual analysis
263
+ visual_analysis = self.analyze_visual_content(temp_frame_path)
264
+
265
+ # Audio analysis
266
+ audio_files = self.extract_audio_segments(video_path, [segment])
267
+ audio_analysis = {"text": "", "language": "unknown", "confidence": 0.0}
268
+
269
+ if audio_files and audio_files[0]:
270
+ audio_analysis = self.transcribe_audio_segment(audio_files[0])
271
+ # Cleanup temporary audio file
272
+ try:
273
+ os.unlink(audio_files[0])
274
+ except:
275
+ pass
276
+
277
+ # Combined scoring
278
+ combined_score = self.calculate_combined_score(
279
+ visual_analysis['score'],
280
+ audio_analysis['text'],
281
+ audio_analysis['language']
282
+ )
283
+
284
+ return {
285
+ 'start_time': start_time,
286
+ 'duration': duration,
287
+ 'visual_score': visual_analysis['score'],
288
+ 'visual_description': visual_analysis['description'],
289
+ 'audio_text': audio_analysis['text'],
290
+ 'audio_language': audio_analysis['language'],
291
+ 'combined_score': combined_score,
292
+ 'selected': False
293
+ }
294
+
295
+ def extract_frames_at_intervals(video_path: str, interval_seconds: float = 10.0) -> List[Dict]:
296
+ """Extract frames at regular intervals from video"""
297
+ cap = cv2.VideoCapture(video_path)
298
+ if not cap.isOpened():
299
+ raise ValueError(f"Cannot open video file: {video_path}")
300
+
301
+ fps = cap.get(cv2.CAP_PROP_FPS)
302
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
303
+ duration = total_frames / fps
304
+
305
+ logger.info(f"πŸ“Ή Video: {duration:.1f}s, {fps:.1f} FPS, {total_frames} frames")
306
+
307
+ segments = []
308
+ current_time = 0
309
+
310
+ while current_time < duration:
311
+ segment_duration = min(interval_seconds, duration - current_time)
312
+ segments.append({
313
+ 'start_time': current_time,
314
+ 'duration': segment_duration,
315
+ 'frame_number': int(current_time * fps)
316
+ })
317
+ current_time += interval_seconds
318
+
319
+ cap.release()
320
+ return segments
321
+
322
+ def save_frame_at_time(video_path: str, time_seconds: float, output_path: str) -> bool:
323
+ """Save a frame at specific time with robust frame extraction"""
324
+ cap = cv2.VideoCapture(video_path)
325
+ if not cap.isOpened():
326
+ return False
327
+
328
+ try:
329
+ fps = cap.get(cv2.CAP_PROP_FPS)
330
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
331
+ frame_number = int(time_seconds * fps)
332
+
333
+ # Ensure frame number is within valid range
334
+ frame_number = min(frame_number, total_frames - 1)
335
+ frame_number = max(frame_number, 0)
336
+
337
+ # Try to extract frame with fallback options
338
+ for attempt in range(3):
339
+ try:
340
+ # Try exact frame first
341
+ test_frame = frame_number + attempt
342
+ if test_frame >= total_frames:
343
+ test_frame = frame_number - attempt
344
+ if test_frame < 0:
345
+ test_frame = frame_number
346
+
347
+ cap.set(cv2.CAP_PROP_POS_FRAMES, test_frame)
348
+ ret, frame = cap.read()
349
+
350
+ if ret and frame is not None and frame.size > 0:
351
+ # Validate frame data
352
+ if len(frame.shape) == 3 and frame.shape[2] == 3: # Valid color frame
353
+ success = cv2.imwrite(output_path, frame)
354
+ if success:
355
+ cap.release()
356
+ return True
357
+
358
+ except Exception as e:
359
+ logger.warning(f"Frame extraction attempt {attempt + 1} failed: {e}")
360
+ continue
361
+
362
+ cap.release()
363
+ return False
364
+
365
+ except Exception as e:
366
+ logger.error(f"Critical error in frame extraction: {e}")
367
+ cap.release()
368
+ return False
369
+
370
+ def create_highlights_video(video_path: str, selected_segments: List[Dict], output_path: str):
371
+ """Create highlights video from selected segments"""
372
+ if not selected_segments:
373
+ logger.error("❌ No segments selected for highlights")
374
+ return False
375
+
376
+ # Create temporary files for each segment
377
+ temp_files = []
378
+ temp_dir = tempfile.mkdtemp()
379
+
380
+ for i, segment in enumerate(selected_segments):
381
+ temp_file = os.path.join(temp_dir, f"segment_{i}.mp4")
382
+
383
+ cmd = [
384
+ 'ffmpeg', '-i', video_path,
385
+ '-ss', str(segment['start_time']),
386
+ '-t', str(segment['duration']),
387
+ '-c', 'copy', # Copy streams without re-encoding
388
+ '-y', temp_file
389
+ ]
390
+
391
+ try:
392
+ subprocess.run(cmd, check=True, capture_output=True)
393
+ temp_files.append(temp_file)
394
+ logger.info(f"βœ… Created segment {i+1}/{len(selected_segments)}")
395
+ except subprocess.CalledProcessError as e:
396
+ logger.error(f"❌ Failed to create segment {i+1}: {e}")
397
+ continue
398
+
399
+ if not temp_files:
400
+ logger.error("❌ No valid segments created")
401
+ return False
402
+
403
+ # Create concat file
404
+ concat_file = os.path.join(temp_dir, "concat.txt")
405
+ with open(concat_file, 'w') as f:
406
+ for temp_file in temp_files:
407
+ f.write(f"file '{temp_file}'\n")
408
+
409
+ # Concatenate segments
410
+ cmd = [
411
+ 'ffmpeg', '-f', 'concat', '-safe', '0',
412
+ '-i', concat_file,
413
+ '-c', 'copy',
414
+ '-y', output_path
415
+ ]
416
+
417
+ try:
418
+ subprocess.run(cmd, check=True, capture_output=True)
419
+ logger.info(f"βœ… Highlights video created: {output_path}")
420
+
421
+ # Cleanup
422
+ for temp_file in temp_files:
423
+ try:
424
+ os.unlink(temp_file)
425
+ except:
426
+ pass
427
+ try:
428
+ os.unlink(concat_file)
429
+ os.rmdir(temp_dir)
430
+ except:
431
+ pass
432
+
433
+ return True
434
+ except subprocess.CalledProcessError as e:
435
+ logger.error(f"❌ Failed to create highlights video: {e}")
436
+ return False
437
+
438
+ def main():
439
+ parser = argparse.ArgumentParser(description="Audio-Enhanced Video Highlights Generator")
440
+ parser.add_argument("video_path", help="Path to input video file")
441
+ parser.add_argument("--output", "-o", default="audio_enhanced_highlights.mp4",
442
+ help="Output highlights video path")
443
+ parser.add_argument("--interval", "-i", type=float, default=10.0,
444
+ help="Analysis interval in seconds (default: 10.0)")
445
+ parser.add_argument("--min-score", "-s", type=float, default=7.0,
446
+ help="Minimum score for highlights (default: 7.0)")
447
+ parser.add_argument("--max-highlights", "-m", type=int, default=5,
448
+ help="Maximum number of highlights (default: 5)")
449
+ parser.add_argument("--whisper-model", "-w", default="base",
450
+ choices=["tiny", "base", "small", "medium", "large"],
451
+ help="Whisper model size (default: base)")
452
+ parser.add_argument("--timeout", "-t", type=int, default=30,
453
+ help="Timeout for each analysis in seconds (default: 30)")
454
+ parser.add_argument("--save-analysis", action="store_true",
455
+ help="Save detailed analysis to JSON file")
456
+
457
+ args = parser.parse_args()
458
+
459
+ # Validate input
460
+ if not os.path.exists(args.video_path):
461
+ print(f"❌ Video file not found: {args.video_path}")
462
+ sys.exit(1)
463
+
464
+ print("🎬 Audio-Enhanced Video Highlights Generator")
465
+ print(f"πŸ“ Input: {args.video_path}")
466
+ print(f"πŸ“ Output: {args.output}")
467
+ print(f"⏱️ Analysis interval: {args.interval}s")
468
+ print(f"🎯 Minimum score: {args.min_score}")
469
+ print(f"πŸ† Max highlights: {args.max_highlights}")
470
+ print(f"πŸŽ™οΈ Whisper model: {args.whisper_model}")
471
+ print()
472
+
473
+ try:
474
+ # Initialize analyzer
475
+ analyzer = AudioVisualAnalyzer(
476
+ whisper_model_size=args.whisper_model,
477
+ timeout_seconds=args.timeout
478
+ )
479
+
480
+ # Extract segments for analysis
481
+ segments = extract_frames_at_intervals(args.video_path, args.interval)
482
+ print(f"πŸ“Š Analyzing {len(segments)} segments...")
483
+
484
+ analyzed_segments = []
485
+ temp_frame_path = "temp_frame.jpg"
486
+
487
+ for i, segment in enumerate(segments):
488
+ print(f"\nπŸ” Segment {i+1}/{len(segments)} (t={segment['start_time']:.1f}s)")
489
+
490
+ # Save frame for visual analysis
491
+ if save_frame_at_time(args.video_path, segment['start_time'], temp_frame_path):
492
+ # Analyze segment
493
+ analysis = analyzer.analyze_segment(args.video_path, segment, temp_frame_path)
494
+ analyzed_segments.append(analysis)
495
+
496
+ print(f" πŸ‘οΈ Visual: {analysis['visual_score']:.1f}/10")
497
+ print(f" πŸŽ™οΈ Audio: '{analysis['audio_text'][:50]}...' ({analysis['audio_language']})")
498
+ print(f" 🎯 Combined: {analysis['combined_score']:.1f}/10")
499
+ else:
500
+ print(f" ❌ Failed to extract frame")
501
+
502
+ # Cleanup temp frame
503
+ try:
504
+ os.unlink(temp_frame_path)
505
+ except:
506
+ pass
507
+
508
+ if not analyzed_segments:
509
+ print("❌ No segments analyzed successfully")
510
+ sys.exit(1)
511
+
512
+ # Select best segments
513
+ analyzed_segments.sort(key=lambda x: x['combined_score'], reverse=True)
514
+ selected_segments = [s for s in analyzed_segments if s['combined_score'] >= args.min_score]
515
+ selected_segments = selected_segments[:args.max_highlights]
516
+
517
+ print(f"\nπŸ† Selected {len(selected_segments)} highlights:")
518
+ for i, segment in enumerate(selected_segments):
519
+ print(f"{i+1}. t={segment['start_time']:.1f}s, score={segment['combined_score']:.1f}")
520
+ if segment['audio_text']:
521
+ print(f" Audio: \"{segment['audio_text'][:100]}...\"")
522
+
523
+ if not selected_segments:
524
+ print(f"❌ No segments met minimum score of {args.min_score}")
525
+ sys.exit(1)
526
+
527
+ # Create highlights video
528
+ print(f"\n🎬 Creating highlights video...")
529
+ success = create_highlights_video(args.video_path, selected_segments, args.output)
530
+
531
+ if success:
532
+ print(f"βœ… Audio-enhanced highlights created: {args.output}")
533
+
534
+ # Save analysis if requested
535
+ if args.save_analysis:
536
+ analysis_file = args.output.replace('.mp4', '_analysis.json')
537
+ with open(analysis_file, 'w') as f:
538
+ json.dump({
539
+ 'input_video': args.video_path,
540
+ 'output_video': args.output,
541
+ 'settings': {
542
+ 'interval': args.interval,
543
+ 'min_score': args.min_score,
544
+ 'max_highlights': args.max_highlights,
545
+ 'whisper_model': args.whisper_model,
546
+ 'timeout': args.timeout
547
+ },
548
+ 'segments': analyzed_segments,
549
+ 'selected_segments': selected_segments
550
+ }, f, indent=2)
551
+ print(f"πŸ“Š Analysis saved: {analysis_file}")
552
+ else:
553
+ print("❌ Failed to create highlights video")
554
+ sys.exit(1)
555
+
556
+ except KeyboardInterrupt:
557
+ print("\n⏹️ Operation cancelled by user")
558
+ sys.exit(1)
559
+ except Exception as e:
560
+ print(f"❌ Error: {e}")
561
+ sys.exit(1)
562
+
563
+ if __name__ == "__main__":
564
+ main()
fastapi_requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI Dependencies for SmolVLM2 Video Highlights API
2
+ # Add these to your existing requirements.txt
3
+
4
+ fastapi==0.104.1
5
+ uvicorn[standard]==0.24.0
6
+ python-multipart==0.0.6
7
+ pydantic==2.5.0
8
+ python-jose[cryptography]==3.3.0
9
+ passlib[bcrypt]==1.7.4
highlights_api.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FastAPI Wrapper for Audio-Enhanced Video Highlights
4
+ Converts your SmolVLM2 + Whisper system into a web API for Android apps
5
+ """
6
+
7
+ from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
8
+ from fastapi.responses import FileResponse, JSONResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ import os
12
+ import sys
13
+ import tempfile
14
+ import uuid
15
+ import json
16
+ import asyncio
17
+ from pathlib import Path
18
+ from typing import Optional
19
+ import logging
20
+
21
+ # Add src directory to path for imports
22
+ sys.path.append(str(Path(__file__).parent / "src"))
23
+
24
+ try:
25
+ from audio_enhanced_highlights_final import AudioVisualAnalyzer, extract_frames_at_intervals, save_frame_at_time, create_highlights_video
26
+ except ImportError:
27
+ print("❌ Cannot import audio_enhanced_highlights_final.py")
28
+ sys.exit(1)
29
+
30
+ # Configure logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # FastAPI app
35
+ app = FastAPI(
36
+ title="SmolVLM2 Video Highlights API",
37
+ description="Generate intelligent video highlights using SmolVLM2 + Whisper",
38
+ version="1.0.0"
39
+ )
40
+
41
+ # Enable CORS for Android apps
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"], # In production, specify your Android app's domain
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ # Request/Response models
51
+ class AnalysisRequest(BaseModel):
52
+ interval: float = 20.0
53
+ min_score: float = 6.5
54
+ max_highlights: int = 3
55
+ whisper_model: str = "base"
56
+ timeout: int = 35
57
+
58
+ class AnalysisResponse(BaseModel):
59
+ job_id: str
60
+ status: str
61
+ message: str
62
+
63
+ class JobStatus(BaseModel):
64
+ job_id: str
65
+ status: str # "processing", "completed", "failed"
66
+ progress: int # 0-100
67
+ message: str
68
+ highlights_url: Optional[str] = None
69
+ analysis_url: Optional[str] = None
70
+
71
+ # Global storage for jobs (in production, use Redis/database)
72
+ active_jobs = {}
73
+ completed_jobs = {}
74
+
75
+ # Create output directories
76
+ os.makedirs("outputs", exist_ok=True)
77
+ os.makedirs("temp", exist_ok=True)
78
+
79
+ @app.get("/")
80
+ async def root():
81
+ return {
82
+ "message": "SmolVLM2 Video Highlights API",
83
+ "version": "1.0.0",
84
+ "endpoints": {
85
+ "upload": "/upload-video",
86
+ "status": "/job-status/{job_id}",
87
+ "download": "/download/{filename}"
88
+ }
89
+ }
90
+
91
+ @app.post("/upload-video", response_model=AnalysisResponse)
92
+ async def upload_video(
93
+ background_tasks: BackgroundTasks,
94
+ video: UploadFile = File(...),
95
+ interval: float = 20.0,
96
+ min_score: float = 6.5,
97
+ max_highlights: int = 3,
98
+ whisper_model: str = "base",
99
+ timeout: int = 35
100
+ ):
101
+ """
102
+ Upload a video and start processing highlights
103
+ """
104
+ # Validate file
105
+ if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
106
+ raise HTTPException(status_code=400, detail="Only video files are supported")
107
+
108
+ # Generate unique job ID
109
+ job_id = str(uuid.uuid4())
110
+
111
+ try:
112
+ # Save uploaded video
113
+ temp_video_path = f"temp/{job_id}_{video.filename}"
114
+ with open(temp_video_path, "wb") as f:
115
+ content = await video.read()
116
+ f.write(content)
117
+
118
+ # Store job info
119
+ active_jobs[job_id] = {
120
+ "status": "processing",
121
+ "progress": 0,
122
+ "message": "Video uploaded, starting analysis...",
123
+ "video_path": temp_video_path,
124
+ "settings": {
125
+ "interval": interval,
126
+ "min_score": min_score,
127
+ "max_highlights": max_highlights,
128
+ "whisper_model": whisper_model,
129
+ "timeout": timeout
130
+ }
131
+ }
132
+
133
+ # Start processing in background
134
+ background_tasks.add_task(
135
+ process_video_highlights,
136
+ job_id,
137
+ temp_video_path,
138
+ interval,
139
+ min_score,
140
+ max_highlights,
141
+ whisper_model,
142
+ timeout
143
+ )
144
+
145
+ return AnalysisResponse(
146
+ job_id=job_id,
147
+ status="processing",
148
+ message="Video uploaded successfully. Processing started."
149
+ )
150
+
151
+ except Exception as e:
152
+ logger.error(f"Upload failed: {e}")
153
+ raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
154
+
155
+ @app.get("/job-status/{job_id}", response_model=JobStatus)
156
+ async def get_job_status(job_id: str):
157
+ """
158
+ Get the status of a processing job
159
+ """
160
+ # Check active jobs
161
+ if job_id in active_jobs:
162
+ job = active_jobs[job_id]
163
+ return JobStatus(
164
+ job_id=job_id,
165
+ status=job["status"],
166
+ progress=job["progress"],
167
+ message=job["message"]
168
+ )
169
+
170
+ # Check completed jobs
171
+ if job_id in completed_jobs:
172
+ job = completed_jobs[job_id]
173
+ return JobStatus(
174
+ job_id=job_id,
175
+ status=job["status"],
176
+ progress=100,
177
+ message=job["message"],
178
+ highlights_url=job.get("highlights_url"),
179
+ analysis_url=job.get("analysis_url")
180
+ )
181
+
182
+ raise HTTPException(status_code=404, detail="Job not found")
183
+
184
+ @app.get("/download/{filename}")
185
+ async def download_file(filename: str):
186
+ """
187
+ Download generated files
188
+ """
189
+ file_path = f"outputs/{filename}"
190
+ if not os.path.exists(file_path):
191
+ raise HTTPException(status_code=404, detail="File not found")
192
+
193
+ return FileResponse(
194
+ file_path,
195
+ media_type='application/octet-stream',
196
+ filename=filename
197
+ )
198
+
199
+ async def process_video_highlights(
200
+ job_id: str,
201
+ video_path: str,
202
+ interval: float,
203
+ min_score: float,
204
+ max_highlights: int,
205
+ whisper_model: str,
206
+ timeout: int
207
+ ):
208
+ """
209
+ Background task to process video highlights
210
+ """
211
+ try:
212
+ # Update status
213
+ active_jobs[job_id]["progress"] = 10
214
+ active_jobs[job_id]["message"] = "Initializing AI models..."
215
+
216
+ # Initialize analyzer
217
+ analyzer = AudioVisualAnalyzer(
218
+ whisper_model_size=whisper_model,
219
+ timeout_seconds=timeout
220
+ )
221
+
222
+ active_jobs[job_id]["progress"] = 20
223
+ active_jobs[job_id]["message"] = "Extracting video segments..."
224
+
225
+ # Extract segments
226
+ segments = extract_frames_at_intervals(video_path, interval)
227
+ total_segments = len(segments)
228
+
229
+ active_jobs[job_id]["progress"] = 30
230
+ active_jobs[job_id]["message"] = f"Analyzing {total_segments} segments..."
231
+
232
+ # Analyze segments
233
+ analyzed_segments = []
234
+ temp_frame_path = f"temp/{job_id}_frame.jpg"
235
+
236
+ for i, segment in enumerate(segments):
237
+ # Update progress
238
+ progress = 30 + int((i / total_segments) * 50) # 30-80%
239
+ active_jobs[job_id]["progress"] = progress
240
+ active_jobs[job_id]["message"] = f"Analyzing segment {i+1}/{total_segments}"
241
+
242
+ # Save frame for visual analysis
243
+ if save_frame_at_time(video_path, segment['start_time'], temp_frame_path):
244
+ # Analyze segment
245
+ analysis = analyzer.analyze_segment(video_path, segment, temp_frame_path)
246
+ analyzed_segments.append(analysis)
247
+
248
+ # Cleanup temp frame
249
+ try:
250
+ os.unlink(temp_frame_path)
251
+ except:
252
+ pass
253
+
254
+ active_jobs[job_id]["progress"] = 85
255
+ active_jobs[job_id]["message"] = "Selecting best highlights..."
256
+
257
+ # Select best segments
258
+ analyzed_segments.sort(key=lambda x: x['combined_score'], reverse=True)
259
+ selected_segments = [s for s in analyzed_segments if s['combined_score'] >= min_score]
260
+ selected_segments = selected_segments[:max_highlights]
261
+
262
+ if not selected_segments:
263
+ raise Exception(f"No segments met minimum score of {min_score}")
264
+
265
+ active_jobs[job_id]["progress"] = 90
266
+ active_jobs[job_id]["message"] = f"Creating highlights video with {len(selected_segments)} segments..."
267
+
268
+ # Create output filenames
269
+ highlights_filename = f"{job_id}_highlights.mp4"
270
+ analysis_filename = f"{job_id}_analysis.json"
271
+ highlights_path = f"outputs/{highlights_filename}"
272
+ analysis_path = f"outputs/{analysis_filename}"
273
+
274
+ # Create highlights video
275
+ success = create_highlights_video(video_path, selected_segments, highlights_path)
276
+
277
+ if not success:
278
+ raise Exception("Failed to create highlights video")
279
+
280
+ # Save analysis
281
+ analysis_data = {
282
+ 'job_id': job_id,
283
+ 'input_video': video_path,
284
+ 'output_video': highlights_path,
285
+ 'settings': {
286
+ 'interval': interval,
287
+ 'min_score': min_score,
288
+ 'max_highlights': max_highlights,
289
+ 'whisper_model': whisper_model,
290
+ 'timeout': timeout
291
+ },
292
+ 'segments': analyzed_segments,
293
+ 'selected_segments': selected_segments,
294
+ 'summary': {
295
+ 'total_segments': len(analyzed_segments),
296
+ 'selected_segments': len(selected_segments),
297
+ 'processing_time': "Completed successfully"
298
+ }
299
+ }
300
+
301
+ with open(analysis_path, 'w') as f:
302
+ json.dump(analysis_data, f, indent=2)
303
+
304
+ # Mark as completed
305
+ completed_jobs[job_id] = {
306
+ "status": "completed",
307
+ "message": f"Successfully created highlights with {len(selected_segments)} segments",
308
+ "highlights_url": f"/download/{highlights_filename}",
309
+ "analysis_url": f"/download/{analysis_filename}",
310
+ "summary": analysis_data['summary']
311
+ }
312
+
313
+ # Remove from active jobs
314
+ del active_jobs[job_id]
315
+
316
+ # Cleanup temp video
317
+ try:
318
+ os.unlink(video_path)
319
+ except:
320
+ pass
321
+
322
+ except Exception as e:
323
+ logger.error(f"Processing failed for job {job_id}: {e}")
324
+
325
+ # Mark as failed
326
+ completed_jobs[job_id] = {
327
+ "status": "failed",
328
+ "message": f"Processing failed: {str(e)}",
329
+ "highlights_url": None,
330
+ "analysis_url": None
331
+ }
332
+
333
+ # Remove from active jobs
334
+ if job_id in active_jobs:
335
+ del active_jobs[job_id]
336
+
337
+ # Cleanup
338
+ try:
339
+ os.unlink(video_path)
340
+ except:
341
+ pass
342
+
343
+ if __name__ == "__main__":
344
+ import uvicorn
345
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolVLM2 Video Testing Requirements
2
+
3
+ # Core ML and Vision Libraries
4
+ torch>=2.0.0
5
+ torchvision>=0.15.0
6
+ transformers>=4.40.0
7
+ accelerate>=0.27.0
8
+ pillow>=10.0.0
9
+
10
+ # Video Processing
11
+ opencv-python>=4.8.0
12
+ imageio>=2.31.0
13
+ imageio-ffmpeg>=0.4.9
14
+
15
+ # Hugging Face Integration
16
+ huggingface-hub>=0.20.0
17
+ datasets>=2.16.0
18
+
19
+ # Utilities
20
+ numpy>=1.24.0
21
+ matplotlib>=3.7.0
22
+ tqdm>=4.65.0
23
+ requests>=2.31.0
24
+
25
+ # Development Tools
26
+ jupyter>=1.0.0
27
+ ipykernel>=6.25.0
28
+ black>=23.0.0
29
+ flake8>=6.0.0
30
+
31
+ # Optional: For better performance on Apple Silicon
32
+ # Install with: pip install --upgrade --force-reinstall --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
src/smolvlm2_handler.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SmolVLM2 Model Handler
4
+ Handles loading and inference with SmolVLM2-1.7B-Instruct model
5
+ """
6
+
7
+ import torch
8
+ from transformers import AutoModelForImageTextToText, AutoProcessor
9
+ from PIL import Image
10
+ import requests
11
+ from typing import List, Union, Optional
12
+ import logging
13
+ import warnings
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Suppress some warnings for cleaner output
20
+ warnings.filterwarnings("ignore", category=UserWarning)
21
+
22
+ class SmolVLM2Handler:
23
+ """Handler for SmolVLM2 model operations"""
24
+
25
+ def __init__(self, model_name: str = "HuggingFaceTB/SmolVLM2-2.2B-Instruct", device: str = "auto"):
26
+ """
27
+ Initialize SmolVLM2 model
28
+
29
+ Args:
30
+ model_name: HuggingFace model identifier
31
+ device: Device to use ('auto', 'cpu', 'cuda', 'mps')
32
+ """
33
+ self.model_name = model_name
34
+ self.device = self._get_device(device)
35
+ self.model = None
36
+ self.processor = None
37
+
38
+ logger.info(f"Initializing SmolVLM2 on device: {self.device}")
39
+ self._load_model()
40
+
41
+ def _get_device(self, device: str) -> str:
42
+ """Determine the best device to use"""
43
+ if device == "auto":
44
+ if torch.cuda.is_available():
45
+ return "cuda"
46
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
47
+ return "mps" # Apple Silicon GPU
48
+ else:
49
+ return "cpu"
50
+ return device
51
+
52
+ def _load_model(self):
53
+ """Load the model and processor"""
54
+ try:
55
+ logger.info("Loading processor...")
56
+ self.processor = AutoProcessor.from_pretrained(
57
+ self.model_name,
58
+ trust_remote_code=True
59
+ )
60
+
61
+ logger.info("Loading model...")
62
+ self.model = AutoModelForImageTextToText.from_pretrained(
63
+ self.model_name,
64
+ torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
65
+ trust_remote_code=True,
66
+ device_map=self.device if self.device != "cpu" else None
67
+ )
68
+
69
+ if self.device == "cpu":
70
+ self.model = self.model.to(self.device)
71
+
72
+ logger.info("βœ… Model loaded successfully!")
73
+
74
+ except Exception as e:
75
+ logger.error(f"❌ Failed to load model: {e}")
76
+ raise
77
+
78
+ def process_image(self, image_input: Union[str, Image.Image]) -> Image.Image:
79
+ """
80
+ Process image input into PIL Image
81
+
82
+ Args:
83
+ image_input: File path, URL, or PIL Image
84
+
85
+ Returns:
86
+ PIL Image object
87
+ """
88
+ if isinstance(image_input, str):
89
+ if image_input.startswith(('http://', 'https://')):
90
+ # Download from URL
91
+ response = requests.get(image_input)
92
+ image = Image.open(requests.get(image_input, stream=True).raw)
93
+ else:
94
+ # Load from file path
95
+ image = Image.open(image_input)
96
+ elif isinstance(image_input, Image.Image):
97
+ image = image_input
98
+ else:
99
+ raise ValueError("Image input must be file path, URL, or PIL Image")
100
+
101
+ # Convert to RGB if necessary
102
+ if image.mode != 'RGB':
103
+ image = image.convert('RGB')
104
+
105
+ return image
106
+
107
+ def generate_response(
108
+ self,
109
+ image_input: Union[str, Image.Image, List[Image.Image]],
110
+ text_prompt: str,
111
+ max_new_tokens: int = 512,
112
+ temperature: float = 0.7,
113
+ do_sample: bool = True
114
+ ) -> str:
115
+ """
116
+ Generate response from image(s) and text prompt
117
+
118
+ Args:
119
+ image_input: Single image or list of images
120
+ text_prompt: Text prompt/question
121
+ max_new_tokens: Maximum tokens to generate
122
+ temperature: Sampling temperature
123
+ do_sample: Whether to use sampling
124
+
125
+ Returns:
126
+ Generated text response
127
+ """
128
+ try:
129
+ # Process images
130
+ if isinstance(image_input, list):
131
+ images = [self.process_image(img) for img in image_input]
132
+ else:
133
+ images = [self.process_image(image_input)]
134
+
135
+ # Create proper conversation format for SmolVLM2
136
+ messages = [
137
+ {
138
+ "role": "user",
139
+ "content": [{"type": "text", "text": text_prompt}]
140
+ }
141
+ ]
142
+
143
+ # Add image content to the message
144
+ for img in images:
145
+ messages[0]["content"].insert(0, {"type": "image", "image": img})
146
+
147
+ # Apply chat template
148
+ try:
149
+ prompt = self.processor.apply_chat_template(
150
+ messages,
151
+ add_generation_prompt=True
152
+ )
153
+ except:
154
+ # Fallback to simple format if chat template fails
155
+ image_tokens = "<image>" * len(images)
156
+ prompt = f"{image_tokens}{text_prompt}"
157
+
158
+ # Prepare inputs
159
+ inputs = self.processor(
160
+ images=images,
161
+ text=prompt,
162
+ return_tensors="pt"
163
+ ).to(self.device)
164
+
165
+ # Generate response with robust parameters
166
+ with torch.no_grad():
167
+ try:
168
+ generated_ids = self.model.generate(
169
+ **inputs,
170
+ max_new_tokens=max_new_tokens,
171
+ temperature=max(0.1, min(temperature, 1.0)), # Clamp temperature
172
+ do_sample=do_sample,
173
+ top_p=0.9,
174
+ repetition_penalty=1.1,
175
+ pad_token_id=self.processor.tokenizer.eos_token_id,
176
+ eos_token_id=self.processor.tokenizer.eos_token_id,
177
+ use_cache=True
178
+ )
179
+ except RuntimeError as e:
180
+ if "probability tensor" in str(e) or "nan" in str(e) or "inf" in str(e):
181
+ # Retry with more conservative parameters
182
+ logger.warning("Retrying with conservative parameters due to probability tensor error")
183
+ generated_ids = self.model.generate(
184
+ **inputs,
185
+ max_new_tokens=min(max_new_tokens, 256),
186
+ temperature=0.3,
187
+ do_sample=False, # Use greedy decoding
188
+ pad_token_id=self.processor.tokenizer.eos_token_id,
189
+ eos_token_id=self.processor.tokenizer.eos_token_id,
190
+ use_cache=True
191
+ )
192
+ else:
193
+ raise
194
+
195
+ # Decode only the new tokens (skip input)
196
+ input_length = inputs['input_ids'].shape[1]
197
+ new_tokens = generated_ids[0][input_length:]
198
+
199
+ generated_text = self.processor.tokenizer.decode(
200
+ new_tokens,
201
+ skip_special_tokens=True
202
+ ).strip()
203
+
204
+ # Return meaningful response even if empty
205
+ if not generated_text:
206
+ return "I can see the image but cannot generate a specific description."
207
+
208
+ return generated_text
209
+
210
+ except Exception as e:
211
+ logger.error(f"❌ Error during generation: {e}")
212
+ raise
213
+
214
+ def analyze_video_frames(
215
+ self,
216
+ frames: List[Image.Image],
217
+ question: str,
218
+ max_frames: int = 8
219
+ ) -> str:
220
+ """
221
+ Analyze video frames and answer questions
222
+
223
+ Args:
224
+ frames: List of PIL Image frames
225
+ question: Question about the video
226
+ max_frames: Maximum number of frames to process
227
+
228
+ Returns:
229
+ Analysis result
230
+ """
231
+ # Sample frames if too many
232
+ if len(frames) > max_frames:
233
+ step = len(frames) // max_frames
234
+ sampled_frames = frames[::step][:max_frames]
235
+ else:
236
+ sampled_frames = frames
237
+
238
+ logger.info(f"Analyzing {len(sampled_frames)} frames")
239
+
240
+ # Create a simple prompt for video analysis (don't add image tokens manually)
241
+ video_prompt = f"These are frames from a video. {question}"
242
+
243
+ return self.generate_response(sampled_frames, video_prompt)
244
+
245
+ def get_model_info(self) -> dict:
246
+ """Get information about the loaded model"""
247
+ return {
248
+ "model_name": self.model_name,
249
+ "device": self.device,
250
+ "model_type": type(self.model).__name__,
251
+ "processor_type": type(self.processor).__name__,
252
+ "loaded": self.model is not None and self.processor is not None
253
+ }
254
+
255
+ def test_model():
256
+ """Test the model with a simple example"""
257
+ try:
258
+ # Initialize model
259
+ vlm = SmolVLM2Handler()
260
+
261
+ print("πŸ“‹ Model Info:")
262
+ info = vlm.get_model_info()
263
+ for key, value in info.items():
264
+ print(f" {key}: {value}")
265
+
266
+ # Test with a simple image (create a test image)
267
+ test_image = Image.new('RGB', (224, 224), color='blue')
268
+ test_prompt = "What color is this image?"
269
+
270
+ print(f"\nπŸ” Testing with prompt: '{test_prompt}'")
271
+ response = vlm.generate_response(test_image, test_prompt)
272
+ print(f"πŸ“ Response: {response}")
273
+
274
+ print("\nβœ… Model test completed successfully!")
275
+
276
+ except Exception as e:
277
+ print(f"❌ Model test failed: {e}")
278
+ raise
279
+
280
+ if __name__ == "__main__":
281
+ test_model()