Spaces:
Runtime error
Runtime error
Contrebande Labs
commited on
Commit
·
e0cb68e
1
Parent(s):
06f2eaf
sync with working jax inference code from main repo
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ from diffusers import (
|
|
| 16 |
|
| 17 |
from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
|
| 18 |
|
|
|
|
| 19 |
def get_inference_lambda(seed):
|
| 20 |
|
| 21 |
tokenizer = ByT5Tokenizer()
|
|
@@ -51,7 +52,7 @@ def get_inference_lambda(seed):
|
|
| 51 |
"trained_betas": None,
|
| 52 |
}
|
| 53 |
)
|
| 54 |
-
timesteps =
|
| 55 |
guidance_scale = jnp.array([7.5], dtype=jnp.float32)
|
| 56 |
|
| 57 |
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
|
@@ -68,7 +69,13 @@ def get_inference_lambda(seed):
|
|
| 68 |
|
| 69 |
image_width = image_height = 256
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
def __tokenize_prompt(prompt: str):
|
| 74 |
|
|
@@ -78,15 +85,11 @@ def get_inference_lambda(seed):
|
|
| 78 |
padding="max_length",
|
| 79 |
truncation=True,
|
| 80 |
return_tensors="jax",
|
| 81 |
-
).input_ids
|
| 82 |
|
| 83 |
-
def __convert_image(
|
| 84 |
-
|
| 85 |
-
return
|
| 86 |
-
# return [
|
| 87 |
-
# Image.fromarray(image)
|
| 88 |
-
# for image in (np.asarray(vae_output) * 255).round().astype(np.uint8)
|
| 89 |
-
# ]
|
| 90 |
|
| 91 |
def __predict_image(tokenized_prompt: jnp.array):
|
| 92 |
|
|
@@ -99,14 +102,6 @@ def get_inference_lambda(seed):
|
|
| 99 |
context = jnp.concatenate(
|
| 100 |
[negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
|
| 101 |
)
|
| 102 |
-
jax.debug.print("got text encoding...")
|
| 103 |
-
|
| 104 |
-
latent_shape = (
|
| 105 |
-
tokenized_prompt.shape[0],
|
| 106 |
-
unet.in_channels,
|
| 107 |
-
image_width // vae_scale_factor,
|
| 108 |
-
image_height // vae_scale_factor,
|
| 109 |
-
)
|
| 110 |
|
| 111 |
def ___timestep(step, step_args):
|
| 112 |
|
|
@@ -148,15 +143,12 @@ def get_inference_lambda(seed):
|
|
| 148 |
scheduler_state, guided_unet_prediction_sample, t, latents
|
| 149 |
).to_tuple()
|
| 150 |
|
| 151 |
-
jax.debug.print("did one step...")
|
| 152 |
-
|
| 153 |
return latents, scheduler_state
|
| 154 |
|
| 155 |
# initialize scheduler state
|
| 156 |
initial_scheduler_state = scheduler.set_timesteps(
|
| 157 |
scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
|
| 158 |
)
|
| 159 |
-
jax.debug.print("initialized scheduler state...")
|
| 160 |
|
| 161 |
# initialize latents
|
| 162 |
initial_latents = (
|
|
@@ -165,49 +157,33 @@ def get_inference_lambda(seed):
|
|
| 165 |
)
|
| 166 |
* initial_scheduler_state.init_noise_sigma
|
| 167 |
)
|
| 168 |
-
jax.debug.print("initialized latents...")
|
| 169 |
|
| 170 |
final_latents, _ = jax.lax.fori_loop(
|
| 171 |
0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
|
| 172 |
)
|
| 173 |
-
jax.debug.print("got final latents...")
|
| 174 |
-
|
| 175 |
-
# scale and decode the image latents with vae
|
| 176 |
-
image = (
|
| 177 |
-
(
|
| 178 |
-
vae.apply(
|
| 179 |
-
{"params": vae_params},
|
| 180 |
-
1 / vae.config.scaling_factor * final_latents,
|
| 181 |
-
method=vae.decode,
|
| 182 |
-
).sample
|
| 183 |
-
/ 2
|
| 184 |
-
+ 0.5
|
| 185 |
-
)
|
| 186 |
-
.clip(0, 1)
|
| 187 |
-
.transpose(0, 2, 3, 1)
|
| 188 |
-
)
|
| 189 |
-
jax.debug.print("got vae processed image output...")
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
|
| 196 |
return lambda prompt: __convert_image(
|
| 197 |
-
|
| 198 |
)
|
| 199 |
|
| 200 |
|
| 201 |
generate_image_for_prompt = get_inference_lambda(87)
|
| 202 |
|
| 203 |
-
print(f"JAX devices: {jax.devices()}")
|
| 204 |
-
print(f"JAX device type: {jax.devices()[0].device_kind}")
|
| 205 |
-
|
| 206 |
-
def infer_charred(prompt):
|
| 207 |
-
# your inference function for charr stable difusion control
|
| 208 |
-
generate_image_for_prompt(prompt)
|
| 209 |
-
return None
|
| 210 |
-
|
| 211 |
|
| 212 |
with gr.Blocks(theme="gradio/soft") as demo:
|
| 213 |
|
|
@@ -239,10 +215,12 @@ with gr.Blocks(theme="gradio/soft") as demo:
|
|
| 239 |
submit_btn = gr.Button(value="Submit")
|
| 240 |
charred_inputs = [prompt_input_charr]
|
| 241 |
submit_btn.click(
|
| 242 |
-
fn=
|
|
|
|
|
|
|
| 243 |
)
|
| 244 |
# examples = [["postage stamp from california", "low quality", "charr_output.png", "charr_output.png" ]]
|
| 245 |
# gr.Examples(fn = infer_sd, inputs = ["text", "text", "image", "image"], examples=examples, cache_examples=True)
|
| 246 |
|
| 247 |
demo.queue(concurrency_count=1)
|
| 248 |
-
demo.launch(debug=True, show_error=True
|
|
|
|
| 16 |
|
| 17 |
from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
|
| 18 |
|
| 19 |
+
|
| 20 |
def get_inference_lambda(seed):
|
| 21 |
|
| 22 |
tokenizer = ByT5Tokenizer()
|
|
|
|
| 52 |
"trained_betas": None,
|
| 53 |
}
|
| 54 |
)
|
| 55 |
+
timesteps = 20
|
| 56 |
guidance_scale = jnp.array([7.5], dtype=jnp.float32)
|
| 57 |
|
| 58 |
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
|
|
|
| 69 |
|
| 70 |
image_width = image_height = 256
|
| 71 |
|
| 72 |
+
# Generating latent shape
|
| 73 |
+
latent_shape = (
|
| 74 |
+
negative_prompt_text_encoder_hidden_states.shape[0],
|
| 75 |
+
unet.in_channels,
|
| 76 |
+
image_width // vae_scale_factor,
|
| 77 |
+
image_height // vae_scale_factor,
|
| 78 |
+
)
|
| 79 |
|
| 80 |
def __tokenize_prompt(prompt: str):
|
| 81 |
|
|
|
|
| 85 |
padding="max_length",
|
| 86 |
truncation=True,
|
| 87 |
return_tensors="jax",
|
| 88 |
+
).input_ids
|
| 89 |
|
| 90 |
+
def __convert_image(image):
|
| 91 |
+
# create PIL image from JAX tensor converted to numpy
|
| 92 |
+
return Image.fromarray(np.asarray(image), mode="RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def __predict_image(tokenized_prompt: jnp.array):
|
| 95 |
|
|
|
|
| 102 |
context = jnp.concatenate(
|
| 103 |
[negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
|
| 104 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
def ___timestep(step, step_args):
|
| 107 |
|
|
|
|
| 143 |
scheduler_state, guided_unet_prediction_sample, t, latents
|
| 144 |
).to_tuple()
|
| 145 |
|
|
|
|
|
|
|
| 146 |
return latents, scheduler_state
|
| 147 |
|
| 148 |
# initialize scheduler state
|
| 149 |
initial_scheduler_state = scheduler.set_timesteps(
|
| 150 |
scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
|
| 151 |
)
|
|
|
|
| 152 |
|
| 153 |
# initialize latents
|
| 154 |
initial_latents = (
|
|
|
|
| 157 |
)
|
| 158 |
* initial_scheduler_state.init_noise_sigma
|
| 159 |
)
|
|
|
|
| 160 |
|
| 161 |
final_latents, _ = jax.lax.fori_loop(
|
| 162 |
0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
|
| 163 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
vae_output = vae.apply(
|
| 166 |
+
{"params": vae_params},
|
| 167 |
+
1 / vae.config.scaling_factor * final_latents,
|
| 168 |
+
method=vae.decode,
|
| 169 |
+
).sample
|
| 170 |
+
|
| 171 |
+
# return 8 bit RGB image (width, height, rgb)
|
| 172 |
+
return (
|
| 173 |
+
((vae_output / 2 + 0.5).transpose(0, 2, 3, 1).clip(0, 1) * 255)
|
| 174 |
+
.round()
|
| 175 |
+
.astype(jnp.uint8)[0]
|
| 176 |
+
)
|
| 177 |
|
| 178 |
+
jax_jit_compiled_predict_image = jax.jit(__predict_image)
|
| 179 |
|
| 180 |
return lambda prompt: __convert_image(
|
| 181 |
+
jax_jit_compiled_predict_image(__tokenize_prompt(prompt))
|
| 182 |
)
|
| 183 |
|
| 184 |
|
| 185 |
generate_image_for_prompt = get_inference_lambda(87)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
with gr.Blocks(theme="gradio/soft") as demo:
|
| 189 |
|
|
|
|
| 215 |
submit_btn = gr.Button(value="Submit")
|
| 216 |
charred_inputs = [prompt_input_charr]
|
| 217 |
submit_btn.click(
|
| 218 |
+
fn=generate_image_for_prompt,
|
| 219 |
+
inputs=charred_inputs,
|
| 220 |
+
outputs=[charred_output],
|
| 221 |
)
|
| 222 |
# examples = [["postage stamp from california", "low quality", "charr_output.png", "charr_output.png" ]]
|
| 223 |
# gr.Examples(fn = infer_sd, inputs = ["text", "text", "image", "image"], examples=examples, cache_examples=True)
|
| 224 |
|
| 225 |
demo.queue(concurrency_count=1)
|
| 226 |
+
demo.launch(debug=True, show_error=True)
|