import tensorflow as tf from PIL import Image from models.dexined import dexined # from models.dexinedBs import dexined from utls.utls import * from utls.dataset_manager import data_parser,get_single_image,\ get_testing_batch class m_tester(): def __init__(self, args): self.args = args def setup(self, session): try: if self.args.model_name=='DXN': self.model = dexined(self.args) else: print_error("Error setting model, {}".format(self.args.model_name)) if self.args.trained_model_dir is None: meta_model_file = os.path.join( self.args.checkpoint_dir, os.path.join( self.args.model_name + '_' + self.args.train_dataset, os.path.join('train', '{}-{}'.format(self.args.model_name, self.args.test_snapshot)))) else: meta_model_file = os.path.join( self.args.checkpoint_dir, os.path.join( self.args.model_name + '_' + self.args.train_dataset, os.path.join(self.args.trained_model_dir, '{}-{}'.format(self.args.model_name, self.args.test_snapshot)))) saver = tf.train.Saver() saver.restore(session, meta_model_file) print_info('Done restoring DexiNed model from {}'.format(meta_model_file)) except Exception as err: print_error('Error setting up DexiNed traied model, {}'.format(err)) def run(self, session): self.model.setup_testing(session) if self.args.use_dataset: test_data= data_parser(self.args) n_data = len(test_data[1]) else: test_data=get_single_image(self.args) n_data = len(test_data) print_info('Writing PNGs at {}'.format(self.args.base_dir_results)) if self.args.batch_size_test==1 and self.args.use_dataset: for i in range(n_data): im, em, file_name = get_testing_batch(self.args, [test_data[0][test_data[1][i]], test_data[1][i]], use_batch=False) self.img_info = file_name edgemap = session.run(self.model.predictions, feed_dict={self.model.images: [im]}) self.save_egdemaps(edgemap, single_image=True) print_info('Done testing {}, {}'.format(self.img_info[0], self.img_info[1])) # for individual images elif self.args.batch_size_test==1 and not self.args.use_dataset: for i in range(n_data): im, file_name = get_single_image(self.args,file_path=test_data[i]) self.img_info = file_name edgemap = session.run(self.model.predictions, feed_dict={self.model.images: [im]}) self.save_egdemaps(edgemap, single_image=True) print_info('Done testing {}, {}'.format(self.img_info[0], self.img_info[1])) def save_egdemaps(self, em_maps, single_image=False): """ save_edgemaps descriptios :param em_maps: :param single_image: save predicted edge maps """ result_dir = 'DexiNed_'+self.args.train_dataset+'2'+self.args.test_dataset if self.args.base_dir_results is None: res_dir = os.path.join('../results', result_dir) else: res_dir = os.path.join(self.args.base_dir_results,result_dir) gt_dir = os.path.join(res_dir,'gt') all_dir = os.path.join(res_dir,'pred-h5') resf_dir = os.path.join(res_dir,'pred-f') resa_dir = os.path.join(res_dir,'pred-a') os.makedirs(resf_dir, exist_ok=True) os.makedirs(resa_dir, exist_ok=True) os.makedirs(gt_dir, exist_ok=True) os.makedirs(all_dir, exist_ok=True) if single_image: em_maps = [e[0] for e in em_maps] em_a = np.mean(np.array(em_maps), axis=0) em_maps = em_maps + [em_a ] em = em_maps[len(em_maps)-2] em[em < self.args.testing_threshold] = 0.0 em_a[em_a < self.args.testing_threshold] = 0.0 em = 255.0 * (1.0 - em) em_a = 255.0 * (1.0 - em_a) em = np.tile(em, [1, 1, 3]) em_a = np.tile(em_a, [1, 1, 3]) em = Image.fromarray(np.uint8(em)) em_a = Image.fromarray(np.uint8(em_a)) tmp_name = os.path.basename(self.img_info[0]) tmp_name = tmp_name[:-4] tmp_size = self.img_info[-1][:2] tmp_size = (tmp_size[1],tmp_size[0]) em_f = em.resize(tmp_size) em_a = em_a.resize(tmp_size) em_f.save(os.path.join(resf_dir, tmp_name + '.png')) em_a.save(os.path.join(resa_dir, tmp_name + '.png')) em_maps =tensor_norm_01(em_maps) save_variable_h5(os.path.join(all_dir, tmp_name + '.h5'), np.float16(em_maps)) else: pass