YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Model Structure

class GRUTransformerSimple(nn.Module):
    def __init__(
        self,
        d_feat: int = 8,
        hidden_size: int = 64,
        num_layers: int = 1,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=4,
            dim_feedforward=hidden_size * 4,
            dropout=dropout,
            activation="relu",
            batch_first=False,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.transformer_encoder_layer, num_layers=num_layers
        )
        self.gru = nn.GRU(
            input_size=d_feat,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
        )
        self.out = nn.Sequential(
            nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, t, s, f = x.shape
        x = x.permute(0, 2, 1, 3).reshape(b * s, t, f)
        gru_out, _ = self.gru(x)  # [b * s, t, h]
        gru_out = gru_out.permute(1, 0, 2).contiguous()  # [t, b * s, h]
        tfm_out = self.transformer_encoder(gru_out)  # [t, b * s, h]
        tfm_out = tfm_out[-1].reshape(b, s, -1)  # [b, s, h]
        final_out = self.out(tfm_out).squeeze(-1)  # [b, s]

        return final_out

Model Config

d_feat: 8
hidden_size: 64
num_layers: 1
dropout: 0.0
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including Abner0803/GRU_Transformer_simple