falseu
commited on
Commit
·
404f3a4
1
Parent(s):
19a4010
fix
Browse files
train.py
CHANGED
|
@@ -10,7 +10,7 @@ def main():
|
|
| 10 |
parser = argparse.ArgumentParser()
|
| 11 |
parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
|
| 12 |
parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
|
| 13 |
-
parser.add_argument('--epochs', type=int, default=
|
| 14 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
|
| 15 |
parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
|
| 16 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
|
@@ -31,6 +31,7 @@ def main():
|
|
| 31 |
losses = []
|
| 32 |
iteration = 0
|
| 33 |
|
|
|
|
| 34 |
if args.resume > 0:
|
| 35 |
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
|
| 36 |
model.decoder.load_state_dict(states['decoder'])
|
|
@@ -44,10 +45,12 @@ def main():
|
|
| 44 |
train_tqdm = tqdm(train_loader)
|
| 45 |
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
|
| 46 |
losses.append((iteration, total_loss, content_loss, style_loss))
|
| 47 |
-
|
| 48 |
|
| 49 |
for content_batch, style_batch in train_tqdm:
|
| 50 |
|
|
|
|
|
|
|
| 51 |
content_batch = content_batch.to(device)
|
| 52 |
style_batch = style_batch.to(device)
|
| 53 |
|
|
@@ -55,39 +58,40 @@ def main():
|
|
| 55 |
loss_scaled = loss_content + 10 * loss_style
|
| 56 |
loss_scaled.backward()
|
| 57 |
decoder_optimizer.step()
|
| 58 |
-
total_loss
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
iteration += 1
|
| 83 |
|
| 84 |
print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
|
| 85 |
|
| 86 |
-
states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
|
| 87 |
-
|
| 88 |
-
torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
|
| 89 |
-
torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
|
| 90 |
-
np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
|
| 91 |
|
| 92 |
if __name__ == '__main__':
|
| 93 |
main()
|
|
|
|
| 10 |
parser = argparse.ArgumentParser()
|
| 11 |
parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
|
| 12 |
parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
|
| 13 |
+
parser.add_argument('--epochs', type=int, default=10, help='Number of epoch')
|
| 14 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
|
| 15 |
parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
|
| 16 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
|
|
|
| 31 |
losses = []
|
| 32 |
iteration = 0
|
| 33 |
|
| 34 |
+
# If resume
|
| 35 |
if args.resume > 0:
|
| 36 |
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
|
| 37 |
model.decoder.load_state_dict(states['decoder'])
|
|
|
|
| 45 |
train_tqdm = tqdm(train_loader)
|
| 46 |
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
|
| 47 |
losses.append((iteration, total_loss, content_loss, style_loss))
|
| 48 |
+
total_num = 0
|
| 49 |
|
| 50 |
for content_batch, style_batch in train_tqdm:
|
| 51 |
|
| 52 |
+
decoder_optimizer.zero_grad()
|
| 53 |
+
|
| 54 |
content_batch = content_batch.to(device)
|
| 55 |
style_batch = style_batch.to(device)
|
| 56 |
|
|
|
|
| 58 |
loss_scaled = loss_content + 10 * loss_style
|
| 59 |
loss_scaled.backward()
|
| 60 |
decoder_optimizer.step()
|
| 61 |
+
total_loss = loss_scaled.item()
|
| 62 |
+
content_loss = loss_content.item()
|
| 63 |
+
style_loss = loss_style.item()
|
| 64 |
+
|
| 65 |
+
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
|
| 66 |
+
iteration += 1
|
| 67 |
+
|
| 68 |
+
# if iteration % 100 == 0 and iteration > 0:
|
| 69 |
|
| 70 |
+
# total_loss /= total_num
|
| 71 |
+
# content_loss /= total_num
|
| 72 |
+
# style_loss /= total_num
|
| 73 |
+
# print('')
|
| 74 |
+
# train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
|
| 75 |
|
| 76 |
+
# losses.append((iteration, total_loss, content_loss, style_loss))
|
| 77 |
|
| 78 |
+
# total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
| 79 |
+
# total_num = 0
|
| 80 |
+
|
| 81 |
+
# if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
|
| 82 |
+
# total_loss /= total_num
|
| 83 |
+
# content_loss /= total_num
|
| 84 |
+
# style_loss /= total_num
|
| 85 |
+
# total_num = 0
|
| 86 |
+
|
|
|
|
| 87 |
|
| 88 |
print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
|
| 89 |
|
| 90 |
+
# states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
|
| 91 |
+
# 'losses': losses, 'iteration': iteration}
|
| 92 |
+
# torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
|
| 93 |
+
# torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
|
| 94 |
+
# np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
|
| 95 |
|
| 96 |
if __name__ == '__main__':
|
| 97 |
main()
|