Spaces:
Runtime error
Runtime error
| import datetime | |
| import functools | |
| import os | |
| import sys | |
| from typing import List | |
| from typing import Union | |
| import pytz | |
| import torch | |
| import torch.distributed as tdist | |
| import torch.multiprocessing as mp | |
| __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu' | |
| __rank_str_zfill = '0' | |
| __initialized = False | |
| def initialized(): | |
| return __initialized | |
| def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30): | |
| global __device | |
| if not torch.cuda.is_available(): | |
| print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) | |
| return | |
| elif 'RANK' not in os.environ: | |
| torch.cuda.set_device(gpu_id_if_not_distibuted) | |
| __device = torch.empty(1).cuda().device | |
| print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) | |
| return | |
| # then 'RANK' must exist | |
| global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() | |
| local_rank = global_rank % num_gpus | |
| torch.cuda.set_device(local_rank) | |
| # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 | |
| if mp.get_start_method(allow_none=True) is None: | |
| method = 'fork' if fork else 'spawn' | |
| print(f'[dist initialize] mp method={method}') | |
| mp.set_start_method(method) | |
| tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) | |
| global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill | |
| __local_rank = local_rank | |
| __rank, __world_size = tdist.get_rank(), tdist.get_world_size() | |
| __rank_str_zfill = str(__rank).zfill(len(str(__world_size))) | |
| __device = torch.empty(1).cuda().device | |
| __initialized = True | |
| assert tdist.is_initialized(), 'torch.distributed is not initialized!' | |
| print(f'[lrk={get_local_rank()}, rk={get_rank()}]') | |
| def get_rank(): | |
| return __rank | |
| def get_rank_str_zfill(): | |
| return __rank_str_zfill | |
| def get_local_rank(): | |
| return __local_rank | |
| def get_world_size(): | |
| return __world_size | |
| def get_device(): | |
| return __device | |
| def set_gpu_id(gpu_id: int): | |
| if gpu_id is None: return | |
| global __device | |
| if isinstance(gpu_id, (str, int)): | |
| torch.cuda.set_device(int(gpu_id)) | |
| __device = torch.empty(1).cuda().device | |
| else: | |
| raise NotImplementedError | |
| def is_master(): | |
| return __rank == 0 | |
| def is_local_master(): | |
| return __local_rank == 0 | |
| def new_group(ranks: List[int]): | |
| if __initialized: | |
| return tdist.new_group(ranks=ranks) | |
| return None | |
| def new_local_machine_group(): | |
| if __initialized: | |
| cur_subgroup, subgroups = tdist.new_subgroups() | |
| return cur_subgroup | |
| return None | |
| def barrier(): | |
| if __initialized: | |
| tdist.barrier() | |
| def allreduce(t: torch.Tensor, async_op=False): | |
| if __initialized: | |
| if not t.is_cuda: | |
| cu = t.detach().cuda() | |
| ret = tdist.all_reduce(cu, async_op=async_op) | |
| t.copy_(cu.cpu()) | |
| else: | |
| ret = tdist.all_reduce(t, async_op=async_op) | |
| return ret | |
| return None | |
| def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: | |
| if __initialized: | |
| if not t.is_cuda: | |
| t = t.cuda() | |
| ls = [torch.empty_like(t) for _ in range(__world_size)] | |
| tdist.all_gather(ls, t) | |
| else: | |
| ls = [t] | |
| if cat: | |
| ls = torch.cat(ls, dim=0) | |
| return ls | |
| def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: | |
| if __initialized: | |
| if not t.is_cuda: | |
| t = t.cuda() | |
| t_size = torch.tensor(t.size(), device=t.device) | |
| ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] | |
| tdist.all_gather(ls_size, t_size) | |
| max_B = max(size[0].item() for size in ls_size) | |
| pad = max_B - t_size[0].item() | |
| if pad: | |
| pad_size = (pad, *t.size()[1:]) | |
| t = torch.cat((t, t.new_empty(pad_size)), dim=0) | |
| ls_padded = [torch.empty_like(t) for _ in range(__world_size)] | |
| tdist.all_gather(ls_padded, t) | |
| ls = [] | |
| for t, size in zip(ls_padded, ls_size): | |
| ls.append(t[:size[0].item()]) | |
| else: | |
| ls = [t] | |
| if cat: | |
| ls = torch.cat(ls, dim=0) | |
| return ls | |
| def broadcast(t: torch.Tensor, src_rank) -> None: | |
| if __initialized: | |
| if not t.is_cuda: | |
| cu = t.detach().cuda() | |
| tdist.broadcast(cu, src=src_rank) | |
| t.copy_(cu.cpu()) | |
| else: | |
| tdist.broadcast(t, src=src_rank) | |
| def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: | |
| if not initialized(): | |
| return torch.tensor([val]) if fmt is None else [fmt % val] | |
| ts = torch.zeros(__world_size) | |
| ts[__rank] = val | |
| allreduce(ts) | |
| if fmt is None: | |
| return ts | |
| return [fmt % v for v in ts.cpu().numpy().tolist()] | |
| def master_only(func): | |
| def wrapper(*args, **kwargs): | |
| force = kwargs.pop('force', False) | |
| if force or is_master(): | |
| ret = func(*args, **kwargs) | |
| else: | |
| ret = None | |
| barrier() | |
| return ret | |
| return wrapper | |
| def local_master_only(func): | |
| def wrapper(*args, **kwargs): | |
| force = kwargs.pop('force', False) | |
| if force or is_local_master(): | |
| ret = func(*args, **kwargs) | |
| else: | |
| ret = None | |
| barrier() | |
| return ret | |
| return wrapper | |
| def for_visualize(func): | |
| def wrapper(*args, **kwargs): | |
| if is_master(): | |
| # with torch.no_grad(): | |
| ret = func(*args, **kwargs) | |
| else: | |
| ret = None | |
| return ret | |
| return wrapper | |
| def finalize(): | |
| if __initialized: | |
| tdist.destroy_process_group() | |
| def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30): | |
| try: | |
| __initialize(fork=False, timeout_minutes=timeout_minutes) | |
| barrier() | |
| except RuntimeError as e: | |
| print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True) | |
| raise e | |
| if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) | |
| _change_builtin_print(is_local_master()) | |
| if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path): | |
| sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False) | |
| def _change_builtin_print(is_master): | |
| import builtins as __builtin__ | |
| builtin_print = __builtin__.print | |
| if type(builtin_print) != type(open): | |
| return | |
| def prt(*args, **kwargs): | |
| force = kwargs.pop('force', False) | |
| clean = kwargs.pop('clean', False) | |
| deeper = kwargs.pop('deeper', False) | |
| if is_master or force: | |
| if not clean: | |
| f_back = sys._getframe().f_back | |
| if deeper and f_back.f_back is not None: | |
| f_back = f_back.f_back | |
| file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] | |
| time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') | |
| builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) | |
| else: | |
| builtin_print(*args, **kwargs) | |
| __builtin__.print = prt | |
| class BackupStreamToFile(object): | |
| def __init__(self, local_output_dir, for_stdout=True): | |
| self.for_stdout = for_stdout | |
| self.terminal_stream = sys.stdout if for_stdout else sys.stderr | |
| fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt') | |
| existing = os.path.exists(fname) | |
| self.file_stream = open(fname, 'a') | |
| if existing: | |
| time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') | |
| self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n') | |
| self.file_stream.flush() | |
| self.enabled = True | |
| def write(self, message): | |
| self.terminal_stream.write(message) | |
| self.file_stream.write(message) | |
| def flush(self): | |
| self.terminal_stream.flush() | |
| self.file_stream.flush() | |
| def close(self): | |
| if not self.enabled: | |
| return | |
| self.enabled = False | |
| self.file_stream.flush() | |
| self.file_stream.close() | |
| if self.for_stdout: | |
| sys.stdout = self.terminal_stream | |
| sys.stdout.flush() | |
| else: | |
| sys.stderr = self.terminal_stream | |
| sys.stderr.flush() | |
| def __del__(self): | |
| self.close() | |