Spaces:
Runtime error
Runtime error
| ''' | |
| Utilities for dealing with simple state dicts as npz files instead of pth files. | |
| ''' | |
| import torch | |
| from collections.abc import MutableMapping, Mapping | |
| def load_from_numpy_dict(model, numpy_dict, prefix='', examples=None): | |
| ''' | |
| Loads a model from numpy_dict using load_state_dict. | |
| Converts numpy types to torch types using the current state_dict | |
| of the model to determine types and devices for the tensors. | |
| Supports loading a subdict by prepending the given prefix to all keys. | |
| ''' | |
| if prefix: | |
| if not prefix.endswith('.'): | |
| prefix = prefix + '.' | |
| numpy_dict = PrefixSubDict(numpy_dict, prefix) | |
| if examples is None: | |
| exampels = model.state_dict() | |
| torch_state_dict = TorchTypeMatchingDict(numpy_dict, examples) | |
| model.load_state_dict(torch_state_dict) | |
| def save_to_numpy_dict(model, numpy_dict, prefix=''): | |
| ''' | |
| Saves a model by copying tensors to numpy_dict. | |
| Converts torch types to numpy types using `t.detach().cpu().numpy()`. | |
| Supports saving a subdict by prepending the given prefix to all keys. | |
| ''' | |
| if prefix: | |
| if not prefix.endswith('.'): | |
| prefix = prefix + '.' | |
| for k, v in model.numpy_dict().items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.detach().cpu().numpy() | |
| numpy_dict[prefix + k] = v | |
| class TorchTypeMatchingDict(Mapping): | |
| ''' | |
| Provides a view of a dict of numpy values as torch tensors, where the | |
| types are converted to match the types and devices in the given | |
| dict of examples. | |
| ''' | |
| def __init__(self, data, examples): | |
| self.data = data | |
| self.examples = examples | |
| self.cached_data = {} | |
| def __getitem__(self, key): | |
| if key in self.cached_data: | |
| return self.cached_data[key] | |
| val = self.data[key] | |
| if key not in self.examples: | |
| return val | |
| example = self.examples.get(key, None) | |
| example_type = type(example) | |
| if example is not None and type(val) != example_type: | |
| if isinstance(example, torch.Tensor): | |
| val = torch.from_numpy(val) | |
| else: | |
| val = example_type(val) | |
| if isinstance(example, torch.Tensor): | |
| val = val.to(dtype=example.dtype, device=example.device) | |
| self.cached_data[key] = val | |
| return val | |
| def __iter__(self): | |
| return self.data.keys() | |
| def __len__(self): | |
| return len(self.data) | |
| class PrefixSubDict(MutableMapping): | |
| ''' | |
| Provides a view of the subset of a dict where string keys begin with | |
| the given prefix. The prefix is stripped from all keys of the view. | |
| ''' | |
| def __init__(self, data, prefix=''): | |
| self.data = data | |
| self.prefix = prefix | |
| self._cached_keys = None | |
| def __getitem__(self, key): | |
| return self.data[self.prefix + key] | |
| def __setitem__(self, key, value): | |
| pkey = self.prefix + key | |
| if self._cached_keys is not None and pkey not in self.data: | |
| self._cached_keys = None | |
| self.data[pkey] = value | |
| def __delitem__(self, key): | |
| pkey = self.prefix + key | |
| if self._cached_keys is not None and pkey in self.data: | |
| self._cached_keys = None | |
| del self.data[pkey] | |
| def __cached_keys(self): | |
| if self._cached_keys is None: | |
| plen = len(self.prefix) | |
| self._cached_keys = list(k[plen:] for k in self.data | |
| if k.startswith(self.prefix)) | |
| return self._cached_keys | |
| def __iter__(self): | |
| return iter(self.__cached_keys()) | |
| def __len__(self): | |
| return len(self.__cached_keys()) | |