tidalove commited on
Commit
decbd98
·
verified ·
1 Parent(s): 4a3a637

Update test_api.py

Browse files
Files changed (1) hide show
  1. test_api.py +24 -24
test_api.py CHANGED
@@ -40,10 +40,10 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
40
  return decoder(mix_enc)
41
 
42
  def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
43
- content_pths = [Path(f) for f in glob(content_dir+'/*')]
44
- style_pths = [Path(f) for f in glob(style_dir+'/*')]
45
 
46
- assert len(content_pths) > 0, 'Failed to load content image'
47
  assert len(style_pths) > 0, 'Failed to load style image'
48
 
49
  # Load AdaIN model
@@ -59,34 +59,34 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_n
59
  times = []
60
 
61
  style_ds = load_dataset(style_dataset_pth, split="train")
62
- # do i need to stick a dataloader around this? idk
63
 
64
  for style_idx, style_item in enumerate(style_ds):
65
- style_img = style_item['image']
66
- print(style_img)
67
- style_tensor = t(style_img).unsqueeze(0).to(device)
68
 
69
- for content_pth in content_pths:
70
- content_img = Image.open(content_pth)
71
- content_tensor = t(content_img).unsqueeze(0).to(device)
72
 
73
- # Start time
74
- tic = time.perf_counter()
75
 
76
- # Execute style transfer
77
- with torch.no_grad():
78
- out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
79
 
80
- # End time
81
- toc = time.perf_counter()
82
- print("Content: " + content_pth.stem + ". Style: " \
83
- + str(style_idx) + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
84
- times.append(toc-tic)
85
 
86
- # Save image
87
- out_pth = out_dir + content_pth.stem + '_style_' + str(style_idx) + '_alpha' + str(alpha)
88
- out_pth += content_pth.suffix
89
- save_image(out_tensor, out_pth)
90
 
91
  # Remove runtime of first iteration because it is flawed for some unknown reason
92
  if len(times) > 1:
 
40
  return decoder(mix_enc)
41
 
42
  def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
43
+ content_pths = [Path(f) for f in glob(content_dir+'/*')]
44
+ style_pths = [Path(f) for f in glob(style_dir+'/*')]
45
 
46
+ assert len(content_pths) > 0, 'Failed to load content image'
47
  assert len(style_pths) > 0, 'Failed to load style image'
48
 
49
  # Load AdaIN model
 
59
  times = []
60
 
61
  style_ds = load_dataset(style_dataset_pth, split="train")
62
+ # do i need to stick a dataloader around this? idk
63
 
64
  for style_idx, style_item in enumerate(style_ds):
65
+ style_img = style_item['image']
66
+ print(style_img)
67
+ style_tensor = t(style_img).unsqueeze(0).to(device)
68
 
69
+ for content_pth in content_pths:
70
+ content_img = Image.open(content_pth)
71
+ content_tensor = t(content_img).unsqueeze(0).to(device)
72
 
73
+ # Start time
74
+ tic = time.perf_counter()
75
 
76
+ # Execute style transfer
77
+ with torch.no_grad():
78
+ out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
79
 
80
+ # End time
81
+ toc = time.perf_counter()
82
+ print("Content: " + content_pth.stem + ". Style: " \
83
+ + str(style_idx) + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
84
+ times.append(toc-tic)
85
 
86
+ # Save image
87
+ out_pth = out_dir + content_pth.stem + '_style_' + str(style_idx) + '_alpha' + str(alpha)
88
+ out_pth += content_pth.suffix
89
+ save_image(out_tensor, out_pth)
90
 
91
  # Remove runtime of first iteration because it is flawed for some unknown reason
92
  if len(times) > 1: