MasaTate commited on
Commit
42f3691
·
1 Parent(s): 4f6c34a

Add histogram_matching.py

Browse files
__pycache__/AdaIN.cpython-38.pyc ADDED
Binary file (1.81 kB). View file
 
__pycache__/Network.cpython-38.pyc ADDED
Binary file (2.01 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.26 kB). View file
 
histogram_matching.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision.utils import save_image
4
+ from utils import transform
5
+ content_path = './images/content/brad_pitt.jpg'
6
+ style_path = './images/art/flower_of_life.jpg'
7
+ output_path = './results/output_1.jpg'
8
+
9
+ #Prepare image transform
10
+ t = transform(512)
11
+
12
+ content_image = Image.open(content_path)
13
+ style_image = Image.open(style_path)
14
+ content_tensor = t(content_image).unsqueeze(0)
15
+ style_tensor = t(style_image).unsqueeze(0)
16
+ #content_tensor = torch.tensor([[[[2.0,2.0,2.0],[4.0,4.0,4.0],[3.0,3.0,3.0]],[[5.0,5.0,5.0],[3.0,3.0,3.0],[1.0,1.0,1.0]],[[3.0,3.0,3.0],[1.0,1.0,1.0],[2.0,2.0,2.0]]]])
17
+ print(content_tensor)
18
+ print(content_tensor.shape)
19
+ std_ct_1, mean_ct_1 = torch.var_mean(content_tensor[0][0],unbiased = False)
20
+ std_ct_2, mean_ct_2 = torch.var_mean(content_tensor[0][1],unbiased = False)
21
+ std_ct_3, mean_ct_3 = torch.var_mean(content_tensor[0][2],unbiased = False)
22
+ std_st_1, mean_st_1 = torch.var_mean(style_tensor[0][0],unbiased = False)
23
+ std_st_2, mean_st_2 = torch.var_mean(style_tensor[0][1],unbiased = False)
24
+ std_st_3, mean_st_3 = torch.var_mean(style_tensor[0][2],unbiased = False)
25
+ style_tensor[0][0] = (style_tensor[0][0] - mean_st_1) * std_ct_1 / std_st_1 + mean_ct_1
26
+ style_tensor[0][1] = (style_tensor[0][1] - mean_st_2) * std_ct_2 / std_st_2 + mean_ct_2
27
+ style_tensor[0][2] = (style_tensor[0][2] - mean_st_3) * std_ct_3 / std_st_3 + mean_ct_3
28
+ #print(content_tensor)
29
+ output_tensor = style_tensor
30
+ save_image(output_tensor,output_path)