STAR / models /common.py
Yixuan Li
first commit
4853fdc
from pathlib import Path
import torch
import torch.nn as nn
from utils.torch_utilities import load_pretrained_model, merge_matched_keys
class LoadPretrainedBase(nn.Module):
def process_state_dict(
self, model_dict: dict[str, torch.Tensor],
state_dict: dict[str, torch.Tensor]
):
"""
Custom processing functions of each model that transforms `state_dict` loaded from
checkpoints to the state that can be used in `load_state_dict`.
Use `merge_mathced_keys` to update parameters with matched names and shapes by
default.
Args
model_dict:
The state dict of the current model, which is going to load pretrained parameters
state_dict:
A dictionary of parameters from a pre-trained model.
Returns:
dict[str, torch.Tensor]:
The updated state dict, where parameters with matched keys and shape are
updated with values in `state_dict`.
"""
state_dict = merge_matched_keys(model_dict, state_dict)
return state_dict
def load_pretrained(self, ckpt_path: str | Path):
load_pretrained_model(
self, ckpt_path, state_dict_process_fn=self.process_state_dict
)
class CountParamsBase(nn.Module):
def count_params(self):
num_params = 0
trainable_params = 0
for param in self.parameters():
num_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return num_params, trainable_params
class SaveTrainableParamsBase(nn.Module):
@property
def param_names_to_save(self):
names = []
for name, param in self.named_parameters():
if param.requires_grad:
names.append(name)
for name, _ in self.named_buffers():
names.append(name)
return names
def load_state_dict(self, state_dict, strict=True):
for key in self.param_names_to_save:
if key not in state_dict:
raise Exception(
f"{key} not found in either pre-trained models (e.g. BERT)"
" or resumed checkpoints (e.g. epoch_40/model.pt)"
)
return super().load_state_dict(state_dict, strict)