Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # HOTR official code : main.py | |
| # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
| # ------------------------------------------------------------------------ | |
| # Modified from DETR (https://github.com/facebookresearch/detr) | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # ------------------------------------------------------------------------ | |
| import argparse | |
| import datetime | |
| import json | |
| import random | |
| import time | |
| import multiprocessing | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| import hotr.data.datasets as datasets | |
| import hotr.util.misc as utils | |
| from hotr.engine.arg_parser import get_args_parser | |
| from hotr.data.datasets import build_dataset, get_coco_api_from_dataset | |
| from hotr.engine.trainer import train_one_epoch | |
| from hotr.engine import hoi_evaluator, hoi_accumulator | |
| from hotr.models import build_model | |
| import wandb | |
| from hotr.util.logger import print_params, print_args | |
| def save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename): | |
| # save_ckpt: function for saving checkpoints | |
| output_dir = Path(args.output_dir) | |
| if args.output_dir: | |
| checkpoint_path = output_dir / f'{filename}.pth' | |
| utils.save_on_master({ | |
| 'model': model_without_ddp.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'lr_scheduler': lr_scheduler.state_dict(), | |
| 'epoch': epoch, | |
| 'args': args, | |
| }, checkpoint_path) | |
| def main(args): | |
| utils.init_distributed_mode(args) | |
| if args.frozen_weights is not None: | |
| print("Freeze weights for detector") | |
| device = torch.device(args.device) | |
| # fix the seed for reproducibility | |
| seed = args.seed + utils.get_rank() | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| # Data Setup | |
| dataset_train = build_dataset(image_set='train', args=args) | |
| dataset_val = build_dataset(image_set='val' if not args.eval else 'test', args=args) | |
| assert dataset_train.num_action() == dataset_val.num_action(), "Number of actions should be the same between splits" | |
| args.num_classes = dataset_train.num_category() | |
| args.num_actions = dataset_train.num_action() | |
| args.action_names = dataset_train.get_actions() | |
| if args.share_enc: args.hoi_enc_layers = args.enc_layers | |
| if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers | |
| if args.dataset_file == 'vcoco': | |
| # Save V-COCO dataset statistics | |
| args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0] | |
| args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1) | |
| args.human_actions = dataset_train.get_human_action() | |
| args.object_actions = dataset_train.get_object_action() | |
| args.num_human_act = dataset_train.num_human_act() | |
| elif args.dataset_file == 'hico-det': | |
| args.valid_obj_ids = dataset_train.get_valid_obj_ids() | |
| print_args(args) | |
| if args.distributed: | |
| sampler_train = DistributedSampler(dataset_train, shuffle=True) | |
| sampler_val = DistributedSampler(dataset_val, shuffle=False) | |
| else: | |
| sampler_train = torch.utils.data.RandomSampler(dataset_train) | |
| sampler_val = torch.utils.data.SequentialSampler(dataset_val) | |
| batch_sampler_train = torch.utils.data.BatchSampler( | |
| sampler_train, args.batch_size, drop_last=True) | |
| data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, | |
| collate_fn=utils.collate_fn, num_workers=args.num_workers) | |
| data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, | |
| drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) | |
| # Model Setup | |
| model, criterion, postprocessors = build_model(args) | |
| # import pdb;pdb.set_trace() | |
| model.to(device) | |
| model_without_ddp = model | |
| if args.distributed: | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) | |
| model_without_ddp = model.module | |
| n_parameters = print_params(model) | |
| param_dicts = [ | |
| {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, | |
| { | |
| "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], | |
| "lr": args.lr_backbone, | |
| }, | |
| ] | |
| optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) | |
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) | |
| # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [1,100]) | |
| # Weight Setup | |
| if args.frozen_weights is not None: | |
| if args.frozen_weights.startswith('https'): | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| args.frozen_weights, map_location='cpu', check_hash=True) | |
| else: | |
| checkpoint = torch.load(args.frozen_weights, map_location='cpu') | |
| model_without_ddp.detr.load_state_dict(checkpoint['model']) | |
| if args.resume: | |
| if args.resume.startswith('https'): | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| args.resume, map_location='cpu', check_hash=True) | |
| else: | |
| checkpoint = torch.load(args.resume, map_location='cpu') | |
| model_without_ddp.load_state_dict(checkpoint['model']) | |
| if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |
| args.start_epoch = checkpoint['epoch'] + 1 | |
| # import pdb;pdb.set_trace() | |
| if args.eval: | |
| # test only mode | |
| if args.HOIDet: | |
| if args.dataset_file == 'vcoco': | |
| total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device) | |
| sc1, sc2 = hoi_accumulator(args, total_res, True, False) | |
| elif args.dataset_file == 'hico-det': | |
| test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device) | |
| print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f}') | |
| print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f}') | |
| print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f}') | |
| else: raise ValueError(f'dataset {args.dataset_file} is not supported.') | |
| return | |
| else: | |
| test_stats, coco_evaluator = evaluate_coco(model, criterion, postprocessors, | |
| data_loader_val, base_ds, device, args.output_dir) | |
| if args.output_dir: | |
| utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") | |
| return | |
| # stats | |
| scenario1, scenario2 = 0, 0 | |
| best_mAP, best_rare, best_non_rare = 0, 0, 0 | |
| # add argparse | |
| if args.wandb and utils.get_rank() == 0: | |
| wandb.init( | |
| project=args.project_name, | |
| group=args.group_name, | |
| name=args.run_name, | |
| config=args | |
| ) | |
| wandb.watch(model) | |
| # Training starts here! | |
| # lr_scheduler.step() | |
| start_time = time.time() | |
| for epoch in range(args.start_epoch, args.epochs): | |
| if args.distributed: | |
| sampler_train.set_epoch(epoch) | |
| train_stats = train_one_epoch( | |
| model, criterion, data_loader_train, optimizer, device, epoch, args.epochs, args.ramp_up_epoch,args.ramp_down_epoch,args.hoi_consistency_loss_coef, | |
| args.clip_max_norm, dataset_file=args.dataset_file, log=args.wandb) | |
| lr_scheduler.step() | |
| # Validation | |
| if args.validate: | |
| print('-'*100) | |
| if args.dataset_file == 'vcoco': | |
| total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device) | |
| if utils.get_rank() == 0: | |
| sc1, sc2 = hoi_accumulator(args, total_res, False, args.wandb) | |
| if sc1 > scenario1: | |
| scenario1 = sc1 | |
| scenario2 = sc2 | |
| save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best') | |
| print(f'| Scenario #1 mAP : {sc1:.2f} ({scenario1:.2f})') | |
| print(f'| Scenario #2 mAP : {sc2:.2f} ({scenario2:.2f})') | |
| elif args.dataset_file == 'hico-det': | |
| test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device) | |
| if utils.get_rank() == 0: | |
| if test_stats['mAP'] > best_mAP: | |
| best_mAP = test_stats['mAP'] | |
| best_rare = test_stats['mAP rare'] | |
| best_non_rare = test_stats['mAP non-rare'] | |
| save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best') | |
| print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f} ({best_mAP:.2f})') | |
| print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f} ({best_rare:.2f})') | |
| print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f} ({best_non_rare:.2f})') | |
| if args.wandb and utils.get_rank() == 0: | |
| wandb.log({ | |
| 'mAP': test_stats['mAP'], | |
| 'mAP rare': test_stats['mAP rare'], | |
| 'mAP non-rare': test_stats['mAP non-rare'] | |
| }) | |
| print('-'*100) | |
| save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint') | |
| if (epoch + 1) % args.lr_drop == 0 : | |
| save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_'+str(epoch)) | |
| # if (epoch + 1) % args.pseudo_epoch == 0 : | |
| # save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_pseudo_'+str(epoch)) | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print('Training time {}'.format(total_time_str)) | |
| if args.dataset_file == 'vcoco': | |
| print(f'| Scenario #1 mAP : {scenario1:.2f}') | |
| print(f'| Scenario #2 mAP : {scenario2:.2f}') | |
| elif args.dataset_file == 'hico-det': | |
| print(f'| mAP (full)\t\t: {best_mAP:.2f}') | |
| print(f'| mAP (rare)\t\t: {best_rare:.2f}') | |
| print(f'| mAP (non-rare)\t: {best_non_rare:.2f}') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser( | |
| 'End-to-End Human Object Interaction training and evaluation script', | |
| parents=[get_args_parser()] | |
| ) | |
| args = parser.parse_args() | |
| if args.output_dir: | |
| args.output_dir += f"/{args.group_name}/{args.run_name}/" | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| main(args) | |