add histogram_matching to test.py
Browse files
test.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import argparse
|
|
|
|
| 3 |
import torch
|
| 4 |
import time
|
| 5 |
import numpy as np
|
|
@@ -20,6 +21,7 @@ parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='D
|
|
| 20 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
| 21 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
| 22 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
|
|
|
| 23 |
args = parser.parse_args()
|
| 24 |
assert args.content_image or args.content_dir
|
| 25 |
assert args.style_image or args.style_dir
|
|
@@ -54,6 +56,29 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
|
| 54 |
return decoder(mix_enc)
|
| 55 |
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def main():
|
| 58 |
# Read content images and style images
|
| 59 |
if args.content_image:
|
|
@@ -103,6 +128,10 @@ def main():
|
|
| 103 |
|
| 104 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
# Start time
|
| 107 |
tic = time.perf_counter()
|
| 108 |
|
|
|
|
| 1 |
import os
|
| 2 |
import argparse
|
| 3 |
+
from turtle import end_fill
|
| 4 |
import torch
|
| 5 |
import time
|
| 6 |
import numpy as np
|
|
|
|
| 21 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
| 22 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
| 23 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
| 24 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
| 25 |
args = parser.parse_args()
|
| 26 |
assert args.content_image or args.content_dir
|
| 27 |
assert args.style_image or args.style_dir
|
|
|
|
| 56 |
return decoder(mix_enc)
|
| 57 |
|
| 58 |
|
| 59 |
+
def linear_histogram_matching(content_tensor, style_tensor):
|
| 60 |
+
"""
|
| 61 |
+
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
content_tensor (torch.FloatTensor): Content image
|
| 65 |
+
style_tensor (torch.FloatTensor): Style Image
|
| 66 |
+
|
| 67 |
+
Return:
|
| 68 |
+
style_tensor (torch.FloatTensor): histogram matched Style Image
|
| 69 |
+
"""
|
| 70 |
+
std_ct_1, mean_ct_1 = torch.var_mean(content_tensor[0][0],unbiased = False)
|
| 71 |
+
std_ct_2, mean_ct_2 = torch.var_mean(content_tensor[0][1],unbiased = False)
|
| 72 |
+
std_ct_3, mean_ct_3 = torch.var_mean(content_tensor[0][2],unbiased = False)
|
| 73 |
+
std_st_1, mean_st_1 = torch.var_mean(style_tensor[0][0],unbiased = False)
|
| 74 |
+
std_st_2, mean_st_2 = torch.var_mean(style_tensor[0][1],unbiased = False)
|
| 75 |
+
std_st_3, mean_st_3 = torch.var_mean(style_tensor[0][2],unbiased = False)
|
| 76 |
+
style_tensor[0][0] = (style_tensor[0][0] - mean_st_1) * std_ct_1 / std_st_1 + mean_ct_1
|
| 77 |
+
style_tensor[0][1] = (style_tensor[0][1] - mean_st_2) * std_ct_2 / std_st_2 + mean_ct_2
|
| 78 |
+
style_tensor[0][2] = (style_tensor[0][2] - mean_st_3) * std_ct_3 / std_st_3 + mean_ct_3
|
| 79 |
+
return style_tensor
|
| 80 |
+
|
| 81 |
+
|
| 82 |
def main():
|
| 83 |
# Read content images and style images
|
| 84 |
if args.content_image:
|
|
|
|
| 128 |
|
| 129 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
| 130 |
|
| 131 |
+
# Linear Histogram Matching if needed
|
| 132 |
+
if args.color_control:
|
| 133 |
+
style_tensor = linear_histogram_matching(content_tensor,style_tensor)
|
| 134 |
+
|
| 135 |
# Start time
|
| 136 |
tic = time.perf_counter()
|
| 137 |
|