Spaces:
Configuration error
Configuration error
| from __future__ import print_function | |
| import argparse | |
| import os | |
| import time, platform | |
| import cv2 | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from datasets import DATASET_NAMES, BipedDataset, TestDataset, dataset_info | |
| from losses import * | |
| from model import DexiNed | |
| from utils import (image_normalization, save_image_batch_to_disk, | |
| visualize_result,count_parameters) | |
| IS_LINUX = True if platform.system()=="Linux" else False | |
| def train_one_epoch(epoch, dataloader, model, criterion, optimizer, device, | |
| log_interval_vis, tb_writer, args=None): | |
| imgs_res_folder = os.path.join(args.output_dir, 'current_res') | |
| os.makedirs(imgs_res_folder,exist_ok=True) | |
| # Put model in training mode | |
| model.train() | |
| # l_weight = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.1] # for bdcn ori loss | |
| # before [0.6,0.6,1.1,1.1,0.4,0.4,1.3] [0.4,0.4,1.1,1.1,0.6,0.6,1.3],[0.4,0.4,1.1,1.1,0.8,0.8,1.3] | |
| l_weight = [0.7,0.7,1.1,1.1,0.3,0.3,1.3] # New BDCN loss | |
| # l_weight = [[0.05, 2.], [0.05, 2.], [0.05, 2.], | |
| # [0.1, 1.], [0.1, 1.], [0.1, 1.], | |
| # [0.01, 4.]] # for cats loss | |
| loss_avg =[] | |
| for batch_id, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) # BxCxHxW | |
| labels = sample_batched['labels'].to(device) # BxHxW | |
| preds_list = model(images) | |
| # loss = sum([criterion(preds, labels, l_w, device) for preds, l_w in zip(preds_list, l_weight)]) # cats_loss | |
| loss = sum([criterion(preds, labels,l_w) for preds, l_w in zip(preds_list,l_weight)]) # bdcn_loss | |
| # loss = sum([criterion(preds, labels) for preds in preds_list]) #HED loss, rcf_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| loss_avg.append(loss.item()) | |
| if epoch==0 and (batch_id==100 and tb_writer is not None): | |
| tmp_loss = np.array(loss_avg).mean() | |
| tb_writer.add_scalar('loss', tmp_loss,epoch) | |
| if batch_id % 5 == 0: | |
| print(time.ctime(), 'Epoch: {0} Sample {1}/{2} Loss: {3}' | |
| .format(epoch, batch_id, len(dataloader), loss.item())) | |
| if batch_id % log_interval_vis == 0: | |
| res_data = [] | |
| img = images.cpu().numpy() | |
| res_data.append(img[2]) | |
| ed_gt = labels.cpu().numpy() | |
| res_data.append(ed_gt[2]) | |
| # tmp_pred = tmp_preds[2,...] | |
| for i in range(len(preds_list)): | |
| tmp = preds_list[i] | |
| tmp = tmp[2] | |
| # print(tmp.shape) | |
| tmp = torch.sigmoid(tmp).unsqueeze(dim=0) | |
| tmp = tmp.cpu().detach().numpy() | |
| res_data.append(tmp) | |
| vis_imgs = visualize_result(res_data, arg=args) | |
| del tmp, res_data | |
| vis_imgs = cv2.resize(vis_imgs, | |
| (int(vis_imgs.shape[1]*0.8), int(vis_imgs.shape[0]*0.8))) | |
| img_test = 'Epoch: {0} Sample {1}/{2} Loss: {3}' \ | |
| .format(epoch, batch_id, len(dataloader), loss.item()) | |
| BLACK = (0, 0, 255) | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_size = 1.1 | |
| font_color = BLACK | |
| font_thickness = 2 | |
| x, y = 30, 30 | |
| vis_imgs = cv2.putText(vis_imgs, | |
| img_test, | |
| (x, y), | |
| font, font_size, font_color, font_thickness, cv2.LINE_AA) | |
| cv2.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs) | |
| loss_avg = np.array(loss_avg).mean() | |
| return loss_avg | |
| def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None): | |
| # XXX This is not really validation, but testing | |
| # Put model in eval mode | |
| model.eval() | |
| with torch.no_grad(): | |
| for _, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) | |
| # labels = sample_batched['labels'].to(device) | |
| file_names = sample_batched['file_names'] | |
| image_shape = sample_batched['image_shape'] | |
| preds = model(images) | |
| # print('pred shape', preds[0].shape) | |
| save_image_batch_to_disk(preds[-1], | |
| output_dir, | |
| file_names,img_shape=image_shape, | |
| arg=arg) | |
| def test(checkpoint_path, dataloader, model, device, output_dir, args): | |
| if not os.path.isfile(checkpoint_path): | |
| raise FileNotFoundError( | |
| f"Checkpoint filte note found: {checkpoint_path}") | |
| print(f"Restoring weights from: {checkpoint_path}") | |
| model.load_state_dict(torch.load(checkpoint_path, | |
| map_location=device)) | |
| # Put model in evaluation mode | |
| model.eval() | |
| with torch.no_grad(): | |
| total_duration = [] | |
| for batch_id, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) | |
| if not args.test_data == "CLASSIC": | |
| labels = sample_batched['labels'].to(device) | |
| file_names = sample_batched['file_names'] | |
| image_shape = sample_batched['image_shape'] | |
| print(f"input tensor shape: {images.shape}") | |
| # images = images[:, [2, 1, 0], :, :] | |
| end = time.perf_counter() | |
| if device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| preds = model(images) | |
| if device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| tmp_duration = time.perf_counter() - end | |
| total_duration.append(tmp_duration) | |
| save_image_batch_to_disk(preds, | |
| output_dir, | |
| file_names, | |
| image_shape, | |
| arg=args) | |
| torch.cuda.empty_cache() | |
| total_duration = np.sum(np.array(total_duration)) | |
| print("******** Testing finished in", args.test_data, "dataset. *****") | |
| print("FPS: %f.4" % (len(dataloader)/total_duration)) | |
| def testPich(checkpoint_path, dataloader, model, device, output_dir, args): | |
| # a test model plus the interganged channels | |
| if not os.path.isfile(checkpoint_path): | |
| raise FileNotFoundError( | |
| f"Checkpoint filte note found: {checkpoint_path}") | |
| print(f"Restoring weights from: {checkpoint_path}") | |
| model.load_state_dict(torch.load(checkpoint_path, | |
| map_location=device)) | |
| # Put model in evaluation mode | |
| model.eval() | |
| with torch.no_grad(): | |
| total_duration = [] | |
| for batch_id, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) | |
| if not args.test_data == "CLASSIC": | |
| labels = sample_batched['labels'].to(device) | |
| file_names = sample_batched['file_names'] | |
| image_shape = sample_batched['image_shape'] | |
| print(f"input tensor shape: {images.shape}") | |
| start_time = time.time() | |
| # images2 = images[:, [1, 0, 2], :, :] #GBR | |
| images2 = images[:, [2, 1, 0], :, :] # RGB | |
| preds = model(images) | |
| preds2 = model(images2) | |
| tmp_duration = time.time() - start_time | |
| total_duration.append(tmp_duration) | |
| save_image_batch_to_disk([preds,preds2], | |
| output_dir, | |
| file_names, | |
| image_shape, | |
| arg=args, is_inchannel=True) | |
| torch.cuda.empty_cache() | |
| total_duration = np.array(total_duration) | |
| print("******** Testing finished in", args.test_data, "dataset. *****") | |
| print("Average time per image: %f.4" % total_duration.mean(), "seconds") | |
| print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds") | |
| def parse_args(): | |
| """Parse command line arguments.""" | |
| parser = argparse.ArgumentParser(description='DexiNed trainer.') | |
| parser.add_argument('--choose_test_data', | |
| type=int, | |
| default=-1, | |
| help='Already set the dataset for testing choice: 0 - 8') | |
| # ----------- test -------0-- | |
| TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8 | |
| test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX) | |
| test_dir = test_inf['data_dir'] | |
| is_testing =True# current test -352-SM-NewGT-2AugmenPublish | |
| # Training settings | |
| TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, MDBD=6 | |
| train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX) | |
| train_dir = train_inf['data_dir'] | |
| # Data parameters | |
| parser.add_argument('--input_dir', | |
| type=str, | |
| default=train_dir, | |
| help='the path to the directory with the input data.') | |
| parser.add_argument('--input_val_dir', | |
| type=str, | |
| default=test_inf['data_dir'], | |
| help='the path to the directory with the input data for validation.') | |
| parser.add_argument('--output_dir', | |
| type=str, | |
| default='checkpoints', | |
| help='the path to output the results.') | |
| parser.add_argument('--train_data', | |
| type=str, | |
| choices=DATASET_NAMES, | |
| default=TRAIN_DATA, | |
| help='Name of the dataset.') | |
| parser.add_argument('--test_data', | |
| type=str, | |
| choices=DATASET_NAMES, | |
| default=TEST_DATA, | |
| help='Name of the dataset.') | |
| parser.add_argument('--test_list', | |
| type=str, | |
| default=test_inf['test_list'], | |
| help='Dataset sample indices list.') | |
| parser.add_argument('--train_list', | |
| type=str, | |
| default=train_inf['train_list'], | |
| help='Dataset sample indices list.') | |
| parser.add_argument('--is_testing',type=bool, | |
| default=is_testing, | |
| help='Script in testing mode.') | |
| parser.add_argument('--double_img', | |
| type=bool, | |
| default=False, | |
| help='True: use same 2 imgs changing channels') # Just for test | |
| parser.add_argument('--resume', | |
| type=bool, | |
| default=False, | |
| help='use previous trained data') # Just for test | |
| parser.add_argument('--checkpoint_data', | |
| type=str, | |
| default='10/10_model.pth',# 4 6 7 9 14 | |
| help='Checkpoint path from which to restore model weights from.') | |
| parser.add_argument('--test_img_width', | |
| type=int, | |
| default=test_inf['img_width'], | |
| help='Image width for testing.') | |
| parser.add_argument('--test_img_height', | |
| type=int, | |
| default=test_inf['img_height'], | |
| help='Image height for testing.') | |
| parser.add_argument('--res_dir', | |
| type=str, | |
| default='result', | |
| help='Result directory') | |
| parser.add_argument('--log_interval_vis', | |
| type=int, | |
| default=50, | |
| help='The number of batches to wait before printing test predictions.') | |
| parser.add_argument('--epochs', | |
| type=int, | |
| default=17, | |
| metavar='N', | |
| help='Number of training epochs (default: 25).') | |
| parser.add_argument('--lr', | |
| default=1e-4, | |
| type=float, | |
| help='Initial learning rate.') | |
| parser.add_argument('--wd', | |
| type=float, | |
| default=1e-8, | |
| metavar='WD', | |
| help='weight decay (Good 1e-8) in TF1=0') # 1e-8 -> BIRND/MDBD, 0.0 -> BIPED | |
| parser.add_argument('--adjust_lr', | |
| default=[10,15], | |
| type=int, | |
| help='Learning rate step size.') #[5,10]BIRND [10,15]BIPED/BRIND | |
| parser.add_argument('--batch_size', | |
| type=int, | |
| default=8, | |
| metavar='B', | |
| help='the mini-batch size (default: 8)') | |
| parser.add_argument('--workers', | |
| default=16, | |
| type=int, | |
| help='The number of workers for the dataloaders.') | |
| parser.add_argument('--tensorboard',type=bool, | |
| default=True, | |
| help='Use Tensorboard for logging.'), | |
| parser.add_argument('--img_width', | |
| type=int, | |
| default=352, | |
| help='Image width for training.') # BIPED 400 BSDS 352/320 MDBD 480 | |
| parser.add_argument('--img_height', | |
| type=int, | |
| default=352, | |
| help='Image height for training.') # BIPED 480 BSDS 352/320 | |
| parser.add_argument('--channel_swap', | |
| default=[2, 1, 0], | |
| type=int) | |
| parser.add_argument('--crop_img', | |
| default=True, | |
| type=bool, | |
| help='If true crop training images, else resize images to match image width and height.') | |
| parser.add_argument('--mean_pixel_values', | |
| default=[103.939,116.779,123.68, 137.86], | |
| type=float) # [103.939,116.779,123.68] [104.00699, 116.66877, 122.67892] | |
| args = parser.parse_args() | |
| return args | |
| def main(args): | |
| """Main function.""" | |
| print(f"Number of GPU's available: {torch.cuda.device_count()}") | |
| print(f"Pytorch version: {torch.__version__}") | |
| # Tensorboard summary writer | |
| tb_writer = None | |
| training_dir = os.path.join(args.output_dir,args.train_data) | |
| os.makedirs(training_dir,exist_ok=True) | |
| checkpoint_path = os.path.join(args.output_dir, args.train_data, args.checkpoint_data) | |
| if args.tensorboard and not args.is_testing: | |
| from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather | |
| tb_writer = SummaryWriter(log_dir=training_dir) | |
| # saving Model training settings | |
| training_notes = ['DexiNed, Xavier Normal Init, LR= ' + str(args.lr) + ' WD= ' | |
| + str(args.wd) + ' image size = ' + str(args.img_width) | |
| + ' adjust LR='+ str(args.adjust_lr) + ' Loss Function= BDCNloss2. ' | |
| +'Trained on> '+args.train_data+' Tested on> ' | |
| +args.test_data+' Batch size= '+str(args.batch_size)+' '+str(time.asctime())] | |
| info_txt = open(os.path.join(training_dir, 'training_settings.txt'), 'w') | |
| info_txt.write(str(training_notes)) | |
| info_txt.close() | |
| # Get computing device | |
| device = torch.device('cpu' if torch.cuda.device_count() == 0 | |
| else 'cuda') | |
| # Instantiate model and move it to the computing device | |
| model = DexiNed().to(device) | |
| # model = nn.DataParallel(model) | |
| ini_epoch =0 | |
| if not args.is_testing: | |
| if args.resume: | |
| ini_epoch=11 | |
| model.load_state_dict(torch.load(checkpoint_path, | |
| map_location=device)) | |
| print('Training restarted from> ',checkpoint_path) | |
| dataset_train = BipedDataset(args.input_dir, | |
| img_width=args.img_width, | |
| img_height=args.img_height, | |
| mean_bgr=args.mean_pixel_values[0:3] if len( | |
| args.mean_pixel_values) == 4 else args.mean_pixel_values, | |
| train_mode='train', | |
| arg=args | |
| ) | |
| dataloader_train = DataLoader(dataset_train, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.workers) | |
| dataset_val = TestDataset(args.input_val_dir, | |
| test_data=args.test_data, | |
| img_width=args.test_img_width, | |
| img_height=args.test_img_height, | |
| mean_bgr=args.mean_pixel_values[0:3] if len( | |
| args.mean_pixel_values) == 4 else args.mean_pixel_values, | |
| test_list=args.test_list, arg=args | |
| ) | |
| dataloader_val = DataLoader(dataset_val, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=args.workers) | |
| # Testing | |
| if args.is_testing: | |
| output_dir = os.path.join(args.res_dir, args.train_data+"2"+ args.test_data) | |
| print(f"output_dir: {output_dir}") | |
| if args.double_img: | |
| # predict twice an image changing channels, then mix those results | |
| testPich(checkpoint_path, dataloader_val, model, device, output_dir, args) | |
| else: | |
| test(checkpoint_path, dataloader_val, model, device, output_dir, args) | |
| num_param = count_parameters(model) | |
| print('-------------------------------------------------------') | |
| print('DexiNed # of Parameters:') | |
| print(num_param) | |
| print('-------------------------------------------------------') | |
| return | |
| criterion = bdcn_loss2 # hed_loss2 #bdcn_loss2 | |
| optimizer = optim.Adam(model.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.wd) | |
| # Main training loop | |
| seed=1021 | |
| adjust_lr = args.adjust_lr | |
| lr2= args.lr | |
| for epoch in range(ini_epoch,args.epochs): | |
| if epoch%7==0: | |
| seed = seed+1000 | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| print("------ Random seed applied-------------") | |
| # Create output directories | |
| if adjust_lr is not None: | |
| if epoch in adjust_lr: | |
| lr2 = lr2*0.1 | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr2 | |
| output_dir_epoch = os.path.join(args.output_dir,args.train_data, str(epoch)) | |
| img_test_dir = os.path.join(output_dir_epoch, args.test_data + '_res') | |
| os.makedirs(output_dir_epoch,exist_ok=True) | |
| os.makedirs(img_test_dir,exist_ok=True) | |
| validate_one_epoch(epoch, | |
| dataloader_val, | |
| model, | |
| device, | |
| img_test_dir, | |
| arg=args) | |
| avg_loss =train_one_epoch(epoch, | |
| dataloader_train, | |
| model, | |
| criterion, | |
| optimizer, | |
| device, | |
| args.log_interval_vis, | |
| tb_writer, | |
| args=args) | |
| validate_one_epoch(epoch, | |
| dataloader_val, | |
| model, | |
| device, | |
| img_test_dir, | |
| arg=args) | |
| # Save model after end of every epoch | |
| torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(), | |
| os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch))) | |
| if tb_writer is not None: | |
| tb_writer.add_scalar('loss', | |
| avg_loss, | |
| epoch+1) | |
| print('Current learning rate> ', optimizer.param_groups[0]['lr']) | |
| num_param = count_parameters(model) | |
| print('-------------------------------------------------------') | |
| print('DexiNed, # of Parameters:') | |
| print(num_param) | |
| print('-------------------------------------------------------') | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| main(args) | |