Spaces:
Runtime error
Runtime error
| # Copyright (2025) Bytedance Ltd. and/or its affiliates | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from torchvision.transforms import Compose | |
| import cv2 | |
| from tqdm import tqdm | |
| import numpy as np | |
| import gc | |
| from .dinov2 import DINOv2 | |
| from .dpt_temporal import DPTHeadTemporal | |
| from .util.transform import Resize, NormalizeImage, PrepareForNet | |
| from utils.util import compute_scale_and_shift, get_interpolate_frames | |
| # infer settings, do not change | |
| INFER_LEN = 32 | |
| OVERLAP = 10 | |
| KEYFRAMES = [0,12,24,25,26,27,28,29,30,31] | |
| INTERP_LEN = 8 | |
| class VideoDepthAnything(nn.Module): | |
| def __init__( | |
| self, | |
| encoder='vitl', | |
| features=256, | |
| out_channels=[256, 512, 1024, 1024], | |
| use_bn=False, | |
| use_clstoken=False, | |
| num_frames=32, | |
| pe='ape' | |
| ): | |
| super(VideoDepthAnything, self).__init__() | |
| self.intermediate_layer_idx = { | |
| 'vits': [2, 5, 8, 11], | |
| 'vitl': [4, 11, 17, 23] | |
| } | |
| self.encoder = encoder | |
| self.pretrained = DINOv2(model_name=encoder) | |
| self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe) | |
| def forward(self, x): | |
| B, T, C, H, W = x.shape | |
| patch_h, patch_w = H // 14, W // 14 | |
| features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True) | |
| depth = self.head(features, patch_h, patch_w, T) | |
| depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True) | |
| depth = F.relu(depth) | |
| return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W] | |
| def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'): | |
| frame_height, frame_width = frames[0].shape[:2] | |
| ratio = max(frame_height, frame_width) / min(frame_height, frame_width) | |
| if ratio > 1.78: | |
| input_size = int(input_size * 1.78 / ratio) | |
| input_size = round(input_size / 14) * 14 | |
| print(f'==> infer lower bound input size: {input_size}') | |
| transform = Compose([ | |
| Resize( | |
| width=input_size, | |
| height=input_size, | |
| resize_target=False, | |
| keep_aspect_ratio=True, | |
| ensure_multiple_of=14, | |
| resize_method='lower_bound', | |
| image_interpolation_method=cv2.INTER_CUBIC, | |
| ), | |
| NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| PrepareForNet(), | |
| ]) | |
| frame_list = [frames[i] for i in range(frames.shape[0])] | |
| frame_step = INFER_LEN - OVERLAP | |
| org_video_len = len(frame_list) | |
| append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step) | |
| frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len | |
| depth_list = [] | |
| pre_input = None | |
| for frame_id in tqdm(range(0, org_video_len, frame_step)): | |
| cur_list = [] | |
| for i in range(INFER_LEN): | |
| cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0)) | |
| cur_input = torch.cat(cur_list, dim=1).to(device) | |
| if pre_input is not None: | |
| cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...] | |
| with torch.no_grad(): | |
| with torch.autocast(device_type=device, enabled=True): | |
| depth = self.forward(cur_input) # depth shape: [1, T, H, W] | |
| depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) | |
| depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])] | |
| pre_input = cur_input | |
| del frame_list | |
| gc.collect() | |
| depth_list_aligned = [] | |
| ref_align = [] | |
| align_len = OVERLAP - INTERP_LEN | |
| kf_align_list = KEYFRAMES[:align_len] | |
| for frame_id in range(0, len(depth_list), INFER_LEN): | |
| if len(depth_list_aligned) == 0: | |
| depth_list_aligned += depth_list[:INFER_LEN] | |
| for kf_id in kf_align_list: | |
| ref_align.append(depth_list[frame_id+kf_id]) | |
| else: | |
| curr_align = [] | |
| for i in range(len(kf_align_list)): | |
| curr_align.append(depth_list[frame_id+i]) | |
| scale, shift = compute_scale_and_shift(np.concatenate(curr_align), | |
| np.concatenate(ref_align), | |
| np.concatenate(np.ones_like(ref_align)==1)) | |
| pre_depth_list = depth_list_aligned[-INTERP_LEN:] | |
| post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP] | |
| for i in range(len(post_depth_list)): | |
| post_depth_list[i] = post_depth_list[i] * scale + shift | |
| post_depth_list[i][post_depth_list[i]<0] = 0 | |
| depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list) | |
| for i in range(OVERLAP, INFER_LEN): | |
| new_depth = depth_list[frame_id+i] * scale + shift | |
| new_depth[new_depth<0] = 0 | |
| depth_list_aligned.append(new_depth) | |
| ref_align = ref_align[:1] | |
| for kf_id in kf_align_list[1:]: | |
| new_depth = depth_list[frame_id+kf_id] * scale + shift | |
| new_depth[new_depth<0] = 0 | |
| ref_align.append(new_depth) | |
| depth_list = depth_list_aligned | |
| return np.stack(depth_list[:org_video_len], axis=0), target_fps | |