DexinedApp / main.py
Dinars34's picture
Upload 60 files
89c5d90 verified
raw
history blame
20.7 kB
from __future__ import print_function
import argparse
import os
import time, platform
import cv2
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import DATASET_NAMES, BipedDataset, TestDataset, dataset_info
from losses import *
from model import DexiNed
from utils import (image_normalization, save_image_batch_to_disk,
visualize_result,count_parameters)
IS_LINUX = True if platform.system()=="Linux" else False
def train_one_epoch(epoch, dataloader, model, criterion, optimizer, device,
log_interval_vis, tb_writer, args=None):
imgs_res_folder = os.path.join(args.output_dir, 'current_res')
os.makedirs(imgs_res_folder,exist_ok=True)
# Put model in training mode
model.train()
# l_weight = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.1] # for bdcn ori loss
# before [0.6,0.6,1.1,1.1,0.4,0.4,1.3] [0.4,0.4,1.1,1.1,0.6,0.6,1.3],[0.4,0.4,1.1,1.1,0.8,0.8,1.3]
l_weight = [0.7,0.7,1.1,1.1,0.3,0.3,1.3] # New BDCN loss
# l_weight = [[0.05, 2.], [0.05, 2.], [0.05, 2.],
# [0.1, 1.], [0.1, 1.], [0.1, 1.],
# [0.01, 4.]] # for cats loss
loss_avg =[]
for batch_id, sample_batched in enumerate(dataloader):
images = sample_batched['images'].to(device) # BxCxHxW
labels = sample_batched['labels'].to(device) # BxHxW
preds_list = model(images)
# loss = sum([criterion(preds, labels, l_w, device) for preds, l_w in zip(preds_list, l_weight)]) # cats_loss
loss = sum([criterion(preds, labels,l_w) for preds, l_w in zip(preds_list,l_weight)]) # bdcn_loss
# loss = sum([criterion(preds, labels) for preds in preds_list]) #HED loss, rcf_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_avg.append(loss.item())
if epoch==0 and (batch_id==100 and tb_writer is not None):
tmp_loss = np.array(loss_avg).mean()
tb_writer.add_scalar('loss', tmp_loss,epoch)
if batch_id % 5 == 0:
print(time.ctime(), 'Epoch: {0} Sample {1}/{2} Loss: {3}'
.format(epoch, batch_id, len(dataloader), loss.item()))
if batch_id % log_interval_vis == 0:
res_data = []
img = images.cpu().numpy()
res_data.append(img[2])
ed_gt = labels.cpu().numpy()
res_data.append(ed_gt[2])
# tmp_pred = tmp_preds[2,...]
for i in range(len(preds_list)):
tmp = preds_list[i]
tmp = tmp[2]
# print(tmp.shape)
tmp = torch.sigmoid(tmp).unsqueeze(dim=0)
tmp = tmp.cpu().detach().numpy()
res_data.append(tmp)
vis_imgs = visualize_result(res_data, arg=args)
del tmp, res_data
vis_imgs = cv2.resize(vis_imgs,
(int(vis_imgs.shape[1]*0.8), int(vis_imgs.shape[0]*0.8)))
img_test = 'Epoch: {0} Sample {1}/{2} Loss: {3}' \
.format(epoch, batch_id, len(dataloader), loss.item())
BLACK = (0, 0, 255)
font = cv2.FONT_HERSHEY_SIMPLEX
font_size = 1.1
font_color = BLACK
font_thickness = 2
x, y = 30, 30
vis_imgs = cv2.putText(vis_imgs,
img_test,
(x, y),
font, font_size, font_color, font_thickness, cv2.LINE_AA)
cv2.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs)
loss_avg = np.array(loss_avg).mean()
return loss_avg
def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None):
# XXX This is not really validation, but testing
# Put model in eval mode
model.eval()
with torch.no_grad():
for _, sample_batched in enumerate(dataloader):
images = sample_batched['images'].to(device)
# labels = sample_batched['labels'].to(device)
file_names = sample_batched['file_names']
image_shape = sample_batched['image_shape']
preds = model(images)
# print('pred shape', preds[0].shape)
save_image_batch_to_disk(preds[-1],
output_dir,
file_names,img_shape=image_shape,
arg=arg)
def test(checkpoint_path, dataloader, model, device, output_dir, args):
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint filte note found: {checkpoint_path}")
print(f"Restoring weights from: {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path,
map_location=device))
# Put model in evaluation mode
model.eval()
with torch.no_grad():
total_duration = []
for batch_id, sample_batched in enumerate(dataloader):
images = sample_batched['images'].to(device)
if not args.test_data == "CLASSIC":
labels = sample_batched['labels'].to(device)
file_names = sample_batched['file_names']
image_shape = sample_batched['image_shape']
print(f"input tensor shape: {images.shape}")
# images = images[:, [2, 1, 0], :, :]
end = time.perf_counter()
if device.type == 'cuda':
torch.cuda.synchronize()
preds = model(images)
if device.type == 'cuda':
torch.cuda.synchronize()
tmp_duration = time.perf_counter() - end
total_duration.append(tmp_duration)
save_image_batch_to_disk(preds,
output_dir,
file_names,
image_shape,
arg=args)
torch.cuda.empty_cache()
total_duration = np.sum(np.array(total_duration))
print("******** Testing finished in", args.test_data, "dataset. *****")
print("FPS: %f.4" % (len(dataloader)/total_duration))
def testPich(checkpoint_path, dataloader, model, device, output_dir, args):
# a test model plus the interganged channels
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint filte note found: {checkpoint_path}")
print(f"Restoring weights from: {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path,
map_location=device))
# Put model in evaluation mode
model.eval()
with torch.no_grad():
total_duration = []
for batch_id, sample_batched in enumerate(dataloader):
images = sample_batched['images'].to(device)
if not args.test_data == "CLASSIC":
labels = sample_batched['labels'].to(device)
file_names = sample_batched['file_names']
image_shape = sample_batched['image_shape']
print(f"input tensor shape: {images.shape}")
start_time = time.time()
# images2 = images[:, [1, 0, 2], :, :] #GBR
images2 = images[:, [2, 1, 0], :, :] # RGB
preds = model(images)
preds2 = model(images2)
tmp_duration = time.time() - start_time
total_duration.append(tmp_duration)
save_image_batch_to_disk([preds,preds2],
output_dir,
file_names,
image_shape,
arg=args, is_inchannel=True)
torch.cuda.empty_cache()
total_duration = np.array(total_duration)
print("******** Testing finished in", args.test_data, "dataset. *****")
print("Average time per image: %f.4" % total_duration.mean(), "seconds")
print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds")
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description='DexiNed trainer.')
parser.add_argument('--choose_test_data',
type=int,
default=-1,
help='Already set the dataset for testing choice: 0 - 8')
# ----------- test -------0--
TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8
test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX)
test_dir = test_inf['data_dir']
is_testing =True# current test -352-SM-NewGT-2AugmenPublish
# Training settings
TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, MDBD=6
train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX)
train_dir = train_inf['data_dir']
# Data parameters
parser.add_argument('--input_dir',
type=str,
default=train_dir,
help='the path to the directory with the input data.')
parser.add_argument('--input_val_dir',
type=str,
default=test_inf['data_dir'],
help='the path to the directory with the input data for validation.')
parser.add_argument('--output_dir',
type=str,
default='checkpoints',
help='the path to output the results.')
parser.add_argument('--train_data',
type=str,
choices=DATASET_NAMES,
default=TRAIN_DATA,
help='Name of the dataset.')
parser.add_argument('--test_data',
type=str,
choices=DATASET_NAMES,
default=TEST_DATA,
help='Name of the dataset.')
parser.add_argument('--test_list',
type=str,
default=test_inf['test_list'],
help='Dataset sample indices list.')
parser.add_argument('--train_list',
type=str,
default=train_inf['train_list'],
help='Dataset sample indices list.')
parser.add_argument('--is_testing',type=bool,
default=is_testing,
help='Script in testing mode.')
parser.add_argument('--double_img',
type=bool,
default=False,
help='True: use same 2 imgs changing channels') # Just for test
parser.add_argument('--resume',
type=bool,
default=False,
help='use previous trained data') # Just for test
parser.add_argument('--checkpoint_data',
type=str,
default='10/10_model.pth',# 4 6 7 9 14
help='Checkpoint path from which to restore model weights from.')
parser.add_argument('--test_img_width',
type=int,
default=test_inf['img_width'],
help='Image width for testing.')
parser.add_argument('--test_img_height',
type=int,
default=test_inf['img_height'],
help='Image height for testing.')
parser.add_argument('--res_dir',
type=str,
default='result',
help='Result directory')
parser.add_argument('--log_interval_vis',
type=int,
default=50,
help='The number of batches to wait before printing test predictions.')
parser.add_argument('--epochs',
type=int,
default=17,
metavar='N',
help='Number of training epochs (default: 25).')
parser.add_argument('--lr',
default=1e-4,
type=float,
help='Initial learning rate.')
parser.add_argument('--wd',
type=float,
default=1e-8,
metavar='WD',
help='weight decay (Good 1e-8) in TF1=0') # 1e-8 -> BIRND/MDBD, 0.0 -> BIPED
parser.add_argument('--adjust_lr',
default=[10,15],
type=int,
help='Learning rate step size.') #[5,10]BIRND [10,15]BIPED/BRIND
parser.add_argument('--batch_size',
type=int,
default=8,
metavar='B',
help='the mini-batch size (default: 8)')
parser.add_argument('--workers',
default=16,
type=int,
help='The number of workers for the dataloaders.')
parser.add_argument('--tensorboard',type=bool,
default=True,
help='Use Tensorboard for logging.'),
parser.add_argument('--img_width',
type=int,
default=352,
help='Image width for training.') # BIPED 400 BSDS 352/320 MDBD 480
parser.add_argument('--img_height',
type=int,
default=352,
help='Image height for training.') # BIPED 480 BSDS 352/320
parser.add_argument('--channel_swap',
default=[2, 1, 0],
type=int)
parser.add_argument('--crop_img',
default=True,
type=bool,
help='If true crop training images, else resize images to match image width and height.')
parser.add_argument('--mean_pixel_values',
default=[103.939,116.779,123.68, 137.86],
type=float) # [103.939,116.779,123.68] [104.00699, 116.66877, 122.67892]
args = parser.parse_args()
return args
def main(args):
"""Main function."""
print(f"Number of GPU's available: {torch.cuda.device_count()}")
print(f"Pytorch version: {torch.__version__}")
# Tensorboard summary writer
tb_writer = None
training_dir = os.path.join(args.output_dir,args.train_data)
os.makedirs(training_dir,exist_ok=True)
checkpoint_path = os.path.join(args.output_dir, args.train_data, args.checkpoint_data)
if args.tensorboard and not args.is_testing:
from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
tb_writer = SummaryWriter(log_dir=training_dir)
# saving Model training settings
training_notes = ['DexiNed, Xavier Normal Init, LR= ' + str(args.lr) + ' WD= '
+ str(args.wd) + ' image size = ' + str(args.img_width)
+ ' adjust LR='+ str(args.adjust_lr) + ' Loss Function= BDCNloss2. '
+'Trained on> '+args.train_data+' Tested on> '
+args.test_data+' Batch size= '+str(args.batch_size)+' '+str(time.asctime())]
info_txt = open(os.path.join(training_dir, 'training_settings.txt'), 'w')
info_txt.write(str(training_notes))
info_txt.close()
# Get computing device
device = torch.device('cpu' if torch.cuda.device_count() == 0
else 'cuda')
# Instantiate model and move it to the computing device
model = DexiNed().to(device)
# model = nn.DataParallel(model)
ini_epoch =0
if not args.is_testing:
if args.resume:
ini_epoch=11
model.load_state_dict(torch.load(checkpoint_path,
map_location=device))
print('Training restarted from> ',checkpoint_path)
dataset_train = BipedDataset(args.input_dir,
img_width=args.img_width,
img_height=args.img_height,
mean_bgr=args.mean_pixel_values[0:3] if len(
args.mean_pixel_values) == 4 else args.mean_pixel_values,
train_mode='train',
arg=args
)
dataloader_train = DataLoader(dataset_train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers)
dataset_val = TestDataset(args.input_val_dir,
test_data=args.test_data,
img_width=args.test_img_width,
img_height=args.test_img_height,
mean_bgr=args.mean_pixel_values[0:3] if len(
args.mean_pixel_values) == 4 else args.mean_pixel_values,
test_list=args.test_list, arg=args
)
dataloader_val = DataLoader(dataset_val,
batch_size=1,
shuffle=False,
num_workers=args.workers)
# Testing
if args.is_testing:
output_dir = os.path.join(args.res_dir, args.train_data+"2"+ args.test_data)
print(f"output_dir: {output_dir}")
if args.double_img:
# predict twice an image changing channels, then mix those results
testPich(checkpoint_path, dataloader_val, model, device, output_dir, args)
else:
test(checkpoint_path, dataloader_val, model, device, output_dir, args)
num_param = count_parameters(model)
print('-------------------------------------------------------')
print('DexiNed # of Parameters:')
print(num_param)
print('-------------------------------------------------------')
return
criterion = bdcn_loss2 # hed_loss2 #bdcn_loss2
optimizer = optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.wd)
# Main training loop
seed=1021
adjust_lr = args.adjust_lr
lr2= args.lr
for epoch in range(ini_epoch,args.epochs):
if epoch%7==0:
seed = seed+1000
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
print("------ Random seed applied-------------")
# Create output directories
if adjust_lr is not None:
if epoch in adjust_lr:
lr2 = lr2*0.1
for param_group in optimizer.param_groups:
param_group['lr'] = lr2
output_dir_epoch = os.path.join(args.output_dir,args.train_data, str(epoch))
img_test_dir = os.path.join(output_dir_epoch, args.test_data + '_res')
os.makedirs(output_dir_epoch,exist_ok=True)
os.makedirs(img_test_dir,exist_ok=True)
validate_one_epoch(epoch,
dataloader_val,
model,
device,
img_test_dir,
arg=args)
avg_loss =train_one_epoch(epoch,
dataloader_train,
model,
criterion,
optimizer,
device,
args.log_interval_vis,
tb_writer,
args=args)
validate_one_epoch(epoch,
dataloader_val,
model,
device,
img_test_dir,
arg=args)
# Save model after end of every epoch
torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch)))
if tb_writer is not None:
tb_writer.add_scalar('loss',
avg_loss,
epoch+1)
print('Current learning rate> ', optimizer.param_groups[0]['lr'])
num_param = count_parameters(model)
print('-------------------------------------------------------')
print('DexiNed, # of Parameters:')
print(num_param)
print('-------------------------------------------------------')
if __name__ == '__main__':
args = parse_args()
main(args)