| from flax import linen as nn | |
| import jax | |
| import jax.numpy as jnp | |
| class LocalResponseNorm(nn.Module): | |
| def __call__( | |
| self, | |
| value: jax.Array | |
| ) -> jax.Array: | |
| return value / jnp.repeat(jnp.expand_dims((1e-8 + (value**2).mean(axis=-1))**0.5, axis=-1), repeats=value.shape[-1], axis=-1) |