Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2022 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # An implementation of SM3 from: | |
| # | |
| # Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf | |
| # Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer | |
| # | |
| # Author: Rohan Anil (rohananil at google dot com) | |
| # | |
| """SM3 Implementation.""" | |
| import functools | |
| from typing import Any, NamedTuple | |
| import chex | |
| import jax | |
| import jax.numpy as jnp | |
| import optax | |
| from .quantization_utils import QuantizedValue | |
| class SM3State(NamedTuple): | |
| count: chex.Array | |
| stats: Any | |
| # Per parameter optimizer state used in data-parallel training. | |
| class ParameterStats(NamedTuple): | |
| """State associated to each parameter of the model being trained.""" | |
| diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner | |
| diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner | |
| def sm3( | |
| learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False | |
| ): | |
| """SM3 optimizer. | |
| Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren, | |
| Yoram Singer | |
| https://arxiv.org/abs/1901.11150 | |
| Args: | |
| learning_rate: the step size used to update the parameters. | |
| beta1: momentum parameter. | |
| beta2: second moment averaging parameter. | |
| diagonal_epsilon: epsilon for sm3 | |
| normalize_grads: Whether to normalize grads. Author finds it useful when | |
| grads are high variance. | |
| Returns: | |
| a GradientTransformation. | |
| """ | |
| def _quantize_momentum(momentum_statistics): | |
| return QuantizedValue.from_float_value(momentum_statistics, jnp.int8) | |
| def init_fn(params): | |
| """Initialise the optimiser's state.""" | |
| def _init(param): | |
| accumulators = [jnp.zeros([s]) for s in param.shape] | |
| momentum = _quantize_momentum(jnp.zeros_like(param)) | |
| return ParameterStats(accumulators, momentum) | |
| return SM3State( | |
| count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params) | |
| ) | |
| def _get_expanded_shape(shape, i): | |
| rank = len(shape) | |
| # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. | |
| # For eg: i = 1 returns [1, N, 1]. | |
| return [1] * i + [shape[i]] + [1] * (rank - i - 1) | |
| def _moving_averages(grad, accumulators): | |
| w = (1.0 - beta2) if beta2 != 1.0 else 1.0 | |
| if grad.ndim < 2: | |
| return beta2 * accumulators[0] + w * grad**2 | |
| else: | |
| min_accumulator = functools.reduce(jnp.minimum, accumulators) | |
| return beta2 * min_accumulator + w * grad**2 | |
| def _moving_averages_momentum(grad, momentum): | |
| w = (1.0 - beta1) if beta1 != 1.0 else 1.0 | |
| return beta1 * momentum.to_float() + w * grad | |
| def _sketch_diagonal_statistics(grad, updated_diagonal_statistics): | |
| all_diagonal_statistics = [] | |
| for i in range(grad.ndim): | |
| axes = list(range(i)) + list(range(i + 1, grad.ndim)) | |
| dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes) | |
| all_diagonal_statistics.append(dim_diagonal_statistics) | |
| if grad.ndim == 1: | |
| all_diagonal_statistics[0] = updated_diagonal_statistics | |
| return all_diagonal_statistics | |
| def update_fn(updates, state, params=None): | |
| del params | |
| stats = state.stats | |
| if normalize_grads: | |
| updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates) | |
| # Reshape all vectors into N-d tensors to compute min over them. | |
| # [n], [m] -> [n, 1], [1, m] | |
| expanded_diagonal_statistics = jax.tree_multimap( | |
| lambda grad, state: [ # pylint:disable=g-long-lambda | |
| jnp.reshape( | |
| state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i) | |
| ) | |
| for i in range(grad.ndim) | |
| ], | |
| updates, | |
| stats, | |
| ) | |
| # Compute new diagonal statistics | |
| new_diagonal_statistics = jax.tree_multimap( | |
| _moving_averages, updates, expanded_diagonal_statistics | |
| ) | |
| # Compute preconditioners (1/sqrt(s)) where s is the statistics. | |
| new_preconditioners = jax.tree_map( | |
| lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics | |
| ) | |
| preconditioned_grads = jax.tree_multimap( | |
| lambda g, p: g * p, updates, new_preconditioners | |
| ) | |
| # Compute updated momentum (also handle quantization) | |
| updated_momentum = jax.tree_multimap( | |
| lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda | |
| preconditioned_grad, state.diagonal_momentum | |
| ), | |
| preconditioned_grads, | |
| stats, | |
| ) | |
| # Update diagonal statistics. | |
| updated_diagonal_statistics = jax.tree_multimap( | |
| _sketch_diagonal_statistics, updates, new_diagonal_statistics | |
| ) | |
| # Update momentum. | |
| new_sm3_stats = jax.tree_multimap( | |
| lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda | |
| diagonal_stats, _quantize_momentum(momentum) | |
| ), | |
| updated_momentum, | |
| updated_diagonal_statistics, | |
| ) | |
| lr = learning_rate | |
| if callable(learning_rate): | |
| lr = learning_rate(state.count) | |
| new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum) | |
| return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats) | |
| return optax.GradientTransformation(init_fn, update_fn) | |