Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import os | |
| import math | |
| from timm.models.layers import trunc_normal_ | |
| from .blocks import CBlock_ln, SwinTransformerBlock | |
| from .global_net import Global_pred | |
| class Local_pred(nn.Module): | |
| def __init__(self, dim=16, number=4, type='ccc'): | |
| super(Local_pred, self).__init__() | |
| # initial convolution | |
| self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1) | |
| self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| # main blocks | |
| block = CBlock_ln(dim) | |
| block_t = SwinTransformerBlock(dim) # head number | |
| if type =='ccc': | |
| #blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)] | |
| blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] | |
| blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] | |
| elif type =='ttt': | |
| blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)] | |
| elif type =='cct': | |
| blocks1, blocks2 = [block, block, block_t], [block, block, block_t] | |
| # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)] | |
| self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU()) | |
| self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh()) | |
| def forward(self, img): | |
| img1 = self.relu(self.conv1(img)) | |
| mul = self.mul_blocks(img1) | |
| add = self.add_blocks(img1) | |
| return mul, add | |
| # Short Cut Connection on Final Layer | |
| class Local_pred_S(nn.Module): | |
| def __init__(self, in_dim=3, dim=16, number=4, type='ccc'): | |
| super(Local_pred_S, self).__init__() | |
| # initial convolution | |
| self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1) | |
| self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| # main blocks | |
| block = CBlock_ln(dim) | |
| block_t = SwinTransformerBlock(dim) # head number | |
| if type =='ccc': | |
| blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] | |
| blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] | |
| elif type =='ttt': | |
| blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)] | |
| elif type =='cct': | |
| blocks1, blocks2 = [block, block, block_t], [block, block, block_t] | |
| # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)] | |
| self.mul_blocks = nn.Sequential(*blocks1) | |
| self.add_blocks = nn.Sequential(*blocks2) | |
| self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU()) | |
| self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh()) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| elif isinstance(m, nn.Conv2d): | |
| fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| fan_out //= m.groups | |
| m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| def forward(self, img): | |
| img1 = self.relu(self.conv1(img)) | |
| # short cut connection | |
| mul = self.mul_blocks(img1) + img1 | |
| add = self.add_blocks(img1) + img1 | |
| mul = self.mul_end(mul) | |
| add = self.add_end(add) | |
| return mul, add | |
| class IAT(nn.Module): | |
| def __init__(self, in_dim=3, with_global=True, type='lol'): | |
| super(IAT, self).__init__() | |
| self.local_net = Local_pred_S(in_dim=in_dim) | |
| self.with_global = with_global | |
| if self.with_global: | |
| self.global_net = Global_pred(in_channels=in_dim, type=type) | |
| def apply_color(self, image, ccm): | |
| shape = image.shape | |
| image = image.view(-1, 3) | |
| image = torch.tensordot(image, ccm, dims=[[-1], [-1]]) | |
| image = image.view(shape) | |
| return torch.clamp(image, 1e-8, 1.0) | |
| def forward(self, img_low): | |
| #print(self.with_global) | |
| mul, add = self.local_net(img_low) | |
| img_high = (img_low.mul(mul)).add(add) | |
| if not self.with_global: | |
| return img_high | |
| else: | |
| gamma, color = self.global_net(img_low) | |
| b = img_high.shape[0] | |
| img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C) | |
| img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0) | |
| img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W) | |
| return img_high | |
| if __name__ == "__main__": | |
| img = torch.Tensor(1, 3, 400, 600) | |
| net = IAT() | |
| print('total parameters:', sum(param.numel() for param in net.parameters())) | |
| high = net(img) |