Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class NanDetector: | |
| """ | |
| Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name | |
| """ | |
| def __init__(self, model, forward=True, backward=True): | |
| self.bhooks = [] | |
| self.fhooks = [] | |
| self.forward = forward | |
| self.backward = backward | |
| self.named_parameters = list(model.named_parameters()) | |
| self.reset() | |
| for name, mod in model.named_modules(): | |
| mod.__module_name = name | |
| self.add_hooks(mod) | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_value, exc_traceback): | |
| # Dump out all model gnorms to enable better debugging | |
| norm = {} | |
| gradients = {} | |
| for name, param in self.named_parameters: | |
| if param.grad is not None: | |
| grad_norm = torch.norm(param.grad.data.float(), p=2) | |
| norm[name] = param.norm().item() | |
| if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): | |
| gradients[name] = param.grad.data | |
| if len(gradients) > 0: | |
| logger.info("Detected nan/inf grad norm, dumping norms...") | |
| logger.info(f"norms: {norm}") | |
| logger.info(f"gradients: {gradients}") | |
| self.close() | |
| def add_hooks(self, module): | |
| if self.forward: | |
| self.fhooks.append(module.register_forward_hook(self.fhook_fn)) | |
| if self.backward: | |
| self.bhooks.append(module.register_backward_hook(self.bhook_fn)) | |
| def reset(self): | |
| self.has_printed_f = False | |
| self.has_printed_b = False | |
| def _detect(self, tensor, name, backward): | |
| err = None | |
| if ( | |
| torch.is_floating_point(tensor) | |
| # single value tensors (like the loss) will not provide much info | |
| and tensor.numel() >= 2 | |
| ): | |
| with torch.no_grad(): | |
| if torch.isnan(tensor).any(): | |
| err = "NaN" | |
| elif torch.isinf(tensor).any(): | |
| err = "Inf" | |
| if err is not None: | |
| err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" | |
| return err | |
| def _apply(self, module, inp, x, backward): | |
| if torch.is_tensor(x): | |
| if isinstance(inp, tuple) and len(inp) > 0: | |
| inp = inp[0] | |
| err = self._detect(x, module.__module_name, backward) | |
| if err is not None: | |
| if torch.is_tensor(inp) and not backward: | |
| err += ( | |
| f" input max: {inp.max().item()}, input min: {inp.min().item()}" | |
| ) | |
| has_printed_attr = "has_printed_b" if backward else "has_printed_f" | |
| logger.warning(err) | |
| setattr(self, has_printed_attr, True) | |
| elif isinstance(x, dict): | |
| for v in x.values(): | |
| self._apply(module, inp, v, backward) | |
| elif isinstance(x, list) or isinstance(x, tuple): | |
| for v in x: | |
| self._apply(module, inp, v, backward) | |
| def fhook_fn(self, module, inp, output): | |
| if not self.has_printed_f: | |
| self._apply(module, inp, output, backward=False) | |
| def bhook_fn(self, module, inp, output): | |
| if not self.has_printed_b: | |
| self._apply(module, inp, output, backward=True) | |
| def close(self): | |
| for hook in self.fhooks + self.bhooks: | |
| hook.remove() | |