File size: 6,480 Bytes
89c5d90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
""" DexiNed main script

This code is based on DexiNed (Dense Extreme Inception Network for Edge Detection),
Please pay attention in the function config_model() to set any parameter before training or
testing the model.
"""
__author__ = "Xavier Soria Poma, CVC-UAB"
__email__ = "[email protected] / [email protected]"
__homepage__="www.cvc.uab.cat/people/xsoria"
__credits__=['DexiNed']
__copyright__   = "MIT License [see LICENSE for details]"#"Copyright 2019, CIMI"

import sys
import argparse
import tensorflow as tf

import utls.dataset_manager as dm
from train import m_trainer
from test import m_tester
import platform

def config_model():
    in_linux = True if platform.system() == "Linux" else False
    base_dir = "/opt/dataset/" if in_linux else "../../dataset/"
    parser = argparse.ArgumentParser(description='Basic details to run HED')
    # dataset config
    parser.add_argument('--train_dataset', default='BIPED', choices=['BIPED','BSDS'])
    parser.add_argument('--test_dataset', default='CLASSIC', choices=['BIPED', 'BSDS','MULTICUE','NYUD','PASCAL','CID','DCD'])
    parser.add_argument('--dataset_dir',default=base_dir,type=str) # default:'/opt/dataset/'
    parser.add_argument('--dataset_augmented', default=True,type=bool)
    parser.add_argument('--train_list',default='train_rgb.lst', type=str) # BSDS train_pair.lst, SSMIHD train_rgb_pair.lst/train_rgbn_pair.lst
    parser.add_argument('--test_list', default='test_rgb.lst',type=str) # for NYUD&BSDS:test_pair.lst, biped msi_test.lst/test_rgb.lst
    parser.add_argument('--trained_model_dir', default='train',type=str) # 'trainV2_RN'
    # SSMIHD_RGBN msi_valid_list.txt and msi_test_list.txt is for unified test
    parser.add_argument('--use_nir', default=False, type=bool)
    parser.add_argument('--use_dataset', default=False, type=bool) # test: dataset=True single image=FALSE
    # model config
    parser.add_argument('--model_state', default='train', choices=['train','test','None']) # always in None
    parser.add_argument('--model_name', default='DXN',choices=['DXN','XCP','None'])
    parser.add_argument('--use_v1', default=False,type=bool)
    parser.add_argument('--model_purpose', default='edges',choices=['edges','restoration','None'])
    parser.add_argument('--batch_size_train',default=8,type=int)
    parser.add_argument('--batch_size_val',default=8, type=int)
    parser.add_argument('--batch_size_test',default=1,type=int)
    parser.add_argument('--checkpoint_dir', default='checkpoints',type=str)
    parser.add_argument('--logs_dir', default='logs',type=str)
    parser.add_argument('--learning_rate',default=1e-4, type=float) # 1e-4=0.0001
    parser.add_argument('--lr_scheduler',default=None,choices=[None,'asce','desc']) # check here
    parser.add_argument('--learning_rate_decay', default=0.1,type=float)
    parser.add_argument('--weight_decay', default=0.0002, type=float)
    parser.add_argument('--model_weights_path', default='vgg16_.npy')
    parser.add_argument('--train_split', default=0.9, type=float) # default 0.8
    parser.add_argument('--max_iterations', default=180000, type=int) # 100000
    parser.add_argument('--learning_decay_interval',default=25000, type=int) # 25000
    parser.add_argument('--loss_weights', default=1.0, type=float)
    parser.add_argument('--save_interval', default=20000, type=int)  # 50000
    parser.add_argument('--val_interval', default=30, type=int)
    parser.add_argument('--use_subpixel', default=None, type=bool)  # None=upsampling with transp conv
    parser.add_argument('--deep_supervision', default=True, type= bool)
    parser.add_argument('--target_regression',default=True, type=bool) # true
    parser.add_argument('--mean_pixel_values', default=[103.939,116.779,123.68, 137.86], type=float)# [103.939,116.779,123.68]
    # for Nir pixels mean [103.939,116.779,123.68, 137.86]
    parser.add_argument('--channel_swap', default=[2,1,0], type=int)
    parser.add_argument('--gpu-limit',default=1.0, type= float, )
    parser.add_argument('--use_trained_model', default=True, type=bool)  # for vvg16
    parser.add_argument('--use_previous_trained', default=False, type=bool) # for training
    # image configuration
    parser.add_argument('--image_width', default=512, type=int) # 480 NYUD=560 BIPED=1280 default 400 other 448
    parser.add_argument('--image_height', default=512, type=int) # 480 for NYUD 425 BIPED=720 default 400
    parser.add_argument('--n_channels', default=3, type=int) # last ssmihd_xcp trained in 512
    # test config
    parser.add_argument('--test_snapshot', default=149999, type=int) #  BIPED: 149736 BSDS:101179
    #DexiNedv1=149736,DexiNedv2=149999
    parser.add_argument('--testing_threshold', default=0.0, type=float)
    parser.add_argument('--base_dir_results',default='results/edges',type=str) # default: '/opt/results/edges'
    # single image default=None
    args = parser.parse_args()
    return args

def get_session(gpu_fraction):

    num_threads = False
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)

    if num_threads:
        return tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
    else:
        return tf.compat.v1.Session(config=tf.compat.v1.ConfigProto())


def main(args):

    if not args.dataset_augmented:
        # Only for BIPED dataset
        # dm.augment_data(args)
        print("Please visit the webpage of BIPED in:")
        print("https://xavysp.github.io/MBIPED/")
        print("and run the code")
        sys.exit()

    if args.model_state =='train' or args.model_state=='test':
        sess = get_session(args.gpu_limit)
        # sess =tf.Session()
    else:
        print("The model state is None, so it will exit...")
        sys.exit()

    if args.model_state=='train':
        trainer = m_trainer(args)
        trainer.setup()
        trainer.run(sess)
        sess.close()

    if args.model_state=='test':

        if args.test_dataset=="BIPED":
            if args.image_width >700:
                pass
            else:
                print(' image size is not set in non augmented data')
                sys.exit()
        tester = m_tester(args)
        tester.setup(sess)
        tester.run(sess)
        sess.close()

    if args.model_state=="None":
        print("Sorry the model state is {}".format(args.model_state))
        sys.exit()

if __name__=='__main__':

    args = config_model()
    main(args=args)