Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from utils.torch_utilities import load_pretrained_model, merge_matched_keys | |
| class LoadPretrainedBase(nn.Module): | |
| def process_state_dict( | |
| self, model_dict: dict[str, torch.Tensor], | |
| state_dict: dict[str, torch.Tensor] | |
| ): | |
| """ | |
| Custom processing functions of each model that transforms `state_dict` loaded from | |
| checkpoints to the state that can be used in `load_state_dict`. | |
| Use `merge_mathced_keys` to update parameters with matched names and shapes by | |
| default. | |
| Args | |
| model_dict: | |
| The state dict of the current model, which is going to load pretrained parameters | |
| state_dict: | |
| A dictionary of parameters from a pre-trained model. | |
| Returns: | |
| dict[str, torch.Tensor]: | |
| The updated state dict, where parameters with matched keys and shape are | |
| updated with values in `state_dict`. | |
| """ | |
| state_dict = merge_matched_keys(model_dict, state_dict) | |
| return state_dict | |
| def load_pretrained(self, ckpt_path: str | Path): | |
| load_pretrained_model( | |
| self, ckpt_path, state_dict_process_fn=self.process_state_dict | |
| ) | |
| class CountParamsBase(nn.Module): | |
| def count_params(self): | |
| num_params = 0 | |
| trainable_params = 0 | |
| for param in self.parameters(): | |
| num_params += param.numel() | |
| if param.requires_grad: | |
| trainable_params += param.numel() | |
| return num_params, trainable_params | |
| class SaveTrainableParamsBase(nn.Module): | |
| def param_names_to_save(self): | |
| names = [] | |
| for name, param in self.named_parameters(): | |
| if param.requires_grad: | |
| names.append(name) | |
| for name, _ in self.named_buffers(): | |
| names.append(name) | |
| return names | |
| def load_state_dict(self, state_dict, strict=True): | |
| for key in self.param_names_to_save: | |
| if key not in state_dict: | |
| raise Exception( | |
| f"{key} not found in either pre-trained models (e.g. BERT)" | |
| " or resumed checkpoints (e.g. epoch_40/model.pt)" | |
| ) | |
| return super().load_state_dict(state_dict, strict) | |