Spaces:
Runtime error
Runtime error
Commit
·
7967a47
1
Parent(s):
cb9bf15
Remove GPU attribution if CUDA error
Browse files
app.py
CHANGED
|
@@ -35,15 +35,8 @@ else:
|
|
| 35 |
is_shared_ui = False
|
| 36 |
is_gpu_associated = torch.cuda.is_available()
|
| 37 |
|
| 38 |
-
|
| 39 |
-
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
|
| 40 |
-
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
|
| 41 |
-
#component-4, #component-3, #component-10{min-height: 0}
|
| 42 |
-
.duplicate-button img{margin: 0}
|
| 43 |
-
'''
|
| 44 |
-
maximum_concepts = 3
|
| 45 |
|
| 46 |
-
#Pre download the files
|
| 47 |
if(is_gpu_associated):
|
| 48 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
| 49 |
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-1", ignore_patterns=["*.ckpt", "*.safetensors"])
|
|
@@ -51,8 +44,25 @@ if(is_gpu_associated):
|
|
| 51 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
| 52 |
model_to_load = model_v1
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def swap_text(option, base):
|
| 58 |
resize_width = 768 if base == "v2-1-768" else 512
|
|
@@ -60,7 +70,7 @@ def swap_text(option, base):
|
|
| 60 |
if(option == "object"):
|
| 61 |
instance_prompt_example = "cttoy"
|
| 62 |
freeze_for = 30
|
| 63 |
-
return [f"You are going to train `object`(s), upload 5-10 images of each object you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file
|
| 64 |
elif(option == "person"):
|
| 65 |
instance_prompt_example = "julcto"
|
| 66 |
freeze_for = 70
|
|
@@ -70,27 +80,17 @@ def swap_text(option, base):
|
|
| 70 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation)
|
| 71 |
else:
|
| 72 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation, value=False)
|
| 73 |
-
return [f"You are going to train a `person`(s), upload 10-20 images of each person you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file
|
| 74 |
elif(option == "style"):
|
| 75 |
instance_prompt_example = "trsldamrl"
|
| 76 |
freeze_for = 10
|
| 77 |
-
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file
|
| 78 |
-
|
| 79 |
-
def swap_base_model(selected_model):
|
| 80 |
-
if(is_gpu_associated):
|
| 81 |
-
global model_to_load
|
| 82 |
-
if(selected_model == "v1-5"):
|
| 83 |
-
model_to_load = model_v1
|
| 84 |
-
elif(selected_model == "v2-1-768"):
|
| 85 |
-
model_to_load = model_v2
|
| 86 |
-
else:
|
| 87 |
-
model_to_load = model_v2_512
|
| 88 |
|
| 89 |
def count_files(*inputs):
|
| 90 |
file_counter = 0
|
| 91 |
concept_counter = 0
|
| 92 |
for i, input in enumerate(inputs):
|
| 93 |
-
if(i < maximum_concepts
|
| 94 |
files = inputs[i]
|
| 95 |
if(files):
|
| 96 |
concept_counter+=1
|
|
@@ -133,6 +133,9 @@ def update_steps(*files_list):
|
|
| 133 |
file_counter+=len(files)
|
| 134 |
return(gr.update(value=file_counter*200))
|
| 135 |
|
|
|
|
|
|
|
|
|
|
| 136 |
def pad_image(image):
|
| 137 |
w, h = image.size
|
| 138 |
if w == h:
|
|
@@ -163,7 +166,34 @@ def validate_model_upload(hf_token, model_name):
|
|
| 163 |
if(model_name == ""):
|
| 164 |
raise gr.Error("Please fill in your model's name")
|
| 165 |
|
| 166 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
if is_shared_ui:
|
| 168 |
raise gr.Error("This Space only works in duplicated instances")
|
| 169 |
if not is_gpu_associated:
|
|
@@ -171,6 +201,9 @@ def train(*inputs):
|
|
| 171 |
hf_token = inputs[-5]
|
| 172 |
model_name = inputs[-7]
|
| 173 |
if(is_spaces):
|
|
|
|
|
|
|
|
|
|
| 174 |
remove_attribution_after = inputs[-6]
|
| 175 |
else:
|
| 176 |
remove_attribution_after = False
|
|
@@ -191,7 +224,6 @@ def train(*inputs):
|
|
| 191 |
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
| 192 |
if os.path.exists("hastrained.success"): os.remove("hastrained.success")
|
| 193 |
file_counter = 0
|
| 194 |
-
which_model = inputs[-10]
|
| 195 |
resolution = 512 if which_model != "v2-1-768" else 768
|
| 196 |
for i, input in enumerate(inputs):
|
| 197 |
if(i < maximum_concepts-1):
|
|
@@ -261,38 +293,22 @@ def train(*inputs):
|
|
| 261 |
print("Starting single training...")
|
| 262 |
lock_file = open("intraining.lock", "w")
|
| 263 |
lock_file.close()
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
mixed_precision="fp16",
|
| 281 |
-
train_batch_size=1,
|
| 282 |
-
gradient_accumulation_steps=1,
|
| 283 |
-
use_8bit_adam=True,
|
| 284 |
-
learning_rate=2e-6,
|
| 285 |
-
lr_scheduler="polynomial",
|
| 286 |
-
lr_warmup_steps = 0,
|
| 287 |
-
max_train_steps=Training_Steps,
|
| 288 |
-
num_class_images=200,
|
| 289 |
-
gradient_checkpointing=gradient_checkpointing,
|
| 290 |
-
cache_latents=cache_latents,
|
| 291 |
-
)
|
| 292 |
-
print("Starting multi-training...")
|
| 293 |
-
lock_file = open("intraining.lock", "w")
|
| 294 |
-
lock_file.close()
|
| 295 |
-
run_training(args_general)
|
| 296 |
gc.collect()
|
| 297 |
torch.cuda.empty_cache()
|
| 298 |
if(which_model == "v1-5"):
|
|
@@ -302,6 +318,7 @@ def train(*inputs):
|
|
| 302 |
shutil.copy(f"model_index.json", "output_model/model_index.json")
|
| 303 |
|
| 304 |
if(not remove_attribution_after):
|
|
|
|
| 305 |
print("Archiving model file...")
|
| 306 |
with tarfile.open("diffusers_model.tar", "w") as tar:
|
| 307 |
tar.add("output_model", arcname=os.path.basename("output_model"))
|
|
@@ -310,6 +327,7 @@ def train(*inputs):
|
|
| 310 |
trained_file.close()
|
| 311 |
print("Training completed!")
|
| 312 |
return [
|
|
|
|
| 313 |
gr.update(visible=True, value=["diffusers_model.tar"]), #result
|
| 314 |
gr.update(visible=True), #try_your_model
|
| 315 |
gr.update(visible=True), #push_to_hub
|
|
@@ -320,10 +338,7 @@ def train(*inputs):
|
|
| 320 |
else:
|
| 321 |
where_to_upload = inputs[-8]
|
| 322 |
push(model_name, where_to_upload, hf_token, which_model, True)
|
| 323 |
-
|
| 324 |
-
headers = { "authorization" : f"Bearer {hf_token}"}
|
| 325 |
-
body = {'flavor': 'cpu-basic'}
|
| 326 |
-
requests.post(hardware_url, json = body, headers=headers)
|
| 327 |
|
| 328 |
pipe_is_set = False
|
| 329 |
def generate(prompt, steps):
|
|
@@ -338,7 +353,7 @@ def generate(prompt, steps):
|
|
| 338 |
|
| 339 |
image = pipe(prompt, num_inference_steps=steps).images[0]
|
| 340 |
return(image)
|
| 341 |
-
|
| 342 |
def push(model_name, where_to_upload, hf_token, which_model, comes_from_automated=False):
|
| 343 |
validate_model_upload(hf_token, model_name)
|
| 344 |
if(not os.path.exists("model.ckpt")):
|
|
@@ -425,7 +440,10 @@ Sample pictures of:
|
|
| 425 |
extra_message = "Don't forget to remove the GPU attribution after you play with it."
|
| 426 |
else:
|
| 427 |
extra_message = "The GPU has been removed automatically as requested, and you can try the model via the model page"
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
| 429 |
print("Model uploaded successfully!")
|
| 430 |
return [gr.update(visible=True, value=f"Successfully uploaded your model. Access it [here](https://huggingface.co/{model_id})"), gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])]
|
| 431 |
|
|
@@ -488,8 +506,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 488 |
<div class="gr-prose" style="max-width: 80%">
|
| 489 |
<h2>Attention - This Space doesn't work in this shared UI</h2>
|
| 490 |
<p>For it to work, you can either run locally or duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it! <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></p>
|
| 491 |
-
<img class="instruction" src="file
|
| 492 |
-
<img class="arrow" src="file
|
| 493 |
</div>
|
| 494 |
''')
|
| 495 |
elif(is_spaces):
|
|
@@ -519,14 +537,15 @@ with gr.Blocks(css=css) as demo:
|
|
| 519 |
|
| 520 |
with gr.Row() as what_are_you_training:
|
| 521 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
| 522 |
-
|
| 523 |
-
|
|
|
|
| 524 |
#Very hacky approach to emulate dynamically created Gradio components
|
| 525 |
with gr.Row() as upload_your_concept:
|
| 526 |
with gr.Column():
|
| 527 |
thing_description = gr.Markdown("You are going to train an `object`, please upload 5-10 images of the object you are planning on training on from different angles/perspectives. You must have the right to do so and you are liable for the images you use, example")
|
| 528 |
thing_experimental = gr.Checkbox(label="Improve faces (prior preservation) - can take longer training but can improve faces", visible=False, value=False)
|
| 529 |
-
thing_image_example = gr.HTML('''<img src="file
|
| 530 |
things_naming = gr.Markdown("You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `cttoy` here). Images will be automatically cropped to 512x512.")
|
| 531 |
|
| 532 |
with gr.Column():
|
|
@@ -588,6 +607,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 588 |
training_summary_token = gr.Textbox(label="Hugging Face Write Token", type="password", visible=True)
|
| 589 |
|
| 590 |
train_btn = gr.Button("Start Training")
|
|
|
|
| 591 |
if(is_shared_ui):
|
| 592 |
training_ongoing = gr.Markdown("## This Space only works in duplicated instances. Please duplicate it and try again!", visible=False)
|
| 593 |
elif(not is_gpu_associated):
|
|
@@ -595,6 +615,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 595 |
else:
|
| 596 |
training_ongoing = gr.Markdown("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ", visible=False)
|
| 597 |
|
|
|
|
| 598 |
#Post-training UI
|
| 599 |
completed_training = gr.Markdown('''# ✅ Training completed.
|
| 600 |
### Don't forget to remove the GPU attribution after you are done trying and uploading your model''', visible=False)
|
|
@@ -624,9 +645,10 @@ with gr.Blocks(css=css) as demo:
|
|
| 624 |
type_of_thing.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
| 625 |
|
| 626 |
#Swap the base model
|
|
|
|
| 627 |
base_model_to_use.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
|
|
|
| 628 |
base_model_to_use.change(fn=swap_base_model, inputs=base_model_to_use, outputs=[])
|
| 629 |
-
|
| 630 |
#Update the summary box below the UI according to how many images are uploaded and whether users are using custom settings or not
|
| 631 |
for file in file_collection:
|
| 632 |
#file.change(fn=update_steps,inputs=file_collection, outputs=steps)
|
|
@@ -641,10 +663,12 @@ with gr.Blocks(css=css) as demo:
|
|
| 641 |
if(is_spaces):
|
| 642 |
training_summary_checkbox.change(fn=checkbox_swap, inputs=training_summary_checkbox, outputs=[training_summary_token_message, training_summary_token, training_summary_model_name, training_summary_where_to_upload],queue=False, show_progress=False)
|
| 643 |
#Add a message for while it is in training
|
| 644 |
-
|
|
|
|
| 645 |
|
| 646 |
#The main train function
|
| 647 |
-
train_btn.click(
|
|
|
|
| 648 |
|
| 649 |
#Button to generate an image from your trained model after training
|
| 650 |
generate_button.click(fn=generate, inputs=[prompt, inference_steps], outputs=result_image, queue=False)
|
|
|
|
| 35 |
is_shared_ui = False
|
| 36 |
is_gpu_associated = torch.cuda.is_available()
|
| 37 |
|
| 38 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
|
|
|
| 40 |
if(is_gpu_associated):
|
| 41 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
| 42 |
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-1", ignore_patterns=["*.ckpt", "*.safetensors"])
|
|
|
|
| 44 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
| 45 |
model_to_load = model_v1
|
| 46 |
|
| 47 |
+
def swap_base_model(selected_model):
|
| 48 |
+
if(is_gpu_associated):
|
| 49 |
+
global model_to_load
|
| 50 |
+
if(selected_model == "v1-5"):
|
| 51 |
+
model_to_load = model_v1
|
| 52 |
+
elif(selected_model == "v2-1-768"):
|
| 53 |
+
model_to_load = model_v2
|
| 54 |
+
else:
|
| 55 |
+
model_to_load = model_v2_512
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
css = '''
|
| 60 |
+
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
|
| 61 |
+
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
|
| 62 |
+
#component-4, #component-3, #component-10{min-height: 0}
|
| 63 |
+
.duplicate-button img{margin: 0}
|
| 64 |
+
'''
|
| 65 |
+
maximum_concepts = 3
|
| 66 |
|
| 67 |
def swap_text(option, base):
|
| 68 |
resize_width = 768 if base == "v2-1-768" else 512
|
|
|
|
| 70 |
if(option == "object"):
|
| 71 |
instance_prompt_example = "cttoy"
|
| 72 |
freeze_for = 30
|
| 73 |
+
return [f"You are going to train `object`(s), upload 5-10 images of each object you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=cat-toy.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, gr.update(visible=False)]
|
| 74 |
elif(option == "person"):
|
| 75 |
instance_prompt_example = "julcto"
|
| 76 |
freeze_for = 70
|
|
|
|
| 80 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation)
|
| 81 |
else:
|
| 82 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation, value=False)
|
| 83 |
+
return [f"You are going to train a `person`(s), upload 10-20 images of each person you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=person.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, prior_preservation_box_update]
|
| 84 |
elif(option == "style"):
|
| 85 |
instance_prompt_example = "trsldamrl"
|
| 86 |
freeze_for = 10
|
| 87 |
+
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=trsl_style.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}", freeze_for, gr.update(visible=False)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
def count_files(*inputs):
|
| 90 |
file_counter = 0
|
| 91 |
concept_counter = 0
|
| 92 |
for i, input in enumerate(inputs):
|
| 93 |
+
if(i < maximum_concepts):
|
| 94 |
files = inputs[i]
|
| 95 |
if(files):
|
| 96 |
concept_counter+=1
|
|
|
|
| 133 |
file_counter+=len(files)
|
| 134 |
return(gr.update(value=file_counter*200))
|
| 135 |
|
| 136 |
+
def visualise_progress_bar():
|
| 137 |
+
return gr.update(visible=True)
|
| 138 |
+
|
| 139 |
def pad_image(image):
|
| 140 |
w, h = image.size
|
| 141 |
if w == h:
|
|
|
|
| 166 |
if(model_name == ""):
|
| 167 |
raise gr.Error("Please fill in your model's name")
|
| 168 |
|
| 169 |
+
def swap_hardware(hf_token, hardware="cpu-basic"):
|
| 170 |
+
hardware_url = f"https://huggingface.co/spaces/{os.environ['SPACE_ID']}/hardware"
|
| 171 |
+
headers = { "authorization" : f"Bearer {hf_token}"}
|
| 172 |
+
body = {'flavor': hardware}
|
| 173 |
+
requests.post(hardware_url, json = body, headers=headers)
|
| 174 |
+
|
| 175 |
+
def swap_sleep_time(hf_token,sleep_time):
|
| 176 |
+
sleep_time_url = f"https://huggingface.co/api/spaces/{os.environ['SPACE_ID']}/sleeptime"
|
| 177 |
+
headers = { "authorization" : f"Bearer {hf_token}"}
|
| 178 |
+
body = {'seconds':sleep_time}
|
| 179 |
+
requests.post(sleep_time_url,json=body,headers=headers)
|
| 180 |
+
|
| 181 |
+
def get_sleep_time(hf_token):
|
| 182 |
+
sleep_time_url = f"https://huggingface.co/api/spaces/{os.environ['SPACE_ID']}"
|
| 183 |
+
headers = { "authorization" : f"Bearer {hf_token}"}
|
| 184 |
+
response = requests.get(sleep_time_url,headers=headers)
|
| 185 |
+
return response.json()['runtime']['gcTimeout']
|
| 186 |
+
|
| 187 |
+
def write_to_community(title, description,hf_token):
|
| 188 |
+
from huggingface_hub import HfApi
|
| 189 |
+
api = HfApi()
|
| 190 |
+
api.create_discussion(repo_id=os.environ['SPACE_ID'], title=title, description=description,repo_type="space", token=hf_token)
|
| 191 |
+
|
| 192 |
+
def train(progress=gr.Progress(track_tqdm=True), *inputs):
|
| 193 |
+
which_model = inputs[-10]
|
| 194 |
+
if(which_model == ""):
|
| 195 |
+
raise gr.Error("You forgot to select a base model to use")
|
| 196 |
+
|
| 197 |
if is_shared_ui:
|
| 198 |
raise gr.Error("This Space only works in duplicated instances")
|
| 199 |
if not is_gpu_associated:
|
|
|
|
| 201 |
hf_token = inputs[-5]
|
| 202 |
model_name = inputs[-7]
|
| 203 |
if(is_spaces):
|
| 204 |
+
sleep_time = get_sleep_time(hf_token)
|
| 205 |
+
if sleep_time:
|
| 206 |
+
swap_sleep_time(hf_token, -1)
|
| 207 |
remove_attribution_after = inputs[-6]
|
| 208 |
else:
|
| 209 |
remove_attribution_after = False
|
|
|
|
| 224 |
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
| 225 |
if os.path.exists("hastrained.success"): os.remove("hastrained.success")
|
| 226 |
file_counter = 0
|
|
|
|
| 227 |
resolution = 512 if which_model != "v2-1-768" else 768
|
| 228 |
for i, input in enumerate(inputs):
|
| 229 |
if(i < maximum_concepts-1):
|
|
|
|
| 293 |
print("Starting single training...")
|
| 294 |
lock_file = open("intraining.lock", "w")
|
| 295 |
lock_file.close()
|
| 296 |
+
try:
|
| 297 |
+
run_training(args_general)
|
| 298 |
+
except Exception as e:
|
| 299 |
+
if(is_spaces):
|
| 300 |
+
title="There was an error on during your training"
|
| 301 |
+
description=f'''
|
| 302 |
+
Unfortunately there was an error during training your {model_name} model.
|
| 303 |
+
Please check it out below. Feel free to report this issue to [Dreambooth Training](https://huggingface.co/spaces/multimodalart/dreambooth-training):
|
| 304 |
+
```
|
| 305 |
+
{str(e)}
|
| 306 |
+
```
|
| 307 |
+
'''
|
| 308 |
+
swap_hardware(hf_token, "cpu-basic")
|
| 309 |
+
write_to_community(title,description,hf_token)
|
| 310 |
+
|
| 311 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
gc.collect()
|
| 313 |
torch.cuda.empty_cache()
|
| 314 |
if(which_model == "v1-5"):
|
|
|
|
| 318 |
shutil.copy(f"model_index.json", "output_model/model_index.json")
|
| 319 |
|
| 320 |
if(not remove_attribution_after):
|
| 321 |
+
swap_sleep_time(hf_token, sleep_time)
|
| 322 |
print("Archiving model file...")
|
| 323 |
with tarfile.open("diffusers_model.tar", "w") as tar:
|
| 324 |
tar.add("output_model", arcname=os.path.basename("output_model"))
|
|
|
|
| 327 |
trained_file.close()
|
| 328 |
print("Training completed!")
|
| 329 |
return [
|
| 330 |
+
gr.update(visible=False), #progress_bar
|
| 331 |
gr.update(visible=True, value=["diffusers_model.tar"]), #result
|
| 332 |
gr.update(visible=True), #try_your_model
|
| 333 |
gr.update(visible=True), #push_to_hub
|
|
|
|
| 338 |
else:
|
| 339 |
where_to_upload = inputs[-8]
|
| 340 |
push(model_name, where_to_upload, hf_token, which_model, True)
|
| 341 |
+
swap_hardware(hf_token, "cpu-basic")
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
pipe_is_set = False
|
| 344 |
def generate(prompt, steps):
|
|
|
|
| 353 |
|
| 354 |
image = pipe(prompt, num_inference_steps=steps).images[0]
|
| 355 |
return(image)
|
| 356 |
+
|
| 357 |
def push(model_name, where_to_upload, hf_token, which_model, comes_from_automated=False):
|
| 358 |
validate_model_upload(hf_token, model_name)
|
| 359 |
if(not os.path.exists("model.ckpt")):
|
|
|
|
| 440 |
extra_message = "Don't forget to remove the GPU attribution after you play with it."
|
| 441 |
else:
|
| 442 |
extra_message = "The GPU has been removed automatically as requested, and you can try the model via the model page"
|
| 443 |
+
title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!"
|
| 444 |
+
description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}"
|
| 445 |
+
write_to_community(title, description, hf_token)
|
| 446 |
+
#api.create_discussion(repo_id=os.environ['SPACE_ID'], title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!", description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}",repo_type="space", token=hf_token)
|
| 447 |
print("Model uploaded successfully!")
|
| 448 |
return [gr.update(visible=True, value=f"Successfully uploaded your model. Access it [here](https://huggingface.co/{model_id})"), gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])]
|
| 449 |
|
|
|
|
| 506 |
<div class="gr-prose" style="max-width: 80%">
|
| 507 |
<h2>Attention - This Space doesn't work in this shared UI</h2>
|
| 508 |
<p>For it to work, you can either run locally or duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it! <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></p>
|
| 509 |
+
<img class="instruction" src="file=duplicate.png">
|
| 510 |
+
<img class="arrow" src="file=arrow.png" />
|
| 511 |
</div>
|
| 512 |
''')
|
| 513 |
elif(is_spaces):
|
|
|
|
| 537 |
|
| 538 |
with gr.Row() as what_are_you_training:
|
| 539 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
| 540 |
+
with gr.Column():
|
| 541 |
+
base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-1-512", "v2-1-768"], value="v1-5", interactive=True)
|
| 542 |
+
|
| 543 |
#Very hacky approach to emulate dynamically created Gradio components
|
| 544 |
with gr.Row() as upload_your_concept:
|
| 545 |
with gr.Column():
|
| 546 |
thing_description = gr.Markdown("You are going to train an `object`, please upload 5-10 images of the object you are planning on training on from different angles/perspectives. You must have the right to do so and you are liable for the images you use, example")
|
| 547 |
thing_experimental = gr.Checkbox(label="Improve faces (prior preservation) - can take longer training but can improve faces", visible=False, value=False)
|
| 548 |
+
thing_image_example = gr.HTML('''<img src="file=cat-toy.png" />''')
|
| 549 |
things_naming = gr.Markdown("You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `cttoy` here). Images will be automatically cropped to 512x512.")
|
| 550 |
|
| 551 |
with gr.Column():
|
|
|
|
| 607 |
training_summary_token = gr.Textbox(label="Hugging Face Write Token", type="password", visible=True)
|
| 608 |
|
| 609 |
train_btn = gr.Button("Start Training")
|
| 610 |
+
progress_bar = gr.Textbox(visible=False)
|
| 611 |
if(is_shared_ui):
|
| 612 |
training_ongoing = gr.Markdown("## This Space only works in duplicated instances. Please duplicate it and try again!", visible=False)
|
| 613 |
elif(not is_gpu_associated):
|
|
|
|
| 615 |
else:
|
| 616 |
training_ongoing = gr.Markdown("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ", visible=False)
|
| 617 |
|
| 618 |
+
|
| 619 |
#Post-training UI
|
| 620 |
completed_training = gr.Markdown('''# ✅ Training completed.
|
| 621 |
### Don't forget to remove the GPU attribution after you are done trying and uploading your model''', visible=False)
|
|
|
|
| 645 |
type_of_thing.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
| 646 |
|
| 647 |
#Swap the base model
|
| 648 |
+
|
| 649 |
base_model_to_use.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
| 650 |
+
#base_model_to_use.change(fn=visualise_progress_bar, inputs=[], outputs=progress_bar)
|
| 651 |
base_model_to_use.change(fn=swap_base_model, inputs=base_model_to_use, outputs=[])
|
|
|
|
| 652 |
#Update the summary box below the UI according to how many images are uploaded and whether users are using custom settings or not
|
| 653 |
for file in file_collection:
|
| 654 |
#file.change(fn=update_steps,inputs=file_collection, outputs=steps)
|
|
|
|
| 663 |
if(is_spaces):
|
| 664 |
training_summary_checkbox.change(fn=checkbox_swap, inputs=training_summary_checkbox, outputs=[training_summary_token_message, training_summary_token, training_summary_model_name, training_summary_where_to_upload],queue=False, show_progress=False)
|
| 665 |
#Add a message for while it is in training
|
| 666 |
+
|
| 667 |
+
#train_btn.click(lambda:gr.update(visible=True), inputs=None, outputs=training_ongoing)
|
| 668 |
|
| 669 |
#The main train function
|
| 670 |
+
train_btn.click(lambda:gr.update(visible=True), inputs=[], outputs=progress_bar)
|
| 671 |
+
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[base_model_to_use]+[thing_experimental]+[training_summary_where_to_upload]+[training_summary_model_name]+[training_summary_checkbox]+[training_summary_token]+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[progress_bar, result, try_your_model, push_to_hub, convert_button, training_ongoing, completed_training], queue=False)
|
| 672 |
|
| 673 |
#Button to generate an image from your trained model after training
|
| 674 |
generate_button.click(fn=generate, inputs=[prompt, inference_steps], outputs=result_image, queue=False)
|