Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from fairseq.models import BaseFairseqModel, register_model | |
| from fairseq.models.wav2vec.wav2vec2_asr import ( | |
| Wav2Vec2CtcConfig, | |
| Wav2VecCtc, | |
| Wav2VecEncoder, | |
| ) | |
| from fairseq.tasks import FairseqTask | |
| class Wav2VecLaser(Wav2VecCtc): | |
| def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): | |
| super().__init__(cfg, w2v_encoder) | |
| self.num_updates = 0 | |
| self.freeze_finetune_updates = cfg.freeze_finetune_updates | |
| def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask): | |
| """Build a new model instance.""" | |
| w2v_encoder = Wav2VecEncoder(cfg, 1024) | |
| return cls(cfg, w2v_encoder) | |
| def forward(self, **kwargs): | |
| output = super().forward(**kwargs) | |
| x_out = output["encoder_out"] * 0.01 | |
| out_pad_mask = output["padding_mask"] | |
| # Set padded outputs to -inf so they are not selected by max-pooling | |
| if out_pad_mask is not None and out_pad_mask.any(): | |
| x_out = ( | |
| x_out.float() | |
| .masked_fill_(out_pad_mask.T.unsqueeze(-1), float("-inf")) | |
| .type_as(x_out) | |
| ) | |
| return x_out.max(dim=0)[0] | |