File size: 3,176 Bytes
7930ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from PIL import Image, ImageFile
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms 
import matplotlib.pyplot as plt
from pathlib import Path
from glob import glob

def adaptive_instance_normalization(x, y, eps=1e-5):
	"""
	Adaptive Instance Normalization. Perform neural style transfer given content image x
	and style image y.

	Args:
		x (torch.FloatTensor): Content image tensor
		y (torch.FloatTensor): Style image tensor
		eps (float, default=1e-5): Small value to avoid zero division

	Return:
		output (torch.FloatTensor): AdaIN style transferred output
	"""

	mu_x = torch.mean(x, dim=[2, 3])
	mu_y = torch.mean(y, dim=[2, 3])
	mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
	mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)

	sigma_x = torch.std(x, dim=[2, 3])
	sigma_y = torch.std(y, dim=[2, 3])
	sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
	sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps

	return (x - mu_x) / sigma_x * sigma_y  + mu_y

def transform(size):
	"""
	Image preprocess transformation. Resize image and convert to tensor.

	Args:
		size (int): Resize image size

	Return:
		output (torchvision.transforms): Composition of torchvision.transforms steps
	"""
	
	t = []
	t.append(transforms.Resize(size))
	t.append(transforms.ToTensor())
	t = transforms.Compose(t)
	return t

def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
	"""
	Generate and save an image that contains row x col grids of images.

	Args:
		row (int): number of rows
		col (int): number of columns
		images (list of PIL image): list of images.
		height (int) : height of each image (inch)
		width (int) : width of eac image (inch)
		save_pth (str): save file path
	"""

	width = col * width
	height = row * height
	plt.figure(figsize=(width, height))
	for i, image in enumerate(images):
		plt.subplot(row, col, i+1)
		plt.imshow(image)
		plt.axis('off')
		plt.subplots_adjust(wspace=0.01, hspace=0.01)
	plt.savefig(save_pth)


class TrainSet(Dataset):
	"""
	Build Training dataset
	"""
	def __init__(self, content_dir, style_dir, crop_size = 256):
		super().__init__()

		self.content_files = [Path(f) for f in glob(content_dir+'/*')]
		self.style_files = [Path(f) for f in glob(style_dir+'/*')]
		
		self.transform = transforms.Compose([
			transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
			transforms.RandomCrop(crop_size),
			transforms.ToTensor(),
			transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
			])

		Image.MAX_IMAGE_PIXELS = None
		ImageFile.LOAD_TRUNCATED_IMAGES = True
	
	def __len__(self):
		return min(len(self.style_files), len(self.content_files))

	def __getitem__(self, index):
		content_img = Image.open(self.content_files[index]).convert('RGB')
		style_img = Image.open(self.style_files[index]).convert('RGB')
	
		content_sample = self.transform(content_img)
		style_sample = self.transform(style_img)

		return content_sample, style_sample

class Range(object):
	"""
	Helper class for input argument range restriction
	"""
	def __init__(self, start, end):
		self.start = start
		self.end = end
	def __eq__(self, other):
		return self.start <= other <= self.end