rahul7star commited on
Commit
4d88196
·
verified ·
1 Parent(s): 35d853d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -248
app.py CHANGED
@@ -7,7 +7,6 @@ import numpy as np
7
  from typing import Optional
8
  import tempfile
9
  import os
10
- import spaces
11
 
12
  MID = "apple/FastVLM-7B"
13
  IMAGE_TOKEN_INDEX = -200
@@ -19,15 +18,15 @@ model = None
19
  def load_model():
20
  global tok, model
21
  if tok is None or model is None:
22
- print("Loading FastVLM model...")
23
  tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MID,
26
- torch_dtype=torch.float16,
27
- device_map="cuda",
28
  trust_remote_code=True,
29
  )
30
- print("Model loaded successfully!")
31
  return tok, model
32
 
33
  def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
@@ -40,19 +39,14 @@ def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str =
40
  return []
41
 
42
  frames = []
43
-
44
  if sampling_method == "uniform":
45
- # Uniform sampling
46
  indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
47
  elif sampling_method == "first":
48
- # Take first N frames
49
  indices = list(range(min(num_frames, total_frames)))
50
  elif sampling_method == "last":
51
- # Take last N frames
52
  start = max(0, total_frames - num_frames)
53
  indices = list(range(start, total_frames))
54
  else: # middle
55
- # Take frames from the middle
56
  start = max(0, (total_frames - num_frames) // 2)
57
  indices = list(range(start, min(start + num_frames, total_frames)))
58
 
@@ -60,41 +54,32 @@ def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str =
60
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
61
  ret, frame = cap.read()
62
  if ret:
63
- # Convert BGR to RGB
64
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
  frames.append(Image.fromarray(frame_rgb))
66
 
67
  cap.release()
68
  return frames
69
 
70
- @spaces.GPU(duration=60)
71
  def caption_frame(image: Image.Image, prompt: str) -> str:
72
- """Generate caption for a single frame"""
73
- # Load model on GPU
74
  tok, model = load_model()
75
- # Build chat with custom prompt
76
- messages = [
77
- {"role": "user", "content": f"<image>\n{prompt}"}
78
- ]
79
- rendered = tok.apply_chat_template(
80
- messages, add_generation_prompt=True, tokenize=False
81
- )
82
  pre, post = rendered.split("<image>", 1)
83
-
84
- # Tokenize the text around the image token
85
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
86
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
87
-
88
- # Splice in the IMAGE token id
89
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
90
- input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
91
- attention_mask = torch.ones_like(input_ids, device=model.device)
92
-
 
93
  # Preprocess image
94
  px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
95
- px = px.to(model.device, dtype=model.dtype)
96
-
97
- # Generate
98
  with torch.no_grad():
99
  out = model.generate(
100
  inputs=input_ids,
@@ -104,225 +89,9 @@ def caption_frame(image: Image.Image, prompt: str) -> str:
104
  temperature=0.7,
105
  do_sample=True,
106
  )
107
-
108
  caption = tok.decode(out[0], skip_special_tokens=True)
109
- # Extract only the generated part
110
  if prompt in caption:
111
  caption = caption.split(prompt)[-1].strip()
112
 
113
  return caption
114
-
115
- def process_video(
116
- video_path: str,
117
- num_frames: int,
118
- sampling_method: str,
119
- caption_mode: str,
120
- custom_prompt: str,
121
- progress=gr.Progress()
122
- ) -> tuple:
123
- """Process video and generate captions"""
124
-
125
- if not video_path:
126
- return "Please upload a video first.", None
127
-
128
- progress(0, desc="Extracting frames...")
129
- frames = extract_frames(video_path, num_frames, sampling_method)
130
-
131
- if not frames:
132
- return "Failed to extract frames from video.", None
133
-
134
- # Use brief one-sentence prompt for faster processing
135
- prompt = "Provide a brief one-sentence description of what's happening in this image."
136
-
137
- captions = []
138
- frame_previews = []
139
-
140
- for i, frame in enumerate(frames):
141
- progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...")
142
- caption = caption_frame(frame, prompt)
143
- captions.append(f"Frame {i + 1}: {caption}")
144
- frame_previews.append(frame)
145
-
146
- progress(1.0, desc="Generating summary...")
147
-
148
- # Combine captions into a simple narrative
149
- full_caption = "\n".join(captions)
150
-
151
- # Generate overall summary if multiple frames
152
- if len(frames) > 1:
153
- video_summary = f"Analyzed {len(frames)} frames:\n\n{full_caption}"
154
- else:
155
- video_summary = f"Video Analysis:\n\n{full_caption}"
156
-
157
- return video_summary, frame_previews
158
-
159
- # Create the Gradio interface
160
- # Create custom Apple-inspired theme
161
- class AppleTheme(gr.themes.Base):
162
- def __init__(self):
163
- super().__init__(
164
- primary_hue=gr.themes.colors.blue,
165
- secondary_hue=gr.themes.colors.gray,
166
- neutral_hue=gr.themes.colors.gray,
167
- spacing_size=gr.themes.sizes.spacing_md,
168
- radius_size=gr.themes.sizes.radius_md,
169
- text_size=gr.themes.sizes.text_md,
170
- font=[
171
- gr.themes.GoogleFont("Inter"),
172
- "-apple-system",
173
- "BlinkMacSystemFont",
174
- "SF Pro Display",
175
- "SF Pro Text",
176
- "Helvetica Neue",
177
- "Helvetica",
178
- "Arial",
179
- "sans-serif"
180
- ],
181
- font_mono=[
182
- gr.themes.GoogleFont("SF Mono"),
183
- "ui-monospace",
184
- "Consolas",
185
- "monospace"
186
- ]
187
- )
188
- super().set(
189
- # Core colors
190
- body_background_fill="*neutral_50",
191
- body_background_fill_dark="*neutral_950",
192
- button_primary_background_fill="*primary_500",
193
- button_primary_background_fill_hover="*primary_600",
194
- button_primary_text_color="white",
195
- button_primary_border_color="*primary_500",
196
-
197
- # Shadows
198
- block_shadow="0 4px 12px rgba(0, 0, 0, 0.08)",
199
-
200
- # Borders
201
- block_border_width="1px",
202
- block_border_color="*neutral_200",
203
- input_border_width="1px",
204
- input_border_color="*neutral_300",
205
- input_border_color_focus="*primary_500",
206
-
207
- # Text
208
- block_title_text_weight="600",
209
- block_label_text_weight="500",
210
- block_label_text_size="13px",
211
- block_label_text_color="*neutral_600",
212
- body_text_color="*neutral_900",
213
-
214
- # Spacing
215
- layout_gap="16px",
216
- block_padding="20px",
217
-
218
- # Specific components
219
- slider_color="*primary_500",
220
- )
221
-
222
- # Create the Gradio interface with the custom theme
223
- with gr.Blocks(theme=AppleTheme()) as demo:
224
- gr.Markdown("# 🎬 FastVLM Video Captioning")
225
-
226
- with gr.Row():
227
- # Main video display
228
- with gr.Column(scale=7):
229
- video_display = gr.Video(
230
- label="Video Input",
231
- autoplay=True,
232
- loop=True
233
- )
234
-
235
- # Sidebar with chat interface
236
- with gr.Sidebar(width=400):
237
- gr.Markdown("## 💬 Video Analysis Chat")
238
-
239
- chatbot = gr.Chatbot(
240
- value=[["Assistant", "Upload a video and I'll analyze it for you!"]],
241
- height=400,
242
- elem_classes=["chatbot"]
243
- )
244
-
245
- process_btn = gr.Button("🎯 Analyze Video", variant="primary", size="lg")
246
-
247
- with gr.Accordion("🖼️ Analyzed Frames", open=False):
248
- frame_gallery = gr.Gallery(
249
- label="Extracted Frames",
250
- show_label=False,
251
- columns=2,
252
- rows=4,
253
- object_fit="contain",
254
- height="auto"
255
- )
256
-
257
- # Hidden parameters with default values
258
- num_frames = gr.State(value=8)
259
- sampling_method = gr.State(value="uniform")
260
- caption_mode = gr.State(value="Brief Summary")
261
- custom_prompt = gr.State(value="")
262
-
263
- # Upload handler
264
- def handle_upload(video, chat_history):
265
- if video:
266
- chat_history.append(["User", "Video uploaded"])
267
- chat_history.append(["Assistant", "Video loaded! Click 'Analyze Video' to generate captions."])
268
- return video, chat_history
269
- return None, chat_history
270
-
271
- video_display.upload(
272
- handle_upload,
273
- inputs=[video_display, chatbot],
274
- outputs=[video_display, chatbot]
275
- )
276
-
277
- # Modified process function to update chatbot with streaming
278
- def process_video_with_chat(video_path, num_frames, sampling_method, caption_mode, custom_prompt, chat_history, progress=gr.Progress()):
279
- if not video_path:
280
- chat_history.append(["Assistant", "Please upload a video first."])
281
- yield chat_history, None
282
- return
283
-
284
- chat_history.append(["User", "Analyzing video..."])
285
- yield chat_history, None
286
-
287
- # Extract frames
288
- progress(0, desc="Extracting frames...")
289
- frames = extract_frames(video_path, num_frames, sampling_method)
290
-
291
- if not frames:
292
- chat_history.append(["Assistant", "Failed to extract frames from video."])
293
- yield chat_history, None
294
- return
295
-
296
- # Start streaming response
297
- chat_history.append(["Assistant", ""])
298
- prompt = "Provide a brief one-sentence description of what's happening in this image."
299
-
300
- captions = []
301
- for i, frame in enumerate(frames):
302
- progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...")
303
- caption = caption_frame(frame, prompt)
304
- frame_caption = f"Frame {i + 1}: {caption}\n"
305
- captions.append(frame_caption)
306
-
307
- # Update the last message with accumulated captions
308
- current_text = "".join(captions)
309
- chat_history[-1] = ["Assistant", f"Analyzing {len(frames)} frames:\n\n{current_text}"]
310
- yield chat_history, frames[:i+1] # Also update frame gallery progressively
311
-
312
- progress(1.0, desc="Analysis complete!")
313
-
314
- # Final update with complete message
315
- full_caption = "".join(captions)
316
- final_message = f"Analyzed {len(frames)} frames:\n\n{full_caption}"
317
- chat_history[-1] = ["Assistant", final_message]
318
- yield chat_history, frames
319
-
320
- # Process button with streaming
321
- process_btn.click(
322
- process_video_with_chat,
323
- inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt, chatbot],
324
- outputs=[chatbot, frame_gallery],
325
- show_progress=True
326
- )
327
-
328
- demo.launch()
 
7
  from typing import Optional
8
  import tempfile
9
  import os
 
10
 
11
  MID = "apple/FastVLM-7B"
12
  IMAGE_TOKEN_INDEX = -200
 
18
  def load_model():
19
  global tok, model
20
  if tok is None or model is None:
21
+ print("Loading FastVLM model (CPU only)...")
22
  tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
23
  model = AutoModelForCausalLM.from_pretrained(
24
  MID,
25
+ torch_dtype=torch.float32, # ✅ CPU-friendly dtype
26
+ device_map="cpu", # ✅ Force CPU
27
  trust_remote_code=True,
28
  )
29
+ print("Model loaded successfully on CPU!")
30
  return tok, model
31
 
32
  def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
 
39
  return []
40
 
41
  frames = []
 
42
  if sampling_method == "uniform":
 
43
  indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
44
  elif sampling_method == "first":
 
45
  indices = list(range(min(num_frames, total_frames)))
46
  elif sampling_method == "last":
 
47
  start = max(0, total_frames - num_frames)
48
  indices = list(range(start, total_frames))
49
  else: # middle
 
50
  start = max(0, (total_frames - num_frames) // 2)
51
  indices = list(range(start, min(start + num_frames, total_frames)))
52
 
 
54
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
55
  ret, frame = cap.read()
56
  if ret:
 
57
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
58
  frames.append(Image.fromarray(frame_rgb))
59
 
60
  cap.release()
61
  return frames
62
 
 
63
  def caption_frame(image: Image.Image, prompt: str) -> str:
64
+ """Generate caption for a single frame (CPU only)"""
 
65
  tok, model = load_model()
66
+
67
+ messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
68
+ rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 
 
 
 
69
  pre, post = rendered.split("<image>", 1)
70
+
 
71
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
72
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
73
+
 
74
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
75
+ input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
76
+
77
+ attention_mask = torch.ones_like(input_ids)
78
+
79
  # Preprocess image
80
  px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
81
+
82
+ # Generate on CPU
 
83
  with torch.no_grad():
84
  out = model.generate(
85
  inputs=input_ids,
 
89
  temperature=0.7,
90
  do_sample=True,
91
  )
92
+
93
  caption = tok.decode(out[0], skip_special_tokens=True)
 
94
  if prompt in caption:
95
  caption = caption.split(prompt)[-1].strip()
96
 
97
  return caption