Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Sequence | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| class Permute(nn.Module): | |
| """ | |
| Permutation as an op | |
| """ | |
| def __init__(self, ordering): | |
| super().__init__() | |
| self.ordering = ordering | |
| def forward(self, frames): | |
| """ | |
| Args: | |
| frames in some ordering, by default (C, T, H, W) | |
| Returns: | |
| frames in the ordering that was specified | |
| """ | |
| return frames.permute(self.ordering) | |
| class TemporalCrop(nn.Module): | |
| """ | |
| Convert the video into smaller clips temporally. | |
| """ | |
| def __init__( | |
| self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1 | |
| ): | |
| super().__init__() | |
| self.frames = frames_per_clip | |
| self.stride = stride | |
| self.frame_stride = frame_stride | |
| def forward(self, video): | |
| assert video.ndim == 4, "Must be (C, T, H, W)" | |
| res = [] | |
| for start in range( | |
| 0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride | |
| ): | |
| end = start + (self.frames) * self.frame_stride | |
| res.append(video[:, start: end: self.frame_stride, ...]) | |
| return res | |
| def crop_boxes(boxes, x_offset, y_offset): | |
| """ | |
| Peform crop on the bounding boxes given the offsets. | |
| Args: | |
| boxes (ndarray or None): bounding boxes to peform crop. The dimension | |
| is `num boxes` x 4. | |
| x_offset (int): cropping offset in the x axis. | |
| y_offset (int): cropping offset in the y axis. | |
| Returns: | |
| cropped_boxes (ndarray or None): the cropped boxes with dimension of | |
| `num boxes` x 4. | |
| """ | |
| cropped_boxes = boxes.copy() | |
| cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset | |
| cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset | |
| return cropped_boxes | |
| def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): | |
| """ | |
| Perform uniform spatial sampling on the images and corresponding boxes. | |
| Args: | |
| images (tensor): images to perform uniform crop. The dimension is | |
| `num frames` x `channel` x `height` x `width`. | |
| size (int): size of height and weight to crop the images. | |
| spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width | |
| is larger than height. Or 0, 1, or 2 for top, center, and bottom | |
| crop if height is larger than width. | |
| boxes (ndarray or None): optional. Corresponding boxes to images. | |
| Dimension is `num boxes` x 4. | |
| scale_size (int): optinal. If not None, resize the images to scale_size before | |
| performing any crop. | |
| Returns: | |
| cropped (tensor): images with dimension of | |
| `num frames` x `channel` x `size` x `size`. | |
| cropped_boxes (ndarray or None): the cropped boxes with dimension of | |
| `num boxes` x 4. | |
| """ | |
| assert spatial_idx in [0, 1, 2] | |
| ndim = len(images.shape) | |
| if ndim == 3: | |
| images = images.unsqueeze(0) | |
| height = images.shape[2] | |
| width = images.shape[3] | |
| if scale_size is not None: | |
| if width <= height: | |
| width, height = scale_size, int(height / width * scale_size) | |
| else: | |
| width, height = int(width / height * scale_size), scale_size | |
| images = torch.nn.functional.interpolate( | |
| images, | |
| size=(height, width), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| y_offset = int(math.ceil((height - size) / 2)) | |
| x_offset = int(math.ceil((width - size) / 2)) | |
| if height > width: | |
| if spatial_idx == 0: | |
| y_offset = 0 | |
| elif spatial_idx == 2: | |
| y_offset = height - size | |
| else: | |
| if spatial_idx == 0: | |
| x_offset = 0 | |
| elif spatial_idx == 2: | |
| x_offset = width - size | |
| cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size] | |
| cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None | |
| if ndim == 3: | |
| cropped = cropped.squeeze(0) | |
| return cropped, cropped_boxes | |
| class SpatialCrop(nn.Module): | |
| """ | |
| Convert the video into 3 smaller clips spatially. Must be used after the | |
| temporal crops to get spatial crops, and should be used with | |
| -2 in the spatial crop at the slowfast augmentation stage (so full | |
| frames are passed in here). Will return a larger list with the | |
| 3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT) | |
| or 3x10 testing in SlowFast etc. | |
| """ | |
| def __init__(self, crop_size: int = 224, num_crops: int = 3): | |
| super().__init__() | |
| self.crop_size = crop_size | |
| if num_crops == 6: | |
| self.crops_to_ext = [0, 1, 2] | |
| # I guess Swin uses 5 crops without flipping, but that doesn't | |
| # make sense given they first resize to 224 and take 224 crops. | |
| # (pg 6 of https://arxiv.org/pdf/2106.13230.pdf) | |
| # So I'm assuming we can use flipped crops and that will add sth.. | |
| self.flipped_crops_to_ext = [0, 1, 2] | |
| elif num_crops == 3: | |
| self.crops_to_ext = [0, 1, 2] | |
| self.flipped_crops_to_ext = [] | |
| elif num_crops == 1: | |
| self.crops_to_ext = [1] | |
| self.flipped_crops_to_ext = [] | |
| else: | |
| raise NotImplementedError( | |
| "Nothing else supported yet, " | |
| "slowfast only takes 0, 1, 2 as arguments" | |
| ) | |
| def forward(self, videos: Sequence[torch.Tensor]): | |
| """ | |
| Args: | |
| videos: A list of C, T, H, W videos. | |
| Returns: | |
| videos: A list with 3x the number of elements. Each video converted | |
| to C, T, H', W' by spatial cropping. | |
| """ | |
| assert isinstance(videos, list), "Must be a list of videos after temporal crops" | |
| assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" | |
| res = [] | |
| for video in videos: | |
| for spatial_idx in self.crops_to_ext: | |
| res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) | |
| if not self.flipped_crops_to_ext: | |
| continue | |
| flipped_video = transforms.functional.hflip(video) | |
| for spatial_idx in self.flipped_crops_to_ext: | |
| res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) | |
| return res | |