import argparse import os from datetime import datetime from pathlib import Path from typing import List import av import numpy as np import torch import torchvision from diffusers import AutoencoderKL, DDIMScheduler from omegaconf import OmegaConf from PIL import Image from transformers import CLIPVisionModelWithProjection from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d_edit_bkfill import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long_edit_bkfill_roiclip import Pose2VideoPipeline from src.utils.util import get_fps, read_frames import cv2 from tools.human_segmenter import human_segmenter import imageio from tools.util import all_file, load_mask_list, crop_img, pad_img, crop_human_clip_auto_context, get_mask, \ refine_img_prepross import gradio as gr import json from huggingface_hub import snapshot_download import spaces MOTION_TRIGGER_WORD = { 'sports_basketball_gym': [], 'sports_nba_pass': [], 'sports_nba_dunk': [], 'movie_BruceLee1': [], 'shorts_kungfu_match1': [], 'shorts_kungfu_desert1': [], 'parkour_climbing': [], 'dance_indoor_1': [], } css_style = "#fixed_size_img {height: 500px;}" def download_models(): """Download required models from Hugging Face""" print("Checking and downloading models...") # Download main MIMO weights if not os.path.exists('./pretrained_weights/denoising_unet.pth'): print("Downloading MIMO model weights...") try: snapshot_download( repo_id='menyifang/MIMO', cache_dir='./pretrained_weights', local_dir='./pretrained_weights', local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading MIMO weights: {e}") # Fallback to ModelScope if available try: from modelscope import snapshot_download as ms_snapshot_download ms_snapshot_download( model_id='iic/MIMO', cache_dir='./pretrained_weights', local_dir='./pretrained_weights' ) except Exception as e2: print(f"Error downloading from ModelScope: {e2}") # Download base models if not present if not os.path.exists('./pretrained_weights/stable-diffusion-v1-5'): print("Downloading Stable Diffusion v1.5...") try: snapshot_download( repo_id='runwayml/stable-diffusion-v1-5', cache_dir='./pretrained_weights', local_dir='./pretrained_weights/stable-diffusion-v1-5', local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading SD v1.5: {e}") if not os.path.exists('./pretrained_weights/sd-vae-ft-mse'): print("Downloading VAE...") try: snapshot_download( repo_id='stabilityai/sd-vae-ft-mse', cache_dir='./pretrained_weights', local_dir='./pretrained_weights/sd-vae-ft-mse', local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading VAE: {e}") if not os.path.exists('./pretrained_weights/image_encoder'): print("Downloading Image Encoder...") try: snapshot_download( repo_id='lambdalabs/sd-image-variations-diffusers', cache_dir='./pretrained_weights', local_dir='./pretrained_weights/image_encoder', local_dir_use_symlinks=False, subfolder='image_encoder' ) except Exception as e: print(f"Error downloading image encoder: {e}") # Download assets if not present if not os.path.exists('./assets'): print("Downloading assets...") # This would need to be uploaded to HF or provided another way # For now, create minimal required structure os.makedirs('./assets/masks', exist_ok=True) os.makedirs('./assets/test_image', exist_ok=True) os.makedirs('./assets/video_template', exist_ok=True) def init_bk(n_frame, tw, th): """Initialize background frames""" bk_images = [] for _ in range(n_frame): bk_img = Image.new('RGB', (tw, th), color='white') bk_images.append(bk_img) return bk_images # Initialize segmenter with error handling seg_path = './assets/matting_human.pb' try: segmenter = human_segmenter(model_path=seg_path) if os.path.exists(seg_path) else None except Exception as e: print(f"Warning: Could not initialize segmenter: {e}") segmenter = None def process_seg(img): """Process image segmentation with fallback""" if segmenter is None: # Fallback: return original image with dummy mask img_array = np.array(img) if isinstance(img, Image.Image) else img mask = np.ones((img_array.shape[0], img_array.shape[1]), dtype=np.uint8) * 255 return img_array, mask try: rgba = segmenter.run(img) mask = rgba[:, :, 3] color = rgba[:, :, :3] alpha = mask / 255 bk = np.ones_like(color) * 255 color = color * alpha[:, :, np.newaxis] + bk * (1 - alpha[:, :, np.newaxis]) color = color.astype(np.uint8) return color, mask except Exception as e: print(f"Error in segmentation: {e}") # Fallback to original image img_array = np.array(img) if isinstance(img, Image.Image) else img mask = np.ones((img_array.shape[0], img_array.shape[1]), dtype=np.uint8) * 255 return img_array, mask def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default='./configs/prompts/animation_edit.yaml') parser.add_argument("-W", type=int, default=784) parser.add_argument("-H", type=int, default=784) parser.add_argument("-L", type=int, default=64) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--cfg", type=float, default=3.5) parser.add_argument("--steps", type=int, default=25) parser.add_argument("--fps", type=int) parser.add_argument("--assets_dir", type=str, default='./assets') parser.add_argument("--ref_pad", type=int, default=1) parser.add_argument("--use_bk", type=int, default=1) parser.add_argument("--clip_length", type=int, default=32) parser.add_argument("--MAX_FRAME_NUM", type=int, default=150) args = parser.parse_args() return args class MIMO(): def __init__(self, debug_mode=False): try: # Download models first download_models() args = parse_args() config = OmegaConf.load(args.config) if config.weight_dtype == "fp16": weight_dtype = torch.float16 else: weight_dtype = torch.float32 # Check CUDA availability device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") if device == "cpu": weight_dtype = torch.float32 print("Warning: Running on CPU, performance may be slow") vae = AutoencoderKL.from_pretrained( config.pretrained_vae_path, ).to(device, dtype=weight_dtype) reference_unet = UNet2DConditionModel.from_pretrained( config.pretrained_base_model_path, subfolder="unet", ).to(dtype=weight_dtype, device=device) inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_pretrained_2d( config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=weight_dtype, device=device) pose_guider = PoseGuider(320, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)).to( dtype=weight_dtype, device=device ) image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path ).to(dtype=weight_dtype, device=device) sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) self.generator = torch.manual_seed(args.seed) self.width, self.height = args.W, args.H self.device = device # Load pretrained weights with error handling try: denoising_unet.load_state_dict( torch.load(config.denoising_unet_path, map_location="cpu"), strict=False, ) reference_unet.load_state_dict( torch.load(config.reference_unet_path, map_location="cpu"), ) pose_guider.load_state_dict( torch.load(config.pose_guider_path, map_location="cpu"), ) print("Successfully loaded all model weights") except Exception as e: print(f"Error loading model weights: {e}") raise self.pipe = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider=pose_guider, scheduler=scheduler, ) self.pipe = self.pipe.to(device, dtype=weight_dtype) self.args = args # Load mask with error handling mask_path = os.path.join(self.args.assets_dir, 'masks', 'alpha2.png') try: self.mask_list = load_mask_list(mask_path) if os.path.exists(mask_path) else None except Exception as e: print(f"Warning: Could not load mask: {e}") self.mask_list = None except Exception as e: print(f"Error initializing MIMO: {e}") raise def load_template(self, template_path): video_path = os.path.join(template_path, 'vid.mp4') pose_video_path = os.path.join(template_path, 'sdc.mp4') bk_video_path = os.path.join(template_path, 'bk.mp4') occ_video_path = os.path.join(template_path, 'occ.mp4') if not os.path.exists(occ_video_path): occ_video_path = None config_file = os.path.join(template_path, 'config.json') with open(config_file) as f: template_data = json.load(f) template_info = {} template_info['video_path'] = video_path template_info['pose_video_path'] = pose_video_path template_info['bk_video_path'] = bk_video_path template_info['occ_video_path'] = occ_video_path template_info['target_fps'] = template_data['fps'] template_info['time_crop'] = template_data['time_crop'] template_info['frame_crop'] = template_data['frame_crop'] template_info['layer_recover'] = template_data['layer_recover'] return template_info @spaces.GPU(duration=60) # Allocate GPU for 60 seconds def run(self, ref_image_pil, template_name): try: template_dir = os.path.join(self.args.assets_dir, 'video_template') template_path = os.path.join(template_dir, template_name) if not os.path.exists(template_path): return None, f"Template {template_name} not found" template_info = self.load_template(template_path) target_fps = template_info['target_fps'] video_path = template_info['video_path'] pose_video_path = template_info['pose_video_path'] bk_video_path = template_info['bk_video_path'] occ_video_path = template_info['occ_video_path'] # Process reference image source_image = np.array(ref_image_pil) source_image, mask = process_seg(source_image[..., ::-1]) source_image = source_image[..., ::-1] source_image = crop_img(source_image, mask) source_image, _ = pad_img(source_image, [255, 255, 255]) ref_image_pil = Image.fromarray(source_image) # Load template videos vid_images = read_frames(video_path) if bk_video_path is None or not os.path.exists(bk_video_path): n_frame = len(vid_images) tw, th = vid_images[0].size bk_images = init_bk(n_frame, tw, th) else: bk_images = read_frames(bk_video_path) if occ_video_path is not None and os.path.exists(occ_video_path): occ_mask_images = read_frames(occ_video_path) print('load occ from %s' % occ_video_path) else: occ_mask_images = None print('no occ masks') pose_images = read_frames(pose_video_path) src_fps = get_fps(pose_video_path) start_idx, end_idx = template_info['time_crop']['start_idx'], template_info['time_crop']['end_idx'] start_idx = max(0, start_idx) end_idx = min(len(pose_images), end_idx) pose_images = pose_images[start_idx:end_idx] vid_images = vid_images[start_idx:end_idx] bk_images = bk_images[start_idx:end_idx] if occ_mask_images is not None: occ_mask_images = occ_mask_images[start_idx:end_idx] self.args.L = len(pose_images) max_n_frames = self.args.MAX_FRAME_NUM if self.args.L > max_n_frames: pose_images = pose_images[:max_n_frames] vid_images = vid_images[:max_n_frames] bk_images = bk_images[:max_n_frames] if occ_mask_images is not None: occ_mask_images = occ_mask_images[:max_n_frames] self.args.L = len(pose_images) bk_images_ori = bk_images.copy() vid_images_ori = vid_images.copy() overlay = 4 pose_images, vid_images, bk_images, bbox_clip, context_list, bbox_clip_list = crop_human_clip_auto_context( pose_images, vid_images, bk_images, overlay) clip_pad_list_context = [] clip_padv_list_context = [] pose_list_context = [] vid_bk_list_context = [] for frame_idx in range(len(pose_images)): pose_image_pil = pose_images[frame_idx] pose_image = np.array(pose_image_pil) pose_image, _ = pad_img(pose_image, color=[0, 0, 0]) pose_image_pil = Image.fromarray(pose_image) pose_list_context.append(pose_image_pil) vid_bk = bk_images[frame_idx] vid_bk = np.array(vid_bk) vid_bk, padding_v = pad_img(vid_bk, color=[255, 255, 255]) pad_h, pad_w, _ = vid_bk.shape clip_pad_list_context.append([pad_h, pad_w]) clip_padv_list_context.append(padding_v) vid_bk_list_context.append(Image.fromarray(vid_bk)) print('Starting inference...') with torch.no_grad(): video = self.pipe( ref_image_pil, pose_list_context, vid_bk_list_context, self.width, self.height, len(pose_list_context), self.args.steps, self.args.cfg, generator=self.generator, ).videos[0] # Post-process video video_idx = 0 res_images = [None for _ in range(self.args.L)] for k, context in enumerate(context_list): start_i = context[0] bbox = bbox_clip_list[k] for i in context: bk_image_pil_ori = bk_images_ori[i] vid_image_pil_ori = vid_images_ori[i] if occ_mask_images is not None: occ_mask = occ_mask_images[i] else: occ_mask = None canvas = Image.new("RGB", bk_image_pil_ori.size, "white") pad_h, pad_w = clip_pad_list_context[video_idx] padding_v = clip_padv_list_context[video_idx] image = video[:, video_idx, :, :].permute(1, 2, 0).cpu().numpy() res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) res_image_pil = res_image_pil.resize((pad_w, pad_h)) top, bottom, left, right = padding_v res_image_pil = res_image_pil.crop((left, top, pad_w - right, pad_h - bottom)) w_min, w_max, h_min, h_max = bbox canvas.paste(res_image_pil, (w_min, h_min)) mask_full = np.zeros((bk_image_pil_ori.size[1], bk_image_pil_ori.size[0]), dtype=np.float32) res_image = np.array(canvas) bk_image = np.array(bk_image_pil_ori) if self.mask_list is not None: mask = get_mask(self.mask_list, bbox, bk_image_pil_ori) mask = cv2.resize(mask, res_image_pil.size, interpolation=cv2.INTER_AREA) mask_full[h_min:h_min + mask.shape[0], w_min:w_min + mask.shape[1]] = mask else: # Use simple rectangle mask if no mask list available mask_full[h_min:h_max, w_min:w_max] = 1.0 res_image = res_image * mask_full[:, :, np.newaxis] + bk_image * (1 - mask_full[:, :, np.newaxis]) if occ_mask is not None: vid_image = np.array(vid_image_pil_ori) occ_mask = np.array(occ_mask)[:, :, 0].astype(np.uint8) occ_mask = occ_mask / 255.0 res_image = res_image * (1 - occ_mask[:, :, np.newaxis]) + vid_image * occ_mask[:, :, np.newaxis] if res_images[i] is None: res_images[i] = res_image else: factor = (i - start_i + 1) / (overlay + 1) res_images[i] = res_images[i] * (1 - factor) + res_image * factor res_images[i] = res_images[i].astype(np.uint8) video_idx = video_idx + 1 return res_images except Exception as e: print(f"Error during inference: {e}") return None class WebApp(): def __init__(self, debug_mode=False): self.args_base = { "device": "cuda" if torch.cuda.is_available() else "cpu", "output_dir": "output_demo", "img": None, "pos_prompt": '', "motion": "sports_basketball_gym", "motion_dir": "./assets/test_video_trunc", } self.args_input = {} self.gr_motion = list(MOTION_TRIGGER_WORD.keys()) self.debug_mode = debug_mode # Initialize model with error handling try: self.model = MIMO() print("MIMO model loaded successfully") except Exception as e: print(f"Error loading MIMO model: {e}") self.model = None def title(self): gr.HTML( """

🎭 MIMO Demo - Controllable Character Video Synthesis

Transform character images into animated videos with controllable motion and scenes

Project Page | Paper | GitHub

""" ) def get_template(self, num_cols=3): self.args_input['motion'] = gr.State('sports_basketball_gym') num_cols = 2 # Create example gallery (simplified for HF Spaces) template_examples = [] for motion in self.gr_motion: example_path = os.path.join(self.args_base['motion_dir'], f"{motion}.mp4") if os.path.exists(example_path): template_examples.append((example_path, motion)) else: # Use placeholder if template video doesn't exist template_examples.append((None, motion)) lora_gallery = gr.Gallery( label='Motion Templates', columns=num_cols, height=400, value=template_examples, show_label=True, selected_index=0 ) lora_gallery.select(self._update_selection, inputs=[], outputs=[self.args_input['motion']]) def _update_selection(self, selected_state: gr.SelectData): return self.gr_motion[selected_state.index] def run_process(self, *values): if self.model is None: return None, "❌ Model not loaded. Please refresh the page." try: gr_args = self.args_base.copy() for k, v in zip(list(self.args_input.keys()), values): gr_args[k] = v ref_image_pil = gr_args['img'] template_name = gr_args['motion'] if ref_image_pil is None: return None, "⚠️ Please upload an image first." print(f'Processing with template: {template_name}') save_dir = 'output' os.makedirs(save_dir, exist_ok=True) case = datetime.now().strftime("%Y%m%d%H%M%S") outpath = f"{save_dir}/{case}.mp4" res = self.model.run(ref_image_pil, template_name) if res is None: return None, "❌ Failed to generate video. Please try again or select a different template." imageio.mimsave(outpath, res, fps=30, quality=8, macro_block_size=1) print(f'Video saved to: {outpath}') return outpath, "✅ Video generated successfully!" except Exception as e: print(f"Error in processing: {e}") return None, f"❌ Error: {str(e)}" def preset_library(self): with gr.Blocks() as demo: with gr.Accordion(label="🧭 Instructions", open=True): gr.Markdown(""" ### How to use: 1. **Upload a character image**: Use a full-body, front-facing image with clear visibility (no occlusion or handheld objects work best) 2. **Select motion template**: Choose from the available motion templates in the gallery 3. **Generate**: Click "Run" to create your character animation ### Tips: - Best results with clear, well-lit character images - Processing may take 1-2 minutes depending on video length - GPU acceleration is automatically used when available """) with gr.Row(): with gr.Column(): img_input = gr.Image(label='Upload Character Image', type="pil", elem_id="fixed_size_img") self.args_input['img'] = img_input submit_btn = gr.Button("🎬 Generate Animation", variant='primary', size="lg") status_text = gr.Textbox(label="Status", interactive=False, value="Ready to generate...") with gr.Column(): self.get_template(num_cols=2) with gr.Column(): res_vid = gr.Video(format="mp4", label="Generated Animation", autoplay=True, elem_id="fixed_size_img") submit_btn.click( self.run_process, inputs=list(self.args_input.values()), outputs=[res_vid, status_text], scroll_to_output=True, ) # Add examples if available example_images = [] example_dir = './assets/test_image' if os.path.exists(example_dir): for img_name in ['sugar.jpg', 'ouwen1.png', 'actorhq_A1S1.png', 'cartoon1.png', 'avatar.jpg']: img_path = os.path.join(example_dir, img_name) if os.path.exists(img_path): example_images.append([img_path]) if example_images: gr.Examples( examples=example_images, inputs=[img_input], examples_per_page=5, label="Example Images" ) def ui(self): with gr.Blocks(css=css_style, title="MIMO - Controllable Character Video Synthesis") as demo: self.title() self.preset_library() return demo # Initialize and run print("Initializing MIMO demo...") app = WebApp(debug_mode=False) demo = app.ui() if __name__ == "__main__": demo.queue(max_size=10) # For Hugging Face Spaces demo.launch(server_name="0.0.0.0", server_port=7860, share=False)