hysts
commited on
Commit
·
69ed433
1
Parent(s):
e305340
Use Uploader to upload models in training time
Browse filesUsing two different upload methods was not a good idea.
So, stop using upload method provided by train_dreambooth_lora.py
and use Uploader class in this repo.
Also, to make it easier to port updates for train_dreambooth_lora.py
from the diffusers library, reset changes.
- train_dreambooth_lora.py +39 -44
- trainer.py +7 -0
- utils.py +38 -0
train_dreambooth_lora.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
-
# This file is adapted from https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/examples/dreambooth/train_dreambooth_lora.py
|
| 3 |
-
# The original license is as below.
|
| 4 |
-
#
|
| 5 |
# coding=utf-8
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 7 |
#
|
| 8 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -25,6 +26,7 @@ import warnings
|
|
| 25 |
from pathlib import Path
|
| 26 |
from typing import Optional
|
| 27 |
|
|
|
|
| 28 |
import torch
|
| 29 |
import torch.nn.functional as F
|
| 30 |
import torch.utils.checkpoint
|
|
@@ -48,7 +50,7 @@ from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
|
| 48 |
from diffusers.optimization import get_scheduler
|
| 49 |
from diffusers.utils import check_min_version, is_wandb_available
|
| 50 |
from diffusers.utils.import_utils import is_xformers_available
|
| 51 |
-
from huggingface_hub import HfFolder, Repository, create_repo,
|
| 52 |
from PIL import Image
|
| 53 |
from torchvision import transforms
|
| 54 |
from tqdm.auto import tqdm
|
|
@@ -61,9 +63,9 @@ check_min_version("0.12.0.dev0")
|
|
| 61 |
logger = get_logger(__name__)
|
| 62 |
|
| 63 |
|
| 64 |
-
def save_model_card(repo_name,
|
| 65 |
-
img_str =
|
| 66 |
-
for i, image in enumerate(images
|
| 67 |
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 68 |
img_str += f"\n"
|
| 69 |
|
|
@@ -71,7 +73,6 @@ def save_model_card(repo_name, base_model, instance_prompt, test_prompt="", imag
|
|
| 71 |
---
|
| 72 |
license: creativeml-openrail-m
|
| 73 |
base_model: {base_model}
|
| 74 |
-
instance_prompt: {instance_prompt}
|
| 75 |
tags:
|
| 76 |
- stable-diffusion
|
| 77 |
- stable-diffusion-diffusers
|
|
@@ -79,11 +80,11 @@ tags:
|
|
| 79 |
- diffusers
|
| 80 |
inference: true
|
| 81 |
---
|
| 82 |
-
"""
|
| 83 |
model_card = f"""
|
| 84 |
# LoRA DreamBooth - {repo_name}
|
| 85 |
|
| 86 |
-
These are LoRA adaption weights for
|
| 87 |
{img_str}
|
| 88 |
"""
|
| 89 |
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
|
@@ -364,9 +365,6 @@ def parse_args(input_args=None):
|
|
| 364 |
parser.add_argument(
|
| 365 |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 366 |
)
|
| 367 |
-
parser.add_argument("--private_repo", action="store_true")
|
| 368 |
-
parser.add_argument("--delete_existing_repo", action="store_true")
|
| 369 |
-
parser.add_argument("--upload_to_lora_library", action="store_true")
|
| 370 |
|
| 371 |
if input_args is not None:
|
| 372 |
args = parser.parse_args(input_args)
|
|
@@ -610,17 +608,11 @@ def main(args):
|
|
| 610 |
if accelerator.is_main_process:
|
| 611 |
if args.push_to_hub:
|
| 612 |
if args.hub_model_id is None:
|
| 613 |
-
|
| 614 |
-
repo_name = get_full_repo_name(Path(args.output_dir).name, organization=organization, token=args.hub_token)
|
| 615 |
else:
|
| 616 |
repo_name = args.hub_model_id
|
| 617 |
|
| 618 |
-
|
| 619 |
-
try:
|
| 620 |
-
delete_repo(repo_name, token=args.hub_token)
|
| 621 |
-
except Exception:
|
| 622 |
-
pass
|
| 623 |
-
create_repo(repo_name, token=args.hub_token, private=args.private_repo)
|
| 624 |
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
| 625 |
|
| 626 |
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
|
@@ -826,14 +818,21 @@ def main(args):
|
|
| 826 |
dirs = os.listdir(args.output_dir)
|
| 827 |
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 828 |
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 829 |
-
path = dirs[-1]
|
| 830 |
-
accelerator.print(f"Resuming from checkpoint {path}")
|
| 831 |
-
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 832 |
-
global_step = int(path.split("-")[1])
|
| 833 |
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
|
| 838 |
# Only show the progress bar once on each machine.
|
| 839 |
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
|
@@ -943,6 +942,9 @@ def main(args):
|
|
| 943 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
| 944 |
|
| 945 |
for tracker in accelerator.trackers:
|
|
|
|
|
|
|
|
|
|
| 946 |
if tracker.name == "wandb":
|
| 947 |
tracker.log(
|
| 948 |
{
|
|
@@ -974,11 +976,15 @@ def main(args):
|
|
| 974 |
pipeline.unet.load_attn_procs(args.output_dir)
|
| 975 |
|
| 976 |
# run inference
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
|
|
|
| 980 |
|
| 981 |
for tracker in accelerator.trackers:
|
|
|
|
|
|
|
|
|
|
| 982 |
if tracker.name == "wandb":
|
| 983 |
tracker.log(
|
| 984 |
{
|
|
@@ -992,23 +998,12 @@ def main(args):
|
|
| 992 |
if args.push_to_hub:
|
| 993 |
save_model_card(
|
| 994 |
repo_name,
|
| 995 |
-
base_model=args.pretrained_model_name_or_path,
|
| 996 |
-
instance_prompt=args.instance_prompt,
|
| 997 |
-
test_prompt=args.validation_prompt,
|
| 998 |
images=images,
|
| 999 |
-
repo_folder=args.output_dir,
|
| 1000 |
-
)
|
| 1001 |
-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
| 1002 |
-
else:
|
| 1003 |
-
repo_name = Path(args.output_dir).name
|
| 1004 |
-
save_model_card(
|
| 1005 |
-
repo_name,
|
| 1006 |
base_model=args.pretrained_model_name_or_path,
|
| 1007 |
-
|
| 1008 |
-
test_prompt=args.validation_prompt,
|
| 1009 |
-
images=images,
|
| 1010 |
repo_folder=args.output_dir,
|
| 1011 |
)
|
|
|
|
| 1012 |
|
| 1013 |
accelerator.end_training()
|
| 1014 |
|
|
|
|
| 1 |
#!/usr/bin/env python
|
|
|
|
|
|
|
|
|
|
| 2 |
# coding=utf-8
|
| 3 |
+
#
|
| 4 |
+
# This file is copied from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
|
| 5 |
+
# The original license is as below:
|
| 6 |
+
#
|
| 7 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 8 |
#
|
| 9 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
| 26 |
from pathlib import Path
|
| 27 |
from typing import Optional
|
| 28 |
|
| 29 |
+
import numpy as np
|
| 30 |
import torch
|
| 31 |
import torch.nn.functional as F
|
| 32 |
import torch.utils.checkpoint
|
|
|
|
| 50 |
from diffusers.optimization import get_scheduler
|
| 51 |
from diffusers.utils import check_min_version, is_wandb_available
|
| 52 |
from diffusers.utils.import_utils import is_xformers_available
|
| 53 |
+
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
| 54 |
from PIL import Image
|
| 55 |
from torchvision import transforms
|
| 56 |
from tqdm.auto import tqdm
|
|
|
|
| 63 |
logger = get_logger(__name__)
|
| 64 |
|
| 65 |
|
| 66 |
+
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
|
| 67 |
+
img_str = ""
|
| 68 |
+
for i, image in enumerate(images):
|
| 69 |
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 70 |
img_str += f"\n"
|
| 71 |
|
|
|
|
| 73 |
---
|
| 74 |
license: creativeml-openrail-m
|
| 75 |
base_model: {base_model}
|
|
|
|
| 76 |
tags:
|
| 77 |
- stable-diffusion
|
| 78 |
- stable-diffusion-diffusers
|
|
|
|
| 80 |
- diffusers
|
| 81 |
inference: true
|
| 82 |
---
|
| 83 |
+
"""
|
| 84 |
model_card = f"""
|
| 85 |
# LoRA DreamBooth - {repo_name}
|
| 86 |
|
| 87 |
+
These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
| 88 |
{img_str}
|
| 89 |
"""
|
| 90 |
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
|
|
|
| 365 |
parser.add_argument(
|
| 366 |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 367 |
)
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
if input_args is not None:
|
| 370 |
args = parser.parse_args(input_args)
|
|
|
|
| 608 |
if accelerator.is_main_process:
|
| 609 |
if args.push_to_hub:
|
| 610 |
if args.hub_model_id is None:
|
| 611 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
|
|
|
| 612 |
else:
|
| 613 |
repo_name = args.hub_model_id
|
| 614 |
|
| 615 |
+
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
| 617 |
|
| 618 |
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
|
|
|
| 818 |
dirs = os.listdir(args.output_dir)
|
| 819 |
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 820 |
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 821 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
|
|
|
|
|
|
|
|
|
| 822 |
|
| 823 |
+
if path is None:
|
| 824 |
+
accelerator.print(
|
| 825 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 826 |
+
)
|
| 827 |
+
args.resume_from_checkpoint = None
|
| 828 |
+
else:
|
| 829 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 830 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 831 |
+
global_step = int(path.split("-")[1])
|
| 832 |
+
|
| 833 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
| 834 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 835 |
+
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
| 836 |
|
| 837 |
# Only show the progress bar once on each machine.
|
| 838 |
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
|
|
|
| 942 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
| 943 |
|
| 944 |
for tracker in accelerator.trackers:
|
| 945 |
+
if tracker.name == "tensorboard":
|
| 946 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 947 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
| 948 |
if tracker.name == "wandb":
|
| 949 |
tracker.log(
|
| 950 |
{
|
|
|
|
| 976 |
pipeline.unet.load_attn_procs(args.output_dir)
|
| 977 |
|
| 978 |
# run inference
|
| 979 |
+
if args.validation_prompt and args.num_validation_images > 0:
|
| 980 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 981 |
+
prompt = args.num_validation_images * [args.validation_prompt]
|
| 982 |
+
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
| 983 |
|
| 984 |
for tracker in accelerator.trackers:
|
| 985 |
+
if tracker.name == "tensorboard":
|
| 986 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 987 |
+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
| 988 |
if tracker.name == "wandb":
|
| 989 |
tracker.log(
|
| 990 |
{
|
|
|
|
| 998 |
if args.push_to_hub:
|
| 999 |
save_model_card(
|
| 1000 |
repo_name,
|
|
|
|
|
|
|
|
|
|
| 1001 |
images=images,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
base_model=args.pretrained_model_name_or_path,
|
| 1003 |
+
prompt=args.instance_prompt,
|
|
|
|
|
|
|
| 1004 |
repo_folder=args.output_dir,
|
| 1005 |
)
|
| 1006 |
+
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
| 1007 |
|
| 1008 |
accelerator.end_training()
|
| 1009 |
|
trainer.py
CHANGED
|
@@ -14,6 +14,7 @@ import torch
|
|
| 14 |
from huggingface_hub import HfApi
|
| 15 |
|
| 16 |
from app_upload import LoRAModelUploader
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
|
@@ -125,6 +126,12 @@ class Trainer:
|
|
| 125 |
command_s = ' '.join(command.split())
|
| 126 |
f.write(command_s)
|
| 127 |
subprocess.run(shlex.split(command))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
message = 'Training completed!'
|
| 129 |
print(message)
|
| 130 |
|
|
|
|
| 14 |
from huggingface_hub import HfApi
|
| 15 |
|
| 16 |
from app_upload import LoRAModelUploader
|
| 17 |
+
from utils import save_model_card
|
| 18 |
|
| 19 |
|
| 20 |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
|
|
|
| 126 |
command_s = ' '.join(command.split())
|
| 127 |
f.write(command_s)
|
| 128 |
subprocess.run(shlex.split(command))
|
| 129 |
+
save_model_card(save_dir=output_dir,
|
| 130 |
+
base_model=base_model,
|
| 131 |
+
instance_prompt=instance_prompt,
|
| 132 |
+
test_prompt=validation_prompt,
|
| 133 |
+
test_image_dir='test_images')
|
| 134 |
+
|
| 135 |
message = 'Training completed!'
|
| 136 |
print(message)
|
| 137 |
|
utils.py
CHANGED
|
@@ -18,3 +18,41 @@ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
|
|
| 18 |
exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
|
| 19 |
]
|
| 20 |
return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
|
| 19 |
]
|
| 20 |
return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def save_model_card(
|
| 24 |
+
save_dir: pathlib.Path,
|
| 25 |
+
base_model: str,
|
| 26 |
+
instance_prompt: str,
|
| 27 |
+
test_prompt: str = '',
|
| 28 |
+
test_image_dir: str = '',
|
| 29 |
+
) -> None:
|
| 30 |
+
image_str = ''
|
| 31 |
+
if test_prompt and test_image_dir:
|
| 32 |
+
image_paths = sorted((save_dir / test_image_dir).glob('*'))
|
| 33 |
+
if image_paths:
|
| 34 |
+
image_str = f'Test prompt: {test_prompt}\n'
|
| 35 |
+
for image_path in image_paths:
|
| 36 |
+
rel_path = image_path.relative_to(save_dir)
|
| 37 |
+
image_str += f'\n'
|
| 38 |
+
|
| 39 |
+
model_card = f'''---
|
| 40 |
+
license: creativeml-openrail-m
|
| 41 |
+
base_model: {base_model}
|
| 42 |
+
instance_prompt: {instance_prompt}
|
| 43 |
+
tags:
|
| 44 |
+
- stable-diffusion
|
| 45 |
+
- stable-diffusion-diffusers
|
| 46 |
+
- text-to-image
|
| 47 |
+
- diffusers
|
| 48 |
+
inference: true
|
| 49 |
+
---
|
| 50 |
+
# LoRA DreamBooth - {save_dir.name}
|
| 51 |
+
|
| 52 |
+
These are LoRA adaption weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
|
| 53 |
+
|
| 54 |
+
{image_str}
|
| 55 |
+
'''
|
| 56 |
+
|
| 57 |
+
with open(save_dir / 'README.md', 'w') as f:
|
| 58 |
+
f.write(model_card)
|