Spaces:
Running
Running
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import os | |
| from fairseq.dataclass.initialize import add_defaults, hydra_init | |
| from fairseq_cli.train import main as pre_main | |
| from fairseq import distributed_utils, metrics | |
| from fairseq.dataclass.configs import FairseqConfig | |
| from fairseq.dataclass.utils import omegaconf_no_object_check | |
| from fairseq.utils import reset_logging | |
| import hydra | |
| from hydra.core.hydra_config import HydraConfig | |
| import torch | |
| from omegaconf import OmegaConf, open_dict | |
| logger = logging.getLogger("fairseq_cli.hydra_train") | |
| def hydra_main(cfg: FairseqConfig) -> float: | |
| _hydra_main(cfg) | |
| def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: | |
| add_defaults(cfg) | |
| if cfg.common.reset_logging: | |
| reset_logging() # Hydra hijacks logging, fix that | |
| else: | |
| # check if directly called or called through hydra_main | |
| if HydraConfig.initialized(): | |
| with open_dict(cfg): | |
| # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) | |
| cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True) | |
| with omegaconf_no_object_check(): | |
| cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) | |
| OmegaConf.set_struct(cfg, True) | |
| try: | |
| if cfg.common.profile: | |
| with torch.cuda.profiler.profile(): | |
| with torch.autograd.profiler.emit_nvtx(): | |
| distributed_utils.call_main(cfg, pre_main, **kwargs) | |
| else: | |
| distributed_utils.call_main(cfg, pre_main, **kwargs) | |
| except BaseException as e: | |
| if not cfg.common.suppress_crashes: | |
| raise | |
| else: | |
| logger.error("Crashed! " + str(e)) | |
| # get best val and return - useful for sweepers | |
| try: | |
| best_val = metrics.get_smoothed_value( | |
| "valid", cfg.checkpoint.best_checkpoint_metric | |
| ) | |
| except: | |
| best_val = None | |
| if best_val is None: | |
| best_val = float("inf") | |
| return best_val | |
| def cli_main(): | |
| try: | |
| from hydra._internal.utils import get_args | |
| cfg_name = get_args().config_name or "config" | |
| except: | |
| logger.warning("Failed to get config name from hydra args") | |
| cfg_name = "config" | |
| hydra_init(cfg_name) | |
| hydra_main() | |
| if __name__ == "__main__": | |
| cli_main() | |