MasaTate commited on
Commit
4d4cef7
·
1 Parent(s): 42f3691

add histogram_matching to test.py

Browse files
Files changed (1) hide show
  1. test.py +29 -0
test.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import argparse
 
3
  import torch
4
  import time
5
  import numpy as np
@@ -20,6 +21,7 @@ parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='D
20
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
21
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
22
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
 
23
  args = parser.parse_args()
24
  assert args.content_image or args.content_dir
25
  assert args.style_image or args.style_dir
@@ -54,6 +56,29 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
54
  return decoder(mix_enc)
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def main():
58
  # Read content images and style images
59
  if args.content_image:
@@ -103,6 +128,10 @@ def main():
103
 
104
  style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
105
 
 
 
 
 
106
  # Start time
107
  tic = time.perf_counter()
108
 
 
1
  import os
2
  import argparse
3
+ from turtle import end_fill
4
  import torch
5
  import time
6
  import numpy as np
 
21
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
22
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
23
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
24
+ parser.add_argument('--color_control', action='store_true', help='Preserve content color')
25
  args = parser.parse_args()
26
  assert args.content_image or args.content_dir
27
  assert args.style_image or args.style_dir
 
56
  return decoder(mix_enc)
57
 
58
 
59
+ def linear_histogram_matching(content_tensor, style_tensor):
60
+ """
61
+ Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
62
+
63
+ Args:
64
+ content_tensor (torch.FloatTensor): Content image
65
+ style_tensor (torch.FloatTensor): Style Image
66
+
67
+ Return:
68
+ style_tensor (torch.FloatTensor): histogram matched Style Image
69
+ """
70
+ std_ct_1, mean_ct_1 = torch.var_mean(content_tensor[0][0],unbiased = False)
71
+ std_ct_2, mean_ct_2 = torch.var_mean(content_tensor[0][1],unbiased = False)
72
+ std_ct_3, mean_ct_3 = torch.var_mean(content_tensor[0][2],unbiased = False)
73
+ std_st_1, mean_st_1 = torch.var_mean(style_tensor[0][0],unbiased = False)
74
+ std_st_2, mean_st_2 = torch.var_mean(style_tensor[0][1],unbiased = False)
75
+ std_st_3, mean_st_3 = torch.var_mean(style_tensor[0][2],unbiased = False)
76
+ style_tensor[0][0] = (style_tensor[0][0] - mean_st_1) * std_ct_1 / std_st_1 + mean_ct_1
77
+ style_tensor[0][1] = (style_tensor[0][1] - mean_st_2) * std_ct_2 / std_st_2 + mean_ct_2
78
+ style_tensor[0][2] = (style_tensor[0][2] - mean_st_3) * std_ct_3 / std_st_3 + mean_ct_3
79
+ return style_tensor
80
+
81
+
82
  def main():
83
  # Read content images and style images
84
  if args.content_image:
 
128
 
129
  style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
130
 
131
+ # Linear Histogram Matching if needed
132
+ if args.color_control:
133
+ style_tensor = linear_histogram_matching(content_tensor,style_tensor)
134
+
135
  # Start time
136
  tic = time.perf_counter()
137