Spaces:
Runtime error
Runtime error
| import argparse | |
| import datetime | |
| import json | |
| import random | |
| import time | |
| import multiprocessing | |
| from pathlib import Path | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| import hotr.data.datasets as datasets | |
| import hotr.util.misc as utils | |
| from hotr.engine.arg_parser import get_args_parser | |
| from hotr.data.datasets import build_dataset, get_coco_api_from_dataset | |
| from hotr.data.datasets.vcoco import make_hoi_transforms | |
| from PIL import Image | |
| from hotr.util.logger import print_params, print_args | |
| import copy | |
| from hotr.data.datasets import builtin_meta | |
| from PIL import Image | |
| import requests | |
| # import mmcv | |
| from matplotlib import pyplot as plt | |
| import imageio | |
| from tools.vis_tool import * | |
| from hotr.models.detr import build | |
| def change_format(results,valid_ids): | |
| boxes,labels,pair_score =\ | |
| list(map(lambda x: x.cpu().numpy(), [results['boxes'], results['labels'], results['pair_score']])) | |
| output_i={} | |
| output_i['predictions']=[] | |
| output_i['hoi_prediction']=[] | |
| h_idx=np.where(labels==1)[0] | |
| for box,label in zip(boxes,labels): | |
| output_i['predictions'].append({'bbox':box.tolist(),'category_id':label}) | |
| for i,verb in enumerate(pair_score): | |
| if i in [1,4,10,23,26,5,18]: | |
| continue | |
| for j,hum in enumerate(h_idx): | |
| for k in range(len(boxes)): | |
| if verb[j][k]>0: | |
| output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]}) | |
| return output_i | |
| def vis(args,input_img=None,id=294,return_img=False): | |
| if args.frozen_weights is not None: | |
| print("Freeze weights for detector") | |
| device = torch.device(args.device) | |
| # fix the seed for reproducibility | |
| seed = args.seed + utils.get_rank() | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| # Data Setup | |
| dataset_train = build_dataset(image_set='train', args=args) | |
| args.num_classes = dataset_train.num_category() | |
| args.num_actions = dataset_train.num_action() | |
| args.action_names = dataset_train.get_actions() | |
| if args.share_enc: args.hoi_enc_layers = args.enc_layers | |
| if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers | |
| if args.dataset_file == 'vcoco': | |
| # Save V-COCO dataset statistics | |
| args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0] | |
| args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1) | |
| args.human_actions = dataset_train.get_human_action() | |
| args.object_actions = dataset_train.get_object_action() | |
| args.num_human_act = dataset_train.num_human_act() | |
| elif args.dataset_file == 'hico-det': | |
| args.valid_obj_ids = dataset_train.get_valid_obj_ids() | |
| print_args(args) | |
| args.HOIDet=True | |
| args.eval=True | |
| args.pretrained_dec=True | |
| args.share_enc=True | |
| args.share_dec_param = True | |
| if args.dataset_file=='hico-det': | |
| args.valid_ids=args.valid_obj_ids | |
| # Model Setup | |
| model, criterion, postprocessors = build(args) | |
| model.to(device) | |
| model_without_ddp = model | |
| n_parameters = print_params(model) | |
| param_dicts = [ | |
| {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, | |
| { | |
| "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], | |
| "lr": args.lr_backbone, | |
| }, | |
| ] | |
| output_dir = Path(args.output_dir) | |
| checkpoint = torch.load(args.resume, map_location='cpu') | |
| #수정 | |
| module_name=list(checkpoint['model'].keys()) | |
| model_without_ddp.load_state_dict(checkpoint['model'], strict=False) | |
| # if not args.video_vis: | |
| # url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12)) | |
| # req = requests.get(url, stream=True, timeout=1, verify=False).raw | |
| if input_img is None: | |
| req = args.image_dir | |
| img = Image.open(req).convert('RGB') | |
| else: | |
| # import pdb;pdb.set_trace() | |
| img = input_img | |
| w,h=img.size | |
| orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device) | |
| transform=make_hoi_transforms('val') | |
| sample=img.copy() | |
| sample,_=transform(sample,None) | |
| sample = sample.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| model.eval() | |
| out=model(sample) | |
| results = postprocessors['hoi'](out, orig_size,dataset=args.dataset_file,args=args) | |
| output_i=change_format(results[0],args.valid_ids) | |
| out_dir = './vis' | |
| image = np.asarray(img, dtype=np.uint8)[:,:,::-1] | |
| # image = cv2.imdecode(image_nparray, cv2.IMREAD_COLOR) | |
| vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES) | |
| plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB)) | |
| # import pdb;pdb.set_trace() | |
| if return_img: | |
| return Image.fromarray(vis_img) | |
| else: | |
| cv2.imwrite('./vis_res/vis1.jpg',vis_img) | |
| # else: | |
| # frames=[] | |
| # video_file=id | |
| # video_reader = mmcv.VideoReader('./vid/'+video_file+'.mp4') | |
| # fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| # video_writer = cv2.VideoWriter( | |
| # './vid/'+video_file+'_vis.mp4', fourcc, video_reader.fps, | |
| # (video_reader.width, video_reader.height)) | |
| # orig_size = torch.as_tensor([int(video_reader.height), int(video_reader.width)]).unsqueeze(0).to(device) | |
| # transform=make_hoi_transforms('val') | |
| # for frame in mmcv.track_iter_progress(video_reader): | |
| # frame=mmcv.imread(frame) | |
| # frame=frame.copy() | |
| # frame=Image.fromarray(frame,'RGB') | |
| # sample,_=transform(frame,None) | |
| # sample=sample.unsqueeze(0).to(device) | |
| # with torch.no_grad(): | |
| # model.eval() | |
| # out=model(sample) | |
| # results = postprocessors['hoi'](out, orig_size,dataset='vcoco',args=args) | |
| # output_i=change_format(results[0],args.valid_ids) | |
| # vis_img=draw_img_vcoco(np.array(frame),output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES) | |
| # frames.append(vis_img) | |
| # video_writer.write(vis_img) | |
| # with imageio.get_writer("smiling.gif", mode="I") as writer: | |
| # for idx, frame in enumerate(frames): | |
| # # print("Adding frame to GIF file: ", idx + 1) | |
| # writer.append_data(frame) | |
| # if video_writer: | |
| # video_writer.release() | |
| # cv2.destroyAllWindows() | |
| # def visualization(id, video_vis=False, dataset_file='vcoco', path_id = 0 ,data_path='v-coco', threshold=0.4, topk=10,aug_path = '[]'): | |
| # parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) | |
| # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' | |
| # with open('./v-coco/data/vcoco_test.ids') as file: | |
| # test_idxs = [line.rstrip('\n') for line in file] | |
| # if not video_vis: | |
| # id = test_idxs[id] | |
| # args = parser.parse_args(args=['--dataset_file',dataset_file,'--data_path',data_path,'--resume',checkpoint_dir,'--num_hoi_queries' ,'16','--temperature' ,'0.05', '--augpath_name',aug_path ,'--path_id','{}'.format(path_id)]) | |
| # args.video_vis=video_vis | |
| # args.threshold=threshold | |
| # args.topk=topk | |
| # if args.output_dir: | |
| # Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| # vis(args,id) | |
| # 230727 for huggingface | |
| def visualization(input_img,threshold,topk): | |
| parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) | |
| args = parser.parse_args(args=[]) | |
| args.threshold = threshold | |
| args.topk = int(topk) | |
| # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' | |
| args.resume= './checkpoints/vcoco/checkpoint.pth' | |
| # with open('./v-coco/data/splits/vcoco_test.ids') as file: | |
| # test_idxs = [line.rstrip('\n') for line in file] | |
| # # if not video_vis: | |
| # id = test_idxs[309] | |
| # args = parser.parse_args() | |
| args.dataset_file = 'vcoco' | |
| args.data_path = 'v-coco' | |
| # args.resume = checkpoint_dir | |
| args.num_hoi_queries = 16 | |
| args.temperature = 0.05 | |
| args.augpath_name = ['p2','p3','p4'] | |
| # args.path_id = 1 | |
| # args.threshold = threshold | |
| # args.topk = topk | |
| if args.output_dir: | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| vis(args,input_img=input_img,return_img=True) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) | |
| parser.add_argument('--threshold',help='score threshold for visualization', default=0.4, type=float) | |
| # parser.add_argument('--path_id',help='index of inference path', default=1, type=int) | |
| parser.add_argument('--topk',help='topk prediction', default=5, type=int) | |
| parser.add_argument('--video_vis', action='store_true') | |
| parser.add_argument('--image_dir', default='', type=str) | |
| args = parser.parse_args() | |
| # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' | |
| args.resume= './checkpoints/vcoco/checkpoint.pth' | |
| with open('./v-coco/data/splits/vcoco_test.ids') as file: | |
| test_idxs = [line.rstrip('\n') for line in file] | |
| # if not video_vis: | |
| id = test_idxs[309] | |
| # args = parser.parse_args() | |
| # args.dataset_file = 'vcoco' | |
| # args.data_path = 'v-coco' | |
| # args.resume = checkpoint_dir | |
| # args.num_hoi_queries = 16 | |
| # args.temperature = 0.05 | |
| args.augpath_name = ['p2','p3','p4'] | |
| # args.path_id = 1 | |
| if args.output_dir: | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| vis(args,id) | |