| from typing import Tuple | |
| from transformers import BartConfig | |
| class TTCompressedBartConfig(BartConfig): | |
| """Class TTCompressedBartConfig defines a configuration for TT-compressed | |
| BART. Here, we split shape to input and output shape in order to serialize | |
| them to different fields in JSON. | |
| """ | |
| def __init__(self, *args, shape_in: Tuple[int] = (), | |
| shape_out: Tuple[int] = (), rank: int = 128, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.shape_in = shape_in | |
| self.shape_out = shape_out | |
| self.rank = rank | |
| TTCompressedBartConfig.register_for_auto_class() | |