Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Optional, Any, Tuple | |
| import time | |
| import json | |
| import gradio as gr | |
| import spaces | |
| # Local modules | |
| from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR | |
| # Defaults matching train_QIE.sh expectations | |
| DEFAULT_DATA_ROOT = "/data" | |
| DEFAULT_IMAGE_FOLDER = "image" | |
| DEFAULT_OUTPUT_DIR_BASE = "/auto/train_LoRA" | |
| DEFAULT_DATASET_CONFIG = "/auto/dataset_QIE.toml" | |
| DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR # "/Qwen-Image_models" | |
| WORKSPACE_AUTO_DIR = "/auto" | |
| # musubi-tuner settings | |
| DEFAULT_MUSUBI_TUNER_DIR = os.environ.get("MUSUBI_TUNER_DIR", "/musubi-tuner") | |
| DEFAULT_MUSUBI_TUNER_REPO = os.environ.get( | |
| "MUSUBI_TUNER_REPO", "https://github.com/kohya-ss/musubi-tuner.git" | |
| ) | |
| TRAINING_DIR = Path(__file__).resolve().parent | |
| # Runtime-resolved paths with fallbacks for non-root environments | |
| MUSUBI_TUNER_DIR_RUNTIME = DEFAULT_MUSUBI_TUNER_DIR | |
| MODELS_ROOT_RUNTIME = DEFAULT_MODELS_ROOT | |
| AUTO_DIR_RUNTIME = WORKSPACE_AUTO_DIR | |
| DATA_ROOT_RUNTIME = DEFAULT_DATA_ROOT | |
| def _bash_quote(s: str) -> str: | |
| """Return a POSIX-safe single-quoted string literal representing s.""" | |
| if s is None: | |
| return "''" | |
| return "'" + str(s).replace("'", "'\"'\"'") + "'" | |
| def _ensure_workspace_auto_files() -> None: | |
| """Ensure /workspace/auto has required helper files from this repo. | |
| Copies training/create_image_caption_json.py and training/dataset_QIE.toml | |
| into /workspace/auto so that train_QIE.sh can run unmodified. | |
| """ | |
| global AUTO_DIR_RUNTIME | |
| try: | |
| os.makedirs(AUTO_DIR_RUNTIME, exist_ok=True) | |
| except PermissionError: | |
| home_auto = os.path.join(os.path.expanduser("~"), "auto") | |
| os.makedirs(home_auto, exist_ok=True) | |
| AUTO_DIR_RUNTIME = home_auto # type: ignore | |
| src_py = TRAINING_DIR / "create_image_caption_json.py" | |
| src_toml = TRAINING_DIR / "dataset_QIE.toml" | |
| dst_py = Path(AUTO_DIR_RUNTIME) / "create_image_caption_json.py" | |
| dst_toml = Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml" | |
| try: | |
| shutil.copy2(src_py, dst_py) | |
| except Exception: | |
| pass | |
| try: | |
| if src_toml.exists(): | |
| shutil.copy2(src_toml, dst_toml) | |
| except Exception: | |
| pass | |
| def _ensure_dir_writable(path: str) -> str: | |
| try: | |
| os.makedirs(path, exist_ok=True) | |
| return path | |
| except PermissionError: | |
| home_path = os.path.join(os.path.expanduser("~"), os.path.basename(path.strip("/\\"))) | |
| os.makedirs(home_path, exist_ok=True) | |
| return home_path | |
| def _ensure_data_root(candidate: Optional[str]) -> str: | |
| root = (candidate or DEFAULT_DATA_ROOT).strip() or DEFAULT_DATA_ROOT | |
| try: | |
| os.makedirs(root, exist_ok=True) | |
| return root | |
| except PermissionError: | |
| home_root = os.path.join(os.path.expanduser("~"), "data") | |
| os.makedirs(home_root, exist_ok=True) | |
| return home_root | |
| def _extract_paths(files: Any) -> List[Tuple[str, str]]: | |
| """Extract a list of (abs_path, orig_basename) from Gradio Files input. | |
| Supports various gradio return shapes across versions. | |
| """ | |
| out: List[Tuple[str, str]] = [] | |
| if not files: | |
| return out | |
| # Gradio Files often returns a list | |
| if isinstance(files, (list, tuple)): | |
| items = files | |
| else: | |
| items = [files] | |
| for item in items: | |
| p: Optional[str] = None | |
| orig: Optional[str] = None | |
| # dict-like | |
| if isinstance(item, dict): | |
| p = item.get("path") or item.get("name") or item.get("file") | |
| orig = item.get("orig_name") or item.get("name") | |
| else: | |
| # object with attributes | |
| p = getattr(item, "name", None) or getattr(item, "path", None) or str(item) | |
| # best-effort original name attribute | |
| orig = getattr(item, "orig_name", None) or os.path.basename(p) if p else None | |
| if p: | |
| abs_p = os.path.abspath(p) | |
| out.append((abs_p, os.path.basename(orig or abs_p))) | |
| return out | |
| def _norm_key(filename: str, prefix: str, suffix: str) -> str: | |
| stem = os.path.splitext(os.path.basename(filename))[0] | |
| if prefix and stem.startswith(prefix): | |
| stem = stem[len(prefix):] | |
| if suffix and stem.endswith(suffix): | |
| stem = stem[: -len(suffix)] | |
| return stem | |
| def _copy_uploads( | |
| uploads: List[Tuple[str, str]], dest_dir: str, rename_to: Optional[List[str]] = None | |
| ) -> List[str]: | |
| os.makedirs(dest_dir, exist_ok=True) | |
| used_names: List[str] = [] | |
| for idx, (src, orig) in enumerate(uploads): | |
| # Determine target stem | |
| if rename_to and idx < len(rename_to): | |
| stem = os.path.splitext(rename_to[idx])[0] | |
| else: | |
| stem = os.path.splitext(orig)[0] | |
| dst_name = f"{stem}.png" | |
| # ensure unique within this batch | |
| final_name = dst_name | |
| dup_idx = 1 | |
| while final_name in used_names: | |
| final_name = f"{stem}_{dup_idx}.png" | |
| dup_idx += 1 | |
| dst_path = os.path.join(dest_dir, final_name) | |
| # Convert to PNG during save | |
| try: | |
| try: | |
| from PIL import Image # type: ignore | |
| with Image.open(src) as img: | |
| img.save(dst_path, format="PNG") | |
| except Exception: | |
| # Fallback: copy then rename | |
| shutil.copy2(src, dst_path) | |
| except Exception: | |
| # Last resort | |
| shutil.copy(src, dst_path) | |
| used_names.append(final_name) | |
| return used_names | |
| def _list_checkpoints(out_dir: str, limit: int = 20) -> List[str]: | |
| try: | |
| if not out_dir or not os.path.isdir(out_dir): | |
| return [] | |
| items: List[Tuple[float, str]] = [] | |
| for root, _, files in os.walk(out_dir): | |
| for fn in files: | |
| if fn.lower().endswith('.safetensors'): | |
| full = os.path.join(root, fn) | |
| try: | |
| items.append((os.path.getmtime(full), full)) | |
| except Exception: | |
| pass | |
| items.sort(reverse=True) | |
| return [p for _, p in items[:limit]] | |
| except Exception: | |
| return [] | |
| def _files_to_gallery(files: Any) -> List[str]: | |
| items: List[str] = [] | |
| if not files: | |
| return items | |
| seq = files if isinstance(files, (list, tuple)) else [files] | |
| for f in seq: | |
| p = None | |
| if isinstance(f, str): | |
| p = f | |
| elif isinstance(f, dict): | |
| p = f.get("path") or f.get("name") | |
| else: | |
| p = getattr(f, "path", None) or getattr(f, "name", None) | |
| if p: | |
| items.append(p) | |
| return items | |
| def _prepare_script( | |
| dataset_name: str, | |
| caption: str, | |
| data_root: str, | |
| image_folder: str, | |
| control_folders: List[Optional[str]], | |
| models_root: str, | |
| output_dir_base: Optional[str] = None, | |
| dataset_config: Optional[str] = None, | |
| override_max_epochs: Optional[int] = None, | |
| override_save_every: Optional[int] = None, | |
| override_run_name: Optional[str] = None, | |
| target_prefix: Optional[str] = None, | |
| target_suffix: Optional[str] = None, | |
| control_prefixes: Optional[List[Optional[str]]] = None, | |
| control_suffixes: Optional[List[Optional[str]]] = None, | |
| override_learning_rate: Optional[str] = None, | |
| override_network_dim: Optional[int] = None, | |
| override_seed: Optional[int] = None, | |
| ) -> Path: | |
| """Create a temporary copy of train_QIE.sh with injected variables. | |
| Only variables that must vary per-run are replaced. The rest of the script | |
| remains as-is to preserve behavior. | |
| """ | |
| src = TRAINING_DIR / "train_QIE.sh" | |
| txt = src.read_text(encoding="utf-8") | |
| # Replace core variables | |
| replacements = { | |
| r"^DATA_ROOT=\".*\"": f"DATA_ROOT={_bash_quote(data_root)}", | |
| r"^DATASET_NAME=\".*\"": f"DATASET_NAME={_bash_quote(dataset_name)}", | |
| r"^CAPTION=\".*\"": f"CAPTION={_bash_quote(caption)}", | |
| r"^IMAGE_FOLDER=\".*\"": f"IMAGE_FOLDER={_bash_quote(image_folder)}", | |
| } | |
| if output_dir_base: | |
| replacements[r"^OUTPUT_DIR_BASE=\".*\""] = ( | |
| f"OUTPUT_DIR_BASE={_bash_quote(output_dir_base)}" | |
| ) | |
| if dataset_config: | |
| replacements[r"^DATASET_CONFIG=\".*\""] = ( | |
| f"DATASET_CONFIG={_bash_quote(dataset_config)}" | |
| ) | |
| for pat, val in replacements.items(): | |
| txt = re.sub(pat, val, txt, flags=re.MULTILINE) | |
| # Inject CONTROL_FOLDER_i if provided (uncomment/override or append) | |
| for i in range(8): | |
| val = control_folders[i] if i < len(control_folders) else None | |
| if not val: | |
| continue | |
| # Try to replace commented placeholder first | |
| pattern = rf"^#\s*CONTROL_FOLDER_{i}=\".*\"" | |
| if re.search(pattern, txt, flags=re.MULTILINE): | |
| txt = re.sub( | |
| pattern, | |
| f"CONTROL_FOLDER_{i}={_bash_quote(val)}", | |
| txt, | |
| flags=re.MULTILINE, | |
| ) | |
| else: | |
| # Append after IMAGE_FOLDER definition | |
| txt = re.sub( | |
| r"^(IMAGE_FOLDER=.*)$", | |
| rf"\1\nCONTROL_FOLDER_{i}={_bash_quote(val)}", | |
| txt, | |
| count=1, | |
| flags=re.MULTILINE, | |
| ) | |
| # Point model paths to the selected models_root | |
| def _replace_model_path(txt: str, key: str, rel: str) -> str: | |
| return re.sub( | |
| rf"--{key} \"[^\"]+\"", | |
| f"--{key} \"{models_root.rstrip('/')}/{rel}\"", | |
| txt, | |
| ) | |
| txt = _replace_model_path(txt, "vae", "vae/diffusion_pytorch_model.safetensors") | |
| txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors") | |
| txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors") | |
| # Replace working dir for metadata generation to runtime /auto | |
| txt = re.sub(r"^cd\s+/workspace/auto\s*$", f"cd {AUTO_DIR_RUNTIME}", txt, flags=re.MULTILINE) | |
| # Ensure musubi-tuner path matches runtime location | |
| txt = re.sub(r"^cd\s+/musubi-tuner\s*$", f"cd {re.escape(MUSUBI_TUNER_DIR_RUNTIME)}", txt, flags=re.MULTILINE) | |
| # ZeroGPU compatibility: avoid spawning via 'accelerate launch'. | |
| # Run the training module directly in-process so GPU stays attached | |
| # to the same Python request context. | |
| txt = re.sub( | |
| r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py", | |
| r"python src/musubi_tuner/qwen_image_train_network.py", | |
| txt, | |
| flags=re.MULTILINE, | |
| ) | |
| # Optionally override epochs and save frequency for ZeroGPU time slicing | |
| if override_max_epochs is not None and override_max_epochs > 0: | |
| txt = re.sub(r"--max_train_epochs\s+\d+", | |
| f"--max_train_epochs {override_max_epochs}", txt) | |
| if override_save_every is not None and override_save_every > 0: | |
| txt = re.sub(r"--save_every_n_epochs\s+\d+", | |
| f"--save_every_n_epochs {override_save_every}", txt) | |
| if override_run_name: | |
| txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE) | |
| # Inject prefix/suffix flags for metadata creation | |
| extra_lines: List[str] = [] | |
| if (target_prefix or ""): | |
| extra_lines.append(f" --main_prefix {_bash_quote(target_prefix)} \\") | |
| if (target_suffix or ""): | |
| extra_lines.append(f" --main_suffix {_bash_quote(target_suffix)} \\") | |
| for i in range(8): | |
| pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None | |
| suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None | |
| if pre: | |
| extra_lines.append(f" --control_prefix_{i} {_bash_quote(pre)} \\") | |
| if suf: | |
| extra_lines.append(f" --control_suffix_{i} {_bash_quote(suf)} \\") | |
| if extra_lines: | |
| extra_block = "\n".join(extra_lines) | |
| txt = re.sub(r'^(\s*)"\$\{CONTROL_ARGS\[@\]\}"', rf"{extra_block}\n\1\"\${{CONTROL_ARGS[@]}}\"", txt, flags=re.MULTILINE) | |
| # Override CLI hyperparameters if provided | |
| if override_learning_rate: | |
| txt = re.sub(r"--learning_rate\s+[-+eE0-9\.]+", f"--learning_rate {override_learning_rate}", txt) | |
| if override_network_dim is not None: | |
| txt = re.sub(r"--network_dim\s+\d+", f"--network_dim {override_network_dim}", txt) | |
| if override_seed is not None: | |
| txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt) | |
| # Prefer overriding variable definitions at top of script (safer than CLI regex) | |
| def _set_var(name: str, value: str) -> None: | |
| nonlocal txt | |
| pattern = rf"(?m)^\s*{name}\s*=.*$" | |
| replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}' | |
| if re.search(pattern, txt): | |
| txt = re.sub(pattern, replacement, txt) | |
| else: | |
| txt = f"{replacement}\n" + txt | |
| if override_learning_rate: | |
| _set_var('LEARNING_RATE', override_learning_rate) | |
| if override_network_dim is not None: | |
| _set_var('NETWORK_DIM', str(override_network_dim)) | |
| if override_seed is not None: | |
| _set_var('SEED', str(override_seed)) | |
| if override_max_epochs is not None and override_max_epochs > 0: | |
| _set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs)) | |
| if override_save_every is not None and override_save_every > 0: | |
| _set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every)) | |
| # Write to a temp file alongside this repo for easier inspection | |
| run_dir = TRAINING_DIR / ".gradio_runs" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| tmp = run_dir / f"train_QIE_run_{os.getpid()}.sh" | |
| tmp.write_text(txt, encoding="utf-8", newline="\n") | |
| try: | |
| os.chmod(tmp, 0o755) | |
| except Exception: | |
| pass | |
| return tmp | |
| def _pick_shell() -> str: | |
| for sh in ("bash", "sh"): | |
| if shutil.which(sh): | |
| return sh | |
| raise RuntimeError("No POSIX shell found. Please install bash or sh.") | |
| def _is_git_repo(path: str) -> bool: | |
| try: | |
| out = subprocess.run( | |
| ["git", "-C", path, "rev-parse", "--is-inside-work-tree"], | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| ) | |
| return out.returncode == 0 and out.stdout.strip() == "true" | |
| except Exception: | |
| return False | |
| def _startup_clone_musubi_tuner() -> None: | |
| global MUSUBI_TUNER_DIR_RUNTIME | |
| target = MUSUBI_TUNER_DIR_RUNTIME | |
| repo = DEFAULT_MUSUBI_TUNER_REPO | |
| parent = os.path.dirname(target.rstrip("/\\")) or "/" | |
| try: | |
| os.makedirs(parent, exist_ok=True) | |
| except PermissionError: | |
| # Fallback to home directory | |
| target = os.path.join(os.path.expanduser("~"), "musubi-tuner") | |
| MUSUBI_TUNER_DIR_RUNTIME = target | |
| os.makedirs(os.path.dirname(target), exist_ok=True) | |
| except Exception: | |
| pass | |
| if os.path.isdir(target) and _is_git_repo(target): | |
| print(f"[QIE] musubi-tuner exists at {target}; pulling latest...") | |
| try: | |
| subprocess.run(["git", "-C", target, "fetch", "--all", "--prune"], check=False) | |
| subprocess.run(["git", "-C", target, "pull", "--ff-only"], check=False) | |
| except Exception as e: | |
| print(f"[QIE] git pull failed: {e}") | |
| return | |
| if os.path.exists(target) and not _is_git_repo(target): | |
| print(f"[QIE] Warning: {target} exists and is not a git repo. Skipping clone.") | |
| return | |
| print(f"[QIE] Cloning musubi-tuner into {target} from {repo} ...") | |
| try: | |
| subprocess.run(["git", "clone", "--depth", "1", repo, target], check=True) | |
| print("[QIE] Clone completed.") | |
| except subprocess.CalledProcessError as e: | |
| print(f"[QIE] Clone failed at {target}: {e}") | |
| # Last-chance fallback into home | |
| if not target.startswith(os.path.expanduser("~")): | |
| fallback = os.path.join(os.path.expanduser("~"), "musubi-tuner") | |
| print(f"[QIE] Retrying clone into {fallback}...") | |
| try: | |
| subprocess.run(["git", "clone", "--depth", "1", repo, fallback], check=True) | |
| MUSUBI_TUNER_DIR_RUNTIME = fallback | |
| print("[QIE] Clone completed in fallback.") | |
| except Exception as e2: | |
| print(f"[QIE] Clone failed in fallback as well: {e2}") | |
| def _run_pip(args: List[str], cwd: Optional[str] = None) -> None: | |
| cmd = [sys.executable, "-m", "pip"] + args | |
| try: | |
| print(f"[QIE] pip {' '.join(args)} (cwd={cwd or os.getcwd()})") | |
| subprocess.run(cmd, check=True, cwd=cwd) | |
| except subprocess.CalledProcessError as e: | |
| print(f"[QIE] pip failed: {e}") | |
| def _startup_install_musubi_deps() -> None: | |
| repo_dir = MUSUBI_TUNER_DIR_RUNTIME | |
| if not os.path.isdir(repo_dir): | |
| print(f"[QIE] Skip deps: musubi-tuner not found at {repo_dir}") | |
| return | |
| # Upgrade basic build tooling (best-effort) | |
| try: | |
| _run_pip(["install", "-U", "pip", "setuptools", "wheel"]) | |
| except Exception: | |
| pass | |
| # Optional Torch extra via env: MUSUBI_TUNER_TORCH_EXTRA=cu124|cu128 | |
| extra = os.environ.get("MUSUBI_TUNER_TORCH_EXTRA", "").strip() | |
| editable_spec = "." if not extra else f".[{extra}]" | |
| # Install musubi-tuner in editable mode to expose entrypoints and deps | |
| try: | |
| _run_pip(["install", "-e", editable_spec], cwd=repo_dir) | |
| except Exception: | |
| # Fallback: plain install without editable | |
| try: | |
| _run_pip(["install", editable_spec], cwd=repo_dir) | |
| except Exception: | |
| print("[QIE] WARN: musubi-tuner installation failed. Continuing.") | |
| def run_training( | |
| output_name: str, | |
| caption: str, | |
| image_uploads: Any, | |
| target_prefix: str, | |
| target_suffix: str, | |
| control0_uploads: Any, | |
| ctrl0_prefix: str, | |
| ctrl0_suffix: str, | |
| control1_uploads: Any, | |
| ctrl1_prefix: str, | |
| ctrl1_suffix: str, | |
| control2_uploads: Any, | |
| ctrl2_prefix: str, | |
| ctrl2_suffix: str, | |
| control3_uploads: Any, | |
| ctrl3_prefix: str, | |
| ctrl3_suffix: str, | |
| control4_uploads: Any, | |
| ctrl4_prefix: str, | |
| ctrl4_suffix: str, | |
| control5_uploads: Any, | |
| ctrl5_prefix: str, | |
| ctrl5_suffix: str, | |
| control6_uploads: Any, | |
| ctrl6_prefix: str, | |
| ctrl6_suffix: str, | |
| control7_uploads: Any, | |
| ctrl7_prefix: str, | |
| ctrl7_suffix: str, | |
| learning_rate: str, | |
| network_dim: int, | |
| seed: int, | |
| max_epochs: int, | |
| save_every: int, | |
| ) -> Iterable[tuple]: | |
| # Basic validation | |
| log_buf = "" | |
| ckpts: List[str] = [] | |
| if not output_name.strip(): | |
| log_buf += "[ERROR] OUTPUT NAME is required.\n" | |
| yield (log_buf, ckpts, None) | |
| return | |
| if not caption.strip(): | |
| log_buf += "[ERROR] CAPTION is required.\n" | |
| yield (log_buf, ckpts, None) | |
| return | |
| # Ensure /auto holds helper files expected by the script | |
| _ensure_workspace_auto_files() | |
| # Resolve data root and create dataset directories (auto-decide) | |
| global DATA_ROOT_RUNTIME | |
| DATA_ROOT_RUNTIME = _ensure_data_root(None) | |
| # Auto-generate dataset directory name | |
| import time | |
| ds_name = f"dataset_{int(time.time())}" | |
| ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name) | |
| img_folder_name = DEFAULT_IMAGE_FOLDER | |
| img_dir = os.path.join(ds_dir, img_folder_name) | |
| os.makedirs(img_dir, exist_ok=True) | |
| # Ingest uploads into dataset folders | |
| base_files = _extract_paths(image_uploads) | |
| if not base_files: | |
| log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n" | |
| yield (log_buf, ckpts, None) | |
| return | |
| base_filenames = _copy_uploads(base_files, img_dir) | |
| log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n" | |
| yield (log_buf, ckpts, None) | |
| # Prepare control sets | |
| control_upload_sets = [ | |
| _extract_paths(control0_uploads), | |
| _extract_paths(control1_uploads), | |
| _extract_paths(control2_uploads), | |
| _extract_paths(control3_uploads), | |
| _extract_paths(control4_uploads), | |
| _extract_paths(control5_uploads), | |
| _extract_paths(control6_uploads), | |
| _extract_paths(control7_uploads), | |
| ] | |
| # Require control_0; others optional | |
| if not control_upload_sets[0]: | |
| log_buf += "[ERROR] control_0 images are required.\n" | |
| yield (log_buf, ckpts, None) | |
| return | |
| control_dirs: List[Optional[str]] = [] | |
| for i, uploads in enumerate(control_upload_sets): | |
| if not uploads: | |
| control_dirs.append(None) | |
| continue | |
| folder_name = f"control_{i}" | |
| cdir = os.path.join(ds_dir, folder_name) | |
| os.makedirs(cdir, exist_ok=True) | |
| # Simply copy; name matching will be handled by create_image_caption_json.py | |
| _copy_uploads(uploads, cdir) | |
| control_dirs.append(folder_name) | |
| log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n" | |
| yield (log_buf, ckpts, None) | |
| # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh | |
| # Prepare script with user parameters | |
| control_folders = [ | |
| (control_dirs[i] if control_dirs[i] else None) | |
| for i in range(8) | |
| ] | |
| # Decide dataset_config path with fallback to runtime auto dir | |
| ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml") | |
| # Resolve models_root and set output_dir_base to the unique dataset dir | |
| models_root = MODELS_ROOT_RUNTIME | |
| out_base = ds_dir | |
| try: | |
| os.makedirs(out_base, exist_ok=True) | |
| except Exception: | |
| pass | |
| tmp_script = _prepare_script( | |
| dataset_name=ds_name, | |
| caption=caption, | |
| data_root=DATA_ROOT_RUNTIME, | |
| image_folder=img_folder_name, | |
| control_folders=control_folders, | |
| models_root=models_root, | |
| output_dir_base=out_base, | |
| dataset_config=ds_conf, | |
| override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None, | |
| override_save_every=save_every if save_every and save_every > 0 else None, | |
| override_run_name=output_name.strip(), | |
| target_prefix=(target_prefix or ""), | |
| target_suffix=(target_suffix or ""), | |
| control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix], | |
| control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix], | |
| override_learning_rate=(learning_rate or None), | |
| override_network_dim=int(network_dim) if network_dim is not None else None, | |
| override_seed=int(seed) if seed is not None else None, | |
| ) | |
| shell = _pick_shell() | |
| log_buf += f"[QIE] Using shell: {shell}\n" | |
| log_buf += f"[QIE] Running script: {tmp_script}\n" | |
| out_dir = os.path.join(out_base, output_name.strip()) | |
| ckpts = _list_checkpoints(out_dir) | |
| yield (log_buf, ckpts, None) | |
| # Run and stream output | |
| proc = subprocess.Popen( | |
| [shell, str(tmp_script)], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| ) | |
| try: | |
| assert proc.stdout is not None | |
| i = 0 | |
| for line in proc.stdout: | |
| log_buf += line | |
| i += 1 | |
| if i % 30 == 0: | |
| ckpts = _list_checkpoints(out_dir) | |
| yield (log_buf, ckpts, None) | |
| finally: | |
| code = proc.wait() | |
| # Try to locate latest LoRA file for download | |
| lora_path = None | |
| try: | |
| ckpts = _list_checkpoints(out_dir) | |
| except Exception: | |
| pass | |
| lora_path = ckpts[0] if ckpts else None | |
| log_buf += f"[QIE] Exit code: {code}\n" | |
| yield (log_buf, ckpts, lora_path) | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(title="Qwen-Image-Edit: Trainer") as demo: | |
| gr.Markdown(""" | |
| # Qwen-Image-Edit Trainer | |
| このページでは、学習に使う画像をアップロードし、必要ならファイル名の前後にある共通の文字(prefix/suffix)を指定して、 | |
| 自動でデータセットを作成し学習を開始します。難しい操作は不要です。 | |
| 画像と名前の対応づけ | |
| - prefix/suffix: ファイルの同名判定のため、画像のファイル名から共通の先頭/末尾文字を取り除く指定(例: IMG_ や _v2) | |
| - まずターゲット画像のファイル名(拡張子なし)から、指定した Target prefix/suffix を取り除いたものを key とします。 | |
| - 各コントロールは「付加」規則で、期待名 = control_prefix_i + key + control_suffix_i + ".png" を探して対応付けます。 | |
| - アップロード時に画像は自動で .png に変換して保存します(元のファイル名のベースは維持)。 | |
| - Control 0 は必須、Control 1〜7 は任意。コントロール画像が1枚だけのときは、すべてのターゲット画像に適用します。 | |
| """) | |
| with gr.Row(): | |
| output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1) | |
| caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2) | |
| # Training options near OUTPUT NAME | |
| with gr.Row(): | |
| lr_input = gr.Textbox(label="Learning rate", value="1e-3") | |
| dim_input = gr.Number(label="Network dim", value=4, precision=0) | |
| seed_input = gr.Number(label="Seed", value=42, precision=0) | |
| max_epochs = gr.Number(label="Max epochs", value=100, precision=0) | |
| save_every = gr.Number(label="Save every N epochs", value=10, precision=0) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath") | |
| main_gallery = gr.Gallery(label="Target preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| main_prefix = gr.Textbox(label="Target prefix", placeholder="e.g., IMG_") | |
| main_suffix = gr.Textbox(label="Target suffix", placeholder="e.g., _v2") | |
| # control_0 is required and shown outside the accordion | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl0_files = gr.File(label="Upload control_0 images (required)", file_count="multiple", type="filepath") | |
| ctrl0_gallery = gr.Gallery(label="control_0 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl0_prefix = gr.Textbox(label="control_0 prefix", placeholder="e.g., C0_") | |
| ctrl0_suffix = gr.Textbox(label="control_0 suffix", placeholder="e.g., _mask") | |
| # Optional controls start from 1, accordion closed by default | |
| with gr.Accordion("Optional control images (control_1..control_7)", open=False): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl1_files = gr.File(label="Upload control_1 images", file_count="multiple", type="filepath") | |
| ctrl1_gallery = gr.Gallery(label="control_1 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl1_prefix = gr.Textbox(label="control_1 prefix", placeholder="") | |
| ctrl1_suffix = gr.Textbox(label="control_1 suffix", placeholder="") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl2_files = gr.File(label="Upload control_2 images", file_count="multiple", type="filepath") | |
| ctrl2_gallery = gr.Gallery(label="control_2 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl2_prefix = gr.Textbox(label="control_2 prefix", placeholder="") | |
| ctrl2_suffix = gr.Textbox(label="control_2 suffix", placeholder="") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl3_files = gr.File(label="Upload control_3 images", file_count="multiple", type="filepath") | |
| ctrl3_gallery = gr.Gallery(label="control_3 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl3_prefix = gr.Textbox(label="control_3 prefix", placeholder="") | |
| ctrl3_suffix = gr.Textbox(label="control_3 suffix", placeholder="") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl4_files = gr.File(label="Upload control_4 images", file_count="multiple", type="filepath") | |
| ctrl4_gallery = gr.Gallery(label="control_4 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl4_prefix = gr.Textbox(label="control_4 prefix", placeholder="") | |
| ctrl4_suffix = gr.Textbox(label="control_4 suffix", placeholder="") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl5_files = gr.File(label="Upload control_5 images", file_count="multiple", type="filepath") | |
| ctrl5_gallery = gr.Gallery(label="control_5 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl5_prefix = gr.Textbox(label="control_5 prefix", placeholder="") | |
| ctrl5_suffix = gr.Textbox(label="control_5 suffix", placeholder="") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl6_files = gr.File(label="Upload control_6 images", file_count="multiple", type="filepath") | |
| ctrl6_gallery = gr.Gallery(label="control_6 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl6_prefix = gr.Textbox(label="control_6 prefix", placeholder="") | |
| ctrl6_suffix = gr.Textbox(label="control_6 suffix", placeholder="") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath") | |
| ctrl7_gallery = gr.Gallery(label="control_7 preview", columns=4, height=200) | |
| with gr.Column(scale=1): | |
| ctrl7_prefix = gr.Textbox(label="control_7 prefix", placeholder="") | |
| ctrl7_suffix = gr.Textbox(label="control_7 suffix", placeholder="") | |
| # Models root / OUTPUT_DIR_BASE / DATASET_CONFIG are auto-resolved at runtime; no user input needed. | |
| run_btn = gr.Button("Start Training", variant="primary") | |
| logs = gr.Textbox(label="Logs", lines=20) | |
| ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False) | |
| # moved max_epochs/save_every above next to OUTPUT NAME | |
| # Wire previews | |
| images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery) | |
| ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery) | |
| ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery) | |
| ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery) | |
| ctrl3_files.change(fn=_files_to_gallery, inputs=ctrl3_files, outputs=ctrl3_gallery) | |
| ctrl4_files.change(fn=_files_to_gallery, inputs=ctrl4_files, outputs=ctrl4_gallery) | |
| ctrl5_files.change(fn=_files_to_gallery, inputs=ctrl5_files, outputs=ctrl5_gallery) | |
| ctrl6_files.change(fn=_files_to_gallery, inputs=ctrl6_files, outputs=ctrl6_gallery) | |
| ctrl7_files.change(fn=_files_to_gallery, inputs=ctrl7_files, outputs=ctrl7_gallery) | |
| run_btn.click( | |
| fn=run_training, | |
| inputs=[ | |
| output_name, caption, images_input, main_prefix, main_suffix, | |
| ctrl0_files, ctrl0_prefix, ctrl0_suffix, | |
| ctrl1_files, ctrl1_prefix, ctrl1_suffix, | |
| ctrl2_files, ctrl2_prefix, ctrl2_suffix, | |
| ctrl3_files, ctrl3_prefix, ctrl3_suffix, | |
| ctrl4_files, ctrl4_prefix, ctrl4_suffix, | |
| ctrl5_files, ctrl5_prefix, ctrl5_suffix, | |
| ctrl6_files, ctrl6_prefix, ctrl6_suffix, | |
| ctrl7_files, ctrl7_prefix, ctrl7_suffix, | |
| lr_input, dim_input, seed_input, max_epochs, save_every, | |
| ], | |
| outputs=[logs, ckpt_files], | |
| ) | |
| return demo | |
| def _startup_download_models() -> None: | |
| global MODELS_ROOT_RUNTIME | |
| # Pick a writable models directory | |
| candidate = os.environ.get("QWEN_IMAGE_MODELS_DIR", DEFAULT_MODELS_ROOT) | |
| try: | |
| os.makedirs(candidate, exist_ok=True) | |
| MODELS_ROOT_RUNTIME = candidate | |
| except PermissionError: | |
| MODELS_ROOT_RUNTIME = os.path.join(os.path.expanduser("~"), "Qwen-Image_models") | |
| os.makedirs(MODELS_ROOT_RUNTIME, exist_ok=True) | |
| print(f"[QIE] Ensuring models in: {MODELS_ROOT_RUNTIME}") | |
| try: | |
| download_all_models(MODELS_ROOT_RUNTIME) | |
| except Exception as e: | |
| print(f"[QIE] Model download failed: {e}") | |
| if __name__ == "__main__": | |
| # 1) Ensure musubi-tuner is cloned before anything else | |
| _startup_clone_musubi_tuner() | |
| # 1.1) Install musubi-tuner dependencies (best-effort) | |
| _startup_install_musubi_deps() | |
| # 2) Download models at startup (blocking by design) | |
| _startup_download_models() | |
| # 3) Launch Gradio app | |
| ui = build_ui() | |
| # Limit concurrency (training is heavy). Enable queue for Spaces compatibility. | |
| # Use generic signature to support multiple gradio versions. | |
| try: | |
| ui = ui.queue(max_size=16) | |
| except TypeError: | |
| ui = ui.queue() | |
| # Allow Gradio to serve files saved under our runtime dirs | |
| try: | |
| allowed = [ | |
| AUTO_DIR_RUNTIME, | |
| os.path.join(AUTO_DIR_RUNTIME, "train_LoRA"), | |
| DEFAULT_DATA_ROOT, | |
| DATA_ROOT_RUNTIME, | |
| os.path.join(os.path.expanduser("~"), "auto"), | |
| os.path.join(os.path.expanduser("~"), "data"), | |
| ] | |
| ui.launch(server_name="0.0.0.0", allowed_paths=allowed) | |
| except TypeError: | |
| # Older gradio without allowed_paths | |
| ui.launch(server_name="0.0.0.0") | |