YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Model Structure

class MambaTransformerSimple(nn.Module):
    def __init__(
        self,
        d_feat: int = 8,
        hidden_size: int = 64,
        num_layers: int = 1,
        dropout: float = 0.0,
        noise_level: float = 0.0,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
        mask_type: str = "none",
    ) -> None:
        super().__init__()
        self.mask_type = mask_type
        self.transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=4,
            dim_feedforward=hidden_size * 4,
            dropout=dropout,
            activation="relu",
            batch_first=False,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.transformer_encoder_layer, num_layers=num_layers
        )
        self.input_proj = nn.Linear(d_feat, hidden_size)
        self.mamba = Mamba(
            d_model=hidden_size, d_state=d_state, d_conv=d_conv, expand=expand
        )
        self.mid_norm = nn.LayerNorm(hidden_size)
        self.out = nn.Sequential(
            nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1)
        )

    def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Generate causal attention mask."""
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1
        )
        return mask

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, t, s, f = x.shape
        x = x.permute(0, 2, 1, 3).reshape(b * s, t, f)
        x = self.input_proj(x)  # [b * s, t, h]
        mamba_out = self.mamba(x)  # [b * s, t, h]
        mamba_out = mamba_out.permute(1, 0, 2).contiguous()  # [t, b * s, h]
        mamba_out = self.mid_norm(mamba_out)

        if self.mask_type == "causal":
            mask = self._generate_causal_mask(t, x.device)
        else:
            mask = None

        tfm_out = self.transformer_encoder(mamba_out, mask=mask)  # [t, b * s, h]
        tfm_out = tfm_out[-1].reshape(b, s, -1)
        final_out = self.out(tfm_out).squeeze(-1)  # [b, s]

        return final_out

Model Config

num_layers: 1
d_feat: 8
hidden_size: 64
d_state: 16
d_conv: 4
expand: 2
dropout: 0.1
noise_level: 0.0
mask_type: "none"
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including Abner0803/Mamba_Transformer_simple