Spaces:
Running
Running
| import torch.nn as nn | |
| ACTIVATION_FUNCTIONS = { | |
| "elu": nn.ELU(), | |
| "swish": nn.SiLU(), | |
| "silu": nn.SiLU(), | |
| "mish": nn.Mish(), | |
| "gelu": nn.GELU(), | |
| "relu": nn.ReLU(), | |
| } | |
| def get_activation(act_fn: str) -> nn.Module: | |
| """Helper function to get activation function from string. | |
| Args: | |
| act_fn (str): Name of activation function. | |
| Returns: | |
| nn.Module: Activation function. | |
| """ | |
| act_fn = act_fn.lower() | |
| if act_fn in ACTIVATION_FUNCTIONS: | |
| return ACTIVATION_FUNCTIONS[act_fn] | |
| else: | |
| raise ValueError(f"Unsupported activation function: {act_fn}") | |