Spaces:
Runtime error
Runtime error
| import os | |
| import PIL.Image | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from torchvision.transforms import Resize, InterpolationMode | |
| import imageio | |
| from einops import rearrange | |
| import cv2 | |
| from PIL import Image | |
| import decord | |
| from controlnet_aux import OpenposeDetector | |
| apply_openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") | |
| def prepare_video(video_path:str, resolution:int, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1): | |
| vr = decord.VideoReader(video_path) | |
| initial_fps = vr.get_avg_fps() | |
| if output_fps == -1: | |
| output_fps = int(initial_fps) | |
| if end_t == -1: | |
| end_t = len(vr) / initial_fps | |
| else: | |
| end_t = min(len(vr) / initial_fps, end_t) | |
| assert 0 <= start_t < end_t | |
| assert output_fps > 0 | |
| start_f_ind = int(start_t * initial_fps) | |
| end_f_ind = int(end_t * initial_fps) | |
| num_f = int((end_t - start_t) * output_fps) | |
| sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int) | |
| video = vr.get_batch(sample_idx) | |
| if torch.is_tensor(video): | |
| video = video.detach().cpu().numpy() | |
| else: | |
| video = video.asnumpy() | |
| _, h, w, _ = video.shape | |
| video = rearrange(video, "f h w c -> f c h w") | |
| video = torch.Tensor(video) | |
| # Use max if you want the larger side to be equal to resolution (e.g. 512) | |
| # k = float(resolution) / min(h, w) | |
| k = float(resolution) / max(h, w) | |
| h *= k | |
| w *= k | |
| h = int(np.round(h / 64.0)) * 64 | |
| w = int(np.round(w / 64.0)) * 64 | |
| video = Resize((h, w), interpolation=InterpolationMode.BILINEAR, antialias=True)(video) | |
| if normalize: | |
| video = video / 127.5 - 1.0 | |
| return video, output_fps | |
| def pre_process_pose(input_video, apply_pose_detect: bool = True): | |
| detected_maps = [] | |
| for frame in input_video: | |
| img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) | |
| if apply_pose_detect: | |
| detected_map, _ = apply_openpose(img) | |
| else: | |
| detected_map = img | |
| H, W, C = img.shape | |
| detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) | |
| detected_maps.append(Image.fromarray(detected_map)) | |
| return detected_maps | |
| def create_gif(frames, fps, rescale=False, path=None, watermark=None): | |
| if path is None: | |
| dir = "temporal" | |
| os.makedirs(dir, exist_ok=True) | |
| path = os.path.join(dir, 'canny_db.gif') | |
| outputs = [] | |
| for i, x in enumerate(frames): | |
| x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) | |
| if rescale: | |
| x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
| x = (x * 255).numpy().astype(np.uint8) | |
| outputs.append(x) | |
| # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) | |
| imageio.mimsave(path, outputs, loop=0, duration=1000/fps) | |
| return path | |