Spaces:
Build error
Build error
| import gc | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import repeat, rearrange | |
| from vidtome import merge | |
| from utils.flow_utils import flow_warp, coords_grid | |
| # AdaIn | |
| def calc_mean_std(feat, eps=1e-5): | |
| # eps is a small value added to the variance to avoid divide-by-zero. | |
| size = feat.size() | |
| assert (len(size) == 4) | |
| N, C = size[:2] | |
| feat_var = feat.view(N, C, -1).var(dim=2) + eps | |
| feat_std = feat_var.sqrt().view(N, C, 1, 1) | |
| feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) | |
| return feat_mean, feat_std | |
| class AttentionControl(): | |
| def __init__(self, | |
| warp_period=(0.0, 0.0), | |
| merge_period=(0.0, 0.0), | |
| merge_ratio=(0.3, 0.3), | |
| ToMe_period=(0.0, 1.0), | |
| mask_period=(0.0, 0.0), | |
| cross_period=(0.0, 0.0), | |
| ada_period=(0.0, 0.0), | |
| inner_strength=1.0, | |
| loose_cfatnn=False, | |
| flow_merge=True, | |
| ): | |
| self.cur_frame_idx = 0 | |
| self.step_store = self.get_empty_store() | |
| self.cur_step = 0 | |
| self.total_step = 0 | |
| self.cur_index = 0 | |
| self.init_store = False | |
| self.restore = False | |
| self.update = False | |
| self.flow = None | |
| self.mask = None | |
| self.cldm = None | |
| self.decoded_imgs = [] | |
| self.restorex0 = True | |
| self.updatex0 = False | |
| self.inner_strength = inner_strength | |
| self.cross_period = cross_period | |
| self.mask_period = mask_period | |
| self.ada_period = ada_period | |
| self.warp_period = warp_period | |
| self.ToMe_period = ToMe_period | |
| self.merge_period = merge_period | |
| self.merge_ratio = merge_ratio | |
| self.keyframe_idx = 0 | |
| self.flow_merge = flow_merge | |
| self.distances = {} | |
| self.flow_correspondence = {} | |
| self.non_pad_ratio = (1.0, 1.0) | |
| self.up_resolution = 1280 if loose_cfatnn else 1281 | |
| def get_empty_store(): | |
| return { | |
| 'first': [], | |
| 'previous': [], | |
| 'x0_previous': [], | |
| 'first_ada': [], | |
| 'pre_x0': [], | |
| "pre_keyframe_lq": None, | |
| "flows": None, | |
| "occ_masks": None, | |
| "flow_confids": None, | |
| "merge": None, | |
| "unmerge": None, | |
| "corres_scores": None, | |
| "flows2": None, | |
| "flow_confids2": None, | |
| } | |
| def forward(self, context, is_cross: bool, place_in_unet: str): | |
| cross_period = (self.total_step * self.cross_period[0], | |
| self.total_step * self.cross_period[1]) | |
| if not is_cross and place_in_unet == 'up' and context.shape[ | |
| 2] < self.up_resolution: | |
| if self.init_store: | |
| self.step_store['first'].append(context.detach()) | |
| self.step_store['previous'].append(context.detach()) | |
| if self.update: | |
| tmp = context.clone().detach() | |
| if self.restore and self.cur_step >= cross_period[0] and \ | |
| self.cur_step <= cross_period[1]: | |
| # context = torch.cat( | |
| # (self.step_store['first'][self.cur_index], | |
| # self.step_store['previous'][self.cur_index]), | |
| # dim=1).clone() | |
| context = self.step_store['previous'][self.cur_index].clone() | |
| if self.update: | |
| self.step_store['previous'][self.cur_index] = tmp | |
| self.cur_index += 1 | |
| # print(is_cross, place_in_unet, context.shape[2]) | |
| # import ipdb; ipdb.set_trace() | |
| return context | |
| def update_x0(self, x0, cur_frame=0): | |
| # if self.init_store: | |
| # self.step_store['x0_previous'].append(x0.detach()) | |
| # style_mean, style_std = calc_mean_std(x0.detach()) | |
| # self.step_store['first_ada'].append(style_mean.detach()) | |
| # self.step_store['first_ada'].append(style_std.detach()) | |
| # if self.updatex0: | |
| # tmp = x0.clone().detach() | |
| if self.restorex0: | |
| # if self.cur_step >= self.total_step * self.ada_period[ | |
| # 0] and self.cur_step <= self.total_step * self.ada_period[ | |
| # 1]: | |
| # x0 = F.instance_norm(x0) * self.step_store['first_ada'][ | |
| # 2 * self.cur_step + | |
| # 1] + self.step_store['first_ada'][2 * self.cur_step] | |
| if self.cur_step >= self.total_step * self.warp_period[ | |
| 0] and self.cur_step < int(self.total_step * self.warp_period[1]): | |
| # mid_x = repeat(x[mid][None], 'b c h w -> (repeat b) c h w', repeat=x.shape[0]) | |
| mid = x0.shape[0] // 2 | |
| if len(self.step_store["pre_x0"]) == int(self.total_step * self.warp_period[1]): | |
| print(f"[INFO] keyframe latent warping @ step {self.cur_step}...") | |
| x0[mid] = (1 - self.step_store["occ_masks"][mid]) * x0[mid] + \ | |
| flow_warp(self.step_store["pre_x0"][self.cur_step][None], self.step_store["flows"][mid], mode='nearest')[0] * self.step_store["occ_masks"][mid] | |
| print(f"[INFO] local latent warping @ step {self.cur_step}...") | |
| for i in range(x0.shape[0]): | |
| if i == mid: | |
| continue | |
| x0[i] = (1 - self.step_store["occ_masks"][i]) * x0[i] + \ | |
| flow_warp(x0[mid][None], self.step_store["flows"][i], mode='nearest')[0] * self.step_store["occ_masks"][i] | |
| # x = rearrange(x, 'b c h w -> b (h w) c', h=64) | |
| # self.step_store['x0_previous'][self.cur_step] = tmp | |
| # print(f"[INFO] storeing {self.cur_frame_idx} th frame x0 for step {self.cur_step}...") | |
| if len(self.step_store["pre_x0"]) < int(self.total_step * self.warp_period[1]): | |
| self.step_store['pre_x0'].append(x0[mid]) | |
| else: | |
| self.step_store['pre_x0'][self.cur_step] = x0[mid] | |
| return x0 | |
| def merge_x0(self, x0, merge_ratio): | |
| # print(f"[INFO] {self.total_step * self.merge_period[0]} {self.cur_step} {int(self.total_step * self.merge_period[1])} ...") | |
| if self.cur_step >= self.total_step * self.merge_period[0] and \ | |
| self.cur_step < int(self.total_step * self.merge_period[1]): | |
| print(f"[INFO] latent merging @ step {self.cur_step}...") | |
| B, C, H, W = x0.shape | |
| non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio | |
| padding_size_w = W - int(W * non_pad_ratio_w) | |
| padding_size_h = H - int(H * non_pad_ratio_h) | |
| non_pad_w = W - padding_size_w | |
| non_pad_h = H - padding_size_h | |
| padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool) | |
| if padding_size_w: | |
| padding_mask[:, -padding_size_w:] = 1 | |
| if padding_size_h: | |
| padding_mask[-padding_size_h:, :] = 1 | |
| padding_mask = rearrange(padding_mask, 'h w -> (h w)') | |
| idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64) | |
| non_pad_idx = idx_buffer[None, ~padding_mask, None] | |
| del idx_buffer, padding_mask | |
| x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H) | |
| x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C)) | |
| # import ipdb; ipdb.set_trace() | |
| # merge.visualize_correspondence(x_non_pad[0][None], x_non_pad[B//2][None], ratio=0.3, H=H, out="latent_correspondence.png") | |
| # m, u, ret_dict = merge.bipartite_soft_matching_randframe( | |
| # x_non_pad, B, merge_ratio, 0, target_stride=B) | |
| import copy | |
| flows = copy.deepcopy(self.step_store["flows"]) | |
| for i in range(B): | |
| if flows[i] is not None: | |
| flows[i] = flows[i][:, :, :non_pad_h, :non_pad_w] | |
| # merge.visualize_flow_correspondence(x_non_pad[1][None], x_non_pad[B // 2][None], flow=flows[1], flow_confid=self.step_store["flow_confids"][1], \ | |
| # ratio=0.8, H=H, out=f"flow_correspondence_08.png") | |
| # import ipdb; ipdb.set_trace() | |
| x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c') | |
| m, u, ret_dict = merge.bipartite_soft_matching_randframe( | |
| x_non_pad, B, merge_ratio, 0, target_stride=B, | |
| H=H, | |
| flow=flows, | |
| flow_confid=self.step_store["flow_confids"], | |
| ) | |
| x_non_pad = u(m(x_non_pad)) | |
| # x_non_pad = self.step_store["unmerge"](self.step_store["merge"](x_non_pad)) | |
| x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', b=B) | |
| # print(torch.mean(x0[0]).item(), torch.mean(x0[1]).item(), torch.mean(x0[2]).item(), torch.mean(x0[3]).item(), torch.mean(x0[4]).item()) | |
| # print(torch.std(x0[0]).item(), torch.std(x0[1]).item(), torch.std(x0[2]).item(), torch.std(x0[3]).item(), torch.std(x0[4]).item()) | |
| # import ipdb; ipdb.set_trace() | |
| x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) | |
| x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H) | |
| # import ipdb; ipdb.set_trace() | |
| return x0 | |
| def merge_x0_scores(self, x0, merge_ratio, merge_mode="replace"): | |
| # print(f"[INFO] {self.total_step * self.merge_period[0]} {self.cur_step} {int(self.total_step * self.merge_period[1])} ...") | |
| # import ipdb; ipdb.set_trace() | |
| if self.cur_step >= self.total_step * self.merge_period[0] and \ | |
| self.cur_step < int(self.total_step * self.merge_period[1]): | |
| print(f"[INFO] latent merging @ step {self.cur_step}...") | |
| B, C, H, W = x0.shape | |
| non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio | |
| padding_size_w = W - int(W * non_pad_ratio_w) | |
| padding_size_h = H - int(H * non_pad_ratio_h) | |
| padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool) | |
| if padding_size_w: | |
| padding_mask[:, -padding_size_w:] = 1 | |
| if padding_size_h: | |
| padding_mask[-padding_size_h:, :] = 1 | |
| padding_mask = rearrange(padding_mask, 'h w -> (h w)') | |
| idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64) | |
| non_pad_idx = idx_buffer[None, ~padding_mask, None] | |
| x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H) | |
| x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C)) | |
| x_non_pad_A, x_non_pad_N = x_non_pad.shape[1], x_non_pad.shape[1] * B | |
| mid = B // 2 | |
| x_non_pad_ = x_non_pad.clone() | |
| x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c') | |
| # import ipdb; ipdb.set_trace() | |
| idx_buffer = torch.arange(x_non_pad_N, device=x0.device, dtype=torch.int64) | |
| randf = torch.tensor(B // 2, dtype=torch.int).to(x0.device) | |
| # print(f"[INFO] {randf.item()} th frame as target") | |
| dst_select = ((torch.div(idx_buffer, x_non_pad_A, rounding_mode='floor')) % B == randf).to(torch.bool) | |
| # a_idx: src index. b_idx: dst index | |
| a_idx = idx_buffer[None, ~dst_select, None] | |
| b_idx = idx_buffer[None, dst_select, None] | |
| del idx_buffer, padding_mask | |
| num_dst = b_idx.shape[1] | |
| # b, _, _ = x_non_pad.shape | |
| b = 1 | |
| src = torch.gather(x_non_pad, dim=1, index=a_idx.expand(b, x_non_pad_N - num_dst, C)) | |
| tar = torch.gather(x_non_pad, dim=1, index=b_idx.expand(b, num_dst, C)) | |
| # tar = x_non_pad[mid][None] | |
| # src = torch.cat((x_non_pad[:mid], x_non_pad[mid+1:]), dim=0) | |
| # src = rearrange(src, 'b n c -> 1 (b n) c') | |
| # print(f"[INFO] {x_non_pad.shape} {src.shape} {tar.shape} ...") | |
| # print(f"[INFO] maximum score {torch.max(self.step_store['corres_scores'])} ...") | |
| flow_src_idx = self.flow_correspondence[H][0] | |
| flow_tar_idx = self.flow_correspondence[H][1] | |
| flow_confid = self.step_store["flow_confids"][:mid] + self.step_store["flow_confids"][mid+1:] | |
| flow_confid = torch.cat(flow_confid, dim=0) | |
| flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') | |
| scores = F.normalize(self.step_store["corres_scores"], p=2, dim=-1) | |
| flow_confid -= (torch.max(flow_confid) - torch.max(scores)) | |
| # merge.visualize_correspondence_score(x_non_pad_[0][None], x_non_pad_[mid][None], | |
| # score=scores[:,:x_non_pad_A], | |
| # ratio=0.2, H=H-padding_size_h, out="latent_correspondence.png") | |
| # import ipdb; ipdb.set_trace() | |
| scores[:, flow_src_idx[0, :, 0], flow_tar_idx[0, :, 0]] += (flow_confid[:, flow_src_idx[0, :, 0]] * 0.3) | |
| # merge.visualize_correspondence_score(x_non_pad_[0][None], x_non_pad_[mid][None], | |
| # score=scores[:,:x_non_pad_A], | |
| # ratio=0.2, H=H-padding_size_h, out="latent_correspondence_flow.png") | |
| # import ipdb; ipdb.set_trace() | |
| r = min(src.shape[1], int(src.shape[1] * merge_ratio)) | |
| node_max, node_idx = scores.max(dim=-1) | |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | |
| unm_idx = edge_idx[..., r:, :] # Unmerged Tokens | |
| src_idx = edge_idx[..., :r, :] # Merged Tokens | |
| tar_idx = torch.gather(node_idx[..., None], dim=-2, index=src_idx) | |
| unm = torch.gather(src, dim=-2, index=unm_idx.expand(-1, -1, C)) | |
| if merge_mode != "replace": | |
| src = torch.gather(src, dim=-2, index=src_idx.expand(-1, -1, C)) | |
| # In other mode such as mean, combine matched src and dst tokens. | |
| tar = tar.scatter_reduce(-2, tar_idx.expand(-1, -1, C), | |
| src, reduce=merge_mode, include_self=True) | |
| # In replace mode, just cat unmerged tokens and tar tokens. Ignore src tokens. | |
| # token = torch.cat([unm, tar], dim=1) | |
| # unm_len = unm_idx.shape[1] | |
| # unm, tar = token[..., :unm_len, :], token[..., unm_len:, :] | |
| src = torch.gather(tar, dim=-2, index=tar_idx.expand(-1, -1, C)) | |
| # Combine back to the original shape | |
| # x_non_pad = torch.zeros(b, x_non_pad_N, C, device=x0.device, dtype=x0.dtype) | |
| # Scatter dst tokens | |
| x_non_pad.scatter_(dim=-2, index=b_idx.expand(b, -1, C), src=tar) | |
| # Scatter unmerged tokens | |
| x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1), | |
| dim=1, index=unm_idx).expand(-1, -1, C), src=unm) | |
| # Scatter src tokens | |
| x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1), | |
| dim=1, index=src_idx).expand(-1, -1, C), src=src) | |
| x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', a=x_non_pad_A) | |
| x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) | |
| x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H) | |
| return x0 | |
| def set_distance(self, B, H, W, radius, device): | |
| y, x = torch.meshgrid(torch.arange(H), torch.arange(W)) | |
| coords = torch.stack((y, x), dim=-1).float().to(device) | |
| coords = rearrange(coords, 'h w c -> (h w) c') | |
| # Calculate the Euclidean distance between all pixels | |
| distances = torch.cdist(coords, coords) | |
| # radius = W // 30 | |
| radius = 1 if radius == 0 else radius | |
| # print(f"[INFO] W: {W} Radius: {radius} ") | |
| distances //= radius | |
| distances = torch.exp(-distances) | |
| # distances += torch.diag_embed(torch.ones(A)).to(metric.device) | |
| distances = repeat(distances, 'h a -> 1 (b h) a', b=B) | |
| self.distances[H] = distances | |
| def set_flow_correspondence(self, B, H, W, key_idx, flow_confid, flow): | |
| if len(flow) != B - 1: | |
| flow_confid = flow_confid[:key_idx] + flow_confid[key_idx+1:] | |
| flow = flow[:key_idx] + flow[key_idx+1:] | |
| flow_confid = torch.cat(flow_confid, dim=0) | |
| flow = torch.cat(flow, dim=0) | |
| flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') | |
| edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] | |
| src_idx = edge_idx[..., :, :] # Merged Tokens | |
| A = H * W | |
| src_idx_tensor = src_idx[0, : ,0] | |
| f = src_idx_tensor // A | |
| id = src_idx_tensor % A | |
| x = id % W | |
| y = id // W | |
| # Stack the results into a 2D tensor | |
| src_fxy = torch.stack((f, x, y), dim=1) | |
| # import ipdb; ipdb.set_trace() | |
| grid = coords_grid(B-1, H, W).to(flow.device) + flow # [F-1, 2, H, W] | |
| x = grid[src_fxy[:, 0], 0, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, W-1).long() | |
| y = grid[src_fxy[:, 0], 1, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, H-1).long() | |
| tar_xy = torch.stack((x, y), dim=1) | |
| tar_idx = y * W + x | |
| tar_idx = rearrange(tar_idx, ' d -> 1 d 1') | |
| self.flow_correspondence[H] = (src_idx, tar_idx) | |
| def set_merge(self, merge, unmerge): | |
| self.step_store["merge"] = merge | |
| self.step_store["unmerge"] = unmerge | |
| def set_warp(self, flows, masks, flow_confids=None): | |
| self.step_store["flows"] = flows | |
| self.step_store["occ_masks"] = masks | |
| if flow_confids is not None: | |
| self.step_store["flow_confids"] = flow_confids | |
| def set_warp2(self, flows, flow_confids): | |
| self.step_store["flows2"] = flows | |
| self.step_store["flow_confids2"] = flow_confids | |
| def set_pre_keyframe_lq(self, pre_keyframe_lq): | |
| self.step_store["pre_keyframe_lq"] = pre_keyframe_lq | |
| def __call__(self, context, is_cross: bool, place_in_unet: str): | |
| context = self.forward(context, is_cross, place_in_unet) | |
| return context | |
| def set_cur_frame_idx(self, frame_idx): | |
| self.cur_frame_idx = frame_idx | |
| def set_step(self, step): | |
| self.cur_step = step | |
| def set_total_step(self, total_step): | |
| self.total_step = total_step | |
| self.cur_index = 0 | |
| def clear_store(self): | |
| del self.step_store | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| self.step_store = self.get_empty_store() | |
| def set_task(self, task, restore_step=1.0): | |
| self.init_store = False | |
| self.restore = False | |
| self.update = False | |
| self.cur_index = 0 | |
| self.restore_step = restore_step | |
| self.updatex0 = False | |
| self.restorex0 = False | |
| if 'initfirst' in task: | |
| self.init_store = True | |
| self.clear_store() | |
| if 'updatestyle' in task: | |
| self.update = True | |
| if 'keepstyle' in task: | |
| self.restore = True | |
| if 'updatex0' in task: | |
| self.updatex0 = True | |
| if 'keepx0' in task: | |
| self.restorex0 = True | |