Save test images with train_dreambooth_lora.py
Browse files- train_dreambooth_lora.py +7 -1
train_dreambooth_lora.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
#
|
| 4 |
-
# This file is
|
| 5 |
# The original license is as below:
|
| 6 |
#
|
| 7 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
@@ -981,6 +981,12 @@ def main(args):
|
|
| 981 |
prompt = args.num_validation_images * [args.validation_prompt]
|
| 982 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
| 983 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
for tracker in accelerator.trackers:
|
| 985 |
if tracker.name == "tensorboard":
|
| 986 |
np_images = np.stack([np.asarray(img) for img in images])
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
#
|
| 4 |
+
# This file is adapted from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
|
| 5 |
# The original license is as below:
|
| 6 |
#
|
| 7 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
|
|
| 981 |
prompt = args.num_validation_images * [args.validation_prompt]
|
| 982 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
| 983 |
|
| 984 |
+
test_image_dir = Path(args.output_dir) / 'test_images'
|
| 985 |
+
test_image_dir.mkdir()
|
| 986 |
+
for i, image in enumerate(images):
|
| 987 |
+
out_path = test_image_dir / f'image_{i}.png'
|
| 988 |
+
image.save(out_path)
|
| 989 |
+
|
| 990 |
for tracker in accelerator.trackers:
|
| 991 |
if tracker.name == "tensorboard":
|
| 992 |
np_images = np.stack([np.asarray(img) for img in images])
|