#!/usr/bin/env python3 import os import re import shutil import subprocess import sys import tempfile from pathlib import Path from typing import Dict, Iterable, List, Optional, Any, Tuple import time import json import gradio as gr import importlib import spaces import signal from threading import Lock # Local modules from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR # Defaults matching train_QIE.sh expectations DEFAULT_DATA_ROOT = "/data" DEFAULT_IMAGE_FOLDER = "image" DEFAULT_OUTPUT_DIR_BASE = "/auto/train_LoRA" DEFAULT_DATASET_CONFIG = "/auto/dataset_QIE.toml" DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR # "/Qwen-Image_models" WORKSPACE_AUTO_DIR = "/auto" # musubi-tuner settings DEFAULT_MUSUBI_TUNER_DIR = os.environ.get("MUSUBI_TUNER_DIR", "/musubi-tuner") DEFAULT_MUSUBI_TUNER_REPO = os.environ.get( "MUSUBI_TUNER_REPO", "https://github.com/kohya-ss/musubi-tuner.git" ) TRAINING_DIR = Path(__file__).resolve().parent # Runtime-resolved paths with fallbacks for non-root environments MUSUBI_TUNER_DIR_RUNTIME = DEFAULT_MUSUBI_TUNER_DIR MODELS_ROOT_RUNTIME = DEFAULT_MODELS_ROOT AUTO_DIR_RUNTIME = WORKSPACE_AUTO_DIR DATA_ROOT_RUNTIME = DEFAULT_DATA_ROOT # Active process management for hard stop (Ubuntu) _ACTIVE_LOCK: Lock = Lock() _ACTIVE_PROC: Optional[subprocess.Popen] = None _ACTIVE_PGID: Optional[int] = None def _bash_quote(s: str) -> str: """Return a POSIX-safe single-quoted string literal representing s.""" if s is None: return "''" return "'" + str(s).replace("'", "'\"'\"'") + "'" def _ensure_workspace_auto_files() -> None: """Ensure /workspace/auto has required helper files from this repo. Copies training/create_image_caption_json.py and training/dataset_QIE.toml into /workspace/auto so that train_QIE.sh can run unmodified. """ global AUTO_DIR_RUNTIME try: os.makedirs(AUTO_DIR_RUNTIME, exist_ok=True) except PermissionError: home_auto = os.path.join(os.path.expanduser("~"), "auto") os.makedirs(home_auto, exist_ok=True) AUTO_DIR_RUNTIME = home_auto # type: ignore src_py = TRAINING_DIR / "create_image_caption_json.py" src_toml = TRAINING_DIR / "dataset_QIE.toml" dst_py = Path(AUTO_DIR_RUNTIME) / "create_image_caption_json.py" dst_toml = Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml" try: shutil.copy2(src_py, dst_py) except Exception: pass try: if src_toml.exists(): shutil.copy2(src_toml, dst_toml) except Exception: pass def _update_dataset_toml( path: str, *, img_res_w: Optional[int] = None, img_res_h: Optional[int] = None, train_batch_size: Optional[int] = None, control_res_w: Optional[int] = None, control_res_h: Optional[int] = None, ) -> None: """Update dataset TOML for resolution/batch/control resolution in-place. - Updates [general] resolution and batch_size if provided. - Updates first [[datasets]] qwen_image_edit_control_resolution if provided. - Creates sections/keys if missing. """ try: txt = Path(path).read_text(encoding="utf-8") except Exception: return def _set_in_general(block: str, key: str, value_line: str) -> str: import re as _re if _re.search(rf"(?m)^\s*{_re.escape(key)}\s*=", block): block = _re.sub(rf"(?m)^\s*{_re.escape(key)}\s*=.*$", value_line, block) else: block = block.rstrip() + "\n" + value_line + "\n" return block import re m = re.search(r"(?ms)^\[general\]\s*(.*?)(?=^\[|\Z)", txt) if not m: gen = "[general]\n" if img_res_w and img_res_h: gen += f"resolution = [{int(img_res_w)}, {int(img_res_h)}]\n" if train_batch_size is not None: gen += f"batch_size = {int(train_batch_size)}\n" txt = gen + "\n" + txt else: head, block, tail = txt[:m.start(1)], m.group(1), txt[m.end(1):] if img_res_w and img_res_h: block = _set_in_general(block, "resolution", f"resolution = [{int(img_res_w)}, {int(img_res_h)}]") if train_batch_size is not None: block = _set_in_general(block, "batch_size", f"batch_size = {int(train_batch_size)}") txt = head + block + tail if control_res_w and control_res_h: m2 = re.search(r"(?ms)^\[\[datasets\]\]\s*(.*?)(?=^\[\[|\Z)", txt) if m2: head, block, tail = txt[:m2.start(1)], m2.group(1), txt[m2.end(1):] line = f"qwen_image_edit_control_resolution = [{int(control_res_w)}, {int(control_res_h)}]" if re.search(r"(?m)^\s*qwen_image_edit_control_resolution\s*=", block): block = re.sub(r"(?m)^\s*qwen_image_edit_control_resolution\s*=.*$", line, block) else: block = block.rstrip() + "\n" + line + "\n" txt = head + block + tail try: Path(path).write_text(txt, encoding="utf-8") except Exception: pass def _ensure_dir_writable(path: str) -> str: try: os.makedirs(path, exist_ok=True) return path except PermissionError: home_path = os.path.join(os.path.expanduser("~"), os.path.basename(path.strip("/\\"))) os.makedirs(home_path, exist_ok=True) return home_path def _ensure_data_root(candidate: Optional[str]) -> str: root = (candidate or DEFAULT_DATA_ROOT).strip() or DEFAULT_DATA_ROOT try: os.makedirs(root, exist_ok=True) return root except PermissionError: home_root = os.path.join(os.path.expanduser("~"), "data") os.makedirs(home_root, exist_ok=True) return home_root def _extract_paths(files: Any) -> List[Tuple[str, str]]: """Extract a list of (abs_path, orig_basename) from Gradio Files input. Supports various gradio return shapes across versions. """ out: List[Tuple[str, str]] = [] if not files: return out # Gradio Files often returns a list if isinstance(files, (list, tuple)): items = files else: items = [files] for item in items: p: Optional[str] = None orig: Optional[str] = None # dict-like if isinstance(item, dict): p = item.get("path") or item.get("name") or item.get("file") orig = item.get("orig_name") or item.get("name") else: # object with attributes p = getattr(item, "name", None) or getattr(item, "path", None) or str(item) # best-effort original name attribute orig = getattr(item, "orig_name", None) or os.path.basename(p) if p else None if p: abs_p = os.path.abspath(p) out.append((abs_p, os.path.basename(orig or abs_p))) return out def _norm_key(filename: str, prefix: str, suffix: str) -> str: stem = os.path.splitext(os.path.basename(filename))[0] if prefix and stem.startswith(prefix): stem = stem[len(prefix):] if suffix and stem.endswith(suffix): stem = stem[: -len(suffix)] return stem def _copy_uploads( uploads: List[Tuple[str, str]], dest_dir: str, rename_to: Optional[List[str]] = None ) -> List[str]: os.makedirs(dest_dir, exist_ok=True) used_names: List[str] = [] for idx, (src, orig) in enumerate(uploads): # Determine target stem if rename_to and idx < len(rename_to): stem = os.path.splitext(rename_to[idx])[0] else: stem = os.path.splitext(orig)[0] dst_name = f"{stem}.png" # ensure unique within this batch final_name = dst_name dup_idx = 1 while final_name in used_names: final_name = f"{stem}_{dup_idx}.png" dup_idx += 1 dst_path = os.path.join(dest_dir, final_name) # Convert to PNG during save try: try: from PIL import Image # type: ignore with Image.open(src) as img: img.save(dst_path, format="PNG") except Exception: # Fallback: copy then rename shutil.copy2(src, dst_path) except Exception: # Last resort shutil.copy(src, dst_path) used_names.append(final_name) return used_names def _list_checkpoints(out_dir: str, limit: int = 20) -> List[str]: try: if not out_dir or not os.path.isdir(out_dir): return [] import time now = time.time() min_age_sec = 3.0 # treat files newer than this as possibly in-flight items: List[Tuple[float, str]] = [] for root, _, files in os.walk(out_dir): for fn in files: if fn.lower().endswith('.safetensors'): full = os.path.join(root, fn) try: # Skip zero-length, too-new, or unreadable files (likely in-flight) size = os.path.getsize(full) if size <= 0: continue mtime = os.path.getmtime(full) if (now - mtime) < min_age_sec: continue # Try opening a small read to ensure readability with open(full, 'rb') as rf: rf.read(64) items.append((mtime, full)) except Exception: pass items.sort(reverse=True) return [p for _, p in items[:limit]] except Exception: return [] def _find_latest_dataset_dir(root: str) -> Optional[str]: try: if not os.path.isdir(root): return None cand: List[Tuple[float, str]] = [] for name in os.listdir(root): if not name.startswith("dataset_"): continue full = os.path.join(root, name) if os.path.isdir(full): try: cand.append((os.path.getmtime(full), full)) except Exception: pass if not cand: return None cand.sort(reverse=True) return cand[0][1] except Exception: return None def _collect_scripts_and_config(ds_dir: Optional[str]) -> List[str]: files: List[str] = [] try: ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml") if os.path.isfile(ds_conf): files.append(ds_conf) if ds_dir and os.path.isdir(ds_dir): used_script = os.path.join(ds_dir, "train_QIE_used.sh") if os.path.isfile(used_script): files.append(used_script) meta = os.path.join(ds_dir, "metadata.jsonl") if os.path.isfile(meta): files.append(meta) except Exception: pass return files def _files_to_gallery(files: Any) -> List[str]: items: List[str] = [] if not files: return items seq = files if isinstance(files, (list, tuple)) else [files] for f in seq: p = None if isinstance(f, str): p = f elif isinstance(f, dict): p = f.get("path") or f.get("name") else: p = getattr(f, "path", None) or getattr(f, "name", None) if p: items.append(p) return items def _prepare_script( dataset_name: str, caption: str, data_root: str, image_folder: str, control_folders: List[Optional[str]], models_root: str, output_dir_base: Optional[str] = None, dataset_config: Optional[str] = None, override_max_epochs: Optional[int] = None, override_save_every: Optional[int] = None, override_run_name: Optional[str] = None, target_prefix: Optional[str] = None, target_suffix: Optional[str] = None, control_prefixes: Optional[List[Optional[str]]] = None, control_suffixes: Optional[List[Optional[str]]] = None, override_learning_rate: Optional[str] = None, override_network_dim: Optional[int] = None, override_seed: Optional[int] = None, override_te_cache_bs: Optional[int] = None, ) -> Path: """Create a temporary copy of train_QIE.sh with injected variables. Only variables that must vary per-run are replaced. The rest of the script remains as-is to preserve behavior. """ src = TRAINING_DIR / "train_QIE.sh" txt = src.read_text(encoding="utf-8") # Replace core variables replacements = { r"^DATA_ROOT=\".*\"": f"DATA_ROOT={_bash_quote(data_root)}", r"^DATASET_NAME=\".*\"": f"DATASET_NAME={_bash_quote(dataset_name)}", r"^CAPTION=\".*\"": f"CAPTION={_bash_quote(caption)}", r"^IMAGE_FOLDER=\".*\"": f"IMAGE_FOLDER={_bash_quote(image_folder)}", } if output_dir_base: replacements[r"^OUTPUT_DIR_BASE=\".*\""] = ( f"OUTPUT_DIR_BASE={_bash_quote(output_dir_base)}" ) if dataset_config: replacements[r"^DATASET_CONFIG=\".*\""] = ( f"DATASET_CONFIG={_bash_quote(dataset_config)}" ) for pat, val in replacements.items(): txt = re.sub(pat, val, txt, flags=re.MULTILINE) # Inject CONTROL_FOLDER_i if provided (uncomment/override or append) for i in range(8): val = control_folders[i] if i < len(control_folders) else None if not val: continue # Try to replace commented placeholder first pattern = rf"^#\s*CONTROL_FOLDER_{i}=\".*\"" if re.search(pattern, txt, flags=re.MULTILINE): txt = re.sub( pattern, f"CONTROL_FOLDER_{i}={_bash_quote(val)}", txt, flags=re.MULTILINE, ) else: # Append after IMAGE_FOLDER definition txt = re.sub( r"^(IMAGE_FOLDER=.*)$", rf"\1\nCONTROL_FOLDER_{i}={_bash_quote(val)}", txt, count=1, flags=re.MULTILINE, ) # Point model paths to the selected models_root def _replace_model_path(txt: str, key: str, rel: str) -> str: return re.sub( rf"--{key} \"[^\"]+\"", f"--{key} \"{models_root.rstrip('/')}/{rel}\"", txt, ) txt = _replace_model_path(txt, "vae", "vae/diffusion_pytorch_model.safetensors") txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors") txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors") # Replace working dir for metadata generation to runtime /auto txt = re.sub(r"^cd\s+/workspace/auto\s*$", f"cd {AUTO_DIR_RUNTIME}", txt, flags=re.MULTILINE) # Ensure musubi-tuner path matches runtime location txt = re.sub(r"^cd\s+/musubi-tuner\s*$", f"cd {re.escape(MUSUBI_TUNER_DIR_RUNTIME)}", txt, flags=re.MULTILINE) # ZeroGPU compatibility: avoid spawning via 'accelerate launch'. # Run the training module directly in-process so GPU stays attached # to the same Python request context. txt = re.sub( r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py", r"python -u src/musubi_tuner/qwen_image_train_network.py", txt, flags=re.MULTILINE, ) # Optionally override epochs and save frequency for ZeroGPU time slicing if override_max_epochs is not None and override_max_epochs > 0: txt = re.sub(r"--max_train_epochs\s+\d+", f"--max_train_epochs {override_max_epochs}", txt) if override_save_every is not None and override_save_every > 0: txt = re.sub(r"--save_every_n_epochs\s+\d+", f"--save_every_n_epochs {override_save_every}", txt) if override_run_name: txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE) # Inject prefix/suffix flags for metadata creation extra_lines: List[str] = [] if (target_prefix or ""): extra_lines.append(f" --target_prefix {_bash_quote(target_prefix)} \\") if (target_suffix or ""): extra_lines.append(f" --target_suffix {_bash_quote(target_suffix)} \\") for i in range(8): pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None if pre: extra_lines.append(f" --control_prefix_{i} {_bash_quote(pre)} \\") if suf: extra_lines.append(f" --control_suffix_{i} {_bash_quote(suf)} \\") if extra_lines: extra_block = "\n".join(extra_lines) # Insert extra flags just before the CONTROL_ARGS line, preserving indentation. txt = re.sub( r'^(\s*)"\$\{CONTROL_ARGS\[@\]\}"', lambda m: f"{extra_block}\n{m.group(1)}\"${{CONTROL_ARGS[@]}}\"", txt, flags=re.MULTILINE, ) # Override CLI hyperparameters if provided if override_learning_rate: txt = re.sub(r"--learning_rate\s+[-+eE0-9\.]+", f"--learning_rate {override_learning_rate}", txt) if override_network_dim is not None: txt = re.sub(r"--network_dim\s+\d+", f"--network_dim {override_network_dim}", txt) if override_seed is not None: txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt) # Optionally override text-encoder cache batch size if override_te_cache_bs is not None and override_te_cache_bs > 0: txt = re.sub( r"(qwen_image_cache_text_encoder_outputs\.py[^\n]*--batch_size\s+)\d+", rf"\g<1>{int(override_te_cache_bs)}", txt, flags=re.MULTILINE, ) # Prefer overriding variable definitions at top of script (safer than CLI regex) def _set_var(name: str, value: str) -> None: nonlocal txt pattern = rf"(?m)^\s*{name}\s*=.*$" replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}' if re.search(pattern, txt): txt = re.sub(pattern, replacement, txt) else: txt = f"{replacement}\n" + txt if override_learning_rate: _set_var('LEARNING_RATE', override_learning_rate) if override_network_dim is not None: _set_var('NETWORK_DIM', str(override_network_dim)) if override_seed is not None: _set_var('SEED', str(override_seed)) if override_max_epochs is not None and override_max_epochs > 0: _set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs)) if override_save_every is not None and override_save_every > 0: _set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every)) # Write to a temp file alongside this repo for easier inspection run_dir = TRAINING_DIR / ".gradio_runs" run_dir.mkdir(parents=True, exist_ok=True) tmp = run_dir / f"train_QIE_run_{os.getpid()}.sh" tmp.write_text(txt, encoding="utf-8", newline="\n") try: os.chmod(tmp, 0o755) except Exception: pass return tmp def _pick_shell() -> str: for sh in ("bash", "sh"): if shutil.which(sh): return sh raise RuntimeError("No POSIX shell found. Please install bash or sh.") def _is_git_repo(path: str) -> bool: try: out = subprocess.run( ["git", "-C", path, "rev-parse", "--is-inside-work-tree"], capture_output=True, text=True, check=False, ) return out.returncode == 0 and out.stdout.strip() == "true" except Exception: return False def _startup_clone_musubi_tuner() -> None: global MUSUBI_TUNER_DIR_RUNTIME target = MUSUBI_TUNER_DIR_RUNTIME repo = DEFAULT_MUSUBI_TUNER_REPO parent = os.path.dirname(target.rstrip("/\\")) or "/" try: os.makedirs(parent, exist_ok=True) except PermissionError: # Fallback to home directory target = os.path.join(os.path.expanduser("~"), "musubi-tuner") MUSUBI_TUNER_DIR_RUNTIME = target os.makedirs(os.path.dirname(target), exist_ok=True) except Exception: pass if os.path.isdir(target) and _is_git_repo(target): print(f"[QIE] musubi-tuner exists at {target}; pulling latest...") try: subprocess.run(["git", "-C", target, "fetch", "--all", "--prune"], check=False) subprocess.run(["git", "-C", target, "pull", "--ff-only"], check=False) except Exception as e: print(f"[QIE] git pull failed: {e}") return if os.path.exists(target) and not _is_git_repo(target): print(f"[QIE] Warning: {target} exists and is not a git repo. Skipping clone.") return print(f"[QIE] Cloning musubi-tuner into {target} from {repo} ...") try: subprocess.run(["git", "clone", "--depth", "1", repo, target], check=True) print("[QIE] Clone completed.") except subprocess.CalledProcessError as e: print(f"[QIE] Clone failed at {target}: {e}") # Last-chance fallback into home if not target.startswith(os.path.expanduser("~")): fallback = os.path.join(os.path.expanduser("~"), "musubi-tuner") print(f"[QIE] Retrying clone into {fallback}...") try: subprocess.run(["git", "clone", "--depth", "1", repo, fallback], check=True) MUSUBI_TUNER_DIR_RUNTIME = fallback print("[QIE] Clone completed in fallback.") except Exception as e2: print(f"[QIE] Clone failed in fallback as well: {e2}") def _run_pip(args: List[str], cwd: Optional[str] = None) -> None: cmd = [sys.executable, "-m", "pip"] + args try: print(f"[QIE] pip {' '.join(args)} (cwd={cwd or os.getcwd()})") subprocess.run(cmd, check=True, cwd=cwd) except subprocess.CalledProcessError as e: print(f"[QIE] pip failed: {e}") def _startup_install_musubi_deps() -> None: repo_dir = MUSUBI_TUNER_DIR_RUNTIME if not os.path.isdir(repo_dir): print(f"[QIE] Skip deps: musubi-tuner not found at {repo_dir}") return # Upgrade basic build tooling (best-effort) try: _run_pip(["install", "-U", "pip", "setuptools", "wheel"]) except Exception: pass # Optional Torch extra via env: MUSUBI_TUNER_TORCH_EXTRA=cu124|cu128 extra = os.environ.get("MUSUBI_TUNER_TORCH_EXTRA", "").strip() editable_spec = "." if not extra else f".[{extra}]" # Install musubi-tuner in editable mode to expose entrypoints and deps try: _run_pip(["install", "-e", editable_spec], cwd=repo_dir) except Exception: # Fallback: plain install without editable try: _run_pip(["install", editable_spec], cwd=repo_dir) except Exception: print("[QIE] WARN: musubi-tuner installation failed. Continuing.") @spaces.GPU def run_training( output_name: str, caption: str, image_uploads: Any, target_prefix: str, target_suffix: str, control0_uploads: Any, ctrl0_prefix: str, ctrl0_suffix: str, control1_uploads: Any, ctrl1_prefix: str, ctrl1_suffix: str, control2_uploads: Any, ctrl2_prefix: str, ctrl2_suffix: str, control3_uploads: Any, ctrl3_prefix: str, ctrl3_suffix: str, control4_uploads: Any, ctrl4_prefix: str, ctrl4_suffix: str, control5_uploads: Any, ctrl5_prefix: str, ctrl5_suffix: str, control6_uploads: Any, ctrl6_prefix: str, ctrl6_suffix: str, control7_uploads: Any, ctrl7_prefix: str, ctrl7_suffix: str, learning_rate: str, network_dim: int, train_res_w: int, train_res_h: int, train_batch_size: int, control_res_w: int, control_res_h: int, te_cache_batch_size: int, seed: int, max_epochs: int, save_every: int, ) -> Iterable[tuple]: global _ACTIVE_PROC, _ACTIVE_PGID # Basic validation log_buf = "[QIE] Start Training invoked.\n" ckpts: List[str] = [] artifacts: List[str] = [] # Emit an initial line so UI can confirm invocation yield (log_buf, ckpts, artifacts) if not output_name.strip(): log_buf += "[ERROR] OUTPUT NAME is required.\n" yield (log_buf, ckpts, artifacts) def _stop_active_training() -> None: """Ubuntu向けのハード停止: 実行中の学習プロセスのプロセスグループを終了する""" with _ACTIVE_LOCK: proc = _ACTIVE_PROC pgid = _ACTIVE_PGID if not proc: return try: if pgid is not None: os.killpg(pgid, signal.SIGTERM) else: os.kill(proc.pid, signal.SIGTERM) except Exception: pass try: proc.wait(timeout=5) except Exception: try: if pgid is not None: os.killpg(pgid, signal.SIGKILL) else: os.kill(proc.pid, signal.SIGKILL) except Exception: pass return if not caption.strip(): log_buf += "[ERROR] CAPTION is required.\n" yield (log_buf, ckpts, artifacts) return # Ensure /auto holds helper files expected by the script _ensure_workspace_auto_files() # Resolve data root and create dataset directories (auto-decide) global DATA_ROOT_RUNTIME DATA_ROOT_RUNTIME = _ensure_data_root(None) # Auto-generate dataset directory name import time ds_name = f"dataset_{int(time.time())}" ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name) img_folder_name = DEFAULT_IMAGE_FOLDER img_dir = os.path.join(ds_dir, img_folder_name) os.makedirs(img_dir, exist_ok=True) # Ingest uploads into dataset folders base_files = _extract_paths(image_uploads) if not base_files: log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n" yield (log_buf, ckpts, artifacts) return base_filenames = _copy_uploads(base_files, img_dir) log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n" yield (log_buf, ckpts, artifacts) # Prepare control sets control_upload_sets = [ _extract_paths(control0_uploads), _extract_paths(control1_uploads), _extract_paths(control2_uploads), _extract_paths(control3_uploads), _extract_paths(control4_uploads), _extract_paths(control5_uploads), _extract_paths(control6_uploads), _extract_paths(control7_uploads), ] # Require control_0; others optional if not control_upload_sets[0]: log_buf += "[ERROR] control_0 images are required.\n" yield (log_buf, ckpts, artifacts) return control_dirs: List[Optional[str]] = [] for i, uploads in enumerate(control_upload_sets): if not uploads: control_dirs.append(None) continue folder_name = f"control_{i}" cdir = os.path.join(ds_dir, folder_name) os.makedirs(cdir, exist_ok=True) # Simply copy; name matching will be handled by create_image_caption_json.py _copy_uploads(uploads, cdir) control_dirs.append(folder_name) log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n" yield (log_buf, ckpts, artifacts) # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh # Prepare script with user parameters control_folders = [ (control_dirs[i] if control_dirs[i] else None) for i in range(8) ] # Decide dataset_config path with fallback to runtime auto dir ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml") # Update dataset config with requested resolution/batch settings try: _update_dataset_toml( ds_conf, img_res_w=int(train_res_w) if train_res_w else None, img_res_h=int(train_res_h) if train_res_h else None, train_batch_size=int(train_batch_size) if train_batch_size else None, control_res_w=int(control_res_w) if control_res_w else None, control_res_h=int(control_res_h) if control_res_h else None, ) log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n" except Exception as e: log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n" # Expose dataset config for download (if exists) if os.path.isfile(ds_conf): artifacts = [ds_conf] # Resolve models_root and set output_dir_base to the unique dataset dir models_root = MODELS_ROOT_RUNTIME out_base = ds_dir try: os.makedirs(out_base, exist_ok=True) except Exception: pass tmp_script = _prepare_script( dataset_name=ds_name, caption=caption, data_root=DATA_ROOT_RUNTIME, image_folder=img_folder_name, control_folders=control_folders, models_root=models_root, output_dir_base=out_base, dataset_config=ds_conf, override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None, override_save_every=save_every if save_every and save_every > 0 else None, override_run_name=output_name.strip(), target_prefix=(target_prefix or ""), target_suffix=(target_suffix or ""), control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix], control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix], override_learning_rate=(learning_rate or None), override_network_dim=int(network_dim) if network_dim is not None else None, override_te_cache_bs=int(te_cache_batch_size) if te_cache_batch_size else None, override_seed=int(seed) if seed is not None else None, ) shell = _pick_shell() log_buf += f"[QIE] Using shell: {shell}\n" log_buf += f"[QIE] Running script: {tmp_script}\n" out_dir = os.path.join(out_base, output_name.strip()) ckpts = _list_checkpoints(out_dir) # Copy the final script to dataset dir for download used_script_path = os.path.join(out_base, "train_QIE_used.sh") try: shutil.copy2(str(tmp_script), used_script_path) try: os.chmod(used_script_path, 0o755) except Exception: pass if used_script_path not in artifacts: artifacts.append(used_script_path) except Exception: pass yield (log_buf, ckpts, artifacts) # Run and stream output # Ensure child Python processes are unbuffered for real-time logs child_env = os.environ.copy() child_env["PYTHONUNBUFFERED"] = "1" child_env["PYTHONIOENCODING"] = "utf-8" proc = subprocess.Popen( [shell, str(tmp_script)], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True, env=child_env, preexec_fn=os.setsid, ) # Register active process for hard stop with _ACTIVE_LOCK: _ACTIVE_PROC = proc try: _ACTIVE_PGID = os.getpgid(proc.pid) except Exception: _ACTIVE_PGID = None try: assert proc.stdout is not None i = 0 for line in proc.stdout: log_buf += line i += 1 if i % 30 == 0: ckpts = _list_checkpoints(out_dir) # Try to add metadata.jsonl once available metadata_json = os.path.join(out_base, "metadata.jsonl") if os.path.isfile(metadata_json) and metadata_json not in artifacts: artifacts.append(metadata_json) yield (log_buf, ckpts, artifacts) finally: code = proc.wait() # Clear active process registration if this proc with _ACTIVE_LOCK: if _ACTIVE_PROC is proc: _ACTIVE_PROC = None _ACTIVE_PGID = None # Try to locate latest LoRA file for download lora_path = None try: ckpts = _list_checkpoints(out_dir) except Exception: pass lora_path = ckpts[0] if ckpts else None if code < 0: try: sig = -code log_buf += f"[QIE] Terminated by signal: {sig}\n" except Exception: log_buf += f"[QIE] Terminated by signal.\n" log_buf += f"[QIE] Exit code: {code}\n" # Final attempt to include metadata.jsonl metadata_json = os.path.join(out_base, "metadata.jsonl") if os.path.isfile(metadata_json) and metadata_json not in artifacts: artifacts.append(metadata_json) yield (log_buf, ckpts, artifacts) def build_ui() -> gr.Blocks: css = """ .pad-section { padding: 6px; margin-bottom: 12px; border: 1px solid var(--color-border, #e5e7eb); border-radius: 8px; background: var(--color-background-secondary, #ffffff); } .pad-section_0 { padding: 6px; margin-bottom: 12px; border: 1px solid var(--color-border, #e5e7eb); border-radius: 8px; background: var(--color-background-secondary, #fafafa); } .pad-section_1 { padding: 6px; margin-bottom: 12px; border: 1px solid var(--color-border, #e5e7eb); border-radius: 8px; background: var(--color-background-secondary, #eaeaea); } """ with gr.Blocks(title="Qwen-Image-Edit: Trainer", css=css) as demo: with gr.Tabs() as tabs: with gr.TabItem("Training"): gr.Markdown(""" # Qwen-Image-Edit Trainer 学習に使う画像をアップロードし、必要ならファイル名の前後にある共通の文字(prefix/suffix)を指定して、 自動でデータセットを作成し学習を開始します。難しい操作は不要です。 """) with gr.Accordion("Settings", elem_classes=["pad-section"]): with gr.Group(): with gr.Row(): output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1) caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2) with gr.Row(): lr_input = gr.Textbox(label="Learning rate", value="1e-3") dim_input = gr.Number(label="Network dim", value=4, precision=0) train_bs = gr.Number(label="Batch size (dataset)", value=1, precision=0) seed_input = gr.Number(label="Seed", value=42, precision=0) max_epochs = gr.Number(label="Max epochs", value=100, precision=0) save_every = gr.Number(label="Save every N epochs", value=10, precision=0) with gr.Row(): tr_w = gr.Number(label="Image resolution W", value=1024, precision=0) tr_h = gr.Number(label="Image resolution H", value=1024, precision=0) cr_w = gr.Number(label="Control resolution W", value=1024, precision=0) cr_h = gr.Number(label="Control resolution H", value=1024, precision=0) te_bs = gr.Number(label="TE cache batch size", value=16, precision=0) with gr.Accordion("Target Image", elem_classes=["pad-section_0"]): with gr.Group(): with gr.Row(): images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath", height=220, scale=3) main_gallery = gr.Gallery(label="Target preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): main_prefix = gr.Textbox(label="Target prefix", placeholder="e.g., IMG_") main_suffix = gr.Textbox(label="Target suffix", placeholder="e.g., _v2") with gr.Accordion("prefix/sufixについて", open=False): gr.Markdown(""" ファイルの同名判定のため、画像のファイル名から共通の先頭/末尾文字を取り除く指定(例: IMG_ や _v2) - まずターゲット画像のファイル名(拡張子なし)から、指定した Target prefix/suffix を取り除いたものを key とします。 - 各コントロールは「付加」規則で、期待名 = control_prefix_i + key + control_suffix_i + ".png" を探して対応付けます。 - アップロード時に画像は自動で .png に変換して保存します(元のファイル名のベースは維持)。 - Control 0 は必須、Control 1〜7 は任意。コントロール画像が1枚だけのときは、すべてのターゲット画像に適用します。 """) # control_0 is required and shown outside the accordion with gr.Accordion("Control 0", elem_classes=["pad-section_1"]): with gr.Group(): with gr.Row(): ctrl0_files = gr.File(label="Upload control_0 images (required)", file_count="multiple", type="filepath", height=220, scale=3) ctrl0_gallery = gr.Gallery(label="control_0 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl0_prefix = gr.Textbox(label="control_0 prefix", placeholder="e.g., C0_") ctrl0_suffix = gr.Textbox(label="control_0 suffix", placeholder="e.g., _mask") # Optional controls start from 1, accordion closed by default with gr.Accordion("Control 1", open=False, elem_classes=["pad-section_0"]): with gr.Group(): with gr.Row(): ctrl1_files = gr.File(label="Upload control_1 images", file_count="multiple", type="filepath", height=220, scale=3) ctrl1_gallery = gr.Gallery(label="control_1 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl1_prefix = gr.Textbox(label="control_1 prefix", placeholder="") ctrl1_suffix = gr.Textbox(label="control_1 suffix", placeholder="") with gr.Accordion("Control 2", open=False, elem_classes=["pad-section_1"]): with gr.Group(): with gr.Row(): ctrl2_files = gr.File(label="Upload control_2 images", file_count="multiple", type="filepath", height=220, scale=3) ctrl2_gallery = gr.Gallery(label="control_2 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl2_prefix = gr.Textbox(label="control_2 prefix", placeholder="") ctrl2_suffix = gr.Textbox(label="control_2 suffix", placeholder="") with gr.Accordion("Control 3", open=False, elem_classes=["pad-section_0"]): with gr.Group(): with gr.Row(): ctrl3_files = gr.File(label="Upload control_3 images", file_count="multiple", type="filepath", height=220, scale=3) ctrl3_gallery = gr.Gallery(label="control_3 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl3_prefix = gr.Textbox(label="control_3 prefix", placeholder="") ctrl3_suffix = gr.Textbox(label="control_3 suffix", placeholder="") with gr.Accordion("Control 4", open=False, elem_classes=["pad-section_1"]): with gr.Group(): with gr.Row(): ctrl4_files = gr.File(label="Upload control_4 images", file_count="multiple", type="filepath", height=220, scale=3) ctrl4_gallery = gr.Gallery(label="control_4 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl4_prefix = gr.Textbox(label="control_4 prefix", placeholder="") ctrl4_suffix = gr.Textbox(label="control_4 suffix", placeholder="") with gr.Accordion("Control 5", open=False, elem_classes=["pad-section_0"]): with gr.Group(): with gr.Row(): ctrl5_files = gr.File(label="Upload control_5 images", file_count="multiple", type="filepath", height=220, scale=3) ctrl5_gallery = gr.Gallery(label="control_5 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl5_prefix = gr.Textbox(label="control_5 prefix", placeholder="") ctrl5_suffix = gr.Textbox(label="control_5 suffix", placeholder="") with gr.Accordion("Control 6", open=False, elem_classes=["pad-section_1"]): with gr.Group(): with gr.Row(): ctrl6_files = gr.File(label="Upload control_6 images", file_count="multiple", type="filepath", height=220, scale=3) ctrl6_gallery = gr.Gallery(label="control_6 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl6_prefix = gr.Textbox(label="control_6 prefix", placeholder="") ctrl6_suffix = gr.Textbox(label="control_6 suffix", placeholder="") with gr.Accordion("Control 7", open=False, elem_classes=["pad-section_0"]): with gr.Group(): with gr.Row(): ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath", height=220, scale=3) ctrl7_gallery = gr.Gallery(label="control_7 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) with gr.Column(scale=1): with gr.Row(): ctrl7_prefix = gr.Textbox(label="control_7 prefix", placeholder="") ctrl7_suffix = gr.Textbox(label="control_7 suffix", placeholder="") # Models root / OUTPUT_DIR_BASE / DATASET_CONFIG are auto-resolved at runtime; no user input needed. run_btn = gr.Button("Start Training", variant="primary") logs = gr.Textbox(label="Logs", lines=20) ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False) scripts_files = gr.Files(label="Scripts & Config (live)", interactive=False) with gr.Row(): stop_btn = gr.Button("学習を停止", variant="secondary") refresh_scripts_btn = gr.Button("ファイルを再取得", variant="secondary") # moved max_epochs/save_every above next to OUTPUT NAME # Wire previews images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery) ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery) ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery) ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery) ctrl3_files.change(fn=_files_to_gallery, inputs=ctrl3_files, outputs=ctrl3_gallery) ctrl4_files.change(fn=_files_to_gallery, inputs=ctrl4_files, outputs=ctrl4_gallery) ctrl5_files.change(fn=_files_to_gallery, inputs=ctrl5_files, outputs=ctrl5_gallery) ctrl6_files.change(fn=_files_to_gallery, inputs=ctrl6_files, outputs=ctrl6_gallery) ctrl7_files.change(fn=_files_to_gallery, inputs=ctrl7_files, outputs=ctrl7_gallery) run_btn.click( fn=run_training, inputs=[ output_name, caption, images_input, main_prefix, main_suffix, ctrl0_files, ctrl0_prefix, ctrl0_suffix, ctrl1_files, ctrl1_prefix, ctrl1_suffix, ctrl2_files, ctrl2_prefix, ctrl2_suffix, ctrl3_files, ctrl3_prefix, ctrl3_suffix, ctrl4_files, ctrl4_prefix, ctrl4_suffix, ctrl5_files, ctrl5_prefix, ctrl5_suffix, ctrl6_files, ctrl6_prefix, ctrl6_suffix, ctrl7_files, ctrl7_prefix, ctrl7_suffix, lr_input, dim_input, tr_w, tr_h, train_bs, cr_w, cr_h, te_bs, seed_input, max_epochs, save_every, ], outputs=[logs, ckpt_files, scripts_files], ) # 回収ボタン: 直近の dataset_ ディレクトリからチェックポイントとスクリプト/設定を再取得 def _refresh_all() -> tuple: try: ds_dir = _find_latest_dataset_dir(DATA_ROOT_RUNTIME) except Exception: ds_dir = None try: ck = _list_checkpoints(ds_dir) if ds_dir else [] except Exception: ck = [] try: sc = _collect_scripts_and_config(ds_dir) except Exception: sc = _collect_scripts_and_config(None) return ck, sc refresh_scripts_btn.click( fn=_refresh_all, inputs=[], outputs=[ckpt_files, scripts_files], ) # Hard stop button (Ubuntu): kill active process group def _on_stop_click(): _stop_active_training() return stop_btn.click(fn=_on_stop_click, inputs=[], outputs=[]) with gr.TabItem("Prompt Generator"): gr.Markdown(""" # 🎨 A→B 変換プロンプト自動生成 画像A(入力)と画像B(出力)、補足説明を入力すると、 A→B の変換内容を英語プロンプトとして自動生成し、タスク名候補(3件)も提案します。 モデルは `gpt-5` を使用します。 """) api_key_pg = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...") with gr.Row(): img_a_pg = gr.Image(type="filepath", label="Image A (Input)", height=300) img_b_pg = gr.Image(type="filepath", label="Image B (Output)", height=300) notes_pg = gr.Textbox(label="補足説明(日本語可)", lines=4, value="この画像は例であって、汎用的なプロンプトにする") want_japanese_pg = gr.Checkbox(label="日本語訳を含める", value=True) run_btn_pg = gr.Button("生成する", variant="primary") english_out_pg = gr.Textbox(label="English Prompt", lines=8) names_out_pg = gr.Textbox(label="Name Suggestions", lines=4) japanese_out_pg = gr.Textbox(label="日本語訳(任意)", lines=8) def _on_click_prompt(api_key_in, a_path, b_path, notes_in, ja_flag): # Lazy import to avoid constructing extra Blocks at startup qpg = importlib.import_module("QIE_prompt_generator") a_url = qpg.file_to_data_url(a_path) if a_path else None b_url = qpg.file_to_data_url(b_path) if b_path else None return qpg.call_openai_chat(api_key_in, a_url, b_url, notes_in, ja_flag) run_btn_pg.click( fn=_on_click_prompt, inputs=[api_key_pg, img_a_pg, img_b_pg, notes_pg, want_japanese_pg], outputs=[english_out_pg, names_out_pg, japanese_out_pg], ) return demo def _startup_download_models() -> None: global MODELS_ROOT_RUNTIME # Pick a writable models directory candidate = os.environ.get("QWEN_IMAGE_MODELS_DIR", DEFAULT_MODELS_ROOT) try: os.makedirs(candidate, exist_ok=True) MODELS_ROOT_RUNTIME = candidate except PermissionError: MODELS_ROOT_RUNTIME = os.path.join(os.path.expanduser("~"), "Qwen-Image_models") os.makedirs(MODELS_ROOT_RUNTIME, exist_ok=True) print(f"[QIE] Ensuring models in: {MODELS_ROOT_RUNTIME}") try: download_all_models(MODELS_ROOT_RUNTIME) except Exception as e: print(f"[QIE] Model download failed: {e}") if __name__ == "__main__": # 1) Ensure musubi-tuner is cloned before anything else _startup_clone_musubi_tuner() # 1.1) Install musubi-tuner dependencies (best-effort) _startup_install_musubi_deps() # 2) Download models at startup (blocking by design) _startup_download_models() # 3) Launch Gradio app ui = build_ui() # Limit concurrency (training is heavy). Enable queue for Spaces compatibility. # Use generic signature to support multiple gradio versions. try: ui = ui.queue(max_size=16) except TypeError: ui = ui.queue() # Allow Gradio to serve files saved under our runtime dirs try: allowed = [ AUTO_DIR_RUNTIME, os.path.join(AUTO_DIR_RUNTIME, "train_LoRA"), DEFAULT_DATA_ROOT, DATA_ROOT_RUNTIME, os.path.join(os.path.expanduser("~"), "auto"), os.path.join(os.path.expanduser("~"), "data"), ] ui.launch(server_name="0.0.0.0", allowed_paths=allowed, ssr_mode=False) except TypeError: # Older gradio without allowed_paths try: ui.launch(server_name="0.0.0.0", ssr_mode=False) except TypeError: # Very old gradio without ssr_mode ui.launch(server_name="0.0.0.0")