Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse files
app.py
CHANGED
|
@@ -160,6 +160,7 @@ def recursive_update(d, u):
|
|
| 160 |
def start_training(
|
| 161 |
lora_name,
|
| 162 |
concept_sentence,
|
|
|
|
| 163 |
steps,
|
| 164 |
lr,
|
| 165 |
rank,
|
|
@@ -224,7 +225,12 @@ def start_training(
|
|
| 224 |
config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
|
| 225 |
else:
|
| 226 |
config["config"]["process"][0]["train"]["disable_sampling"] = True
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
if(use_more_advanced_options):
|
| 229 |
more_advanced_options_dict = yaml.safe_load(more_advanced_options)
|
| 230 |
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
|
|
@@ -291,11 +297,13 @@ def update_pricing(steps, oauth_token: Union[gr.OAuthToken, None]):
|
|
| 291 |
else:
|
| 292 |
return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
|
| 293 |
|
|
|
|
|
|
|
|
|
|
| 294 |
config_yaml = '''
|
| 295 |
device: cuda:0
|
| 296 |
model:
|
| 297 |
is_flux: true
|
| 298 |
-
name_or_path: black-forest-labs/FLUX.1-dev
|
| 299 |
quantize: true
|
| 300 |
network:
|
| 301 |
linear: 16 #it will overcome the 'rank' parameter
|
|
@@ -342,6 +350,7 @@ h3{margin-top: 0}
|
|
| 342 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
| 343 |
.tabitem{border: 0px}
|
| 344 |
.group_padding{padding: .55em}
|
|
|
|
| 345 |
"""
|
| 346 |
with gr.Blocks(theme=theme, css=css) as demo:
|
| 347 |
gr.Markdown(
|
|
@@ -352,18 +361,22 @@ with gr.Blocks(theme=theme, css=css) as demo:
|
|
| 352 |
gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
|
| 353 |
with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
|
| 354 |
with gr.Column() as main_ui:
|
| 355 |
-
with gr.
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
with gr.Group(visible=True) as image_upload:
|
| 368 |
with gr.Row():
|
| 369 |
images = gr.File(
|
|
@@ -503,12 +516,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
|
|
| 503 |
inputs=[steps],
|
| 504 |
outputs=[cost_preview, cost_preview_info, payment_update, start]
|
| 505 |
)
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
|
| 508 |
fn=start_training,
|
| 509 |
inputs=[
|
| 510 |
lora_name,
|
| 511 |
concept_sentence,
|
|
|
|
| 512 |
steps,
|
| 513 |
lr,
|
| 514 |
rank,
|
|
|
|
| 160 |
def start_training(
|
| 161 |
lora_name,
|
| 162 |
concept_sentence,
|
| 163 |
+
which_model,
|
| 164 |
steps,
|
| 165 |
lr,
|
| 166 |
rank,
|
|
|
|
| 225 |
config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
|
| 226 |
else:
|
| 227 |
config["config"]["process"][0]["train"]["disable_sampling"] = True
|
| 228 |
+
|
| 229 |
+
if(which_model == "[schnell] (4 step fast model)"):
|
| 230 |
+
config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
|
| 231 |
+
config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
|
| 232 |
+
config["config"]["process"][0]["sample"]["sample_steps"] = 4
|
| 233 |
+
|
| 234 |
if(use_more_advanced_options):
|
| 235 |
more_advanced_options_dict = yaml.safe_load(more_advanced_options)
|
| 236 |
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
|
|
|
|
| 297 |
else:
|
| 298 |
return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
|
| 299 |
|
| 300 |
+
def swap_base_model(model):
|
| 301 |
+
return gr.update(visible=True) if model == "[dev] (high quality model, non-commercial license)" else gr.update(visible=False)
|
| 302 |
+
|
| 303 |
config_yaml = '''
|
| 304 |
device: cuda:0
|
| 305 |
model:
|
| 306 |
is_flux: true
|
|
|
|
| 307 |
quantize: true
|
| 308 |
network:
|
| 309 |
linear: 16 #it will overcome the 'rank' parameter
|
|
|
|
| 350 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
| 351 |
.tabitem{border: 0px}
|
| 352 |
.group_padding{padding: .55em}
|
| 353 |
+
#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
|
| 354 |
"""
|
| 355 |
with gr.Blocks(theme=theme, css=css) as demo:
|
| 356 |
gr.Markdown(
|
|
|
|
| 361 |
gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
|
| 362 |
with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
|
| 363 |
with gr.Column() as main_ui:
|
| 364 |
+
with gr.Group():
|
| 365 |
+
with gr.Row():
|
| 366 |
+
lora_name = gr.Textbox(
|
| 367 |
+
label="The name of your LoRA",
|
| 368 |
+
info="This has to be a unique name",
|
| 369 |
+
placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
|
| 370 |
+
)
|
| 371 |
+
concept_sentence = gr.Textbox(
|
| 372 |
+
label="Trigger word/sentence",
|
| 373 |
+
info="Trigger word or sentence to be used",
|
| 374 |
+
placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
|
| 375 |
+
interactive=True,
|
| 376 |
+
)
|
| 377 |
+
which_model = gr.Radio(["[schnell] (4 step fast model)", "[dev] (high quality model, non-commercial license - available when training locally)"], label="Which base model to train?", elem_id="space_model" if is_spaces else "local_model", value="[schnell] (4 step fast model)",)
|
| 378 |
+
model_warning = gr.Markdown("""> [dev] model license is non-commercial. By choosing to fine-tune [dev], you must agree with [its license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) and make sure the LoRA you will train and the training process you would start does not violate it.
|
| 379 |
+
""", visible=False)
|
| 380 |
with gr.Group(visible=True) as image_upload:
|
| 381 |
with gr.Row():
|
| 382 |
images = gr.File(
|
|
|
|
| 516 |
inputs=[steps],
|
| 517 |
outputs=[cost_preview, cost_preview_info, payment_update, start]
|
| 518 |
)
|
| 519 |
+
|
| 520 |
+
which_model.change(
|
| 521 |
+
fn=swap_base_model,
|
| 522 |
+
inputs=which_model,
|
| 523 |
+
outputs=model_warning
|
| 524 |
+
)
|
| 525 |
start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
|
| 526 |
fn=start_training,
|
| 527 |
inputs=[
|
| 528 |
lora_name,
|
| 529 |
concept_sentence,
|
| 530 |
+
which_model,
|
| 531 |
steps,
|
| 532 |
lr,
|
| 533 |
rank,
|