Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import torch | |
| import random | |
| import numpy as np | |
| from tap import Tap | |
| from typing import Optional, Union | |
| from collections import OrderedDict | |
| from unitok import dist | |
| class Args(Tap): | |
| model: str = 'vitamin_large' # 'vitamin_base', 'vitamin_large', xxx | |
| exp_name: str = 'unitok_large' | |
| output_dir: str = 'local_output' | |
| resume_from: str = '' # if specified, load this checkpoint; if not, load the latest checkpoint in output_dir (if exists) | |
| lpips_path: str = 'external/lpips_with_vgg.pth' | |
| dino_path: str = 'external/dinov2_vits14_pretrain.pth' | |
| fid_eval_src: str = '' | |
| fid_eval_dst: str = '' | |
| vis_img_dir: str = 'asset/vis_imgs/' | |
| fid_feature_extractor: str = 'external/weights-inception-2015-12-05-6726825d.pth' | |
| clip_pretrain_path: str = '' | |
| # speed-up | |
| fp16: bool = False # whether to use FP16 | |
| bf16: bool = True # whether to use BF16 | |
| tf32: bool = True # whether to use TensorFloat32 | |
| compile_model: bool = False # whether to use torch.compile() | |
| ddp_static: bool = False # whether to use static graph in DDP | |
| grad_ckpt: bool = True # gradient checkpointing | |
| grad_accu: int = 1 # gradient accumulation | |
| device: str = 'cpu' # will be set automatically | |
| dtype: torch.dtype = torch.float32 # will be set automatically | |
| # data | |
| train_data: str = None | |
| val_data: str = None | |
| dataset_type: str = 'webdataset' | |
| imagenet_val: str = None | |
| imagenet_v2: str = None | |
| subset_ratio: float = 1.0 | |
| img_size: int = 256 | |
| resize_ratio: float = 1.125 # only applicable to 'img' dataset_type | |
| hflip: bool = False | |
| workers: int = 8 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader | |
| train_num_samples: int = 1280_000_000 | |
| train_data_upsampling_factors: str = None | |
| dataset_resampled: bool = False | |
| use_aug: bool = False | |
| # quantizer | |
| vocab_size: int = 32768 | |
| vocab_width: int = 64 | |
| vocab_norm: bool = True | |
| vq_beta: float = 0.25 # commitment loss weight | |
| num_codebooks: int = 8 | |
| quant_proj: str = 'attn' | |
| # model | |
| embed_dim: int = 768 | |
| num_query: int = 0 | |
| use_clip_pretrain: bool = False | |
| patch_size: int = 16 | |
| drop_path: float = 0.1 | |
| text_width: int = 768 | |
| text_heads: int = 12 | |
| text_layers: int = 12 | |
| text_vocab_size: int = 49408 | |
| text_context_length: int = 77 | |
| # CLIP | |
| local_loss: bool = True | |
| gather_with_grad: bool = True | |
| pretrained_clip: str = None | |
| pretrained_clip_text: str = None | |
| lock_text: bool = False | |
| lock_text_unlocked_layers: int = 0 | |
| lock_text_freeze_layer_norm: bool = False | |
| force_custom_text: bool = False | |
| force_custom_vision: bool = False | |
| zeroshot_eval_freq: int = 1 | |
| # discriminator | |
| dino_depth: int = 12 | |
| dino_kernel_size: int = 9 | |
| disc_norm: str = 'gn' # gn: group norm, bn: batch norm, sbn: sync batch norm, hbn: hybrid sync batch norm | |
| disc_aug_prob: float = 1.0 | |
| disc_specnorm: bool = False | |
| step_disc_every: int = 1 | |
| # initialization | |
| vae_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init) | |
| vocab_init: float = -1 # <0: uniform(-abs(init)*base, abs(init)*base), where base = 20/vocab_size; >0: trunc_normal_(std=init) | |
| disc_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init) | |
| # optimization | |
| epoch: int = 1 # number of epochs | |
| local_bs: int = 64 # batch size per device; if this is specified, --global_bs will be ignored | |
| vae_local_bs: int = 64 # sub-batch size for vae loss calculation | |
| global_bs: int = 0 # global batch size (exclusive to --local_bs) | |
| lr: float = 5e-4 # learning rate | |
| wd: float = 0.02 # weight decay | |
| disc_lr: float = 2e-5 # disc lr | |
| disc_wd: float = 0.2 | |
| grad_clip: float = 10 # <=0 for not using grad clip | |
| ema: float = 0.9999 # ema ratio | |
| warmup_iter: int = None | |
| warmup_ep: float = 0.01 # lr warmup: epochs | |
| disc_start_ep: float = 0.375 # start using disc loss for VAE after xxx epochs; | |
| disc_warmup_ep: float = 0.03 # disc loss warm up epochs; | |
| schedule: str = 'cos' # lr schedule type | |
| lr_start_ratio: float = 0. # lr warmup: initial lr ratio | |
| lr_end_ratio: float = 0.1 # lr schedule: final lr ratio | |
| disc_lr_end_ratio: float = 0.1 | |
| custom_lr_multiplier: float = None | |
| optimizer: str = 'adamw' | |
| optim_eps: float = 1e-6 | |
| fuse_opt: bool = False # whether to use fused optimizer | |
| optim_beta: str = '0.9_0.95' # beta1, beta2 of optimizer | |
| disc_optim_beta: str = '0.5_0.9' # beta1, beta2 of disc optimizer | |
| # loss | |
| l1: float = 0.2 # L1 rec loss weight | |
| l2: float = 1.0 # L2 rec loss weight | |
| lp: float = 1.0 # lpips loss weight | |
| lpr: int = 48 # only calculate lpips >= this image resolution | |
| ld: float = 0.4 # discriminator loss weight; if <0: NO ADAPTIVE WEIGHT | |
| le: float = 0.0 # VQ entropy loss weight | |
| lq: float = 1.0 | |
| lc: float = 1.0 # CLIP loss weight | |
| e_temp: float = 0.01 | |
| gada: int = 1 | |
| bcr: float = 4. # balanced Consistency Regularization, used on small dataset with low reso, StyleSwin: 10.0 | |
| bcr_cut: float = 0.2 # cutout ratio (0.5: 50% width) | |
| dcrit: str = 'hg' # hg hinge, sp softplus, ln linear | |
| # wandb log | |
| report_wandb: bool = True | |
| wandb_notes: str = None | |
| run_id: str = None | |
| # debug | |
| eval_per_epoch: int = 8 | |
| dbg_unused_param: bool = False | |
| dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ | |
| seed: int = None | |
| deterministic: bool = False | |
| same_seed_for_all_ranks: int = 0 # this is only for distributed sampler | |
| def seed_everything(self): | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = False | |
| if self.seed is not None: | |
| if self.deterministic: | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| torch.use_deterministic_algorithms(True) | |
| os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8' | |
| seed = self.seed + dist.get_rank() * 10000 | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation | |
| if self.seed is None: | |
| return None | |
| g = torch.Generator() | |
| g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank()) | |
| return g | |
| def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: | |
| d = (OrderedDict if key_ordered else dict)() | |
| for k in self.class_variables.keys(): | |
| if k not in {'device'}: # these are not serializable | |
| d[k] = getattr(self, k) | |
| return d | |
| def load_state_dict(self, state_dict): | |
| for k, v in state_dict.items(): | |
| try: | |
| setattr(self, k, v) | |
| except Exception as e: | |
| print(f'k={k}, v={v}') | |
| raise e | |
| def set_tf32(tf32: bool): | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.allow_tf32 = bool(tf32) | |
| torch.backends.cuda.matmul.allow_tf32 = bool(tf32) | |
| if hasattr(torch, 'set_float32_matmul_precision'): | |
| torch.set_float32_matmul_precision('high' if tf32 else 'highest') | |
| print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}') | |
| print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}') | |
| print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}') | |
| def __str__(self): | |
| s = [] | |
| for k in self.class_variables.keys(): | |
| if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
| s.append(f' {k:20s}: {getattr(self, k)}') | |
| s = '\n'.join(s) | |
| return f'{{\n{s}\n}}\n' | |
| def init_dist_and_get_args(): | |
| for i in range(len(sys.argv)): | |
| if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='): | |
| del sys.argv[i] | |
| break | |
| args = Args(explicit_bool=True).parse_args(known_only=True) | |
| # warn args.extra_args | |
| if len(args.extra_args) > 0: | |
| print(f'======================================================================================') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') | |
| print(f'======================================================================================\n\n') | |
| # init torch distributed | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| dist.init_distributed_mode(local_out_path=args.output_dir, timeout_minutes=30) | |
| # set env | |
| args.set_tf32(args.tf32) | |
| args.seed_everything() | |
| args.device = dist.get_device() | |
| # update args | |
| if args.local_bs == 0: | |
| args.local_bs = max(1, round(args.global_bs / args.grad_accu / dist.get_world_size())) | |
| args.global_bs = args.local_bs * dist.get_world_size() | |
| if args.fp16 or args.bf16: | |
| args.dtype = torch.float16 if args.fp16 else torch.bfloat16 | |
| return args | |