Spaces:
Configuration error
Configuration error
File size: 6,480 Bytes
89c5d90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
""" DexiNed main script
This code is based on DexiNed (Dense Extreme Inception Network for Edge Detection),
Please pay attention in the function config_model() to set any parameter before training or
testing the model.
"""
__author__ = "Xavier Soria Poma, CVC-UAB"
__email__ = "[email protected] / [email protected]"
__homepage__="www.cvc.uab.cat/people/xsoria"
__credits__=['DexiNed']
__copyright__ = "MIT License [see LICENSE for details]"#"Copyright 2019, CIMI"
import sys
import argparse
import tensorflow as tf
import utls.dataset_manager as dm
from train import m_trainer
from test import m_tester
import platform
def config_model():
in_linux = True if platform.system() == "Linux" else False
base_dir = "/opt/dataset/" if in_linux else "../../dataset/"
parser = argparse.ArgumentParser(description='Basic details to run HED')
# dataset config
parser.add_argument('--train_dataset', default='BIPED', choices=['BIPED','BSDS'])
parser.add_argument('--test_dataset', default='CLASSIC', choices=['BIPED', 'BSDS','MULTICUE','NYUD','PASCAL','CID','DCD'])
parser.add_argument('--dataset_dir',default=base_dir,type=str) # default:'/opt/dataset/'
parser.add_argument('--dataset_augmented', default=True,type=bool)
parser.add_argument('--train_list',default='train_rgb.lst', type=str) # BSDS train_pair.lst, SSMIHD train_rgb_pair.lst/train_rgbn_pair.lst
parser.add_argument('--test_list', default='test_rgb.lst',type=str) # for NYUD&BSDS:test_pair.lst, biped msi_test.lst/test_rgb.lst
parser.add_argument('--trained_model_dir', default='train',type=str) # 'trainV2_RN'
# SSMIHD_RGBN msi_valid_list.txt and msi_test_list.txt is for unified test
parser.add_argument('--use_nir', default=False, type=bool)
parser.add_argument('--use_dataset', default=False, type=bool) # test: dataset=True single image=FALSE
# model config
parser.add_argument('--model_state', default='train', choices=['train','test','None']) # always in None
parser.add_argument('--model_name', default='DXN',choices=['DXN','XCP','None'])
parser.add_argument('--use_v1', default=False,type=bool)
parser.add_argument('--model_purpose', default='edges',choices=['edges','restoration','None'])
parser.add_argument('--batch_size_train',default=8,type=int)
parser.add_argument('--batch_size_val',default=8, type=int)
parser.add_argument('--batch_size_test',default=1,type=int)
parser.add_argument('--checkpoint_dir', default='checkpoints',type=str)
parser.add_argument('--logs_dir', default='logs',type=str)
parser.add_argument('--learning_rate',default=1e-4, type=float) # 1e-4=0.0001
parser.add_argument('--lr_scheduler',default=None,choices=[None,'asce','desc']) # check here
parser.add_argument('--learning_rate_decay', default=0.1,type=float)
parser.add_argument('--weight_decay', default=0.0002, type=float)
parser.add_argument('--model_weights_path', default='vgg16_.npy')
parser.add_argument('--train_split', default=0.9, type=float) # default 0.8
parser.add_argument('--max_iterations', default=180000, type=int) # 100000
parser.add_argument('--learning_decay_interval',default=25000, type=int) # 25000
parser.add_argument('--loss_weights', default=1.0, type=float)
parser.add_argument('--save_interval', default=20000, type=int) # 50000
parser.add_argument('--val_interval', default=30, type=int)
parser.add_argument('--use_subpixel', default=None, type=bool) # None=upsampling with transp conv
parser.add_argument('--deep_supervision', default=True, type= bool)
parser.add_argument('--target_regression',default=True, type=bool) # true
parser.add_argument('--mean_pixel_values', default=[103.939,116.779,123.68, 137.86], type=float)# [103.939,116.779,123.68]
# for Nir pixels mean [103.939,116.779,123.68, 137.86]
parser.add_argument('--channel_swap', default=[2,1,0], type=int)
parser.add_argument('--gpu-limit',default=1.0, type= float, )
parser.add_argument('--use_trained_model', default=True, type=bool) # for vvg16
parser.add_argument('--use_previous_trained', default=False, type=bool) # for training
# image configuration
parser.add_argument('--image_width', default=512, type=int) # 480 NYUD=560 BIPED=1280 default 400 other 448
parser.add_argument('--image_height', default=512, type=int) # 480 for NYUD 425 BIPED=720 default 400
parser.add_argument('--n_channels', default=3, type=int) # last ssmihd_xcp trained in 512
# test config
parser.add_argument('--test_snapshot', default=149999, type=int) # BIPED: 149736 BSDS:101179
#DexiNedv1=149736,DexiNedv2=149999
parser.add_argument('--testing_threshold', default=0.0, type=float)
parser.add_argument('--base_dir_results',default='results/edges',type=str) # default: '/opt/results/edges'
# single image default=None
args = parser.parse_args()
return args
def get_session(gpu_fraction):
num_threads = False
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
if num_threads:
return tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
else:
return tf.compat.v1.Session(config=tf.compat.v1.ConfigProto())
def main(args):
if not args.dataset_augmented:
# Only for BIPED dataset
# dm.augment_data(args)
print("Please visit the webpage of BIPED in:")
print("https://xavysp.github.io/MBIPED/")
print("and run the code")
sys.exit()
if args.model_state =='train' or args.model_state=='test':
sess = get_session(args.gpu_limit)
# sess =tf.Session()
else:
print("The model state is None, so it will exit...")
sys.exit()
if args.model_state=='train':
trainer = m_trainer(args)
trainer.setup()
trainer.run(sess)
sess.close()
if args.model_state=='test':
if args.test_dataset=="BIPED":
if args.image_width >700:
pass
else:
print(' image size is not set in non augmented data')
sys.exit()
tester = m_tester(args)
tester.setup(sess)
tester.run(sess)
sess.close()
if args.model_state=="None":
print("Sorry the model state is {}".format(args.model_state))
sys.exit()
if __name__=='__main__':
args = config_model()
main(args=args)
|