File size: 2,844 Bytes
a62af7c a6e34e8 a62af7c 02582df a62af7c 68ca7bf decbd98 68ca7bf a62af7c 68ca7bf a62af7c 667d357 a62af7c 68ca7bf a62af7c 68ca7bf aa184cd 68ca7bf 02582df decbd98 a62af7c decbd98 36754a2 decbd98 36754a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import tempfile
import torch
import time
import numpy as np
import random
from pathlib import Path
from AdaIN import AdaINNet
from PIL import Image
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
from glob import glob
from datasets import load_dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
"""
Given content image and style image, generate feature maps with encoder, apply
neural style transfer with adaptive instance normalization, generate output image
with decoder
Args:
content_tensor (torch.FloatTensor): Content image
style_tensor (torch.FloatTensor): Style Image
encoder: Encoder (vgg19) network
decoder: Decoder network
alpha (float, default=1.0): Weight of style image feature
Return:
output_tensor (torch.FloatTensor): Style Transfer output image
"""
content_enc = encoder(content_tensor)
style_enc = encoder(style_tensor)
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
return decoder(mix_enc)
def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, dataset_size=100, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
content_pths = [Path(f) for f in glob(content_dir+'/*')]
num_content_imgs = len(content_pths)
assert num_content_imgs > 0, 'Failed to load content image'
# Load AdaIN model
vgg = torch.load(vgg_pth)
model = AdaINNet(vgg).to(device)
model.decoder.load_state_dict(torch.load(decoder_pth))
model.eval()
# Prepare image transform
t = transform(512)
# Timer
times = []
style_ds = load_dataset(style_dataset_pth, split="train")
if num_content_imgs * len(style_ds) > dataset_size:
num_style_per_content = int(np.ceil(dataset_size / num_content_imgs))
else:
num_style_per_content = len(style_ds)
for content_pth in content_pths:
content_img = Image.open(content_pth)
content_tensor = t(content_img).unsqueeze(0).to(device)
indices = random.sample(range(len(style_ds)), num_style_per_content)
for idx in indices:
style_img = style_ds[idx]['image']
if not style_img.mode == "RGB":
style_img = style_img.convert("RGB")
style_tensor = t(style_img).unsqueeze(0).to(device)
# Execute style transfer
with torch.no_grad():
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
# Save image
out_pth = os.path.join(out_dir, content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha) + content_pth.suffix)
save_image(out_tensor, out_pth)
print(f"Style transferred image saved to {out_pth}") |