import os import tempfile import torch import time import numpy as np from pathlib import Path from AdaIN import AdaINNet from PIL import Image from torchvision.utils import save_image from torchvision.transforms import ToPILImage from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range from glob import glob from datasets import load_dataset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0): """ Given content image and style image, generate feature maps with encoder, apply neural style transfer with adaptive instance normalization, generate output image with decoder Args: content_tensor (torch.FloatTensor): Content image style_tensor (torch.FloatTensor): Style Image encoder: Encoder (vgg19) network decoder: Decoder network alpha (float, default=1.0): Weight of style image feature Return: output_tensor (torch.FloatTensor): Style Transfer output image """ content_enc = encoder(content_tensor) style_enc = encoder(style_tensor) transfer_enc = adaptive_instance_normalization(content_enc, style_enc) mix_enc = alpha * transfer_enc + (1-alpha) * content_enc return decoder(mix_enc) def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'): content_pths = [Path(f) for f in glob(content_dir+'/*')] style_pths = [Path(f) for f in glob(style_dir+'/*')] assert len(content_pths) > 0, 'Failed to load content image' assert len(style_pths) > 0, 'Failed to load style image' # Load AdaIN model vgg = torch.load(vgg_pth) model = AdaINNet(vgg).to(device) model.decoder.load_state_dict(torch.load(decoder_pth)) model.eval() # Prepare image transform t = transform(512) # Timer times = [] style_ds = load_dataset(style_dataset_pth, split="train") # do i need to stick a dataloader around this? idk for style_idx, style_item in enumerate(style_ds): style_img = style_item['image'] print(style_img) style_tensor = t(style_img).unsqueeze(0).to(device) for content_pth in content_pths: content_img = Image.open(content_pth) content_tensor = t(content_img).unsqueeze(0).to(device) # Start time tic = time.perf_counter() # Execute style transfer with torch.no_grad(): out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu() # End time toc = time.perf_counter() print("Content: " + content_pth.stem + ". Style: " \ + str(style_idx) + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic)) times.append(toc-tic) # Save image out_pth = out_dir + content_pth.stem + '_style_' + str(style_idx) + '_alpha' + str(alpha) out_pth += content_pth.suffix save_image(out_tensor, out_pth) # Remove runtime of first iteration because it is flawed for some unknown reason if len(times) > 1: times.pop(0) avg = sum(times)/len(times) print("Average style transfer time: %.4f seconds" % (avg))