DexinedApp / DexiNed-TF2 /run_model.py
Dinars34's picture
Upload 60 files
89c5d90 verified
from __future__ import absolute_import, division, print_function
import time, os
import numpy as np
from os.path import join
import cv2 as cv
from model import *
from utls import image_normalization,visualize_result, tensor2image, cv_imshow,h5_writer
from dataset_manager import DataLoader
BUFFER_SIZE = 448
# tf.set_random_seed(1)
class run_DexiNed():
def __init__(self, args):
self.model_state= args.model_state
self.args = args
self.img_width=args.image_width
self.img_height = args.image_height
self.epochs = args.max_epochs
self.bs = args.batch_size
def train(self):
# Validation and Train dataset generation
train_data = DataLoader(data_name=self.args.data4train, arg=self.args)
n_train =train_data.indices.size #data_cache["n_files"]
val_data = DataLoader(data_name=self.args.data4train,
arg=self.args, is_val=True)
val_idcs = np.arange(val_data.indices.size)
# Summary and checkpoint manager
model_dir =self.args.model_name+'2'+self.args.data4train
summary_dir = os.path.join('logs',model_dir)
train_log_dir=os.path.join(summary_dir,'train')
val_log_dir =os.path.join(summary_dir,'test')
checkpoint_dir = os.path.join(self.args.checkpoint_dir,model_dir)
epoch_ckpt_dir = checkpoint_dir + 'epochs'
os.makedirs(epoch_ckpt_dir, exist_ok=True)
os.makedirs(train_log_dir,exist_ok=True)
os.makedirs(val_log_dir,exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
train_writer = tf.summary.create_file_writer(train_log_dir)
val_writer = tf.summary.create_file_writer(val_log_dir)
my_model = DexiNed(rgb_mean=self.args.rgbn_mean)#rgb_mean=self.args.rgbn_mean
# accuracy = metrics.SparseCategoricalAccuracy()
accuracy = metrics.BinaryAccuracy()
accuracy_val = metrics.BinaryAccuracy()
loss_bc = losses.BinaryCrossentropy()
optimizer = optimizers.Adam(
learning_rate=self.args.lr, beta_1=self.args.beta1)
iter = 0
imgs_res_folder = os.path.join(self.args.output_dir, "current_training")
os.makedirs(imgs_res_folder, exist_ok=True)
global_loss = 1000.
t_loss = []
ckpt_save_mode = "h5"
tmp_lr = self.args.lr
for epoch in range(self.args.max_epochs):
# training
t_loss = []
# if epoch in self.args.adjust_lr:
tmp_lr=tmp_lr*0.1
optimizer.lr.assign(tmp_lr)
for step, (x, y) in enumerate(train_data):
with tf.GradientTape() as tape:
pred = my_model(x, training=True)
preds, loss = pre_process_binary_cross_entropy(
loss_bc, pred, y, self.args, use_tf_loss=False)
accuracy.update_state(y_true=y, y_pred=preds[-1])
gradients = tape.gradient(loss, my_model.trainable_variables)
optimizer.apply_gradients(zip(gradients, my_model.trainable_variables))
# logging the current accuracy value so far.
t_loss.append(loss.numpy())
if step % 10 == 0:
print("Epoch:", epoch, "Step:", step, "Loss: %.4f" % loss.numpy(),
"Accuracy: %.4f" % accuracy.result(), time.ctime())
if step % 10 == 0:
# visualize preds
img_test = 'Epoch: {0} Sample {1}/{2} Loss: {3}' \
.format(epoch, step, n_train // self.args.batch_size, loss.numpy())
vis_imgs = visualize_result(
x=x[2], y=y[2], p=preds, img_title=img_test)
cv.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs)
if step % 20 == 0 and loss < global_loss: # 500
if epoch==0 and step==0:
tmp_loss = np.array(t_loss)
with train_writer.as_default():
tf.summary.scalar('loss', tmp_loss.mean(), step=epoch)
tf.summary.scalar('accuracy', accuracy.result(), step=epoch)
save_ckpt_path = os.path.join(checkpoint_dir, "DexiNedL_model.h5")
Model.save_weights(my_model, save_ckpt_path, save_format='h5')
global_loss = loss
print("Model saved in: ", save_ckpt_path, "Current loss:", global_loss.numpy())
iter += 1 # global iteration
t_loss = np.array(t_loss)
# train summary
if epoch!=0:
with train_writer.as_default():
tf.summary.scalar('loss', t_loss.mean(), step=epoch)
tf.summary.scalar('accuracy', accuracy.result(), step=epoch)
Model.save_weights(my_model, os.path.join(epoch_ckpt_dir, "DexiNed{}_model.h5".format(str(epoch))),
save_format=ckpt_save_mode)
print("Epoch:", epoch, "Model saved in Loss: ", t_loss.mean())
# validation
t_val_loss = []
for i, (x_val, y_val) in enumerate(val_data):
pred_val = my_model(x_val)
v_logits, V_loss = pre_process_binary_cross_entropy(
loss_bc, pred_val, y_val, self.args, use_tf_loss=False)
accuracy_val.update_state(y_true=y_val, y_pred=v_logits[-1])
t_val_loss.append(V_loss.numpy())
if i == 7:
break
val_acc = accuracy_val.result()
t_val_loss = np.array(t_val_loss)
print("Epoch(validation):", epoch, "Val loss: ", t_val_loss.mean(),
"Accuracy: ", val_acc.numpy())
# validation summary
with val_writer.as_default():
tf.summary.scalar('loss', t_val_loss.mean(), step=epoch)
tf.summary.scalar('accuracy', val_acc.numpy(), step=epoch)
# Reset metrics every epoch
accuracy.reset_states()
accuracy_val.reset_states()
my_model.summary()
def test(self):
# Test dataset generation
test_data = DataLoader(data_name=self.args.data4test, arg=self.args)
n_test = test_data.indices.size # data_cache["n_files"]
optimizer = tf.keras.optimizers.Adam(
learning_rate=self.args.lr, beta_1=self.args.beta1)
my_model = DexiNed(rgb_mean=self.args.rgbn_mean)
input_shape = test_data.input_shape
my_model.build(input_shape=input_shape) # rgb_mean=self.args.rgbn_mean
checkpoit_dir = os.path.join(self.args.checkpoint_dir,
self.args.model_name + "2" + self.args.data4train)
my_model.load_weights(os.path.join(checkpoit_dir, self.args.checkpoint))
result_dir = os.path.join(
self.args.output_dir,
self.args.model_name + '-' + self.args.data4train + "2" + self.args.data4test)
os.makedirs(result_dir, exist_ok=True)
if self.args.scale is not None:
scl = self.args.scale
save_dir = ['fuse_'+str(scl), 'avrg_'+str(scl), 'h5_'+str(scl)]
else:
save_dir = ['fuse', 'avrg', 'h5']
save_dirs = []
for tmp_dir in save_dir:
os.makedirs(os.path.join(result_dir, tmp_dir), exist_ok=True)
save_dirs.append(os.path.join(result_dir, tmp_dir))
total_time = []
data_names = test_data.imgs_name
data_shape = test_data.imgs_shape
k = 0
for step, (x, y) in enumerate(test_data):
start_time = time.time()
preds = my_model(x, training=False)
tmp_time = time.time() - start_time
total_time.append(tmp_time)
preds = [tf.sigmoid(i).numpy() for i in preds]
all_preds = np.array(preds)
for i in range(all_preds.shape[1]):
tmp_name = data_names[k]
tmp_name, _ = os.path.splitext(tmp_name)
tmp_shape = data_shape[k]
tmp_preds = all_preds[:, i, ...]
tmp_av = np.expand_dims(tmp_preds.mean(axis=0), axis=0)
tmp_preds = np.concatenate((tmp_preds, tmp_av), axis=0)
res_preds = []
for j in range(tmp_preds.shape[0]):
tmp_pred = tmp_preds[j, ...]
tmp_pred[tmp_pred < 0.0] = 0.0
tmp_pred = cv.bitwise_not(np.uint8(image_normalization(tmp_pred)))
h, w = tmp_pred.shape[:2]
if h != tmp_shape[0] or w != tmp_shape[1]:
tmp_pred = cv.resize(tmp_pred, (tmp_shape[1], tmp_shape[0]))
res_preds.append(tmp_pred)
n_save =len(tmp_preds)-2
for idx in range(len(save_dirs) - 1):
s_dir = save_dirs[idx]
tmp = res_preds[n_save + idx]
cv.imwrite(join(s_dir, tmp_name + '.png'), tmp)
h5_writer(path=join(save_dirs[-1], tmp_name + '.h5'),
vars=np.squeeze(res_preds))
print("saved:", join(save_dirs[-1], tmp_name + '.h5'), tmp_preds.shape)
k += 1
# tmp_name = data_names[step][:-3]+"png"
# tmp_shape = data_shape[step]
# tmp_path = os.path.join(result_dir,tmp_name)
# tensor2image(preds[-1].numpy(), img_path =tmp_path,img_shape=tmp_shape)
total_time = np.array(total_time)
print('-------------------------------------------------')
print("End testing in: ", self.args.data4test)
print("Batch size: ", self.args.test_bs)
print("Time average per image: ", total_time.mean(), "secs")
print("Total time: ", total_time.sum(), "secs")
print('-------------------------------------------------')