rahul7star commited on
Commit
920c71d
·
verified ·
1 Parent(s): 4d88196

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -12
app.py CHANGED
@@ -4,8 +4,6 @@ from PIL import Image
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import cv2
6
  import numpy as np
7
- from typing import Optional
8
- import tempfile
9
  import os
10
 
11
  MID = "apple/FastVLM-7B"
@@ -15,6 +13,7 @@ IMAGE_TOKEN_INDEX = -200
15
  tok = None
16
  model = None
17
 
 
18
  def load_model():
19
  global tok, model
20
  if tok is None or model is None:
@@ -29,15 +28,16 @@ def load_model():
29
  print("Model loaded successfully on CPU!")
30
  return tok, model
31
 
 
 
32
  def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
33
- """Extract frames from video"""
34
  cap = cv2.VideoCapture(video_path)
35
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36
-
37
  if total_frames == 0:
38
  cap.release()
39
  return []
40
-
41
  frames = []
42
  if sampling_method == "uniform":
43
  indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
@@ -49,19 +49,20 @@ def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str =
49
  else: # middle
50
  start = max(0, (total_frames - num_frames) // 2)
51
  indices = list(range(start, min(start + num_frames, total_frames)))
52
-
53
  for idx in indices:
54
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
55
  ret, frame = cap.read()
56
  if ret:
57
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
58
  frames.append(Image.fromarray(frame_rgb))
59
-
60
  cap.release()
61
  return frames
62
 
 
 
63
  def caption_frame(image: Image.Image, prompt: str) -> str:
64
- """Generate caption for a single frame (CPU only)"""
65
  tok, model = load_model()
66
 
67
  messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
@@ -75,11 +76,8 @@ def caption_frame(image: Image.Image, prompt: str) -> str:
75
  input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
76
 
77
  attention_mask = torch.ones_like(input_ids)
78
-
79
- # Preprocess image
80
  px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
81
 
82
- # Generate on CPU
83
  with torch.no_grad():
84
  out = model.generate(
85
  inputs=input_ids,
@@ -93,5 +91,85 @@ def caption_frame(image: Image.Image, prompt: str) -> str:
93
  caption = tok.decode(out[0], skip_special_tokens=True)
94
  if prompt in caption:
95
  caption = caption.split(prompt)[-1].strip()
96
-
97
  return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import cv2
6
  import numpy as np
 
 
7
  import os
8
 
9
  MID = "apple/FastVLM-7B"
 
13
  tok = None
14
  model = None
15
 
16
+ # ---------------- Load Model ----------------
17
  def load_model():
18
  global tok, model
19
  if tok is None or model is None:
 
28
  print("Model loaded successfully on CPU!")
29
  return tok, model
30
 
31
+
32
+ # ---------------- Frame Extraction ----------------
33
  def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
 
34
  cap = cv2.VideoCapture(video_path)
35
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36
+
37
  if total_frames == 0:
38
  cap.release()
39
  return []
40
+
41
  frames = []
42
  if sampling_method == "uniform":
43
  indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
 
49
  else: # middle
50
  start = max(0, (total_frames - num_frames) // 2)
51
  indices = list(range(start, min(start + num_frames, total_frames)))
52
+
53
  for idx in indices:
54
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
55
  ret, frame = cap.read()
56
  if ret:
57
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
58
  frames.append(Image.fromarray(frame_rgb))
59
+
60
  cap.release()
61
  return frames
62
 
63
+
64
+ # ---------------- Caption Frame ----------------
65
  def caption_frame(image: Image.Image, prompt: str) -> str:
 
66
  tok, model = load_model()
67
 
68
  messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
 
76
  input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
77
 
78
  attention_mask = torch.ones_like(input_ids)
 
 
79
  px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
80
 
 
81
  with torch.no_grad():
82
  out = model.generate(
83
  inputs=input_ids,
 
91
  caption = tok.decode(out[0], skip_special_tokens=True)
92
  if prompt in caption:
93
  caption = caption.split(prompt)[-1].strip()
94
+
95
  return caption
96
+
97
+
98
+ # ---------------- Process Video ----------------
99
+ def process_video(video_path, num_frames, sampling_method, chat_history, progress=gr.Progress()):
100
+ if not video_path:
101
+ chat_history.append(["Assistant", "Please upload a video first."])
102
+ return chat_history, None
103
+
104
+ progress(0, desc="Extracting frames...")
105
+ frames = extract_frames(video_path, num_frames, sampling_method)
106
+
107
+ if not frames:
108
+ chat_history.append(["Assistant", "Failed to extract frames."])
109
+ return chat_history, None
110
+
111
+ prompt = "Provide a brief one-sentence description of what's happening in this image."
112
+ captions = []
113
+
114
+ chat_history.append(["Assistant", "Analyzing frames..."])
115
+ for i, frame in enumerate(frames):
116
+ caption = caption_frame(frame, prompt)
117
+ captions.append(f"Frame {i+1}: {caption}")
118
+ chat_history[-1] = ["Assistant", "\n".join(captions)]
119
+ progress((i + 1) / len(frames))
120
+
121
+ progress(1.0, desc="Analysis complete!")
122
+ return chat_history, frames
123
+
124
+
125
+ # ---------------- Custom Apple-like Theme ----------------
126
+ class AppleTheme(gr.themes.Base):
127
+ def __init__(self):
128
+ super().__init__(
129
+ primary_hue=gr.themes.colors.blue,
130
+ secondary_hue=gr.themes.colors.gray,
131
+ neutral_hue=gr.themes.colors.gray,
132
+ spacing_size=gr.themes.sizes.spacing_md,
133
+ radius_size=gr.themes.sizes.radius_md,
134
+ text_size=gr.themes.sizes.text_md,
135
+ font=[gr.themes.GoogleFont("Inter"), "SF Pro Display", "Helvetica Neue", "Arial", "sans-serif"],
136
+ font_mono=[gr.themes.GoogleFont("SF Mono"), "Consolas", "monospace"]
137
+ )
138
+
139
+
140
+ # ---------------- Gradio UI ----------------
141
+ with gr.Blocks(theme=AppleTheme()) as demo:
142
+ gr.Markdown("# 🎬 FastVLM Video Captioning (CPU Only)")
143
+
144
+ with gr.Row():
145
+ with gr.Column(scale=7):
146
+ video_display = gr.Video(label="Video Input", autoplay=True, loop=True)
147
+
148
+ with gr.Sidebar(width=400):
149
+ chatbot = gr.Chatbot(
150
+ value=[["Assistant", "Upload a video and I'll analyze it for you!"]],
151
+ height=400
152
+ )
153
+ process_btn = gr.Button("🎯 Analyze Video", variant="primary")
154
+
155
+ with gr.Accordion("🖼️ Analyzed Frames", open=False):
156
+ frame_gallery = gr.Gallery(columns=2, rows=4, height="auto")
157
+
158
+ num_frames = gr.State(value=4)
159
+ sampling_method = gr.State(value="uniform")
160
+
161
+ process_btn.click(
162
+ fn=process_video,
163
+ inputs=[video_display, num_frames, sampling_method, chatbot],
164
+ outputs=[chatbot, frame_gallery],
165
+ show_progress=True
166
+ )
167
+
168
+
169
+ # ---------------- Launch ----------------
170
+ demo.launch(
171
+ server_name="0.0.0.0", # Spaces/containers need this
172
+ server_port=7860,
173
+ share=False,
174
+ show_error=True
175
+ )