wan2-2-T2V-EXP / optimization.py
rahul7star's picture
Update optimization.py
64d1e7e verified
raw
history blame
4.22 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import ZeroGPUCompiledModel
from optimization_utils import drain_module_parameters
P = ParamSpec('P')
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
2: TRANSFORMER_NUM_FRAMES_DIM,
},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
# --- LoRA 1: lightning (loads into default transformer) ---
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
adapter_name="lightning",
)
# --- LoRA 2: lightning_2 (loads into transformer_2) ---
kwargs_lora = {"load_into_transformer_2": True}
pipeline.load_lora_weights(
"deadman44/Wan2.2_T2i_T2v_LoRA",
weight_name="lora_wan2.2_myjd_Low_v01.safetensors",
adapter_name="lightning_2",
**kwargs_lora,
)
# --- LoRA 3: orbit_shot (the ostris repo you asked for) ---
# Load into transformer_2 as well (set load_into_transformer_2 True if this adapter targets transformer_2)
# pipeline.load_lora_weights(
# "ostris/wan22_i2v_14b_orbit_shot_lora",
# weight_name="wan22_14b_i2v_orbit_low_noise.safetensors",
# adapter_name="orbit_shot",
# **kwargs_lora,
# )
# Register adapters and their relative weights
# (adjust adapter_weights to taste; here each is weight 1.0)
#pipeline.set_adapters(["lightning", "lightning_2", "orbit_shot"], adapter_weights=[1.0, 1.0, 1.0])
pipeline.set_adapters(["lightning", "lightning_2", "orbit_shot"], adapter_weights=[1., 1.])
# Fuse each adapter into the correct component with chosen lora_scale:
# - lightning -> transformer (boosted by 3x in your original)
# - lightning_2 -> transformer_2
# - orbit_shot -> transformer_2 (or transformer depending on the LoRA design)
pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
#pipeline.fuse_lora(adapter_names=["orbit_shot"], lora_scale=1.0, components=["transformer_2"])
# After fusing, you can unload LoRA weights to free memory (fused weights remain applied)
pipeline.unload_lora_weights()
# --- then continue with capture_component_call / export / compile logic as you already have ---
with capture_component_call(pipeline, 'transformer') as call:
pipeline(*args, **kwargs)
# ... rest of your function unchanged ...
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
cl1, cl2, cp1, cp2 = compile_transformer()
def combined_transformer_1(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl1(*args, **kwargs)
else:
return cp1(*args, **kwargs)
def combined_transformer_2(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl2(*args, **kwargs)
else:
return cp2(*args, **kwargs)
pipeline.transformer.forward = combined_transformer_1
drain_module_parameters(pipeline.transformer)
pipeline.transformer_2.forward = combined_transformer_2
drain_module_parameters(pipeline.transformer_2)