Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import contextlib | |
| from omegaconf import open_dict, OmegaConf | |
| from fairseq.tasks import register_task | |
| from fairseq.tasks.sentence_prediction import ( | |
| SentencePredictionTask, | |
| SentencePredictionConfig, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class SentencePredictionAdapterTask(SentencePredictionTask): | |
| def build_model(self, cfg): | |
| from fairseq import models | |
| with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack(): | |
| cfg.max_positions = self.cfg.max_positions | |
| model = models.build_model(cfg, self) | |
| model.register_classification_head( | |
| self.cfg.classification_head_name, | |
| num_classes=self.cfg.num_classes, | |
| ) | |
| logger.info("Freezing Embedding Parameters") | |
| for parameter in model.encoder.sentence_encoder.embed_positions.parameters(): | |
| parameter.requires_grad = False | |
| for ( | |
| parameter | |
| ) in model.encoder.sentence_encoder.layernorm_embedding.parameters(): | |
| parameter.requires_grad = False | |
| for parameter in model.encoder.sentence_encoder.embed_tokens.parameters(): | |
| parameter.requires_grad = False | |
| logger.info("Freezing Adapters") | |
| for k, v in model.encoder.sentence_encoder.layers._modules.items(): | |
| logger.info("Freezing Adapters in Layer " + str(k)) | |
| if hasattr(v, "adapter_layer_norm"): | |
| logger.info("Freezing Adapter LN") | |
| for parameter in v.adapter_layer_norm.parameters(): | |
| parameter.requires_grad = False | |
| for parameter in v.adapter_modules.parameters(): | |
| parameter.requires_grad = False | |
| return model | |