Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import cv2 | |
| import numpy as np | |
| from typing import Optional | |
| import tempfile | |
| import os | |
| MID = "apple/FastVLM-7B" | |
| IMAGE_TOKEN_INDEX = -200 | |
| # Load model and tokenizer | |
| print("Loading FastVLM model...") | |
| tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MID, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| print("Model loaded successfully!") | |
| def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"): | |
| """Extract frames from video""" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames == 0: | |
| cap.release() | |
| return [] | |
| frames = [] | |
| if sampling_method == "uniform": | |
| # Uniform sampling | |
| indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| elif sampling_method == "first": | |
| # Take first N frames | |
| indices = list(range(min(num_frames, total_frames))) | |
| elif sampling_method == "last": | |
| # Take last N frames | |
| start = max(0, total_frames - num_frames) | |
| indices = list(range(start, total_frames)) | |
| else: # middle | |
| # Take frames from the middle | |
| start = max(0, (total_frames - num_frames) // 2) | |
| indices = list(range(start, min(start + num_frames, total_frames))) | |
| for idx in indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| return frames | |
| def caption_frame(image: Image.Image, prompt: str) -> str: | |
| """Generate caption for a single frame""" | |
| # Build chat with custom prompt | |
| messages = [ | |
| {"role": "user", "content": f"<image>\n{prompt}"} | |
| ] | |
| rendered = tok.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=False | |
| ) | |
| pre, post = rendered.split("<image>", 1) | |
| # Tokenize the text around the image token | |
| pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids | |
| post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids | |
| # Splice in the IMAGE token id | |
| img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) | |
| input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) | |
| attention_mask = torch.ones_like(input_ids, device=model.device) | |
| # Preprocess image | |
| px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] | |
| px = px.to(model.device, dtype=model.dtype) | |
| # Generate | |
| with torch.no_grad(): | |
| out = model.generate( | |
| inputs=input_ids, | |
| attention_mask=attention_mask, | |
| images=px, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| do_sample=True, | |
| ) | |
| caption = tok.decode(out[0], skip_special_tokens=True) | |
| # Extract only the generated part | |
| if prompt in caption: | |
| caption = caption.split(prompt)[-1].strip() | |
| return caption | |
| def process_video( | |
| video_path: str, | |
| num_frames: int, | |
| sampling_method: str, | |
| caption_mode: str, | |
| custom_prompt: str, | |
| progress=gr.Progress() | |
| ) -> tuple: | |
| """Process video and generate captions""" | |
| if not video_path: | |
| return "Please upload a video first.", None, None | |
| progress(0, desc="Extracting frames...") | |
| frames = extract_frames(video_path, num_frames, sampling_method) | |
| if not frames: | |
| return "Failed to extract frames from video.", None, None | |
| # Prepare prompt based on mode | |
| if caption_mode == "Detailed Description": | |
| prompt = "Describe this image in detail, including all visible objects, actions, and the overall scene." | |
| elif caption_mode == "Brief Summary": | |
| prompt = "Provide a brief one-sentence description of what's happening in this image." | |
| elif caption_mode == "Action Recognition": | |
| prompt = "What action or activity is taking place in this image? Focus on the main action." | |
| else: # Custom | |
| prompt = custom_prompt if custom_prompt else "Describe this image." | |
| captions = [] | |
| frame_previews = [] | |
| for i, frame in enumerate(frames): | |
| progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") | |
| caption = caption_frame(frame, prompt) | |
| captions.append(f"**Frame {i + 1}:** {caption}") | |
| frame_previews.append(frame) | |
| progress(1.0, desc="Generating summary...") | |
| # Combine captions into a narrative | |
| full_caption = "\n\n".join(captions) | |
| # Generate overall summary if multiple frames | |
| if len(frames) > 1: | |
| summary_prompt = f"Based on these frame descriptions, provide a coherent summary of the video:\n{full_caption}\n\nSummary:" | |
| # For simplicity, we'll just combine the captions | |
| video_summary = f"## Video Analysis ({len(frames)} frames analyzed)\n\n{full_caption}" | |
| else: | |
| video_summary = f"## Video Analysis\n\n{full_caption}" | |
| return video_summary, frame_previews, video_path | |
| # Create the Gradio interface | |
| with gr.Blocks(css=""" | |
| .video-container { | |
| height: calc(100vh - 100px) !important; | |
| } | |
| .sidebar { | |
| height: calc(100vh - 100px) !important; | |
| overflow-y: auto; | |
| } | |
| """) as demo: | |
| gr.Markdown("# π¬ FastVLM Video Captioning") | |
| with gr.Row(): | |
| # Main video display | |
| with gr.Column(scale=7): | |
| video_display = gr.Video( | |
| label="Video Input", | |
| height=600, | |
| elem_classes=["video-container"], | |
| autoplay=True, | |
| loop=True | |
| ) | |
| # Sidebar with controls | |
| with gr.Sidebar(width=400, elem_classes=["sidebar"]): | |
| gr.Markdown("## βοΈ Settings") | |
| with gr.Group(): | |
| gr.Markdown("### Frame Sampling") | |
| num_frames = gr.Slider( | |
| minimum=1, | |
| maximum=16, | |
| value=8, | |
| step=1, | |
| label="Number of Frames to Analyze", | |
| info="More frames = better understanding but slower processing" | |
| ) | |
| sampling_method = gr.Radio( | |
| choices=["uniform", "first", "last", "middle"], | |
| value="uniform", | |
| label="Sampling Method", | |
| info="How to select frames from the video" | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### Caption Settings") | |
| caption_mode = gr.Radio( | |
| choices=["Detailed Description", "Brief Summary", "Action Recognition", "Custom"], | |
| value="Detailed Description", | |
| label="Caption Mode" | |
| ) | |
| custom_prompt = gr.Textbox( | |
| label="Custom Prompt", | |
| placeholder="Enter your custom prompt here...", | |
| visible=False, | |
| lines=3 | |
| ) | |
| process_btn = gr.Button("π― Analyze Video", variant="primary", size="lg") | |
| gr.Markdown("### π Results") | |
| output_text = gr.Markdown( | |
| value="Upload a video and click 'Analyze Video' to begin.", | |
| elem_classes=["output-text"] | |
| ) | |
| with gr.Accordion("πΌοΈ Analyzed Frames", open=False): | |
| frame_gallery = gr.Gallery( | |
| label="Extracted Frames", | |
| show_label=False, | |
| columns=2, | |
| rows=4, | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| # Show/hide custom prompt based on mode selection | |
| def toggle_custom_prompt(mode): | |
| return gr.Textbox(visible=(mode == "Custom")) | |
| caption_mode.change( | |
| toggle_custom_prompt, | |
| inputs=[caption_mode], | |
| outputs=[custom_prompt] | |
| ) | |
| # Upload handler | |
| def handle_upload(video): | |
| if video: | |
| return video, "Video loaded! Click 'Analyze Video' to generate captions." | |
| return None, "Upload a video to begin." | |
| video_display.upload( | |
| handle_upload, | |
| inputs=[video_display], | |
| outputs=[video_display, output_text] | |
| ) | |
| # Process button | |
| process_btn.click( | |
| process_video, | |
| inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt], | |
| outputs=[output_text, frame_gallery, video_display] | |
| ) | |
| demo.launch() |