Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the sav_dataset directory of this source tree. | |
| import json | |
| import os | |
| from typing import Dict, List, Optional, Tuple | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pycocotools.mask as mask_util | |
| def decode_video(video_path: str) -> List[np.ndarray]: | |
| """ | |
| Decode the video and return the RGB frames | |
| """ | |
| video = cv2.VideoCapture(video_path) | |
| video_frames = [] | |
| while video.isOpened(): | |
| ret, frame = video.read() | |
| if ret: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| video_frames.append(frame) | |
| else: | |
| break | |
| return video_frames | |
| def show_anns(masks, colors: List, borders=True) -> None: | |
| """ | |
| show the annotations | |
| """ | |
| # return if no masks | |
| if len(masks) == 0: | |
| return | |
| # sort masks by size | |
| sorted_annot_and_color = sorted( | |
| zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True | |
| ) | |
| H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1] | |
| canvas = np.ones((H, W, 4)) | |
| canvas[:, :, 3] = 0 # set the alpha channel | |
| contour_thickness = max(1, int(min(5, 0.01 * min(H, W)))) | |
| for mask, color in sorted_annot_and_color: | |
| canvas[mask] = np.concatenate([color, [0.55]]) | |
| if borders: | |
| contours, _ = cv2.findContours( | |
| np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE | |
| ) | |
| cv2.drawContours( | |
| canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness | |
| ) | |
| ax = plt.gca() | |
| ax.imshow(canvas) | |
| class SAVDataset: | |
| """ | |
| SAVDataset is a class to load the SAV dataset and visualize the annotations. | |
| """ | |
| def __init__(self, sav_dir, annot_sample_rate=4): | |
| """ | |
| Args: | |
| sav_dir: the directory of the SAV dataset | |
| annot_sample_rate: the sampling rate of the annotations. | |
| The annotations are aligned with the videos at 6 fps. | |
| """ | |
| self.sav_dir = sav_dir | |
| self.annot_sample_rate = annot_sample_rate | |
| self.manual_mask_colors = np.random.random((256, 3)) | |
| self.auto_mask_colors = np.random.random((256, 3)) | |
| def read_frames(self, mp4_path: str) -> None: | |
| """ | |
| Read the frames and downsample them to align with the annotations. | |
| """ | |
| if not os.path.exists(mp4_path): | |
| print(f"{mp4_path} doesn't exist.") | |
| return None | |
| else: | |
| # decode the video | |
| frames = decode_video(mp4_path) | |
| print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).") | |
| # downsample the frames to align with the annotations | |
| frames = frames[:: self.annot_sample_rate] | |
| print( | |
| f"Videos are annotated every {self.annot_sample_rate} frames. " | |
| "To align with the annotations, " | |
| f"downsample the video to {len(frames)} frames." | |
| ) | |
| return frames | |
| def get_frames_and_annotations( | |
| self, video_id: str | |
| ) -> Tuple[List | None, Dict | None, Dict | None]: | |
| """ | |
| Get the frames and annotations for video. | |
| """ | |
| # load the video | |
| mp4_path = os.path.join(self.sav_dir, video_id + ".mp4") | |
| frames = self.read_frames(mp4_path) | |
| if frames is None: | |
| return None, None, None | |
| # load the manual annotations | |
| manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json") | |
| if not os.path.exists(manual_annot_path): | |
| print(f"{manual_annot_path} doesn't exist. Something might be wrong.") | |
| manual_annot = None | |
| else: | |
| manual_annot = json.load(open(manual_annot_path)) | |
| # load the manual annotations | |
| auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json") | |
| if not os.path.exists(auto_annot_path): | |
| print(f"{auto_annot_path} doesn't exist.") | |
| auto_annot = None | |
| else: | |
| auto_annot = json.load(open(auto_annot_path)) | |
| return frames, manual_annot, auto_annot | |
| def visualize_annotation( | |
| self, | |
| frames: List[np.ndarray], | |
| auto_annot: Optional[Dict], | |
| manual_annot: Optional[Dict], | |
| annotated_frame_id: int, | |
| show_auto=True, | |
| show_manual=True, | |
| ) -> None: | |
| """ | |
| Visualize the annotations on the annotated_frame_id. | |
| If show_manual is True, show the manual annotations. | |
| If show_auto is True, show the auto annotations. | |
| By default, show both auto and manual annotations. | |
| """ | |
| if annotated_frame_id >= len(frames): | |
| print("invalid annotated_frame_id") | |
| return | |
| rles = [] | |
| colors = [] | |
| if show_manual and manual_annot is not None: | |
| rles.extend(manual_annot["masklet"][annotated_frame_id]) | |
| colors.extend( | |
| self.manual_mask_colors[ | |
| : len(manual_annot["masklet"][annotated_frame_id]) | |
| ] | |
| ) | |
| if show_auto and auto_annot is not None: | |
| rles.extend(auto_annot["masklet"][annotated_frame_id]) | |
| colors.extend( | |
| self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])] | |
| ) | |
| plt.imshow(frames[annotated_frame_id]) | |
| if len(rles) > 0: | |
| masks = [mask_util.decode(rle) > 0 for rle in rles] | |
| show_anns(masks, colors) | |
| else: | |
| print("No annotation will be shown") | |
| plt.axis("off") | |
| plt.show() | |