Spaces:
Sleeping
Sleeping
File size: 5,493 Bytes
e1ccef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, einsum
from layers.Embed import DataEmbedding
class Model(nn.Module):
"""
Mamba, linear-time sequence modeling with selective state spaces O(L)
Paper link: https://arxiv.org/abs/2312.00752
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
"""
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len
self.d_inner = configs.d_model * configs.expand
self.dt_rank = math.ceil(configs.d_model / 16)
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)])
self.norm = RMSNorm(configs.d_model)
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
def forecast(self, x_enc, x_mark_enc):
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc
x = self.embedding(x_enc, x_mark_enc)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x_out = self.out_layer(x)
x_out = x_out * std_enc + mean_enc
return x_out
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
x_out = self.forecast(x_enc, x_mark_enc)
return x_out[:, -self.pred_len:, :]
class ResidualBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(ResidualBlock, self).__init__()
self.mixer = MambaBlock(configs, d_inner, dt_rank)
self.norm = RMSNorm(configs.d_model)
def forward(self, x):
output = self.mixer(self.norm(x)) + x
return output
class MambaBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(MambaBlock, self).__init__()
self.d_inner = d_inner
self.dt_rank = dt_rank
self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(
in_channels = self.d_inner,
out_channels = self.d_inner,
bias = True,
kernel_size = configs.d_conv,
padding = configs.d_conv - 1,
groups = self.d_inner
)
# takes in x and outputs the input-specific delta, B, C
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False)
# projects delta
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False)
def forward(self, x):
"""
Figure 3 in Section 3.4 in the paper
"""
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
x = rearrange(x, "b l d -> b d l")
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, "b d l -> b l d")
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
output = self.out_proj(y)
return output
def ssm(self, x):
"""
Algorithm 2 in Section 3.2 in the paper
"""
(d_in, n) = self.A_log.shape
A = -torch.exp(self.A_log.float()) # [d_in, n]
D = self.D.float() # [d_in]
x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
y = self.selective_scan(x, delta, A, B, C, D)
return y
def selective_scan(self, u, delta, A, B, C, D):
(b, l, d_in) = u.shape
n = A.shape[1]
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"
# selective scan, sequential instead of parallel
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
ys.append(y)
y = torch.stack(ys, dim=1) # [B, L, d_in]
y = y + u * D
return y
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
|