Spaces:
Running
on
Zero
Running
on
Zero
| """Various positional encodings for the transformer.""" | |
| import math | |
| import torch | |
| from torch import Tensor, nn | |
| class PositionEmbeddingSineHW(nn.Module): | |
| """A more standard version of the position embedding. | |
| It is very similar to the one used by the Attention is all you need paper, | |
| generalized to work on images. | |
| """ | |
| def __init__( | |
| self, | |
| num_pos_feats: int = 64, | |
| temperatureH: int = 10000, | |
| temperatureW: int = 10000, | |
| normalize: bool = False, | |
| scale: float | None = None, | |
| ) -> None: | |
| """Constructor method for PositionEmbeddingSineHW class.""" | |
| super().__init__() | |
| self.num_pos_feats = num_pos_feats | |
| self.temperatureH = temperatureH | |
| self.temperatureW = temperatureW | |
| self.normalize = normalize | |
| if scale is not None and normalize is False: | |
| raise ValueError("normalize should be True if scale is passed") | |
| if scale is None: | |
| scale = 2 * math.pi | |
| self.scale = scale | |
| def forward(self, x: Tensor, mask: Tensor | None = None): | |
| assert mask is not None | |
| not_mask = ~mask | |
| y_embed = not_mask.cumsum(1, dtype=torch.float32) | |
| x_embed = not_mask.cumsum(2, dtype=torch.float32) | |
| if self.normalize: | |
| eps = 1e-6 | |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale | |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale | |
| dim_tx = torch.arange( | |
| self.num_pos_feats, dtype=torch.float32, device=x.device | |
| ) | |
| dim_tx = self.temperatureW ** ( | |
| 2 | |
| * (torch.div(dim_tx, 2, rounding_mode="floor")) | |
| / self.num_pos_feats | |
| ) | |
| pos_x = x_embed[:, :, :, None] / dim_tx | |
| dim_ty = torch.arange( | |
| self.num_pos_feats, dtype=torch.float32, device=x.device | |
| ) | |
| dim_ty = self.temperatureH ** ( | |
| 2 | |
| * (torch.div(dim_ty, 2, rounding_mode="floor")) | |
| / self.num_pos_feats | |
| ) | |
| pos_y = y_embed[:, :, :, None] / dim_ty | |
| pos_x = torch.stack( | |
| (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos_y = torch.stack( | |
| (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
| return pos | |
| def get_sine_pos_embed( | |
| pos_tensor: torch.Tensor, | |
| num_pos_feats: int = 128, | |
| temperature: int = 10000, | |
| exchange_xy: bool = True, | |
| ): | |
| """Generate sine position embedding from a position tensor. | |
| Args: | |
| pos_tensor (torch.Tensor): shape: [..., n]. | |
| num_pos_feats (int): projected shape for each float in the tensor. | |
| temperature (int): temperature in the sine/cosine function. | |
| exchange_xy (bool, optional): exchange pos x and pos y. \ | |
| For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True. | |
| Returns: | |
| pos_embed (torch.Tensor): shape: [..., n*num_pos_feats]. | |
| """ | |
| scale = 2 * math.pi | |
| dim_t = torch.arange( | |
| num_pos_feats, dtype=torch.float32, device=pos_tensor.device | |
| ) | |
| dim_t = temperature ** ( | |
| 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats | |
| ) | |
| def sine_func(x: torch.Tensor): | |
| sin_x = x * scale / dim_t | |
| sin_x = torch.stack( | |
| (sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| return sin_x | |
| pos_res = [ | |
| sine_func(x) | |
| for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1) | |
| ] | |
| if exchange_xy: | |
| pos_res[0], pos_res[1] = pos_res[1], pos_res[0] | |
| pos_res = torch.cat(pos_res, dim=-1) | |
| return pos_res | |
| def gen_sineembed_for_position(pos_tensor): | |
| # n_query, bs, _ = pos_tensor.size() | |
| # sineembed_tensor = torch.zeros(n_query, bs, 256) | |
| scale = 2 * math.pi | |
| dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) | |
| dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / 128) | |
| x_embed = pos_tensor[:, :, 0] * scale | |
| y_embed = pos_tensor[:, :, 1] * scale | |
| pos_x = x_embed[:, :, None] / dim_t | |
| pos_y = y_embed[:, :, None] / dim_t | |
| pos_x = torch.stack( | |
| (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| pos_y = torch.stack( | |
| (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| if pos_tensor.size(-1) == 2: | |
| pos = torch.cat((pos_y, pos_x), dim=2) | |
| elif pos_tensor.size(-1) == 4: | |
| w_embed = pos_tensor[:, :, 2] * scale | |
| pos_w = w_embed[:, :, None] / dim_t | |
| pos_w = torch.stack( | |
| (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| h_embed = pos_tensor[:, :, 3] * scale | |
| pos_h = h_embed[:, :, None] / dim_t | |
| pos_h = torch.stack( | |
| (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 | |
| ).flatten(2) | |
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) | |
| else: | |
| raise ValueError( | |
| "Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)) | |
| ) | |
| pos = pos.to(pos_tensor.dtype) | |
| return pos | |
| def coordinate_to_encoding( | |
| coord_tensor: Tensor, | |
| num_feats: int = 128, | |
| temperature: int = 10000, | |
| scale: float = 2 * math.pi, | |
| ) -> Tensor: | |
| """Convert coordinate tensor to positional encoding. | |
| Args: | |
| coord_tensor (Tensor): Coordinate tensor to be converted to | |
| positional encoding. With the last dimension as 2 or 4. | |
| num_feats (int, optional): The feature dimension for each position | |
| along x-axis or y-axis. Note the final returned dimension | |
| for each position is 2 times of this value. Defaults to 128. | |
| temperature (int, optional): The temperature used for scaling | |
| the position embedding. Defaults to 10000. | |
| scale (float, optional): A scale factor that scales the position | |
| embedding. The scale will be used only when `normalize` is True. | |
| Defaults to 2*pi. | |
| Returns: | |
| Tensor: Returned encoded positional tensor. | |
| """ | |
| dim_t = torch.arange( | |
| num_feats, dtype=torch.float32, device=coord_tensor.device | |
| ) | |
| dim_t = temperature ** (2 * (dim_t // 2) / num_feats) | |
| x_embed = coord_tensor[..., 0] * scale | |
| y_embed = coord_tensor[..., 1] * scale | |
| pos_x = x_embed[..., None] / dim_t | |
| pos_y = y_embed[..., None] / dim_t | |
| pos_x = torch.stack( | |
| (pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1 | |
| ).flatten(2) | |
| pos_y = torch.stack( | |
| (pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1 | |
| ).flatten(2) | |
| if coord_tensor.size(-1) == 2: | |
| pos = torch.cat((pos_y, pos_x), dim=-1) | |
| elif coord_tensor.size(-1) == 4: | |
| w_embed = coord_tensor[..., 2] * scale | |
| pos_w = w_embed[..., None] / dim_t | |
| pos_w = torch.stack( | |
| (pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()), dim=-1 | |
| ).flatten(2) | |
| h_embed = coord_tensor[..., 3] * scale | |
| pos_h = h_embed[..., None] / dim_t | |
| pos_h = torch.stack( | |
| (pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), dim=-1 | |
| ).flatten(2) | |
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1) | |
| else: | |
| raise ValueError( | |
| "Unknown pos_tensor shape(-1):{}".format(coord_tensor.size(-1)) | |
| ) | |
| return pos | |