Spaces:
Build error
Build error
| import cv2 | |
| import time | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| from matplotlib.patches import ConnectionPatch | |
| from controller.controller import AttentionControl | |
| from einops import repeat, rearrange | |
| from typing import Tuple, Callable | |
| from vidtome.patch import PCA_token | |
| from utils.flow_utils import coords_grid | |
| def do_nothing(x: torch.Tensor, mode: str = None): | |
| return x | |
| def mps_gather_workaround(input, dim, index): | |
| if input.shape[-1] == 1: | |
| return torch.gather( | |
| input.unsqueeze(-1), | |
| dim - 1 if dim < 0 else dim, | |
| index.unsqueeze(-1) | |
| ).squeeze(-1) | |
| else: | |
| return torch.gather(input, dim, index) | |
| def visualize_flow_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, flow: torch.Tensor, flow_confid: torch.Tensor, | |
| ratio: float, H: int=64, out: str = "correspondence.png") -> Tuple[Callable, Callable, dict]: | |
| if len(src_img.shape) == 4: | |
| B, C, H, W = src_img.shape | |
| src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) | |
| tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) | |
| else: | |
| B, N, C = src_img.shape | |
| W = N // H | |
| src_PCA_token = PCA_token(src_img, token_h=H) | |
| tar_PCA_token = PCA_token(tar_img, token_h=H) | |
| # Compute pre-frame token number. N = unm_pre + tnum * F. | |
| gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| # Cosine similarity between src and dst tokens | |
| a = src_img / src_img.norm(dim=-1, keepdim=True) | |
| b = tar_img / tar_img.norm(dim=-1, keepdim=True) | |
| scores = a @ b.transpose(-1, -2) | |
| # Can't reduce more than the # tokens in src | |
| r = min(a.shape[1], int(a.shape[1] * ratio)) | |
| print(f"[INFO] flow r {r} ") | |
| # Find the most similar greedily | |
| flow_confid = rearrange(flow_confid, 'b h w -> b (h w)') | |
| edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] | |
| grid = coords_grid(B, H, W).to(flow.device) + flow # [B, 2, H, W] | |
| tar_xy = [(grid[0, 0, y, x].clamp(0, W-1).item(), \ | |
| grid[0, 1, y, x].clamp(0, H-1).item()) for (x, y) in src_xy] | |
| # tar_idx = torch.tensor([y * W + x for (x, y) in tar_xy]).to(src_idx.device) | |
| fig, ax = plt.subplots(1, 2, figsize=(8, 3)) | |
| # Display the source and target images | |
| ax[0].imshow(src_PCA_token, cmap='gray') | |
| ax[1].imshow(tar_PCA_token, cmap='gray') | |
| ax[0].axis('off') | |
| ax[1].axis('off') | |
| colors = cm.Greens(np.linspace(0.5, 1, len(src_xy))) | |
| # Draw lines connecting corresponding points | |
| for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): | |
| ax[0].plot(x1, y1, marker='o', color=color, markersize=0.5) # red dot in source image | |
| ax[1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image | |
| con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", | |
| axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) | |
| ax[1].add_artist(con) | |
| # plt.tight_layout() | |
| plt.savefig(out, bbox_inches="tight") | |
| plt.close() | |
| def visualize_correspondence_score(src_img: torch.Tensor, tar_img: torch.Tensor, score: torch.Tensor, | |
| ratio: float, H: int=64, out: str = "correspondence_idx.png") -> Tuple[Callable, Callable, dict]: | |
| if len(src_img.shape) == 4: | |
| B, C, H, W = src_img.shape | |
| src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) | |
| tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) | |
| else: | |
| B, N, C = src_img.shape | |
| W = N // H | |
| src_PCA_token = PCA_token(src_img, token_h=H) | |
| tar_PCA_token = PCA_token(tar_img, token_h=H) | |
| with torch.no_grad(): | |
| # Can't reduce more than the # tokens in src | |
| r = min(N, int(N * ratio)) | |
| node_max, node_idx = score.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| # src_idx = edge_idx[0, -r:, 0] # Merged Tokens | |
| src_idx = edge_idx[0, :r, 0] # Merged Tokens | |
| tar_idx = torch.gather(node_idx[0], dim=0, index=src_idx) | |
| src_idx = src_idx[:r] | |
| tar_idx = tar_idx[:r] | |
| # x = src_idx % W | |
| # y = src_idx // W | |
| # src_xy | |
| src_xy = [(id.item() % W, id.item() // W) for id in src_idx] | |
| tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx] | |
| fig, ax = plt.subplots(1, 2, figsize=(8, 3)) | |
| # Display the source and target images | |
| ax[0].imshow(src_PCA_token, cmap='gray') | |
| ax[1].imshow(tar_PCA_token, cmap='gray') | |
| colors = cm.cool(np.linspace(0, 1, len(src_xy))) | |
| # Draw lines connecting corresponding points | |
| for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): | |
| ax[0].plot(x1, y1, marker='o', color=color, markersize=1) # red dot in source image | |
| ax[1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image | |
| con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", | |
| axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) | |
| ax[1].add_artist(con) | |
| # plt.tight_layout() | |
| plt.savefig(out, bbox_inches="tight") | |
| plt.close() | |
| def visualize_cosine_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, | |
| ratio: float, H: int=64, out: str = "correspondence.png", | |
| flow: torch.Tensor=None, flow_confid: torch.Tensor=None, | |
| controller: AttentionControl=None ) -> Tuple[Callable, Callable, dict]: | |
| if len(src_img.shape) == 4: | |
| B, C, H, W = src_img.shape | |
| src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) | |
| tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) | |
| else: | |
| B, N, C = src_img.shape | |
| W = N // H | |
| # import ipdb; ipdb.set_trace() | |
| src_PCA_token = PCA_token(src_img, token_h=H) | |
| tar_PCA_token = PCA_token(tar_img, token_h=H) | |
| # Compute pre-frame token number. N = unm_pre + tnum * F. | |
| gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| # Cosine similarity between src and dst tokens | |
| a = src_img / src_img.norm(dim=-1, keepdim=True) | |
| b = tar_img / tar_img.norm(dim=-1, keepdim=True) | |
| scores = a @ b.transpose(-1, -2) | |
| # Can't reduce more than the # tokens in src | |
| r = min(a.shape[1], int(a.shape[1] * ratio)) | |
| print(f"[INFO] cosine r {r} ") | |
| # Find the most similar greedily | |
| # import ipdb; ipdb.set_trace() | |
| # scores *= controller.distances[H][:,:scores.shape[1]] | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., int(4*r):int(5*r), :] # Merged Tokens | |
| # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| # src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] | |
| tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx[0]] | |
| fig, ax = plt.subplots(1, 2, figsize=(8, 3)) | |
| # Display the source and target images | |
| ax[0].imshow(src_PCA_token, cmap='spring') | |
| ax[1].imshow(tar_PCA_token, cmap='spring') | |
| # Hide the axis labels | |
| ax[0].axis('off') | |
| ax[1].axis('off') | |
| # colors = cm.Reds(np.linspace(0.5, 1, len(src_xy))) | |
| colors = cm.cool(np.linspace(0.5, 1, len(src_xy))) | |
| # Draw lines connecting corresponding points | |
| for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): | |
| # color = "orangered" | |
| ax[0].plot(x1, y1, marker='o', color=color, markersize=0.5) # red dot in source image | |
| ax[1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image | |
| con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", | |
| axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) | |
| ax[1].add_artist(con) | |
| # plt.tight_layout() | |
| plt.savefig(out, bbox_inches="tight") | |
| plt.close() | |
| def visualize_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, | |
| ratio: float, H: int=64, out: str = "correspondence.png", | |
| flow: torch.Tensor=None, flow_confid: torch.Tensor=None, | |
| controller: AttentionControl=None ) -> Tuple[Callable, Callable, dict]: | |
| if len(src_img.shape) == 4: | |
| B, C, H, W = src_img.shape | |
| src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) | |
| tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) | |
| else: | |
| B, N, C = src_img.shape | |
| W = N // H | |
| src_PCA_token = PCA_token(src_img, token_h=H, n=1) | |
| tar_PCA_token = PCA_token(tar_img, token_h=H, n=1) | |
| # import ipdb; ipdb.set_trace() | |
| if abs(np.mean(src_PCA_token[:20, :20]) - np.mean(tar_PCA_token[:20, :20])) > 50: | |
| if np.mean(src_PCA_token[:20, :20]) > np.mean(tar_PCA_token[:20, :20]): | |
| src_PCA_token = 255 - src_PCA_token | |
| else: | |
| tar_PCA_token = 255 - tar_PCA_token | |
| print(f"[INFO] src_PCA_token mean {np.mean(src_PCA_token[:20, :20])} tar_PCA_token mean {np.mean(tar_PCA_token[:20, :20])} ") | |
| # Compute pre-frame token number. N = unm_pre + tnum * F. | |
| gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| # Cosine similarity between src and dst tokens | |
| a = src_img / src_img.norm(dim=-1, keepdim=True) | |
| b = tar_img / tar_img.norm(dim=-1, keepdim=True) | |
| scores = a @ b.transpose(-1, -2) | |
| # Can't reduce more than the # tokens in src | |
| r = min(a.shape[1], int(a.shape[1] * ratio)) | |
| # Find the most similar greedily | |
| # import ipdb; ipdb.set_trace() | |
| print(f"[INFO] no distance weigthed ... ") | |
| # scores *= controller.distances[H][:,:scores.shape[1]] | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| # src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] | |
| tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx[0]] | |
| # Find the most similar greedily | |
| flow_confid = rearrange(flow_confid, 'b h w -> b (h w)') | |
| edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| flow_src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] | |
| # import ipdb; ipdb.set_trace() | |
| grid = coords_grid(B, H, W).to(flow.device) + flow # [B, 2, H, W] | |
| flow_tar_xy = [(grid[0, 0, y, x].clamp(0, W-1).item(), \ | |
| grid[0, 1, y, x].clamp(0, H-1).item()) for (x, y) in flow_src_xy] | |
| fig, ax = plt.subplots(2, 2, figsize=(8, 4)) | |
| if len(controller.decoded_imgs): | |
| step = out.split("/")[-1].split(".")[0] | |
| _, h_, w_, _ = controller.decoded_imgs[0].shape | |
| mul = h_ // H | |
| decoded_img = controller.decoded_imgs[1] | |
| decoded_img = decoded_img[0, :, :int(W * mul), :] | |
| if step == "49": | |
| decoded_img = cv2.imread("/project/DiffBVR_eval/DAVIS/BDx8_results/DiffBIR_ours/cows/00001.png") | |
| decoded_img = cv2.resize(decoded_img, (W, H)) | |
| ax[0, 0].imshow(decoded_img, aspect='auto') | |
| decoded_img = controller.decoded_imgs[2] | |
| decoded_img = decoded_img[0, :, :int(W * mul), :] | |
| if step == "49": | |
| decoded_img = cv2.imread("/project/DiffBVR_eval/DAVIS/BDx8_results/DiffBIR_ours/cows/00002.png") | |
| decoded_img = cv2.resize(decoded_img, (W, H)) | |
| ax[0, 1].imshow(decoded_img, aspect='auto') | |
| else: | |
| # Display the source and target images | |
| ax[0, 0].imshow(src_PCA_token, cmap='ocean', aspect='auto') | |
| ax[0, 1].imshow(tar_PCA_token, cmap='ocean', aspect='auto') | |
| ax[0, 0].axis('off') | |
| ax[0, 1].axis('off') | |
| ax[1, 0].imshow(src_PCA_token, cmap='Blues', aspect='auto') | |
| ax[1, 1].imshow(tar_PCA_token, cmap='Blues', aspect='auto') | |
| # ax[1, 0].imshow(np.mean(src_PCA_token, -1), cmap='ocean') | |
| # ax[1, 1].imshow(np.mean(tar_PCA_token, -1), cmap='ocean') | |
| # Hide the axis labels | |
| ax[1, 0].axis('off') | |
| ax[1, 1].axis('off') | |
| colors = cm.Greens(np.linspace(0.25, 0.75, len(flow_src_xy))) | |
| # Draw lines connecting corresponding points | |
| for (x1, y1), (x2, y2), color in zip(flow_src_xy, flow_tar_xy, colors): | |
| # color = "mediumslateblue" | |
| # ax[1, 0].plot(x1, y1, marker='o', color=color, markersize=1) # red dot in source image | |
| ax[1, 1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image | |
| con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", | |
| axesA=ax[1, 1], axesB=ax[1, 0], color=color, linewidth=0.2) | |
| ax[1, 1].add_artist(con) | |
| # plt.tight_layout() | |
| colors = cm.Reds(np.linspace(0.25, 0.75, len(src_xy))) | |
| # Draw lines connecting corresponding points | |
| for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): | |
| # color = "orangered" | |
| # ax[1, 0].plot(x1, y1, marker='o', color=color, markersize=1) # red dot in source image | |
| ax[1, 1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image | |
| con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", | |
| axesA=ax[1, 1], axesB=ax[1, 0], color=color, linewidth=0.2) | |
| ax[1, 1].add_artist(con) | |
| plt.subplots_adjust(wspace=0.05, hspace=0.1) | |
| plt.savefig(out, bbox_inches="tight") | |
| plt.close() | |
| # For Local Token Merging | |
| def bipartite_soft_matching_randframe(metric: torch.Tensor, | |
| F: int, ratio: float, unm_pre: int, generator: torch.Generator=None, | |
| target_stride: int = 4, align_batch: bool = False, | |
| merge_mode: str = "replace", H: int=64, | |
| flow_merge: bool=False, | |
| controller: AttentionControl=None) -> Tuple[Callable, Callable, dict]: | |
| """ | |
| Partitions the multi-frame tokens into src and dst and merges ratio of src tokens from src to dst. | |
| Dst tokens are partitioned by choosing one random frame. | |
| Args: | |
| - metric [B, N, C]: metric to use for similarity. | |
| - F: frame number. | |
| - ratio: ratio of src tokens to be removed (by merging). | |
| - unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...] | |
| - generator: random number generator | |
| - target_stride: stride of target frame. | |
| - align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP. | |
| - merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token. | |
| Returns: | |
| Merge and unmerge operation according to the matching result. Return a dict including other values. | |
| """ | |
| B, N, _ = metric.shape | |
| A = N // F | |
| W = A // H | |
| # Compute pre-frame token number. N = unm_pre + tnum * F. | |
| tnum = (N - unm_pre) // F | |
| if ratio <= 0: | |
| return do_nothing, do_nothing, {"unm_num": tnum} | |
| gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| # Prepare idx buffer. Ignore previous unmerged tokens. | |
| idx_buffer = torch.arange( | |
| N - unm_pre, device=metric.device, dtype=torch.int64) | |
| # Select the random target frame. | |
| target_stride = min(target_stride, F) | |
| # import ipdb; ipdb.set_trace() | |
| if controller is None: | |
| randf = torch.randint(0, target_stride, torch.Size( | |
| [1]), generator=generator, device=generator.device) | |
| else: | |
| randf = torch.tensor(target_stride // 2).to(metric.device) | |
| # print(f"[INFO] randf: {randf} ... ") | |
| dst_select = ((torch.div(idx_buffer, tnum, rounding_mode='floor')) % | |
| target_stride == randf).to(torch.bool) | |
| # a_idx: src index. b_idx: dst index | |
| a_idx = idx_buffer[None, ~dst_select, None] + unm_pre | |
| b_idx = idx_buffer[None, dst_select, None] + unm_pre | |
| # import ipdb; ipdb.set_trace() | |
| # Add unmerged tokens to dst. | |
| unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[ | |
| None, :, None] | |
| b_idx = torch.cat([b_idx, unm_buffer], dim=1) | |
| # We're finished with these | |
| del idx_buffer, unm_buffer | |
| num_dst = b_idx.shape[1] | |
| def split(x): | |
| # Split src, dst tokens | |
| b, n, c = x.shape | |
| src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) | |
| dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) | |
| # print(f"[INFO] {x.shape} {num_dst}") | |
| return src, dst | |
| # if flow is not None: | |
| # start = time.time() | |
| # if len(flow) != F-1: | |
| # mid = F // 2 | |
| # flow_confid = flow_confid[:mid] + flow_confid[mid+1:] | |
| # flow = flow[:mid] + flow[mid+1:] | |
| # flow_confid = torch.cat(flow_confid, dim=0) | |
| # flow = torch.cat(flow, dim=0) | |
| # flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') | |
| # print(f"[INFO] flow time {time.time() - start}") | |
| # Cosine similarity between src and dst tokens | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| # import ipdb; ipdb.set_trace() | |
| a, b = split(metric) | |
| scores = a @ b.transpose(-1, -2) | |
| # Can't reduce more than the # tokens in src | |
| r = min(a.shape[1], int(a.shape[1] * ratio)) | |
| if align_batch: | |
| # Cat scores of all samples in the batch. When using PnP, samples are (src, neg, pos). | |
| # Find the most similar greedily among all samples. | |
| scores = torch.cat([*scores], dim=-1) | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], | |
| dim=-2, index=src_idx) % num_dst # Map index to (0, num_dst - 1) | |
| # Use the same matching result for all samples | |
| unm_idx = unm_idx.expand(B, -1, -1) | |
| src_idx = src_idx.expand(B, -1, -1) | |
| dst_idx = dst_idx.expand(B, -1, -1) | |
| else: | |
| if flow_merge: | |
| # print(f"[INFO] flow merge ... ") | |
| # start = time.time() | |
| # edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] | |
| # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| # src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| # src_idx_tensor = src_idx[0, : ,0] | |
| # f = src_idx_tensor // A | |
| # id = src_idx_tensor % A | |
| # x = id % W | |
| # y = id // W | |
| # # Stack the results into a 2D tensor | |
| # src_fxy = torch.stack((f, x, y), dim=1) | |
| # # import ipdb; ipdb.set_trace() | |
| # grid = coords_grid(F-1, H, W).to(flow.device) + flow # [F-1, 2, H, W] | |
| # x = grid[src_fxy[:, 0], 0, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, W-1).long() | |
| # y = grid[src_fxy[:, 0], 1, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, H-1).long() | |
| # tar_xy = torch.stack((x, y), dim=1) | |
| # tar_idx = y * W + x | |
| # tar_idx = rearrange(tar_idx, ' d -> 1 d 1') | |
| # print(f"[INFO] {src_idx[0, 10, 0]} {tar_idx[0, 10, 0]}") | |
| unm_idx = controller.flow_correspondence[H][0][:, r:, :] | |
| src_idx = controller.flow_correspondence[H][0][:, :r, :] | |
| tar_idx = controller.flow_correspondence[H][1][:, :r, :] | |
| # score[src_idx[i], tar_idx[i]] = flow_confid[src_idx[i]] | |
| # scores[:, src_idx[0, :, 0], tar_idx[0, :, 0]] = flow_confid[0, src_idx[0, :, 0]] | |
| # import ipdb; ipdb.set_trace() | |
| else: | |
| ''' distacne weighted ''' | |
| # # if H == 64: | |
| # # Create a tensor that represents the coordinates of each pixel | |
| # start = time.time() | |
| # y, x = torch.meshgrid(torch.arange(H), torch.arange(W)) | |
| # coords = torch.stack((y, x), dim=-1).float().to(metric.device) | |
| # coords = rearrange(coords, 'h w c -> (h w) c') | |
| # # Calculate the Euclidean distance between all pixels | |
| # distances = torch.cdist(coords, coords) | |
| # radius = W // 30 | |
| # radius = 1 if radius == 0 else radius | |
| # # print(f"[INFO] W: {W} Radius: {radius} ") | |
| # distances //= radius | |
| # distances = torch.exp(-distances) | |
| # # distances += torch.diag_embed(torch.ones(A)).to(metric.device) | |
| # distances = repeat(distances, 'h a -> 1 (b h) a', b=F-1) | |
| # print(f"[INFO] {W} {torch.mean(distances)} {torch.std(distances)}") | |
| # node_max, node_idx = scores.max(dim=-1) | |
| # scores *= distances | |
| # print(f"[INFO] distance not weighted ... ") | |
| if controller is not None: | |
| if H not in controller.distances: | |
| controller.set_distance(F-1, H, W, W//30, metric.device) | |
| print(f"[INFO] distance weighted ... ") | |
| # print(f"[INFO] controller distance time {time.time() - start}") | |
| scores *= controller.distances[H] | |
| # Find the most similar greedily | |
| ''' node_idx: src_idx to tar_idx ''' | |
| node_max, node_idx = scores.max(dim=-1) | |
| # src_idx_tensor = torch.arange(node_max.shape[1], device=metric.device, dtype=torch.int64) | |
| # id = src_idx_tensor % A | |
| # x = id % W | |
| # y = id // W | |
| # src_xy = torch.stack((x, y), dim=1) | |
| # tar_idx_tensor = node_idx[0, :] | |
| # x = tar_idx_tensor % W | |
| # y = tar_idx_tensor // W | |
| # tar_xy = torch.stack((x, y), dim=1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| ''' idx in all src tokens ''' | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| # correspond_dis = gather(distance[None, ..., None], dim=-2, index=src_idx) | |
| # import ipdb; ipdb.set_trace() | |
| # import ipdb; ipdb.set_trace() | |
| # src_idx_tensor = src_idx[0, : ,0] | |
| # id = src_idx_tensor % A | |
| # x = id % W | |
| # y = id // W | |
| # src_xy = torch.stack((x, y), dim=1) | |
| # tar_idx_tensor = tar_idx[0, : ,0] | |
| # x = tar_idx_tensor % W | |
| # y = tar_idx_tensor // W | |
| # tar_xy = torch.stack((x, y), dim=1) | |
| # cosine_delta = torch.sum(torch.norm((src_xy - tar_xy).float(), dim=-1)) | |
| # import ipdb; ipdb.set_trace() | |
| # print("&&&") | |
| # if flow is not None: | |
| # print(f"[INFO] Flow Delta: {flow_delta.item()} Cosine Delta: {cosine_delta.item()}") | |
| # else: | |
| # print(f"Cosine Delta: {cosine_delta.item()}") | |
| def merge(x: torch.Tensor, mode=None) -> torch.Tensor: | |
| # Merge tokens according to matching result. | |
| src, dst = split(x) | |
| n, t1, c = src.shape | |
| u_idx, s_idx, t_idx = unm_idx, src_idx, tar_idx | |
| # print(f"[INFO] {s_idx[0, 10, 0]} {t_idx[0, 10, 0]}") | |
| unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) | |
| mode = mode if mode is not None else merge_mode | |
| if mode != "replace": | |
| src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) | |
| # In other mode such as mean, combine matched src and dst tokens. | |
| dst = dst.scatter_reduce(-2, t_idx.expand(-1, -1, c), | |
| src, reduce=mode, include_self=True) | |
| # In replace mode, just cat unmerged tokens and dst tokens. Ignore src tokens. | |
| return torch.cat([unm, dst], dim=1) | |
| def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor: | |
| # Unmerge tokens to original size according to matching result. | |
| unm_len = unm_idx.shape[1] | |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | |
| b, _, c = unm.shape | |
| u_idx, s_idx, t_idx = unm_idx, src_idx, tar_idx | |
| # Restored src tokens take value from dst tokens | |
| src = gather(dst, dim=-2, index=t_idx.expand(-1, -1, c)) | |
| # Combine back to the original shape | |
| out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) | |
| # Scatter dst tokens | |
| out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) | |
| # Scatter unmerged tokens | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), | |
| dim=1, index=u_idx).expand(-1, -1, c), src=unm) | |
| # Scatter src tokens | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), | |
| dim=1, index=s_idx).expand(-1, -1, c), src=src) | |
| return out | |
| # Return number of tokens not merged. | |
| ret_dict = {"scores": scores, "unm_num": unm_idx.shape[1] if unm_idx.shape[1] is not None else 0} | |
| return merge, unmerge, ret_dict | |
| def bipartite_soft_matching_random2d_hier(metric: torch.Tensor, frame_num: int, ratio: float, unm_pre: int, generator: torch.Generator, target_stride: int = 4, adhere_src: bool = False, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2) -> Tuple[Callable, Callable]: | |
| """ | |
| Partitions the tokens into src and dst and merges r tokens from src to dst. | |
| Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. | |
| Args: | |
| - metric [B, N, C]: metric to use for similarity | |
| - w: image width in tokens | |
| - h: image height in tokens | |
| - sx: stride in the x dimension for dst, must divide w | |
| - sy: stride in the y dimension for dst, must divide h | |
| - r: number of tokens to remove (by merging) | |
| - no_rand: if true, disable randomness (use top left corner only) | |
| - rand_seed: if no_rand is false, and if not None, sets random seed. | |
| """ | |
| B, N, _ = metric.shape | |
| F = frame_num | |
| nf = (N - unm_pre) // F | |
| if ratio <= 0: | |
| return do_nothing, do_nothing | |
| gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead | |
| idx_buffer = torch.arange(N - unm_pre, device=metric.device, dtype=torch.int64) | |
| # randn = torch.randint(0, F, torch.Size([nf])).to(idx_buffer) * nf | |
| # dst_indexes = torch.arange(nf, device=metric.device, dtype=torch.int64) + randn | |
| # dst_select = torch.zeros_like(idx_buffer).to(torch.bool) | |
| # dst_select[dst_indexes] = 1 | |
| max_f = min(target_stride, F) | |
| randn = torch.randint(0, max_f, torch.Size([1]), generator=generator, device = generator.device) | |
| # randn = 0 | |
| dst_select = ((torch.div(idx_buffer, nf, rounding_mode='floor')) % max_f == randn).to(torch.bool) | |
| # dst_select = ((idx_buffer // nf) == 0).to(torch.bool) | |
| a_idx = idx_buffer[None, ~dst_select, None] + unm_pre | |
| b_idx = idx_buffer[None, dst_select, None] + unm_pre | |
| unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[None,:,None] | |
| b_idx = torch.cat([b_idx, unm_buffer], dim = 1) | |
| # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices | |
| # We're finished with these | |
| del idx_buffer, unm_buffer | |
| num_dst = b_idx.shape[1] | |
| def split(x): | |
| b, n, c = x.shape | |
| src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) | |
| dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) | |
| return src, dst | |
| def split_coord(coord): | |
| b, n, c = coord.shape | |
| src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c)) | |
| dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c)) | |
| return src, dst | |
| # Cosine similarity between A and B | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = split(metric) | |
| if coord is not None: | |
| src_coord, dst_coord = split_coord(coord) | |
| mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field | |
| scores = a @ b.transpose(-1, -2) | |
| if coord is not None: | |
| scores[mask] = 0 | |
| # Can't reduce more than the # tokens in src | |
| r = int(a.shape[1] * ratio) | |
| r = min(a.shape[1], r) | |
| if adhere_src: | |
| # scores = torch.sum(scores, dim=0) | |
| scores = torch.cat([*scores], dim = -1) | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst | |
| unm_idx = unm_idx.expand(B, -1, -1) | |
| src_idx = src_idx.expand(B, -1, -1) | |
| dst_idx = dst_idx.expand(B, -1, -1) | |
| else: | |
| # scores = torch.cat([*scores][1:], dim = -1) | |
| # node_max, node_idx = scores.max(dim=-1) | |
| # edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| # src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| # dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst | |
| # unm_idx = unm_idx.expand(B, -1, -1) | |
| # src_idx = src_idx.expand(B, -1, -1) | |
| # dst_idx = dst_idx.expand(B, -1, -1) | |
| # Find the most similar greedily | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| # if adhere_src: | |
| # unm_idx[:,...] = unm_idx[0:1] | |
| # src_idx[:,...] = src_idx[0:1] | |
| # dst_idx[:,...] = dst_idx[0:1] | |
| def merge(x: torch.Tensor, mode=None, b_select = None, **kwarg) -> torch.Tensor: | |
| src, dst = split(x) | |
| n, t1, c = src.shape | |
| if b_select is not None: | |
| if not isinstance(b_select, list): | |
| b_select = [b_select] | |
| u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] | |
| else: | |
| u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx | |
| unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) | |
| src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) | |
| mode = mode if mode is not None else merge_mode | |
| if mode != "replace": | |
| dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True) | |
| # dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add') | |
| # dst_cnt = torch.ones_like(dst) | |
| # src_ones = torch.ones_like(src) | |
| # dst_cnt = dst_cnt.scatter(-2, dst_idx.expand(n, r, c), src_ones, reduce='add') | |
| # dst = dst / dst_cnt | |
| # dst2 = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode, include_self=True) | |
| # assert torch.allclose(dst1, dst2) | |
| return torch.cat([unm, dst], dim=1) | |
| def unmerge(x: torch.Tensor, b_select = None, unm_modi = None, **kwarg) -> torch.Tensor: | |
| unm_len = unm_idx.shape[1] | |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | |
| b, _, c = unm.shape | |
| if b_select is not None: | |
| if not isinstance(b_select, list): | |
| b_select = [b_select] | |
| u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] | |
| else: | |
| u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx | |
| if unm_modi is not None: | |
| if unm_modi == "zero": | |
| unm = torch.zeros_like(unm) | |
| src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) | |
| # Combine back to the original shape | |
| out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) | |
| out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src) | |
| return out | |
| ret_dict = {"unm_num": unm_idx.shape[1]} | |
| return merge, unmerge, ret_dict | |
| # For Global Token Merging. | |
| def bipartite_soft_matching_2s( metric: torch.Tensor, | |
| src_len: int, ratio: float, align_batch: bool, | |
| merge_mode: str = "replace", unmerge_chunk: int = 0) -> Tuple[Callable, Callable, dict]: | |
| """ | |
| Partitions the tokens into src and dst and merges ratio of src tokens from src to dst. | |
| Src tokens are partitioned as first src_len tokens. Others are dst tokens. | |
| Args: | |
| - metric [B, N, C]: metric to use for similarity. | |
| - src_len: src token length. [ src | dst ]: [ src_len | N - src_len ] | |
| - ratio: ratio of src tokens to be removed (by merging). | |
| - unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...] | |
| - align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP. | |
| - merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token. | |
| - unmerge_chunk: return which partition in unmerge. 0 for src and 1 for dst. | |
| Returns: | |
| Merge and unmerge operation according to the matching result. Return a dict including other values. | |
| """ | |
| B, N, _ = metric.shape | |
| if ratio <= 0: | |
| return do_nothing, do_nothing | |
| gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64) | |
| # [ src | dst ]: [ src_len | N - src_len ] | |
| a_idx = idx_buffer[None, :src_len, None] | |
| b_idx = idx_buffer[None, src_len:, None] | |
| del idx_buffer | |
| num_dst = b_idx.shape[1] | |
| # import ipdb; ipdb.set_trace() | |
| def split(x): | |
| # Split src, dst tokens | |
| b, n, c = x.shape | |
| # print(f"[INFO] {num_dst} {x.shape} ") | |
| src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) | |
| dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) | |
| return src, dst | |
| # Cosine similarity between src and dst tokens | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = split(metric) | |
| scores = a @ b.transpose(-1, -2) | |
| # Can't reduce more than the # tokens in src | |
| r = min(a.shape[1], int(a.shape[1] * ratio)) | |
| if align_batch: | |
| # Cat scores of all samples in the batch. When using PnP, samples are (src, neg, pos). | |
| # Find the most similar greedily among all samples. | |
| scores = torch.cat([*scores], dim=-1) | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], | |
| dim=-2, index=src_idx) % num_dst # Map index to (0, num_dst - 1) | |
| # Use the same matching result for all samples | |
| unm_idx = unm_idx.expand(B, -1, -1) | |
| src_idx = src_idx.expand(B, -1, -1) | |
| dst_idx = dst_idx.expand(B, -1, -1) | |
| else: | |
| # Find the most similar greedily | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| def merge(x: torch.Tensor, mode=None) -> torch.Tensor: | |
| # Merge tokens according to matching result. | |
| # import ipdb; ipdb.set_trace() | |
| src, dst = split(x) | |
| n, t1, c = src.shape | |
| u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx | |
| unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) | |
| mode = mode if mode is not None else merge_mode | |
| if mode != "replace": | |
| src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) | |
| # In other mode such as mean, combine matched src and dst tokens. | |
| dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), | |
| src, reduce=mode, include_self=True) | |
| # In replace mode, just cat unmerged tokens and dst tokens. Discard src tokens. | |
| return torch.cat([unm, dst], dim=1) | |
| def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor: | |
| # Unmerge tokens to original size according to matching result. | |
| unm_len = unm_idx.shape[1] | |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | |
| b, _, c = unm.shape | |
| u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx | |
| # Restored src tokens take value from dst tokens | |
| src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) | |
| # Combine back to the original shape | |
| out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) | |
| # Scatter dst tokens | |
| out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) | |
| # Scatter unmerged tokens | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), | |
| dim=1, index=u_idx).expand(-1, -1, c), src=unm) | |
| # Scatter src tokens | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), | |
| dim=1, index=s_idx).expand(-1, -1, c), src=src) | |
| out = out[:, :src_len, :] if unmerge_chunk == 0 else out[:, src_len:, :] | |
| return out | |
| ret_dict = {"unm_num": unm_idx.shape[1]} | |
| return merge, unmerge, ret_dict | |
| # Original ToMe | |
| def bipartite_soft_matching_random2d(metric: torch.Tensor, | |
| w: int, h: int, sx: int, sy: int, r: int, | |
| no_rand: bool = False, | |
| generator: torch.Generator = None) -> Tuple[Callable, Callable]: | |
| """ | |
| Partitions the tokens into src and dst and merges r tokens from src to dst. | |
| Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. | |
| Args: | |
| - metric [B, N, C]: metric to use for similarity | |
| - w: image width in tokens | |
| - h: image height in tokens | |
| - sx: stride in the x dimension for dst, must divide w | |
| - sy: stride in the y dimension for dst, must divide h | |
| - r: number of tokens to remove (by merging) | |
| - no_rand: if true, disable randomness (use top left corner only) | |
| - rand_seed: if no_rand is false, and if not None, sets random seed. | |
| """ | |
| B, N, _ = metric.shape | |
| if r <= 0: | |
| return do_nothing, do_nothing | |
| gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| hsy, wsx = h // sy, w // sx | |
| # For each sy by sx kernel, randomly assign one token to be dst and the rest src | |
| if no_rand: | |
| rand_idx = torch.zeros( | |
| hsy, wsx, 1, device=metric.device, dtype=torch.int64) | |
| else: | |
| rand_idx = torch.randint( | |
| sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) | |
| # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead | |
| idx_buffer_view = torch.zeros( | |
| hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) | |
| idx_buffer_view.scatter_( | |
| dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) | |
| idx_buffer_view = idx_buffer_view.view( | |
| hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) | |
| # Image is not divisible by sx or sy so we need to move it into a new buffer | |
| if (hsy * sy) < h or (wsx * sx) < w: | |
| idx_buffer = torch.zeros( | |
| h, w, device=metric.device, dtype=torch.int64) | |
| idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view | |
| else: | |
| idx_buffer = idx_buffer_view | |
| # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices | |
| rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) | |
| # We're finished with these | |
| del idx_buffer, idx_buffer_view | |
| # rand_idx is currently dst|src, so split them | |
| num_dst = hsy * wsx | |
| a_idx = rand_idx[:, num_dst:, :] # src | |
| b_idx = rand_idx[:, :num_dst, :] # dst | |
| def split(x): | |
| C = x.shape[-1] | |
| src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) | |
| dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) | |
| return src, dst | |
| # Cosine similarity between A and B | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = split(metric) | |
| scores = a @ b.transpose(-1, -2) | |
| # Can't reduce more than the # tokens in src | |
| r = min(a.shape[1], r) | |
| # Find the most similar greedily | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: | |
| src, dst = split(x) | |
| n, t1, c = src.shape | |
| unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) | |
| src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) | |
| dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) | |
| return torch.cat([unm, dst], dim=1) | |
| def unmerge(x: torch.Tensor) -> torch.Tensor: | |
| unm_len = unm_idx.shape[1] | |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | |
| _, _, c = unm.shape | |
| src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) | |
| # Combine back to the original shape | |
| out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) | |
| out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(B, | |
| a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(B, | |
| a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) | |
| return out | |
| return merge, unmerge | |
| def bipartite_soft_matching_2f(metric: torch.Tensor, src_len: int, ratio: float, adhere_src: bool, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2, unmerge_chunk = 0) -> Tuple[Callable, Callable]: | |
| """ | |
| Partitions the tokens into src and dst and merges r tokens from src to dst. | |
| Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. | |
| Args: | |
| - metric [B, N, C]: metric to use for similarity | |
| - w: image width in tokens | |
| - h: image height in tokens | |
| - sx: stride in the x dimension for dst, must divide w | |
| - sy: stride in the y dimension for dst, must divide h | |
| - r: number of tokens to remove (by merging) | |
| - no_rand: if true, disable randomness (use top left corner only) | |
| - rand_seed: if no_rand is false, and if not None, sets random seed. | |
| """ | |
| B, N, _ = metric.shape | |
| if ratio <= 0: | |
| return do_nothing, do_nothing | |
| gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather | |
| with torch.no_grad(): | |
| # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead | |
| idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64) | |
| # randn = torch.randint(0, F, torch.Size([nf])).to(idx_buffer) * nf | |
| # dst_indexes = torch.arange(nf, device=metric.device, dtype=torch.int64) + randn | |
| # dst_select = torch.zeros_like(idx_buffer).to(torch.bool) | |
| # dst_select[dst_indexes] = 1 | |
| # randn = 0 | |
| # dst_select = ((idx_buffer // nf) == 0).to(torch.bool) | |
| a_idx = idx_buffer[None, :src_len, None] | |
| b_idx = idx_buffer[None, src_len:, None] | |
| # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices | |
| # We're finished with these | |
| del idx_buffer | |
| num_dst = b_idx.shape[1] | |
| def split(x): | |
| b, n, c = x.shape | |
| src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) | |
| dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) | |
| return src, dst | |
| def split_coord(coord): | |
| b, n, c = coord.shape | |
| src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c)) | |
| dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c)) | |
| return src, dst | |
| # Cosine similarity between A and B | |
| metric = metric / metric.norm(dim=-1, keepdim=True) | |
| a, b = split(metric) | |
| if coord is not None: | |
| src_coord, dst_coord = split_coord(coord) | |
| mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field | |
| scores = a @ b.transpose(-1, -2) | |
| if coord is not None: | |
| scores[mask] = 0 | |
| # Can't reduce more than the # tokens in src | |
| r = int(a.shape[1] * ratio) | |
| r = min(a.shape[1], r) | |
| if adhere_src: | |
| scores = torch.cat([*scores], dim = -1) | |
| # scores = torch.sum(scores, dim=0) | |
| node_max, node_idx = scores.max(dim=-1) | |
| # nscores = torch.cat([*scores], dim = -2) | |
| # rev_node_max, rev_node_idx = nscores.max(dim = -2) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst | |
| unm_idx = unm_idx.expand(B, -1, -1) | |
| src_idx = src_idx.expand(B, -1, -1) | |
| dst_idx = dst_idx.expand(B, -1, -1) | |
| else: | |
| # scores = torch.cat([*scores][1:], dim = -1) | |
| # node_max, node_idx = scores.max(dim=-1) | |
| # edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| # src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| # dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst | |
| # unm_idx = unm_idx.expand(B, -1, -1) | |
| # src_idx = src_idx.expand(B, -1, -1) | |
| # dst_idx = dst_idx.expand(B, -1, -1) | |
| # Find the most similar greedily | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) | |
| # if adhere_src: | |
| # unm_idx[:,...] = unm_idx[0:1] | |
| # src_idx[:,...] = src_idx[0:1] | |
| # dst_idx[:,...] = dst_idx[0:1] | |
| def merge(x: torch.Tensor, mode=None, b_select = None) -> torch.Tensor: | |
| src, dst = split(x) | |
| n, t1, c = src.shape | |
| if b_select is not None: | |
| if not isinstance(b_select, list): | |
| b_select = [b_select] | |
| u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] | |
| else: | |
| u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx | |
| unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) | |
| # src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) | |
| mode = mode if mode is not None else merge_mode | |
| if mode != "replace": | |
| dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True) | |
| # dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add') | |
| # dst_cnt = torch.ones_like(dst) | |
| # src_ones = torch.ones_like(src) | |
| # dst_cnt = dst_cnt.scatter(-2, dst_idx.expand(n, r, c), src_ones, reduce='add') | |
| # dst = dst / dst_cnt | |
| # dst2 = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode, include_self=True) | |
| # assert torch.allclose(dst1, dst2) | |
| return torch.cat([unm, dst], dim=1) | |
| def unmerge(x: torch.Tensor, b_select = None, unm_modi = None) -> torch.Tensor: | |
| unm_len = unm_idx.shape[1] | |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | |
| b, _, c = unm.shape | |
| if b_select is not None: | |
| if not isinstance(b_select, list): | |
| b_select = [b_select] | |
| u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] | |
| else: | |
| u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx | |
| if unm_modi is not None: | |
| if unm_modi == "zero": | |
| unm = torch.zeros_like(unm) | |
| src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) | |
| # Combine back to the original shape | |
| out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) | |
| out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm) | |
| out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src) | |
| if unmerge_chunk == 0: | |
| out = out[:,:src_len,:] | |
| else: | |
| out = out[:,src_len:,:] | |
| return out | |
| ret_dict = {"unm_num": unm_idx.shape[1]} | |
| return merge, unmerge, ret_dict |