Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from collections import OrderedDict | |
| import json | |
| import math | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import sys | |
| import time | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| import torch.cuda.amp as amp | |
| from torch.distributed.optim import ZeroRedundancyOptimizer | |
| import torch.nn.parallel | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms._transforms_video as transforms_video | |
| import wandb | |
| from lavila.data import datasets | |
| from lavila.data.video_transforms import Permute | |
| from lavila.models import models, loss | |
| from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) | |
| from lavila.models.utils import inflate_positional_embeds | |
| from lavila.utils import distributed as dist_utils | |
| from lavila.utils.evaluation_charades import charades_map | |
| from lavila.utils.meter import AverageMeter, ProgressMeter | |
| from lavila.utils.preprocess import generate_label_map | |
| from lavila.utils.random import random_seed | |
| from lavila.utils.scheduler import cosine_scheduler | |
| from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG) | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser(description='lavila finetune and evaluation', add_help=False) | |
| # Data | |
| parser.add_argument('--dataset', default='ek100_mir', type=str, | |
| choices=['ek100_mir', 'charades_ego']) | |
| parser.add_argument('--root', | |
| default='datasets/EK100/video_ht256px/', | |
| type=str, help='path to dataset root') | |
| parser.add_argument('--metadata', | |
| default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv', | |
| type=str, help='path to metadata file (train set)') | |
| parser.add_argument('--metadata-val', | |
| default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv', | |
| type=str, help='path to metadata file (val set)') | |
| parser.add_argument('--relevancy-path', | |
| default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl', | |
| type=str, help='path to relevancy matrix (val set)') | |
| parser.add_argument('--output-dir', default='./', type=str, help='output dir') | |
| parser.add_argument('--clip-length', default=16, type=int, help='clip length') | |
| parser.add_argument('--clip-stride', default=4, type=int, help='clip stride') | |
| parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling') | |
| # Model | |
| parser.add_argument('--pretrain-model', default='', type=str, help='path to pretrain model') | |
| parser.add_argument('--resume', default='', type=str, help='path to resume from') | |
| parser.add_argument('--find-unused-parameters', action='store_true', | |
| help='do this during DDP (useful for models with tied weights)') | |
| parser.add_argument('--drop-path-rate', default=0.1, type=float, help='drop path ratio') | |
| # Training | |
| parser.add_argument('--epochs', default=100, type=int) | |
| parser.add_argument('--warmup-epochs', default=1, type=int) | |
| parser.add_argument('--start-epoch', default=0, type=int) | |
| parser.add_argument('--batch-size', default=16, type=int, | |
| help='number of samples per-device/per-gpu') | |
| parser.add_argument('--freeze-temperature', action='store_true', help='freeze temperature if set to True') | |
| parser.add_argument('--lr', default=3e-5, type=float) | |
| parser.add_argument('--fix-lr', action='store_true', help='disable cosine lr decay if set True') | |
| parser.add_argument('--lr-start', default=1e-6, type=float, | |
| help='initial warmup lr') | |
| parser.add_argument('--lr-end', default=1e-5, type=float, | |
| help='minimum final lr') | |
| parser.add_argument('--clip-grad-type', default='norm', choices=['norm', 'value']) | |
| parser.add_argument('--clip-grad-value', default=None, type=float, help='') | |
| parser.add_argument('--update-freq', default=1, type=int, | |
| help='optimizer update frequency (i.e. gradient accumulation steps)') | |
| parser.add_argument('--wd', default=0.01, type=float) | |
| parser.add_argument('--betas', default=(0.9, 0.999), nargs=2, type=float) | |
| parser.add_argument('--eps', default=1e-8, type=float) | |
| parser.add_argument('--eval-freq', default=5, type=int) | |
| parser.add_argument('--save-freq', default=5, type=int) | |
| parser.add_argument('--disable-amp', action='store_true', | |
| help='disable mixed-precision training (requires more memory and compute)') | |
| parser.add_argument('--use-zero', action='store_true', | |
| help='use ZeroRedundancyOptimizer to save memory') | |
| parser.add_argument('--use-checkpoint', action='store_true', | |
| help='use gradient checkpointing during training for significantly less GPU usage') | |
| # System | |
| parser.add_argument('--print-freq', default=100, type=int, help='print frequency') | |
| parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |
| help='number of data loading workers per process') | |
| parser.add_argument('--evaluate', action='store_true', help='eval only') | |
| parser.add_argument('--world-size', default=1, type=int, | |
| help='number of nodes for distributed training') | |
| parser.add_argument('--rank', default=0, type=int, | |
| help='node rank for distributed training') | |
| parser.add_argument("--local_rank", type=int, default=0) | |
| parser.add_argument('--dist-url', default='env://', type=str, | |
| help='url used to set up distributed training') | |
| parser.add_argument('--dist-backend', default='nccl', type=str) | |
| parser.add_argument('--seed', default=0, type=int) | |
| parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') | |
| parser.add_argument('--wandb', action='store_true', help='Enable WandB logging') | |
| return parser | |
| def main(args): | |
| dist_utils.init_distributed_mode(args) | |
| global best_acc1 | |
| random_seed(args.seed, dist_utils.get_rank()) | |
| if args.pretrain_model: | |
| ckpt_path = args.pretrain_model | |
| else: | |
| raise Exception('no checkpoint found') | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| state_dict = OrderedDict() | |
| for k, v in ckpt['state_dict'].items(): | |
| state_dict[k.replace('module.', '')] = v | |
| old_args = ckpt['args'] | |
| print("=> creating model: {}".format(old_args.model)) | |
| model = getattr(models, old_args.model)( | |
| pretrained=old_args.load_visual_pretrained, | |
| pretrained2d=old_args.load_visual_pretrained is not None, | |
| text_use_cls_token=old_args.use_cls_token, | |
| project_embed_dim=old_args.project_embed_dim, | |
| timesformer_gated_xattn=False, | |
| timesformer_freeze_space=False, | |
| num_frames=args.clip_length, | |
| drop_path_rate=args.drop_path_rate, | |
| ) | |
| model.logit_scale.requires_grad = False | |
| model.cuda(args.gpu) | |
| if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model: | |
| # inflate weight | |
| print('=> inflating PE in models due to different frame numbers') | |
| state_dict = inflate_positional_embeds( | |
| model.state_dict(), state_dict, | |
| num_frames=args.clip_length, | |
| load_temporal_fix='bilinear', | |
| ) | |
| model.load_state_dict(state_dict, strict=True) | |
| print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) | |
| if args.distributed: | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, device_ids=[args.gpu], bucket_cap_mb=200, | |
| find_unused_parameters=args.find_unused_parameters | |
| ) | |
| p_wd, p_non_wd = [], [] | |
| for n, p in model.named_parameters(): | |
| if not p.requires_grad: | |
| continue # frozen weights | |
| if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: | |
| p_non_wd.append(p) | |
| else: | |
| p_wd.append(p) | |
| optim_params = [{"params": p_wd, "weight_decay": args.wd}, | |
| {"params": p_non_wd, "weight_decay": 0}] | |
| if args.use_zero: | |
| optimizer = ZeroRedundancyOptimizer( | |
| optim_params, optimizer_class=torch.optim.AdamW, | |
| lr=args.lr, betas=args.betas, eps=args.eps, weight_decay=args.wd | |
| ) | |
| else: | |
| optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas, | |
| eps=args.eps, weight_decay=args.wd) | |
| scaler = amp.GradScaler(enabled=not args.disable_amp) | |
| # optionally resume from a checkpoint (takes precedence over autoresume) | |
| latest = os.path.join(args.output_dir, 'checkpoint.pt') | |
| if os.path.isfile(latest): | |
| args.resume = '' | |
| if args.resume: | |
| if os.path.isfile(args.resume): | |
| print("=> loading resume checkpoint '{}'".format(args.resume)) | |
| checkpoint = torch.load(args.resume, map_location='cpu') | |
| epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 | |
| args.start_epoch = epoch | |
| if not args.distributed: | |
| state_dict = OrderedDict() | |
| for k, v in checkpoint['state_dict'].items(): | |
| state_dict[k.replace('module.', '')] = v | |
| result = model.load_state_dict(state_dict, strict=False) | |
| else: | |
| result = model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| print(result) | |
| optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else () | |
| scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else () | |
| best_acc1 = checkpoint['best_acc1'] | |
| print("=> loaded resume checkpoint '{}' (epoch {})" | |
| .format(args.resume, epoch)) | |
| else: | |
| print("=> no checkpoint found at '{}'".format(args.resume)) | |
| else: | |
| # auto-resume from latest checkpoint in output directory | |
| latest = os.path.join(args.output_dir, 'checkpoint.pt') | |
| if os.path.isfile(latest): | |
| print("=> loading latest checkpoint '{}'".format(latest)) | |
| latest_checkpoint = torch.load(latest, map_location='cpu') | |
| args.start_epoch = latest_checkpoint['epoch'] | |
| model.load_state_dict(latest_checkpoint['state_dict']) | |
| optimizer.load_state_dict(latest_checkpoint['optimizer']) | |
| scaler.load_state_dict(latest_checkpoint['scaler']) | |
| best_acc1 = latest_checkpoint['best_acc1'] | |
| print("=> loaded latest checkpoint '{}' (epoch {})" | |
| .format(latest, latest_checkpoint['epoch'])) | |
| cudnn.benchmark = True | |
| # Data loading code | |
| print("=> creating dataset") | |
| if old_args.model.endswith('DISTILBERT_BASE'): | |
| tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') | |
| elif old_args.model.endswith('BERT_BASE'): | |
| tokenizer = MyBertTokenizer('bert-base-uncased') | |
| elif old_args.model.endswith('BERT_LARGE'): | |
| tokenizer = MyBertTokenizer('bert-large-uncased') | |
| elif old_args.model.endswith('GPT2'): | |
| tokenizer = MyGPT2Tokenizer('gpt2') | |
| elif old_args.model.endswith('GPT2_MEDIUM'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-medium') | |
| elif old_args.model.endswith('GPT2_LARGE'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-large') | |
| elif old_args.model.endswith('GPT2_XL'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-xl') | |
| else: | |
| print("Using SimpleTokenizer because of model '{}'. " | |
| "Please check if this is what you want".format(old_args.model)) | |
| tokenizer = SimpleTokenizer() | |
| if args.dataset == 'ek100_mir': | |
| criterion = loss.MaxMarginRankingLoss(margin=0.2, fix_norm=True).cuda(args.gpu) | |
| elif args.dataset == 'charades_ego': | |
| criterion = loss.CLIPLoss( | |
| use_vissl=True, | |
| cache_labels=True, | |
| rank=args.rank, | |
| world_size=args.world_size | |
| ) | |
| crop_size = 224 if '336PX' not in old_args.model else 336 | |
| transforms_list = [ | |
| Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
| transforms.RandomResizedCrop(crop_size, scale=(0.5, 1.0)), | |
| ] | |
| if 'OPENAI' in old_args.model: | |
| transforms_list.append(transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])) | |
| else: | |
| transforms_list.append(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])) | |
| train_transform = transforms.Compose(transforms_list) | |
| val_transform = transforms.Compose([ | |
| Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
| transforms.Resize(crop_size), | |
| transforms.CenterCrop(crop_size), | |
| (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if 'OPENAI' not in old_args.model else | |
| transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), | |
| ]) | |
| # build dataset | |
| args.model = old_args.model | |
| args.norm_embed = old_args.norm_embed | |
| if args.dataset == 'ek100_mir': | |
| train_dataset = datasets.get_dataset(train_transform, tokenizer, args, is_training=True) | |
| args.metadata = args.metadata.replace('train', 'test') | |
| val_dataset = datasets.get_dataset(val_transform, tokenizer, args, is_training=False) | |
| args.metadata = args.metadata.replace('test', 'train') | |
| elif args.dataset == 'charades_ego': | |
| train_dataset = datasets.VideoCaptionDatasetCLIP( | |
| 'charades_ego_trimmed', args.root, args.metadata, | |
| transform=train_transform, is_training=True, tokenizer=tokenizer, | |
| clip_length=args.clip_length, clip_stride=args.clip_stride | |
| ) | |
| labels, mapping_vn2act = generate_label_map(args.dataset) | |
| val_dataset = datasets.VideoClassyDataset( | |
| args.dataset, args.root, args.metadata_val, | |
| transform=val_transform, is_training=False, | |
| label_mapping=mapping_vn2act, is_trimmed=False, | |
| num_clips=1, clip_length=args.clip_length, clip_stride=args.clip_stride, | |
| sparse_sample=args.sparse_sample, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if args.distributed: | |
| train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
| val_sampler = torch.utils.data.SequentialSampler(val_dataset) # disable distributed | |
| else: | |
| train_sampler = None | |
| val_sampler = None | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), | |
| num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True | |
| ) | |
| print('len(train_loader) = {}'.format(len(train_loader))) | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None), | |
| num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False | |
| ) | |
| print('len(val_loader) = {}'.format(len(val_loader))) | |
| if args.evaluate: | |
| if args.dataset == 'ek100_mir': | |
| _ = validate_mir(val_loader, model, criterion, args) | |
| elif args.dataset == 'charades_ego': | |
| _ = validate_cls(val_loader, ['{}'], labels, model, tokenizer, args) | |
| return | |
| if args.fix_lr: | |
| lr_schedule = None | |
| else: | |
| lr_schedule = cosine_scheduler( | |
| args.lr, args.lr_end, args.epochs, len(train_loader) // args.update_freq, | |
| warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start, | |
| ) | |
| if dist_utils.is_main_process() and args.wandb: | |
| wandb_id = os.path.split(args.output_dir)[-1] | |
| wandb.init(project='LaViLa', id=wandb_id, config=args, resume='allow') | |
| print(args) | |
| print("=> zero-shot testing") | |
| if args.dataset == 'ek100_mir': | |
| _ = validate_mir(val_loader, model, criterion, args) | |
| elif args.dataset == 'charades_ego': | |
| _ = validate_cls(val_loader, ['{}'], labels, model, tokenizer, args) | |
| print("=> beginning training") | |
| for epoch in range(args.start_epoch, args.epochs): | |
| if args.distributed: | |
| train_sampler.set_epoch(epoch) | |
| train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args) | |
| is_epoch = ((epoch + 1) % args.save_freq) == 0 | |
| print('=> saving checkpoint') | |
| dist_utils.save_on_master({ | |
| 'epoch': epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'scaler': scaler.state_dict(), | |
| 'best_acc1': 0, | |
| 'args': args, | |
| }, False, args.output_dir, is_epoch=is_epoch) | |
| if (epoch + 1) % args.eval_freq != 0: | |
| continue | |
| # TODO: add evaluation | |
| if args.dataset == 'ek100_mir': | |
| val_stats = validate_mir(val_loader, model, criterion, args) | |
| elif args.dataset == 'charades_ego': | |
| val_stats = validate_cls(val_loader, ['{}'], labels, model, tokenizer, args) | |
| log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, | |
| **{f'test_{k}': v for k, v in val_stats.items()}, | |
| 'epoch': epoch} | |
| if dist_utils.is_main_process(): | |
| if args.wandb: | |
| wandb.log(log_stats) | |
| with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f: | |
| f.write(json.dumps(log_stats) + '\n') | |
| def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args): | |
| batch_time = AverageMeter('Time', ':6.2f') | |
| data_time = AverageMeter('Data', ':6.2f') | |
| mem = AverageMeter('Mem (GB)', ':6.1f') | |
| if args.dataset == 'ek100_mir': | |
| metric_names = ['loss', 'max_margin_loss'] | |
| elif args.dataset == 'charades_ego': | |
| metric_names = models.get_metric_names(args.model) | |
| iters_per_epoch = len(train_loader) // args.update_freq | |
| metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) | |
| progress = ProgressMeter( | |
| iters_per_epoch, | |
| [batch_time, data_time, mem, *metrics.values()], | |
| prefix="Epoch: [{}]".format(epoch)) | |
| # switch to train mode | |
| model.train() | |
| end = time.time() | |
| for data_iter, inputs in enumerate(train_loader): | |
| optim_iter = data_iter // args.update_freq | |
| # measure data loading time | |
| data_time.update(time.time() - end) | |
| # update weight decay and learning rate according to their schedule | |
| it = iters_per_epoch * epoch + optim_iter # global training iteration | |
| for k, param_group in enumerate(optimizer.param_groups): | |
| if lr_schedule is not None: | |
| param_group['lr'] = lr_schedule[it] | |
| inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] | |
| relevancies = inputs.pop() | |
| # compute output | |
| with amp.autocast(enabled=not args.disable_amp): | |
| outputs = model( | |
| *inputs, | |
| use_checkpoint=args.use_checkpoint, | |
| norm_embed=args.norm_embed | |
| ) | |
| if args.dataset == 'ek100_mir': | |
| loss_dict = criterion(outputs, weight=relevancies) | |
| elif args.dataset == 'charades_ego': | |
| loss_dict = criterion(outputs) | |
| loss = loss_dict['loss'] | |
| loss /= args.update_freq | |
| if not math.isfinite(loss.item()): | |
| print("Loss is {}, stopping training".format(loss.item())) | |
| sys.exit(1) | |
| scaler.scale(loss).backward() | |
| # TODO: for debug only | |
| # for n, p in model.named_parameters(): | |
| # if p.grad is not None: | |
| # print('{}: {} | {}'.format(n, torch.mean(torch.abs(p.data)), torch.mean(torch.abs(p.grad))), flush=True) | |
| # else: | |
| # print('{}: {} | {}'.format(n, torch.mean(torch.abs(p.data)), 'None'), flush=True) | |
| # if torch.isnan(loss): | |
| # for n, p in model.named_parameters(): | |
| # print(f'{n}:', p.grad, flush=True) | |
| if (data_iter + 1) % args.update_freq != 0: | |
| continue | |
| if args.clip_grad_value is not None: | |
| scaler.unscale_(optimizer) | |
| if args.clip_grad_type == 'norm': | |
| torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), args.clip_grad_value, norm_type=2. | |
| ) | |
| elif args.clip_grad_type == 'value': | |
| torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_value) | |
| else: | |
| assert False, f"Unknown clip mode ({args.clip_grad_type})." | |
| # compute gradient and do SGD step | |
| scaler.step(optimizer) | |
| scaler.update() | |
| model.zero_grad(set_to_none=True) | |
| if hasattr(dist_utils.get_model(model), 'logit_scale'): | |
| # clamp logit scale to [0, 100] | |
| dist_utils.get_model(model).logit_scale.data.clamp_(0, 4.6052) | |
| logit_scale = dist_utils.get_model(model).logit_scale.exp().item() | |
| else: | |
| logit_scale = torch.nan | |
| for k in loss_dict: | |
| metrics[k].update(loss_dict[k].item(), args.batch_size) | |
| # measure elapsed time | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| mem.update(torch.cuda.max_memory_allocated() // 1e9) | |
| if optim_iter % args.print_freq == 0: | |
| if dist_utils.is_main_process() and args.wandb: | |
| wandb.log({**{k: v.item() for k, v in loss_dict.items()}, | |
| 'scaler': scaler.get_scale(), 'logit': logit_scale}) | |
| progress.display(optim_iter) | |
| progress.synchronize() | |
| return {**{k: v.avg for k, v in metrics.items()}, | |
| 'lr': optimizer.param_groups[0]['lr'], | |
| 'logit_scale': logit_scale} | |
| def validate_mir(val_loader, model, criterion, args): | |
| batch_time = AverageMeter('Time', ':6.2f') | |
| data_time = AverageMeter('Data', ':6.2f') | |
| mem = AverageMeter('Mem (GB)', ':6.1f') | |
| metric_names = ['loss', 'max_margin_loss'] | |
| iters_per_epoch = len(val_loader) // args.update_freq | |
| metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) | |
| progress = ProgressMeter( | |
| iters_per_epoch, | |
| [batch_time, data_time, mem, *metrics.values()], | |
| prefix="Test: " | |
| ) | |
| # switch to eval mode | |
| model.eval() | |
| all_video_embed = [] | |
| all_text_embed = [] | |
| with torch.no_grad(): | |
| end = time.time() | |
| for i, inputs in enumerate(val_loader): | |
| # measure data loading time | |
| data_time.update(time.time() - end) | |
| inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] | |
| relevancies = inputs.pop() | |
| # compute output | |
| outputs = model( | |
| *inputs, | |
| use_checkpoint=args.use_checkpoint, | |
| norm_embed=args.norm_embed | |
| ) | |
| loss_dict = criterion(outputs, weight=relevancies) | |
| for k in loss_dict: | |
| metrics[k].update(loss_dict[k].item(), args.batch_size) | |
| image_features = outputs['image_embed'] | |
| text_features = outputs['text_embed'] | |
| all_video_embed.append(image_features.cpu().numpy()) | |
| all_text_embed.append(text_features.cpu().numpy()) | |
| # measure elapsed time | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| mem.update(torch.cuda.max_memory_allocated() // 1e9) | |
| if i % args.print_freq == 0: | |
| if dist_utils.is_main_process() and args.wandb: | |
| wandb.log({**{k: v.item() for k, v in loss_dict.items()}}) | |
| progress.display(i) | |
| progress.synchronize() | |
| all_text_embed = np.vstack(all_text_embed) | |
| all_video_embed = np.vstack(all_video_embed) | |
| similarity_matrix = np.matmul(all_video_embed, all_text_embed.T) | |
| similarity_matrix = (similarity_matrix + 1) / 2 | |
| video_id = pd.read_csv(args.metadata.replace('train', 'test')).values[:, 0] | |
| text_id = pd.read_csv(args.metadata.replace('train', 'test_sentence')).values[:, 0] | |
| indexes = [video_id.tolist().index(elem) for elem in text_id] | |
| similarity_matrix = similarity_matrix[:, indexes] | |
| print(similarity_matrix.shape) | |
| rel_matrix = pd.read_pickle( | |
| args.relevancy_path | |
| ) | |
| vis_map = calculate_mAP(similarity_matrix, rel_matrix) | |
| txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T) | |
| print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, (vis_map + txt_map) / 2)) | |
| vis_k_counts = calculate_k_counts(rel_matrix) | |
| txt_k_counts = calculate_k_counts(rel_matrix.T) | |
| vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts) | |
| txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts) | |
| vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG) | |
| txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG) | |
| print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2)) | |
| return {**{k: v.avg for k, v in metrics.items()}} | |
| def validate_cls(val_loader, templates, labels, model, tokenizer, args): | |
| # switch to eval mode | |
| model.eval() | |
| all_outputs = [] | |
| all_targets = [] | |
| with torch.no_grad(): | |
| text_features = [] | |
| for label in labels: | |
| if isinstance(label, list): | |
| texts = [tmpl.format(lbl) for tmpl in templates for lbl in label] | |
| else: | |
| texts = [tmpl.format(label) for tmpl in templates] | |
| texts = tokenizer(texts) | |
| if isinstance(texts, tuple): | |
| # Bert-style tokenizer will output both ids and mask | |
| texts, masks = texts | |
| texts = texts.cuda(non_blocking=True) | |
| masks = masks.cuda(non_blocking=True) | |
| else: | |
| texts = texts.cuda(non_blocking=True) | |
| masks = None | |
| texts = texts.view(-1, 77).contiguous() | |
| masks = masks.view(-1, 77).contiguous() if masks is not None else None | |
| if masks is not None: | |
| class_embeddings = dist_utils.get_model(model).encode_text(texts, attention_mask=masks) | |
| else: | |
| class_embeddings = dist_utils.get_model(model).encode_text(texts) | |
| class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embeddings = class_embeddings.mean(dim=0) | |
| class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
| text_features.append(class_embeddings) | |
| text_features = torch.stack(text_features, dim=0) | |
| print('=> start forwarding') | |
| end_time = time.time() | |
| for i, (images, target) in enumerate(val_loader): | |
| if i % args.print_freq == 0: | |
| print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time)) | |
| end_time = time.time() | |
| if isinstance(images, torch.Tensor): | |
| images = images.cuda(non_blocking=True) | |
| target = target.cuda(non_blocking=True) | |
| # encode images | |
| image_features = dist_utils.get_model(model).encode_image(images) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| # cosine similarity as logits | |
| logits_per_image = image_features @ text_features.t() | |
| logits_per_image = torch.softmax(logits_per_image, dim=1) | |
| else: | |
| target = target.cuda(non_blocking=True) | |
| images_list = images | |
| logits_all_clips = [] | |
| for images in images_list: | |
| images = images.cuda(non_blocking=True) | |
| image_features = dist_utils.get_model(model).encode_image(images) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| logits_per_image = image_features @ text_features.t() | |
| logits_all_clips.append(logits_per_image) | |
| logits_all_clips = torch.stack(logits_all_clips, dim=0) | |
| # logits_per_image = logits_all_clips.max(0).values | |
| logits_per_image = logits_all_clips.mean(0) | |
| logits_per_image = torch.softmax(logits_per_image, dim=1) | |
| all_outputs.append(logits_per_image.cpu()) | |
| all_targets.append(target.cpu()) | |
| all_outputs = torch.cat(all_outputs) | |
| all_targets = torch.cat(all_targets) | |
| preds, targets = all_outputs.numpy(), all_targets.numpy() | |
| m_ap, _, _ = charades_map(preds, targets) | |
| print('mAP = {:.3f}'.format(m_ap)) | |
| return {'mAP': m_ap} | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('lavila finetune and evaluation', parents=[get_args_parser()]) | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| main(args) | |