Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import torch | |
| import math | |
| from torch import nn | |
| from typing import List | |
| from transformers import BertTokenizer | |
| from urllib.parse import urlparse | |
| from timm.models.hub import download_cached_file | |
| from .vit import interpolate_pos_embed | |
| from .swin_transformer import interpolate_relative_pos_embed | |
| from pathlib import Path | |
| CONFIG_PATH=(Path(__file__).resolve().parents[1]) | |
| def read_json(rpath): | |
| with open(rpath, 'r') as f: | |
| return json.load(f) | |
| def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, | |
| base_model_prefix: str, skip_key: str): | |
| uninitialized_encoder_weights: List[str] = [] | |
| if decoder.__class__ != encoder.__class__: | |
| logger.info( | |
| f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." | |
| ) | |
| def tie_encoder_to_decoder_recursively( | |
| decoder_pointer: nn.Module, | |
| encoder_pointer: nn.Module, | |
| module_name: str, | |
| uninitialized_encoder_weights: List[str], | |
| skip_key: str, | |
| depth=0, | |
| ): | |
| assert isinstance(decoder_pointer, nn.Module) and isinstance( | |
| encoder_pointer, nn.Module | |
| ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" | |
| if hasattr(decoder_pointer, "weight") and skip_key not in module_name: | |
| assert hasattr(encoder_pointer, "weight") | |
| encoder_pointer.weight = decoder_pointer.weight | |
| if hasattr(decoder_pointer, "bias"): | |
| assert hasattr(encoder_pointer, "bias") | |
| encoder_pointer.bias = decoder_pointer.bias | |
| print(module_name + ' is tied') | |
| return | |
| encoder_modules = encoder_pointer._modules | |
| decoder_modules = decoder_pointer._modules | |
| if len(decoder_modules) > 0: | |
| assert ( | |
| len(encoder_modules) > 0 | |
| ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" | |
| all_encoder_weights = set([ | |
| module_name + "/" + sub_name | |
| for sub_name in encoder_modules.keys() | |
| ]) | |
| encoder_layer_pos = 0 | |
| for name, module in decoder_modules.items(): | |
| if name.isdigit(): | |
| encoder_name = str(int(name) + encoder_layer_pos) | |
| decoder_name = name | |
| if not isinstance( | |
| decoder_modules[decoder_name], | |
| type(encoder_modules[encoder_name])) and len( | |
| encoder_modules) != len(decoder_modules): | |
| # this can happen if the name corresponds to the position in a list module list of layers | |
| # in this case the decoder has added a cross-attention that the encoder does not have | |
| # thus skip this step and subtract one layer pos from encoder | |
| encoder_layer_pos -= 1 | |
| continue | |
| elif name not in encoder_modules: | |
| continue | |
| elif depth > 500: | |
| raise ValueError( | |
| "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." | |
| ) | |
| else: | |
| decoder_name = encoder_name = name | |
| tie_encoder_to_decoder_recursively( | |
| decoder_modules[decoder_name], | |
| encoder_modules[encoder_name], | |
| module_name + "/" + name, | |
| uninitialized_encoder_weights, | |
| skip_key, | |
| depth=depth + 1, | |
| ) | |
| all_encoder_weights.remove(module_name + "/" + encoder_name) | |
| uninitialized_encoder_weights += list(all_encoder_weights) | |
| # tie weights recursively | |
| tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, | |
| uninitialized_encoder_weights, skip_key) | |
| class GroupWiseLinear(nn.Module): | |
| # could be changed to: | |
| # output = torch.einsum('ijk,zjk->ij', x, self.W) | |
| # or output = torch.einsum('ijk,jk->ij', x, self.W[0]) | |
| def __init__(self, num_class, hidden_dim, bias=True): | |
| super().__init__() | |
| self.num_class = num_class | |
| self.hidden_dim = hidden_dim | |
| self.bias = bias | |
| self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) | |
| if bias: | |
| self.b = nn.Parameter(torch.Tensor(1, num_class)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| stdv = 1. / math.sqrt(self.W.size(2)) | |
| for i in range(self.num_class): | |
| self.W[0][i].data.uniform_(-stdv, stdv) | |
| if self.bias: | |
| for i in range(self.num_class): | |
| self.b[0][i].data.uniform_(-stdv, stdv) | |
| def forward(self, x): | |
| # x: B,K,d | |
| x = (self.W * x).sum(-1) | |
| if self.bias: | |
| x = x + self.b | |
| return x | |
| def init_tokenizer(): | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| # tokenizer = BertTokenizer.from_pretrained('/home/notebook/data/group/LowLevelLLM/LLM/bert-base-uncased', local_files_only=True) | |
| tokenizer.add_special_tokens({'bos_token': '[DEC]'}) | |
| tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) | |
| tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] | |
| return tokenizer | |
| def create_vit(vit, | |
| image_size, | |
| use_grad_checkpointing=False, | |
| ckpt_layer=0, | |
| drop_path_rate=0): | |
| assert vit in ['base', 'large'], "vit parameter must be base or large" | |
| if vit == 'base': | |
| vision_width = 768 | |
| visual_encoder = VisionTransformer( | |
| img_size=image_size, | |
| patch_size=16, | |
| embed_dim=vision_width, | |
| depth=12, | |
| num_heads=12, | |
| use_grad_checkpointing=use_grad_checkpointing, | |
| ckpt_layer=ckpt_layer, | |
| drop_path_rate=0 or drop_path_rate) | |
| elif vit == 'large': | |
| vision_width = 1024 | |
| visual_encoder = VisionTransformer( | |
| img_size=image_size, | |
| patch_size=16, | |
| embed_dim=vision_width, | |
| depth=24, | |
| num_heads=16, | |
| use_grad_checkpointing=use_grad_checkpointing, | |
| ckpt_layer=ckpt_layer, | |
| drop_path_rate=0.1 or drop_path_rate) | |
| return visual_encoder, vision_width | |
| def is_url(url_or_filename): | |
| parsed = urlparse(url_or_filename) | |
| return parsed.scheme in ("http", "https") | |
| def load_checkpoint(model, url_or_filename): | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file(url_or_filename, | |
| check_hash=False, | |
| progress=True) | |
| checkpoint = torch.load(cached_file, map_location='cpu') | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location='cpu') | |
| else: | |
| raise RuntimeError('checkpoint url or path is invalid') | |
| state_dict = checkpoint['model'] | |
| state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed( | |
| state_dict['visual_encoder.pos_embed'], model.visual_encoder) | |
| if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): | |
| state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed( | |
| state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m) | |
| for key in model.state_dict().keys(): | |
| if key in state_dict.keys(): | |
| if state_dict[key].shape != model.state_dict()[key].shape: | |
| del state_dict[key] | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| print('load checkpoint from %s' % url_or_filename) | |
| return model, msg | |
| # def load_checkpoint_condition(model, url_or_filename): | |
| def load_checkpoint_swinlarge_condition(model, url_or_filename, kwargs): | |
| if kwargs['image_size'] == 224: | |
| vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' | |
| elif kwargs['image_size'] == 384: | |
| vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' | |
| window_size = read_json(vision_config_path)['window_size'] | |
| print('--------------') | |
| print(url_or_filename) | |
| print('--------------') | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file(url_or_filename, | |
| check_hash=False, | |
| progress=True) | |
| checkpoint = torch.load(cached_file, map_location='cpu') | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location='cpu') | |
| else: | |
| raise RuntimeError('checkpoint url or path is invalid') | |
| state_dict = checkpoint['params'] | |
| for k in list(state_dict.keys()): | |
| if 'relative_position_bias_table' in k: | |
| dst_num_pos = (2 * window_size - 1)**2 | |
| state_dict[k] = interpolate_relative_pos_embed(state_dict[k], | |
| dst_num_pos, | |
| param_name=k) | |
| elif ('relative_position_index' in k) or ('attn_mask' in k): | |
| del state_dict[k] | |
| elif "vision_multi" in k: | |
| state_dict[k.replace("vision_multi", | |
| "tagging_head")] = state_dict.pop(k) | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| print('load checkpoint from %s' % url_or_filename) | |
| return model, msg | |
| def load_checkpoint_swinbase(model, url_or_filename, kwargs): | |
| if kwargs['image_size'] == 224: | |
| vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' | |
| elif kwargs['image_size'] == 384: | |
| vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' | |
| window_size = read_json(vision_config_path)['window_size'] | |
| print('--------------') | |
| print(url_or_filename) | |
| print('--------------') | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file(url_or_filename, | |
| check_hash=False, | |
| progress=True) | |
| checkpoint = torch.load(cached_file, map_location='cpu') | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location='cpu') | |
| else: | |
| raise RuntimeError('checkpoint url or path is invalid') | |
| state_dict = checkpoint['model'] | |
| for k in list(state_dict.keys()): | |
| if 'relative_position_bias_table' in k: | |
| dst_num_pos = (2 * window_size - 1)**2 | |
| state_dict[k] = interpolate_relative_pos_embed(state_dict[k], | |
| dst_num_pos, | |
| param_name=k) | |
| elif ('relative_position_index' in k) or ('attn_mask' in k): | |
| del state_dict[k] | |
| elif "vision_multi" in k: | |
| state_dict[k.replace("vision_multi", | |
| "tagging_head")] = state_dict.pop(k) | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| print('load checkpoint from %s' % url_or_filename) | |
| return model, msg | |
| def load_checkpoint_swinlarge(model, url_or_filename, kwargs): | |
| if kwargs['image_size'] == 224: | |
| vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' | |
| elif kwargs['image_size'] == 384: | |
| vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' | |
| window_size = read_json(vision_config_path)['window_size'] | |
| print('--------------') | |
| print(url_or_filename) | |
| print('--------------') | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file(url_or_filename, | |
| check_hash=False, | |
| progress=True) | |
| checkpoint = torch.load(cached_file, map_location='cpu') | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location='cpu') | |
| else: | |
| raise RuntimeError('checkpoint url or path is invalid') | |
| state_dict = checkpoint['model'] | |
| for k in list(state_dict.keys()): | |
| if 'relative_position_bias_table' in k: | |
| dst_num_pos = (2 * window_size - 1)**2 | |
| state_dict[k] = interpolate_relative_pos_embed(state_dict[k], | |
| dst_num_pos, | |
| param_name=k) | |
| elif ('relative_position_index' in k) or ('attn_mask' in k): | |
| del state_dict[k] | |
| elif "vision_multi" in k: | |
| state_dict[k.replace("vision_multi", | |
| "tagging_head")] = state_dict.pop(k) | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| print('load checkpoint from %s' % url_or_filename) | |
| return model, msg | |
| # Tagging loss function | |
| # copy from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py | |
| class AsymmetricLoss(nn.Module): | |
| def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): | |
| super(AsymmetricLoss, self).__init__() | |
| self.gamma_neg = gamma_neg | |
| self.gamma_pos = gamma_pos | |
| self.clip = clip | |
| self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss | |
| self.eps = eps | |
| def forward(self, x, y): | |
| """" | |
| Parameters | |
| ---------- | |
| x: input logits | |
| y: targets (multi-label binarized vector) | |
| """ | |
| # Calculating Probabilities | |
| x_sigmoid = torch.sigmoid(x) | |
| xs_pos = x_sigmoid | |
| xs_neg = 1 - x_sigmoid | |
| # Asymmetric Clipping | |
| if self.clip is not None and self.clip > 0: | |
| xs_neg = (xs_neg + self.clip).clamp(max=1) | |
| # Basic CE calculation | |
| los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) | |
| los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) | |
| loss = los_pos + los_neg | |
| # Asymmetric Focusing | |
| if self.gamma_neg > 0 or self.gamma_pos > 0: | |
| if self.disable_torch_grad_focal_loss: | |
| torch.set_grad_enabled(False) | |
| pt0 = xs_pos * y | |
| pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p | |
| pt = pt0 + pt1 | |
| one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) | |
| one_sided_w = torch.pow(1 - pt, one_sided_gamma) | |
| if self.disable_torch_grad_focal_loss: | |
| torch.set_grad_enabled(True) | |
| loss *= one_sided_w | |
| return -loss.sum() |