Spaces:
Runtime error
Runtime error
| """ Scheduler Factory | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| from .cosine_lr import CosineLRScheduler | |
| from .tanh_lr import TanhLRScheduler | |
| from .step_lr import StepLRScheduler | |
| from .plateau_lr import PlateauLRScheduler | |
| def create_scheduler(args, optimizer): | |
| num_epochs = args.epochs | |
| if getattr(args, 'lr_noise', None) is not None: | |
| lr_noise = getattr(args, 'lr_noise') | |
| if isinstance(lr_noise, (list, tuple)): | |
| noise_range = [n * num_epochs for n in lr_noise] | |
| if len(noise_range) == 1: | |
| noise_range = noise_range[0] | |
| else: | |
| noise_range = lr_noise * num_epochs | |
| else: | |
| noise_range = None | |
| lr_scheduler = None | |
| if args.sched == 'cosine': | |
| lr_scheduler = CosineLRScheduler( | |
| optimizer, | |
| t_initial=num_epochs, | |
| t_mul=getattr(args, 'lr_cycle_mul', 1.), | |
| lr_min=args.min_lr, | |
| decay_rate=args.decay_rate, | |
| warmup_lr_init=args.warmup_lr, | |
| warmup_t=args.warmup_epochs, | |
| cycle_limit=getattr(args, 'lr_cycle_limit', 1), | |
| t_in_epochs=True, | |
| noise_range_t=noise_range, | |
| noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
| noise_std=getattr(args, 'lr_noise_std', 1.), | |
| noise_seed=getattr(args, 'seed', 42), | |
| ) | |
| num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs | |
| elif args.sched == 'tanh': | |
| lr_scheduler = TanhLRScheduler( | |
| optimizer, | |
| t_initial=num_epochs, | |
| t_mul=getattr(args, 'lr_cycle_mul', 1.), | |
| lr_min=args.min_lr, | |
| warmup_lr_init=args.warmup_lr, | |
| warmup_t=args.warmup_epochs, | |
| cycle_limit=getattr(args, 'lr_cycle_limit', 1), | |
| t_in_epochs=True, | |
| noise_range_t=noise_range, | |
| noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
| noise_std=getattr(args, 'lr_noise_std', 1.), | |
| noise_seed=getattr(args, 'seed', 42), | |
| ) | |
| num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs | |
| elif args.sched == 'step': | |
| lr_scheduler = StepLRScheduler( | |
| optimizer, | |
| decay_t=args.decay_epochs, | |
| decay_rate=args.decay_rate, | |
| warmup_lr_init=args.warmup_lr, | |
| warmup_t=args.warmup_epochs, | |
| noise_range_t=noise_range, | |
| noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
| noise_std=getattr(args, 'lr_noise_std', 1.), | |
| noise_seed=getattr(args, 'seed', 42), | |
| ) | |
| elif args.sched == 'plateau': | |
| mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' | |
| lr_scheduler = PlateauLRScheduler( | |
| optimizer, | |
| decay_rate=args.decay_rate, | |
| patience_t=args.patience_epochs, | |
| lr_min=args.min_lr, | |
| mode=mode, | |
| warmup_lr_init=args.warmup_lr, | |
| warmup_t=args.warmup_epochs, | |
| cooldown_t=0, | |
| noise_range_t=noise_range, | |
| noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
| noise_std=getattr(args, 'lr_noise_std', 1.), | |
| noise_seed=getattr(args, 'seed', 42), | |
| ) | |
| return lr_scheduler, num_epochs | |