tidalove commited on
Commit
02582df
·
verified ·
1 Parent(s): bbbd3d2

Update test_api.py

Browse files
Files changed (1) hide show
  1. test_api.py +28 -28
test_api.py CHANGED
@@ -10,10 +10,10 @@ from torchvision.utils import save_image
10
  from torchvision.transforms import ToPILImage
11
  from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
12
  from glob import glob
 
13
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
 
16
-
17
  def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
18
  """
19
  Given content image and style image, generate feature maps with encoder, apply
@@ -39,17 +39,13 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
39
  mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
40
  return decoder(mix_enc)
41
 
42
- def run_adain(content_dir, style_dir, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth', alpha=1.0):
43
  content_pths = [Path(f) for f in glob(content_dir+'/*')]
44
  style_pths = [Path(f) for f in glob(style_dir+'/*')]
45
 
46
  assert len(content_pths) > 0, 'Failed to load content image'
47
  assert len(style_pths) > 0, 'Failed to load style image'
48
 
49
- # Prepare directory for saving results
50
- out_dir = tempfile.mkdtemp()
51
- os.makedirs(out_dir, exist_ok=True)
52
-
53
  # Load AdaIN model
54
  vgg = torch.load(vgg_pth)
55
  model = AdaINNet(vgg).to(device)
@@ -62,31 +58,35 @@ def run_adain(content_dir, style_dir, vgg_pth='vgg_normalized.pth', decoder_pth=
62
  # Timer
63
  times = []
64
 
65
- for content_pth in content_pths:
66
- content_img = Image.open(content_pth)
67
- content_tensor = t(content_img).unsqueeze(0).to(device)
68
 
69
- for style_pth in style_pths:
70
-
71
- style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
 
72
 
73
- # Start time
74
- tic = time.perf_counter()
75
-
76
- # Execute style transfer
77
- with torch.no_grad():
78
- out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
79
-
80
- # End time
81
- toc = time.perf_counter()
82
- print("Content: " + content_pth.stem + ". Style: " \
83
- + style_pth.stem + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
84
- times.append(toc-tic)
 
 
 
 
85
 
86
- # Save image
87
- out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(alpha)
88
- out_pth += content_pth.suffix
89
- save_image(out_tensor, out_pth)
90
 
91
  # Remove runtime of first iteration because it is flawed for some unknown reason
92
  if len(times) > 1:
 
10
  from torchvision.transforms import ToPILImage
11
  from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
12
  from glob import glob
13
+ from datasets import load_dataset
14
 
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
 
 
17
  def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
18
  """
19
  Given content image and style image, generate feature maps with encoder, apply
 
39
  mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
40
  return decoder(mix_enc)
41
 
42
+ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
43
  content_pths = [Path(f) for f in glob(content_dir+'/*')]
44
  style_pths = [Path(f) for f in glob(style_dir+'/*')]
45
 
46
  assert len(content_pths) > 0, 'Failed to load content image'
47
  assert len(style_pths) > 0, 'Failed to load style image'
48
 
 
 
 
 
49
  # Load AdaIN model
50
  vgg = torch.load(vgg_pth)
51
  model = AdaINNet(vgg).to(device)
 
58
  # Timer
59
  times = []
60
 
61
+ style_ds = load_dataset(style_dataset_pth, split="train")
62
+ # do i need to stick a dataloader around this? idk
 
63
 
64
+ for style_item in style_ds:
65
+ style_img = style_item['image']
66
+ print(style_img)
67
+ style_tensor = t(style_img).unsqueeze(0).to(device)
68
 
69
+ for content_pth in content_pths:
70
+ content_img = Image.open(content_pth)
71
+ content_tensor = t(content_img).unsqueeze(0).to(device)
72
+
73
+ # Start time
74
+ tic = time.perf_counter()
75
+
76
+ # Execute style transfer
77
+ with torch.no_grad():
78
+ out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
79
+
80
+ # End time
81
+ toc = time.perf_counter()
82
+ print("Content: " + content_pth.stem + ". Style: " \
83
+ + style_pth.stem + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
84
+ times.append(toc-tic)
85
 
86
+ # Save image
87
+ out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(alpha)
88
+ out_pth += content_pth.suffix
89
+ save_image(out_tensor, out_pth)
90
 
91
  # Remove runtime of first iteration because it is flawed for some unknown reason
92
  if len(times) > 1: