Update README.md
Browse files
README.md
CHANGED
|
@@ -53,6 +53,8 @@ This weights here are intended to be used with the 🧨 Diffusers library. If yo
|
|
| 53 |
|
| 54 |
We recommend using [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run Stable Diffusion.
|
| 55 |
|
|
|
|
|
|
|
| 56 |
```bash
|
| 57 |
pip install --upgrade diffusers transformers scipy
|
| 58 |
```
|
|
@@ -119,6 +121,75 @@ with autocast("cuda"):
|
|
| 119 |
image.save("astronaut_rides_horse.png")
|
| 120 |
```
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
# Uses
|
| 123 |
|
| 124 |
## Direct Use
|
|
|
|
| 53 |
|
| 54 |
We recommend using [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run Stable Diffusion.
|
| 55 |
|
| 56 |
+
### PyTorch
|
| 57 |
+
|
| 58 |
```bash
|
| 59 |
pip install --upgrade diffusers transformers scipy
|
| 60 |
```
|
|
|
|
| 121 |
image.save("astronaut_rides_horse.png")
|
| 122 |
```
|
| 123 |
|
| 124 |
+
### JAX/Flax
|
| 125 |
+
|
| 126 |
+
To use StableDiffusion on TPUs and GPUs for faster inference you can leverage JAX/Flax.
|
| 127 |
+
|
| 128 |
+
Running the pipeline with default PNDMScheduler
|
| 129 |
+
|
| 130 |
+
```python
|
| 131 |
+
import jax
|
| 132 |
+
import numpy as np
|
| 133 |
+
from flax.jax_utils import replicate
|
| 134 |
+
from flax.training.common_utils import shard
|
| 135 |
+
|
| 136 |
+
from diffusers import FlaxStableDiffusionPipeline
|
| 137 |
+
|
| 138 |
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
| 139 |
+
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
prompt = "a photo of an astronaut riding a horse on mars"
|
| 143 |
+
|
| 144 |
+
prng_seed = jax.random.PRNGKey(0)
|
| 145 |
+
num_inference_steps = 50
|
| 146 |
+
|
| 147 |
+
num_samples = jax.device_count()
|
| 148 |
+
prompt = num_samples * [prompt]
|
| 149 |
+
prompt_ids = pipeline.prepare_inputs(prompt)
|
| 150 |
+
|
| 151 |
+
# shard inputs and rng
|
| 152 |
+
params = replicate(params)
|
| 153 |
+
prng_seed = jax.random.split(prng_seed, 8)
|
| 154 |
+
prompt_ids = shard(prompt_ids)
|
| 155 |
+
|
| 156 |
+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
| 157 |
+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
**Note**:
|
| 161 |
+
If you are limited by TPU memory, please make sure to load the `FlaxStableDiffusionPipeline` in `bfloat16` precision instead of the default `float32` precision as done above. You can do so by telling diffusers to load the weights from "bf16" branch.
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
import jax
|
| 165 |
+
import numpy as np
|
| 166 |
+
from flax.jax_utils import replicate
|
| 167 |
+
from flax.training.common_utils import shard
|
| 168 |
+
|
| 169 |
+
from diffusers import FlaxStableDiffusionPipeline
|
| 170 |
+
|
| 171 |
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
| 172 |
+
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
prompt = "a photo of an astronaut riding a horse on mars"
|
| 176 |
+
|
| 177 |
+
prng_seed = jax.random.PRNGKey(0)
|
| 178 |
+
num_inference_steps = 50
|
| 179 |
+
|
| 180 |
+
num_samples = jax.device_count()
|
| 181 |
+
prompt = num_samples * [prompt]
|
| 182 |
+
prompt_ids = pipeline.prepare_inputs(prompt)
|
| 183 |
+
|
| 184 |
+
# shard inputs and rng
|
| 185 |
+
params = replicate(params)
|
| 186 |
+
prng_seed = jax.random.split(prng_seed, 8)
|
| 187 |
+
prompt_ids = shard(prompt_ids)
|
| 188 |
+
|
| 189 |
+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
| 190 |
+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
# Uses
|
| 194 |
|
| 195 |
## Direct Use
|