Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import uuid | |
| from typing import Dict, Optional | |
| from torch import Tensor | |
| class FairseqIncrementalState(object): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.init_incremental_state() | |
| def init_incremental_state(self): | |
| self._incremental_state_id = str(uuid.uuid4()) | |
| def _get_full_incremental_state_key(self, key: str) -> str: | |
| return "{}.{}".format(self._incremental_state_id, key) | |
| def get_incremental_state( | |
| self, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
| key: str, | |
| ) -> Optional[Dict[str, Optional[Tensor]]]: | |
| """Helper for getting incremental state for an nn.Module.""" | |
| full_key = self._get_full_incremental_state_key(key) | |
| if incremental_state is None or full_key not in incremental_state: | |
| return None | |
| return incremental_state[full_key] | |
| def set_incremental_state( | |
| self, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
| key: str, | |
| value: Dict[str, Optional[Tensor]], | |
| ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: | |
| """Helper for setting incremental state for an nn.Module.""" | |
| if incremental_state is not None: | |
| full_key = self._get_full_incremental_state_key(key) | |
| incremental_state[full_key] = value | |
| return incremental_state | |
| def with_incremental_state(cls): | |
| cls.__bases__ = (FairseqIncrementalState,) + tuple( | |
| b for b in cls.__bases__ if b != FairseqIncrementalState | |
| ) | |
| return cls | |