File size: 4,221 Bytes
dc155d4
 
 
 
 
 
 
 
 
 
 
 
879ee4e
dc155d4
 
 
 
b09f6a2
dc155d4
 
 
 
988720a
 
 
 
 
 
 
dc155d4
 
 
 
 
 
 
 
 
 
 
 
 
d82d0e1
 
02c2928
 
 
 
 
 
 
 
 
 
64d1e7e
 
02c2928
 
 
 
 
 
64d1e7e
 
 
 
 
 
02c2928
 
 
64d1e7e
 
 
02c2928
 
 
 
 
 
 
64d1e7e
02c2928
 
 
 
 
 
 
 
 
 
 
82d7cc1
dc155d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988720a
 
dc155d4
988720a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import ZeroGPUCompiledModel
from optimization_utils import drain_module_parameters

P = ParamSpec('P')


TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: TRANSFORMER_NUM_FRAMES_DIM,
    },
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):

 @spaces.GPU(duration=1500)
 def compile_transformer():
    # --- LoRA 1: lightning (loads into default transformer) ---
    pipeline.load_lora_weights(
        "Kijai/WanVideo_comfy",
        weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
        adapter_name="lightning",
    )

    # --- LoRA 2: lightning_2 (loads into transformer_2) ---
    kwargs_lora = {"load_into_transformer_2": True}
    pipeline.load_lora_weights(
        "deadman44/Wan2.2_T2i_T2v_LoRA",
        weight_name="lora_wan2.2_myjd_Low_v01.safetensors",
        adapter_name="lightning_2",
        **kwargs_lora,
    )

    # --- LoRA 3: orbit_shot (the ostris repo you asked for) ---
    # Load into transformer_2 as well (set load_into_transformer_2 True if this adapter targets transformer_2)
    # pipeline.load_lora_weights(
    #     "ostris/wan22_i2v_14b_orbit_shot_lora",
    #     weight_name="wan22_14b_i2v_orbit_low_noise.safetensors",
    #     adapter_name="orbit_shot",
    #     **kwargs_lora,
    # )

    # Register adapters and their relative weights
    # (adjust adapter_weights to taste; here each is weight 1.0)
    #pipeline.set_adapters(["lightning", "lightning_2", "orbit_shot"], adapter_weights=[1.0, 1.0, 1.0])
    pipeline.set_adapters(["lightning", "lightning_2", "orbit_shot"], adapter_weights=[1., 1.])


    # Fuse each adapter into the correct component with chosen lora_scale:
    # - lightning -> transformer (boosted by 3x in your original)
    # - lightning_2 -> transformer_2
    # - orbit_shot -> transformer_2 (or transformer depending on the LoRA design)
    pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
    pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
    #pipeline.fuse_lora(adapter_names=["orbit_shot"], lora_scale=1.0, components=["transformer_2"])

    # After fusing, you can unload LoRA weights to free memory (fused weights remain applied)
    pipeline.unload_lora_weights()

    # --- then continue with capture_component_call / export / compile logic as you already have ---
    with capture_component_call(pipeline, 'transformer') as call:
        pipeline(*args, **kwargs)

    # ... rest of your function unchanged ...
 
   
    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
    cl1, cl2, cp1, cp2 = compile_transformer()

    def combined_transformer_1(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl1(*args, **kwargs)
        else:
            return cp1(*args, **kwargs)

    def combined_transformer_2(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl2(*args, **kwargs)
        else:
            return cp2(*args, **kwargs)

    pipeline.transformer.forward = combined_transformer_1
    drain_module_parameters(pipeline.transformer)

    pipeline.transformer_2.forward = combined_transformer_2
    drain_module_parameters(pipeline.transformer_2)