Spaces:
Runtime error
Runtime error
| import torch | |
| from mamba_ssm import Mamba | |
| batch, length, dim = 2, 64, 768 | |
| x = torch.randn(batch, length, dim).to("cuda") | |
| model = Mamba( | |
| # This module uses roughly 3 * expand * d_model^2 parameters | |
| d_model=dim, # Model dimension d_model | |
| d_state=16, # SSM state expansion factor # 64 | |
| d_conv=4, # Local convolution width | |
| expand=2, # Block expansion factor | |
| use_fast_path=False, | |
| ).to("cuda") | |
| y = model(x) | |
| assert y.shape == x.shape | |