Spaces:
Runtime error
Runtime error
| """ | |
| Various positional encodings for the transformer. | |
| """ | |
| import math | |
| import torch | |
| from torch import nn | |
| def PE1d_sincos(seq_length, dim): | |
| """ | |
| :param d_model: dimension of the model | |
| :param length: length of positions | |
| :return: length*d_model position matrix | |
| """ | |
| if dim % 2 != 0: | |
| raise ValueError("Cannot use sin/cos positional encoding with " | |
| "odd dim (got dim={:d})".format(dim)) | |
| pe = torch.zeros(seq_length, dim) | |
| position = torch.arange(0, seq_length).unsqueeze(1) | |
| div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * | |
| -(math.log(10000.0) / dim))) | |
| pe[:, 0::2] = torch.sin(position.float() * div_term) | |
| pe[:, 1::2] = torch.cos(position.float() * div_term) | |
| return pe.unsqueeze(1) | |
| class PositionEmbedding(nn.Module): | |
| """ | |
| Absolute pos embedding (standard), learned. | |
| """ | |
| def __init__(self, seq_length, dim, dropout, grad=False): | |
| super().__init__() | |
| self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x): | |
| # x.shape: bs, seq_len, feat_dim | |
| l = x.shape[1] | |
| x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape) | |
| x = self.dropout(x.permute(1, 0, 2)) | |
| return x | |