Spaces:
Build error
Build error
init
Browse files- .gitignore +3 -0
- app.py +95 -0
- checkpoint/best_Epoch_exposure.pth +3 -0
- checkpoint/best_Epoch_lol.pth +3 -0
- dark_imgs/1.jpg +0 -0
- dark_imgs/2.jpg +0 -0
- dark_imgs/3.jpg +0 -0
- exposure_imgs/1.jpg +0 -0
- exposure_imgs/2.jpg +0 -0
- exposure_imgs/3.jpeg +0 -0
- model/IAT.py +126 -0
- model/__init__.py +1 -0
- model/blocks.py +281 -0
- model/global_net.py +129 -0
- requirements.txt +5 -0
- test_dark.ipynb +0 -0
- test_exposure.ipynb +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.ipynb_checkpoints
|
| 2 |
+
__pycache__/
|
| 3 |
+
.DS_Store
|
app.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision.transforms import Compose, ToTensor, Normalize, ConvertImageDtype
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
from model import IAT
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def set_example_image(example: list) -> dict:
|
| 17 |
+
return gr.Image.update(value=example[0])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def dark_inference(img):
|
| 21 |
+
model = IAT()
|
| 22 |
+
checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
|
| 23 |
+
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
|
| 24 |
+
model.load_state_dict(state_dict)
|
| 25 |
+
model.eval()
|
| 26 |
+
|
| 27 |
+
transform = Compose([
|
| 28 |
+
ToTensor(),
|
| 29 |
+
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
| 30 |
+
ConvertImageDtype(torch.float)
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
enhanced_img = model(transform(img).unsqueeze(0))
|
| 34 |
+
return enhanced_img[0].permute(1, 2, 0).detach().numpy()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def exposure_inference(img):
|
| 38 |
+
model = IAT()
|
| 39 |
+
checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
|
| 40 |
+
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
|
| 41 |
+
model.load_state_dict(state_dict)
|
| 42 |
+
model.eval()
|
| 43 |
+
|
| 44 |
+
transform = Compose([
|
| 45 |
+
ToTensor(),
|
| 46 |
+
ConvertImageDtype(torch.float)
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
enhanced_img = model(transform(img).unsqueeze(0))
|
| 50 |
+
return enhanced_img[0].permute(1, 2, 0).detach().numpy()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
demo = gr.Blocks()
|
| 54 |
+
with demo:
|
| 55 |
+
gr.Markdown(
|
| 56 |
+
"""
|
| 57 |
+
# IAT
|
| 58 |
+
Gradio demo for <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>IAT</a>: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
|
| 59 |
+
"""
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
with gr.Box():
|
| 63 |
+
with gr.Row():
|
| 64 |
+
with gr.Column():
|
| 65 |
+
with gr.Row():
|
| 66 |
+
input_image = gr.Image(label='Input Image', type='numpy')
|
| 67 |
+
with gr.Row():
|
| 68 |
+
dark_button = gr.Button('Low-light Enhancement')
|
| 69 |
+
with gr.Row():
|
| 70 |
+
exposure_button = gr.Button('Exposure Correction')
|
| 71 |
+
with gr.Column():
|
| 72 |
+
res_image = gr.Image(type='numpy', label='Resutls')
|
| 73 |
+
with gr.Row():
|
| 74 |
+
dark_example_images = gr.Dataset(
|
| 75 |
+
components=[input_image],
|
| 76 |
+
samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']]
|
| 77 |
+
)
|
| 78 |
+
with gr.Row():
|
| 79 |
+
exposure_example_images = gr.Dataset(
|
| 80 |
+
components=[input_image],
|
| 81 |
+
samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpg']]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
gr.Markdown(
|
| 85 |
+
"""
|
| 86 |
+
<p style='text-align: center'><a href='https://arxiv.org/abs/2205.14871' target='_blank'>You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction</a> | <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>Github Repo</a></p>
|
| 87 |
+
"""
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image)
|
| 91 |
+
exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image)
|
| 92 |
+
dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components)
|
| 93 |
+
exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components)
|
| 94 |
+
|
| 95 |
+
demo.launch(enable_queue=True)
|
checkpoint/best_Epoch_exposure.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:15a9494582f028bef996d4af7145860eaa5d67799d2b0625ed93ff8c546ea3ee
|
| 3 |
+
size 427160
|
checkpoint/best_Epoch_lol.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9564b7e10882e688ac817ae6fd164544d05b9f74232de56c33ed7f9dabf7bdc4
|
| 3 |
+
size 427160
|
dark_imgs/1.jpg
ADDED
|
dark_imgs/2.jpg
ADDED
|
dark_imgs/3.jpg
ADDED
|
exposure_imgs/1.jpg
ADDED
|
exposure_imgs/2.jpg
ADDED
|
exposure_imgs/3.jpeg
ADDED
|
model/IAT.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
from timm.models.layers import trunc_normal_
|
| 9 |
+
from .blocks import CBlock_ln, SwinTransformerBlock
|
| 10 |
+
from .global_net import Global_pred
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Local_pred(nn.Module):
|
| 14 |
+
def __init__(self, dim=16, number=4, type='ccc'):
|
| 15 |
+
super(Local_pred, self).__init__()
|
| 16 |
+
# initial convolution
|
| 17 |
+
self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
|
| 18 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 19 |
+
# main blocks
|
| 20 |
+
block = CBlock_ln(dim)
|
| 21 |
+
block_t = SwinTransformerBlock(dim) # head number
|
| 22 |
+
if type =='ccc':
|
| 23 |
+
#blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
|
| 24 |
+
blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
|
| 25 |
+
blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
|
| 26 |
+
elif type =='ttt':
|
| 27 |
+
blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
|
| 28 |
+
elif type =='cct':
|
| 29 |
+
blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
|
| 30 |
+
# block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
|
| 31 |
+
self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
|
| 32 |
+
self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
|
| 33 |
+
|
| 34 |
+
def forward(self, img):
|
| 35 |
+
img1 = self.relu(self.conv1(img))
|
| 36 |
+
mul = self.mul_blocks(img1)
|
| 37 |
+
add = self.add_blocks(img1)
|
| 38 |
+
return mul, add
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Short Cut Connection on Final Layer
|
| 42 |
+
class Local_pred_S(nn.Module):
|
| 43 |
+
def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
|
| 44 |
+
super(Local_pred_S, self).__init__()
|
| 45 |
+
# initial convolution
|
| 46 |
+
self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
|
| 47 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 48 |
+
# main blocks
|
| 49 |
+
block = CBlock_ln(dim)
|
| 50 |
+
block_t = SwinTransformerBlock(dim) # head number
|
| 51 |
+
if type =='ccc':
|
| 52 |
+
blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
|
| 53 |
+
blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
|
| 54 |
+
elif type =='ttt':
|
| 55 |
+
blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
|
| 56 |
+
elif type =='cct':
|
| 57 |
+
blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
|
| 58 |
+
# block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
|
| 59 |
+
self.mul_blocks = nn.Sequential(*blocks1)
|
| 60 |
+
self.add_blocks = nn.Sequential(*blocks2)
|
| 61 |
+
|
| 62 |
+
self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
|
| 63 |
+
self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
|
| 64 |
+
self.apply(self._init_weights)
|
| 65 |
+
|
| 66 |
+
def _init_weights(self, m):
|
| 67 |
+
if isinstance(m, nn.Linear):
|
| 68 |
+
trunc_normal_(m.weight, std=.02)
|
| 69 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 70 |
+
nn.init.constant_(m.bias, 0)
|
| 71 |
+
elif isinstance(m, nn.LayerNorm):
|
| 72 |
+
nn.init.constant_(m.bias, 0)
|
| 73 |
+
nn.init.constant_(m.weight, 1.0)
|
| 74 |
+
elif isinstance(m, nn.Conv2d):
|
| 75 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 76 |
+
fan_out //= m.groups
|
| 77 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 78 |
+
if m.bias is not None:
|
| 79 |
+
m.bias.data.zero_()
|
| 80 |
+
|
| 81 |
+
def forward(self, img):
|
| 82 |
+
img1 = self.relu(self.conv1(img))
|
| 83 |
+
# short cut connection
|
| 84 |
+
mul = self.mul_blocks(img1) + img1
|
| 85 |
+
add = self.add_blocks(img1) + img1
|
| 86 |
+
mul = self.mul_end(mul)
|
| 87 |
+
add = self.add_end(add)
|
| 88 |
+
return mul, add
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class IAT(nn.Module):
|
| 92 |
+
def __init__(self, in_dim=3, with_global=True, type='lol'):
|
| 93 |
+
super(IAT, self).__init__()
|
| 94 |
+
self.local_net = Local_pred_S(in_dim=in_dim)
|
| 95 |
+
self.with_global = with_global
|
| 96 |
+
if self.with_global:
|
| 97 |
+
self.global_net = Global_pred(in_channels=in_dim, type=type)
|
| 98 |
+
|
| 99 |
+
def apply_color(self, image, ccm):
|
| 100 |
+
shape = image.shape
|
| 101 |
+
image = image.view(-1, 3)
|
| 102 |
+
image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
|
| 103 |
+
image = image.view(shape)
|
| 104 |
+
return torch.clamp(image, 1e-8, 1.0)
|
| 105 |
+
|
| 106 |
+
def forward(self, img_low):
|
| 107 |
+
#print(self.with_global)
|
| 108 |
+
mul, add = self.local_net(img_low)
|
| 109 |
+
img_high = (img_low.mul(mul)).add(add)
|
| 110 |
+
|
| 111 |
+
if not self.with_global:
|
| 112 |
+
return img_high
|
| 113 |
+
else:
|
| 114 |
+
gamma, color = self.global_net(img_low)
|
| 115 |
+
b = img_high.shape[0]
|
| 116 |
+
img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C)
|
| 117 |
+
img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0)
|
| 118 |
+
img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W)
|
| 119 |
+
return img_high
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
img = torch.Tensor(1, 3, 400, 600)
|
| 124 |
+
net = IAT()
|
| 125 |
+
print('total parameters:', sum(param.numel() for param in net.parameters()))
|
| 126 |
+
high = net(img)
|
model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .IAT import IAT
|
model/blocks.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code copy from uniformer source code:
|
| 3 |
+
https://github.com/Sense-X/UniFormer
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from functools import partial
|
| 9 |
+
import math
|
| 10 |
+
from timm.models.vision_transformer import VisionTransformer, _cfg
|
| 11 |
+
from timm.models.registry import register_model
|
| 12 |
+
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
|
| 13 |
+
|
| 14 |
+
# ResMLP's normalization
|
| 15 |
+
class Aff(nn.Module):
|
| 16 |
+
def __init__(self, dim):
|
| 17 |
+
super().__init__()
|
| 18 |
+
# learnable
|
| 19 |
+
self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
|
| 20 |
+
self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
x = x * self.alpha + self.beta
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
# Color Normalization
|
| 27 |
+
class Aff_channel(nn.Module):
|
| 28 |
+
def __init__(self, dim, channel_first = True):
|
| 29 |
+
super().__init__()
|
| 30 |
+
# learnable
|
| 31 |
+
self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
|
| 32 |
+
self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
|
| 33 |
+
self.color = nn.Parameter(torch.eye(dim))
|
| 34 |
+
self.channel_first = channel_first
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
if self.channel_first:
|
| 38 |
+
x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
|
| 39 |
+
x2 = x1 * self.alpha + self.beta
|
| 40 |
+
else:
|
| 41 |
+
x1 = x * self.alpha + self.beta
|
| 42 |
+
x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
|
| 43 |
+
return x2
|
| 44 |
+
|
| 45 |
+
class Mlp(nn.Module):
|
| 46 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 47 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 48 |
+
super().__init__()
|
| 49 |
+
out_features = out_features or in_features
|
| 50 |
+
hidden_features = hidden_features or in_features
|
| 51 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 52 |
+
self.act = act_layer()
|
| 53 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 54 |
+
self.drop = nn.Dropout(drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
x = self.fc1(x)
|
| 58 |
+
x = self.act(x)
|
| 59 |
+
x = self.drop(x)
|
| 60 |
+
x = self.fc2(x)
|
| 61 |
+
x = self.drop(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
class CMlp(nn.Module):
|
| 65 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 66 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 67 |
+
super().__init__()
|
| 68 |
+
out_features = out_features or in_features
|
| 69 |
+
hidden_features = hidden_features or in_features
|
| 70 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
| 71 |
+
self.act = act_layer()
|
| 72 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
| 73 |
+
self.drop = nn.Dropout(drop)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
x = self.fc1(x)
|
| 77 |
+
x = self.act(x)
|
| 78 |
+
x = self.drop(x)
|
| 79 |
+
x = self.fc2(x)
|
| 80 |
+
x = self.drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class CBlock_ln(nn.Module):
|
| 84 |
+
def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 85 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
| 88 |
+
#self.norm1 = Aff_channel(dim)
|
| 89 |
+
self.norm1 = norm_layer(dim)
|
| 90 |
+
self.conv1 = nn.Conv2d(dim, dim, 1)
|
| 91 |
+
self.conv2 = nn.Conv2d(dim, dim, 1)
|
| 92 |
+
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
| 93 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 94 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 95 |
+
#self.norm2 = Aff_channel(dim)
|
| 96 |
+
self.norm2 = norm_layer(dim)
|
| 97 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 98 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
|
| 99 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
|
| 100 |
+
self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
x = x + self.pos_embed(x)
|
| 104 |
+
B, C, H, W = x.shape
|
| 105 |
+
#print(x.shape)
|
| 106 |
+
norm_x = x.flatten(2).transpose(1, 2)
|
| 107 |
+
#print(norm_x.shape)
|
| 108 |
+
norm_x = self.norm1(norm_x)
|
| 109 |
+
norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
|
| 113 |
+
norm_x = x.flatten(2).transpose(1, 2)
|
| 114 |
+
norm_x = self.norm2(norm_x)
|
| 115 |
+
norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
|
| 116 |
+
x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def window_partition(x, window_size):
|
| 121 |
+
"""
|
| 122 |
+
Args:
|
| 123 |
+
x: (B, H, W, C)
|
| 124 |
+
window_size (int): window size
|
| 125 |
+
Returns:
|
| 126 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 127 |
+
"""
|
| 128 |
+
B, H, W, C = x.shape
|
| 129 |
+
#print(x.shape)
|
| 130 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 131 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 132 |
+
return windows
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def window_reverse(windows, window_size, H, W):
|
| 136 |
+
"""
|
| 137 |
+
Args:
|
| 138 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 139 |
+
window_size (int): Window size
|
| 140 |
+
H (int): Height of image
|
| 141 |
+
W (int): Width of image
|
| 142 |
+
Returns:
|
| 143 |
+
x: (B, H, W, C)
|
| 144 |
+
"""
|
| 145 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 146 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 147 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class WindowAttention(nn.Module):
|
| 152 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 153 |
+
It supports both of shifted and non-shifted window.
|
| 154 |
+
Args:
|
| 155 |
+
dim (int): Number of input channels.
|
| 156 |
+
window_size (tuple[int]): The height and width of the window.
|
| 157 |
+
num_heads (int): Number of attention heads.
|
| 158 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 159 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 160 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 161 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.dim = dim
|
| 167 |
+
self.window_size = window_size # Wh, Ww
|
| 168 |
+
self.num_heads = num_heads
|
| 169 |
+
head_dim = dim // num_heads
|
| 170 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 171 |
+
|
| 172 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 173 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 174 |
+
self.proj = nn.Linear(dim, dim)
|
| 175 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 176 |
+
|
| 177 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
B_, N, C = x.shape
|
| 181 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 182 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 183 |
+
|
| 184 |
+
q = q * self.scale
|
| 185 |
+
attn = (q @ k.transpose(-2, -1))
|
| 186 |
+
|
| 187 |
+
attn = self.softmax(attn)
|
| 188 |
+
|
| 189 |
+
attn = self.attn_drop(attn)
|
| 190 |
+
|
| 191 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 192 |
+
x = self.proj(x)
|
| 193 |
+
x = self.proj_drop(x)
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
## Layer_norm, Aff_norm, Aff_channel_norm
|
| 197 |
+
class SwinTransformerBlock(nn.Module):
|
| 198 |
+
r""" Swin Transformer Block.
|
| 199 |
+
Args:
|
| 200 |
+
dim (int): Number of input channels.
|
| 201 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 202 |
+
num_heads (int): Number of attention heads.
|
| 203 |
+
window_size (int): Window size.
|
| 204 |
+
shift_size (int): Shift size for SW-MSA.
|
| 205 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 206 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 207 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 208 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 209 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 210 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 211 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 212 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(self, dim, num_heads=2, window_size=8, shift_size=0,
|
| 216 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 217 |
+
act_layer=nn.GELU, norm_layer=Aff_channel):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.dim = dim
|
| 220 |
+
self.num_heads = num_heads
|
| 221 |
+
self.window_size = window_size
|
| 222 |
+
self.shift_size = shift_size
|
| 223 |
+
self.mlp_ratio = mlp_ratio
|
| 224 |
+
|
| 225 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
| 226 |
+
#self.norm1 = norm_layer(dim)
|
| 227 |
+
self.norm1 = norm_layer(dim)
|
| 228 |
+
self.attn = WindowAttention(
|
| 229 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 230 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 231 |
+
|
| 232 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 233 |
+
#self.norm2 = norm_layer(dim)
|
| 234 |
+
self.norm2 = norm_layer(dim)
|
| 235 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 236 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
x = x + self.pos_embed(x)
|
| 240 |
+
B, C, H, W = x.shape
|
| 241 |
+
x = x.flatten(2).transpose(1, 2)
|
| 242 |
+
|
| 243 |
+
shortcut = x
|
| 244 |
+
x = self.norm1(x)
|
| 245 |
+
x = x.view(B, H, W, C)
|
| 246 |
+
|
| 247 |
+
# cyclic shift
|
| 248 |
+
if self.shift_size > 0:
|
| 249 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 250 |
+
else:
|
| 251 |
+
shifted_x = x
|
| 252 |
+
|
| 253 |
+
# partition windows
|
| 254 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 255 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 256 |
+
|
| 257 |
+
# W-MSA/SW-MSA
|
| 258 |
+
attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
|
| 259 |
+
|
| 260 |
+
# merge windows
|
| 261 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 262 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 263 |
+
|
| 264 |
+
x = shifted_x
|
| 265 |
+
x = x.view(B, H * W, C)
|
| 266 |
+
|
| 267 |
+
# FFN
|
| 268 |
+
x = shortcut + self.drop_path(x)
|
| 269 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 270 |
+
x = x.transpose(1, 2).reshape(B, C, H, W)
|
| 271 |
+
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if __name__ == "__main__":
|
| 276 |
+
os.environ['CUDA_VISIBLE_DEVICES']='1'
|
| 277 |
+
cb_blovk = CBlock_ln(dim = 16)
|
| 278 |
+
x = torch.Tensor(1, 16, 400, 600)
|
| 279 |
+
swin = SwinTransformerBlock(dim=16, num_heads=4)
|
| 280 |
+
x = cb_blovk(x)
|
| 281 |
+
print(x.shape)
|
model/global_net.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imp
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
|
| 5 |
+
import os
|
| 6 |
+
from .blocks import Mlp
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class query_Attention(nn.Module):
|
| 10 |
+
def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.num_heads = num_heads
|
| 13 |
+
head_dim = dim // num_heads
|
| 14 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 15 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 16 |
+
|
| 17 |
+
self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
|
| 18 |
+
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
| 19 |
+
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
| 20 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 21 |
+
self.proj = nn.Linear(dim, dim)
|
| 22 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
B, N, C = x.shape
|
| 26 |
+
k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 27 |
+
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 28 |
+
|
| 29 |
+
q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 30 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 31 |
+
attn = attn.softmax(dim=-1)
|
| 32 |
+
attn = self.attn_drop(attn)
|
| 33 |
+
|
| 34 |
+
x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
|
| 35 |
+
x = self.proj(x)
|
| 36 |
+
x = self.proj_drop(x)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class query_SABlock(nn.Module):
|
| 41 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 42 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
| 45 |
+
self.norm1 = norm_layer(dim)
|
| 46 |
+
self.attn = query_Attention(
|
| 47 |
+
dim,
|
| 48 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 49 |
+
attn_drop=attn_drop, proj_drop=drop)
|
| 50 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 51 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 52 |
+
self.norm2 = norm_layer(dim)
|
| 53 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 54 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
x = x + self.pos_embed(x)
|
| 58 |
+
x = x.flatten(2).transpose(1, 2)
|
| 59 |
+
x = self.drop_path(self.attn(self.norm1(x)))
|
| 60 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class conv_embedding(nn.Module):
|
| 65 |
+
def __init__(self, in_channels, out_channels):
|
| 66 |
+
super(conv_embedding, self).__init__()
|
| 67 |
+
self.proj = nn.Sequential(
|
| 68 |
+
nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
|
| 69 |
+
nn.BatchNorm2d(out_channels // 2),
|
| 70 |
+
nn.GELU(),
|
| 71 |
+
# nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
| 72 |
+
# nn.BatchNorm2d(out_channels // 2),
|
| 73 |
+
# nn.GELU(),
|
| 74 |
+
nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
|
| 75 |
+
nn.BatchNorm2d(out_channels),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Global_pred(nn.Module):
|
| 84 |
+
def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'):
|
| 85 |
+
super(Global_pred, self).__init__()
|
| 86 |
+
if type == 'exp':
|
| 87 |
+
self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction
|
| 88 |
+
else:
|
| 89 |
+
self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True)
|
| 90 |
+
self.color_base = nn.Parameter(torch.eye((3)), requires_grad=True) # basic color matrix
|
| 91 |
+
# main blocks
|
| 92 |
+
self.conv_large = conv_embedding(in_channels, out_channels)
|
| 93 |
+
self.generator = query_SABlock(dim=out_channels, num_heads=num_heads)
|
| 94 |
+
self.gamma_linear = nn.Linear(out_channels, 1)
|
| 95 |
+
self.color_linear = nn.Linear(out_channels, 1)
|
| 96 |
+
|
| 97 |
+
self.apply(self._init_weights)
|
| 98 |
+
|
| 99 |
+
for name, p in self.named_parameters():
|
| 100 |
+
if name == 'generator.attn.v.weight':
|
| 101 |
+
nn.init.constant_(p, 0)
|
| 102 |
+
|
| 103 |
+
def _init_weights(self, m):
|
| 104 |
+
if isinstance(m, nn.Linear):
|
| 105 |
+
trunc_normal_(m.weight, std=.02)
|
| 106 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 107 |
+
nn.init.constant_(m.bias, 0)
|
| 108 |
+
elif isinstance(m, nn.LayerNorm):
|
| 109 |
+
nn.init.constant_(m.bias, 0)
|
| 110 |
+
nn.init.constant_(m.weight, 1.0)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
#print(self.gamma_base)
|
| 115 |
+
x = self.conv_large(x)
|
| 116 |
+
x = self.generator(x)
|
| 117 |
+
gamma, color = x[:, 0].unsqueeze(1), x[:, 1:]
|
| 118 |
+
gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base
|
| 119 |
+
#print(self.gamma_base, self.gamma_linear(gamma))
|
| 120 |
+
color = self.color_linear(color).squeeze(-1).view(-1, 3, 3) + self.color_base
|
| 121 |
+
return gamma, color
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
os.environ['CUDA_VISIBLE_DEVICES']='3'
|
| 125 |
+
#net = Local_pred_new().cuda()
|
| 126 |
+
img = torch.Tensor(8, 3, 400, 600)
|
| 127 |
+
global_net = Global_pred()
|
| 128 |
+
gamma, color = global_net(img)
|
| 129 |
+
print(gamma.shape, color.shape)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
timm
|
| 4 |
+
Pillow
|
| 5 |
+
opencv-python
|
test_dark.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_exposure.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|