Spaces:
Runtime error
Runtime error
| from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration | |
| from typing import List | |
| import os | |
| import supervision as sv | |
| import uuid | |
| from tqdm import tqdm | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import spaces | |
| import flax.linen as nn | |
| import jax | |
| import string | |
| import functools | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import re | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_id = "google/paligemma-3b-mix-448" | |
| model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device) | |
| processor = PaliGemmaProcessor.from_pretrained(model_id) | |
| BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() | |
| MASK_ANNOTATOR = sv.MaskAnnotator() | |
| LABEL_ANNOTATOR = sv.LabelAnnotator() | |
| def calculate_end_frame_index(source_video_path): | |
| video_info = sv.VideoInfo.from_video_path(source_video_path) | |
| return min( | |
| video_info.total_frames, | |
| video_info.fps * 2 | |
| ) | |
| def annotate_image( | |
| input_image, | |
| detections, | |
| labels | |
| ) -> np.ndarray: | |
| output_image = MASK_ANNOTATOR.annotate(input_image, detections) | |
| output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) | |
| output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) | |
| return output_image | |
| def process_video( | |
| input_video, | |
| labels, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| video_info = sv.VideoInfo.from_video_path(input_video) | |
| total = calculate_end_frame_index(input_video) | |
| frame_generator = sv.get_video_frames_generator( | |
| source_path=input_video, | |
| end=total | |
| ) | |
| result_file_name = f"{uuid.uuid4()}.mp4" | |
| result_file_path = os.path.join("./", result_file_name) | |
| with sv.VideoSink(result_file_path, video_info=video_info) as sink: | |
| for _ in tqdm(range(total), desc="Processing video.."): | |
| frame = next(frame_generator) | |
| # list of dict of {"box": box, "mask":mask, "score":score, "label":label} | |
| results, input_list = parse_detection(frame, labels) | |
| detections = sv.Detections.from_transformers(results[0]) | |
| final_labels = [] | |
| for id in results[0]["labels"]: | |
| final_labels.append(input_list[id]) | |
| frame = annotate_image( | |
| input_image=frame, | |
| detections=detections, | |
| labels=final_labels, | |
| ) | |
| sink.write_frame(frame) | |
| return result_file_path | |
| def infer( | |
| image: Image.Image, | |
| text: str, | |
| max_new_tokens: int | |
| ) -> str: | |
| inputs = processor(text=text, images=image, return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False | |
| ) | |
| result = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| return result[0][len(text):].lstrip("\n") | |
| def parse_detection(input_image, input_text): | |
| prompt = f"detect {input_text}" | |
| out = infer(input_image, prompt, max_new_tokens=100) | |
| objs = extract_objs(out.lstrip("\n"), input_image.shape[0], input_image.shape[1], unique_labels=True) | |
| labels = list(obj.get('name') for obj in objs if obj.get('name')) | |
| print("labels", labels) | |
| input_list = input_text.split(";") | |
| for ind, input in enumerate(input_list): | |
| input_list[ind] = remove_special_characters(input).lstrip("\n").rstrip("\n") | |
| label_indices = [] | |
| for label in labels: | |
| label = remove_special_characters(label) | |
| label_indices.append(input_list.index(label)) | |
| label_indices = torch.tensor(label_indices).to("cuda") | |
| boxes = torch.tensor([list(obj["xyxy"]) for obj in objs]) | |
| return [{"boxes": boxes, "scores":torch.tensor([0.99 for _ in range(len(boxes))]).to("cuda"), "labels":label_indices}], input_list | |
| _MODEL_PATH = 'vae-oid.npz' | |
| _SEGMENT_DETECT_RE = re.compile( | |
| r'(.*?)' + | |
| r'<loc(\d{4})>' * 4 + r'\s*' + | |
| '(?:%s)?' % (r'<seg(\d{3})>' * 16) + | |
| r'\s*([^;<>]+)? ?(?:; )?', | |
| ) | |
| def _quantized_values_from_codebook_indices(codebook_indices, embeddings): | |
| batch_size, num_tokens = codebook_indices.shape | |
| assert num_tokens == 16, codebook_indices.shape | |
| unused_num_embeddings, embedding_dim = embeddings.shape | |
| encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) | |
| encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) | |
| return encodings | |
| def remove_special_characters(word): | |
| return re.sub(r'^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$', '', word) | |
| def extract_objs(text, width, height, unique_labels=False): | |
| """Returns objs for a string with "<loc>" and "<seg>" tokens.""" | |
| objs = [] | |
| seen = set() | |
| while text: | |
| m = _SEGMENT_DETECT_RE.match(text) | |
| if not m: | |
| break | |
| gs = list(m.groups()) | |
| before = gs.pop(0) | |
| name = gs.pop() | |
| y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] | |
| y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) | |
| seg_indices = gs[4:20] | |
| mask=None | |
| content = m.group() | |
| if before: | |
| objs.append(dict(content=before)) | |
| content = content[len(before):] | |
| while unique_labels and name in seen: | |
| name = (name or '') + "'" | |
| seen.add(name) | |
| objs.append(dict( | |
| content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) | |
| text = text[len(before) + len(content):] | |
| if text: | |
| objs.append(dict(content=text)) | |
| return objs | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Zero-shot Object Tracking with PaliGemma") | |
| gr.Markdown("This is a demo for zero-shot object tracking using [PaliGemma](https://huggingface.co/google/paligemma-3b-mix-448) vision language model by Google.") | |
| gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. Text input should be ; separated. 👇") | |
| with gr.Tab(label="Video"): | |
| with gr.Row(): | |
| input_video = gr.Video( | |
| label='Input Video' | |
| ) | |
| output_video = gr.Video( | |
| label='Output Video' | |
| ) | |
| with gr.Row(): | |
| candidate_labels = gr.Textbox( | |
| label='Labels', | |
| placeholder='Labels separated by a comma', | |
| ) | |
| submit = gr.Button() | |
| gr.Examples( | |
| fn=process_video, | |
| examples=[["./cats.mp4", "bird ; cat"]], | |
| inputs=[ | |
| input_video, | |
| candidate_labels, | |
| ], | |
| outputs=output_video | |
| ) | |
| submit.click( | |
| fn=process_video, | |
| inputs=[input_video, candidate_labels], | |
| outputs=output_video | |
| ) | |
| demo.launch(debug=False, show_error=True) | |