Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.modules.linear import Linear | |
| from layers.SelfAttention_Family import AttentionLayer, FullAttention | |
| from layers.Embed import DataEmbedding | |
| import math | |
| def get_mask(input_size, window_size, inner_size): | |
| """Get the attention mask of PAM-Naive""" | |
| # Get the size of all layers | |
| all_size = [] | |
| all_size.append(input_size) | |
| for i in range(len(window_size)): | |
| layer_size = math.floor(all_size[i] / window_size[i]) | |
| all_size.append(layer_size) | |
| seq_length = sum(all_size) | |
| mask = torch.zeros(seq_length, seq_length) | |
| # get intra-scale mask | |
| inner_window = inner_size // 2 | |
| for layer_idx in range(len(all_size)): | |
| start = sum(all_size[:layer_idx]) | |
| for i in range(start, start + all_size[layer_idx]): | |
| left_side = max(i - inner_window, start) | |
| right_side = min(i + inner_window + 1, start + all_size[layer_idx]) | |
| mask[i, left_side:right_side] = 1 | |
| # get inter-scale mask | |
| for layer_idx in range(1, len(all_size)): | |
| start = sum(all_size[:layer_idx]) | |
| for i in range(start, start + all_size[layer_idx]): | |
| left_side = (start - all_size[layer_idx - 1]) + \ | |
| (i - start) * window_size[layer_idx - 1] | |
| if i == (start + all_size[layer_idx] - 1): | |
| right_side = start | |
| else: | |
| right_side = ( | |
| start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1] | |
| mask[i, left_side:right_side] = 1 | |
| mask[left_side:right_side, i] = 1 | |
| mask = (1 - mask).bool() | |
| return mask, all_size | |
| def refer_points(all_sizes, window_size): | |
| """Gather features from PAM's pyramid sequences""" | |
| input_size = all_sizes[0] | |
| indexes = torch.zeros(input_size, len(all_sizes)) | |
| for i in range(input_size): | |
| indexes[i][0] = i | |
| former_index = i | |
| for j in range(1, len(all_sizes)): | |
| start = sum(all_sizes[:j]) | |
| inner_layer_idx = former_index - (start - all_sizes[j - 1]) | |
| former_index = start + \ | |
| min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1) | |
| indexes[i][j] = former_index | |
| indexes = indexes.unsqueeze(0).unsqueeze(3) | |
| return indexes.long() | |
| class RegularMask(): | |
| def __init__(self, mask): | |
| self._mask = mask.unsqueeze(1) | |
| def mask(self): | |
| return self._mask | |
| class EncoderLayer(nn.Module): | |
| """ Compose with two layers """ | |
| def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True): | |
| super(EncoderLayer, self).__init__() | |
| self.slf_attn = AttentionLayer( | |
| FullAttention(mask_flag=True, factor=0, | |
| attention_dropout=dropout, output_attention=False), | |
| d_model, n_head) | |
| self.pos_ffn = PositionwiseFeedForward( | |
| d_model, d_inner, dropout=dropout, normalize_before=normalize_before) | |
| def forward(self, enc_input, slf_attn_mask=None): | |
| attn_mask = RegularMask(slf_attn_mask) | |
| enc_output, _ = self.slf_attn( | |
| enc_input, enc_input, enc_input, attn_mask=attn_mask) | |
| enc_output = self.pos_ffn(enc_output) | |
| return enc_output | |
| class Encoder(nn.Module): | |
| """ A encoder model with self attention mechanism. """ | |
| def __init__(self, configs, window_size, inner_size): | |
| super().__init__() | |
| d_bottleneck = configs.d_model//4 | |
| self.mask, self.all_size = get_mask( | |
| configs.seq_len, window_size, inner_size) | |
| self.indexes = refer_points(self.all_size, window_size) | |
| self.layers = nn.ModuleList([ | |
| EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout, | |
| normalize_before=False) for _ in range(configs.e_layers) | |
| ]) # naive pyramid attention | |
| self.enc_embedding = DataEmbedding( | |
| configs.enc_in, configs.d_model, configs.dropout) | |
| self.conv_layers = Bottleneck_Construct( | |
| configs.d_model, window_size, d_bottleneck) | |
| def forward(self, x_enc, x_mark_enc): | |
| seq_enc = self.enc_embedding(x_enc, x_mark_enc) | |
| mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device) | |
| seq_enc = self.conv_layers(seq_enc) | |
| for i in range(len(self.layers)): | |
| seq_enc = self.layers[i](seq_enc, mask) | |
| indexes = self.indexes.repeat(seq_enc.size( | |
| 0), 1, 1, seq_enc.size(2)).to(seq_enc.device) | |
| indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2)) | |
| all_enc = torch.gather(seq_enc, 1, indexes) | |
| seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1) | |
| return seq_enc | |
| class ConvLayer(nn.Module): | |
| def __init__(self, c_in, window_size): | |
| super(ConvLayer, self).__init__() | |
| self.downConv = nn.Conv1d(in_channels=c_in, | |
| out_channels=c_in, | |
| kernel_size=window_size, | |
| stride=window_size) | |
| self.norm = nn.BatchNorm1d(c_in) | |
| self.activation = nn.ELU() | |
| def forward(self, x): | |
| x = self.downConv(x) | |
| x = self.norm(x) | |
| x = self.activation(x) | |
| return x | |
| class Bottleneck_Construct(nn.Module): | |
| """Bottleneck convolution CSCM""" | |
| def __init__(self, d_model, window_size, d_inner): | |
| super(Bottleneck_Construct, self).__init__() | |
| if not isinstance(window_size, list): | |
| self.conv_layers = nn.ModuleList([ | |
| ConvLayer(d_inner, window_size), | |
| ConvLayer(d_inner, window_size), | |
| ConvLayer(d_inner, window_size) | |
| ]) | |
| else: | |
| self.conv_layers = [] | |
| for i in range(len(window_size)): | |
| self.conv_layers.append(ConvLayer(d_inner, window_size[i])) | |
| self.conv_layers = nn.ModuleList(self.conv_layers) | |
| self.up = Linear(d_inner, d_model) | |
| self.down = Linear(d_model, d_inner) | |
| self.norm = nn.LayerNorm(d_model) | |
| def forward(self, enc_input): | |
| temp_input = self.down(enc_input).permute(0, 2, 1) | |
| all_inputs = [] | |
| for i in range(len(self.conv_layers)): | |
| temp_input = self.conv_layers[i](temp_input) | |
| all_inputs.append(temp_input) | |
| all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2) | |
| all_inputs = self.up(all_inputs) | |
| all_inputs = torch.cat([enc_input, all_inputs], dim=1) | |
| all_inputs = self.norm(all_inputs) | |
| return all_inputs | |
| class PositionwiseFeedForward(nn.Module): | |
| """ Two-layer position-wise feed-forward neural network. """ | |
| def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True): | |
| super().__init__() | |
| self.normalize_before = normalize_before | |
| self.w_1 = nn.Linear(d_in, d_hid) | |
| self.w_2 = nn.Linear(d_hid, d_in) | |
| self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| residual = x | |
| if self.normalize_before: | |
| x = self.layer_norm(x) | |
| x = F.gelu(self.w_1(x)) | |
| x = self.dropout(x) | |
| x = self.w_2(x) | |
| x = self.dropout(x) | |
| x = x + residual | |
| if not self.normalize_before: | |
| x = self.layer_norm(x) | |
| return x | |