File size: 3,528 Bytes
7930ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils import TrainSet
from AdaIN import AdaINNet
from tqdm import tqdm

def main():
	parser = argparse.ArgumentParser()
	parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
	parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
	parser.add_argument('--epochs', type=int, default=1, help='Number of epoch')
	parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
	parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
	parser.add_argument('--cuda', action='store_true', help='Use CUDA')
	args = parser.parse_args()

	device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
	
	check_point_dir = './check_point/'
	weights_dir = './weights/'
	train_set = TrainSet(args.content_dir, args.style_dir)
	train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
	
	vgg_model = torch.load('vgg_normalized.pth')
	model = AdaINNet(vgg_model).to(device)

	decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
	total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
	losses = []
	iteration = 0

	if args.resume > 0:
		states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
		model.decoder.load_state_dict(states['decoder'])
		decoder_optimizer.load_state_dict(states['decoder_optimizer'])
		losses = states['losses']
		iteration = states['iteration']


	for epoch in range(args.resume + 1, args.epochs + 1):
		print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
		train_tqdm = tqdm(train_loader)
		train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
		losses.append((iteration, total_loss, content_loss, style_loss))
		total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
		  
		for content_batch, style_batch in train_tqdm:
			
			content_batch = content_batch.to(device)
			style_batch = style_batch.to(device)

			loss_content, loss_style = model(content_batch, style_batch)
			loss_scaled = loss_content + 10 * loss_style
			loss_scaled.backward()
			decoder_optimizer.step()
			total_loss += loss_scaled.item() * style_batch.size(0)
			decoder_optimizer.zero_grad()
			
			total_num += style_batch.size(0)
			
			if iteration % 100 == 0 and iteration > 0:
				
				total_loss /= total_num
				content_loss /= total_num
				style_loss /= total_num
				print('')
				train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
				
				losses.append((iteration, total_loss, content_loss, style_loss))
				
				total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
				total_num = 0  

			if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
				total_loss /= total_num
				content_loss /= total_num
				style_loss /= total_num
				total_num = 0
			
			iteration += 1
		
		print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))

		states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(), 
			'losses': losses, 'iteration': iteration}
		torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
		torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))	
		np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')						

if __name__ == '__main__':
	main()