Spaces:
Runtime error
Runtime error
| ''' | |
| Utilities for instrumenting a torch model. | |
| InstrumentedModel will wrap a pytorch model and allow hooking | |
| arbitrary layers to monitor or modify their output directly. | |
| Modified by Erik Härkönen: | |
| - 29.11.2019: Unhooking bugfix | |
| - 25.01.2020: Offset edits, removed old API | |
| ''' | |
| import torch, numpy, types | |
| from collections import OrderedDict | |
| class InstrumentedModel(torch.nn.Module): | |
| ''' | |
| A wrapper for hooking, probing and intervening in pytorch Modules. | |
| Example usage: | |
| ``` | |
| model = load_my_model() | |
| with inst as InstrumentedModel(model): | |
| inst.retain_layer(layername) | |
| inst.edit_layer(layername, 0.5, target_features) | |
| inst.edit_layer(layername, offset=offset_tensor) | |
| inst(inputs) | |
| original_features = inst.retained_layer(layername) | |
| ``` | |
| ''' | |
| def __init__(self, model): | |
| super(InstrumentedModel, self).__init__() | |
| self.model = model | |
| self._retained = OrderedDict() | |
| self._ablation = {} | |
| self._replacement = {} | |
| self._offset = {} | |
| self._hooked_layer = {} | |
| self._old_forward = {} | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, type, value, traceback): | |
| self.close() | |
| def forward(self, *inputs, **kwargs): | |
| return self.model(*inputs, **kwargs) | |
| def retain_layer(self, layername): | |
| ''' | |
| Pass a fully-qualified layer name (E.g., module.submodule.conv3) | |
| to hook that layer and retain its output each time the model is run. | |
| A pair (layername, aka) can be provided, and the aka will be used | |
| as the key for the retained value instead of the layername. | |
| ''' | |
| self.retain_layers([layername]) | |
| def retain_layers(self, layernames): | |
| ''' | |
| Retains a list of a layers at once. | |
| ''' | |
| self.add_hooks(layernames) | |
| for layername in layernames: | |
| aka = layername | |
| if not isinstance(aka, str): | |
| layername, aka = layername | |
| if aka not in self._retained: | |
| self._retained[aka] = None | |
| def retained_features(self): | |
| ''' | |
| Returns a dict of all currently retained features. | |
| ''' | |
| return OrderedDict(self._retained) | |
| def retained_layer(self, aka=None, clear=False): | |
| ''' | |
| Retrieve retained data that was previously hooked by retain_layer. | |
| Call this after the model is run. If clear is set, then the | |
| retained value will return and also cleared. | |
| ''' | |
| if aka is None: | |
| # Default to the first retained layer. | |
| aka = next(self._retained.keys().__iter__()) | |
| result = self._retained[aka] | |
| if clear: | |
| self._retained[aka] = None | |
| return result | |
| def edit_layer(self, layername, ablation=None, replacement=None, offset=None): | |
| ''' | |
| Pass a fully-qualified layer name (E.g., module.submodule.conv3) | |
| to hook that layer and modify its output each time the model is run. | |
| The output of the layer will be modified to be a convex combination | |
| of the replacement and x interpolated according to the ablation, i.e.: | |
| `output = x * (1 - a) + (r * a)`. | |
| Additionally or independently, an offset can be added to the output. | |
| ''' | |
| if not isinstance(layername, str): | |
| layername, aka = layername | |
| else: | |
| aka = layername | |
| # The default ablation if a replacement is specified is 1.0. | |
| if ablation is None and replacement is not None: | |
| ablation = 1.0 | |
| self.add_hooks([(layername, aka)]) | |
| if ablation is not None: | |
| self._ablation[aka] = ablation | |
| if replacement is not None: | |
| self._replacement[aka] = replacement | |
| if offset is not None: | |
| self._offset[aka] = offset | |
| # If needed, could add an arbitrary postprocessing lambda here. | |
| def remove_edits(self, layername=None, remove_offset=True, remove_replacement=True): | |
| ''' | |
| Removes edits at the specified layer, or removes edits at all layers | |
| if no layer name is specified. | |
| ''' | |
| if layername is None: | |
| if remove_replacement: | |
| self._ablation.clear() | |
| self._replacement.clear() | |
| if remove_offset: | |
| self._offset.clear() | |
| return | |
| if not isinstance(layername, str): | |
| layername, aka = layername | |
| else: | |
| aka = layername | |
| if remove_replacement and aka in self._ablation: | |
| del self._ablation[aka] | |
| if remove_replacement and aka in self._replacement: | |
| del self._replacement[aka] | |
| if remove_offset and aka in self._offset: | |
| del self._offset[aka] | |
| def add_hooks(self, layernames): | |
| ''' | |
| Sets up a set of layers to be hooked. | |
| Usually not called directly: use edit_layer or retain_layer instead. | |
| ''' | |
| needed = set() | |
| aka_map = {} | |
| for name in layernames: | |
| aka = name | |
| if not isinstance(aka, str): | |
| name, aka = name | |
| if self._hooked_layer.get(aka, None) != name: | |
| aka_map[name] = aka | |
| needed.add(name) | |
| if not needed: | |
| return | |
| for name, layer in self.model.named_modules(): | |
| if name in aka_map: | |
| needed.remove(name) | |
| aka = aka_map[name] | |
| self._hook_layer(layer, name, aka) | |
| for name in needed: | |
| raise ValueError('Layer %s not found in model' % name) | |
| def _hook_layer(self, layer, layername, aka): | |
| ''' | |
| Internal method to replace a forward method with a closure that | |
| intercepts the call, and tracks the hook so that it can be reverted. | |
| ''' | |
| if aka in self._hooked_layer: | |
| raise ValueError('Layer %s already hooked' % aka) | |
| if layername in self._old_forward: | |
| raise ValueError('Layer %s already hooked' % layername) | |
| self._hooked_layer[aka] = layername | |
| self._old_forward[layername] = (layer, aka, | |
| layer.__dict__.get('forward', None)) | |
| editor = self | |
| original_forward = layer.forward | |
| def new_forward(self, *inputs, **kwargs): | |
| original_x = original_forward(*inputs, **kwargs) | |
| x = editor._postprocess_forward(original_x, aka) | |
| return x | |
| layer.forward = types.MethodType(new_forward, layer) | |
| def _unhook_layer(self, aka): | |
| ''' | |
| Internal method to remove a hook, restoring the original forward method. | |
| ''' | |
| if aka not in self._hooked_layer: | |
| return | |
| layername = self._hooked_layer[aka] | |
| layer, check, old_forward = self._old_forward[layername] | |
| assert check == aka | |
| if old_forward is None: | |
| if 'forward' in layer.__dict__: | |
| del layer.__dict__['forward'] | |
| else: | |
| layer.forward = old_forward | |
| del self._old_forward[layername] | |
| del self._hooked_layer[aka] | |
| if aka in self._ablation: | |
| del self._ablation[aka] | |
| if aka in self._replacement: | |
| del self._replacement[aka] | |
| if aka in self._offset: | |
| del self._offset[aka] | |
| if aka in self._retained: | |
| del self._retained[aka] | |
| def _postprocess_forward(self, x, aka): | |
| ''' | |
| The internal method called by the hooked layers after they are run. | |
| ''' | |
| # Retain output before edits, if desired. | |
| if aka in self._retained: | |
| self._retained[aka] = x.detach() | |
| # Apply replacement edit | |
| a = make_matching_tensor(self._ablation, aka, x) | |
| if a is not None: | |
| x = x * (1 - a) | |
| v = make_matching_tensor(self._replacement, aka, x) | |
| if v is not None: | |
| x += (v * a) | |
| # Apply offset edit | |
| b = make_matching_tensor(self._offset, aka, x) | |
| if b is not None: | |
| x = x + b | |
| return x | |
| def close(self): | |
| ''' | |
| Unhooks all hooked layers in the model. | |
| ''' | |
| for aka in list(self._old_forward.keys()): | |
| self._unhook_layer(aka) | |
| assert len(self._old_forward) == 0 | |
| def make_matching_tensor(valuedict, name, data): | |
| ''' | |
| Converts `valuedict[name]` to be a tensor with the same dtype, device, | |
| and dimension count as `data`, and caches the converted tensor. | |
| ''' | |
| v = valuedict.get(name, None) | |
| if v is None: | |
| return None | |
| if not isinstance(v, torch.Tensor): | |
| # Accept non-torch data. | |
| v = torch.from_numpy(numpy.array(v)) | |
| valuedict[name] = v | |
| if not v.device == data.device or not v.dtype == data.dtype: | |
| # Ensure device and type matches. | |
| assert not v.requires_grad, '%s wrong device or type' % (name) | |
| v = v.to(device=data.device, dtype=data.dtype) | |
| valuedict[name] = v | |
| if len(v.shape) < len(data.shape): | |
| # Ensure dimensions are unsqueezed as needed. | |
| assert not v.requires_grad, '%s wrong dimensions' % (name) | |
| v = v.view((1,) + tuple(v.shape) + | |
| (1,) * (len(data.shape) - len(v.shape) - 1)) | |
| valuedict[name] = v | |
| return v | |