Spaces:
Runtime error
Runtime error
| import functools | |
| import itertools | |
| import json | |
| import logging | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from multiprocessing import Pool | |
| from argparse import ArgumentParser | |
| import multiprocessing as mp | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import transformers | |
| from decord import VideoReader, cpu | |
| from tasks.eval.model_utils import load_pllava, pllava_answer | |
| from tasks.eval.eval_utils import conv_templates | |
| logging.basicConfig() | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| IMAGE_TOKEN='<image>' | |
| from tasks.eval.recaption import ( | |
| RecaptionDataset, | |
| load_results, | |
| save_results, | |
| ) | |
| RESOLUTION = 672 # | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| required=True, | |
| default='llava-hf/llava-1.5-7b-hf' | |
| ) | |
| parser.add_argument( | |
| "--save_path", | |
| type=str, | |
| required=True, | |
| default='"./test_results/test_llava_mvbench"' | |
| ) | |
| parser.add_argument( | |
| "--num_frames", | |
| type=int, | |
| required=True, | |
| default=4, | |
| ) | |
| parser.add_argument( | |
| "--use_lora", | |
| action='store_true' | |
| ) | |
| parser.add_argument( | |
| "--lora_alpha", | |
| type=int, | |
| required=False, | |
| default=32, | |
| ) | |
| parser.add_argument( | |
| "--weight_dir", | |
| type=str, | |
| required=False, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--eval_model", | |
| type=str, | |
| required=False, | |
| default="gpt-3.5-turbo-0125", | |
| ) | |
| parser.add_argument( | |
| '--test_ratio', | |
| type=float, | |
| required=False, | |
| default=None | |
| ) | |
| parser.add_argument( | |
| "--conv_mode", | |
| type=str, | |
| required=False, | |
| default='eval_videoqabench', | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, test_ratio): | |
| # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes. | |
| model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, lora_alpha=lora_alpha, weight_dir=weight_dir) | |
| logger.info('done loading llava') | |
| # position embedding | |
| model = model.to(torch.device(rank)) | |
| model = model.eval() | |
| dataset = RecaptionDataset(test_ratio=test_ratio, num_segments=num_frames) | |
| dataset.set_rank_and_world_size(rank, world_size) | |
| return model, processor, dataset | |
| def infer_recaption( | |
| model, | |
| processor, | |
| data_sample, | |
| conv_mode, | |
| pre_query_prompt=None, # add in the head of question | |
| post_query_prompt=None, # add in the end of question | |
| answer_prompt=None, # add in the begining of answer | |
| return_prompt=None, # add in the begining of return message | |
| print_res=False, | |
| ): | |
| video_list = data_sample["video_pils"] | |
| conv = conv_templates[conv_mode].copy() | |
| # info = data_sample['info'] | |
| query = ( | |
| "You are to assist me in accomplishing a task about the input video. Reply to me with a precise yet detailed response. For how you would succeed in the recaptioning task, read the following Instructions section and Then, make your response with a elaborate paragraph.\n" | |
| "# Instructions\n" | |
| "1. Avoid providing over detailed information such as color, counts of any objects as you are terrible regarding observing these details\n" | |
| "2. Instead, you should carefully go over the provided video and reason about key information about the overall video\n" | |
| "3. If you are not sure about something, do not include it in you response.\n" | |
| "# Task\n" | |
| "Describe the background, characters and the actions in the provided video.\n" | |
| ) | |
| conv.user_query(query, pre_query_prompt, post_query_prompt, is_mm=True) | |
| if answer_prompt is not None: | |
| conv.assistant_response(answer_prompt) | |
| llm_message, conv = pllava_answer( | |
| conv=conv, | |
| model=model, | |
| processor=processor, | |
| img_list=video_list, | |
| max_new_tokens=400, | |
| num_beams=1, | |
| do_sample=False, | |
| print_res=print_res | |
| ) | |
| if answer_prompt is not None: | |
| llm_message = ''.join(llm_message.split(answer_prompt)[1:]) | |
| if return_prompt is not None: | |
| llm_message = return_prompt + llm_message | |
| return llm_message, query | |
| def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"): | |
| def get_index(num_frames, num_segments): | |
| seg_size = float(num_frames - 1) / num_segments | |
| start = int(seg_size / 2) | |
| offsets = np.array([ | |
| start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
| ]) | |
| return offsets | |
| def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336): | |
| transforms = torchvision.transforms.Resize(size=resolution) | |
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
| num_frames = len(vr) | |
| frame_indices = get_index(num_frames, num_segments) | |
| images_group = list() | |
| for frame_index in frame_indices: | |
| img = Image.fromarray(vr[frame_index].asnumpy()) | |
| images_group.append(transforms(img)) | |
| if return_msg: | |
| fps = float(vr.get_avg_fps()) | |
| sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
| # " " should be added in the start and end | |
| msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
| return images_group, msg | |
| else: | |
| return images_group | |
| if num_frames != 0: | |
| vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION) | |
| else: | |
| vid, msg = None, 'num_frames is 0, not inputing image' | |
| img_list = vid | |
| conv = conv_templates[conv_mode].copy() | |
| conv.user_query("Describe the video in details.", is_mm=True) | |
| llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True) | |
| def run(rank, args, world_size): | |
| if rank != 0: | |
| transformers.utils.logging.set_verbosity_error() | |
| logger.setLevel(transformers.logging.ERROR) | |
| print_res = True | |
| conv_mode= args.conv_mode | |
| pre_query_prompt = None | |
| post_query_prompt = None | |
| # pre_query_prompt = ("""Assist me in detailing the background, characters, and actions depicted in the provided video.\n""") | |
| # post_query_prompt = ("""My apologies for any lack of precision; there may be errors in the supplementary information provided.\n""" | |
| # """You are encouraged to be discerning and perceptive, paying attention to the minutest details, """ | |
| # """and to furnish a detailed yet precise description using eloquent language.""") | |
| logger.info(f'loading model and constructing dataset to gpu {rank}...') | |
| model, processor, dataset = load_model_and_dataset(rank, | |
| world_size, | |
| pretrained_model_name_or_path=args.pretrained_model_name_or_path, | |
| num_frames=args.num_frames, | |
| use_lora=args.use_lora, | |
| lora_alpha=args.lora_alpha, | |
| weight_dir=args.weight_dir, | |
| test_ratio=args.test_ratio) | |
| logger.info(f'done model and dataset...') | |
| logger.info('constructing dataset...') | |
| logger.info('single test...') | |
| vid_path = "./example/yoga.mp4" | |
| # vid_path = "./example/jesse_dance.mp4" | |
| if rank == 0: | |
| single_test(model, processor, vid_path, num_frames=args.num_frames) | |
| logger.info('single test done...') | |
| tbar = tqdm(total=len(dataset)) | |
| logger.info('single test...') | |
| result_list = [] | |
| done_count = 0 | |
| for example in dataset: | |
| task_type = example['task_type'] | |
| if task_type in dataset.data_list_info: | |
| pred, query = infer_recaption( | |
| model, | |
| processor, | |
| example, | |
| conv_mode=conv_mode, | |
| pre_query_prompt=pre_query_prompt, | |
| post_query_prompt=post_query_prompt, | |
| print_res=print_res, | |
| ) | |
| infos = {k: v for k, v in example['sample'].items() if isinstance(v, (str, float, int))} | |
| res = { | |
| 'pred': pred, | |
| 'task_type': task_type, | |
| 'video_path': example['video_path'], | |
| 'query': query, | |
| **infos | |
| } | |
| else: | |
| raise NotImplementedError(f'not implemented task type {task_type}') | |
| # res = chatgpt_eval(res) | |
| result_list.append(res) | |
| if rank == 0: | |
| tbar.update(len(result_list) - done_count, ) | |
| tbar.set_description_str( | |
| f"One Chunk--Task Type: {task_type}-" | |
| f"pred: {pred[:min(15, len(pred))]}......" | |
| ) | |
| done_count = len(result_list) | |
| return result_list | |
| def main(): | |
| multiprocess=True | |
| mp.set_start_method('spawn') | |
| args = parse_args() | |
| save_path = args.save_path | |
| eval_model = args.eval_model | |
| logger.info(f'trying loading results from {save_path}') | |
| result_list = load_results(save_path, model=args.eval_model) | |
| if result_list is None: | |
| if multiprocess: | |
| logger.info(f'started benchmarking, saving to: {save_path}') | |
| n_gpus = torch.cuda.device_count() | |
| # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" | |
| world_size = n_gpus | |
| with Pool(world_size) as pool: | |
| func = functools.partial(run, args=args, world_size=world_size) | |
| # func = functools.partial(run, world_size=world_size, model=model, dataset=dataset, result_list=[], acc_dict={}) | |
| result_lists = pool.map(func, range(world_size)) | |
| logger.info('finished running') | |
| result_list = [ res for res in itertools.chain(*result_lists)] | |
| else: | |
| result_list = run(0, world_size=1, args=args) # debug | |
| else: | |
| logger.info(f'loaded results from {save_path}') | |
| save_results(result_list, save_path, model=eval_model) | |
| if __name__ == "__main__": | |
| main() |