Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from nncore.nn import Parameter | |
| class Permute(nn.Module): | |
| def forward(self, x): | |
| return x.transpose(-1, -2) | |
| class LearnableEmbedding(nn.Module): | |
| def __init__(self, dims): | |
| super().__init__() | |
| self.weights = Parameter(1, 1, dims) | |
| def forward(self, x): | |
| return x + self.weights.expand_as(x) | |
| class ConvPyramid(nn.Module): | |
| def __init__(self, dims, strides, act_cls=nn.ReLU): | |
| super().__init__() | |
| self.blocks = nn.ModuleList() | |
| for s in strides: | |
| p = int(math.log2(s)) | |
| if p == 0: | |
| layers = act_cls() | |
| else: | |
| conv_cls = nn.Conv1d if p > 0 else nn.ConvTranspose1d | |
| layers = nn.Sequential() | |
| for _ in range(abs(p)): | |
| module = [Permute(), conv_cls(dims, dims, 2, stride=2), Permute(), nn.LayerNorm(dims), act_cls()] | |
| layers.extend(module) | |
| self.blocks.append(layers) | |
| self.strides = strides | |
| def forward(self, x, mask, return_mask=False): | |
| pymid, pymid_msk = [], [] | |
| for s, blk in zip(self.strides, self.blocks): | |
| if x.size(1) < s: | |
| continue | |
| pymid.append(blk(x)) | |
| if return_mask: | |
| if s > 1: | |
| msk = F.max_pool1d(mask.float(), s, stride=s).long() | |
| elif s < 1: | |
| msk = mask.repeat_interleave(int(1 / s), dim=1) | |
| else: | |
| msk = mask | |
| pymid_msk.append(msk) | |
| return (pymid, pymid_msk) if return_mask else pymid | |
| class Scale(nn.Module): | |
| def __init__(self, strides): | |
| super().__init__() | |
| self.scale = nn.Parameter(torch.ones(len(strides))) | |
| def forward(self, x, i): | |
| return x * self.scale[i] | |
| class ConvHead(nn.Module): | |
| def __init__(self, dims, out_dims, kernal_size=3, act_cls=nn.ReLU): | |
| super().__init__() | |
| # yapf:disable | |
| self.module = nn.Sequential( | |
| Permute(), | |
| nn.Conv1d(dims, dims, kernal_size, padding=kernal_size // 2), | |
| act_cls(), | |
| nn.Conv1d(dims, out_dims, kernal_size, padding=kernal_size // 2), | |
| Permute()) | |
| # yapf:enable | |
| def forward(self, x): | |
| return self.module(x) | |