Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update lora_trainer.py
Browse files- lora_trainer.py +7 -428
    	
        lora_trainer.py
    CHANGED
    
    | @@ -1,430 +1,9 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
            from huggingface_hub import whoami    
         | 
| 3 | 
            -
            os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
         | 
| 4 | 
            -
            import sys
         | 
| 5 | 
            -
            import spaces
         | 
| 6 | 
            -
            # Add the current working directory to the Python path
         | 
| 7 | 
            -
            sys.path.insert(0, os.getcwd())
         | 
| 8 |  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            import torch
         | 
| 12 | 
            -
            import uuid
         | 
| 13 | 
            -
            import os
         | 
| 14 | 
            -
            import shutil
         | 
| 15 | 
            -
            import json
         | 
| 16 | 
            -
            import yaml
         | 
| 17 | 
            -
            from slugify import slugify
         | 
| 18 | 
            -
            from transformers import AutoProcessor, AutoModelForCausalLM
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            sys.path.insert(0, "ai-toolkit")
         | 
| 21 | 
            -
            from toolkit.job import get_job
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            MAX_IMAGES = 150
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            def load_captioning(uploaded_files, concept_sentence):
         | 
| 26 | 
            -
                uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
         | 
| 27 | 
            -
                txt_files = [file for file in uploaded_files if file.endswith('.txt')]
         | 
| 28 | 
            -
                txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
         | 
| 29 | 
            -
                updates = []
         | 
| 30 | 
            -
                if len(uploaded_images) <= 1:
         | 
| 31 | 
            -
                    raise gr.Error(
         | 
| 32 | 
            -
                        "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
         | 
| 33 | 
            -
                    )
         | 
| 34 | 
            -
                elif len(uploaded_images) > MAX_IMAGES:
         | 
| 35 | 
            -
                    raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
         | 
| 36 | 
            -
                # Update for the captioning_area
         | 
| 37 | 
            -
                # for _ in range(3):
         | 
| 38 | 
            -
                updates.append(gr.update(visible=True))
         | 
| 39 | 
            -
                # Update visibility and image for each captioning row and image
         | 
| 40 | 
            -
                for i in range(1, MAX_IMAGES + 1):
         | 
| 41 | 
            -
                    # Determine if the current row and image should be visible
         | 
| 42 | 
            -
                    visible = i <= len(uploaded_images)
         | 
| 43 | 
            -
                    
         | 
| 44 | 
            -
                    # Update visibility of the captioning row
         | 
| 45 | 
            -
                    updates.append(gr.update(visible=visible))
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                    # Update for image component - display image if available, otherwise hide
         | 
| 48 | 
            -
                    image_value = uploaded_images[i - 1] if visible else None
         | 
| 49 | 
            -
                    updates.append(gr.update(value=image_value, visible=visible))
         | 
| 50 | 
            -
                    
         | 
| 51 | 
            -
                    corresponding_caption = False
         | 
| 52 | 
            -
                    if(image_value):
         | 
| 53 | 
            -
                        base_name = os.path.splitext(os.path.basename(image_value))[0]
         | 
| 54 | 
            -
                        print(base_name)
         | 
| 55 | 
            -
                        print(image_value)
         | 
| 56 | 
            -
                        if base_name in txt_files_dict:
         | 
| 57 | 
            -
                            print("entrou")
         | 
| 58 | 
            -
                            with open(txt_files_dict[base_name], 'r') as file:
         | 
| 59 | 
            -
                                corresponding_caption = file.read()
         | 
| 60 | 
            -
                                
         | 
| 61 | 
            -
                    # Update value of captioning area
         | 
| 62 | 
            -
                    text_value = corresponding_caption if visible and corresponding_caption else "[trigger]" if visible and concept_sentence else None
         | 
| 63 | 
            -
                    updates.append(gr.update(value=text_value, visible=visible))
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                # Update for the sample caption area
         | 
| 66 | 
            -
                updates.append(gr.update(visible=True))
         | 
| 67 | 
            -
                # Update prompt samples
         | 
| 68 | 
            -
                updates.append(gr.update(placeholder=f'A portrait of person in a bustling cafe {concept_sentence}', value=f'A person in a bustling cafe {concept_sentence}'))
         | 
| 69 | 
            -
                updates.append(gr.update(placeholder=f"A mountainous landscape in the style of {concept_sentence}"))
         | 
| 70 | 
            -
                updates.append(gr.update(placeholder=f"A {concept_sentence} in a mall"))
         | 
| 71 | 
            -
                updates.append(gr.update(visible=True))
         | 
| 72 | 
            -
                return updates
         | 
| 73 | 
            -
             | 
| 74 | 
            -
            def hide_captioning():
         | 
| 75 | 
            -
                return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) 
         | 
| 76 | 
            -
             | 
| 77 | 
            -
            def create_dataset(*inputs):
         | 
| 78 | 
            -
                print("Creating dataset")
         | 
| 79 | 
            -
                images = inputs[0]
         | 
| 80 | 
            -
                destination_folder = str(f"datasets")
         | 
| 81 | 
            -
                if not os.path.exists(destination_folder):
         | 
| 82 | 
            -
                    os.makedirs(destination_folder)
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl")
         | 
| 85 | 
            -
                with open(jsonl_file_path, "a") as jsonl_file:
         | 
| 86 | 
            -
                    for index, image in enumerate(images):
         | 
| 87 | 
            -
                        new_image_path = shutil.copy(image, destination_folder)
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                        original_caption = inputs[index + 1]
         | 
| 90 | 
            -
                        file_name = os.path.basename(new_image_path)
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                        data = {"file_name": file_name, "prompt": original_caption}
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                        jsonl_file.write(json.dumps(data) + "\n")
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                return destination_folder
         | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
            def run_captioning(images, concept_sentence, *captions):
         | 
| 100 | 
            -
                #Load internally to not consume resources for training
         | 
| 101 | 
            -
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 102 | 
            -
                torch_dtype = torch.float16
         | 
| 103 | 
            -
                model = AutoModelForCausalLM.from_pretrained(
         | 
| 104 | 
            -
                    "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
         | 
| 105 | 
            -
                ).to(device)
         | 
| 106 | 
            -
                processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
         | 
| 107 | 
            -
             | 
| 108 | 
            -
                captions = list(captions)
         | 
| 109 | 
            -
                for i, image_path in enumerate(images):
         | 
| 110 | 
            -
                    print(captions[i])
         | 
| 111 | 
            -
                    if isinstance(image_path, str):  # If image is a file path
         | 
| 112 | 
            -
                        image = Image.open(image_path).convert("RGB")
         | 
| 113 | 
            -
             | 
| 114 | 
            -
                    prompt = "<DETAILED_CAPTION>"
         | 
| 115 | 
            -
                    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                    generated_ids = model.generate(
         | 
| 118 | 
            -
                        input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
         | 
| 119 | 
            -
                    )
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
         | 
| 122 | 
            -
                    parsed_answer = processor.post_process_generation(
         | 
| 123 | 
            -
                        generated_text, task=prompt, image_size=(image.width, image.height)
         | 
| 124 | 
            -
                    )
         | 
| 125 | 
            -
                    caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
         | 
| 126 | 
            -
                    if concept_sentence:
         | 
| 127 | 
            -
                        caption_text = f"{caption_text} [trigger]"
         | 
| 128 | 
            -
                    captions[i] = caption_text
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                    yield captions
         | 
| 131 | 
            -
                model.to("cpu")
         | 
| 132 | 
            -
                del model
         | 
| 133 | 
            -
                del processor
         | 
| 134 | 
            -
             | 
| 135 | 
            -
            def recursive_update(d, u):
         | 
| 136 | 
            -
                for k, v in u.items():
         | 
| 137 | 
            -
                    if isinstance(v, dict) and v:
         | 
| 138 | 
            -
                        d[k] = recursive_update(d.get(k, {}), v)
         | 
| 139 | 
            -
                    else:
         | 
| 140 | 
            -
                        d[k] = v
         | 
| 141 | 
            -
                return d
         | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
            def get_duration(  lora_name,
         | 
| 145 | 
            -
                concept_sentence,
         | 
| 146 | 
            -
                steps,
         | 
| 147 | 
            -
                lr,
         | 
| 148 | 
            -
                rank,
         | 
| 149 | 
            -
                model_to_train,
         | 
| 150 | 
            -
                low_vram,
         | 
| 151 | 
            -
                dataset_folder,
         | 
| 152 | 
            -
                sample_1,
         | 
| 153 | 
            -
                sample_2,
         | 
| 154 | 
            -
                sample_3,
         | 
| 155 | 
            -
                use_more_advanced_options,
         | 
| 156 | 
            -
                more_advanced_options,):
         | 
| 157 | 
            -
                return total_second_length * 60
         | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
            def start_training(
         | 
| 161 | 
            -
                lora_name,
         | 
| 162 | 
            -
                concept_sentence,
         | 
| 163 | 
            -
                steps,
         | 
| 164 | 
            -
                lr,
         | 
| 165 | 
            -
                rank,
         | 
| 166 | 
            -
                model_to_train,
         | 
| 167 | 
            -
                low_vram,
         | 
| 168 | 
            -
                dataset_folder,
         | 
| 169 | 
            -
                sample_1,
         | 
| 170 | 
            -
                sample_2,
         | 
| 171 | 
            -
                sample_3,
         | 
| 172 | 
            -
                use_more_advanced_options,
         | 
| 173 | 
            -
                more_advanced_options,
         | 
| 174 | 
            -
            ):
         | 
| 175 | 
            -
                push_to_hub = True
         | 
| 176 | 
            -
                print("flux ttain invoke ====================")
         | 
| 177 | 
            -
                if not lora_name:
         | 
| 178 | 
            -
                    raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
         | 
| 179 | 
            -
                try:
         | 
| 180 | 
            -
                    if whoami()["auth"]["accessToken"]["role"] == "write" or "repo.write" in whoami()["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]:
         | 
| 181 | 
            -
                        gr.Info(f"Starting training locally {whoami()['name']}. Your LoRA will be available locally and in Hugging Face after it finishes.")
         | 
| 182 | 
            -
                    else:
         | 
| 183 | 
            -
                        push_to_hub = False
         | 
| 184 | 
            -
                        gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
         | 
| 185 | 
            -
                except:
         | 
| 186 | 
            -
                    push_to_hub = False
         | 
| 187 | 
            -
                    gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
         | 
| 188 | 
            -
                        
         | 
| 189 | 
            -
                print("Started training")
         | 
| 190 | 
            -
                slugged_lora_name = slugify(lora_name)
         | 
| 191 | 
            -
             | 
| 192 | 
            -
                # Load the default config
         | 
| 193 | 
            -
                with open("config/examples/train_lora_flux_24gb.yaml", "r") as f:
         | 
| 194 | 
            -
                    config = yaml.safe_load(f)
         | 
| 195 | 
            -
             | 
| 196 | 
            -
                # Update the config with user inputs
         | 
| 197 | 
            -
                config["config"]["name"] = slugged_lora_name
         | 
| 198 | 
            -
                config["config"]["process"][0]["model"]["low_vram"] = low_vram
         | 
| 199 | 
            -
                config["config"]["process"][0]["train"]["skip_first_sample"] = True
         | 
| 200 | 
            -
                config["config"]["process"][0]["train"]["steps"] = int(steps)
         | 
| 201 | 
            -
                config["config"]["process"][0]["train"]["lr"] = float(lr)
         | 
| 202 | 
            -
                config["config"]["process"][0]["network"]["linear"] = int(rank)
         | 
| 203 | 
            -
                config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
         | 
| 204 | 
            -
                config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
         | 
| 205 | 
            -
                config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
         | 
| 206 | 
            -
                if(push_to_hub):
         | 
| 207 | 
            -
                    try:
         | 
| 208 | 
            -
                        username = whoami()["name"]
         | 
| 209 | 
            -
                    except:
         | 
| 210 | 
            -
                        raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
         | 
| 211 | 
            -
                    config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
         | 
| 212 | 
            -
                    config["config"]["process"][0]["save"]["hf_private"] = True
         | 
| 213 | 
            -
                if concept_sentence:
         | 
| 214 | 
            -
                    config["config"]["process"][0]["trigger_word"] = concept_sentence
         | 
| 215 | 
            -
                
         | 
| 216 | 
            -
                if sample_1 or sample_2 or sample_3:
         | 
| 217 | 
            -
                    config["config"]["process"][0]["train"]["disable_sampling"] = False
         | 
| 218 | 
            -
                    config["config"]["process"][0]["sample"]["sample_every"] = steps
         | 
| 219 | 
            -
                    config["config"]["process"][0]["sample"]["sample_steps"] = 28
         | 
| 220 | 
            -
                    config["config"]["process"][0]["sample"]["prompts"] = []
         | 
| 221 | 
            -
                    if sample_1:
         | 
| 222 | 
            -
                        config["config"]["process"][0]["sample"]["prompts"].append(sample_1)
         | 
| 223 | 
            -
                    if sample_2:
         | 
| 224 | 
            -
                        config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
         | 
| 225 | 
            -
                    if sample_3:
         | 
| 226 | 
            -
                        config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
         | 
| 227 | 
            -
                else:
         | 
| 228 | 
            -
                    config["config"]["process"][0]["train"]["disable_sampling"] = True
         | 
| 229 | 
            -
                if(model_to_train == "schnell"):
         | 
| 230 | 
            -
                    config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
         | 
| 231 | 
            -
                    config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
         | 
| 232 | 
            -
                    config["config"]["process"][0]["sample"]["sample_steps"] = 4
         | 
| 233 | 
            -
                if(use_more_advanced_options):
         | 
| 234 | 
            -
                    more_advanced_options_dict = yaml.safe_load(more_advanced_options)
         | 
| 235 | 
            -
                    config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
         | 
| 236 | 
            -
                    print(config)
         | 
| 237 | 
            -
                
         | 
| 238 | 
            -
                # Save the updated config
         | 
| 239 | 
            -
                # generate a random name for the config
         | 
| 240 | 
            -
                random_config_name = str(uuid.uuid4())
         | 
| 241 | 
            -
                os.makedirs("tmp", exist_ok=True)
         | 
| 242 | 
            -
                config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
         | 
| 243 | 
            -
                with open(config_path, "w") as f:
         | 
| 244 | 
            -
                    yaml.dump(config, f)
         | 
| 245 | 
            -
                
         | 
| 246 | 
            -
                # run the job locally
         | 
| 247 | 
            -
                job = get_job(config_path)
         | 
| 248 | 
            -
                job.run()
         | 
| 249 | 
            -
                job.cleanup()
         | 
| 250 | 
            -
             | 
| 251 | 
            -
                return f"Training completed successfully. Model saved as {slugged_lora_name}"
         | 
| 252 | 
            -
             | 
| 253 | 
            -
            config_yaml = '''
         | 
| 254 | 
            -
            device: cuda:0
         | 
| 255 | 
            -
            model:
         | 
| 256 | 
            -
              is_flux: true
         | 
| 257 | 
            -
              quantize: true
         | 
| 258 | 
            -
            network:
         | 
| 259 | 
            -
              linear: 16 #it will overcome the 'rank' parameter
         | 
| 260 | 
            -
              linear_alpha: 16 #you can have an alpha different than the ranking if you'd like
         | 
| 261 | 
            -
              type: lora
         | 
| 262 | 
            -
            sample:
         | 
| 263 | 
            -
              guidance_scale: 3.5
         | 
| 264 | 
            -
              height: 1024
         | 
| 265 | 
            -
              neg: '' #doesn't work for FLUX
         | 
| 266 | 
            -
              sample_every: 1000
         | 
| 267 | 
            -
              sample_steps: 28
         | 
| 268 | 
            -
              sampler: flowmatch
         | 
| 269 | 
            -
              seed: 42
         | 
| 270 | 
            -
              walk_seed: true
         | 
| 271 | 
            -
              width: 1024
         | 
| 272 | 
            -
            save:
         | 
| 273 | 
            -
              dtype: float16
         | 
| 274 | 
            -
              hf_private: true
         | 
| 275 | 
            -
              max_step_saves_to_keep: 4
         | 
| 276 | 
            -
              push_to_hub: true
         | 
| 277 | 
            -
              save_every: 10000
         | 
| 278 | 
            -
            train:
         | 
| 279 | 
            -
              batch_size: 1
         | 
| 280 | 
            -
              dtype: bf16
         | 
| 281 | 
            -
              ema_config:
         | 
| 282 | 
            -
                ema_decay: 0.99
         | 
| 283 | 
            -
                use_ema: true
         | 
| 284 | 
            -
              gradient_accumulation_steps: 1
         | 
| 285 | 
            -
              gradient_checkpointing: true
         | 
| 286 | 
            -
              noise_scheduler: flowmatch 
         | 
| 287 | 
            -
              optimizer: adamw8bit #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit
         | 
| 288 | 
            -
              train_text_encoder: false #probably doesn't work for flux
         | 
| 289 | 
            -
              train_unet: true
         | 
| 290 | 
            -
            '''
         | 
| 291 | 
            -
             | 
| 292 | 
            -
            theme = gr.themes.Monochrome(
         | 
| 293 | 
            -
                text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
         | 
| 294 | 
            -
                font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
         | 
| 295 | 
            -
            )
         | 
| 296 | 
            -
            css = """
         | 
| 297 | 
            -
            h1{font-size: 2em}
         | 
| 298 | 
            -
            h3{margin-top: 0}
         | 
| 299 | 
            -
            #component-1{text-align:center}
         | 
| 300 | 
            -
            .main_ui_logged_out{opacity: 0.3; pointer-events: none}
         | 
| 301 | 
            -
            .tabitem{border: 0px}
         | 
| 302 | 
            -
            .group_padding{padding: .55em}
         | 
| 303 | 
            -
            """
         | 
| 304 | 
            -
            with gr.Blocks(theme=theme, css=css) as demo:
         | 
| 305 | 
            -
                gr.Markdown(
         | 
| 306 | 
            -
                    """# LoRA Ease for FLUX 🧞♂️
         | 
| 307 | 
            -
            ### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit)"""
         | 
| 308 | 
            -
                )
         | 
| 309 | 
            -
                with gr.Column() as main_ui:
         | 
| 310 | 
            -
                    with gr.Row():
         | 
| 311 | 
            -
                        lora_name = gr.Textbox(
         | 
| 312 | 
            -
                            label="The name of your LoRA",
         | 
| 313 | 
            -
                            info="This has to be a unique name",
         | 
| 314 | 
            -
                            placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
         | 
| 315 | 
            -
                        )
         | 
| 316 | 
            -
                        concept_sentence = gr.Textbox(
         | 
| 317 | 
            -
                            label="Trigger word/sentence",
         | 
| 318 | 
            -
                            info="Trigger word or sentence to be used",
         | 
| 319 | 
            -
                            placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
         | 
| 320 | 
            -
                            interactive=True,
         | 
| 321 | 
            -
                        )
         | 
| 322 | 
            -
                    with gr.Group(visible=True) as image_upload:
         | 
| 323 | 
            -
                        with gr.Row():
         | 
| 324 | 
            -
                            images = gr.File(
         | 
| 325 | 
            -
                                file_types=["image", ".txt"],
         | 
| 326 | 
            -
                                label="Upload your images",
         | 
| 327 | 
            -
                                file_count="multiple",
         | 
| 328 | 
            -
                                interactive=True,
         | 
| 329 | 
            -
                                visible=True,
         | 
| 330 | 
            -
                                scale=1,
         | 
| 331 | 
            -
                            )
         | 
| 332 | 
            -
                            with gr.Column(scale=3, visible=False) as captioning_area:
         | 
| 333 | 
            -
                                with gr.Column():
         | 
| 334 | 
            -
                                    gr.Markdown(
         | 
| 335 | 
            -
                                        """# Custom captioning
         | 
| 336 | 
            -
            <p style="margin-top:0">You can optionally add a custom caption for each image (or use an AI model for this). [trigger] will represent your concept sentence/trigger word.</p>
         | 
| 337 | 
            -
            """, elem_classes="group_padding")
         | 
| 338 | 
            -
                                    do_captioning = gr.Button("Add AI captions with Florence-2")
         | 
| 339 | 
            -
                                    output_components = [captioning_area]
         | 
| 340 | 
            -
                                    caption_list = []
         | 
| 341 | 
            -
                                    for i in range(1, MAX_IMAGES + 1):
         | 
| 342 | 
            -
                                        locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
         | 
| 343 | 
            -
                                        with locals()[f"captioning_row_{i}"]:
         | 
| 344 | 
            -
                                            locals()[f"image_{i}"] = gr.Image(
         | 
| 345 | 
            -
                                                type="filepath",
         | 
| 346 | 
            -
                                                width=111,
         | 
| 347 | 
            -
                                                height=111,
         | 
| 348 | 
            -
                                                min_width=111,
         | 
| 349 | 
            -
                                                interactive=False,
         | 
| 350 | 
            -
                                                scale=2,
         | 
| 351 | 
            -
                                                show_label=False,
         | 
| 352 | 
            -
                                                show_share_button=False,
         | 
| 353 | 
            -
                                                show_download_button=False,
         | 
| 354 | 
            -
                                            )
         | 
| 355 | 
            -
                                            locals()[f"caption_{i}"] = gr.Textbox(
         | 
| 356 | 
            -
                                                label=f"Caption {i}", scale=15, interactive=True
         | 
| 357 | 
            -
                                            )
         | 
| 358 | 
            -
             | 
| 359 | 
            -
                                        output_components.append(locals()[f"captioning_row_{i}"])
         | 
| 360 | 
            -
                                        output_components.append(locals()[f"image_{i}"])
         | 
| 361 | 
            -
                                        output_components.append(locals()[f"caption_{i}"])
         | 
| 362 | 
            -
                                        caption_list.append(locals()[f"caption_{i}"])
         | 
| 363 | 
            -
             | 
| 364 | 
            -
                    with gr.Accordion("Advanced options", open=False):
         | 
| 365 | 
            -
                        steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
         | 
| 366 | 
            -
                        lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
         | 
| 367 | 
            -
                        rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
         | 
| 368 | 
            -
                        model_to_train = gr.Radio(["dev", "schnell"], value="dev", label="Model to train")
         | 
| 369 | 
            -
                        low_vram = gr.Checkbox(label="Low VRAM", value=True)
         | 
| 370 | 
            -
                        with gr.Accordion("Even more advanced options", open=False):
         | 
| 371 | 
            -
                            use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False)
         | 
| 372 | 
            -
                            more_advanced_options = gr.Code(config_yaml, language="yaml")
         | 
| 373 | 
            -
             | 
| 374 | 
            -
                    with gr.Accordion("Sample prompts (optional)", visible=False) as sample:
         | 
| 375 | 
            -
                        gr.Markdown(
         | 
| 376 | 
            -
                            "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)"
         | 
| 377 | 
            -
                        )
         | 
| 378 | 
            -
                        sample_1 = gr.Textbox(label="Test prompt 1")
         | 
| 379 | 
            -
                        sample_2 = gr.Textbox(label="Test prompt 2")
         | 
| 380 | 
            -
                        sample_3 = gr.Textbox(label="Test prompt 3")
         | 
| 381 | 
            -
                    
         | 
| 382 | 
            -
                    output_components.append(sample)
         | 
| 383 | 
            -
                    output_components.append(sample_1)
         | 
| 384 | 
            -
                    output_components.append(sample_2)
         | 
| 385 | 
            -
                    output_components.append(sample_3)
         | 
| 386 | 
            -
                    start = gr.Button("Start training", visible=False)
         | 
| 387 | 
            -
                    output_components.append(start)
         | 
| 388 | 
            -
                    progress_area = gr.Markdown("")
         | 
| 389 | 
            -
             | 
| 390 | 
            -
                dataset_folder = gr.State()
         | 
| 391 | 
            -
             | 
| 392 | 
            -
                images.upload(
         | 
| 393 | 
            -
                    load_captioning,
         | 
| 394 | 
            -
                    inputs=[images, concept_sentence],
         | 
| 395 | 
            -
                    outputs=output_components
         | 
| 396 | 
            -
                )
         | 
| 397 | 
            -
                
         | 
| 398 | 
            -
                images.delete(
         | 
| 399 | 
            -
                    load_captioning,
         | 
| 400 | 
            -
                    inputs=[images, concept_sentence],
         | 
| 401 | 
            -
                    outputs=output_components
         | 
| 402 | 
            -
                )
         | 
| 403 | 
            -
             | 
| 404 | 
            -
                images.clear(
         | 
| 405 | 
            -
                    hide_captioning,
         | 
| 406 | 
            -
                    outputs=[captioning_area, sample, start]
         | 
| 407 | 
            -
                )
         | 
| 408 | 
            -
                
         | 
| 409 | 
            -
                start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
         | 
| 410 | 
            -
                    fn=start_training,
         | 
| 411 | 
            -
                    inputs=[
         | 
| 412 | 
            -
                        lora_name,
         | 
| 413 | 
            -
                        concept_sentence,
         | 
| 414 | 
            -
                        steps,
         | 
| 415 | 
            -
                        lr,
         | 
| 416 | 
            -
                        rank,
         | 
| 417 | 
            -
                        model_to_train,
         | 
| 418 | 
            -
                        low_vram,
         | 
| 419 | 
            -
                        dataset_folder,
         | 
| 420 | 
            -
                        sample_1,
         | 
| 421 | 
            -
                        sample_2,
         | 
| 422 | 
            -
                        sample_3,
         | 
| 423 | 
            -
                        use_more_advanced_options,
         | 
| 424 | 
            -
                        more_advanced_options
         | 
| 425 | 
            -
                    ],
         | 
| 426 | 
            -
                    outputs=progress_area,
         | 
| 427 | 
            -
                )
         | 
| 428 | 
            -
             | 
| 429 | 
            -
                do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
         | 
| 430 |  | 
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # script.py
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2 |  | 
| 3 | 
            +
            def greet(name):
         | 
| 4 | 
            +
                print(f"Hello, {name}!")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            +
            if __name__ == "__main__":
         | 
| 7 | 
            +
                import sys
         | 
| 8 | 
            +
                name = sys.argv[1] if len(sys.argv) > 1 else "World"
         | 
| 9 | 
            +
                greet(name)
         | 
