Spaces:
Running
on
Zero
Running
on
Zero
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import ( | |
| ResNet50_Weights, | |
| VGG16_BN_Weights, | |
| VGG16_Weights, | |
| resnet50, | |
| vgg16, | |
| vgg16_bn, | |
| ) | |
| from engine.BiRefNet.config import Config | |
| from engine.BiRefNet.models.backbones.pvt_v2 import ( | |
| pvt_v2_b0, | |
| pvt_v2_b1, | |
| pvt_v2_b2, | |
| pvt_v2_b5, | |
| ) | |
| from engine.BiRefNet.models.backbones.swin_v1 import ( | |
| swin_v1_b, | |
| swin_v1_l, | |
| swin_v1_s, | |
| swin_v1_t, | |
| ) | |
| config = Config() | |
| def build_backbone(bb_name, pretrained=True, params_settings=""): | |
| if bb_name == "vgg16": | |
| bb_net = list( | |
| vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children() | |
| )[0] | |
| bb = nn.Sequential( | |
| OrderedDict( | |
| { | |
| "conv1": bb_net[:4], | |
| "conv2": bb_net[4:9], | |
| "conv3": bb_net[9:16], | |
| "conv4": bb_net[16:23], | |
| } | |
| ) | |
| ) | |
| elif bb_name == "vgg16bn": | |
| bb_net = list( | |
| vgg16_bn( | |
| pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None | |
| ).children() | |
| )[0] | |
| bb = nn.Sequential( | |
| OrderedDict( | |
| { | |
| "conv1": bb_net[:6], | |
| "conv2": bb_net[6:13], | |
| "conv3": bb_net[13:23], | |
| "conv4": bb_net[23:33], | |
| } | |
| ) | |
| ) | |
| elif bb_name == "resnet50": | |
| bb_net = list( | |
| resnet50( | |
| pretrained=ResNet50_Weights.DEFAULT if pretrained else None | |
| ).children() | |
| ) | |
| bb = nn.Sequential( | |
| OrderedDict( | |
| { | |
| "conv1": nn.Sequential(*bb_net[0:3]), | |
| "conv2": bb_net[4], | |
| "conv3": bb_net[5], | |
| "conv4": bb_net[6], | |
| } | |
| ) | |
| ) | |
| else: | |
| bb = eval("{}({})".format(bb_name, params_settings)) | |
| if pretrained: | |
| bb = load_weights(bb, bb_name) | |
| return bb | |
| def load_weights(model, model_name): | |
| save_model = torch.load( | |
| config.weights[model_name], map_location="cpu", weights_only=True | |
| ) | |
| model_dict = model.state_dict() | |
| state_dict = { | |
| k: v if v.size() == model_dict[k].size() else model_dict[k] | |
| for k, v in save_model.items() | |
| if k in model_dict.keys() | |
| } | |
| # to ignore the weights with mismatched size when I modify the backbone itself. | |
| if not state_dict: | |
| save_model_keys = list(save_model.keys()) | |
| sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None | |
| state_dict = { | |
| k: v if v.size() == model_dict[k].size() else model_dict[k] | |
| for k, v in save_model[sub_item].items() | |
| if k in model_dict.keys() | |
| } | |
| if not state_dict or not sub_item: | |
| print( | |
| "Weights are not successully loaded. Check the state dict of weights file." | |
| ) | |
| return None | |
| else: | |
| print( | |
| 'Found correct weights in the "{}" item of loaded state_dict.'.format( | |
| sub_item | |
| ) | |
| ) | |
| model_dict.update(state_dict) | |
| model.load_state_dict(model_dict) | |
| return model | |