Add changes from https://github.com/huggingface/diffusers/pull/2106
Browse files- train_dreambooth_lora.py +20 -14
train_dreambooth_lora.py
CHANGED
|
@@ -215,7 +215,13 @@ def parse_args(input_args=None):
|
|
| 215 |
),
|
| 216 |
)
|
| 217 |
parser.add_argument(
|
| 218 |
-
"--center_crop",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
)
|
| 220 |
parser.add_argument(
|
| 221 |
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
|
@@ -988,19 +994,19 @@ def main(args):
|
|
| 988 |
out_path = test_image_dir / f'image_{i}.png'
|
| 989 |
image.save(out_path)
|
| 990 |
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
|
| 1005 |
if args.push_to_hub:
|
| 1006 |
save_model_card(
|
|
|
|
| 215 |
),
|
| 216 |
)
|
| 217 |
parser.add_argument(
|
| 218 |
+
"--center_crop",
|
| 219 |
+
default=False,
|
| 220 |
+
action="store_true",
|
| 221 |
+
help=(
|
| 222 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
| 223 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
| 224 |
+
),
|
| 225 |
)
|
| 226 |
parser.add_argument(
|
| 227 |
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
|
|
|
| 994 |
out_path = test_image_dir / f'image_{i}.png'
|
| 995 |
image.save(out_path)
|
| 996 |
|
| 997 |
+
for tracker in accelerator.trackers:
|
| 998 |
+
if tracker.name == "tensorboard":
|
| 999 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 1000 |
+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
| 1001 |
+
if tracker.name == "wandb":
|
| 1002 |
+
tracker.log(
|
| 1003 |
+
{
|
| 1004 |
+
"test": [
|
| 1005 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
| 1006 |
+
for i, image in enumerate(images)
|
| 1007 |
+
]
|
| 1008 |
+
}
|
| 1009 |
+
)
|
| 1010 |
|
| 1011 |
if args.push_to_hub:
|
| 1012 |
save_model_card(
|