Spaces:
Build error
Build error
| import os | |
| import sys | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import einops | |
| from PIL import Image | |
| parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| gmflow_dir = os.path.join(parent_dir, 'deps/gmflow') | |
| sys.path.insert(0, gmflow_dir) | |
| from GMFlow.gmflow.gmflow import GMFlow # noqa: E702 E402 F401 | |
| from GMFlow.utils.utils import InputPadder # noqa: E702 E402 | |
| def coords_grid(b, h, w, homogeneous=False, device=None): | |
| y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] | |
| stacks = [x, y] | |
| if homogeneous: | |
| ones = torch.ones_like(x) # [H, W] | |
| stacks.append(ones) | |
| grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] | |
| grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] | |
| if device is not None: | |
| grid = grid.to(device) | |
| return grid | |
| def bilinear_sample(img, | |
| sample_coords, | |
| mode='bilinear', | |
| padding_mode='zeros', | |
| return_mask=False): | |
| # img: [B, C, H, W] | |
| # sample_coords: [B, 2, H, W] in image scale | |
| if sample_coords.size(1) != 2: # [B, H, W, 2] | |
| sample_coords = sample_coords.permute(0, 3, 1, 2) | |
| b, _, h, w = sample_coords.shape | |
| # Normalize to [-1, 1] | |
| x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 | |
| y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 | |
| grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] | |
| img = F.grid_sample(img, | |
| grid, | |
| mode=mode, | |
| padding_mode=padding_mode, | |
| align_corners=True) | |
| if return_mask: | |
| mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & ( | |
| y_grid <= 1) # [B, H, W] | |
| return img, mask | |
| return img | |
| def flow_warp(feature, | |
| flow, | |
| mask=False, | |
| mode='bilinear', | |
| padding_mode='zeros'): | |
| b, c, h, w = feature.size() | |
| assert flow.size(1) == 2 | |
| grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] | |
| return bilinear_sample(feature, | |
| grid, | |
| mode=mode, | |
| padding_mode=padding_mode, | |
| return_mask=mask) | |
| def forward_backward_consistency_check(fwd_flow, | |
| bwd_flow, | |
| alpha=0.01, | |
| beta=0.5, | |
| return_confidence=False): | |
| # fwd_flow, bwd_flow: [B, 2, H, W] | |
| # alpha and beta values are following UnFlow | |
| # (https://arxiv.org/abs/1711.07837) | |
| assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 | |
| assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 | |
| flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, | |
| dim=1) # [B, H, W] | |
| warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] | |
| warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] | |
| diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] | |
| diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) | |
| threshold = alpha * flow_mag + beta | |
| if return_confidence: | |
| # fwd_occ = diff_fwd | |
| # bwd_occ = diff_bwd | |
| fwd_occ = torch.exp(-diff_fwd) | |
| bwd_occ = torch.exp(-diff_bwd) | |
| # import ipdb; ipdb.set_trace() | |
| # Image.fromarray((bwd_occ * 255)[0,:,:].cpu().numpy().clip(0, 255).astype(np.uint8)).save("mask.png") | |
| else: | |
| fwd_occ = (diff_fwd > threshold).float() # [B, H, W] | |
| bwd_occ = (diff_bwd > threshold).float() | |
| return fwd_occ, bwd_occ | |
| def get_warped_and_mask(flow_model, | |
| image1, | |
| image2, | |
| image3=None, | |
| pixel_consistency=False, | |
| return_confidence=False): | |
| if image3 is None: | |
| image3 = image1[None] | |
| padder = InputPadder(image1.shape, padding_factor=16) | |
| # image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) | |
| image1, image2 = padder.pad(image1[None], image2[None]) | |
| results_dict = flow_model(image1, | |
| image2, | |
| attn_splits_list=[2], | |
| corr_radius_list=[-1], | |
| prop_radius_list=[-1], | |
| pred_bidir_flow=True) | |
| flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] | |
| fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] | |
| bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] | |
| # results_dict_ = flow_model(image2, | |
| # image1, | |
| # attn_splits_list=[2], | |
| # corr_radius_list=[-1], | |
| # prop_radius_list=[-1], | |
| # pred_bidir_flow=True) | |
| # flow_pr = results_dict_['flow_preds'][-1] # [B, 2, H, W] | |
| # fwd_flow_ = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] | |
| # bwd_flow_ = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] | |
| # fwd_occ_, bwd_occ_ = forward_backward_consistency_check( | |
| # fwd_flow_, bwd_flow_, return_error=True) # [1, H, W] float | |
| fwd_occ, bwd_occ = forward_backward_consistency_check( | |
| fwd_flow, bwd_flow) # [1, H, W] float | |
| if pixel_consistency: | |
| warped_image1 = flow_warp(image1, padder.pad(bwd_flow)[0]) | |
| bwd_occ = torch.clamp( | |
| padder.pad(bwd_occ)[0] + | |
| (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, | |
| 1) | |
| warped_results = flow_warp(image3, bwd_flow) | |
| if return_confidence: | |
| fwd_err, bwd_err = forward_backward_consistency_check( | |
| fwd_flow, bwd_flow, return_confidence=return_confidence) # [1, H, W] float | |
| return warped_results, bwd_occ, bwd_flow, bwd_err | |
| return warped_results, bwd_occ, bwd_flow | |
| class FlowCalc(): | |
| def __init__(self, model_path='./weights/gmflow_sintel-0c07dcb3.pth'): | |
| flow_model = GMFlow( | |
| feature_channels=128, | |
| num_scales=1, | |
| upsample_factor=8, | |
| num_head=1, | |
| attention_type='swin', | |
| ffn_dim_expansion=4, | |
| num_transformer_layers=6, | |
| ).to('cuda') | |
| checkpoint = torch.load(model_path, | |
| map_location=lambda storage, loc: storage) | |
| weights = checkpoint['model'] if 'model' in checkpoint else checkpoint | |
| flow_model.load_state_dict(weights, strict=False) | |
| flow_model.eval() | |
| self.model = flow_model | |
| def get_flow(self, image1, image2, save_path=None): | |
| if save_path is not None and os.path.exists(save_path): | |
| bwd_flow = read_flow(save_path) | |
| return bwd_flow | |
| image1 = torch.from_numpy(image1).permute(2, 0, 1).float() | |
| image2 = torch.from_numpy(image2).permute(2, 0, 1).float() | |
| padder = InputPadder(image1.shape, padding_factor=8) | |
| image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) | |
| results_dict = self.model(image1, | |
| image2, | |
| attn_splits_list=[2], | |
| corr_radius_list=[-1], | |
| prop_radius_list=[-1], | |
| pred_bidir_flow=True) | |
| flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] | |
| fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] | |
| bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] | |
| fwd_occ, bwd_occ = forward_backward_consistency_check( | |
| fwd_flow, bwd_flow) # [1, H, W] float | |
| if save_path is not None: | |
| flow_np = bwd_flow.cpu().numpy() | |
| np.save(save_path, flow_np) | |
| mask_path = os.path.splitext(save_path)[0] + '.png' | |
| bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( | |
| torch.long).numpy() * 255 | |
| cv2.imwrite(mask_path, bwd_occ) | |
| return bwd_flow | |
| def get_mask(self, image1, image2, save_path=None): | |
| if save_path is not None: | |
| mask_path = os.path.splitext(save_path)[0] + '.png' | |
| if os.path.exists(mask_path): | |
| return read_mask(mask_path) | |
| image1 = torch.from_numpy(image1).permute(2, 0, 1).float() | |
| image2 = torch.from_numpy(image2).permute(2, 0, 1).float() | |
| padder = InputPadder(image1.shape, padding_factor=8) | |
| image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) | |
| results_dict = self.model(image1, | |
| image2, | |
| attn_splits_list=[2], | |
| corr_radius_list=[-1], | |
| prop_radius_list=[-1], | |
| pred_bidir_flow=True) | |
| flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] | |
| fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] | |
| bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] | |
| fwd_occ, bwd_occ = forward_backward_consistency_check( | |
| fwd_flow, bwd_flow) # [1, H, W] float | |
| if save_path is not None: | |
| flow_np = bwd_flow.cpu().numpy() | |
| np.save(save_path, flow_np) | |
| mask_path = os.path.splitext(save_path)[0] + '.png' | |
| bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( | |
| torch.long).numpy() * 255 | |
| cv2.imwrite(mask_path, bwd_occ) | |
| return bwd_occ | |
| def warp(self, img, flow, mode='bilinear'): | |
| expand = False | |
| if len(img.shape) == 2: | |
| expand = True | |
| img = np.expand_dims(img, 2) | |
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) | |
| dtype = img.dtype | |
| img = img.to(torch.float) | |
| res = flow_warp(img, flow, mode=mode) | |
| res = res.to(dtype) | |
| res = res[0].cpu().permute(1, 2, 0).numpy() | |
| if expand: | |
| res = res[:, :, 0] | |
| return res | |
| def read_flow(save_path): | |
| flow_np = np.load(save_path) | |
| bwd_flow = torch.from_numpy(flow_np) | |
| return bwd_flow | |
| def read_mask(save_path): | |
| mask_path = os.path.splitext(save_path)[0] + '.png' | |
| mask = cv2.imread(mask_path) | |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| return mask | |
| flow_calc = FlowCalc() | |