File size: 3,804 Bytes
7930ce0
 
 
 
 
 
 
 
 
 
 
 
404f3a4
7930ce0
 
 
 
 
 
4f6c34a
7930ce0
 
4f6c34a
 
7930ce0
 
4f6c34a
 
7930ce0
 
 
4f6c34a
7930ce0
 
 
 
4f6c34a
7930ce0
 
 
 
 
 
 
 
 
 
 
 
 
404f3a4
7930ce0
 
 
404f3a4
 
7930ce0
 
 
4f6c34a
7930ce0
 
4f6c34a
 
7930ce0
 
4f6c34a
404f3a4
 
 
 
 
 
 
 
7930ce0
404f3a4
 
 
 
 
7930ce0
404f3a4
7930ce0
404f3a4
 
 
 
 
 
 
 
 
7930ce0
 
 
404f3a4
 
 
 
 
7930ce0
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils import TrainSet
from AdaIN import AdaINNet
from tqdm import tqdm

def main():
	parser = argparse.ArgumentParser()
	parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
	parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
	parser.add_argument('--epochs', type=int, default=10, help='Number of epoch')
	parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
	parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
	parser.add_argument('--cuda', action='store_true', help='Use CUDA')
	args = parser.parse_args()

	device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')

	check_point_dir = './check_point/'
	weights_dir = './weights/'

	# Prepare Training dataset
	train_set = TrainSet(args.content_dir, args.style_dir)
	train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)

	# load vgg19 weights
	vgg_model = torch.load('vgg_normalized.pth')
	model = AdaINNet(vgg_model).to(device)
	decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)

	total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
	losses = []
	iteration = 0

	# If resume training, load states
	if args.resume > 0:
		states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
		model.decoder.load_state_dict(states['decoder'])
		decoder_optimizer.load_state_dict(states['decoder_optimizer'])
		losses = states['losses']
		iteration = states['iteration']


	for epoch in range(args.resume + 1, args.epochs + 1):
		print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
		train_tqdm = tqdm(train_loader)
		train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
		losses.append((iteration, total_loss, content_loss, style_loss))
		total_num = 0
		  
		for content_batch, style_batch in train_tqdm:
			
			decoder_optimizer.zero_grad()
			
			content_batch = content_batch.to(device)
			style_batch = style_batch.to(device)

			# Feed forward and compute loss
			loss_content, loss_style = model(content_batch, style_batch)
			loss_scaled = loss_content + 10 * loss_style

			# Gradient descent
			loss_scaled.backward()
			decoder_optimizer.step()

			total_loss = loss_scaled.item()
			content_loss = loss_content.item()
			style_loss = loss_style.item()

			train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
			iteration += 1

			# if iteration % 100 == 0 and iteration > 0:
				
			# 	total_loss /= total_num
			# 	content_loss /= total_num
			# 	style_loss /= total_num
			# 	print('')
			# 	train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
				
			# 	losses.append((iteration, total_loss, content_loss, style_loss))
				
			# 	total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
			# 	total_num = 0  

			# if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
			# 	total_loss /= total_num
			# 	content_loss /= total_num
			# 	style_loss /= total_num
			# 	total_num = 0
		
		
		print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))

		# states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(), 
		# 	'losses': losses, 'iteration': iteration}
		# torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
		# torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))	
		# np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')						

if __name__ == '__main__':
	main()