Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # author=maziqing | |
| # [email protected] | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| def get_frequency_modes(seq_len, modes=64, mode_select_method='random'): | |
| """ | |
| get modes on frequency domain: | |
| 'random' means sampling randomly; | |
| 'else' means sampling the lowest modes; | |
| """ | |
| modes = min(modes, seq_len // 2) | |
| if mode_select_method == 'random': | |
| index = list(range(0, seq_len // 2)) | |
| np.random.shuffle(index) | |
| index = index[:modes] | |
| else: | |
| index = list(range(0, modes)) | |
| index.sort() | |
| return index | |
| # ########## fourier layer ############# | |
| class FourierBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'): | |
| super(FourierBlock, self).__init__() | |
| print('fourier enhanced block used!') | |
| """ | |
| 1D Fourier block. It performs representation learning on frequency domain, | |
| it does FFT, linear transform, and Inverse FFT. | |
| """ | |
| # get modes on frequency domain | |
| self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method) | |
| print('modes={}, index={}'.format(modes, self.index)) | |
| self.scale = (1 / (in_channels * out_channels)) | |
| self.weights1 = nn.Parameter( | |
| self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) | |
| self.weights2 = nn.Parameter( | |
| self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) | |
| # Complex multiplication | |
| def compl_mul1d(self, order, x, weights): | |
| x_flag = True | |
| w_flag = True | |
| if not torch.is_complex(x): | |
| x_flag = False | |
| x = torch.complex(x, torch.zeros_like(x).to(x.device)) | |
| if not torch.is_complex(weights): | |
| w_flag = False | |
| weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) | |
| if x_flag or w_flag: | |
| return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), | |
| torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) | |
| else: | |
| return torch.einsum(order, x.real, weights.real) | |
| def forward(self, q, k, v, mask): | |
| # size = [B, L, H, E] | |
| B, L, H, E = q.shape | |
| x = q.permute(0, 2, 3, 1) | |
| # Compute Fourier coefficients | |
| x_ft = torch.fft.rfft(x, dim=-1) | |
| # Perform Fourier neural operations | |
| out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) | |
| for wi, i in enumerate(self.index): | |
| if i >= x_ft.shape[3] or wi >= out_ft.shape[3]: | |
| continue | |
| out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i], | |
| torch.complex(self.weights1, self.weights2)[:, :, :, wi]) | |
| # Return to time domain | |
| x = torch.fft.irfft(out_ft, n=x.size(-1)) | |
| return (x, None) | |
| # ########## Fourier Cross Former #################### | |
| class FourierCrossAttention(nn.Module): | |
| def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random', | |
| activation='tanh', policy=0, num_heads=8): | |
| super(FourierCrossAttention, self).__init__() | |
| print(' fourier enhanced cross attention used!') | |
| """ | |
| 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT. | |
| """ | |
| self.activation = activation | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| # get modes for queries and keys (& values) on frequency domain | |
| self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method) | |
| self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method) | |
| print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q)) | |
| print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv)) | |
| self.scale = (1 / (in_channels * out_channels)) | |
| self.weights1 = nn.Parameter( | |
| self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) | |
| self.weights2 = nn.Parameter( | |
| self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) | |
| # Complex multiplication | |
| def compl_mul1d(self, order, x, weights): | |
| x_flag = True | |
| w_flag = True | |
| if not torch.is_complex(x): | |
| x_flag = False | |
| x = torch.complex(x, torch.zeros_like(x).to(x.device)) | |
| if not torch.is_complex(weights): | |
| w_flag = False | |
| weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) | |
| if x_flag or w_flag: | |
| return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), | |
| torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) | |
| else: | |
| return torch.einsum(order, x.real, weights.real) | |
| def forward(self, q, k, v, mask): | |
| # size = [B, L, H, E] | |
| B, L, H, E = q.shape | |
| xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] | |
| xk = k.permute(0, 2, 3, 1) | |
| xv = v.permute(0, 2, 3, 1) | |
| # Compute Fourier coefficients | |
| xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) | |
| xq_ft = torch.fft.rfft(xq, dim=-1) | |
| for i, j in enumerate(self.index_q): | |
| if j >= xq_ft.shape[3]: | |
| continue | |
| xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] | |
| xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat) | |
| xk_ft = torch.fft.rfft(xk, dim=-1) | |
| for i, j in enumerate(self.index_kv): | |
| if j >= xk_ft.shape[3]: | |
| continue | |
| xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] | |
| # perform attention mechanism on frequency domain | |
| xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)) | |
| if self.activation == 'tanh': | |
| xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh()) | |
| elif self.activation == 'softmax': | |
| xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) | |
| xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) | |
| else: | |
| raise Exception('{} actiation function is not implemented'.format(self.activation)) | |
| xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) | |
| xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)) | |
| out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) | |
| for i, j in enumerate(self.index_q): | |
| if i >= xqkvw.shape[3] or j >= out_ft.shape[3]: | |
| continue | |
| out_ft[:, :, :, j] = xqkvw[:, :, :, i] | |
| # Return to time domain | |
| out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) | |
| return (out, None) | |