Spaces:
Running
on
Zero
Running
on
Zero
| from torch.utils.data import DataLoader | |
| import torch | |
| import lightning as L | |
| import yaml | |
| import os | |
| import time | |
| from datasets import load_dataset | |
| from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset | |
| from .model import OminiModel | |
| from .callbacks import TrainingCallback | |
| def get_rank(): | |
| try: | |
| rank = int(os.environ.get("LOCAL_RANK")) | |
| except: | |
| rank = 0 | |
| return rank | |
| def get_config(): | |
| config_path = os.environ.get("XFL_CONFIG") | |
| assert config_path is not None, "Please set the XFL_CONFIG environment variable" | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def init_wandb(wandb_config, run_name): | |
| import wandb | |
| try: | |
| assert os.environ.get("WANDB_API_KEY") is not None | |
| wandb.init( | |
| project=wandb_config["project"], | |
| name=run_name, | |
| config={}, | |
| ) | |
| except Exception as e: | |
| print("Failed to initialize WanDB:", e) | |
| def main(): | |
| # Initialize | |
| is_main_process, rank = get_rank() == 0, get_rank() | |
| torch.cuda.set_device(rank) | |
| config = get_config() | |
| training_config = config["train"] | |
| run_name = time.strftime("%Y%m%d-%H%M%S") | |
| # Initialize WanDB | |
| wandb_config = training_config.get("wandb", None) | |
| if wandb_config is not None and is_main_process: | |
| init_wandb(wandb_config, run_name) | |
| print("Rank:", rank) | |
| if is_main_process: | |
| print("Config:", config) | |
| # Initialize dataset and dataloader | |
| if training_config["dataset"]["type"] == "subject": | |
| dataset = load_dataset("Yuanshi/Subjects200K") | |
| # Define filter function | |
| def filter_func(item): | |
| if not item.get("quality_assessment"): | |
| return False | |
| return all( | |
| item["quality_assessment"].get(key, 0) >= 5 | |
| for key in ["compositeStructure", "objectConsistency", "imageQuality"] | |
| ) | |
| # Filter dataset | |
| if not os.path.exists("./cache/dataset"): | |
| os.makedirs("./cache/dataset") | |
| data_valid = dataset["train"].filter( | |
| filter_func, | |
| num_proc=16, | |
| cache_file_name="./cache/dataset/data_valid.arrow", | |
| ) | |
| dataset = Subject200KDataset( | |
| data_valid, | |
| condition_size=training_config["dataset"]["condition_size"], | |
| target_size=training_config["dataset"]["target_size"], | |
| image_size=training_config["dataset"]["image_size"], | |
| padding=training_config["dataset"]["padding"], | |
| condition_type=training_config["condition_type"], | |
| drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
| drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
| ) | |
| elif training_config["dataset"]["type"] == "img": | |
| # Load dataset text-to-image-2M | |
| dataset = load_dataset( | |
| "webdataset", | |
| data_files={"train": training_config["dataset"]["urls"]}, | |
| split="train", | |
| cache_dir="cache/t2i2m", | |
| num_proc=32, | |
| ) | |
| dataset = ImageConditionDataset( | |
| dataset, | |
| condition_size=training_config["dataset"]["condition_size"], | |
| target_size=training_config["dataset"]["target_size"], | |
| condition_type=training_config["condition_type"], | |
| drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
| drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
| position_scale=training_config["dataset"].get("position_scale", 1.0), | |
| ) | |
| elif training_config["dataset"]["type"] == "cartoon": | |
| dataset = load_dataset("saquiboye/oye-cartoon", split="train") | |
| dataset = CartoonDataset( | |
| dataset, | |
| condition_size=training_config["dataset"]["condition_size"], | |
| target_size=training_config["dataset"]["target_size"], | |
| image_size=training_config["dataset"]["image_size"], | |
| padding=training_config["dataset"]["padding"], | |
| condition_type=training_config["condition_type"], | |
| drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
| drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
| ) | |
| else: | |
| raise NotImplementedError | |
| print("Dataset length:", len(dataset)) | |
| train_loader = DataLoader( | |
| dataset, | |
| batch_size=training_config["batch_size"], | |
| shuffle=True, | |
| num_workers=training_config["dataloader_workers"], | |
| ) | |
| # Initialize model | |
| trainable_model = OminiModel( | |
| flux_pipe_id=config["flux_path"], | |
| lora_config=training_config["lora_config"], | |
| device=f"cuda", | |
| dtype=getattr(torch, config["dtype"]), | |
| optimizer_config=training_config["optimizer"], | |
| model_config=config.get("model", {}), | |
| gradient_checkpointing=training_config.get("gradient_checkpointing", False), | |
| ) | |
| # Callbacks for logging and saving checkpoints | |
| training_callbacks = ( | |
| [TrainingCallback(run_name, training_config=training_config)] | |
| if is_main_process | |
| else [] | |
| ) | |
| # Initialize trainer | |
| trainer = L.Trainer( | |
| accumulate_grad_batches=training_config["accumulate_grad_batches"], | |
| callbacks=training_callbacks, | |
| enable_checkpointing=False, | |
| enable_progress_bar=False, | |
| logger=False, | |
| max_steps=training_config.get("max_steps", -1), | |
| max_epochs=training_config.get("max_epochs", -1), | |
| gradient_clip_val=training_config.get("gradient_clip_val", 0.5), | |
| ) | |
| setattr(trainer, "training_config", training_config) | |
| # Save config | |
| save_path = training_config.get("save_path", "./output") | |
| if is_main_process: | |
| os.makedirs(f"{save_path}/{run_name}") | |
| with open(f"{save_path}/{run_name}/config.yaml", "w") as f: | |
| yaml.dump(config, f) | |
| # Start training | |
| trainer.fit(trainable_model, train_loader) | |
| if __name__ == "__main__": | |
| main() | |