Spaces:
Runtime error
Runtime error
| """Feature extraction utilities for video and text processing.""" | |
| import os | |
| import numpy as np | |
| import torch | |
| import av | |
| from PIL import Image | |
| from einops import rearrange | |
| from typing import Any, Dict, List, Union, Tuple | |
| from loguru import logger | |
| from .config_utils import AttributeDict | |
| from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS | |
| class FeatureExtractionError(Exception): | |
| """Exception raised for feature extraction errors.""" | |
| pass | |
| def get_frames_av( | |
| video_path: str, | |
| fps: float, | |
| max_length: float = None, | |
| ) -> Tuple[np.ndarray, float]: | |
| end_sec = max_length if max_length is not None else 15 | |
| next_frame_time_for_each_fps = 0.0 | |
| time_delta_for_each_fps = 1 / fps | |
| all_frames = [] | |
| output_frames = [] | |
| with av.open(video_path) as container: | |
| stream = container.streams.video[0] | |
| ori_fps = stream.guessed_rate | |
| stream.thread_type = "AUTO" | |
| for packet in container.demux(stream): | |
| for frame in packet.decode(): | |
| frame_time = frame.time | |
| if frame_time < 0: | |
| continue | |
| if frame_time > end_sec: | |
| break | |
| frame_np = None | |
| this_time = frame_time | |
| while this_time >= next_frame_time_for_each_fps: | |
| if frame_np is None: | |
| frame_np = frame.to_ndarray(format="rgb24") | |
| output_frames.append(frame_np) | |
| next_frame_time_for_each_fps += time_delta_for_each_fps | |
| output_frames = np.stack(output_frames) | |
| vid_len_in_s = len(output_frames) / fps | |
| if max_length is not None and len(output_frames) > int(max_length * fps): | |
| output_frames = output_frames[: int(max_length * fps)] | |
| vid_len_in_s = max_length | |
| return output_frames, vid_len_in_s | |
| def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1): | |
| b, t, c, h, w = x.shape | |
| if batch_size < 0: | |
| batch_size = b * t | |
| x = rearrange(x, "b t c h w -> (b t) c h w") | |
| outputs = [] | |
| for i in range(0, b * t, batch_size): | |
| outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size])) | |
| res = torch.cat(outputs, dim=0) | |
| res = rearrange(res, "(b t) d -> b t d", b=b) | |
| return res | |
| def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1): | |
| """ | |
| The input video of x is best to be in fps of 24 of greater than 24. | |
| Input: | |
| x: tensor in shape of [B, T, C, H, W] | |
| batch_size: the batch_size for synchformer inference | |
| """ | |
| b, t, c, h, w = x.shape | |
| assert c == 3 and h == 224 and w == 224 | |
| segment_size = 16 | |
| step_size = 8 | |
| num_segments = (t - segment_size) // step_size + 1 | |
| segments = [] | |
| for i in range(num_segments): | |
| segments.append(x[:, i * step_size : i * step_size + segment_size]) | |
| x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224) | |
| outputs = [] | |
| if batch_size < 0: | |
| batch_size = b * num_segments | |
| x = rearrange(x, "b s t c h w -> (b s) 1 t c h w") | |
| for i in range(0, b * num_segments, batch_size): | |
| with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half): | |
| outputs.append(model_dict.syncformer_model(x[i : i + batch_size])) | |
| x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768] | |
| x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b) | |
| return x | |
| def encode_video_features(video_path, model_dict): | |
| visual_features = {} | |
| # siglip2 visual features | |
| frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"]) | |
| images = [Image.fromarray(frame).convert('RGB') for frame in frames] | |
| images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W] | |
| clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0) | |
| visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device) | |
| # synchformer visual features | |
| frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"]) | |
| images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W] | |
| sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224] | |
| # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video | |
| visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict) | |
| vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"] | |
| visual_features = AttributeDict(visual_features) | |
| return visual_features, vid_len_in_s | |
| def encode_text_feat(text: List[str], model_dict): | |
| # x: (B, L) | |
| inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device) | |
| outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True) | |
| return outputs.last_hidden_state, outputs.attentions | |
| def feature_process(video_path, prompt, model_dict, cfg): | |
| visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict) | |
| neg_prompt = "noisy, harsh" | |
| prompts = [neg_prompt, prompt] | |
| text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict) | |
| text_feat = text_feat_res[1:] | |
| uncond_text_feat = text_feat_res[:1] | |
| if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]: | |
| text_seq_length = cfg.model_config.model_kwargs.text_length | |
| text_feat = text_feat[:, :text_seq_length] | |
| uncond_text_feat = uncond_text_feat[:, :text_seq_length] | |
| text_feats = AttributeDict({ | |
| 'text_feat': text_feat, | |
| 'uncond_text_feat': uncond_text_feat, | |
| }) | |
| return visual_feats, text_feats, audio_len_in_s | |