File size: 3,804 Bytes
7930ce0 404f3a4 7930ce0 4f6c34a 7930ce0 4f6c34a 7930ce0 4f6c34a 7930ce0 4f6c34a 7930ce0 4f6c34a 7930ce0 404f3a4 7930ce0 404f3a4 7930ce0 4f6c34a 7930ce0 4f6c34a 7930ce0 4f6c34a 404f3a4 7930ce0 404f3a4 7930ce0 404f3a4 7930ce0 404f3a4 7930ce0 404f3a4 7930ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils import TrainSet
from AdaIN import AdaINNet
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
parser.add_argument('--epochs', type=int, default=10, help='Number of epoch')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
args = parser.parse_args()
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
check_point_dir = './check_point/'
weights_dir = './weights/'
# Prepare Training dataset
train_set = TrainSet(args.content_dir, args.style_dir)
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
# load vgg19 weights
vgg_model = torch.load('vgg_normalized.pth')
model = AdaINNet(vgg_model).to(device)
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
losses = []
iteration = 0
# If resume training, load states
if args.resume > 0:
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
model.decoder.load_state_dict(states['decoder'])
decoder_optimizer.load_state_dict(states['decoder_optimizer'])
losses = states['losses']
iteration = states['iteration']
for epoch in range(args.resume + 1, args.epochs + 1):
print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
train_tqdm = tqdm(train_loader)
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
losses.append((iteration, total_loss, content_loss, style_loss))
total_num = 0
for content_batch, style_batch in train_tqdm:
decoder_optimizer.zero_grad()
content_batch = content_batch.to(device)
style_batch = style_batch.to(device)
# Feed forward and compute loss
loss_content, loss_style = model(content_batch, style_batch)
loss_scaled = loss_content + 10 * loss_style
# Gradient descent
loss_scaled.backward()
decoder_optimizer.step()
total_loss = loss_scaled.item()
content_loss = loss_content.item()
style_loss = loss_style.item()
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
iteration += 1
# if iteration % 100 == 0 and iteration > 0:
# total_loss /= total_num
# content_loss /= total_num
# style_loss /= total_num
# print('')
# train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
# losses.append((iteration, total_loss, content_loss, style_loss))
# total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
# total_num = 0
# if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
# total_loss /= total_num
# content_loss /= total_num
# style_loss /= total_num
# total_num = 0
print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
# states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
# 'losses': losses, 'iteration': iteration}
# torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
# torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
# np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
if __name__ == '__main__':
main()
|