Spaces:
Configuration error
Configuration error
| import argparse | |
| import logging | |
| import os | |
| import random | |
| import torch | |
| from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase | |
| from fastai.distributed import * | |
| from fastai.vision import * | |
| from torch.backends import cudnn | |
| from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy | |
| from dataset import ImageDataset, TextDataset | |
| from losses import MultiLosses | |
| from utils import Config, Logger, MyDataParallel, MyConcatDataset | |
| def _set_random_seed(seed): | |
| if seed is not None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| cudnn.deterministic = True | |
| logging.warning('You have chosen to seed training. ' | |
| 'This will slow down your training!') | |
| def _get_training_phases(config, n): | |
| lr = np.array(config.optimizer_lr) | |
| periods = config.optimizer_scheduler_periods | |
| sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))] | |
| phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i]) | |
| for i in range(len(periods))] | |
| return phases | |
| def _get_dataset(ds_type, paths, is_training, config, **kwargs): | |
| kwargs.update({ | |
| 'img_h': config.dataset_image_height, | |
| 'img_w': config.dataset_image_width, | |
| 'max_length': config.dataset_max_length, | |
| 'case_sensitive': config.dataset_case_sensitive, | |
| 'charset_path': config.dataset_charset_path, | |
| 'data_aug': config.dataset_data_aug, | |
| 'deteriorate_ratio': config.dataset_deteriorate_ratio, | |
| 'is_training': is_training, | |
| 'multiscales': config.dataset_multiscales, | |
| 'one_hot_y': config.dataset_one_hot_y, | |
| }) | |
| datasets = [ds_type(p, **kwargs) for p in paths] | |
| if len(datasets) > 1: return MyConcatDataset(datasets) | |
| else: return datasets[0] | |
| def _get_language_databaunch(config): | |
| kwargs = { | |
| 'max_length': config.dataset_max_length, | |
| 'case_sensitive': config.dataset_case_sensitive, | |
| 'charset_path': config.dataset_charset_path, | |
| 'smooth_label': config.dataset_smooth_label, | |
| 'smooth_factor': config.dataset_smooth_factor, | |
| 'one_hot_y': config.dataset_one_hot_y, | |
| 'use_sm': config.dataset_use_sm, | |
| } | |
| train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs) | |
| valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs) | |
| data = DataBunch.create( | |
| path=train_ds.path, | |
| train_ds=train_ds, | |
| valid_ds=valid_ds, | |
| bs=config.dataset_train_batch_size, | |
| val_bs=config.dataset_test_batch_size, | |
| num_workers=config.dataset_num_workers, | |
| pin_memory=config.dataset_pin_memory) | |
| logging.info(f'{len(data.train_ds)} training items found.') | |
| if not data.empty_val: | |
| logging.info(f'{len(data.valid_ds)} valid items found.') | |
| return data | |
| def _get_databaunch(config): | |
| # An awkward way to reduce loadding data time during test | |
| if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots | |
| train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config) | |
| valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config) | |
| data = ImageDataBunch.create( | |
| train_ds=train_ds, | |
| valid_ds=valid_ds, | |
| bs=config.dataset_train_batch_size, | |
| val_bs=config.dataset_test_batch_size, | |
| num_workers=config.dataset_num_workers, | |
| pin_memory=config.dataset_pin_memory).normalize(imagenet_stats) | |
| ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd | |
| data.add_tfm(ar_tfm) | |
| logging.info(f'{len(data.train_ds)} training items found.') | |
| if not data.empty_val: | |
| logging.info(f'{len(data.valid_ds)} valid items found.') | |
| return data | |
| def _get_model(config): | |
| import importlib | |
| names = config.model_name.split('.') | |
| module_name, class_name = '.'.join(names[:-1]), names[-1] | |
| cls = getattr(importlib.import_module(module_name), class_name) | |
| model = cls(config) | |
| logging.info(model) | |
| return model | |
| def _get_learner(config, data, model, local_rank=None): | |
| strict = ifnone(config.model_strict, True) | |
| if config.global_stage == 'pretrain-language': | |
| metrics = [TopKTextAccuracy( | |
| k=ifnone(config.model_k, 5), | |
| charset_path=config.dataset_charset_path, | |
| max_length=config.dataset_max_length + 1, | |
| case_sensitive=config.dataset_eval_case_sensisitves, | |
| model_eval=config.model_eval)] | |
| else: | |
| metrics = [TextAccuracy( | |
| charset_path=config.dataset_charset_path, | |
| max_length=config.dataset_max_length + 1, | |
| case_sensitive=config.dataset_eval_case_sensisitves, | |
| model_eval=config.model_eval)] | |
| opt_type = getattr(torch.optim, config.optimizer_type) | |
| learner = Learner(data, model, silent=True, model_dir='.', | |
| true_wd=config.optimizer_true_wd, | |
| wd=config.optimizer_wd, | |
| bn_wd=config.optimizer_bn_wd, | |
| path=config.global_workdir, | |
| metrics=metrics, | |
| opt_func=partial(opt_type, **config.optimizer_args or dict()), | |
| loss_func=MultiLosses(one_hot=config.dataset_one_hot_y)) | |
| learner.split(lambda m: children(m)) | |
| if config.global_phase == 'train': | |
| num_replicas = 1 if local_rank is None else torch.distributed.get_world_size() | |
| phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas) | |
| learner.callback_fns += [ | |
| partial(GeneralScheduler, phases=phases), | |
| partial(GradientClipping, clip=config.optimizer_clip_grad), | |
| partial(IterationCallback, name=config.global_name, | |
| show_iters=config.training_show_iters, | |
| eval_iters=config.training_eval_iters, | |
| save_iters=config.training_save_iters, | |
| start_iters=config.training_start_iters, | |
| stats_iters=config.training_stats_iters)] | |
| else: | |
| learner.callbacks += [ | |
| DumpPrediction(learn=learner, | |
| dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path, | |
| model_eval=config.model_eval, | |
| debug=config.global_debug, | |
| image_only=config.global_image_only)] | |
| learner.rank = local_rank | |
| if local_rank is not None: | |
| logging.info(f'Set model to distributed with rank {local_rank}.') | |
| learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model) | |
| learner.model.to(local_rank) | |
| learner = learner.to_distributed(local_rank) | |
| if torch.cuda.device_count() > 1 and local_rank is None: | |
| logging.info(f'Use {torch.cuda.device_count()} GPUs.') | |
| learner.model = MyDataParallel(learner.model) | |
| if config.model_checkpoint: | |
| if Path(config.model_checkpoint).exists(): | |
| with open(config.model_checkpoint, 'rb') as f: | |
| buffer = io.BytesIO(f.read()) | |
| learner.load(buffer, strict=strict) | |
| else: | |
| from distutils.dir_util import copy_tree | |
| src = Path('/data/fangsc/model')/config.global_name | |
| trg = Path('/output')/config.global_name | |
| if src.exists(): copy_tree(str(src), str(trg)) | |
| learner.load(config.model_checkpoint, strict=strict) | |
| logging.info(f'Read model from {config.model_checkpoint}') | |
| elif config.global_phase == 'test': | |
| learner.load(f'best-{config.global_name}', strict=strict) | |
| logging.info(f'Read model from best-{config.global_name}') | |
| if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60 | |
| learner.fit(epochs=0, lr=config.optimizer_lr) | |
| learner.opt.mom = 0. | |
| return learner | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, required=True, | |
| help='path to config file') | |
| parser.add_argument('--phase', type=str, default=None, choices=['train', 'test']) | |
| parser.add_argument('--name', type=str, default=None) | |
| parser.add_argument('--checkpoint', type=str, default=None) | |
| parser.add_argument('--test_root', type=str, default=None) | |
| parser.add_argument("--local_rank", type=int, default=None) | |
| parser.add_argument('--debug', action='store_true', default=None) | |
| parser.add_argument('--image_only', action='store_true', default=None) | |
| parser.add_argument('--model_strict', action='store_false', default=None) | |
| parser.add_argument('--model_eval', type=str, default=None, | |
| choices=['alignment', 'vision', 'language']) | |
| args = parser.parse_args() | |
| config = Config(args.config) | |
| if args.name is not None: config.global_name = args.name | |
| if args.phase is not None: config.global_phase = args.phase | |
| if args.test_root is not None: config.dataset_test_roots = [args.test_root] | |
| if args.checkpoint is not None: config.model_checkpoint = args.checkpoint | |
| if args.debug is not None: config.global_debug = args.debug | |
| if args.image_only is not None: config.global_image_only = args.image_only | |
| if args.model_eval is not None: config.model_eval = args.model_eval | |
| if args.model_strict is not None: config.model_strict = args.model_strict | |
| Logger.init(config.global_workdir, config.global_name, config.global_phase) | |
| Logger.enable_file() | |
| _set_random_seed(config.global_seed) | |
| logging.info(config) | |
| if args.local_rank is not None: | |
| logging.info(f'Init distribution training at device {args.local_rank}.') | |
| torch.cuda.set_device(args.local_rank) | |
| torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
| logging.info('Construct dataset.') | |
| if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config) | |
| else: data = _get_databaunch(config) | |
| logging.info('Construct model.') | |
| model = _get_model(config) | |
| logging.info('Construct learner.') | |
| learner = _get_learner(config, data, model, args.local_rank) | |
| if config.global_phase == 'train': | |
| logging.info('Start training.') | |
| learner.fit(epochs=config.training_epochs, | |
| lr=config.optimizer_lr) | |
| else: | |
| logging.info('Start validate') | |
| last_metrics = learner.validate() | |
| log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \ | |
| f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \ | |
| f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \ | |
| f'ted/w = {last_metrics[5]:6.3f}, ' | |
| logging.info(log_str) | |
| if __name__ == '__main__': | |
| main() | |