|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from typing import Optional |
|
|
import tempfile |
|
|
import os |
|
|
import spaces |
|
|
|
|
|
MID = "apple/FastVLM-7B" |
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
|
|
|
|
tok = None |
|
|
model = None |
|
|
|
|
|
def load_model(): |
|
|
global tok, model |
|
|
if tok is None or model is None: |
|
|
print("Loading FastVLM model...") |
|
|
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MID, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="cuda", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
print("Model loaded successfully!") |
|
|
return tok, model |
|
|
|
|
|
def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"): |
|
|
"""Extract frames from video""" |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
if total_frames == 0: |
|
|
cap.release() |
|
|
return [] |
|
|
|
|
|
frames = [] |
|
|
|
|
|
if sampling_method == "uniform": |
|
|
|
|
|
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
|
|
elif sampling_method == "first": |
|
|
|
|
|
indices = list(range(min(num_frames, total_frames))) |
|
|
elif sampling_method == "last": |
|
|
|
|
|
start = max(0, total_frames - num_frames) |
|
|
indices = list(range(start, total_frames)) |
|
|
else: |
|
|
|
|
|
start = max(0, (total_frames - num_frames) // 2) |
|
|
indices = list(range(start, min(start + num_frames, total_frames))) |
|
|
|
|
|
for idx in indices: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(Image.fromarray(frame_rgb)) |
|
|
|
|
|
cap.release() |
|
|
return frames |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def caption_frame(image: Image.Image, prompt: str) -> str: |
|
|
"""Generate caption for a single frame""" |
|
|
|
|
|
tok, model = load_model() |
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": f"<image>\n{prompt}"} |
|
|
] |
|
|
rendered = tok.apply_chat_template( |
|
|
messages, add_generation_prompt=True, tokenize=False |
|
|
) |
|
|
pre, post = rendered.split("<image>", 1) |
|
|
|
|
|
|
|
|
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
|
|
|
|
|
|
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) |
|
|
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) |
|
|
attention_mask = torch.ones_like(input_ids, device=model.device) |
|
|
|
|
|
|
|
|
px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] |
|
|
px = px.to(model.device, dtype=model.dtype) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
inputs=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
images=px, |
|
|
max_new_tokens=15, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
caption = tok.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
if prompt in caption: |
|
|
caption = caption.split(prompt)[-1].strip() |
|
|
|
|
|
return caption |
|
|
|
|
|
def process_video( |
|
|
video_path: str, |
|
|
num_frames: int, |
|
|
sampling_method: str, |
|
|
caption_mode: str, |
|
|
custom_prompt: str, |
|
|
progress=gr.Progress() |
|
|
) -> tuple: |
|
|
"""Process video and generate captions""" |
|
|
|
|
|
if not video_path: |
|
|
return "Please upload a video first.", None |
|
|
|
|
|
progress(0, desc="Extracting frames...") |
|
|
frames = extract_frames(video_path, num_frames, sampling_method) |
|
|
|
|
|
if not frames: |
|
|
return "Failed to extract frames from video.", None |
|
|
|
|
|
|
|
|
prompt = "Provide a brief one-sentence description of what's happening in this image." |
|
|
|
|
|
captions = [] |
|
|
frame_previews = [] |
|
|
|
|
|
for i, frame in enumerate(frames): |
|
|
progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") |
|
|
caption = caption_frame(frame, prompt) |
|
|
captions.append(f"Frame {i + 1}: {caption}") |
|
|
frame_previews.append(frame) |
|
|
|
|
|
progress(1.0, desc="Generating summary...") |
|
|
|
|
|
|
|
|
full_caption = "\n".join(captions) |
|
|
|
|
|
|
|
|
if len(frames) > 1: |
|
|
video_summary = f"Analyzed {len(frames)} frames:\n\n{full_caption}" |
|
|
else: |
|
|
video_summary = f"Video Analysis:\n\n{full_caption}" |
|
|
|
|
|
return video_summary, frame_previews |
|
|
|
|
|
|
|
|
|
|
|
class AppleTheme(gr.themes.Base): |
|
|
def __init__(self): |
|
|
super().__init__( |
|
|
primary_hue=gr.themes.colors.blue, |
|
|
secondary_hue=gr.themes.colors.gray, |
|
|
neutral_hue=gr.themes.colors.gray, |
|
|
spacing_size=gr.themes.sizes.spacing_md, |
|
|
radius_size=gr.themes.sizes.radius_md, |
|
|
text_size=gr.themes.sizes.text_md, |
|
|
font=[ |
|
|
gr.themes.GoogleFont("Inter"), |
|
|
"-apple-system", |
|
|
"BlinkMacSystemFont", |
|
|
"SF Pro Display", |
|
|
"SF Pro Text", |
|
|
"Helvetica Neue", |
|
|
"Helvetica", |
|
|
"Arial", |
|
|
"sans-serif" |
|
|
], |
|
|
font_mono=[ |
|
|
gr.themes.GoogleFont("SF Mono"), |
|
|
"ui-monospace", |
|
|
"Consolas", |
|
|
"monospace" |
|
|
] |
|
|
) |
|
|
super().set( |
|
|
|
|
|
body_background_fill="*neutral_50", |
|
|
body_background_fill_dark="*neutral_950", |
|
|
button_primary_background_fill="*primary_500", |
|
|
button_primary_background_fill_hover="*primary_600", |
|
|
button_primary_text_color="white", |
|
|
button_primary_border_color="*primary_500", |
|
|
|
|
|
|
|
|
button_shadow="0 2px 8px rgba(0, 0, 0, 0.04)", |
|
|
button_shadow_hover="0 4px 12px rgba(0, 0, 0, 0.08)", |
|
|
block_shadow="0 4px 12px rgba(0, 0, 0, 0.08)", |
|
|
|
|
|
|
|
|
button_large_radius="8px", |
|
|
button_small_radius="6px", |
|
|
block_radius="12px", |
|
|
container_radius="12px", |
|
|
|
|
|
|
|
|
block_border_width="1px", |
|
|
block_border_color="*neutral_200", |
|
|
input_border_width="1px", |
|
|
input_border_color="*neutral_300", |
|
|
input_border_color_focus="*primary_500", |
|
|
|
|
|
|
|
|
block_title_text_weight="600", |
|
|
block_label_text_weight="500", |
|
|
block_label_text_size="13px", |
|
|
block_label_text_color="*neutral_600", |
|
|
body_text_color="*neutral_900", |
|
|
|
|
|
|
|
|
layout_gap="16px", |
|
|
block_padding="20px", |
|
|
|
|
|
|
|
|
chatbot_code_background_color="*neutral_100", |
|
|
slider_color="*primary_500", |
|
|
|
|
|
|
|
|
button_transition="all 0.2s cubic-bezier(0.4, 0, 0.2, 1)", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=AppleTheme()) as demo: |
|
|
gr.Markdown("# 🎬 FastVLM Video Captioning") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=7): |
|
|
video_display = gr.Video( |
|
|
label="Video Input", |
|
|
autoplay=True, |
|
|
loop=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Sidebar(width=400): |
|
|
gr.Markdown("## 💬 Video Analysis Chat") |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
value=[["Assistant", "Upload a video and I'll analyze it for you!"]], |
|
|
height=400, |
|
|
elem_classes=["chatbot"] |
|
|
) |
|
|
|
|
|
process_btn = gr.Button("🎯 Analyze Video", variant="primary", size="lg") |
|
|
|
|
|
with gr.Accordion("🖼️ Analyzed Frames", open=False): |
|
|
frame_gallery = gr.Gallery( |
|
|
label="Extracted Frames", |
|
|
show_label=False, |
|
|
columns=2, |
|
|
rows=4, |
|
|
object_fit="contain", |
|
|
height="auto" |
|
|
) |
|
|
|
|
|
|
|
|
num_frames = gr.State(value=8) |
|
|
sampling_method = gr.State(value="uniform") |
|
|
caption_mode = gr.State(value="Brief Summary") |
|
|
custom_prompt = gr.State(value="") |
|
|
|
|
|
|
|
|
def handle_upload(video, chat_history): |
|
|
if video: |
|
|
chat_history.append(["User", "Video uploaded"]) |
|
|
chat_history.append(["Assistant", "Video loaded! Click 'Analyze Video' to generate captions."]) |
|
|
return video, chat_history |
|
|
return None, chat_history |
|
|
|
|
|
video_display.upload( |
|
|
handle_upload, |
|
|
inputs=[video_display, chatbot], |
|
|
outputs=[video_display, chatbot] |
|
|
) |
|
|
|
|
|
|
|
|
def process_video_with_chat(video_path, num_frames, sampling_method, caption_mode, custom_prompt, chat_history, progress=gr.Progress()): |
|
|
if not video_path: |
|
|
chat_history.append(["Assistant", "Please upload a video first."]) |
|
|
yield chat_history, None |
|
|
return |
|
|
|
|
|
chat_history.append(["User", "Analyzing video..."]) |
|
|
yield chat_history, None |
|
|
|
|
|
|
|
|
progress(0, desc="Extracting frames...") |
|
|
frames = extract_frames(video_path, num_frames, sampling_method) |
|
|
|
|
|
if not frames: |
|
|
chat_history.append(["Assistant", "Failed to extract frames from video."]) |
|
|
yield chat_history, None |
|
|
return |
|
|
|
|
|
|
|
|
chat_history.append(["Assistant", ""]) |
|
|
prompt = "Provide a brief one-sentence description of what's happening in this image." |
|
|
|
|
|
captions = [] |
|
|
for i, frame in enumerate(frames): |
|
|
progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") |
|
|
caption = caption_frame(frame, prompt) |
|
|
frame_caption = f"Frame {i + 1}: {caption}\n" |
|
|
captions.append(frame_caption) |
|
|
|
|
|
|
|
|
current_text = "".join(captions) |
|
|
chat_history[-1] = ["Assistant", f"Analyzing {len(frames)} frames:\n\n{current_text}"] |
|
|
yield chat_history, frames[:i+1] |
|
|
|
|
|
progress(1.0, desc="Analysis complete!") |
|
|
|
|
|
|
|
|
full_caption = "".join(captions) |
|
|
final_message = f"Analyzed {len(frames)} frames:\n\n{full_caption}" |
|
|
chat_history[-1] = ["Assistant", final_message] |
|
|
yield chat_history, frames |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
process_video_with_chat, |
|
|
inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt, chatbot], |
|
|
outputs=[chatbot, frame_gallery], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
demo.launch() |