adain / train.py
falseu
update comments
4f6c34a
raw
history blame
3.8 kB
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils import TrainSet
from AdaIN import AdaINNet
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
parser.add_argument('--epochs', type=int, default=10, help='Number of epoch')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
args = parser.parse_args()
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
check_point_dir = './check_point/'
weights_dir = './weights/'
# Prepare Training dataset
train_set = TrainSet(args.content_dir, args.style_dir)
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
# load vgg19 weights
vgg_model = torch.load('vgg_normalized.pth')
model = AdaINNet(vgg_model).to(device)
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
losses = []
iteration = 0
# If resume training, load states
if args.resume > 0:
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
model.decoder.load_state_dict(states['decoder'])
decoder_optimizer.load_state_dict(states['decoder_optimizer'])
losses = states['losses']
iteration = states['iteration']
for epoch in range(args.resume + 1, args.epochs + 1):
print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
train_tqdm = tqdm(train_loader)
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
losses.append((iteration, total_loss, content_loss, style_loss))
total_num = 0
for content_batch, style_batch in train_tqdm:
decoder_optimizer.zero_grad()
content_batch = content_batch.to(device)
style_batch = style_batch.to(device)
# Feed forward and compute loss
loss_content, loss_style = model(content_batch, style_batch)
loss_scaled = loss_content + 10 * loss_style
# Gradient descent
loss_scaled.backward()
decoder_optimizer.step()
total_loss = loss_scaled.item()
content_loss = loss_content.item()
style_loss = loss_style.item()
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
iteration += 1
# if iteration % 100 == 0 and iteration > 0:
# total_loss /= total_num
# content_loss /= total_num
# style_loss /= total_num
# print('')
# train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
# losses.append((iteration, total_loss, content_loss, style_loss))
# total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
# total_num = 0
# if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
# total_loss /= total_num
# content_loss /= total_num
# style_loss /= total_num
# total_num = 0
print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
# states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
# 'losses': losses, 'iteration': iteration}
# torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
# torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
# np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
if __name__ == '__main__':
main()