Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import einops | |
| import torch | |
| from scipy.spatial.transform import Rotation as R | |
| ########################## 6DoF CaPE #################################### | |
| class CaPE_6DoF: | |
| def cape_embed(self, f, P): | |
| """ | |
| Apply CaPE on feature. | |
| :param f: feature vector of shape [..., d] | |
| :param P: 4x4 transformation matrix | |
| :return: rotated feature f by pose P: f@P | |
| """ | |
| f = einops.rearrange(f, '... (d k) -> ... d k', k=4) | |
| return einops.rearrange(f@P, '... d k -> ... (d k)', k=4) | |
| def attn_with_CaPE(self, f1, f2, p1, p2): | |
| """ | |
| Do attention dot production with CaPE pose encoding. | |
| # query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) | |
| # key = cape_embed(key, p_in) # key f_k @ p_in | |
| :param f1: b (t1 l) d | |
| :param f2: b (t2 l) d | |
| :param p1: [b, t, 4, 4] | |
| :param p2: [b, t, 4, 4] | |
| :return: attention score: [email protected] | |
| """ | |
| l = f1.shape[1] // p1.shape[1] | |
| assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1] | |
| p1_invT = einops.repeat(torch.inverse(p1).permute(0, 1, 3, 2), 'b t m n -> b (t l) m n', l=l) # f1 [b, l*t1, d] | |
| query = self.cape_embed(f1, p1_invT) # [b, l*t1, d] query: f1 @ (p1)^(-T), transpose the last two dim | |
| p2_copy = einops.repeat(p2, 'b t m n -> b (t l) m n', l=l) # f2 [b, l*t2, d] | |
| key = self.cape_embed(f2, p2_copy) # [b, l*t2, d] key: f2 @ p2 | |
| att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T | |
| return att | |
| ################### 6DoF Verification ################################### | |
| def euler_to_matrix(alpha, beta, gamma, x, y, z): | |
| # radian | |
| r = R.from_euler('xyz', [alpha, beta, gamma], degrees=True) | |
| t = np.array([[x], [y], [z]]) | |
| rot_matrix = r.as_matrix() | |
| rot_matrix = np.concatenate([rot_matrix, t], axis=-1) | |
| rot_matrix = np.concatenate([rot_matrix, [[0, 0, 0, 1]]], axis=0) | |
| return rot_matrix | |
| def random_6dof_pose(B, T): | |
| pose_euler = torch.rand([B, T, 6]).numpy() # euler | |
| pose_matrix = [] | |
| for b in range(B): | |
| p = [] | |
| for t in range(T): | |
| p.append(torch.from_numpy(euler_to_matrix(*pose_euler[b, t]))) | |
| pose_matrix.append(torch.stack(p)) | |
| pose_matrix = torch.stack(pose_matrix) | |
| return pose_matrix.float() | |
| bs = 6 # batch size | |
| t1 = 3 # num of target views in each batch, can be arbitrary number | |
| t2 = 5 # num of reference views in each batch, can be arbitrary number | |
| l = 10 # len of token | |
| d = 16 # dim of token feature, need to mod 4 in this case | |
| assert d % 4 == 0 | |
| # random init query and key | |
| f1 = torch.rand(bs, t1, l, d) # query | |
| f2 = torch.rand(bs, t2, l, d) # key | |
| f1 = einops.rearrange(f1, 'b t l d -> b (t l) d') | |
| f2 = einops.rearrange(f2, 'b t l d -> b (t l) d') | |
| # random init pose p1, p2, delta_p, [bs, t, 4, 4] | |
| p1 = random_6dof_pose(bs, t1) # [bs, t1, 4, 4] | |
| p2 = random_6dof_pose(bs, t2) # [bs, t2, 4, 4] | |
| p_delta = random_6dof_pose(bs, 1) # [bs, 1, 4, 4] | |
| # delta p is identical to p1 and p2 in each batch | |
| p1_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t1//1) | |
| p2_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t2//1) | |
| # run attention with CaPE 6DoF | |
| cape_6dof = CaPE_6DoF() | |
| # att | |
| att = cape_6dof.attn_with_CaPE(f1, f2, p1, p2) | |
| # att_delta | |
| att_delta = cape_6dof.attn_with_CaPE(f1, f2, p1@p1_delta, p2@p2_delta) | |
| # condition: att score should be the same i.e. non effect from any delta_p | |
| assert torch.allclose(att, att_delta, 1e-3) | |
| print("6DoF CaPE Verified") | |
| ########################## 4DoF CaPE #################################### | |
| class CaPE_4DoF: | |
| def rotate_every_two(self, x): | |
| x = einops.rearrange(x, '... (d j) -> ... d j', j=2) | |
| x1, x2 = x.unbind(dim=-1) | |
| x = torch.stack((-x2, x1), dim=-1) | |
| return einops.rearrange(x, '... d j -> ... (d j)') | |
| def cape(self, x, p): | |
| d, l, n = x.shape[-1], p.shape[-2], p.shape[-1] | |
| assert d % (2 * n) == 0 | |
| m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n) | |
| return m | |
| def cape_embed(self, qq, kk, p1, p2): | |
| """ | |
| Embed camera position encoding into attention map | |
| :param qq: query feature map [b, l_q, feature_dim] | |
| :param kk: key feature map [b, l_k, feature_dim] | |
| :param p1: query pose [b, l_q, pose_dim] | |
| :param p2: key pose [b, l_k, pose_dim] | |
| :return: cape embedded attention map [b, l_q, l_k] | |
| """ | |
| assert p1.shape[-1] == p2.shape[-1] | |
| assert qq.shape[-1] == kk.shape[-1] | |
| assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0] | |
| assert p1.shape[1] == qq.shape[1] | |
| assert p2.shape[1] == kk.shape[1] | |
| m1 = self.cape(qq, p1) | |
| m2 = self.cape(kk, p2) | |
| q = (qq * m1.cos()) + (self.rotate_every_two(qq) * m1.sin()) | |
| k = (kk * m2.cos()) + (self.rotate_every_two(kk) * m2.sin()) | |
| return q, k | |
| def attn_with_CaPE(self, f1, f2, p1, p2): | |
| """ | |
| Do attention dot production with CaPE pose encoding. | |
| # query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) | |
| # key = cape_embed(key, p_in) # key f_k @ p_in | |
| :param f1: b (t1 l) d | |
| :param f2: b (t2 l) d | |
| :param p1: [b, t, 4] | |
| :param p2: [b, t, 4] | |
| :return: attention score: [email protected] | |
| """ | |
| l = f1.shape[1] // p1.shape[1] | |
| assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1] | |
| p1_reshape = einops.repeat(p1, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d] | |
| p2_reshape = einops.repeat(p2, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d] | |
| query, key = self.cape_embed(f1, f2, p1_reshape, p2_reshape) | |
| att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T | |
| return att | |
| ################### 4DoF Verification ################################### | |
| def random_4dof_pose(B, T): | |
| pose = torch.zeros([B, T, 4]) | |
| pose[:, :, :3] = torch.rand([B, T, 3]) # radian angle | |
| # theta \in [0, pi], azimuth \in [0, 2pi], radius \in [0, pi], 0 | |
| pose[:, :, 1] *= (2*torch.pi) | |
| pose[:, :, 0] *= torch.pi | |
| pose[:, :, 2] *= torch.pi | |
| return pose.float() | |
| def look_at(origin, target, up): | |
| forward = (target - origin) | |
| forward = forward / torch.linalg.norm(forward, dim=-1, keepdim=True) | |
| right = torch.linalg.cross(forward, up) | |
| right = right / torch.linalg.norm(right, dim=-1, keepdim=True) | |
| new_up = torch.linalg.cross(forward, right) | |
| new_up = new_up / torch.linalg.norm(new_up, dim=-1, keepdim=True) | |
| rotation_matrix = torch.stack((right, new_up, forward, target), dim=-1) | |
| matrix = torch.cat([rotation_matrix, torch.tensor([[0, 0, 0, 1]]).repeat(rotation_matrix.shape[0],rotation_matrix.shape[1], 1, 1)], dim=-2) | |
| return matrix | |
| def pose_4dof2matrix(pose_4dof): | |
| """ | |
| :param pose_4dof: [b, t, 4] | |
| :return: pose 4x4 matrix: [b, t, 4, 4] | |
| """ | |
| theta = pose_4dof[:, :, 0] | |
| azimuth = pose_4dof[:, :, 1] | |
| radius = pose_4dof[:, :, 2] | |
| xyz = torch.stack([torch.sin(theta) * torch.cos(azimuth), torch.sin(theta) * torch.sin(azimuth), torch.cos(theta)], dim=-1) * radius.unsqueeze(-1) | |
| origin = torch.zeros_like(xyz) | |
| up = torch.zeros_like(xyz) | |
| up[:, :, 2] = 1 | |
| pose = look_at(origin, xyz, up) | |
| return pose | |
| def pose_matrix24dof(pose_matrix): | |
| """ | |
| :param pose_matrix: [b, t, 4, 4] | |
| :return: pose_4dof: [b, t, 4] theta, azimuth, radius, 0, looking at origin | |
| """ | |
| xyz = pose_matrix[..., :3, 3] | |
| xy = xyz[..., 0] ** 2 + xyz[..., 1] ** 2 | |
| radius = torch.sqrt(xy + xyz[..., 2] ** 2) | |
| theta = torch.arctan2(torch.sqrt(xy), xyz[..., 2]) # for elevation angle defined from Z-axis down | |
| azimuth = torch.arctan2(xyz[..., 1], xyz[..., 0]) | |
| pose = torch.stack([theta, azimuth, radius, torch.zeros_like(radius)], dim=-1) | |
| # move to [0, 2pi] | |
| pose %= (2 * torch.pi) | |
| return pose | |
| bs = 6 # batch size | |
| t1 = 3 # num of target views in each batch, can be arbitrary number | |
| t2 = 5 # num of reference views in each batch, can be arbitrary number | |
| l = 10 # len of token | |
| d = 16 # dim of token feature, need to mod 4 in this case | |
| # random init query and key | |
| f1 = torch.rand(bs, t1, l, d) # query | |
| f2 = torch.rand(bs, t2, l, d) # key | |
| f1 = einops.rearrange(f1, 'b t l d -> b (t l) d') | |
| f2 = einops.rearrange(f2, 'b t l d -> b (t l) d') | |
| #random init 4DoF pose [bs, t1, 4], theta, azimuth, radius, 0 | |
| p1 = random_4dof_pose(bs, t1) # [bs, t1, 4] | |
| p2 = random_4dof_pose(bs, t2) # [bs, t2, 4] | |
| p1_matrix = pose_4dof2matrix(p1) | |
| p1_4dof = pose_matrix24dof(p1_matrix) | |
| assert torch.allclose(p1, p1_4dof) | |
| p_delta_4dof = random_4dof_pose(bs, 1) | |
| # delta p is identical to p1 and p2 in each batch | |
| p1_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t1//1) | |
| p2_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t2//1) | |
| # run attention with CaPE 6DoF | |
| cape_4dof = CaPE_4DoF() | |
| # att | |
| att = cape_4dof.attn_with_CaPE(f1, f2, p1, p2) | |
| # att_delta | |
| att_delta = cape_4dof.attn_with_CaPE(f1, f2, p1+p1_delta_4dof, p2+p2_delta_4dof) | |
| # condition: att score should be the same i.e. non effect from any delta_p | |
| assert torch.allclose(att, att_delta, 1e-3) | |
| print("4DoF CaPE Verified") | |
| # print("You should get assertion error because 4DoF CaPE cannot handle 6DoF jitter") | |
| # # att_delta_6dof, it cannot handle 6dof jitter | |
| # p_delta_6dof = random_6dof_pose(bs, 1) # [bs, 1, 4, 4] any delta transformation in 6DoF | |
| # # delta p is identical to p1 and p2 in each batch | |
| # p1_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t1//1) | |
| # p2_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t2//1) | |
| # # 4dof pose to 4x4 matrix | |
| # p1_matrix = pose_4dof2matrix(p1) | |
| # p2_matrix = pose_4dof2matrix(p2) | |
| # att_delta_6dof = cape_4dof.attn_with_CaPE(f1, f2, pose_matrix24dof(p1_matrix@p1_delta_6dof), pose_matrix24dof(p2_matrix@p2_delta_6dof)) | |
| # # condition: att score should be the same i.e. non effect from any delta_p | |
| # assert torch.allclose(att, att_delta_6dof, 1e-3) | |