File size: 3,528 Bytes
7930ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils import TrainSet
from AdaIN import AdaINNet
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
parser.add_argument('--epochs', type=int, default=1, help='Number of epoch')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
args = parser.parse_args()
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
check_point_dir = './check_point/'
weights_dir = './weights/'
train_set = TrainSet(args.content_dir, args.style_dir)
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
vgg_model = torch.load('vgg_normalized.pth')
model = AdaINNet(vgg_model).to(device)
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
losses = []
iteration = 0
if args.resume > 0:
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
model.decoder.load_state_dict(states['decoder'])
decoder_optimizer.load_state_dict(states['decoder_optimizer'])
losses = states['losses']
iteration = states['iteration']
for epoch in range(args.resume + 1, args.epochs + 1):
print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
train_tqdm = tqdm(train_loader)
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
losses.append((iteration, total_loss, content_loss, style_loss))
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
for content_batch, style_batch in train_tqdm:
content_batch = content_batch.to(device)
style_batch = style_batch.to(device)
loss_content, loss_style = model(content_batch, style_batch)
loss_scaled = loss_content + 10 * loss_style
loss_scaled.backward()
decoder_optimizer.step()
total_loss += loss_scaled.item() * style_batch.size(0)
decoder_optimizer.zero_grad()
total_num += style_batch.size(0)
if iteration % 100 == 0 and iteration > 0:
total_loss /= total_num
content_loss /= total_num
style_loss /= total_num
print('')
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
losses.append((iteration, total_loss, content_loss, style_loss))
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
total_num = 0
if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
total_loss /= total_num
content_loss /= total_num
style_loss /= total_num
total_num = 0
iteration += 1
print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
'losses': losses, 'iteration': iteration}
torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
if __name__ == '__main__':
main()
|