from __future__ import absolute_import, division, print_function import time, os import numpy as np from os.path import join import cv2 as cv from model import * from utls import image_normalization,visualize_result, tensor2image, cv_imshow,h5_writer from dataset_manager import DataLoader BUFFER_SIZE = 448 # tf.set_random_seed(1) class run_DexiNed(): def __init__(self, args): self.model_state= args.model_state self.args = args self.img_width=args.image_width self.img_height = args.image_height self.epochs = args.max_epochs self.bs = args.batch_size def train(self): # Validation and Train dataset generation train_data = DataLoader(data_name=self.args.data4train, arg=self.args) n_train =train_data.indices.size #data_cache["n_files"] val_data = DataLoader(data_name=self.args.data4train, arg=self.args, is_val=True) val_idcs = np.arange(val_data.indices.size) # Summary and checkpoint manager model_dir =self.args.model_name+'2'+self.args.data4train summary_dir = os.path.join('logs',model_dir) train_log_dir=os.path.join(summary_dir,'train') val_log_dir =os.path.join(summary_dir,'test') checkpoint_dir = os.path.join(self.args.checkpoint_dir,model_dir) epoch_ckpt_dir = checkpoint_dir + 'epochs' os.makedirs(epoch_ckpt_dir, exist_ok=True) os.makedirs(train_log_dir,exist_ok=True) os.makedirs(val_log_dir,exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) train_writer = tf.summary.create_file_writer(train_log_dir) val_writer = tf.summary.create_file_writer(val_log_dir) my_model = DexiNed(rgb_mean=self.args.rgbn_mean)#rgb_mean=self.args.rgbn_mean # accuracy = metrics.SparseCategoricalAccuracy() accuracy = metrics.BinaryAccuracy() accuracy_val = metrics.BinaryAccuracy() loss_bc = losses.BinaryCrossentropy() optimizer = optimizers.Adam( learning_rate=self.args.lr, beta_1=self.args.beta1) iter = 0 imgs_res_folder = os.path.join(self.args.output_dir, "current_training") os.makedirs(imgs_res_folder, exist_ok=True) global_loss = 1000. t_loss = [] ckpt_save_mode = "h5" tmp_lr = self.args.lr for epoch in range(self.args.max_epochs): # training t_loss = [] # if epoch in self.args.adjust_lr: tmp_lr=tmp_lr*0.1 optimizer.lr.assign(tmp_lr) for step, (x, y) in enumerate(train_data): with tf.GradientTape() as tape: pred = my_model(x, training=True) preds, loss = pre_process_binary_cross_entropy( loss_bc, pred, y, self.args, use_tf_loss=False) accuracy.update_state(y_true=y, y_pred=preds[-1]) gradients = tape.gradient(loss, my_model.trainable_variables) optimizer.apply_gradients(zip(gradients, my_model.trainable_variables)) # logging the current accuracy value so far. t_loss.append(loss.numpy()) if step % 10 == 0: print("Epoch:", epoch, "Step:", step, "Loss: %.4f" % loss.numpy(), "Accuracy: %.4f" % accuracy.result(), time.ctime()) if step % 10 == 0: # visualize preds img_test = 'Epoch: {0} Sample {1}/{2} Loss: {3}' \ .format(epoch, step, n_train // self.args.batch_size, loss.numpy()) vis_imgs = visualize_result( x=x[2], y=y[2], p=preds, img_title=img_test) cv.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs) if step % 20 == 0 and loss < global_loss: # 500 if epoch==0 and step==0: tmp_loss = np.array(t_loss) with train_writer.as_default(): tf.summary.scalar('loss', tmp_loss.mean(), step=epoch) tf.summary.scalar('accuracy', accuracy.result(), step=epoch) save_ckpt_path = os.path.join(checkpoint_dir, "DexiNedL_model.h5") Model.save_weights(my_model, save_ckpt_path, save_format='h5') global_loss = loss print("Model saved in: ", save_ckpt_path, "Current loss:", global_loss.numpy()) iter += 1 # global iteration t_loss = np.array(t_loss) # train summary if epoch!=0: with train_writer.as_default(): tf.summary.scalar('loss', t_loss.mean(), step=epoch) tf.summary.scalar('accuracy', accuracy.result(), step=epoch) Model.save_weights(my_model, os.path.join(epoch_ckpt_dir, "DexiNed{}_model.h5".format(str(epoch))), save_format=ckpt_save_mode) print("Epoch:", epoch, "Model saved in Loss: ", t_loss.mean()) # validation t_val_loss = [] for i, (x_val, y_val) in enumerate(val_data): pred_val = my_model(x_val) v_logits, V_loss = pre_process_binary_cross_entropy( loss_bc, pred_val, y_val, self.args, use_tf_loss=False) accuracy_val.update_state(y_true=y_val, y_pred=v_logits[-1]) t_val_loss.append(V_loss.numpy()) if i == 7: break val_acc = accuracy_val.result() t_val_loss = np.array(t_val_loss) print("Epoch(validation):", epoch, "Val loss: ", t_val_loss.mean(), "Accuracy: ", val_acc.numpy()) # validation summary with val_writer.as_default(): tf.summary.scalar('loss', t_val_loss.mean(), step=epoch) tf.summary.scalar('accuracy', val_acc.numpy(), step=epoch) # Reset metrics every epoch accuracy.reset_states() accuracy_val.reset_states() my_model.summary() def test(self): # Test dataset generation test_data = DataLoader(data_name=self.args.data4test, arg=self.args) n_test = test_data.indices.size # data_cache["n_files"] optimizer = tf.keras.optimizers.Adam( learning_rate=self.args.lr, beta_1=self.args.beta1) my_model = DexiNed(rgb_mean=self.args.rgbn_mean) input_shape = test_data.input_shape my_model.build(input_shape=input_shape) # rgb_mean=self.args.rgbn_mean checkpoit_dir = os.path.join(self.args.checkpoint_dir, self.args.model_name + "2" + self.args.data4train) my_model.load_weights(os.path.join(checkpoit_dir, self.args.checkpoint)) result_dir = os.path.join( self.args.output_dir, self.args.model_name + '-' + self.args.data4train + "2" + self.args.data4test) os.makedirs(result_dir, exist_ok=True) if self.args.scale is not None: scl = self.args.scale save_dir = ['fuse_'+str(scl), 'avrg_'+str(scl), 'h5_'+str(scl)] else: save_dir = ['fuse', 'avrg', 'h5'] save_dirs = [] for tmp_dir in save_dir: os.makedirs(os.path.join(result_dir, tmp_dir), exist_ok=True) save_dirs.append(os.path.join(result_dir, tmp_dir)) total_time = [] data_names = test_data.imgs_name data_shape = test_data.imgs_shape k = 0 for step, (x, y) in enumerate(test_data): start_time = time.time() preds = my_model(x, training=False) tmp_time = time.time() - start_time total_time.append(tmp_time) preds = [tf.sigmoid(i).numpy() for i in preds] all_preds = np.array(preds) for i in range(all_preds.shape[1]): tmp_name = data_names[k] tmp_name, _ = os.path.splitext(tmp_name) tmp_shape = data_shape[k] tmp_preds = all_preds[:, i, ...] tmp_av = np.expand_dims(tmp_preds.mean(axis=0), axis=0) tmp_preds = np.concatenate((tmp_preds, tmp_av), axis=0) res_preds = [] for j in range(tmp_preds.shape[0]): tmp_pred = tmp_preds[j, ...] tmp_pred[tmp_pred < 0.0] = 0.0 tmp_pred = cv.bitwise_not(np.uint8(image_normalization(tmp_pred))) h, w = tmp_pred.shape[:2] if h != tmp_shape[0] or w != tmp_shape[1]: tmp_pred = cv.resize(tmp_pred, (tmp_shape[1], tmp_shape[0])) res_preds.append(tmp_pred) n_save =len(tmp_preds)-2 for idx in range(len(save_dirs) - 1): s_dir = save_dirs[idx] tmp = res_preds[n_save + idx] cv.imwrite(join(s_dir, tmp_name + '.png'), tmp) h5_writer(path=join(save_dirs[-1], tmp_name + '.h5'), vars=np.squeeze(res_preds)) print("saved:", join(save_dirs[-1], tmp_name + '.h5'), tmp_preds.shape) k += 1 # tmp_name = data_names[step][:-3]+"png" # tmp_shape = data_shape[step] # tmp_path = os.path.join(result_dir,tmp_name) # tensor2image(preds[-1].numpy(), img_path =tmp_path,img_shape=tmp_shape) total_time = np.array(total_time) print('-------------------------------------------------') print("End testing in: ", self.args.data4test) print("Batch size: ", self.args.test_bs) print("Time average per image: ", total_time.mean(), "secs") print("Total time: ", total_time.sum(), "secs") print('-------------------------------------------------')