Spaces:
Configuration error
Configuration error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -*- encoding: utf-8 -*- | |
| import torch | |
| import torch.nn as nn | |
| from asteroid_filterbanks import Encoder, ParamSincFB | |
| from .RawNetBasicBlock import Bottle2neck, PreEmphasis | |
| class RawNet3(nn.Module): | |
| def __init__(self, block, model_scale, context, summed, C=1024, **kwargs): | |
| super().__init__() | |
| nOut = kwargs["nOut"] | |
| self.context = context | |
| self.encoder_type = kwargs["encoder_type"] | |
| self.log_sinc = kwargs["log_sinc"] | |
| self.norm_sinc = kwargs["norm_sinc"] | |
| self.out_bn = kwargs["out_bn"] | |
| self.summed = summed | |
| self.preprocess = nn.Sequential( | |
| PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True) | |
| ) | |
| self.conv1 = Encoder( | |
| ParamSincFB( | |
| C // 4, | |
| 251, | |
| stride=kwargs["sinc_stride"], | |
| ) | |
| ) | |
| self.relu = nn.ReLU() | |
| self.bn1 = nn.BatchNorm1d(C // 4) | |
| self.layer1 = block( | |
| C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5 | |
| ) | |
| self.layer2 = block(C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3) | |
| self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale) | |
| self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1) | |
| if self.context: | |
| attn_input = 1536 * 3 | |
| else: | |
| attn_input = 1536 | |
| print("self.encoder_type", self.encoder_type) | |
| if self.encoder_type == "ECA": | |
| attn_output = 1536 | |
| elif self.encoder_type == "ASP": | |
| attn_output = 1 | |
| else: | |
| raise ValueError("Undefined encoder") | |
| self.attention = nn.Sequential( | |
| nn.Conv1d(attn_input, 128, kernel_size=1), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(128), | |
| nn.Conv1d(128, attn_output, kernel_size=1), | |
| nn.Softmax(dim=2), | |
| ) | |
| self.bn5 = nn.BatchNorm1d(3072) | |
| self.fc6 = nn.Linear(3072, nOut) | |
| self.bn6 = nn.BatchNorm1d(nOut) | |
| self.mp3 = nn.MaxPool1d(3) | |
| def forward(self, x): | |
| """ | |
| :param x: input mini-batch (bs, samp) | |
| """ | |
| with torch.cuda.amp.autocast(enabled=False): | |
| x = self.preprocess(x) | |
| x = torch.abs(self.conv1(x)) | |
| if self.log_sinc: | |
| x = torch.log(x + 1e-6) | |
| if self.norm_sinc == "mean": | |
| x = x - torch.mean(x, dim=-1, keepdim=True) | |
| elif self.norm_sinc == "mean_std": | |
| m = torch.mean(x, dim=-1, keepdim=True) | |
| s = torch.std(x, dim=-1, keepdim=True) | |
| s[s < 0.001] = 0.001 | |
| x = (x - m) / s | |
| if self.summed: | |
| x1 = self.layer1(x) | |
| x2 = self.layer2(x1) | |
| x3 = self.layer3(self.mp3(x1) + x2) | |
| else: | |
| x1 = self.layer1(x) | |
| x2 = self.layer2(x1) | |
| x3 = self.layer3(x2) | |
| x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1)) | |
| x = self.relu(x) | |
| t = x.size()[-1] | |
| if self.context: | |
| global_x = torch.cat( | |
| ( | |
| x, | |
| torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t), | |
| torch.sqrt( | |
| torch.var(x, dim=2, keepdim=True).clamp(min=1e-4, max=1e4) | |
| ).repeat(1, 1, t), | |
| ), | |
| dim=1, | |
| ) | |
| else: | |
| global_x = x | |
| w = self.attention(global_x) | |
| mu = torch.sum(x * w, dim=2) | |
| sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)) | |
| x = torch.cat((mu, sg), 1) | |
| x = self.bn5(x) | |
| x = self.fc6(x) | |
| if self.out_bn: | |
| x = self.bn6(x) | |
| return x | |
| def MainModel(**kwargs): | |
| model = RawNet3(Bottle2neck, model_scale=8, context=True, summed=True, **kwargs) | |
| return model | |