Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
|
|
|
|
|
|
|
| 5 |
# Add current directory to Python path
|
| 6 |
try:
|
| 7 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
@@ -46,6 +48,7 @@ class NagWanTransformer3DModel(nn.Module):
|
|
| 46 |
self.out_channels = out_channels
|
| 47 |
self.hidden_size = hidden_size
|
| 48 |
self.training = False
|
|
|
|
| 49 |
|
| 50 |
# Dummy config for compatibility
|
| 51 |
self.config = type('Config', (), {
|
|
@@ -67,6 +70,27 @@ class NagWanTransformer3DModel(nn.Module):
|
|
| 67 |
nn.SiLU(),
|
| 68 |
nn.Linear(hidden_size, hidden_size),
|
| 69 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
@staticmethod
|
| 72 |
def attn_processors():
|
|
@@ -423,8 +447,8 @@ from mmaudio.model.utils.features_utils import FeaturesUtils
|
|
| 423 |
|
| 424 |
# Constants
|
| 425 |
MOD_VALUE = 32
|
| 426 |
-
DEFAULT_DURATION_SECONDS =
|
| 427 |
-
DEFAULT_STEPS =
|
| 428 |
DEFAULT_SEED = 2025
|
| 429 |
DEFAULT_H_SLIDER_VALUE = 128
|
| 430 |
DEFAULT_W_SLIDER_VALUE = 128
|
|
@@ -434,9 +458,9 @@ SLIDER_MIN_H, SLIDER_MAX_H = 128, 256
|
|
| 434 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
|
| 435 |
MAX_SEED = np.iinfo(np.int32).max
|
| 436 |
|
| 437 |
-
FIXED_FPS =
|
| 438 |
MIN_FRAMES_MODEL = 8
|
| 439 |
-
MAX_FRAMES_MODEL =
|
| 440 |
|
| 441 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
| 442 |
|
|
@@ -454,6 +478,7 @@ print("Creating demo models...")
|
|
| 454 |
class DemoVAE(nn.Module):
|
| 455 |
def __init__(self):
|
| 456 |
super().__init__()
|
|
|
|
| 457 |
self.encoder = nn.Sequential(
|
| 458 |
nn.Conv2d(3, 64, 3, padding=1),
|
| 459 |
nn.ReLU(),
|
|
@@ -470,6 +495,27 @@ class DemoVAE(nn.Module):
|
|
| 470 |
'latent_channels': 4,
|
| 471 |
})()
|
| 472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
def encode(self, x):
|
| 474 |
# Simple encoding
|
| 475 |
encoded = self.encoder(x)
|
|
@@ -519,18 +565,19 @@ pipe = NAGWanPipeline(
|
|
| 519 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 520 |
print(f"Using device: {device}")
|
| 521 |
|
| 522 |
-
# Move models to device
|
| 523 |
-
vae = vae.to(device)
|
| 524 |
-
transformer = transformer.to(device)
|
| 525 |
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
print("Warning:
|
|
|
|
|
|
|
| 532 |
|
| 533 |
-
# Skip LoRA for demo version
|
| 534 |
print("Demo version ready!")
|
| 535 |
|
| 536 |
# Check if transformer has the required methods
|
|
@@ -748,6 +795,8 @@ def generate_video(
|
|
| 748 |
if hasattr(pipe, 'vae'):
|
| 749 |
pipe.vae = pipe.vae.to(device).to(torch.float32)
|
| 750 |
|
|
|
|
|
|
|
| 751 |
with torch.inference_mode():
|
| 752 |
nag_output_frames_list = pipe(
|
| 753 |
prompt=prompt,
|
|
@@ -785,13 +834,21 @@ def generate_video(
|
|
| 785 |
|
| 786 |
except Exception as e:
|
| 787 |
print(f"Error generating video: {e}")
|
|
|
|
|
|
|
|
|
|
| 788 |
# Return a simple error video
|
| 789 |
-
error_frames =
|
| 790 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
| 792 |
error_video_path = tmpfile.name
|
| 793 |
-
export_to_video(
|
| 794 |
-
return error_video_path, None,
|
| 795 |
|
| 796 |
def update_audio_visibility(audio_mode):
|
| 797 |
return gr.update(visible=(audio_mode == "Enable Audio"))
|
|
@@ -800,8 +857,8 @@ def update_audio_visibility(audio_mode):
|
|
| 800 |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
| 801 |
with gr.Column(elem_classes="container"):
|
| 802 |
gr.HTML("""
|
| 803 |
-
<h1 class="main-title">🎬 NAG Video Demo
|
| 804 |
-
<p class="subtitle">
|
| 805 |
""")
|
| 806 |
|
| 807 |
gr.HTML("""
|
|
@@ -818,8 +875,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 818 |
with gr.Group(elem_classes="prompt-container"):
|
| 819 |
prompt = gr.Textbox(
|
| 820 |
label="✨ Video Prompt",
|
| 821 |
-
|
| 822 |
-
|
|
|
|
| 823 |
elem_classes="prompt-input"
|
| 824 |
)
|
| 825 |
|
|
@@ -831,11 +889,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 831 |
)
|
| 832 |
nag_scale = gr.Slider(
|
| 833 |
label="NAG Scale",
|
| 834 |
-
minimum=
|
| 835 |
maximum=20.0,
|
| 836 |
step=0.25,
|
| 837 |
-
value=
|
| 838 |
-
info="Higher values = stronger guidance"
|
| 839 |
)
|
| 840 |
|
| 841 |
audio_mode = gr.Radio(
|
|
@@ -866,9 +924,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 866 |
)
|
| 867 |
audio_steps = gr.Slider(
|
| 868 |
minimum=1,
|
| 869 |
-
maximum=
|
| 870 |
step=1,
|
| 871 |
-
value=
|
| 872 |
label="🚀 Audio Steps"
|
| 873 |
)
|
| 874 |
audio_cfg_strength = gr.Slider(
|
|
@@ -885,7 +943,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 885 |
with gr.Row():
|
| 886 |
duration_seconds_input = gr.Slider(
|
| 887 |
minimum=1,
|
| 888 |
-
maximum=
|
| 889 |
step=1,
|
| 890 |
value=DEFAULT_DURATION_SECONDS,
|
| 891 |
label="📱 Duration (seconds)",
|
|
@@ -893,7 +951,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 893 |
)
|
| 894 |
steps_slider = gr.Slider(
|
| 895 |
minimum=1,
|
| 896 |
-
maximum=
|
| 897 |
step=1,
|
| 898 |
value=DEFAULT_STEPS,
|
| 899 |
label="🔄 Inference Steps",
|
|
@@ -964,18 +1022,18 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 964 |
gr.Markdown("### 🎯 Example Prompts")
|
| 965 |
gr.Examples(
|
| 966 |
examples=[
|
| 967 |
-
["A
|
| 968 |
-
128, 128,
|
| 969 |
-
|
| 970 |
-
"Enable Audio", "
|
| 971 |
-
["A red
|
| 972 |
-
128, 128,
|
| 973 |
-
|
| 974 |
-
"Enable Audio", "car engine
|
| 975 |
-
["
|
| 976 |
-
128, 128,
|
| 977 |
-
|
| 978 |
-
"Video Only", "", default_audio_negative_prompt, -1,
|
| 979 |
],
|
| 980 |
fn=generate_video,
|
| 981 |
inputs=[prompt, nag_negative_prompt, nag_scale,
|
|
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
|
| 5 |
+
print("Starting NAG Video Demo application...")
|
| 6 |
+
|
| 7 |
# Add current directory to Python path
|
| 8 |
try:
|
| 9 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
| 48 |
self.out_channels = out_channels
|
| 49 |
self.hidden_size = hidden_size
|
| 50 |
self.training = False
|
| 51 |
+
self._dtype = torch.float32 # Add dtype attribute
|
| 52 |
|
| 53 |
# Dummy config for compatibility
|
| 54 |
self.config = type('Config', (), {
|
|
|
|
| 70 |
nn.SiLU(),
|
| 71 |
nn.Linear(hidden_size, hidden_size),
|
| 72 |
)
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def dtype(self):
|
| 76 |
+
"""Return the dtype of the model"""
|
| 77 |
+
return self._dtype
|
| 78 |
+
|
| 79 |
+
@dtype.setter
|
| 80 |
+
def dtype(self, value):
|
| 81 |
+
"""Set the dtype of the model"""
|
| 82 |
+
self._dtype = value
|
| 83 |
+
|
| 84 |
+
def to(self, *args, **kwargs):
|
| 85 |
+
"""Override to method to handle dtype"""
|
| 86 |
+
result = super().to(*args, **kwargs)
|
| 87 |
+
# Update dtype if moving to a specific dtype
|
| 88 |
+
for arg in args:
|
| 89 |
+
if isinstance(arg, torch.dtype):
|
| 90 |
+
self._dtype = arg
|
| 91 |
+
if 'dtype' in kwargs:
|
| 92 |
+
self._dtype = kwargs['dtype']
|
| 93 |
+
return result
|
| 94 |
|
| 95 |
@staticmethod
|
| 96 |
def attn_processors():
|
|
|
|
| 447 |
|
| 448 |
# Constants
|
| 449 |
MOD_VALUE = 32
|
| 450 |
+
DEFAULT_DURATION_SECONDS = 1
|
| 451 |
+
DEFAULT_STEPS = 1
|
| 452 |
DEFAULT_SEED = 2025
|
| 453 |
DEFAULT_H_SLIDER_VALUE = 128
|
| 454 |
DEFAULT_W_SLIDER_VALUE = 128
|
|
|
|
| 458 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
|
| 459 |
MAX_SEED = np.iinfo(np.int32).max
|
| 460 |
|
| 461 |
+
FIXED_FPS = 8 # Reduced FPS for demo
|
| 462 |
MIN_FRAMES_MODEL = 8
|
| 463 |
+
MAX_FRAMES_MODEL = 32 # Reduced max frames for demo
|
| 464 |
|
| 465 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
| 466 |
|
|
|
|
| 478 |
class DemoVAE(nn.Module):
|
| 479 |
def __init__(self):
|
| 480 |
super().__init__()
|
| 481 |
+
self._dtype = torch.float32 # Add dtype attribute
|
| 482 |
self.encoder = nn.Sequential(
|
| 483 |
nn.Conv2d(3, 64, 3, padding=1),
|
| 484 |
nn.ReLU(),
|
|
|
|
| 495 |
'latent_channels': 4,
|
| 496 |
})()
|
| 497 |
|
| 498 |
+
@property
|
| 499 |
+
def dtype(self):
|
| 500 |
+
"""Return the dtype of the model"""
|
| 501 |
+
return self._dtype
|
| 502 |
+
|
| 503 |
+
@dtype.setter
|
| 504 |
+
def dtype(self, value):
|
| 505 |
+
"""Set the dtype of the model"""
|
| 506 |
+
self._dtype = value
|
| 507 |
+
|
| 508 |
+
def to(self, *args, **kwargs):
|
| 509 |
+
"""Override to method to handle dtype"""
|
| 510 |
+
result = super().to(*args, **kwargs)
|
| 511 |
+
# Update dtype if moving to a specific dtype
|
| 512 |
+
for arg in args:
|
| 513 |
+
if isinstance(arg, torch.dtype):
|
| 514 |
+
self._dtype = arg
|
| 515 |
+
if 'dtype' in kwargs:
|
| 516 |
+
self._dtype = kwargs['dtype']
|
| 517 |
+
return result
|
| 518 |
+
|
| 519 |
def encode(self, x):
|
| 520 |
# Simple encoding
|
| 521 |
encoded = self.encoder(x)
|
|
|
|
| 565 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 566 |
print(f"Using device: {device}")
|
| 567 |
|
| 568 |
+
# Move models to device with explicit dtype
|
| 569 |
+
vae = vae.to(device).to(torch.float32)
|
| 570 |
+
transformer = transformer.to(device).to(torch.float32)
|
| 571 |
|
| 572 |
+
# Now move pipeline to device (it will handle the components)
|
| 573 |
+
try:
|
| 574 |
+
pipe = pipe.to(device)
|
| 575 |
+
print(f"Pipeline moved to {device}")
|
| 576 |
+
except Exception as e:
|
| 577 |
+
print(f"Warning: Could not move pipeline to {device}: {e}")
|
| 578 |
+
# Manually set device
|
| 579 |
+
pipe._execution_device = device
|
| 580 |
|
|
|
|
| 581 |
print("Demo version ready!")
|
| 582 |
|
| 583 |
# Check if transformer has the required methods
|
|
|
|
| 795 |
if hasattr(pipe, 'vae'):
|
| 796 |
pipe.vae = pipe.vae.to(device).to(torch.float32)
|
| 797 |
|
| 798 |
+
print(f"Generating video: {target_w}x{target_h}, {num_frames} frames, seed {current_seed}")
|
| 799 |
+
|
| 800 |
with torch.inference_mode():
|
| 801 |
nag_output_frames_list = pipe(
|
| 802 |
prompt=prompt,
|
|
|
|
| 834 |
|
| 835 |
except Exception as e:
|
| 836 |
print(f"Error generating video: {e}")
|
| 837 |
+
import traceback
|
| 838 |
+
traceback.print_exc()
|
| 839 |
+
|
| 840 |
# Return a simple error video
|
| 841 |
+
error_frames = []
|
| 842 |
+
for i in range(8): # Create 8 frames
|
| 843 |
+
frame = np.zeros((128, 128, 3), dtype=np.uint8)
|
| 844 |
+
frame[:, :] = [255, 0, 0] # Red frame
|
| 845 |
+
# Add error text
|
| 846 |
+
error_frames.append(frame)
|
| 847 |
+
|
| 848 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
| 849 |
error_video_path = tmpfile.name
|
| 850 |
+
export_to_video(error_frames, error_video_path, fps=FIXED_FPS)
|
| 851 |
+
return error_video_path, None, 0
|
| 852 |
|
| 853 |
def update_audio_visibility(audio_mode):
|
| 854 |
return gr.update(visible=(audio_mode == "Enable Audio"))
|
|
|
|
| 857 |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
| 858 |
with gr.Column(elem_classes="container"):
|
| 859 |
gr.HTML("""
|
| 860 |
+
<h1 class="main-title">🎬 NAG Video Demo</h1>
|
| 861 |
+
<p class="subtitle">Simple Text-to-Video with NAG + Audio Generation</p>
|
| 862 |
""")
|
| 863 |
|
| 864 |
gr.HTML("""
|
|
|
|
| 875 |
with gr.Group(elem_classes="prompt-container"):
|
| 876 |
prompt = gr.Textbox(
|
| 877 |
label="✨ Video Prompt",
|
| 878 |
+
value=default_prompt,
|
| 879 |
+
placeholder="Describe your video scene...",
|
| 880 |
+
lines=2,
|
| 881 |
elem_classes="prompt-input"
|
| 882 |
)
|
| 883 |
|
|
|
|
| 889 |
)
|
| 890 |
nag_scale = gr.Slider(
|
| 891 |
label="NAG Scale",
|
| 892 |
+
minimum=0.0,
|
| 893 |
maximum=20.0,
|
| 894 |
step=0.25,
|
| 895 |
+
value=5.0,
|
| 896 |
+
info="Higher values = stronger guidance (0 = no NAG)"
|
| 897 |
)
|
| 898 |
|
| 899 |
audio_mode = gr.Radio(
|
|
|
|
| 924 |
)
|
| 925 |
audio_steps = gr.Slider(
|
| 926 |
minimum=1,
|
| 927 |
+
maximum=25,
|
| 928 |
step=1,
|
| 929 |
+
value=10,
|
| 930 |
label="🚀 Audio Steps"
|
| 931 |
)
|
| 932 |
audio_cfg_strength = gr.Slider(
|
|
|
|
| 943 |
with gr.Row():
|
| 944 |
duration_seconds_input = gr.Slider(
|
| 945 |
minimum=1,
|
| 946 |
+
maximum=2,
|
| 947 |
step=1,
|
| 948 |
value=DEFAULT_DURATION_SECONDS,
|
| 949 |
label="📱 Duration (seconds)",
|
|
|
|
| 951 |
)
|
| 952 |
steps_slider = gr.Slider(
|
| 953 |
minimum=1,
|
| 954 |
+
maximum=2,
|
| 955 |
step=1,
|
| 956 |
value=DEFAULT_STEPS,
|
| 957 |
label="🔄 Inference Steps",
|
|
|
|
| 1022 |
gr.Markdown("### 🎯 Example Prompts")
|
| 1023 |
gr.Examples(
|
| 1024 |
examples=[
|
| 1025 |
+
["A cat playing guitar on stage", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
|
| 1026 |
+
128, 128, 1,
|
| 1027 |
+
1, DEFAULT_SEED, False,
|
| 1028 |
+
"Enable Audio", "guitar music", default_audio_negative_prompt, -1, 10, 4.5],
|
| 1029 |
+
["A red car driving on a cliff road", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
|
| 1030 |
+
128, 128, 1,
|
| 1031 |
+
1, DEFAULT_SEED, False,
|
| 1032 |
+
"Enable Audio", "car engine, wind", default_audio_negative_prompt, -1, 10, 4.5],
|
| 1033 |
+
["Glowing jellyfish floating in the sky", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
|
| 1034 |
+
128, 128, 1,
|
| 1035 |
+
1, DEFAULT_SEED, False,
|
| 1036 |
+
"Video Only", "", default_audio_negative_prompt, -1, 10, 4.5],
|
| 1037 |
],
|
| 1038 |
fn=generate_video,
|
| 1039 |
inputs=[prompt, nag_negative_prompt, nag_scale,
|