falseu commited on
Commit
404f3a4
·
1 Parent(s): 19a4010
Files changed (1) hide show
  1. train.py +33 -29
train.py CHANGED
@@ -10,7 +10,7 @@ def main():
10
  parser = argparse.ArgumentParser()
11
  parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
12
  parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
13
- parser.add_argument('--epochs', type=int, default=1, help='Number of epoch')
14
  parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
15
  parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
16
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
@@ -31,6 +31,7 @@ def main():
31
  losses = []
32
  iteration = 0
33
 
 
34
  if args.resume > 0:
35
  states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
36
  model.decoder.load_state_dict(states['decoder'])
@@ -44,10 +45,12 @@ def main():
44
  train_tqdm = tqdm(train_loader)
45
  train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
46
  losses.append((iteration, total_loss, content_loss, style_loss))
47
- total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
48
 
49
  for content_batch, style_batch in train_tqdm:
50
 
 
 
51
  content_batch = content_batch.to(device)
52
  style_batch = style_batch.to(device)
53
 
@@ -55,39 +58,40 @@ def main():
55
  loss_scaled = loss_content + 10 * loss_style
56
  loss_scaled.backward()
57
  decoder_optimizer.step()
58
- total_loss += loss_scaled.item() * style_batch.size(0)
59
- decoder_optimizer.zero_grad()
60
-
61
- total_num += style_batch.size(0)
62
-
63
- if iteration % 100 == 0 and iteration > 0:
 
 
64
 
65
- total_loss /= total_num
66
- content_loss /= total_num
67
- style_loss /= total_num
68
- print('')
69
- train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
70
 
71
- losses.append((iteration, total_loss, content_loss, style_loss))
72
 
73
- total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
74
- total_num = 0
75
-
76
- if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
77
- total_loss /= total_num
78
- content_loss /= total_num
79
- style_loss /= total_num
80
- total_num = 0
81
-
82
- iteration += 1
83
 
84
  print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
85
 
86
- states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
87
- 'losses': losses, 'iteration': iteration}
88
- torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
89
- torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
90
- np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
91
 
92
  if __name__ == '__main__':
93
  main()
 
10
  parser = argparse.ArgumentParser()
11
  parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
12
  parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
13
+ parser.add_argument('--epochs', type=int, default=10, help='Number of epoch')
14
  parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
15
  parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
16
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
 
31
  losses = []
32
  iteration = 0
33
 
34
+ # If resume
35
  if args.resume > 0:
36
  states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
37
  model.decoder.load_state_dict(states['decoder'])
 
45
  train_tqdm = tqdm(train_loader)
46
  train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
47
  losses.append((iteration, total_loss, content_loss, style_loss))
48
+ total_num = 0
49
 
50
  for content_batch, style_batch in train_tqdm:
51
 
52
+ decoder_optimizer.zero_grad()
53
+
54
  content_batch = content_batch.to(device)
55
  style_batch = style_batch.to(device)
56
 
 
58
  loss_scaled = loss_content + 10 * loss_style
59
  loss_scaled.backward()
60
  decoder_optimizer.step()
61
+ total_loss = loss_scaled.item()
62
+ content_loss = loss_content.item()
63
+ style_loss = loss_style.item()
64
+
65
+ train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
66
+ iteration += 1
67
+
68
+ # if iteration % 100 == 0 and iteration > 0:
69
 
70
+ # total_loss /= total_num
71
+ # content_loss /= total_num
72
+ # style_loss /= total_num
73
+ # print('')
74
+ # train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
75
 
76
+ # losses.append((iteration, total_loss, content_loss, style_loss))
77
 
78
+ # total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
79
+ # total_num = 0
80
+
81
+ # if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
82
+ # total_loss /= total_num
83
+ # content_loss /= total_num
84
+ # style_loss /= total_num
85
+ # total_num = 0
86
+
 
87
 
88
  print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
89
 
90
+ # states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
91
+ # 'losses': losses, 'iteration': iteration}
92
+ # torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
93
+ # torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
94
+ # np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
95
 
96
  if __name__ == '__main__':
97
  main()