Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| import argparse | |
| import numpy as np | |
| import os | |
| from control_cogvideox.cogvideox_transformer_3d import CogVideoXTransformer3DModel | |
| from control_cogvideox.controlnet_cogvideox_transformer_3d import ControlCogVideoXTransformer3DModel | |
| from pipeline_cogvideox_controlnet_5b_i2v_instruction2 import ControlCogVideoXPipeline | |
| from diffusers.utils import export_to_video | |
| from diffusers import AutoencoderKLCogVideoX | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| from diffusers.schedulers import CogVideoXDDIMScheduler | |
| from safetensors.torch import load_file | |
| from omegaconf import OmegaConf | |
| from transformers import T5EncoderModel | |
| from einops import rearrange | |
| from decord import VideoReader | |
| import transformers | |
| from transformers import CLIPTextModel, CLIPProcessor, CLIPVisionModel, CLIPTokenizer | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from dataset_demo_videos import VideoDataset | |
| def unwarp_model(state_dict): | |
| new_state_dict = {} | |
| for key in state_dict: | |
| new_state_dict[key.split('module.')[1]] = state_dict[key] | |
| return new_state_dict | |
| """ | |
| def transform_tensor_to_images(images): | |
| images = images.cpu().detach().numpy() | |
| images = np.uint8(images) | |
| images2 = [] | |
| for image in images: | |
| image = Image.fromarray(image) | |
| images2.append(image) | |
| return images2 | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--pos_prompt", type=str, default="") | |
| parser.add_argument("--neg_prompt", type=str, default="") | |
| parser.add_argument("--training_steps", type=int, default=30001) | |
| parser.add_argument("--root_path", type=str, default="./models_half") | |
| parser.add_argument("--i2v", action="store_true",default=True) | |
| parser.add_argument("--guidance_scale", type=float, default=4.0) | |
| parser.add_argument("--random_seed", type=int, default=0) | |
| args = parser.parse_args() | |
| #----------------------------------------------------------------- | |
| prefix = args.root_path.replace("/","_").replace(".","_") + "_" + args.pos_prompt.replace(" ","_").replace(".","_") | |
| #----------------------------------------------------------------- | |
| if args.i2v: | |
| key = "i2v" | |
| else: | |
| key = "t2v" | |
| noise_scheduler = CogVideoXDDIMScheduler( | |
| **OmegaConf.to_container( | |
| OmegaConf.load(f"./cogvideox-5b-{key}/scheduler/scheduler_config.json") | |
| ) | |
| ) | |
| text_encoder = T5EncoderModel.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16)#.to("cuda:0") | |
| vae = AutoencoderKLCogVideoX.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16).to("cuda:0") | |
| tokenizer = T5Tokenizer.from_pretrained(f"./cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16) | |
| config = OmegaConf.to_container( | |
| OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json") | |
| ) | |
| if args.i2v: | |
| config["in_channels"] = 32 | |
| else: | |
| config["in_channels"] = 16 | |
| transformer = CogVideoXTransformer3DModel(**config) | |
| control_config = OmegaConf.to_container( | |
| OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json") | |
| ) | |
| if args.i2v: | |
| control_config["in_channels"] = 32 | |
| else: | |
| control_config["in_channels"] = 16 | |
| control_config['num_layers'] = 6 | |
| control_config['control_in_channels'] = 16 | |
| controlnet_transformer = ControlCogVideoXTransformer3DModel(**control_config) | |
| all_state_dicts = torch.load("{args.root_path}/ff_controlnet_half.pth", map_location="cpu",weights_only=True) | |
| transformer_state_dict = unwarp_model(all_state_dicts["transformer_state_dict"]) | |
| controlnet_transformer_state_dict = unwarp_model(all_state_dicts["controlnet_transformer_state_dict"]) | |
| transformer.load_state_dict(transformer_state_dict, strict=True) | |
| controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True) | |
| transformer = transformer.half().to("cuda:0") | |
| controlnet_transformer = controlnet_transformer.half().to("cuda:0") | |
| vae = vae.eval() | |
| text_encoder = text_encoder.eval() | |
| transformer = transformer.eval() | |
| controlnet_transformer = controlnet_transformer.eval() | |
| pipe = ControlCogVideoXPipeline(tokenizer, | |
| text_encoder, | |
| vae, | |
| transformer, | |
| noise_scheduler, | |
| controlnet_transformer, | |
| )#.to("cuda:0") | |
| pipe.vae.enable_slicing() | |
| pipe.vae.enable_tiling() | |
| pipe.enable_model_cpu_offload() | |
| def inference(prefix, source_images, \ | |
| target_images, \ | |
| text_prompt, negative_prompt, \ | |
| pipe, vae, \ | |
| step, guidance_scale, \ | |
| target_path, video_dir, \ | |
| h, w, random_seed): | |
| source_pixel_values = source_images/127.5 - 1.0 | |
| source_pixel_values = source_pixel_values.to(torch.float16).to("cuda:0") | |
| if target_images is not None: | |
| target_pixel_values = target_images/127.5 - 1.0 | |
| target_pixel_values = target_pixel_values.to(torch.float16).to("cuda:0") | |
| bsz,f,h,w,c = source_pixel_values.shape | |
| with torch.no_grad(): | |
| source_pixel_values = rearrange(source_pixel_values, "b f w h c -> b c f w h") | |
| source_latents = vae.encode(source_pixel_values).latent_dist.sample() | |
| source_latents = source_latents.to(torch.float16) | |
| source_latents = source_latents * vae.config.scaling_factor | |
| source_latents = rearrange(source_latents, "b c f h w -> b f c h w") | |
| if target_images is not None: | |
| target_pixel_values = rearrange(target_pixel_values, "b f w h c -> b c f w h") | |
| images = target_pixel_values[:,:,:1,...] | |
| image_latents = vae.encode(images).latent_dist.sample() | |
| image_latents = image_latents.to(torch.float16) | |
| image_latents = image_latents * vae.config.scaling_factor | |
| image_latents = rearrange(image_latents, "b c f h w -> b f c h w") | |
| image_latents = torch.cat([image_latents, torch.zeros_like(source_latents)[:,1:]],dim=1) | |
| latents = torch.cat([image_latents, source_latents], dim=2) | |
| else: | |
| image_latents = None | |
| latents = source_latents | |
| video = pipe( | |
| prompt = text_prompt, | |
| negative_prompt = negative_prompt, | |
| video_condition = source_latents, # input to controlnet | |
| video_condition2 = image_latents, # concat with latents | |
| height = h, | |
| width = w, | |
| num_frames = f, | |
| num_inference_steps = 50, | |
| interval = 6, | |
| guidance_scale = guidance_scale, | |
| generator = torch.Generator(device=f"cuda:0").manual_seed(random_seed) | |
| ).frames[0] | |
| def transform_tensor_to_images(images): | |
| images = images.cpu().detach().numpy() | |
| images = np.uint8(images) | |
| images2 = [] | |
| for image in images: | |
| image = Image.fromarray(image) | |
| images2.append(image) | |
| return images2 | |
| source_images = transform_tensor_to_images(source_images[0]) | |
| os.makedirs(f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}", exist_ok=True) | |
| export_to_video(video, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}.mp4", fps=8) | |
| export_to_video(source_images, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}_org.mp4", fps=8) | |
| def read_video(video_path, h, w): | |
| vr = VideoReader(video_path) | |
| images = vr.get_batch(list(range(min(33, len(vr))))).asnumpy() | |
| images2 = [] | |
| for image in images: | |
| image = cv2.resize(image, (h,w)) | |
| images2.append(image) | |
| images2 = np.array(images2) | |
| images = images2 | |
| del vr | |
| images = torch.from_numpy(images) | |
| return images | |
| def resize(images, h, w): | |
| images = rearrange(images, "f w h c -> f c w h") | |
| images = F.interpolate(images, (h, w), mode="bilinear") | |
| images = rearrange(images, "f c w h -> f w h c") | |
| images = images[None,...] | |
| return images | |
| h = 448 | |
| w = 768 | |
| root_dir = 'additional_videos8' | |
| dataset = VideoDataset(root_dir) | |
| print(len(dataset)) | |
| for step, sample in enumerate(dataset): | |
| image = sample['image'] # w h c | |
| images = sample['frames'] # f w h c | |
| pos_prompt = sample['pos_prompt'] | |
| neg_prompt = sample['neg_prompt'] | |
| image_path = sample['image_path'] | |
| prefix = image_path.replace("/","_") | |
| source_images = images[None,...] | |
| target_images = image[None,None,...] | |
| print(pos_prompt, neg_prompt) | |
| print(source_images.shape, torch.min(source_images), torch.max(source_images)) | |
| print(target_images.shape, torch.min(target_images), torch.max(target_images)) | |
| target_path = f"demo_first_frame_controlnet_33_stride_2_new_videos_8/{prefix}/" | |
| random_seeds = [args.random_seed] | |
| for random_seed in random_seeds: | |
| inference("", source_images, \ | |
| target_images, pos_prompt, \ | |
| neg_prompt, pipe, vae, \ | |
| args.training_steps, args.guidance_scale, \ | |
| target_path, "", \ | |
| h, w, random_seed) | |