Spaces:
Running
on
Zero
Running
on
Zero
| import os.path | |
| from typing import Optional, Dict, Any | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint | |
| from soupsieve.util import lower | |
| class CheckpointHook(ModelCheckpoint): | |
| """Save checkpoint with only the incremental part of the model""" | |
| def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: | |
| self.dirpath = trainer.default_root_dir | |
| self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt") | |
| pl_module.strict_loading = False | |
| def on_save_checkpoint( | |
| self, trainer: "pl.Trainer", | |
| pl_module: "pl.LightningModule", | |
| checkpoint: Dict[str, Any] | |
| ) -> None: | |
| del checkpoint["callbacks"] |