Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import shutil | |
| import zipfile | |
| import tensorflow as tf | |
| import pandas as pd | |
| import pathlib | |
| import PIL.Image | |
| import os | |
| import subprocess | |
| def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: | |
| w, h = image.size | |
| if w == h: | |
| return image | |
| elif w > h: | |
| new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) | |
| new_image.paste(image, (0, (w - h) // 2)) | |
| return new_image | |
| else: | |
| new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) | |
| new_image.paste(image, ((h - w) // 2, 0)) | |
| return new_image | |
| class ModelTrainer: | |
| def __init__(self): | |
| self.training_pictures = [] | |
| self.training_model = None | |
| def unzip_file(self, zip_file_path): | |
| with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
| extracted_path = zip_file_path.replace('.zip', '') | |
| zip_ref.extractall(extracted_path) | |
| file_names = zip_ref.namelist() | |
| for file_name in file_names: | |
| if file_name.endswith(('.jpeg', '.jpg', '.png')): | |
| self.training_pictures.append(f'{extracted_path}/{file_name}') | |
| def train(self, pretrained_model_name_or_path: str, instance_images: list | None): | |
| output_model_name = 'a-xyz-model' | |
| resolution = 512 | |
| repo_dir = pathlib.Path(__file__).parent | |
| subdirs = ['train-instance', 'train-class', 'experiments'] | |
| dir_paths = [] | |
| for subdir in subdirs: | |
| dir_path = repo_dir / subdir / output_model_name | |
| dir_paths.append(dir_path) | |
| shutil.rmtree(dir_path, ignore_errors=True) | |
| os.makedirs(dir_path, exist_ok=True) | |
| instance_data_dir, class_data_dir, output_dir = dir_paths | |
| for i, temp_path in enumerate(instance_images): | |
| image = PIL.Image.open(temp_path.name) | |
| image = pad_image(image) | |
| image = image.resize((resolution, resolution)) | |
| image = image.convert('RGB') | |
| out_path = instance_data_dir / f'{i:03d}.jpg' | |
| image.save(out_path, format='JPEG', quality=100) | |
| command = [ | |
| 'python', '-u', | |
| 'train_dreambooth_cloneofsimo_lora.py', | |
| '--pretrained_model_name_or_path', pretrained_model_name_or_path, | |
| '--instance_data_dir', instance_data_dir, | |
| '--class_data_dir', class_data_dir, | |
| '--resolution', '768', | |
| '--output_dir', output_dir, | |
| '--instance_prompt', 'a photo of a pwsm dog', | |
| '--with_prior_preservation', | |
| '--class_prompt', 'a dog', | |
| '--prior_loss_weight', '1.0', | |
| '--num_class_images', '100', | |
| '--learning_rate', '0.0004', | |
| '--train_batch_size', '1', | |
| '--sample_batch_size', '1', | |
| '--max_train_steps', '400', | |
| '--gradient_accumulation_steps', '1', | |
| '--gradient_checkpointing', | |
| '--train_text_encoder', | |
| '--learning_rate_text', '5e-6', | |
| '--save_steps', '100', | |
| '--seed', '1337', | |
| '--lr_scheduler', 'constant', | |
| '--lr_warmup_steps', '0' | |
| ] | |
| result = subprocess.run(command) | |
| return result | |
| def generate_picture(self, row): | |
| num_of_training_steps, learning_rate, checkpoint_steps, abc = row | |
| return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}' | |
| def generate_pictures(self, csv_input): | |
| csv = pd.read_csv(csv_input.name) | |
| result = [] | |
| for index, row in csv.iterrows(): | |
| result.append(self.generate_picture(row)) | |
| return "\n".join(str(item) for item in result) | |
| loader = ModelTrainer() | |
| with gr.Blocks() as demo: | |
| with gr.Box(): | |
| instance_images = gr.Files(label='Instance images') | |
| pretrained_model_name_or_path = gr.Textbox(lines=1, label='pretrained_model_name_or_path', placeholder='stabilityai/stable-diffusion-2-1') | |
| output_message = gr.Markdown() | |
| train_button = gr.Button('Train') | |
| train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message]) | |
| with gr.Box(): | |
| csv_input = gr.File(label='CSV File') | |
| output_message2 = gr.Markdown() | |
| generate_button = gr.Button('Generate Pictures from CSV') | |
| generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2]) | |
| demo.launch() | |