Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
d7edbd1
1
Parent(s):
52d009c
app demo
Browse files
app.py
CHANGED
|
@@ -63,14 +63,14 @@ def load_model(weight_dict, denosier):
|
|
| 63 |
|
| 64 |
|
| 65 |
class Pipeline:
|
| 66 |
-
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
|
| 67 |
self.vae = vae
|
| 68 |
self.denoiser = denoiser
|
| 69 |
self.conditioner = conditioner
|
| 70 |
self.diffusion_sampler = diffusion_sampler
|
| 71 |
self.resolution = resolution
|
|
|
|
| 72 |
|
| 73 |
-
@spaces.GPU
|
| 74 |
@torch.no_grad()
|
| 75 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 76 |
def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
|
|
@@ -83,7 +83,7 @@ class Pipeline:
|
|
| 83 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 84 |
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
| 85 |
with torch.no_grad():
|
| 86 |
-
condition, uncondition = conditioner([y,]*num_images)
|
| 87 |
# Sample images:
|
| 88 |
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
| 89 |
samples = vae.decode(samples)
|
|
@@ -136,7 +136,15 @@ if __name__ == "__main__":
|
|
| 136 |
vae = vae.cuda()
|
| 137 |
denoiser.eval()
|
| 138 |
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
with gr.Blocks() as demo:
|
| 142 |
gr.Markdown("DDT")
|
|
@@ -144,12 +152,14 @@ if __name__ == "__main__":
|
|
| 144 |
with gr.Column(scale=1):
|
| 145 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
| 146 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
| 147 |
-
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=
|
| 148 |
-
label = gr.
|
| 149 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
| 150 |
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
| 151 |
-
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min",
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
| 154 |
with gr.Column(scale=2):
|
| 155 |
btn = gr.Button("Generate")
|
|
@@ -167,4 +177,4 @@ if __name__ == "__main__":
|
|
| 167 |
guidance_interval_max,
|
| 168 |
timeshift
|
| 169 |
], outputs=[output])
|
| 170 |
-
demo.launch(
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
class Pipeline:
|
| 66 |
+
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution, classlabels2ids):
|
| 67 |
self.vae = vae
|
| 68 |
self.denoiser = denoiser
|
| 69 |
self.conditioner = conditioner
|
| 70 |
self.diffusion_sampler = diffusion_sampler
|
| 71 |
self.resolution = resolution
|
| 72 |
+
self.classlabels2ids = classlabels2ids
|
| 73 |
|
|
|
|
| 74 |
@torch.no_grad()
|
| 75 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 76 |
def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
|
|
|
|
| 83 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 84 |
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
| 85 |
with torch.no_grad():
|
| 86 |
+
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
|
| 87 |
# Sample images:
|
| 88 |
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
| 89 |
samples = vae.decode(samples)
|
|
|
|
| 136 |
vae = vae.cuda()
|
| 137 |
denoiser.eval()
|
| 138 |
|
| 139 |
+
# read imagenet classlabels
|
| 140 |
+
with open("imagenet_classlabels.txt", "r") as f:
|
| 141 |
+
classlabels = f.readlines()
|
| 142 |
+
classlabels = [label.strip() for label in classlabels]
|
| 143 |
+
|
| 144 |
+
classlabels2id = {label: i for i, label in enumerate(classlabels)}
|
| 145 |
+
id2classlabels = {i: label for i, label in enumerate(classlabels)}
|
| 146 |
+
|
| 147 |
+
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution, classlabels2id)
|
| 148 |
|
| 149 |
with gr.Blocks() as demo:
|
| 150 |
gr.Markdown("DDT")
|
|
|
|
| 152 |
with gr.Column(scale=1):
|
| 153 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
| 154 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
| 155 |
+
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=4)
|
| 156 |
+
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
|
| 157 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
| 158 |
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
| 159 |
+
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min",
|
| 160 |
+
value=0.0)
|
| 161 |
+
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max",
|
| 162 |
+
value=1.0)
|
| 163 |
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
| 164 |
with gr.Column(scale=2):
|
| 165 |
btn = gr.Button("Generate")
|
|
|
|
| 177 |
guidance_interval_max,
|
| 178 |
timeshift
|
| 179 |
], outputs=[output])
|
| 180 |
+
demo.launch()
|
imagenet_classlabels.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/diffusion/stateful_flow_matching/sharing_sampling.py
CHANGED
|
@@ -109,7 +109,7 @@ class EulerSampler(BaseSampler):
|
|
| 109 |
timesteps.reverse()
|
| 110 |
|
| 111 |
print("recompute timesteps solved by DP: ", timesteps)
|
| 112 |
-
return timesteps[:-1]
|
| 113 |
|
| 114 |
def _impl_sampling(self, net, noise, condition, uncondition):
|
| 115 |
"""
|
|
|
|
| 109 |
timesteps.reverse()
|
| 110 |
|
| 111 |
print("recompute timesteps solved by DP: ", timesteps)
|
| 112 |
+
return timesteps[:-1][:self.num_recompute_timesteps]
|
| 113 |
|
| 114 |
def _impl_sampling(self, net, noise, condition, uncondition):
|
| 115 |
"""
|