Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.utilities import grad_norm | |
| from torch.optim import Optimizer | |
| class GradientMonitor(pl.Callback): | |
| """Logs the gradient norm""" | |
| def __init__(self, norm_type: int = 2): | |
| norm_type = float(norm_type) | |
| if norm_type <= 0: | |
| raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") | |
| self.norm_type = norm_type | |
| def on_before_optimizer_step( | |
| self, trainer: "pl.Trainer", | |
| pl_module: "pl.LightningModule", | |
| optimizer: Optimizer | |
| ) -> None: | |
| norms = grad_norm(pl_module, norm_type=self.norm_type) | |
| max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max() | |
| pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) |