Spaces:
Runtime error
Runtime error
| # Copyright (c) 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the LICENSE file in | |
| # the root directory of this source tree. An additional grant of patent rights | |
| # can be found in the PATENTS file in the same directory. | |
| from collections import OrderedDict | |
| import itertools | |
| import logging | |
| import os | |
| import sys | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from omegaconf import II, MISSING | |
| from sklearn import metrics as sklearn_metrics | |
| from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset | |
| from fairseq.data.multi_corpus_dataset import MultiCorpusDataset | |
| from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask | |
| from fairseq.tasks.audio_finetuning import label_len_fn, LabelEncoder | |
| from .. import utils | |
| from ..logging import metrics | |
| from . import FairseqTask, register_task | |
| logger = logging.getLogger(__name__) | |
| class AudioClassificationConfig(AudioPretrainingConfig): | |
| target_dictionary: Optional[str] = field( | |
| default=None, metadata={"help": "override default dictionary location"} | |
| ) | |
| class AudioClassificationTask(AudioPretrainingTask): | |
| """Task for audio classification tasks.""" | |
| cfg: AudioClassificationConfig | |
| def __init__( | |
| self, | |
| cfg: AudioClassificationConfig, | |
| ): | |
| super().__init__(cfg) | |
| self.state.add_factory("target_dictionary", self.load_target_dictionary) | |
| logging.info(f"=== Number of labels = {len(self.target_dictionary)}") | |
| def load_target_dictionary(self): | |
| if self.cfg.labels: | |
| target_dictionary = self.cfg.data | |
| if self.cfg.target_dictionary: # override dict | |
| target_dictionary = self.cfg.target_dictionary | |
| dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt") | |
| logger.info("Using dict_path : {}".format(dict_path)) | |
| return Dictionary.load(dict_path, add_special_symbols=False) | |
| return None | |
| def load_dataset( | |
| self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs | |
| ): | |
| super().load_dataset(split, task_cfg, **kwargs) | |
| task_cfg = task_cfg or self.cfg | |
| assert task_cfg.labels is not None | |
| text_compression_level = getattr( | |
| TextCompressionLevel, str(self.cfg.text_compression_level) | |
| ) | |
| data_path = self.cfg.data | |
| if task_cfg.multi_corpus_keys is None: | |
| label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") | |
| skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) | |
| text_compressor = TextCompressor(level=text_compression_level) | |
| with open(label_path, "r") as f: | |
| labels = [ | |
| text_compressor.compress(l) | |
| for i, l in enumerate(f) | |
| if i not in skipped_indices | |
| ] | |
| assert len(labels) == len(self.datasets[split]), ( | |
| f"labels length ({len(labels)}) and dataset length " | |
| f"({len(self.datasets[split])}) do not match" | |
| ) | |
| process_label = LabelEncoder(self.target_dictionary) | |
| self.datasets[split] = AddTargetDataset( | |
| self.datasets[split], | |
| labels, | |
| pad=self.target_dictionary.pad(), | |
| eos=self.target_dictionary.eos(), | |
| batch_targets=True, | |
| process_label=process_label, | |
| label_len_fn=label_len_fn, | |
| add_to_input=False, | |
| # text_compression_level=text_compression_level, | |
| ) | |
| else: | |
| target_dataset_map = OrderedDict() | |
| multi_corpus_keys = [ | |
| k.strip() for k in task_cfg.multi_corpus_keys.split(",") | |
| ] | |
| corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} | |
| data_keys = [k.split(":") for k in split.split(",")] | |
| multi_corpus_sampling_weights = [ | |
| float(val.strip()) | |
| for val in task_cfg.multi_corpus_sampling_weights.split(",") | |
| ] | |
| data_weights = [] | |
| for key, file_name in data_keys: | |
| k = key.strip() | |
| label_path = os.path.join( | |
| data_path, f"{file_name.strip()}.{task_cfg.labels}" | |
| ) | |
| skipped_indices = getattr( | |
| self.dataset_map[split][k], "skipped_indices", set() | |
| ) | |
| text_compressor = TextCompressor(level=text_compression_level) | |
| with open(label_path, "r") as f: | |
| labels = [ | |
| text_compressor.compress(l) | |
| for i, l in enumerate(f) | |
| if i not in skipped_indices | |
| ] | |
| assert len(labels) == len(self.dataset_map[split][k]), ( | |
| f"labels length ({len(labels)}) and dataset length " | |
| f"({len(self.dataset_map[split][k])}) do not match" | |
| ) | |
| process_label = LabelEncoder(self.target_dictionary) | |
| # TODO: Remove duplication of code from the if block above | |
| target_dataset_map[k] = AddTargetDataset( | |
| self.dataset_map[split][k], | |
| labels, | |
| pad=self.target_dictionary.pad(), | |
| eos=self.target_dictionary.eos(), | |
| batch_targets=True, | |
| process_label=process_label, | |
| label_len_fn=label_len_fn, | |
| add_to_input=False, | |
| # text_compression_level=text_compression_level, | |
| ) | |
| data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) | |
| if len(target_dataset_map) == 1: | |
| self.datasets[split] = list(target_dataset_map.values())[0] | |
| else: | |
| self.datasets[split] = MultiCorpusDataset( | |
| target_dataset_map, | |
| distribution=data_weights, | |
| seed=0, | |
| sort_indices=True, | |
| ) | |
| def source_dictionary(self): | |
| return None | |
| def target_dictionary(self): | |
| """Return the :class:`~fairseq.data.Dictionary` for the language | |
| model.""" | |
| return self.state.target_dictionary | |
| def train_step(self, sample, model, *args, **kwargs): | |
| sample["target"] = sample["target"].to(dtype=torch.long) | |
| loss, sample_size, logging_output = super().train_step( | |
| sample, model, *args, **kwargs | |
| ) | |
| self._log_metrics(sample, model, logging_output) | |
| return loss, sample_size, logging_output | |
| def valid_step(self, sample, model, criterion): | |
| sample["target"] = sample["target"].to(dtype=torch.long) | |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
| self._log_metrics(sample, model, logging_output) | |
| return loss, sample_size, logging_output | |
| def _log_metrics(self, sample, model, logging_output): | |
| metrics = self._inference_with_metrics( | |
| sample, | |
| model, | |
| ) | |
| """ | |
| logging_output["_precision"] = metrics["precision"] | |
| logging_output["_recall"] = metrics["recall"] | |
| logging_output["_f1"] = metrics["f1"] | |
| logging_output["_eer"] = metrics["eer"] | |
| logging_output["_accuracy"] = metrics["accuracy"] | |
| """ | |
| logging_output["_correct"] = metrics["correct"] | |
| logging_output["_total"] = metrics["total"] | |
| def _inference_with_metrics(self, sample, model): | |
| def _compute_eer(target_list, lprobs): | |
| # from scipy.optimize import brentq | |
| # from scipy.interpolate import interp1d | |
| y_one_hot = np.eye(len(self.state.target_dictionary))[target_list] | |
| fpr, tpr, thresholds = sklearn_metrics.roc_curve( | |
| y_one_hot.ravel(), lprobs.ravel() | |
| ) | |
| # Revisit the interpolation approach. | |
| # eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) | |
| fnr = 1 - tpr | |
| eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] | |
| return eer | |
| with torch.no_grad(): | |
| net_output = model(**sample["net_input"]) | |
| lprobs = ( | |
| model.get_normalized_probs(net_output, log_probs=True).cpu().detach() | |
| ) | |
| target_list = sample["target"][:, 0].detach().cpu() | |
| predicted_list = torch.argmax(lprobs, 1).detach().cpu() # B,C->B | |
| metrics = { | |
| "correct": torch.sum(target_list == predicted_list).item(), | |
| "total": len(target_list), | |
| } | |
| return metrics | |
| def reduce_metrics(self, logging_outputs, criterion): | |
| super().reduce_metrics(logging_outputs, criterion) | |
| zero = torch.scalar_tensor(0.0) | |
| correct, total = 0, 0 | |
| for log in logging_outputs: | |
| correct += log.get("_correct", zero) | |
| total += log.get("_total", zero) | |
| metrics.log_scalar("_correct", correct) | |
| metrics.log_scalar("_total", total) | |
| if total > 0: | |
| def _fn_accuracy(meters): | |
| if meters["_total"].sum > 0: | |
| return utils.item(meters["_correct"].sum / meters["_total"].sum) | |
| return float("nan") | |
| metrics.log_derived("accuracy", _fn_accuracy) | |
| """ | |
| prec_sum, recall_sum, f1_sum, acc_sum, eer_sum = 0.0, 0.0, 0.0, 0.0, 0.0 | |
| for log in logging_outputs: | |
| prec_sum += log.get("_precision", zero).item() | |
| recall_sum += log.get("_recall", zero).item() | |
| f1_sum += log.get("_f1", zero).item() | |
| acc_sum += log.get("_accuracy", zero).item() | |
| eer_sum += log.get("_eer", zero).item() | |
| metrics.log_scalar("avg_precision", prec_sum / len(logging_outputs)) | |
| metrics.log_scalar("avg_recall", recall_sum / len(logging_outputs)) | |
| metrics.log_scalar("avg_f1", f1_sum / len(logging_outputs)) | |
| metrics.log_scalar("avg_accuracy", acc_sum / len(logging_outputs)) | |
| metrics.log_scalar("avg_eer", eer_sum / len(logging_outputs)) | |
| """ |