Spaces:
Runtime error
Runtime error
| import os | |
| import copy | |
| import torch | |
| import random | |
| import gradio as gr | |
| from glob import glob | |
| from omegaconf import OmegaConf | |
| from safetensors import safe_open | |
| from diffusers import AutoencoderKL | |
| from diffusers import EulerDiscreteScheduler, DDIMScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from utils.unet import UNet3DConditionModel | |
| from utils.pipeline_magictime import MagicTimePipeline | |
| from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model | |
| # import spaces | |
| pretrained_model_path = "./ckpts/Base_Model/stable-diffusion-v1-5" | |
| inference_config_path = "./sample_configs/RealisticVision.yaml" | |
| magic_adapter_s_path = "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" | |
| magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t" | |
| magic_text_encoder_path = "./ckpts/Magic_Weights/magic_text_encoder" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| css = """ | |
| .toolbutton { | |
| margin-buttom: 0em 0em 0em 0em; | |
| max-width: 2.5em; | |
| min-width: 2.5em !important; | |
| height: 2.5em; | |
| } | |
| """ | |
| examples = [ | |
| # 1-RealisticVision | |
| [ | |
| "RealisticVisionV60B1_v51VAE.safetensors", | |
| "motion_module.ckpt", | |
| "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms.", | |
| "worst quality, low quality, letterboxed", | |
| 512, 512, "2038801077" | |
| ], | |
| # 2-RCNZ | |
| [ | |
| "RcnzCartoon.safetensors", | |
| "motion_module.ckpt", | |
| "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney.", | |
| "worst quality, low quality, letterboxed", | |
| 512, 512, "1268480012" | |
| ], | |
| # 3-ToonYou | |
| [ | |
| "ToonYou_beta6.safetensors", | |
| "motion_module.ckpt", | |
| "Bean sprouts grow and mature from seeds.", | |
| "worst quality, low quality, letterboxed", | |
| 512, 512, "1496541313" | |
| ] | |
| ] | |
| # clean Grdio cache | |
| print(f"### Cleaning cached examples ...") | |
| os.system(f"rm -rf gradio_cached_examples/") | |
| # @spaces.GPU(duration=300) | |
| class MagicTimeController: | |
| def __init__(self): | |
| # config dirs | |
| self.basedir = os.getcwd() | |
| self.stable_diffusion_dir = os.path.join(self.basedir, "ckpts", "Base_Model") | |
| self.motion_module_dir = os.path.join(self.basedir, "ckpts", "Base_Model", "motion_module") | |
| self.personalized_model_dir = os.path.join(self.basedir, "ckpts", "DreamBooth") | |
| self.savedir = os.path.join(self.basedir, "outputs") | |
| os.makedirs(self.savedir, exist_ok=True) | |
| self.dreambooth_list = [] | |
| self.motion_module_list = [] | |
| self.selected_dreambooth = None | |
| self.selected_motion_module = None | |
| self.refresh_motion_module() | |
| self.refresh_personalized_model() | |
| # config models | |
| self.inference_config = OmegaConf.load(inference_config_path)[1] | |
| self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device) | |
| self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) | |
| self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device) | |
| self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | |
| # self.tokenizer = tokenizer | |
| # self.text_encoder = text_encoder | |
| # self.vae = vae | |
| # self.unet = unet | |
| # self.text_model = text_model | |
| self.update_motion_module(self.motion_module_list[0]) | |
| self.update_dreambooth(self.dreambooth_list[0]) | |
| def refresh_motion_module(self): | |
| motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) | |
| self.motion_module_list = [os.path.basename(p) for p in motion_module_list] | |
| def refresh_personalized_model(self): | |
| dreambooth_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) | |
| self.dreambooth_list = [os.path.basename(p) for p in dreambooth_list] | |
| def update_dreambooth(self, dreambooth_dropdown): | |
| self.selected_dreambooth = dreambooth_dropdown | |
| dreambooth_dropdown = os.path.join(self.personalized_model_dir, dreambooth_dropdown) | |
| dreambooth_state_dict = {} | |
| with safe_open(dreambooth_dropdown, framework="pt", device="cpu") as f: | |
| for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key) | |
| converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config) | |
| self.vae.load_state_dict(converted_vae_checkpoint) | |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet.config) | |
| self.unet.load_state_dict(converted_unet_checkpoint, strict=False) | |
| text_model = copy.deepcopy(self.text_model) | |
| self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict) | |
| from swift import Swift | |
| magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu") | |
| self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0) | |
| self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path) | |
| self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path) | |
| return gr.Dropdown() | |
| def update_motion_module(self, motion_module_dropdown): | |
| self.selected_motion_module = motion_module_dropdown | |
| motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) | |
| motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") | |
| _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) | |
| assert len(unexpected) == 0 | |
| return gr.Dropdown() | |
| def magictime( | |
| self, | |
| dreambooth_dropdown, | |
| motion_module_dropdown, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| width_slider, | |
| height_slider, | |
| seed_textbox, | |
| ): | |
| if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown) | |
| if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown) | |
| if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() | |
| pipeline = MagicTimePipeline( | |
| vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, | |
| scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) | |
| ).to(device) | |
| if int(seed_textbox) > 0: seed = int(seed_textbox) | |
| else: seed = random.randint(1, 1e16) | |
| torch.manual_seed(int(seed)) | |
| assert seed == torch.initial_seed() | |
| print(f"### seed: {seed}") | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(seed) | |
| sample = pipeline( | |
| prompt_textbox, | |
| negative_prompt = negative_prompt_textbox, | |
| num_inference_steps = 25, | |
| guidance_scale = 8., | |
| width = width_slider, | |
| height = height_slider, | |
| video_length = 16, | |
| generator = generator, | |
| ).videos | |
| save_sample_path = os.path.join(self.savedir, f"sample.mp4") | |
| save_videos_grid(sample, save_sample_path) | |
| json_config = { | |
| "prompt": prompt_textbox, | |
| "n_prompt": negative_prompt_textbox, | |
| "width": width_slider, | |
| "height": height_slider, | |
| "seed": seed, | |
| "dreambooth": dreambooth_dropdown, | |
| } | |
| return gr.Video(value=save_sample_path), gr.Json(value=json_config) | |
| # inference_config = OmegaConf.load(inference_config_path)[1] | |
| # tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| # text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda() | |
| # vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda() | |
| # unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda() | |
| # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | |
| # controller = MagicTimeController(tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, unet=unet, text_model=text_model) | |
| controller = MagicTimeController() | |
| def ui(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| <h2 align="center"> <a href="https://github.com/PKU-YuanGroup/MagicTime">MagicTime: Time-lapse Video Generation Models as Metamorphic Simulators</a></h2> | |
| <h5 align="center"> If you like our project, please give us a star ⭐ on GitHub for the latest update. </h2> | |
| [GitHub](https://img.shields.io/github/stars/PKU-YuanGroup/MagicTime) | [arXiv](https://arxiv.org/abs/2404.05014) | [Home Page](https://pku-yuangroup.github.io/MagicTime/) | [Dataset](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| dreambooth_dropdown = gr.Dropdown( label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True ) | |
| motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True ) | |
| dreambooth_dropdown.change(fn=controller.update_dreambooth, inputs=[dreambooth_dropdown], outputs=[dreambooth_dropdown]) | |
| motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown]) | |
| prompt_textbox = gr.Textbox( label="Prompt", lines=3 ) | |
| negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo") | |
| with gr.Accordion("Advance", open=False): | |
| with gr.Row(): | |
| width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64 ) | |
| height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64 ) | |
| with gr.Row(): | |
| seed_textbox = gr.Textbox(label="Seed", value=-1, interactive=True) | |
| seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") | |
| seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox]) | |
| generate_button = gr.Button( value="Generate", variant='primary' ) | |
| with gr.Column(): | |
| result_video = gr.Video( label="Generated Animation", interactive=False ) | |
| json_config = gr.Json( label="Config", value=None ) | |
| inputs = [dreambooth_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox] | |
| outputs = [result_video, json_config] | |
| generate_button.click( fn=controller.magictime, inputs=inputs, outputs=outputs ) | |
| #gr.Examples( fn=controller.magictime, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True ) | |
| gr.Examples( fn=controller.magictime, examples=examples, inputs=inputs, outputs=outputs) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = ui() | |
| demo.queue(max_size=20) | |
| demo.launch() | |