3D-MOOD / opendet3d /op /layer /positional_encoding.py
RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""Various positional encodings for the transformer."""
import math
import torch
from torch import Tensor, nn
class PositionEmbeddingSineHW(nn.Module):
"""A more standard version of the position embedding.
It is very similar to the one used by the Attention is all you need paper,
generalized to work on images.
"""
def __init__(
self,
num_pos_feats: int = 64,
temperatureH: int = 10000,
temperatureW: int = 10000,
normalize: bool = False,
scale: float | None = None,
) -> None:
"""Constructor method for PositionEmbeddingSineHW class."""
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH
self.temperatureW = temperatureW
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x: Tensor, mask: Tensor | None = None):
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(
self.num_pos_feats, dtype=torch.float32, device=x.device
)
dim_tx = self.temperatureW ** (
2
* (torch.div(dim_tx, 2, rounding_mode="floor"))
/ self.num_pos_feats
)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(
self.num_pos_feats, dtype=torch.float32, device=x.device
)
dim_ty = self.temperatureH ** (
2
* (torch.div(dim_ty, 2, rounding_mode="floor"))
/ self.num_pos_feats
)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def get_sine_pos_embed(
pos_tensor: torch.Tensor,
num_pos_feats: int = 128,
temperature: int = 10000,
exchange_xy: bool = True,
):
"""Generate sine position embedding from a position tensor.
Args:
pos_tensor (torch.Tensor): shape: [..., n].
num_pos_feats (int): projected shape for each float in the tensor.
temperature (int): temperature in the sine/cosine function.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
Returns:
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
"""
scale = 2 * math.pi
dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=pos_tensor.device
)
dim_t = temperature ** (
2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats
)
def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t
sin_x = torch.stack(
(sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3
).flatten(2)
return sin_x
pos_res = [
sine_func(x)
for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)
]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1)
return pos_res
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
).flatten(2)
pos_y = torch.stack(
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack(
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack(
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError(
"Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))
)
pos = pos.to(pos_tensor.dtype)
return pos
def coordinate_to_encoding(
coord_tensor: Tensor,
num_feats: int = 128,
temperature: int = 10000,
scale: float = 2 * math.pi,
) -> Tensor:
"""Convert coordinate tensor to positional encoding.
Args:
coord_tensor (Tensor): Coordinate tensor to be converted to
positional encoding. With the last dimension as 2 or 4.
num_feats (int, optional): The feature dimension for each position
along x-axis or y-axis. Note the final returned dimension
for each position is 2 times of this value. Defaults to 128.
temperature (int, optional): The temperature used for scaling
the position embedding. Defaults to 10000.
scale (float, optional): A scale factor that scales the position
embedding. The scale will be used only when `normalize` is True.
Defaults to 2*pi.
Returns:
Tensor: Returned encoded positional tensor.
"""
dim_t = torch.arange(
num_feats, dtype=torch.float32, device=coord_tensor.device
)
dim_t = temperature ** (2 * (dim_t // 2) / num_feats)
x_embed = coord_tensor[..., 0] * scale
y_embed = coord_tensor[..., 1] * scale
pos_x = x_embed[..., None] / dim_t
pos_y = y_embed[..., None] / dim_t
pos_x = torch.stack(
(pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1
).flatten(2)
pos_y = torch.stack(
(pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1
).flatten(2)
if coord_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=-1)
elif coord_tensor.size(-1) == 4:
w_embed = coord_tensor[..., 2] * scale
pos_w = w_embed[..., None] / dim_t
pos_w = torch.stack(
(pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()), dim=-1
).flatten(2)
h_embed = coord_tensor[..., 3] * scale
pos_h = h_embed[..., None] / dim_t
pos_h = torch.stack(
(pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), dim=-1
).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1)
else:
raise ValueError(
"Unknown pos_tensor shape(-1):{}".format(coord_tensor.size(-1))
)
return pos