Spaces:
Running
on
Zero
Running
on
Zero
update(*): HF space support.
Browse files- app.py +2 -0
- pipeline/i2v_pipeline.py +2 -0
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from einops import rearrange
|
|
| 13 |
from datetime import datetime
|
| 14 |
from typing import Optional, List, Dict
|
| 15 |
from huggingface_hub import snapshot_download
|
|
|
|
| 16 |
|
| 17 |
os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
|
| 18 |
|
|
@@ -526,6 +527,7 @@ def validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt,
|
|
| 526 |
|
| 527 |
return errors
|
| 528 |
|
|
|
|
| 529 |
def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
|
| 530 |
# Validate inputs first
|
| 531 |
validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
|
|
|
|
| 13 |
from datetime import datetime
|
| 14 |
from typing import Optional, List, Dict
|
| 15 |
from huggingface_hub import snapshot_download
|
| 16 |
+
import spaces
|
| 17 |
|
| 18 |
os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
|
| 19 |
|
|
|
|
| 527 |
|
| 528 |
return errors
|
| 529 |
|
| 530 |
+
@spaces.GPU
|
| 531 |
def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
|
| 532 |
# Validate inputs first
|
| 533 |
validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
|
pipeline/i2v_pipeline.py
CHANGED
|
@@ -160,6 +160,8 @@ class WanVideoPipeline(BasePipeline):
|
|
| 160 |
state_dict, config = state_dict
|
| 161 |
config.update(config_dict or {})
|
| 162 |
model = model_cls(**config)
|
|
|
|
|
|
|
| 163 |
if "use_local_lora" in config_dict or "use_dera" in config_dict:
|
| 164 |
strict = False
|
| 165 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
|
|
|
|
| 160 |
state_dict, config = state_dict
|
| 161 |
config.update(config_dict or {})
|
| 162 |
model = model_cls(**config)
|
| 163 |
+
if torch.cuda.is_available():
|
| 164 |
+
model = model.to("cuda")
|
| 165 |
if "use_local_lora" in config_dict or "use_dera" in config_dict:
|
| 166 |
strict = False
|
| 167 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
|