Spaces:
Running
Running
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import os | |
| import sys | |
| import json | |
| from itertools import chain | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from fairseq import distributed_utils, options, tasks, utils | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.logging import progress_bar | |
| from fairseq.utils import reset_logging | |
| from omegaconf import DictConfig | |
| from utils import checkpoint_utils | |
| from utils.eval_utils import eval_step | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
| stream=sys.stdout, | |
| ) | |
| logger = logging.getLogger("ofa.evaluate") | |
| def apply_half(t): | |
| if t.dtype is torch.float32: | |
| return t.to(dtype=torch.half) | |
| return t | |
| def main(cfg: DictConfig, **kwargs): | |
| utils.import_user_module(cfg.common) | |
| reset_logging() | |
| logger.info(cfg) | |
| assert ( | |
| cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None | |
| ), "Must specify batch size either with --max-tokens or --batch-size" | |
| # Fix seed for stochastic decoding | |
| if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
| np.random.seed(cfg.common.seed) | |
| utils.set_torch_seed(cfg.common.seed) | |
| use_fp16 = cfg.common.fp16 | |
| use_cuda = torch.cuda.is_available() and not cfg.common.cpu | |
| if use_cuda: | |
| torch.cuda.set_device(cfg.distributed_training.device_id) | |
| # Load ensemble | |
| overrides = eval(cfg.common_eval.model_overrides) | |
| logger.info("loading model(s) from {}".format(cfg.common_eval.path)) | |
| models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
| utils.split_paths(cfg.common_eval.path), | |
| arg_overrides=overrides, | |
| suffix=cfg.checkpoint.checkpoint_suffix, | |
| strict=(cfg.checkpoint.checkpoint_shard_count == 1), | |
| num_shards=cfg.checkpoint.checkpoint_shard_count, | |
| ) | |
| # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config | |
| task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) | |
| # Move models to GPU | |
| for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)): | |
| if kwargs['ema_eval']: | |
| logger.info("loading EMA weights from {}".format(ckpt_path)) | |
| model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model']) | |
| model.eval() | |
| if use_fp16: | |
| model.half() | |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
| model.cuda() | |
| model.prepare_for_inference_(cfg) | |
| # Load dataset (possibly sharded) | |
| itr = task.get_batch_iterator( | |
| dataset=task.dataset(cfg.dataset.gen_subset), | |
| max_tokens=cfg.dataset.max_tokens, | |
| max_sentences=cfg.dataset.batch_size, | |
| max_positions=utils.resolve_max_positions( | |
| task.max_positions(), *[m.max_positions() for m in models] | |
| ), | |
| ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, | |
| required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, | |
| seed=cfg.common.seed, | |
| num_shards=cfg.distributed_training.distributed_world_size, | |
| shard_id=cfg.distributed_training.distributed_rank, | |
| num_workers=cfg.dataset.num_workers, | |
| data_buffer_size=cfg.dataset.data_buffer_size, | |
| ).next_epoch_itr(shuffle=False) | |
| progress = progress_bar.progress_bar( | |
| itr, | |
| log_format=cfg.common.log_format, | |
| log_interval=cfg.common.log_interval, | |
| default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), | |
| ) | |
| # Initialize generator | |
| generator = task.build_generator(models, cfg.generation) | |
| results = [] | |
| score_sum = torch.FloatTensor([0]).cuda() | |
| score_cnt = torch.FloatTensor([0]).cuda() | |
| for sample in progress: | |
| if "net_input" not in sample: | |
| continue | |
| sample = utils.move_to_cuda(sample) if use_cuda else sample | |
| sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample | |
| with torch.no_grad(): | |
| result, scores = eval_step(task, generator, models, sample) | |
| results += result | |
| score_sum += sum(scores) if scores is not None else 0 | |
| score_cnt += len(scores) if scores is not None else 0 | |
| progress.log({"sentences": sample["nsentences"]}) | |
| gather_results = None | |
| if cfg.distributed_training.distributed_world_size > 1: | |
| gather_results = [None for _ in range(dist.get_world_size())] | |
| dist.all_gather_object(gather_results, results) | |
| dist.all_reduce(score_sum.data) | |
| dist.all_reduce(score_cnt.data) | |
| if score_cnt.item() > 0: | |
| logger.info("score_sum: {}, score_cnt: {}, score: {}".format( | |
| score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4) | |
| )) | |
| if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0: | |
| os.makedirs(cfg.common_eval.results_path, exist_ok=True) | |
| output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset)) | |
| gather_results = list(chain(*gather_results)) if gather_results is not None else results | |
| with open(output_path, 'w') as fw: | |
| json.dump(gather_results, fw) | |
| def cli_main(): | |
| parser = options.get_generation_parser() | |
| parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.") | |
| args = options.parse_args_and_arch(parser) | |
| cfg = convert_namespace_to_omegaconf(args) | |
| distributed_utils.call_main(cfg, main, ema_eval=args.ema_eval) | |
| if __name__ == "__main__": | |
| cli_main() |