File size: 2,844 Bytes
a62af7c
 
 
 
 
a6e34e8
a62af7c
 
 
 
 
 
 
02582df
a62af7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68ca7bf
decbd98
68ca7bf
a62af7c
68ca7bf
a62af7c
 
 
 
 
 
 
 
 
 
 
 
 
667d357
a62af7c
68ca7bf
 
 
 
a62af7c
68ca7bf
 
 
 
 
 
 
aa184cd
68ca7bf
 
02582df
decbd98
 
 
a62af7c
decbd98
36754a2
decbd98
36754a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import tempfile
import torch
import time
import numpy as np
import random
from pathlib import Path
from AdaIN import AdaINNet
from PIL import Image
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
from glob import glob
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
	"""
	Given content image and style image, generate feature maps with encoder, apply 
	neural style transfer with adaptive instance normalization, generate output image
	with decoder

	Args:
		content_tensor (torch.FloatTensor): Content image 
		style_tensor (torch.FloatTensor): Style Image
		encoder: Encoder (vgg19) network
		decoder: Decoder network
		alpha (float, default=1.0): Weight of style image feature 
	
	Return:
		output_tensor (torch.FloatTensor): Style Transfer output image
	"""

	content_enc = encoder(content_tensor)
	style_enc = encoder(style_tensor)

	transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
	
	mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
	return decoder(mix_enc)

def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, dataset_size=100, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
	content_pths = [Path(f) for f in glob(content_dir+'/*')]
	num_content_imgs = len(content_pths)

	assert num_content_imgs > 0, 'Failed to load content image'

	# Load AdaIN model
	vgg = torch.load(vgg_pth)
	model = AdaINNet(vgg).to(device)
	model.decoder.load_state_dict(torch.load(decoder_pth))
	model.eval()
	
	# Prepare image transform
	t = transform(512)
	
	# Timer
	times = []

	style_ds = load_dataset(style_dataset_pth, split="train")

	if num_content_imgs * len(style_ds) > dataset_size:
		num_style_per_content = int(np.ceil(dataset_size / num_content_imgs))
	else:
		num_style_per_content = len(style_ds)

	for content_pth in content_pths:
		content_img = Image.open(content_pth)
		content_tensor = t(content_img).unsqueeze(0).to(device)
		indices = random.sample(range(len(style_ds)), num_style_per_content)

		for idx in indices:
			style_img = style_ds[idx]['image']
			if not style_img.mode == "RGB":
				style_img = style_img.convert("RGB")
			style_tensor = t(style_img).unsqueeze(0).to(device)
    
			# Execute style transfer
			with torch.no_grad():
				out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
            
			# Save image
			out_pth = os.path.join(out_dir, content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha) + content_pth.suffix)
			save_image(out_tensor, out_pth)
			print(f"Style transferred image saved to {out_pth}")