HuMo_local / app.py
alexnasa's picture
Update app.py
b6e90a6 verified
raw
history blame
16.1 kB
import spaces
import gradio as gr
import sys
import os
import subprocess
import uuid
import shutil
from tqdm import tqdm
from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
import importlib, site
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()
def sh(cmd): subprocess.check_call(cmd, shell=True)
flash_attention_installed = False
try:
flash_attention_wheel = hf_hub_download(
repo_id="alexnasa/flash-attn-3",
repo_type="model",
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
sh(f"pip install {flash_attention_wheel}")
print("Attempting to download and install FlashAttention wheel...")
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
flash_attention_installed = True
except Exception as e:
print(f"⚠️ Could not install FlashAttention: {e}")
print("Continuing without FlashAttention...")
try:
te_wheel = hf_hub_download(
repo_id="alexnasa/transformer_engine_wheels",
repo_type="model",
filename="transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl",
)
sh(f"pip install {te_wheel}")
print("Attempting to download and install Transformer Engine wheel...")
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
except Exception as e:
print(f"⚠️ Could not install Transformer Engine : {e}")
print("Continuing without Transformer Engine ...")
import torch
print(f"Torch version: {torch.__version__}")
print(f"FlashAttention available: {flash_attention_installed}")
import tempfile
from pathlib import Path
from torch._inductor.runtime.runtime_utils import cache_dir as _inductor_cache_dir
from huggingface_hub import HfApi
snapshot_download(repo_id="bytedance-research/HuMo", local_dir="./weights/HuMo")
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./weights/Wan2.1-T2V-1.3B")
snapshot_download(repo_id="openai/whisper-large-v3", local_dir="./weights/whisper-large-v3")
os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
path_to_insert = "humo"
if path_to_insert not in sys.path:
sys.path.insert(0, path_to_insert)
from common.config import load_config, create_object
config = load_config(
"./humo/configs/inference/generate.yaml",
[
"dit.sp_size=1",
"generation.frames=97",
"generation.scale_t=5.5",
"generation.scale_a=5.0",
"generation.mode=TIA",
"generation.height=480",
"generation.width=832",
],
)
runner = create_object(config)
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path
def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
path_in_repo: str = "inductor_cache", repo_type: str = "model",
hf_token: str | None = None):
cache_root = Path(_inductor_cache_dir()).resolve()
cache_root.mkdir(parents=True, exist_ok=True)
zip_path = hf_hub_download(repo_id=repo_id, filename=f"{path_in_repo}/{filename}",
repo_type=repo_type, token=hf_token)
shutil.unpack_archive(zip_path, extract_dir=str(cache_root))
print(f"✓ Restored cache into {cache_root}")
# restore_inductor_cache_from_hub("alexnasa/humo-compiled")
def get_duration(prompt_text, steps, image_file, audio_file_path, max_duration, session_id, progress):
return calculate_required_time(steps, max_duration)
def calculate_required_time(steps, max_duration):
warmup_s = 50
max_duration_duration_mapping = {
20: 3,
45: 7,
70: 13,
95: 21,
}
# Humo 1.7
# max_duration_duration_mapping = {
# 20: 2,
# 45: 2,
# 70: 5,
# 95: 6,
# }
each_step_s = max_duration_duration_mapping[max_duration]
duration_s = (each_step_s * steps) + warmup_s
print(f'estimated duration:{duration_s}')
return int(duration_s)
def get_required_time_string(steps, max_duration):
duration_s = calculate_required_time(steps, max_duration)
duration_m = duration_s / 60
return f"<center>⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)</center>"
def update_required_time(steps, max_duration):
return get_required_time_string(steps, max_duration)
def generate_scene(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None, progress=gr.Progress(),):
prompt_text_check = (prompt_text or "").strip()
if not prompt_text_check:
raise gr.Error("Please enter a prompt.")
if not audio_file_path and not image_paths:
raise gr.Error("Please provide a reference image or a lipsync audio.")
return run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration, session_id, progress)
def upload_inductor_cache_to_hub(
repo_id: str,
path_in_repo: str = "inductor_cache",
repo_type: str = "model", # or "dataset" if you prefer
hf_token: str | None = None,
):
"""
Zips the current TorchInductor cache and uploads it to the given repo path.
Assumes the model was already run once with torch.compile() so the cache exists.
"""
cache_dir = Path(_inductor_cache_dir()).resolve()
if not cache_dir.exists():
raise FileNotFoundError(f"TorchInductor cache not found at {cache_dir}. "
"Run a compiled model once to populate it.")
# Create a zip archive of the entire cache directory
with tempfile.TemporaryDirectory() as tmpdir:
archive_base = Path(tmpdir) / "torch_compile_cache"
archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(cache_dir))
archive_path = Path(archive_path)
# Upload to Hub
api = HfApi(token=hf_token)
api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)
# Put each artifact under path_in_repo, including a tiny metadata stamp for traceability
# Upload the zip
dest_path = f"{path_in_repo}/{archive_path.name}"
api.upload_file(
path_or_fileobj=str(archive_path),
path_in_repo=dest_path,
repo_id=repo_id,
repo_type=repo_type,
)
# Upload a small metadata file (optional but handy)
meta_txt = (
f"pytorch={torch.__version__}\n"
f"inductor_cache_dir={cache_dir}\n"
f"cuda_available={torch.cuda.is_available()}\n"
f"cuda_device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}\n"
)
api.upload_file(
path_or_fileobj=meta_txt.encode(),
path_in_repo=f"{path_in_repo}/INDUCTOR_CACHE_METADATA.txt",
repo_id=repo_id,
repo_type=repo_type,
)
print("✔ Uploaded TorchInductor cache to the Hub.")
@spaces.GPU(duration=get_duration)
def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None, progress=gr.Progress(),):
if session_id is None:
session_id = uuid.uuid4().hex
inference_mode = "TIA"
# Validate inputs
prompt_text = (prompt_text or "").strip()
if not prompt_text:
raise gr.Error("Please enter a prompt.")
if not audio_file_path and not image_paths:
raise gr.Error("Please provide a reference image or a lipsync audio.")
if not audio_file_path:
inference_mode = "TI"
audio_path = None
tmp_audio_path = None
else:
audio_path = audio_file_path if isinstance(audio_file_path, str) else getattr(audio_file_path, "name", str(audio_file_path))
if not image_paths:
inference_mode = "TA"
img_paths = None
else:
img_paths = [image_data[0] for image_data in image_paths]
print(f'{session_id} is using inference_mode:{inference_mode} with steps:{steps} with {max_duration} frames')
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
os.makedirs(output_dir, exist_ok=True)
if audio_path:
def add_silence_to_audio_ffmpeg(audio_path, tmp_audio_path, silence_duration_s=0.5):
command = [
'ffmpeg',
'-i', audio_path,
'-f', 'lavfi',
'-t', str(silence_duration_s),
'-i', 'anullsrc=r=16000:cl=stereo',
'-filter_complex', '[1][0]concat=n=2:v=0:a=1[out]',
'-map', '[out]',
'-y', tmp_audio_path,
'-loglevel', 'quiet'
]
subprocess.run(command, check=True)
tmp_audio_path = os.path.join(output_dir, "tmp_audio.wav")
add_silence_to_audio_ffmpeg(audio_path, tmp_audio_path)
# Random filename
filename = f"gen_{uuid.uuid4().hex[:10]}"
width, height = 832, 480
runner.inference_loop(
prompt_text,
img_paths,
tmp_audio_path,
output_dir,
filename,
inference_mode,
width,
height,
steps,
frames = int(max_duration),
tea_cache_l1_thresh = 0.0,
progress_bar_cmd=progress
)
# Return resulting video path
video_path = os.path.join(output_dir, f"{filename}.mp4")
if os.path.exists(video_path):
# upload_inductor_cache_to_hub("alexnasa/humo-compiled")
return video_path
else:
candidates = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".mp4")]
if candidates:
return max(candidates, key=lambda p: os.path.getmtime(p))
return None
css = """
#col-container {
margin: 0 auto;
width: 100%;
max-width: 720px;
}
"""
def cleanup(request: gr.Request):
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
shutil.rmtree(d1, ignore_errors=True)
def start_session(request: gr.Request):
return request.session_hash
with gr.Blocks(css=css) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
with gr.Sidebar(width=400):
gr.HTML(
"""
<div style="text-align: center;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>HuMo</strong> – Human-Centric Video Generation via Collaborative Multi-Modal Conditioning
</p>
<a href="https://github.com/Phantom-video/HuMo" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
[Github]
</a>
</div>
"""
)
gr.Markdown("**REFERENCE IMAGES**")
img_input = gr.Gallery(
value=["./examples/ali.png"],
show_label=False,
label="",
interactive=True,
rows=1, columns=3, object_fit="contain", height="280",
file_types=['image']
)
gr.Markdown("**LIPSYNC AUDIO**")
audio_input = gr.Audio(
value="./examples/life.wav",
sources=["upload"],
show_label=False,
type="filepath",
)
gr.Markdown("**SETTINGS**")
default_steps = 10
default_max_duration = 45
max_duration = gr.Slider(minimum=45, maximum=95, value=default_max_duration, step=25, label="Frames")
steps_input = gr.Slider(minimum=10, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: center;">
<strong>HF Space by:</strong>
<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
</a>
</div>
"""
)
video_output = gr.Video(show_label=False)
gr.Markdown("<center><h2>PROMPT</h2></center>")
prompt_tb = gr.Textbox(
value="A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
show_label=False,
lines=5,
placeholder="Describe the scene and the person talking....",
)
gr.Markdown("")
time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
run_btn = gr.Button("🎬 Action", variant="primary")
gr.Examples(
examples=[
[
"A handheld tracking shot follows a female through a science lab. Her determined eyes are locked straight ahead. She is explaining something to someone standing opposite her",
10,
["./examples/naomi.png"],
"./examples/science.wav",
70,
],
[
"A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
10,
["./examples/ella.png"],
"./examples/dream.mp3",
45,
],
[
"A reddish-brown haired woman sits pensively against swirling blue-and-white brushstrokes, dressed in a blue coat and dark waistcoat. The artistic backdrop and her thoughtful pose evoke a Post-Impressionist style in a studio-like setting.",
10,
["./examples/art.png"],
"./examples/art.wav",
70,
],
[
"A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead as she grips a blazing torch tightly in her hand. She speaks with intensity.",
10,
["./examples/ella.png"],
"./examples/dream.mp3",
95,
],
[
"A woman with long, wavy dark hair looking at a person sitting opposite her whilst holding a book, wearing a leather jacket, long-sleeved jacket with a semi purple color one seen on a photo. Warm, window-like light bathes her figure, highlighting the outfit's elegant design and her graceful movements.",
40,
["./examples/amber.png", "./examples/jacket.png"],
"./examples/fictional.wav",
70,
],
],
inputs=[prompt_tb, steps_input, img_input, audio_input, max_duration],
outputs=[video_output],
fn=run_pipeline,
cache_examples=True,
)
max_duration.change(update_required_time, [steps_input, max_duration], time_required)
steps_input.change(update_required_time, [steps_input, max_duration], time_required)
run_btn.click(
fn=generate_scene,
inputs=[prompt_tb, steps_input, img_input, audio_input, max_duration, session_state],
outputs=[video_output],
)
if __name__ == "__main__":
demo.unload(cleanup)
demo.queue()
demo.launch(ssr_mode=False)