Spaces:
Configuration error
Configuration error
File size: 6,957 Bytes
89c5d90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import tensorflow as tf
import matplotlib.pyplot as plt
from models.dexined import dexined
# from models.dexinedBs import dexined
from utls.utls import *
from utls.dataset_manager import (data_parser,
get_training_batch,get_validation_batch, visualize_result)
class m_trainer():
def __init__(self,args ):
self.init = True
self.args = args
def setup(self):
try:
if self.args.model_name=='DXN':
self.model = dexined(self.args)
else:
print_error("Error setting model, {}".format(self.args.model_name))
print_info("DL model Set")
except Exception as err:
print_error("Error setting up DL model, {}".format(err))
self.init=False
def run(self, sess):
if not self.init:
return
train_data = data_parser(self.args)
self.model.setup_training(sess)
if self.args.lr_scheduler is not None:
global_step = tf.Variable(0, trainable=False, dtype=tf.int64)
if self.args.lr_scheduler is None:
learning_rate = tf.constant(self.args.learning_rate, dtype=tf.float16)
else:
raise NotImplementedError('Learning rate scheduler type [%s] is not implemented',
self.args.lr_scheduler)
opt = tf.compat.v1.train.AdamOptimizer(learning_rate)
trainG = opt.minimize(self.model.loss)# like hed
saver = tf.compat.v1.train.Saver(max_to_keep=7)
sess.run(tf.compat.v1.global_variables_initializer())
# here to recovery previous training
if self.args.use_previous_trained:
if self.args.dataset_name.lower()!='biped': # using biped pretrained to use in other dataset
model_path = os.path.join(self.args.checkpoint_dir,self.args.model_name+
'_'+self.args.train_dataset,'train')
else:
model_path = os.path.join(self.args.checkpoint_dir, self.args.model_name + '_' + self.args.train_dataset)
model_path = os.path.join(model_path, 'train')
if not os.path.exists(model_path) or len(os.listdir(model_path))==0: # :
ini = 0
maxi = self.args.max_iterations+1
print_warning('There is not previous trained data for the current model... and')
print_warning('*** The training process is starting from scratch ***')
else:
# restoring using the last checkpoint
assert (len(os.listdir(model_path)) != 0),'There is not previous trained data for the current model...'
last_ckpt = tf.train.latest_checkpoint(model_path)
saver.restore(sess,last_ckpt)
ini=self.args.max_iterations
maxi=ini+self.args.max_iterations+1 # check
print_info('--> Previous model restored successfully: {}'.format(last_ckpt))
else:
print_warning('*** The training process is starting from scratch ***')
ini = 0
maxi = ini + self.args.max_iterations
prev_loss=1000.
prev_val = None
# directories for checkpoints
checkpoint_dir = os.path.join(
self.args.checkpoint_dir, self.args.model_name + '_' + self.args.train_dataset,
self.args.model_state)
os.makedirs(checkpoint_dir,exist_ok=True)
fig = plt.figure()
for idx in range(ini, maxi):
x_batch, y_batch,_ = get_training_batch(self.args, train_data)
run_metadata = tf.compat.v1.RunMetadata()
_, summary, loss,pred_maps= sess.run(
[trainG, self.model.merged_summary, self.model.loss, self.model.predictions],
feed_dict={self.model.images: x_batch, self.model.edgemaps: y_batch})
if idx%5==0:
self.model.train_writer.add_run_metadata(run_metadata,
'step{:06}'.format(idx))
self.model.train_writer.add_summary(summary, idx)
print(time.ctime(), '[{}/{}]'.format(idx, maxi), ' TRAINING loss: %.5f' % loss,
'prev_loss: %.5f' % prev_loss)
# saving trained parameters
save_inter = ini+self.args.save_interval
if prev_loss>loss:
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
prev_loss = loss
print("Weights saved in the lowest loss",idx, " Current Loss",prev_loss)
if idx % self.args.save_interval == 0:
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
prev_loss = loss
print("Weights saved in the interval", idx, " Current Loss",prev_loss)
# ********* for validation **********
if (idx+1) % self.args.val_interval== 0:
pause_show=0.01
imgs_list = []
img = x_batch[2][:,:,0:3]
gt_mp= y_batch[2]
imgs_list.append(img)
imgs_list.append(gt_mp)
for i in range(len(pred_maps)):
tmp=pred_maps[i][2,...]
imgs_list.append(tmp)
vis_imgs = visualize_result(imgs_list, self.args)
fig.suptitle("Iterac:" + str(idx + 1) + " Loss:" + '%.5f' % loss + " training")
fig.add_subplot(1,1,1)
plt.imshow(np.uint8(vis_imgs))
print("Evaluation in progress...")
plt.draw()
plt.pause(pause_show)
im, em, _ = get_validation_batch(self.args, train_data)
summary, error, pred_val = sess.run(
[self.model.merged_summary, self.model.error, self.model.fuse_output],
feed_dict={self.model.images: im, self.model.edgemaps: em})
if error<=0.08:
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
prev_loss = loss
print("Parameters saved in the validation stage when its error is <=0.08::", error)
self.model.val_writer.add_summary(summary, idx)
print_info(('[{}/{}]'.format(idx, self.args.max_iterations),'VALIDATION error: %0.5f'%error,
'pError: %.5f'%prev_loss))
if (idx+1) % (self.args.val_interval*150)== 0:
print('updating visualisation')
plt.close()
fig = plt.figure()
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
print("Final Weights saved", idx, " Current Loss", loss)
self.model.train_writer.close()
sess.close()
|