tidalove commited on
Commit
36754a2
·
verified ·
1 Parent(s): d69b2fc

update output location

Browse files
Files changed (1) hide show
  1. test_api.py +2 -17
test_api.py CHANGED
@@ -76,26 +76,11 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, dataset_size=1
76
  style_img = style_img.convert("RGB")
77
  style_tensor = t(style_img).unsqueeze(0).to(device)
78
 
79
- # Start time
80
- tic = time.perf_counter()
81
-
82
  # Execute style transfer
83
  with torch.no_grad():
84
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
85
-
86
- # End time
87
- toc = time.perf_counter()
88
- print("Content: " + content_pth.stem + ". Style: " \
89
- + str(idx) + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
90
- times.append(toc-tic)
91
 
92
  # Save image
93
- out_pth = out_dir + content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha)
94
- out_pth += content_pth.suffix
95
  save_image(out_tensor, out_pth)
96
-
97
- # Remove runtime of first iteration because it is flawed for some unknown reason
98
- if len(times) > 1:
99
- times.pop(0)
100
- avg = sum(times)/len(times)
101
- print("Average style transfer time: %.4f seconds" % (avg))
 
76
  style_img = style_img.convert("RGB")
77
  style_tensor = t(style_img).unsqueeze(0).to(device)
78
 
 
 
 
79
  # Execute style transfer
80
  with torch.no_grad():
81
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
 
 
 
 
 
 
82
 
83
  # Save image
84
+ out_pth = os.path.join(out_dir, content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha) + content_pth.suffix)
 
85
  save_image(out_tensor, out_pth)
86
+ print(f"Style transferred image saved to {out_pth}")