|
|
|
|
|
""" |
|
|
MIMO - Complete HuggingFace Spaces Implementation |
|
|
Controllable Character Video Synthesis with Spatial Decomposed Modeling |
|
|
|
|
|
Complete features matching README_SETUP.md: |
|
|
- Character Image Animation (run_animate.py functionality) |
|
|
- Video Character Editing (run_edit.py functionality) |
|
|
- Real motion templates from assets/video_template/ |
|
|
- Auto GPU detection (T4/A10G/A100) |
|
|
- Auto model downloading |
|
|
- Human segmentation and background processing |
|
|
- Pose-guided video generation with occlusion handling |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
HAS_SPACES = True |
|
|
print("β
Spaces library loaded - ZeroGPU mode enabled") |
|
|
except ImportError: |
|
|
HAS_SPACES = False |
|
|
print("β οΈ Spaces library not available - running in local mode") |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Dict, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
import imageio |
|
|
from omegaconf import OmegaConf |
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
|
from transformers import CLIPVisionModelWithProjection |
|
|
|
|
|
|
|
|
sys.path.append('./src') |
|
|
|
|
|
from src.models.pose_guider import PoseGuider |
|
|
from src.models.unet_2d_condition import UNet2DConditionModel |
|
|
from src.models.unet_3d_edit_bkfill import UNet3DConditionModel |
|
|
from src.pipelines.pipeline_pose2vid_long_edit_bkfill_roiclip import Pose2VideoPipeline |
|
|
from src.utils.util import get_fps, read_frames |
|
|
|
|
|
|
|
|
try: |
|
|
from tools.human_segmenter import human_segmenter |
|
|
HAS_SEGMENTER = True |
|
|
except ImportError: |
|
|
print("β οΈ TensorFlow not available, human_segmenter disabled (will use fallback)") |
|
|
human_segmenter = None |
|
|
HAS_SEGMENTER = False |
|
|
|
|
|
from tools.util import ( |
|
|
load_mask_list, crop_img, pad_img, crop_human, |
|
|
crop_human_clip_auto_context, get_mask, load_video_fixed_fps, |
|
|
recover_bk, all_file |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cpu" |
|
|
MODEL_CACHE = "./models" |
|
|
ASSETS_CACHE = "./assets" |
|
|
|
|
|
|
|
|
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
|
|
|
|
|
class CompleteMIMO: |
|
|
"""Complete MIMO implementation matching README_SETUP.md functionality""" |
|
|
|
|
|
def __init__(self): |
|
|
self.pipe = None |
|
|
self.is_loaded = False |
|
|
self.segmenter = None |
|
|
self.mask_list = None |
|
|
self.weight_dtype = torch.float32 |
|
|
self._model_cache_valid = False |
|
|
|
|
|
|
|
|
os.makedirs(MODEL_CACHE, exist_ok=True) |
|
|
os.makedirs(ASSETS_CACHE, exist_ok=True) |
|
|
os.makedirs("./output", exist_ok=True) |
|
|
|
|
|
print(f"π MIMO initializing on {DEVICE}") |
|
|
if DEVICE == "cuda": |
|
|
print(f"π GPU: {torch.cuda.get_device_name()}") |
|
|
print(f"πΎ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") |
|
|
|
|
|
|
|
|
self._check_existing_models() |
|
|
|
|
|
def _check_existing_models(self): |
|
|
"""Check if models are already downloaded and show status""" |
|
|
try: |
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
model_dirs = [ |
|
|
Path(f"{MODEL_CACHE}/stable-diffusion-v1-5"), |
|
|
Path(f"{MODEL_CACHE}/sd-vae-ft-mse"), |
|
|
Path(f"{MODEL_CACHE}/mimo_weights"), |
|
|
Path(f"{MODEL_CACHE}/image_encoder_full") |
|
|
] |
|
|
|
|
|
|
|
|
cache_patterns = [ |
|
|
"models--runwayml--stable-diffusion-v1-5", |
|
|
"models--stabilityai--sd-vae-ft-mse", |
|
|
"models--menyifang--MIMO", |
|
|
"models--lambdalabs--sd-image-variations-diffusers" |
|
|
] |
|
|
|
|
|
models_found = 0 |
|
|
for pattern in cache_patterns: |
|
|
|
|
|
for cache_dir in Path(MODEL_CACHE).rglob(pattern): |
|
|
if cache_dir.is_dir(): |
|
|
models_found += 1 |
|
|
break |
|
|
|
|
|
|
|
|
for model_dir in model_dirs: |
|
|
if model_dir.exists() and model_dir.is_dir(): |
|
|
models_found += 1 |
|
|
|
|
|
if models_found >= 3: |
|
|
print(f"β
Found {models_found} model components in cache - models persist across restarts!") |
|
|
self._model_cache_valid = True |
|
|
if not self.is_loaded: |
|
|
print("π‘ Models available - click 'Load Model' to activate") |
|
|
return True |
|
|
else: |
|
|
print(f"β οΈ Only found {models_found} model components - click 'Setup Models' to download") |
|
|
self._model_cache_valid = False |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not check existing models: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
self._model_cache_valid = False |
|
|
return False |
|
|
|
|
|
def download_models(self, progress_callback=None): |
|
|
"""Download all required models matching README_SETUP.md requirements""" |
|
|
|
|
|
|
|
|
|
|
|
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '0' |
|
|
|
|
|
def update_progress(msg): |
|
|
if progress_callback: |
|
|
progress_callback(msg) |
|
|
print(f"π₯ {msg}") |
|
|
|
|
|
update_progress("π§ Disabled hf_transfer for stable downloads") |
|
|
|
|
|
downloaded_count = 0 |
|
|
total_steps = 7 |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
update_progress("Downloading MIMO main models...") |
|
|
snapshot_download( |
|
|
repo_id="menyifang/MIMO", |
|
|
cache_dir=f"{MODEL_CACHE}/mimo_weights", |
|
|
allow_patterns=["*.pth", "*.json", "*.md"], |
|
|
token=None |
|
|
) |
|
|
downloaded_count += 1 |
|
|
update_progress(f"β
MIMO models downloaded ({downloaded_count}/{total_steps})") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ MIMO models download failed: {str(e)[:100]}") |
|
|
print(f"Error details: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Downloading Stable Diffusion v1.5...") |
|
|
snapshot_download( |
|
|
repo_id="runwayml/stable-diffusion-v1-5", |
|
|
cache_dir=f"{MODEL_CACHE}/stable-diffusion-v1-5", |
|
|
allow_patterns=["**/*.json", "**/*.bin", "**/*.safetensors", "**/*.txt"], |
|
|
ignore_patterns=["*.msgpack", "*.h5", "*.ot"], |
|
|
token=None |
|
|
) |
|
|
downloaded_count += 1 |
|
|
update_progress(f"β
SD v1.5 downloaded ({downloaded_count}/{total_steps})") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ SD v1.5 download failed: {str(e)[:100]}") |
|
|
print(f"Error details: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Downloading sd-vae-ft-mse...") |
|
|
snapshot_download( |
|
|
repo_id="stabilityai/sd-vae-ft-mse", |
|
|
cache_dir=f"{MODEL_CACHE}/sd-vae-ft-mse", |
|
|
token=None |
|
|
) |
|
|
downloaded_count += 1 |
|
|
update_progress(f"β
VAE downloaded ({downloaded_count}/{total_steps})") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ VAE download failed: {str(e)[:100]}") |
|
|
print(f"Error details: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Downloading image encoder...") |
|
|
snapshot_download( |
|
|
repo_id="lambdalabs/sd-image-variations-diffusers", |
|
|
cache_dir=f"{MODEL_CACHE}/image_encoder_full", |
|
|
allow_patterns=["image_encoder/**"], |
|
|
token=None |
|
|
) |
|
|
downloaded_count += 1 |
|
|
update_progress(f"β
Image encoder downloaded ({downloaded_count}/{total_steps})") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Image encoder download failed: {str(e)[:100]}") |
|
|
print(f"Error details: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Downloading human segmenter...") |
|
|
os.makedirs(ASSETS_CACHE, exist_ok=True) |
|
|
if not os.path.exists(f"{ASSETS_CACHE}/matting_human.pb"): |
|
|
hf_hub_download( |
|
|
repo_id="menyifang/MIMO", |
|
|
filename="matting_human.pb", |
|
|
cache_dir=ASSETS_CACHE, |
|
|
local_dir=ASSETS_CACHE, |
|
|
token=None |
|
|
) |
|
|
downloaded_count += 1 |
|
|
update_progress(f"β
Human segmenter downloaded ({downloaded_count}/{total_steps})") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Human segmenter download failed (optional): {str(e)[:100]}") |
|
|
print(f"Will use fallback segmentation. Error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Setting up video templates...") |
|
|
os.makedirs("./assets/video_template", exist_ok=True) |
|
|
|
|
|
|
|
|
existing_templates = [] |
|
|
try: |
|
|
for item in os.listdir("./assets/video_template"): |
|
|
template_path = os.path.join("./assets/video_template", item) |
|
|
if os.path.isdir(template_path) and os.path.exists(os.path.join(template_path, "sdc.mp4")): |
|
|
existing_templates.append(item) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if existing_templates: |
|
|
update_progress(f"β
Found {len(existing_templates)} existing templates") |
|
|
downloaded_count += 1 |
|
|
else: |
|
|
update_progress("βΉοΈ No video templates found (optional - see TEMPLATES_SETUP.md)") |
|
|
print("π‘ Templates are optional. You can:") |
|
|
print(" 1. Use reference image only (no template needed)") |
|
|
print(" 2. Manually upload templates to assets/video_template/") |
|
|
print(" 3. See TEMPLATES_SETUP.md for instructions") |
|
|
|
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Template setup warning: {str(e)[:100]}") |
|
|
print("π‘ Templates are optional - app will work without them") |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Setting up directories...") |
|
|
os.makedirs("./assets/masks", exist_ok=True) |
|
|
os.makedirs("./output", exist_ok=True) |
|
|
downloaded_count += 1 |
|
|
update_progress(f"β
Directories created ({downloaded_count}/{total_steps})") |
|
|
except Exception as e: |
|
|
print(f"Directory creation warning: {e}") |
|
|
|
|
|
|
|
|
if downloaded_count >= 4: |
|
|
update_progress(f"β
Setup complete! ({downloaded_count}/{total_steps} steps successful)") |
|
|
|
|
|
self._model_cache_valid = True |
|
|
print("β
Model cache is now valid - 'Load Model' button will work") |
|
|
return True |
|
|
else: |
|
|
update_progress(f"β οΈ Partial download ({downloaded_count}/{total_steps}). Some features may not work.") |
|
|
|
|
|
if downloaded_count > 0: |
|
|
self._model_cache_valid = True |
|
|
return downloaded_count > 0 |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"β Download failed: {str(e)}" |
|
|
update_progress(error_msg) |
|
|
print(f"\n{'='*60}") |
|
|
print("ERROR DETAILS:") |
|
|
traceback.print_exc() |
|
|
print(f"{'='*60}\n") |
|
|
return False |
|
|
|
|
|
def load_model(self, progress_callback=None): |
|
|
"""Load MIMO model with complete functionality""" |
|
|
|
|
|
def update_progress(msg): |
|
|
if progress_callback: |
|
|
progress_callback(msg) |
|
|
print(f"π {msg}") |
|
|
|
|
|
try: |
|
|
if self.is_loaded: |
|
|
update_progress("β
Model already loaded") |
|
|
return True |
|
|
|
|
|
|
|
|
update_progress("Checking model files...") |
|
|
|
|
|
|
|
|
def find_model_path(primary_path, model_name, search_patterns=None): |
|
|
"""Find model in cache, checking multiple possible locations""" |
|
|
|
|
|
if os.path.exists(primary_path): |
|
|
|
|
|
try: |
|
|
has_config = os.path.exists(os.path.join(primary_path, "config.json")) |
|
|
has_model_files = any(f.endswith(('.bin', '.safetensors', '.pth')) for f in os.listdir(primary_path) if os.path.isfile(os.path.join(primary_path, f))) |
|
|
|
|
|
if has_config or has_model_files: |
|
|
update_progress(f"β
Found {model_name} at primary path") |
|
|
return primary_path |
|
|
else: |
|
|
|
|
|
update_progress(f"β οΈ Primary path exists but appears to be a cache directory, searching inside...") |
|
|
|
|
|
if search_patterns: |
|
|
for pattern in search_patterns: |
|
|
|
|
|
cache_dir_name = pattern.split('/')[-1] if '/' in pattern else pattern |
|
|
cache_subdir = os.path.join(primary_path, cache_dir_name) |
|
|
if os.path.exists(cache_subdir): |
|
|
update_progress(f" Found cache subdir: {cache_dir_name}") |
|
|
|
|
|
snap_path = os.path.join(cache_subdir, "snapshots") |
|
|
if os.path.exists(snap_path): |
|
|
try: |
|
|
snapshot_dirs = [d for d in os.listdir(snap_path) if os.path.isdir(os.path.join(snap_path, d))] |
|
|
if snapshot_dirs: |
|
|
full_path = os.path.join(snap_path, snapshot_dirs[0]) |
|
|
update_progress(f" Checking snapshot: {snapshot_dirs[0]}") |
|
|
|
|
|
|
|
|
|
|
|
has_config = os.path.exists(os.path.join(full_path, "config.json")) |
|
|
has_model_index = os.path.exists(os.path.join(full_path, "model_index.json")) |
|
|
has_subdirs = any(os.path.isdir(os.path.join(full_path, d)) for d in os.listdir(full_path)) |
|
|
has_model_files = any(f.endswith(('.bin', '.safetensors', '.pth')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))) |
|
|
|
|
|
if has_config or has_model_index or has_model_files or has_subdirs: |
|
|
update_progress(f"β
Found {model_name} in snapshot: {full_path}") |
|
|
return full_path |
|
|
else: |
|
|
update_progress(f" β οΈ Snapshot exists but appears empty or invalid") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Error in snapshot: {e}") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Error checking primary path: {e}") |
|
|
|
|
|
|
|
|
if search_patterns: |
|
|
for pattern in search_patterns: |
|
|
alt_path = os.path.join(MODEL_CACHE, pattern) |
|
|
if os.path.exists(alt_path): |
|
|
update_progress(f" Checking cache: {pattern}") |
|
|
|
|
|
snap_path = os.path.join(alt_path, "snapshots") |
|
|
if os.path.exists(snap_path): |
|
|
try: |
|
|
snapshot_dirs = [d for d in os.listdir(snap_path) if os.path.isdir(os.path.join(snap_path, d))] |
|
|
if snapshot_dirs: |
|
|
full_path = os.path.join(snap_path, snapshot_dirs[0]) |
|
|
|
|
|
has_config = os.path.exists(os.path.join(full_path, "config.json")) |
|
|
has_model_index = os.path.exists(os.path.join(full_path, "model_index.json")) |
|
|
has_subdirs = any(os.path.isdir(os.path.join(full_path, d)) for d in os.listdir(full_path)) |
|
|
has_model_files = any(f.endswith(('.bin', '.safetensors', '.pth')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))) |
|
|
|
|
|
if has_config or has_model_index or has_model_files or has_subdirs: |
|
|
update_progress(f"β
Found {model_name} in snapshot: {full_path}") |
|
|
return full_path |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Error searching snapshots: {e}") |
|
|
|
|
|
update_progress(f"β οΈ Could not find {model_name} in any location") |
|
|
return None |
|
|
vae_path = find_model_path( |
|
|
f"{MODEL_CACHE}/sd-vae-ft-mse", |
|
|
"VAE", |
|
|
["models--stabilityai--sd-vae-ft-mse"] |
|
|
) |
|
|
|
|
|
sd_path = find_model_path( |
|
|
f"{MODEL_CACHE}/stable-diffusion-v1-5", |
|
|
"SD v1.5", |
|
|
["models--runwayml--stable-diffusion-v1-5"] |
|
|
) |
|
|
|
|
|
|
|
|
encoder_path = None |
|
|
update_progress(f"π Searching for Image Encoder...") |
|
|
|
|
|
|
|
|
image_encoder_base = f"{MODEL_CACHE}/image_encoder_full" |
|
|
if os.path.exists(image_encoder_base): |
|
|
try: |
|
|
contents = os.listdir(image_encoder_base) |
|
|
update_progress(f" π image_encoder_full contains: {contents}") |
|
|
|
|
|
|
|
|
hf_cache_dir = os.path.join(image_encoder_base, "models--lambdalabs--sd-image-variations-diffusers") |
|
|
if os.path.exists(hf_cache_dir): |
|
|
update_progress(f" β Found HF cache directory") |
|
|
|
|
|
snapshots_dir = os.path.join(hf_cache_dir, "snapshots") |
|
|
if os.path.exists(snapshots_dir): |
|
|
snapshot_dirs = [d for d in os.listdir(snapshots_dir) if os.path.isdir(os.path.join(snapshots_dir, d))] |
|
|
if snapshot_dirs: |
|
|
snapshot_path = os.path.join(snapshots_dir, snapshot_dirs[0]) |
|
|
update_progress(f" β Found snapshot: {snapshot_dirs[0]}") |
|
|
|
|
|
img_enc_path = os.path.join(snapshot_path, "image_encoder") |
|
|
if os.path.exists(img_enc_path) and os.path.exists(os.path.join(img_enc_path, "config.json")): |
|
|
encoder_path = img_enc_path |
|
|
update_progress(f"β
Found Image Encoder at: {img_enc_path}") |
|
|
elif os.path.exists(os.path.join(snapshot_path, "config.json")): |
|
|
encoder_path = snapshot_path |
|
|
update_progress(f"β
Found Image Encoder at: {snapshot_path}") |
|
|
except Exception as e: |
|
|
update_progress(f" β οΈ Error navigating cache: {e}") |
|
|
|
|
|
|
|
|
if not encoder_path: |
|
|
fallback_paths = [ |
|
|
f"{MODEL_CACHE}/image_encoder_full/image_encoder", |
|
|
f"{MODEL_CACHE}/models--lambdalabs--sd-image-variations-diffusers/snapshots/*/image_encoder", |
|
|
] |
|
|
for path_pattern in fallback_paths: |
|
|
if '*' in path_pattern: |
|
|
import glob |
|
|
matches = glob.glob(path_pattern) |
|
|
if matches and os.path.exists(os.path.join(matches[0], "config.json")): |
|
|
encoder_path = matches[0] |
|
|
update_progress(f"β
Found Image Encoder via glob: {encoder_path}") |
|
|
break |
|
|
elif os.path.exists(path_pattern) and os.path.exists(os.path.join(path_pattern, "config.json")): |
|
|
encoder_path = path_pattern |
|
|
update_progress(f"β
Found Image Encoder at: {path_pattern}") |
|
|
break |
|
|
|
|
|
mimo_weights_path = find_model_path( |
|
|
f"{MODEL_CACHE}/mimo_weights", |
|
|
"MIMO Weights", |
|
|
["models--menyifang--MIMO"] |
|
|
) |
|
|
|
|
|
|
|
|
missing = [] |
|
|
if not vae_path: |
|
|
missing.append("VAE") |
|
|
update_progress(f"β VAE path not found") |
|
|
if not sd_path: |
|
|
missing.append("SD v1.5") |
|
|
update_progress(f"β SD v1.5 path not found") |
|
|
if not encoder_path: |
|
|
missing.append("Image Encoder") |
|
|
update_progress(f"β Image Encoder path not found") |
|
|
if not mimo_weights_path: |
|
|
missing.append("MIMO Weights") |
|
|
update_progress(f"β MIMO Weights path not found") |
|
|
|
|
|
if missing: |
|
|
error_msg = f"Missing required models: {', '.join(missing)}. Please run 'Setup Models' first." |
|
|
update_progress(f"β {error_msg}") |
|
|
|
|
|
try: |
|
|
cache_contents = os.listdir(MODEL_CACHE) if os.path.exists(MODEL_CACHE) else [] |
|
|
update_progress(f"π MODEL_CACHE contents: {cache_contents[:15]}") |
|
|
except: |
|
|
pass |
|
|
return False |
|
|
|
|
|
update_progress("β
All required models found") |
|
|
|
|
|
|
|
|
if DEVICE == "cuda": |
|
|
try: |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
self.weight_dtype = torch.float16 if gpu_memory > 10 else torch.float32 |
|
|
update_progress(f"Using {'FP16' if self.weight_dtype == torch.float16 else 'FP32'} on GPU ({gpu_memory:.1f}GB)") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ GPU detection failed: {e}, using FP32") |
|
|
self.weight_dtype = torch.float32 |
|
|
else: |
|
|
self.weight_dtype = torch.float32 |
|
|
update_progress("Using FP32 on CPU") |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Loading VAE...") |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
vae_path, |
|
|
torch_dtype=self.weight_dtype |
|
|
) |
|
|
update_progress("β
VAE loaded (on CPU)") |
|
|
except Exception as e: |
|
|
update_progress(f"β VAE loading failed: {str(e)[:100]}") |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Loading Reference UNet...") |
|
|
reference_unet = UNet2DConditionModel.from_pretrained( |
|
|
sd_path, |
|
|
subfolder="unet", |
|
|
torch_dtype=self.weight_dtype |
|
|
) |
|
|
update_progress("β
Reference UNet loaded (on CPU)") |
|
|
except Exception as e: |
|
|
update_progress(f"β Reference UNet loading failed: {str(e)[:100]}") |
|
|
raise |
|
|
|
|
|
|
|
|
config_path = "./configs/inference/inference_v2.yaml" |
|
|
if os.path.exists(config_path): |
|
|
infer_config = OmegaConf.load(config_path) |
|
|
update_progress("β
Loaded inference config") |
|
|
else: |
|
|
|
|
|
update_progress("Creating fallback inference config...") |
|
|
infer_config = OmegaConf.create({ |
|
|
"unet_additional_kwargs": { |
|
|
"use_inflated_groupnorm": True, |
|
|
"unet_use_cross_frame_attention": False, |
|
|
"unet_use_temporal_attention": False, |
|
|
"use_motion_module": True, |
|
|
"motion_module_resolutions": [1, 2, 4, 8], |
|
|
"motion_module_mid_block": True, |
|
|
"motion_module_decoder_only": False, |
|
|
"motion_module_type": "Vanilla", |
|
|
"motion_module_kwargs": { |
|
|
"num_attention_heads": 8, |
|
|
"num_transformer_block": 1, |
|
|
"attention_block_types": ["Temporal_Self", "Temporal_Self"], |
|
|
"temporal_position_encoding": True, |
|
|
"temporal_position_encoding_max_len": 32, |
|
|
"temporal_attention_dim_div": 1 |
|
|
} |
|
|
}, |
|
|
"noise_scheduler_kwargs": { |
|
|
"beta_start": 0.00085, |
|
|
"beta_end": 0.012, |
|
|
"beta_schedule": "scaled_linear", |
|
|
"clip_sample": False, |
|
|
"steps_offset": 1, |
|
|
"prediction_type": "v_prediction", |
|
|
"rescale_betas_zero_snr": True, |
|
|
"timestep_spacing": "trailing" |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Loading Denoising UNet (3D)...") |
|
|
denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
|
|
sd_path, |
|
|
"", |
|
|
subfolder="unet", |
|
|
unet_additional_kwargs=infer_config.unet_additional_kwargs |
|
|
) |
|
|
|
|
|
denoising_unet = denoising_unet.to(dtype=self.weight_dtype) |
|
|
update_progress("β
Denoising UNet loaded (on CPU)") |
|
|
except Exception as e: |
|
|
update_progress(f"β Denoising UNet loading failed: {str(e)[:100]}") |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Loading Pose Guider...") |
|
|
pose_guider = PoseGuider( |
|
|
320, |
|
|
conditioning_channels=3, |
|
|
block_out_channels=(16, 32, 96, 256) |
|
|
).to(dtype=self.weight_dtype) |
|
|
update_progress("β
Pose Guider initialized (on CPU)") |
|
|
except Exception as e: |
|
|
update_progress(f"β Pose Guider loading failed: {str(e)[:100]}") |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Loading CLIP Image Encoder...") |
|
|
image_enc = CLIPVisionModelWithProjection.from_pretrained( |
|
|
encoder_path, |
|
|
torch_dtype=self.weight_dtype |
|
|
) |
|
|
update_progress("β
Image Encoder loaded (on CPU)") |
|
|
except Exception as e: |
|
|
update_progress(f"β Image Encoder loading failed: {str(e)[:100]}") |
|
|
raise |
|
|
|
|
|
|
|
|
update_progress("Loading Scheduler...") |
|
|
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) |
|
|
scheduler = DDIMScheduler(**sched_kwargs) |
|
|
|
|
|
|
|
|
update_progress("Loading MIMO pretrained weights...") |
|
|
weight_files = list(Path(mimo_weights_path).rglob("*.pth")) |
|
|
|
|
|
if not weight_files: |
|
|
error_msg = f"No MIMO weight files (.pth) found at {mimo_weights_path}. Please run 'Setup Models' to download them." |
|
|
update_progress(f"β {error_msg}") |
|
|
return False |
|
|
|
|
|
update_progress(f"Found {len(weight_files)} weight files") |
|
|
weights_loaded = 0 |
|
|
|
|
|
for weight_file in weight_files: |
|
|
try: |
|
|
weight_name = weight_file.name |
|
|
if "denoising_unet" in weight_name: |
|
|
state_dict = torch.load(weight_file, map_location="cpu") |
|
|
denoising_unet.load_state_dict(state_dict, strict=False) |
|
|
update_progress(f"β
Loaded {weight_name}") |
|
|
weights_loaded += 1 |
|
|
elif "reference_unet" in weight_name: |
|
|
state_dict = torch.load(weight_file, map_location="cpu") |
|
|
reference_unet.load_state_dict(state_dict) |
|
|
update_progress(f"β
Loaded {weight_name}") |
|
|
weights_loaded += 1 |
|
|
elif "pose_guider" in weight_name: |
|
|
state_dict = torch.load(weight_file, map_location="cpu") |
|
|
pose_guider.load_state_dict(state_dict) |
|
|
update_progress(f"β
Loaded {weight_name}") |
|
|
weights_loaded += 1 |
|
|
elif "motion_module" in weight_name: |
|
|
|
|
|
state_dict = torch.load(weight_file, map_location="cpu") |
|
|
denoising_unet.load_state_dict(state_dict, strict=False) |
|
|
update_progress(f"β
Loaded {weight_name}") |
|
|
weights_loaded += 1 |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Failed to load {weight_file.name}: {str(e)[:100]}") |
|
|
print(f"Full error for {weight_file.name}: {e}") |
|
|
|
|
|
if weights_loaded == 0: |
|
|
error_msg = "No MIMO weights were successfully loaded" |
|
|
update_progress(f"β {error_msg}") |
|
|
return False |
|
|
|
|
|
update_progress(f"β
Loaded {weights_loaded}/{len(weight_files)} weight files") |
|
|
|
|
|
|
|
|
try: |
|
|
update_progress("Creating MIMO pipeline...") |
|
|
self.pipe = Pose2VideoPipeline( |
|
|
vae=vae, |
|
|
image_encoder=image_enc, |
|
|
reference_unet=reference_unet, |
|
|
denoising_unet=denoising_unet, |
|
|
pose_guider=pose_guider, |
|
|
scheduler=scheduler, |
|
|
).to(dtype=self.weight_dtype) |
|
|
|
|
|
|
|
|
if HAS_SPACES: |
|
|
try: |
|
|
|
|
|
if hasattr(denoising_unet, 'enable_gradient_checkpointing'): |
|
|
denoising_unet.enable_gradient_checkpointing() |
|
|
if hasattr(reference_unet, 'enable_gradient_checkpointing'): |
|
|
reference_unet.enable_gradient_checkpointing() |
|
|
|
|
|
try: |
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
update_progress("β
Memory-efficient attention enabled") |
|
|
except: |
|
|
update_progress("β οΈ xformers not available, using standard attention") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Could not enable memory optimizations: {str(e)[:50]}") |
|
|
|
|
|
update_progress("β
Pipeline created (on CPU - will use GPU during generation)") |
|
|
except Exception as e: |
|
|
update_progress(f"β Pipeline creation failed: {str(e)[:100]}") |
|
|
raise |
|
|
|
|
|
|
|
|
update_progress("Loading human segmenter...") |
|
|
if HAS_SEGMENTER: |
|
|
seg_path = f"{ASSETS_CACHE}/matting_human.pb" |
|
|
if os.path.exists(seg_path): |
|
|
try: |
|
|
self.segmenter = human_segmenter(model_path=seg_path) |
|
|
update_progress("β
Human segmenter loaded") |
|
|
except Exception as e: |
|
|
update_progress(f"β οΈ Segmenter load failed: {e}, using fallback") |
|
|
self.segmenter = None |
|
|
else: |
|
|
update_progress("β οΈ Segmenter model not found, using fallback") |
|
|
self.segmenter = None |
|
|
else: |
|
|
update_progress("β οΈ TensorFlow not available, using fallback segmentation") |
|
|
self.segmenter = None |
|
|
|
|
|
|
|
|
update_progress("Loading mask templates...") |
|
|
mask_path = f"{ASSETS_CACHE}/masks/alpha2.png" |
|
|
if os.path.exists(mask_path): |
|
|
self.mask_list = load_mask_list(mask_path) |
|
|
update_progress("β
Mask templates loaded") |
|
|
else: |
|
|
|
|
|
update_progress("Creating fallback masks...") |
|
|
os.makedirs(f"{ASSETS_CACHE}/masks", exist_ok=True) |
|
|
fallback_mask = np.ones((512, 512), dtype=np.uint8) * 255 |
|
|
self.mask_list = [fallback_mask] |
|
|
|
|
|
self.is_loaded = True |
|
|
update_progress("π MIMO model loaded successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
update_progress(f"β Model loading failed: {e}") |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
def process_image(self, image): |
|
|
"""Process input image with human segmentation (matching run_edit.py/run_animate.py)""" |
|
|
if self.segmenter is None: |
|
|
|
|
|
image = np.array(image) |
|
|
image = cv2.resize(image, (512, 512)) |
|
|
return Image.fromarray(image), None |
|
|
|
|
|
try: |
|
|
img_array = np.array(image) |
|
|
|
|
|
rgba = self.segmenter.run(img_array[..., ::-1]) |
|
|
mask = rgba[:, :, 3] |
|
|
color = rgba[:, :, :3] |
|
|
alpha = mask / 255 |
|
|
bk = np.ones_like(color) * 255 |
|
|
color = color * alpha[:, :, np.newaxis] + bk * (1 - alpha[:, :, np.newaxis]) |
|
|
color = color.astype(np.uint8) |
|
|
|
|
|
color = color[..., ::-1] |
|
|
|
|
|
|
|
|
color = crop_img(color, mask) |
|
|
color, _ = pad_img(color, [255, 255, 255]) |
|
|
|
|
|
return Image.fromarray(color), mask |
|
|
except Exception as e: |
|
|
print(f"β οΈ Segmentation failed, using original image: {e}") |
|
|
return image, None |
|
|
|
|
|
def get_available_templates(self): |
|
|
"""Get list of available video templates""" |
|
|
template_dir = "./assets/video_template" |
|
|
|
|
|
|
|
|
if not os.path.exists(template_dir): |
|
|
os.makedirs(template_dir, exist_ok=True) |
|
|
print(f"β οΈ Video template directory created: {template_dir}") |
|
|
print("π‘ Tip: Download templates from HuggingFace repo or use 'Setup Models' button") |
|
|
return [] |
|
|
|
|
|
templates = [] |
|
|
try: |
|
|
for item in os.listdir(template_dir): |
|
|
template_path = os.path.join(template_dir, item) |
|
|
if os.path.isdir(template_path): |
|
|
|
|
|
sdc_file = os.path.join(template_path, "sdc.mp4") |
|
|
if os.path.exists(sdc_file): |
|
|
templates.append(item) |
|
|
except Exception as e: |
|
|
print(f"β οΈ Error scanning templates: {e}") |
|
|
return [] |
|
|
|
|
|
if not templates: |
|
|
print("β οΈ No video templates found. Click 'Setup Models' to download.") |
|
|
|
|
|
return sorted(templates) |
|
|
|
|
|
def load_template(self, template_path: str) -> Dict: |
|
|
"""Load template metadata (matching run_edit.py logic)""" |
|
|
try: |
|
|
video_path = os.path.join(template_path, 'vid.mp4') |
|
|
pose_video_path = os.path.join(template_path, 'sdc.mp4') |
|
|
bk_video_path = os.path.join(template_path, 'bk.mp4') |
|
|
occ_video_path = os.path.join(template_path, 'occ.mp4') |
|
|
|
|
|
|
|
|
if not os.path.exists(occ_video_path): |
|
|
occ_video_path = None |
|
|
|
|
|
|
|
|
config_file = os.path.join(template_path, 'config.json') |
|
|
if os.path.exists(config_file): |
|
|
with open(config_file) as f: |
|
|
template_data = json.load(f) |
|
|
|
|
|
return { |
|
|
'video_path': video_path, |
|
|
'pose_video_path': pose_video_path, |
|
|
'bk_video_path': bk_video_path if os.path.exists(bk_video_path) else None, |
|
|
'occ_video_path': occ_video_path, |
|
|
'target_fps': template_data.get('fps', 30), |
|
|
'time_crop': template_data.get('time_crop', {'start_idx': 0, 'end_idx': -1}), |
|
|
'frame_crop': template_data.get('frame_crop', {}), |
|
|
'layer_recover': template_data.get('layer_recover', True) |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
'video_path': video_path if os.path.exists(video_path) else None, |
|
|
'pose_video_path': pose_video_path, |
|
|
'bk_video_path': bk_video_path if os.path.exists(bk_video_path) else None, |
|
|
'occ_video_path': occ_video_path, |
|
|
'target_fps': 30, |
|
|
'time_crop': {'start_idx': 0, 'end_idx': -1}, |
|
|
'frame_crop': {}, |
|
|
'layer_recover': True |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"β οΈ Failed to load template config: {e}") |
|
|
return None |
|
|
|
|
|
def generate_animation(self, input_image, template_name, mode="edit", progress_callback=None): |
|
|
"""Generate video animation (implementing both run_edit.py and run_animate.py logic)""" |
|
|
|
|
|
def update_progress(msg): |
|
|
if progress_callback: |
|
|
progress_callback(msg) |
|
|
print(f"π¬ {msg}") |
|
|
|
|
|
try: |
|
|
if not self.is_loaded: |
|
|
update_progress("Loading model first...") |
|
|
if not self.load_model(progress_callback): |
|
|
return None, "β Model loading failed" |
|
|
|
|
|
|
|
|
if HAS_SPACES and torch.cuda.is_available(): |
|
|
update_progress("Moving models to GPU...") |
|
|
self.pipe = self.pipe.to("cuda") |
|
|
update_progress("β
Models on GPU") |
|
|
|
|
|
|
|
|
update_progress("Processing input image...") |
|
|
processed_image, mask = self.process_image(input_image) |
|
|
|
|
|
|
|
|
template_path = f"./assets/video_template/{template_name}" |
|
|
if not os.path.exists(template_path): |
|
|
return None, f"β Template '{template_name}' not found" |
|
|
|
|
|
template_info = self.load_template(template_path) |
|
|
if template_info is None: |
|
|
return None, f"β Failed to load template '{template_name}'" |
|
|
|
|
|
update_progress(f"Loaded template: {template_name}") |
|
|
|
|
|
|
|
|
target_fps = template_info['target_fps'] |
|
|
pose_video_path = template_info['pose_video_path'] |
|
|
|
|
|
if not os.path.exists(pose_video_path): |
|
|
return None, f"β Pose video not found: {pose_video_path}" |
|
|
|
|
|
|
|
|
update_progress("Loading motion sequence...") |
|
|
pose_images = load_video_fixed_fps(pose_video_path, target_fps=target_fps) |
|
|
|
|
|
|
|
|
bk_video_path = template_info['bk_video_path'] |
|
|
if bk_video_path and os.path.exists(bk_video_path): |
|
|
bk_images = load_video_fixed_fps(bk_video_path, target_fps=target_fps) |
|
|
update_progress("β
Loaded background video") |
|
|
else: |
|
|
|
|
|
n_frame = len(pose_images) |
|
|
tw, th = pose_images[0].size |
|
|
bk_images = [] |
|
|
for _ in range(n_frame): |
|
|
bk_img = Image.new('RGB', (tw, th), (255, 255, 255)) |
|
|
bk_images.append(bk_img) |
|
|
update_progress("β
Created white background") |
|
|
|
|
|
|
|
|
occ_video_path = template_info['occ_video_path'] |
|
|
if occ_video_path and os.path.exists(occ_video_path) and mode == "edit": |
|
|
occ_mask_images = load_video_fixed_fps(occ_video_path, target_fps=target_fps) |
|
|
update_progress("β
Loaded occlusion masks") |
|
|
else: |
|
|
occ_mask_images = None |
|
|
|
|
|
|
|
|
time_crop = template_info['time_crop'] |
|
|
start_idx = max(0, int(target_fps * time_crop['start_idx'] / 30)) if time_crop['start_idx'] >= 0 else 0 |
|
|
end_idx = min(len(pose_images), int(target_fps * time_crop['end_idx'] / 30)) if time_crop['end_idx'] >= 0 else len(pose_images) |
|
|
|
|
|
pose_images = pose_images[start_idx:end_idx] |
|
|
bk_images = bk_images[start_idx:end_idx] |
|
|
if occ_mask_images: |
|
|
occ_mask_images = occ_mask_images[start_idx:end_idx] |
|
|
|
|
|
|
|
|
|
|
|
MAX_FRAMES = 100 if HAS_SPACES else 150 |
|
|
if len(pose_images) > MAX_FRAMES: |
|
|
update_progress(f"β οΈ Limiting to {MAX_FRAMES} frames to fit in GPU memory") |
|
|
pose_images = pose_images[:MAX_FRAMES] |
|
|
bk_images = bk_images[:MAX_FRAMES] |
|
|
if occ_mask_images: |
|
|
occ_mask_images = occ_mask_images[:MAX_FRAMES] |
|
|
|
|
|
num_frames = len(pose_images) |
|
|
update_progress(f"Processing {num_frames} frames...") |
|
|
|
|
|
if mode == "animate": |
|
|
|
|
|
pose_list = [] |
|
|
vid_bk_list = [] |
|
|
|
|
|
|
|
|
pose_images, _, bk_images = crop_human(pose_images, pose_images.copy(), bk_images) |
|
|
|
|
|
for frame_idx in range(len(pose_images)): |
|
|
pose_image = np.array(pose_images[frame_idx]) |
|
|
pose_image, _ = pad_img(pose_image, color=[0, 0, 0]) |
|
|
pose_list.append(Image.fromarray(pose_image)) |
|
|
|
|
|
vid_bk = np.array(bk_images[frame_idx]) |
|
|
vid_bk, _ = pad_img(vid_bk, color=[255, 255, 255]) |
|
|
vid_bk_list.append(Image.fromarray(vid_bk)) |
|
|
|
|
|
|
|
|
update_progress("Generating animation...") |
|
|
width, height = 512, 512 |
|
|
steps = 20 |
|
|
cfg = 3.5 |
|
|
|
|
|
generator = torch.Generator(device=DEVICE).manual_seed(42) |
|
|
video = self.pipe( |
|
|
processed_image, |
|
|
pose_list, |
|
|
vid_bk_list, |
|
|
width, |
|
|
height, |
|
|
num_frames, |
|
|
steps, |
|
|
cfg, |
|
|
generator=generator, |
|
|
).videos[0] |
|
|
|
|
|
|
|
|
update_progress("Post-processing video...") |
|
|
res_images = [] |
|
|
for video_idx in range(num_frames): |
|
|
image = video[:, video_idx, :, :].permute(1, 2, 0).cpu().numpy() |
|
|
res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) |
|
|
res_images.append(res_image_pil) |
|
|
|
|
|
else: |
|
|
|
|
|
update_progress("Advanced video editing mode...") |
|
|
|
|
|
|
|
|
video_path = template_info['video_path'] |
|
|
if video_path and os.path.exists(video_path): |
|
|
vid_images = load_video_fixed_fps(video_path, target_fps=target_fps) |
|
|
vid_images = vid_images[start_idx:end_idx][:MAX_FRAMES] |
|
|
else: |
|
|
vid_images = pose_images.copy() |
|
|
|
|
|
|
|
|
overlay = 4 |
|
|
pose_images, vid_images, bk_images, bbox_clip, context_list, bbox_clip_list = crop_human_clip_auto_context( |
|
|
pose_images, vid_images, bk_images, overlay) |
|
|
|
|
|
|
|
|
clip_pad_list_context = [] |
|
|
clip_padv_list_context = [] |
|
|
pose_list_context = [] |
|
|
vid_bk_list_context = [] |
|
|
|
|
|
for frame_idx in range(len(pose_images)): |
|
|
pose_image = np.array(pose_images[frame_idx]) |
|
|
pose_image, _ = pad_img(pose_image, color=[0, 0, 0]) |
|
|
pose_list_context.append(Image.fromarray(pose_image)) |
|
|
|
|
|
vid_bk = np.array(bk_images[frame_idx]) |
|
|
vid_bk, padding_v = pad_img(vid_bk, color=[255, 255, 255]) |
|
|
pad_h, pad_w, _ = vid_bk.shape |
|
|
clip_pad_list_context.append([pad_h, pad_w]) |
|
|
clip_padv_list_context.append(padding_v) |
|
|
vid_bk_list_context.append(Image.fromarray(vid_bk)) |
|
|
|
|
|
|
|
|
width, height = 784, 784 |
|
|
steps = 25 |
|
|
cfg = 3.5 |
|
|
|
|
|
generator = torch.Generator(device=DEVICE).manual_seed(42) |
|
|
video = self.pipe( |
|
|
processed_image, |
|
|
pose_list_context, |
|
|
vid_bk_list_context, |
|
|
width, |
|
|
height, |
|
|
len(pose_list_context), |
|
|
steps, |
|
|
cfg, |
|
|
generator=generator, |
|
|
).videos[0] |
|
|
|
|
|
|
|
|
update_progress("Advanced post-processing...") |
|
|
vid_images_ori = vid_images.copy() |
|
|
bk_images_ori = bk_images.copy() |
|
|
|
|
|
video_idx = 0 |
|
|
res_images = [None for _ in range(len(pose_images))] |
|
|
|
|
|
for k, context in enumerate(context_list): |
|
|
start_i = context[0] |
|
|
bbox = bbox_clip_list[k] |
|
|
|
|
|
for i in context: |
|
|
bk_image_pil_ori = bk_images_ori[i] |
|
|
vid_image_pil_ori = vid_images_ori[i] |
|
|
occ_mask = occ_mask_images[i] if occ_mask_images else None |
|
|
|
|
|
canvas = Image.new("RGB", bk_image_pil_ori.size, "white") |
|
|
|
|
|
pad_h, pad_w = clip_pad_list_context[video_idx] |
|
|
padding_v = clip_padv_list_context[video_idx] |
|
|
|
|
|
image = video[:, video_idx, :, :].permute(1, 2, 0).cpu().numpy() |
|
|
res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) |
|
|
res_image_pil = res_image_pil.resize((pad_w, pad_h)) |
|
|
|
|
|
top, bottom, left, right = padding_v |
|
|
res_image_pil = res_image_pil.crop((left, top, pad_w - right, pad_h - bottom)) |
|
|
|
|
|
w_min, w_max, h_min, h_max = bbox |
|
|
canvas.paste(res_image_pil, (w_min, h_min)) |
|
|
|
|
|
|
|
|
mask_full = np.zeros((bk_image_pil_ori.size[1], bk_image_pil_ori.size[0]), dtype=np.float32) |
|
|
mask = get_mask(self.mask_list, bbox, bk_image_pil_ori) |
|
|
mask = cv2.resize(mask, res_image_pil.size, interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
canvas_h, canvas_w = mask_full.shape |
|
|
mask_h, mask_w = mask.shape |
|
|
|
|
|
|
|
|
h_end = min(h_min + mask_h, canvas_h) |
|
|
w_end = min(w_min + mask_w, canvas_w) |
|
|
|
|
|
|
|
|
actual_h = h_end - h_min |
|
|
actual_w = w_end - w_min |
|
|
|
|
|
mask_full[h_min:h_end, w_min:w_end] = mask[:actual_h, :actual_w] |
|
|
|
|
|
res_image = np.array(canvas) |
|
|
bk_image = np.array(bk_image_pil_ori) |
|
|
res_image = res_image * mask_full[:, :, np.newaxis] + bk_image * (1 - mask_full[:, :, np.newaxis]) |
|
|
|
|
|
|
|
|
if occ_mask is not None: |
|
|
vid_image = np.array(vid_image_pil_ori) |
|
|
occ_mask_array = np.array(occ_mask)[:, :, 0].astype(np.uint8) |
|
|
occ_mask_array = occ_mask_array / 255.0 |
|
|
|
|
|
|
|
|
if occ_mask_array.shape[:2] != res_image.shape[:2]: |
|
|
occ_mask_array = cv2.resize(occ_mask_array, (res_image.shape[1], res_image.shape[0]), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
if vid_image.shape[:2] != res_image.shape[:2]: |
|
|
vid_image = cv2.resize(vid_image, (res_image.shape[1], res_image.shape[0]), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
res_image = res_image * (1 - occ_mask_array[:, :, np.newaxis]) + vid_image * occ_mask_array[:, :, np.newaxis] |
|
|
|
|
|
|
|
|
if res_images[i] is None: |
|
|
res_images[i] = res_image |
|
|
else: |
|
|
factor = (i - start_i + 1) / (overlay + 1) |
|
|
res_images[i] = res_images[i] * (1 - factor) + res_image * factor |
|
|
|
|
|
res_images[i] = res_images[i].astype(np.uint8) |
|
|
video_idx += 1 |
|
|
|
|
|
|
|
|
update_progress("Finalizing video encoding...") |
|
|
for i, frame in enumerate(res_images): |
|
|
if frame is not None: |
|
|
h, w = frame.shape[:2] |
|
|
|
|
|
new_h = h if h % 2 == 0 else h - 1 |
|
|
new_w = w if w % 2 == 0 else w - 1 |
|
|
if new_h != h or new_w != w: |
|
|
res_images[i] = frame[:new_h, :new_w] |
|
|
|
|
|
|
|
|
output_path = f"./output/mimo_output_{int(time.time())}.mp4" |
|
|
try: |
|
|
imageio.mimsave(output_path, res_images, fps=target_fps, quality=8, macro_block_size=1) |
|
|
except (OSError, BrokenPipeError) as e: |
|
|
|
|
|
update_progress("β οΈ Retrying with compatible encoding settings...") |
|
|
try: |
|
|
|
|
|
gif_path = output_path.replace('.mp4', '.gif') |
|
|
imageio.mimsave(gif_path, res_images, fps=target_fps, duration=1000/target_fps) |
|
|
output_path = gif_path |
|
|
update_progress("β
Saved as GIF (FFMPEG encoding failed)") |
|
|
except Exception as gif_error: |
|
|
raise Exception(f"Video encoding failed: {str(e)}. GIF fallback also failed: {str(gif_error)}") |
|
|
|
|
|
|
|
|
if HAS_SPACES and torch.cuda.is_available(): |
|
|
update_progress("Cleaning up GPU memory...") |
|
|
self.pipe = self.pipe.to("cpu") |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
update_progress("β
GPU memory released") |
|
|
|
|
|
update_progress("β
Video generated successfully!") |
|
|
return output_path, f"π Generated {len(res_images)} frames at {target_fps}fps using {mode} mode!" |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if HAS_SPACES and torch.cuda.is_available(): |
|
|
try: |
|
|
self.pipe = self.pipe.to("cpu") |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
print("β
GPU memory cleaned up after error") |
|
|
except: |
|
|
pass |
|
|
|
|
|
error_msg = f"β Generation failed: {e}" |
|
|
update_progress(error_msg) |
|
|
traceback.print_exc() |
|
|
return None, error_msg |
|
|
|
|
|
|
|
|
mimo_model = CompleteMIMO() |
|
|
|
|
|
def gradio_interface(): |
|
|
"""Create complete Gradio interface matching README_SETUP.md functionality""" |
|
|
|
|
|
def setup_models(progress=gr.Progress()): |
|
|
"""Setup models with progress tracking""" |
|
|
try: |
|
|
|
|
|
progress(0.1, desc="Starting download...") |
|
|
download_success = mimo_model.download_models(lambda msg: progress(0.3, desc=msg)) |
|
|
|
|
|
if not download_success: |
|
|
return "β οΈ Some downloads failed. Check logs for details. You may still be able to use the app with partial functionality." |
|
|
|
|
|
|
|
|
progress(0.6, desc="Loading models...") |
|
|
load_success = mimo_model.load_model(lambda msg: progress(0.8, desc=msg)) |
|
|
|
|
|
if not load_success: |
|
|
return "β Model loading failed. Please check the logs and try again." |
|
|
|
|
|
progress(1.0, desc="β
Ready!") |
|
|
return "π MIMO is ready! Models loaded successfully. Upload an image and select a template to start." |
|
|
|
|
|
except Exception as e: |
|
|
error_details = str(e) |
|
|
print(f"Setup error: {error_details}") |
|
|
traceback.print_exc() |
|
|
return f"β Setup failed: {error_details[:200]}" |
|
|
|
|
|
|
|
|
if HAS_SPACES: |
|
|
@spaces.GPU(duration=120) |
|
|
def generate_video_gradio(input_image, template_name, mode, progress=gr.Progress()): |
|
|
"""Gradio wrapper for video generation""" |
|
|
if input_image is None: |
|
|
return None, "Please upload an image first" |
|
|
|
|
|
if not template_name: |
|
|
return None, "Please select a motion template" |
|
|
|
|
|
try: |
|
|
progress(0.1, desc="Starting generation...") |
|
|
|
|
|
def progress_callback(msg): |
|
|
progress(0.5, desc=msg) |
|
|
|
|
|
output_path, message = mimo_model.generate_animation( |
|
|
input_image, |
|
|
template_name, |
|
|
mode, |
|
|
progress_callback |
|
|
) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
return output_path, message |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"β Generation failed: {e}" |
|
|
else: |
|
|
|
|
|
def generate_video_gradio(input_image, template_name, mode, progress=gr.Progress()): |
|
|
"""Gradio wrapper for video generation""" |
|
|
if input_image is None: |
|
|
return None, "Please upload an image first" |
|
|
|
|
|
if not template_name: |
|
|
return None, "Please select a motion template" |
|
|
|
|
|
try: |
|
|
progress(0.1, desc="Starting generation...") |
|
|
|
|
|
def progress_callback(msg): |
|
|
progress(0.5, desc=msg) |
|
|
|
|
|
output_path, message = mimo_model.generate_animation( |
|
|
input_image, |
|
|
template_name, |
|
|
mode, |
|
|
progress_callback |
|
|
) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
return output_path, message |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"β Generation failed: {e}" |
|
|
|
|
|
def refresh_templates(): |
|
|
"""Refresh available templates""" |
|
|
templates = mimo_model.get_available_templates() |
|
|
return gr.Dropdown(choices=templates, value=templates[0] if templates else None) |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="MIMO - Complete Character Video Synthesis", |
|
|
theme=gr.themes.Soft(), |
|
|
css=""" |
|
|
.gradio-container { |
|
|
max-width: 1400px; |
|
|
margin: auto; |
|
|
} |
|
|
.header { |
|
|
text-align: center; |
|
|
margin-bottom: 2rem; |
|
|
color: #1a1a1a !important; |
|
|
} |
|
|
.header h1 { |
|
|
color: #2c3e50 !important; |
|
|
margin-bottom: 0.5rem; |
|
|
font-weight: 700; |
|
|
} |
|
|
.header p { |
|
|
color: #34495e !important; |
|
|
margin: 0.5rem 0; |
|
|
font-weight: 500; |
|
|
} |
|
|
.header a { |
|
|
color: #3498db !important; |
|
|
text-decoration: none; |
|
|
margin: 0 0.5rem; |
|
|
font-weight: 600; |
|
|
} |
|
|
.header a:hover { |
|
|
text-decoration: underline; |
|
|
color: #2980b9 !important; |
|
|
} |
|
|
.mode-info { |
|
|
padding: 1rem; |
|
|
margin: 1rem 0; |
|
|
border-radius: 8px; |
|
|
color: #2c3e50 !important; |
|
|
} |
|
|
.mode-info h4 { |
|
|
margin-top: 0; |
|
|
color: #2c3e50 !important; |
|
|
font-weight: 700; |
|
|
} |
|
|
.mode-info p { |
|
|
margin: 0.5rem 0; |
|
|
color: #34495e !important; |
|
|
font-weight: 500; |
|
|
} |
|
|
.mode-info strong { |
|
|
color: #1a1a1a !important; |
|
|
font-weight: 700; |
|
|
} |
|
|
.mode-animate { |
|
|
background: #e8f5e8; |
|
|
border-left: 4px solid #4caf50; |
|
|
} |
|
|
.mode-edit { |
|
|
background: #e3f2fd; |
|
|
border-left: 4px solid #2196f3; |
|
|
} |
|
|
.warning-box { |
|
|
padding: 1rem; |
|
|
background: #fff3cd; |
|
|
border-left: 4px solid #ffc107; |
|
|
margin: 1rem 0; |
|
|
border-radius: 4px; |
|
|
} |
|
|
.warning-box b { |
|
|
color: #856404 !important; |
|
|
font-weight: 700; |
|
|
} |
|
|
.warning-box br + text, .warning-box { |
|
|
color: #856404 !important; |
|
|
} |
|
|
.warning-box, .warning-box * { |
|
|
color: #856404 !important; |
|
|
} |
|
|
.instructions-box { |
|
|
margin-top: 2rem; |
|
|
padding: 1.5rem; |
|
|
background: #f8f9fa; |
|
|
border-radius: 8px; |
|
|
border: 1px solid #dee2e6; |
|
|
} |
|
|
.instructions-box h4 { |
|
|
color: #2c3e50 !important; |
|
|
margin-top: 1rem; |
|
|
margin-bottom: 0.5rem; |
|
|
font-weight: 700; |
|
|
} |
|
|
.instructions-box h4:first-child { |
|
|
margin-top: 0; |
|
|
} |
|
|
.instructions-box ol { |
|
|
color: #495057 !important; |
|
|
line-height: 1.8; |
|
|
} |
|
|
.instructions-box ol li { |
|
|
margin: 0.5rem 0; |
|
|
color: #495057 !important; |
|
|
} |
|
|
.instructions-box ol li strong { |
|
|
color: #1a1a1a !important; |
|
|
font-weight: 700; |
|
|
} |
|
|
.instructions-box p { |
|
|
color: #495057 !important; |
|
|
margin: 0.3rem 0; |
|
|
line-height: 1.6; |
|
|
} |
|
|
.instructions-box p strong { |
|
|
color: #1a1a1a !important; |
|
|
font-weight: 700; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="header"> |
|
|
<h1>π¬ MIMO - Complete Character Video Synthesis</h1> |
|
|
<p>Full implementation matching the original research paper - Character Animation & Video Editing</p> |
|
|
<p> |
|
|
<a href="https://menyifang.github.io/projects/MIMO/index.html">π Project Page</a> | |
|
|
<a href="https://github.com/menyifang/MIMO">π» GitHub</a> | |
|
|
<a href="https://arxiv.org/abs/2409.16160">π Paper</a> |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.HTML("<h3>πΌοΈ Input Configuration</h3>") |
|
|
|
|
|
input_image = gr.Image( |
|
|
label="Character Image", |
|
|
type="pil", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
mode = gr.Radio( |
|
|
label="Generation Mode", |
|
|
choices=[ |
|
|
("π Character Animation", "animate"), |
|
|
("π¬ Video Character Editing", "edit") |
|
|
], |
|
|
value="edit" |
|
|
) |
|
|
|
|
|
|
|
|
templates = mimo_model.get_available_templates() |
|
|
|
|
|
if not templates: |
|
|
gr.HTML(""" |
|
|
<div class="warning-box"> |
|
|
<b>β οΈ No Motion Templates Found</b><br/> |
|
|
Click <b>"π§ Setup Models"</b> button below to download video templates.<br/> |
|
|
Templates will be downloaded to: <code>./assets/video_template/</code> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
motion_template = gr.Dropdown( |
|
|
label="Motion Template", |
|
|
choices=templates if templates else ["No templates - Upload manually or use reference image only"], |
|
|
value=templates[0] if templates else None, |
|
|
info="Templates provide motion guidance. Not required for basic image animation." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
setup_btn = gr.Button("οΏ½ Setup Models", variant="secondary", scale=1) |
|
|
load_btn = gr.Button("β‘ Load Model", variant="secondary", scale=1) |
|
|
|
|
|
with gr.Row(): |
|
|
refresh_btn = gr.Button("οΏ½ Refresh Templates", variant="secondary", scale=1) |
|
|
generate_btn = gr.Button("π¬ Generate Video", variant="primary", scale=2) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.HTML("<h3>π₯ Output</h3>") |
|
|
|
|
|
output_video = gr.Video( |
|
|
label="Generated Video", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
status_text = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
lines=4 |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="mode-info mode-animate"> |
|
|
<h4>π Character Animation Mode</h4> |
|
|
<p><strong>Features:</strong> Character image + motion template β animated video</p> |
|
|
<p><strong>Use case:</strong> Animate static characters with predefined motions</p> |
|
|
<p><strong>Based on:</strong> run_animate.py functionality</p> |
|
|
</div> |
|
|
|
|
|
<div class="mode-info mode-edit"> |
|
|
<h4>π¬ Video Character Editing Mode</h4> |
|
|
<p><strong>Features:</strong> Advanced editing with background blending, occlusion handling</p> |
|
|
<p><strong>Use case:</strong> Replace characters in existing videos while preserving backgrounds</p> |
|
|
<p><strong>Based on:</strong> run_edit.py functionality</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="instructions-box"> |
|
|
<h4>π Instructions:</h4> |
|
|
<ol> |
|
|
<li><strong>First Time Setup:</strong> Click "π§ Setup Models" to download MIMO (~8GB, one-time)</li> |
|
|
<li><strong>Load Model:</strong> Click "β‘ Load Model" to activate the model (required once per session)</li> |
|
|
<li><strong>Upload Image:</strong> Upload a character image (clear, front-facing works best)</li> |
|
|
<li><strong>Select Mode:</strong> Choose between Animation (simpler) or Editing (advanced)</li> |
|
|
<li><strong>Pick Template:</strong> Select a motion template from the dropdown (or refresh to see new ones)</li> |
|
|
<li><strong>Generate:</strong> Click "π¬ Generate Video" and wait for processing</li> |
|
|
</ol> |
|
|
|
|
|
<h4>π― Available Templates (11 total):</h4> |
|
|
<p><strong>Sports:</strong> basketball_gym, nba_dunk, nba_pass, football</p> |
|
|
<p><strong>Action:</strong> kungfu_desert, kungfu_match, parkour_climbing, BruceLee</p> |
|
|
<p><strong>Dance:</strong> dance_indoor, irish_dance</p> |
|
|
<p><strong>Synthetic:</strong> syn_basketball, syn_dancing, syn_football</p> |
|
|
|
|
|
<p><strong>π‘ Model Persistence:</strong> Downloaded models persist across page refreshes! Just click "Load Model" to reactivate.</p> |
|
|
<p><strong>β οΈ Timing:</strong> First setup takes 5-10 minutes. Model loading takes 30-60 seconds. Generation takes 2-5 minutes per video.</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
def load_model_only(progress=gr.Progress()): |
|
|
"""Load models without downloading (if already cached)""" |
|
|
try: |
|
|
|
|
|
if mimo_model.is_loaded: |
|
|
return "β
Model already loaded and ready! You can generate videos now." |
|
|
|
|
|
|
|
|
mimo_model._check_existing_models() |
|
|
|
|
|
if not mimo_model._model_cache_valid: |
|
|
return "β οΈ Models not found in cache. Please click 'π§ Setup Models' first to download (~8GB)." |
|
|
|
|
|
progress(0.3, desc="Loading models from cache...") |
|
|
load_success = mimo_model.load_model(lambda msg: progress(0.7, desc=msg)) |
|
|
|
|
|
if load_success: |
|
|
progress(1.0, desc="β
Ready!") |
|
|
return "β
Model loaded successfully! Ready to generate videos. Upload an image and select a template." |
|
|
else: |
|
|
return "β Model loading failed. Check logs for details or try 'Setup Models' button." |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"β Load failed: {str(e)[:200]}" |
|
|
|
|
|
setup_btn.click( |
|
|
fn=setup_models, |
|
|
outputs=[status_text] |
|
|
) |
|
|
|
|
|
load_btn.click( |
|
|
fn=load_model_only, |
|
|
outputs=[status_text] |
|
|
) |
|
|
|
|
|
refresh_btn.click( |
|
|
fn=refresh_templates, |
|
|
outputs=[motion_template] |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_video_gradio, |
|
|
inputs=[input_image, motion_template, mode], |
|
|
outputs=[output_video, status_text] |
|
|
) |
|
|
|
|
|
|
|
|
example_files = [ |
|
|
["./assets/test_image/sugar.jpg", "sports_basketball_gym", "animate"], |
|
|
["./assets/test_image/avatar.jpg", "dance_indoor_1", "animate"], |
|
|
["./assets/test_image/cartoon1.png", "shorts_kungfu_desert1", "edit"], |
|
|
["./assets/test_image/actorhq_A7S1.png", "syn_basketball_06_13", "edit"], |
|
|
] |
|
|
|
|
|
|
|
|
valid_examples = [ex for ex in example_files if os.path.exists(ex[0])] |
|
|
|
|
|
if valid_examples: |
|
|
gr.Examples( |
|
|
examples=valid_examples, |
|
|
inputs=[input_image, motion_template, mode], |
|
|
label="π― Examples" |
|
|
) |
|
|
else: |
|
|
print("β οΈ No example images found, skipping examples section") |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if os.getenv("SPACE_ID"): |
|
|
print("π Running on HuggingFace Spaces") |
|
|
print("π¦ Models will download on first use to prevent build timeout") |
|
|
else: |
|
|
print("π» Running locally") |
|
|
|
|
|
|
|
|
demo = gradio_interface() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True |
|
|
) |