Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.insert(0, './') | |
| import decord | |
| import numpy as np | |
| import torch | |
| import os | |
| from lavila.data.video_transforms import Permute | |
| from lavila.data.datasets import get_frame_ids, video_loader_by_frames | |
| from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2 | |
| from lavila.models.tokenizer import MyGPT2Tokenizer | |
| from collections import OrderedDict | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms._transforms_video as transforms_video | |
| import gradio as gr | |
| def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): | |
| seg_size = float(end_frame - start_frame - 1) / num_segments | |
| seq = [] | |
| for i in range(num_segments): | |
| start = int(np.round(seg_size * i) + start_frame) | |
| end = int(np.round(seg_size * (i + 1)) + start_frame) | |
| end = min(end, end_frame) | |
| if jitter: | |
| frame_id = np.random.randint(low=start, high=(end + 1)) | |
| else: | |
| frame_id = (start + end) // 2 | |
| seq.append(frame_id) | |
| return seq | |
| def video_loader_by_frames(root, vid, frame_ids): | |
| vr = decord.VideoReader(os.path.join(root, vid)) | |
| try: | |
| frames = vr.get_batch(frame_ids).asnumpy() | |
| frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] | |
| except (IndexError, decord.DECORDError) as error: | |
| print(error) | |
| print("Erroneous video: ", vid) | |
| frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] | |
| return torch.stack(frames, dim=0) | |
| def iter_clips(video_path, num_segments=4, stride_size=16): | |
| # The video is represented by `num_seg=4` frames | |
| vr = decord.VideoReader(video_path) | |
| frame_sample_size = num_segments * stride_size | |
| max_start_frame = len(vr) - frame_sample_size | |
| curr_frame = 0 | |
| fps = vr.get_avg_fps() | |
| while curr_frame < max_start_frame: | |
| stop_frame = min(frame_sample_size, len(vr)) | |
| curr_sec, stop_sec = curr_frame / fps, stop_frame / fps | |
| frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False) | |
| frames = video_loader_by_frames('./', video_path, frame_ids) | |
| yield curr_sec, stop_sec, frames | |
| curr_frame += frame_sample_size | |
| class Pipeline: | |
| def __init__(self, path=""): | |
| ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth') | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| state_dict = OrderedDict() | |
| for k, v in ckpt['state_dict'].items(): | |
| state_dict[k.replace('module.', '')] = v | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2( | |
| text_use_cls_token=False, | |
| project_embed_dim=256, | |
| gated_xattn=True, | |
| timesformer_gated_xattn=False, | |
| freeze_lm_vclm=False, | |
| freeze_visual_vclm=False, | |
| freeze_visual_vclm_temporal=False, | |
| num_frames=4, | |
| drop_path_rate=0. | |
| ) | |
| self.model.load_state_dict(state_dict, strict=True) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True) | |
| crop_size = 224 | |
| self.val_transform = transforms.Compose([ | |
| Permute([3, 0, 1, 2]), | |
| transforms.Resize(crop_size), | |
| transforms.CenterCrop(crop_size), | |
| transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]) | |
| ]) | |
| def decode_one(self, generated_ids, tokenizer): | |
| # get the index of <EOS> | |
| if tokenizer.eos_token_id == tokenizer.bos_token_id: | |
| if tokenizer.eos_token_id in generated_ids[1:].tolist(): | |
| eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1 | |
| else: | |
| eos_id = len(generated_ids.tolist()) - 1 | |
| elif tokenizer.eos_token_id in generated_ids.tolist(): | |
| eos_id = generated_ids.tolist().index(tokenizer.eos_token_id) | |
| else: | |
| eos_id = len(generated_ids.tolist()) - 1 | |
| generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist()) | |
| return generated_text_str | |
| def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10): | |
| text = "" | |
| with torch.autocast(self.device): | |
| for start, stop, frames in iter_clips(video_path): | |
| text_to_add = f"{'-'*30} Predictions From: {start:2.3f}-{stop:2.3f} seconds {'-'*30}\n" | |
| print(text_to_add) | |
| text += text_to_add | |
| frames = self.val_transform(frames).unsqueeze(0) | |
| if self.device == 'cuda': | |
| frames = frames.to(self.device).half() | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(frames) | |
| generated_text_ids, ppls = self.model.generate( | |
| image_features, | |
| self.tokenizer, | |
| target=None, # free-form generation | |
| max_text_length=max_text_length, | |
| top_k=None, | |
| top_p=top_p, # nucleus sampling | |
| num_return_sequences=num_return_sequences, # number of candidates: 10 | |
| temperature=temperature, | |
| early_stopping=True, | |
| ) | |
| for i in range(num_return_sequences): | |
| generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer) | |
| text_to_add = '\t{}: {}\n'.format(i, generated_text_str) | |
| print(text_to_add) | |
| text += text_to_add | |
| return text | |
| interface = gr.Interface( | |
| Pipeline(), | |
| inputs=[ | |
| gr.Video(label='video_path'), | |
| gr.Slider(0.0, 1.0, 0.7, label='temperature'), | |
| gr.Slider(0.0, 1.0, 0.95, label='top_p'), | |
| ], | |
| outputs='text' | |
| ) | |
| if __name__ == '__main__': | |
| interface.launch(debug=True) |