Spaces:
Runtime error
Runtime error
Commit
·
95ea872
1
Parent(s):
62ee77b
Enable xformers
Browse files- train_dreambooth.py +9 -1
train_dreambooth.py
CHANGED
|
@@ -18,6 +18,7 @@ from accelerate import Accelerator
|
|
| 18 |
from accelerate.logging import get_logger
|
| 19 |
from accelerate.utils import set_seed
|
| 20 |
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
|
|
|
| 21 |
from diffusers.optimization import get_scheduler
|
| 22 |
from huggingface_hub import HfFolder, Repository, whoami
|
| 23 |
from PIL import Image
|
|
@@ -533,7 +534,14 @@ def run_training(args_imported):
|
|
| 533 |
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 534 |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 535 |
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
vae.requires_grad_(False)
|
| 538 |
if not args.train_text_encoder:
|
| 539 |
text_encoder.requires_grad_(False)
|
|
|
|
| 18 |
from accelerate.logging import get_logger
|
| 19 |
from accelerate.utils import set_seed
|
| 20 |
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
| 21 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 22 |
from diffusers.optimization import get_scheduler
|
| 23 |
from huggingface_hub import HfFolder, Repository, whoami
|
| 24 |
from PIL import Image
|
|
|
|
| 534 |
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 535 |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 536 |
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 537 |
+
if is_xformers_available():
|
| 538 |
+
try:
|
| 539 |
+
print("Enabling memory efficient attention with xformers...")
|
| 540 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 541 |
+
except Exception as e:
|
| 542 |
+
logger.warning(
|
| 543 |
+
f"Could not enable memory efficient attention. Make sure xformers is installed correctly and a GPU is available: {e}"
|
| 544 |
+
)
|
| 545 |
vae.requires_grad_(False)
|
| 546 |
if not args.train_text_encoder:
|
| 547 |
text_encoder.requires_grad_(False)
|