Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Optional, Any, Tuple | |
| import time | |
| import json | |
| import gradio as gr | |
| import importlib | |
| import spaces | |
| import signal | |
| from threading import Lock | |
| # Local modules | |
| from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR | |
| # Defaults matching train_QIE.sh expectations | |
| DEFAULT_DATA_ROOT = "/data" | |
| DEFAULT_IMAGE_FOLDER = "image" | |
| DEFAULT_OUTPUT_DIR_BASE = "/auto/train_LoRA" | |
| DEFAULT_DATASET_CONFIG = "/auto/dataset_QIE.toml" | |
| DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR # "/Qwen-Image_models" | |
| WORKSPACE_AUTO_DIR = "/auto" | |
| # musubi-tuner settings | |
| DEFAULT_MUSUBI_TUNER_DIR = os.environ.get("MUSUBI_TUNER_DIR", "/musubi-tuner") | |
| DEFAULT_MUSUBI_TUNER_REPO = os.environ.get( | |
| "MUSUBI_TUNER_REPO", "https://github.com/kohya-ss/musubi-tuner.git" | |
| ) | |
| TRAINING_DIR = Path(__file__).resolve().parent | |
| # Runtime-resolved paths with fallbacks for non-root environments | |
| MUSUBI_TUNER_DIR_RUNTIME = DEFAULT_MUSUBI_TUNER_DIR | |
| MODELS_ROOT_RUNTIME = DEFAULT_MODELS_ROOT | |
| AUTO_DIR_RUNTIME = WORKSPACE_AUTO_DIR | |
| DATA_ROOT_RUNTIME = DEFAULT_DATA_ROOT | |
| # Active process management for hard stop (Ubuntu) | |
| _ACTIVE_LOCK: Lock = Lock() | |
| _ACTIVE_PROC: Optional[subprocess.Popen] = None | |
| _ACTIVE_PGID: Optional[int] = None | |
| def _bash_quote(s: str) -> str: | |
| """Return a POSIX-safe single-quoted string literal representing s.""" | |
| if s is None: | |
| return "''" | |
| return "'" + str(s).replace("'", "'\"'\"'") + "'" | |
| def _ensure_workspace_auto_files() -> None: | |
| """Ensure /workspace/auto has required helper files from this repo. | |
| Copies training/create_image_caption_json.py and training/dataset_QIE.toml | |
| into /workspace/auto so that train_QIE.sh can run unmodified. | |
| """ | |
| global AUTO_DIR_RUNTIME | |
| try: | |
| os.makedirs(AUTO_DIR_RUNTIME, exist_ok=True) | |
| except PermissionError: | |
| home_auto = os.path.join(os.path.expanduser("~"), "auto") | |
| os.makedirs(home_auto, exist_ok=True) | |
| AUTO_DIR_RUNTIME = home_auto # type: ignore | |
| src_py = TRAINING_DIR / "create_image_caption_json.py" | |
| src_toml = TRAINING_DIR / "dataset_QIE.toml" | |
| dst_py = Path(AUTO_DIR_RUNTIME) / "create_image_caption_json.py" | |
| dst_toml = Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml" | |
| try: | |
| shutil.copy2(src_py, dst_py) | |
| except Exception: | |
| pass | |
| try: | |
| if src_toml.exists(): | |
| shutil.copy2(src_toml, dst_toml) | |
| except Exception: | |
| pass | |
| def _update_dataset_toml( | |
| path: str, | |
| *, | |
| img_res_w: Optional[int] = None, | |
| img_res_h: Optional[int] = None, | |
| train_batch_size: Optional[int] = None, | |
| control_res_w: Optional[int] = None, | |
| control_res_h: Optional[int] = None, | |
| ) -> None: | |
| """Update dataset TOML for resolution/batch/control resolution in-place. | |
| - Updates [general] resolution and batch_size if provided. | |
| - Updates first [[datasets]] qwen_image_edit_control_resolution if provided. | |
| - Creates sections/keys if missing. | |
| """ | |
| try: | |
| txt = Path(path).read_text(encoding="utf-8") | |
| except Exception: | |
| return | |
| def _set_in_general(block: str, key: str, value_line: str) -> str: | |
| import re as _re | |
| if _re.search(rf"(?m)^\s*{_re.escape(key)}\s*=", block): | |
| block = _re.sub(rf"(?m)^\s*{_re.escape(key)}\s*=.*$", value_line, block) | |
| else: | |
| block = block.rstrip() + "\n" + value_line + "\n" | |
| return block | |
| import re | |
| m = re.search(r"(?ms)^\[general\]\s*(.*?)(?=^\[|\Z)", txt) | |
| if not m: | |
| gen = "[general]\n" | |
| if img_res_w and img_res_h: | |
| gen += f"resolution = [{int(img_res_w)}, {int(img_res_h)}]\n" | |
| if train_batch_size is not None: | |
| gen += f"batch_size = {int(train_batch_size)}\n" | |
| txt = gen + "\n" + txt | |
| else: | |
| head, block, tail = txt[:m.start(1)], m.group(1), txt[m.end(1):] | |
| if img_res_w and img_res_h: | |
| block = _set_in_general(block, "resolution", f"resolution = [{int(img_res_w)}, {int(img_res_h)}]") | |
| if train_batch_size is not None: | |
| block = _set_in_general(block, "batch_size", f"batch_size = {int(train_batch_size)}") | |
| txt = head + block + tail | |
| if control_res_w and control_res_h: | |
| m2 = re.search(r"(?ms)^\[\[datasets\]\]\s*(.*?)(?=^\[\[|\Z)", txt) | |
| if m2: | |
| head, block, tail = txt[:m2.start(1)], m2.group(1), txt[m2.end(1):] | |
| line = f"qwen_image_edit_control_resolution = [{int(control_res_w)}, {int(control_res_h)}]" | |
| if re.search(r"(?m)^\s*qwen_image_edit_control_resolution\s*=", block): | |
| block = re.sub(r"(?m)^\s*qwen_image_edit_control_resolution\s*=.*$", line, block) | |
| else: | |
| block = block.rstrip() + "\n" + line + "\n" | |
| txt = head + block + tail | |
| try: | |
| Path(path).write_text(txt, encoding="utf-8") | |
| except Exception: | |
| pass | |
| def _ensure_dir_writable(path: str) -> str: | |
| try: | |
| os.makedirs(path, exist_ok=True) | |
| return path | |
| except PermissionError: | |
| home_path = os.path.join(os.path.expanduser("~"), os.path.basename(path.strip("/\\"))) | |
| os.makedirs(home_path, exist_ok=True) | |
| return home_path | |
| def _ensure_data_root(candidate: Optional[str]) -> str: | |
| root = (candidate or DEFAULT_DATA_ROOT).strip() or DEFAULT_DATA_ROOT | |
| try: | |
| os.makedirs(root, exist_ok=True) | |
| return root | |
| except PermissionError: | |
| home_root = os.path.join(os.path.expanduser("~"), "data") | |
| os.makedirs(home_root, exist_ok=True) | |
| return home_root | |
| def _extract_paths(files: Any) -> List[Tuple[str, str]]: | |
| """Extract a list of (abs_path, orig_basename) from Gradio Files input. | |
| Supports various gradio return shapes across versions. | |
| """ | |
| out: List[Tuple[str, str]] = [] | |
| if not files: | |
| return out | |
| # Gradio Files often returns a list | |
| if isinstance(files, (list, tuple)): | |
| items = files | |
| else: | |
| items = [files] | |
| for item in items: | |
| p: Optional[str] = None | |
| orig: Optional[str] = None | |
| # dict-like | |
| if isinstance(item, dict): | |
| p = item.get("path") or item.get("name") or item.get("file") | |
| orig = item.get("orig_name") or item.get("name") | |
| else: | |
| # object with attributes | |
| p = getattr(item, "name", None) or getattr(item, "path", None) or str(item) | |
| # best-effort original name attribute | |
| orig = getattr(item, "orig_name", None) or os.path.basename(p) if p else None | |
| if p: | |
| abs_p = os.path.abspath(p) | |
| out.append((abs_p, os.path.basename(orig or abs_p))) | |
| return out | |
| def _norm_key(filename: str, prefix: str, suffix: str) -> str: | |
| stem = os.path.splitext(os.path.basename(filename))[0] | |
| if prefix and stem.startswith(prefix): | |
| stem = stem[len(prefix):] | |
| if suffix and stem.endswith(suffix): | |
| stem = stem[: -len(suffix)] | |
| return stem | |
| def _copy_uploads( | |
| uploads: List[Tuple[str, str]], dest_dir: str, rename_to: Optional[List[str]] = None | |
| ) -> List[str]: | |
| os.makedirs(dest_dir, exist_ok=True) | |
| used_names: List[str] = [] | |
| for idx, (src, orig) in enumerate(uploads): | |
| # Determine target stem | |
| if rename_to and idx < len(rename_to): | |
| stem = os.path.splitext(rename_to[idx])[0] | |
| else: | |
| stem = os.path.splitext(orig)[0] | |
| dst_name = f"{stem}.png" | |
| # ensure unique within this batch | |
| final_name = dst_name | |
| dup_idx = 1 | |
| while final_name in used_names: | |
| final_name = f"{stem}_{dup_idx}.png" | |
| dup_idx += 1 | |
| dst_path = os.path.join(dest_dir, final_name) | |
| # Convert to PNG during save | |
| try: | |
| try: | |
| from PIL import Image # type: ignore | |
| with Image.open(src) as img: | |
| img.save(dst_path, format="PNG") | |
| except Exception: | |
| # Fallback: copy then rename | |
| shutil.copy2(src, dst_path) | |
| except Exception: | |
| # Last resort | |
| shutil.copy(src, dst_path) | |
| used_names.append(final_name) | |
| return used_names | |
| def _list_checkpoints(out_dir: str, limit: int = 20) -> List[str]: | |
| try: | |
| if not out_dir or not os.path.isdir(out_dir): | |
| return [] | |
| import time | |
| now = time.time() | |
| min_age_sec = 3.0 # treat files newer than this as possibly in-flight | |
| items: List[Tuple[float, str]] = [] | |
| for root, _, files in os.walk(out_dir): | |
| for fn in files: | |
| if fn.lower().endswith('.safetensors'): | |
| full = os.path.join(root, fn) | |
| try: | |
| # Skip zero-length, too-new, or unreadable files (likely in-flight) | |
| size = os.path.getsize(full) | |
| if size <= 0: | |
| continue | |
| mtime = os.path.getmtime(full) | |
| if (now - mtime) < min_age_sec: | |
| continue | |
| # Try opening a small read to ensure readability | |
| with open(full, 'rb') as rf: | |
| rf.read(64) | |
| items.append((mtime, full)) | |
| except Exception: | |
| pass | |
| items.sort(reverse=True) | |
| return [p for _, p in items[:limit]] | |
| except Exception: | |
| return [] | |
| def _find_latest_dataset_dir(root: str) -> Optional[str]: | |
| try: | |
| if not os.path.isdir(root): | |
| return None | |
| cand: List[Tuple[float, str]] = [] | |
| for name in os.listdir(root): | |
| if not name.startswith("dataset_"): | |
| continue | |
| full = os.path.join(root, name) | |
| if os.path.isdir(full): | |
| try: | |
| cand.append((os.path.getmtime(full), full)) | |
| except Exception: | |
| pass | |
| if not cand: | |
| return None | |
| cand.sort(reverse=True) | |
| return cand[0][1] | |
| except Exception: | |
| return None | |
| def _collect_scripts_and_config(ds_dir: Optional[str]) -> List[str]: | |
| files: List[str] = [] | |
| try: | |
| ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml") | |
| if os.path.isfile(ds_conf): | |
| files.append(ds_conf) | |
| if ds_dir and os.path.isdir(ds_dir): | |
| used_script = os.path.join(ds_dir, "train_QIE_used.sh") | |
| if os.path.isfile(used_script): | |
| files.append(used_script) | |
| meta = os.path.join(ds_dir, "metadata.jsonl") | |
| if os.path.isfile(meta): | |
| files.append(meta) | |
| except Exception: | |
| pass | |
| return files | |
| def _files_to_gallery(files: Any) -> List[str]: | |
| items: List[str] = [] | |
| if not files: | |
| return items | |
| seq = files if isinstance(files, (list, tuple)) else [files] | |
| for f in seq: | |
| p = None | |
| if isinstance(f, str): | |
| p = f | |
| elif isinstance(f, dict): | |
| p = f.get("path") or f.get("name") | |
| else: | |
| p = getattr(f, "path", None) or getattr(f, "name", None) | |
| if p: | |
| items.append(p) | |
| return items | |
| def _prepare_script( | |
| dataset_name: str, | |
| caption: str, | |
| data_root: str, | |
| image_folder: str, | |
| control_folders: List[Optional[str]], | |
| models_root: str, | |
| output_dir_base: Optional[str] = None, | |
| dataset_config: Optional[str] = None, | |
| override_max_epochs: Optional[int] = None, | |
| override_save_every: Optional[int] = None, | |
| override_run_name: Optional[str] = None, | |
| target_prefix: Optional[str] = None, | |
| target_suffix: Optional[str] = None, | |
| control_prefixes: Optional[List[Optional[str]]] = None, | |
| control_suffixes: Optional[List[Optional[str]]] = None, | |
| override_learning_rate: Optional[str] = None, | |
| override_network_dim: Optional[int] = None, | |
| override_seed: Optional[int] = None, | |
| override_te_cache_bs: Optional[int] = None, | |
| ) -> Path: | |
| """Create a temporary copy of train_QIE.sh with injected variables. | |
| Only variables that must vary per-run are replaced. The rest of the script | |
| remains as-is to preserve behavior. | |
| """ | |
| src = TRAINING_DIR / "train_QIE.sh" | |
| txt = src.read_text(encoding="utf-8") | |
| # Replace core variables | |
| replacements = { | |
| r"^DATA_ROOT=\".*\"": f"DATA_ROOT={_bash_quote(data_root)}", | |
| r"^DATASET_NAME=\".*\"": f"DATASET_NAME={_bash_quote(dataset_name)}", | |
| r"^CAPTION=\".*\"": f"CAPTION={_bash_quote(caption)}", | |
| r"^IMAGE_FOLDER=\".*\"": f"IMAGE_FOLDER={_bash_quote(image_folder)}", | |
| } | |
| if output_dir_base: | |
| replacements[r"^OUTPUT_DIR_BASE=\".*\""] = ( | |
| f"OUTPUT_DIR_BASE={_bash_quote(output_dir_base)}" | |
| ) | |
| if dataset_config: | |
| replacements[r"^DATASET_CONFIG=\".*\""] = ( | |
| f"DATASET_CONFIG={_bash_quote(dataset_config)}" | |
| ) | |
| for pat, val in replacements.items(): | |
| txt = re.sub(pat, val, txt, flags=re.MULTILINE) | |
| # Inject CONTROL_FOLDER_i if provided (uncomment/override or append) | |
| for i in range(8): | |
| val = control_folders[i] if i < len(control_folders) else None | |
| if not val: | |
| continue | |
| # Try to replace commented placeholder first | |
| pattern = rf"^#\s*CONTROL_FOLDER_{i}=\".*\"" | |
| if re.search(pattern, txt, flags=re.MULTILINE): | |
| txt = re.sub( | |
| pattern, | |
| f"CONTROL_FOLDER_{i}={_bash_quote(val)}", | |
| txt, | |
| flags=re.MULTILINE, | |
| ) | |
| else: | |
| # Append after IMAGE_FOLDER definition | |
| txt = re.sub( | |
| r"^(IMAGE_FOLDER=.*)$", | |
| rf"\1\nCONTROL_FOLDER_{i}={_bash_quote(val)}", | |
| txt, | |
| count=1, | |
| flags=re.MULTILINE, | |
| ) | |
| # Point model paths to the selected models_root | |
| def _replace_model_path(txt: str, key: str, rel: str) -> str: | |
| return re.sub( | |
| rf"--{key} \"[^\"]+\"", | |
| f"--{key} \"{models_root.rstrip('/')}/{rel}\"", | |
| txt, | |
| ) | |
| txt = _replace_model_path(txt, "vae", "vae/diffusion_pytorch_model.safetensors") | |
| txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors") | |
| txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors") | |
| # Replace working dir for metadata generation to runtime /auto | |
| txt = re.sub(r"^cd\s+/workspace/auto\s*$", f"cd {AUTO_DIR_RUNTIME}", txt, flags=re.MULTILINE) | |
| # Ensure musubi-tuner path matches runtime location | |
| txt = re.sub(r"^cd\s+/musubi-tuner\s*$", f"cd {re.escape(MUSUBI_TUNER_DIR_RUNTIME)}", txt, flags=re.MULTILINE) | |
| # ZeroGPU compatibility: avoid spawning via 'accelerate launch'. | |
| # Run the training module directly in-process so GPU stays attached | |
| # to the same Python request context. | |
| txt = re.sub( | |
| r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py", | |
| r"python -u src/musubi_tuner/qwen_image_train_network.py", | |
| txt, | |
| flags=re.MULTILINE, | |
| ) | |
| # Optionally override epochs and save frequency for ZeroGPU time slicing | |
| if override_max_epochs is not None and override_max_epochs > 0: | |
| txt = re.sub(r"--max_train_epochs\s+\d+", | |
| f"--max_train_epochs {override_max_epochs}", txt) | |
| if override_save_every is not None and override_save_every > 0: | |
| txt = re.sub(r"--save_every_n_epochs\s+\d+", | |
| f"--save_every_n_epochs {override_save_every}", txt) | |
| if override_run_name: | |
| txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE) | |
| # Inject prefix/suffix flags for metadata creation | |
| extra_lines: List[str] = [] | |
| if (target_prefix or ""): | |
| extra_lines.append(f" --target_prefix {_bash_quote(target_prefix)} \\") | |
| if (target_suffix or ""): | |
| extra_lines.append(f" --target_suffix {_bash_quote(target_suffix)} \\") | |
| for i in range(8): | |
| pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None | |
| suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None | |
| if pre: | |
| extra_lines.append(f" --control_prefix_{i} {_bash_quote(pre)} \\") | |
| if suf: | |
| extra_lines.append(f" --control_suffix_{i} {_bash_quote(suf)} \\") | |
| if extra_lines: | |
| extra_block = "\n".join(extra_lines) | |
| # Insert extra flags just before the CONTROL_ARGS line, preserving indentation. | |
| txt = re.sub( | |
| r'^(\s*)"\$\{CONTROL_ARGS\[@\]\}"', | |
| lambda m: f"{extra_block}\n{m.group(1)}\"${{CONTROL_ARGS[@]}}\"", | |
| txt, | |
| flags=re.MULTILINE, | |
| ) | |
| # Override CLI hyperparameters if provided | |
| if override_learning_rate: | |
| txt = re.sub(r"--learning_rate\s+[-+eE0-9\.]+", f"--learning_rate {override_learning_rate}", txt) | |
| if override_network_dim is not None: | |
| txt = re.sub(r"--network_dim\s+\d+", f"--network_dim {override_network_dim}", txt) | |
| if override_seed is not None: | |
| txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt) | |
| # Optionally override text-encoder cache batch size | |
| if override_te_cache_bs is not None and override_te_cache_bs > 0: | |
| txt = re.sub( | |
| r"(qwen_image_cache_text_encoder_outputs\.py[^\n]*--batch_size\s+)\d+", | |
| rf"\g<1>{int(override_te_cache_bs)}", | |
| txt, | |
| flags=re.MULTILINE, | |
| ) | |
| # Prefer overriding variable definitions at top of script (safer than CLI regex) | |
| def _set_var(name: str, value: str) -> None: | |
| nonlocal txt | |
| pattern = rf"(?m)^\s*{name}\s*=.*$" | |
| replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}' | |
| if re.search(pattern, txt): | |
| txt = re.sub(pattern, replacement, txt) | |
| else: | |
| txt = f"{replacement}\n" + txt | |
| if override_learning_rate: | |
| _set_var('LEARNING_RATE', override_learning_rate) | |
| if override_network_dim is not None: | |
| _set_var('NETWORK_DIM', str(override_network_dim)) | |
| if override_seed is not None: | |
| _set_var('SEED', str(override_seed)) | |
| if override_max_epochs is not None and override_max_epochs > 0: | |
| _set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs)) | |
| if override_save_every is not None and override_save_every > 0: | |
| _set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every)) | |
| # Write to a temp file alongside this repo for easier inspection | |
| run_dir = TRAINING_DIR / ".gradio_runs" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| tmp = run_dir / f"train_QIE_run_{os.getpid()}.sh" | |
| tmp.write_text(txt, encoding="utf-8", newline="\n") | |
| try: | |
| os.chmod(tmp, 0o755) | |
| except Exception: | |
| pass | |
| return tmp | |
| def _pick_shell() -> str: | |
| for sh in ("bash", "sh"): | |
| if shutil.which(sh): | |
| return sh | |
| raise RuntimeError("No POSIX shell found. Please install bash or sh.") | |
| def _is_git_repo(path: str) -> bool: | |
| try: | |
| out = subprocess.run( | |
| ["git", "-C", path, "rev-parse", "--is-inside-work-tree"], | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| ) | |
| return out.returncode == 0 and out.stdout.strip() == "true" | |
| except Exception: | |
| return False | |
| def _startup_clone_musubi_tuner() -> None: | |
| global MUSUBI_TUNER_DIR_RUNTIME | |
| target = MUSUBI_TUNER_DIR_RUNTIME | |
| repo = DEFAULT_MUSUBI_TUNER_REPO | |
| parent = os.path.dirname(target.rstrip("/\\")) or "/" | |
| try: | |
| os.makedirs(parent, exist_ok=True) | |
| except PermissionError: | |
| # Fallback to home directory | |
| target = os.path.join(os.path.expanduser("~"), "musubi-tuner") | |
| MUSUBI_TUNER_DIR_RUNTIME = target | |
| os.makedirs(os.path.dirname(target), exist_ok=True) | |
| except Exception: | |
| pass | |
| if os.path.isdir(target) and _is_git_repo(target): | |
| print(f"[QIE] musubi-tuner exists at {target}; pulling latest...") | |
| try: | |
| subprocess.run(["git", "-C", target, "fetch", "--all", "--prune"], check=False) | |
| subprocess.run(["git", "-C", target, "pull", "--ff-only"], check=False) | |
| except Exception as e: | |
| print(f"[QIE] git pull failed: {e}") | |
| return | |
| if os.path.exists(target) and not _is_git_repo(target): | |
| print(f"[QIE] Warning: {target} exists and is not a git repo. Skipping clone.") | |
| return | |
| print(f"[QIE] Cloning musubi-tuner into {target} from {repo} ...") | |
| try: | |
| subprocess.run(["git", "clone", "--depth", "1", repo, target], check=True) | |
| print("[QIE] Clone completed.") | |
| except subprocess.CalledProcessError as e: | |
| print(f"[QIE] Clone failed at {target}: {e}") | |
| # Last-chance fallback into home | |
| if not target.startswith(os.path.expanduser("~")): | |
| fallback = os.path.join(os.path.expanduser("~"), "musubi-tuner") | |
| print(f"[QIE] Retrying clone into {fallback}...") | |
| try: | |
| subprocess.run(["git", "clone", "--depth", "1", repo, fallback], check=True) | |
| MUSUBI_TUNER_DIR_RUNTIME = fallback | |
| print("[QIE] Clone completed in fallback.") | |
| except Exception as e2: | |
| print(f"[QIE] Clone failed in fallback as well: {e2}") | |
| def _run_pip(args: List[str], cwd: Optional[str] = None) -> None: | |
| cmd = [sys.executable, "-m", "pip"] + args | |
| try: | |
| print(f"[QIE] pip {' '.join(args)} (cwd={cwd or os.getcwd()})") | |
| subprocess.run(cmd, check=True, cwd=cwd) | |
| except subprocess.CalledProcessError as e: | |
| print(f"[QIE] pip failed: {e}") | |
| def _startup_install_musubi_deps() -> None: | |
| repo_dir = MUSUBI_TUNER_DIR_RUNTIME | |
| if not os.path.isdir(repo_dir): | |
| print(f"[QIE] Skip deps: musubi-tuner not found at {repo_dir}") | |
| return | |
| # Upgrade basic build tooling (best-effort) | |
| try: | |
| _run_pip(["install", "-U", "pip", "setuptools", "wheel"]) | |
| except Exception: | |
| pass | |
| # Optional Torch extra via env: MUSUBI_TUNER_TORCH_EXTRA=cu124|cu128 | |
| extra = os.environ.get("MUSUBI_TUNER_TORCH_EXTRA", "").strip() | |
| editable_spec = "." if not extra else f".[{extra}]" | |
| # Install musubi-tuner in editable mode to expose entrypoints and deps | |
| try: | |
| _run_pip(["install", "-e", editable_spec], cwd=repo_dir) | |
| except Exception: | |
| # Fallback: plain install without editable | |
| try: | |
| _run_pip(["install", editable_spec], cwd=repo_dir) | |
| except Exception: | |
| print("[QIE] WARN: musubi-tuner installation failed. Continuing.") | |
| def run_training( | |
| output_name: str, | |
| caption: str, | |
| image_uploads: Any, | |
| target_prefix: str, | |
| target_suffix: str, | |
| control0_uploads: Any, | |
| ctrl0_prefix: str, | |
| ctrl0_suffix: str, | |
| control1_uploads: Any, | |
| ctrl1_prefix: str, | |
| ctrl1_suffix: str, | |
| control2_uploads: Any, | |
| ctrl2_prefix: str, | |
| ctrl2_suffix: str, | |
| control3_uploads: Any, | |
| ctrl3_prefix: str, | |
| ctrl3_suffix: str, | |
| control4_uploads: Any, | |
| ctrl4_prefix: str, | |
| ctrl4_suffix: str, | |
| control5_uploads: Any, | |
| ctrl5_prefix: str, | |
| ctrl5_suffix: str, | |
| control6_uploads: Any, | |
| ctrl6_prefix: str, | |
| ctrl6_suffix: str, | |
| control7_uploads: Any, | |
| ctrl7_prefix: str, | |
| ctrl7_suffix: str, | |
| learning_rate: str, | |
| network_dim: int, | |
| train_res_w: int, | |
| train_res_h: int, | |
| train_batch_size: int, | |
| control_res_w: int, | |
| control_res_h: int, | |
| te_cache_batch_size: int, | |
| seed: int, | |
| max_epochs: int, | |
| save_every: int, | |
| ) -> Iterable[tuple]: | |
| global _ACTIVE_PROC, _ACTIVE_PGID | |
| # Basic validation | |
| log_buf = "[QIE] Start Training invoked.\n" | |
| ckpts: List[str] = [] | |
| artifacts: List[str] = [] | |
| # Emit an initial line so UI can confirm invocation | |
| yield (log_buf, ckpts, artifacts) | |
| if not output_name.strip(): | |
| log_buf += "[ERROR] OUTPUT NAME is required.\n" | |
| yield (log_buf, ckpts, artifacts) | |
| def _stop_active_training() -> None: | |
| """Ubuntu向けのハード停止: 実行中の学習プロセスのプロセスグループを終了する""" | |
| with _ACTIVE_LOCK: | |
| proc = _ACTIVE_PROC | |
| pgid = _ACTIVE_PGID | |
| if not proc: | |
| return | |
| try: | |
| if pgid is not None: | |
| os.killpg(pgid, signal.SIGTERM) | |
| else: | |
| os.kill(proc.pid, signal.SIGTERM) | |
| except Exception: | |
| pass | |
| try: | |
| proc.wait(timeout=5) | |
| except Exception: | |
| try: | |
| if pgid is not None: | |
| os.killpg(pgid, signal.SIGKILL) | |
| else: | |
| os.kill(proc.pid, signal.SIGKILL) | |
| except Exception: | |
| pass | |
| return | |
| if not caption.strip(): | |
| log_buf += "[ERROR] CAPTION is required.\n" | |
| yield (log_buf, ckpts, artifacts) | |
| return | |
| # Ensure /auto holds helper files expected by the script | |
| _ensure_workspace_auto_files() | |
| # Resolve data root and create dataset directories (auto-decide) | |
| global DATA_ROOT_RUNTIME | |
| DATA_ROOT_RUNTIME = _ensure_data_root(None) | |
| # Auto-generate dataset directory name | |
| import time | |
| ds_name = f"dataset_{int(time.time())}" | |
| ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name) | |
| img_folder_name = DEFAULT_IMAGE_FOLDER | |
| img_dir = os.path.join(ds_dir, img_folder_name) | |
| os.makedirs(img_dir, exist_ok=True) | |
| # Ingest uploads into dataset folders | |
| base_files = _extract_paths(image_uploads) | |
| if not base_files: | |
| log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n" | |
| yield (log_buf, ckpts, artifacts) | |
| return | |
| base_filenames = _copy_uploads(base_files, img_dir) | |
| log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n" | |
| yield (log_buf, ckpts, artifacts) | |
| # Prepare control sets | |
| control_upload_sets = [ | |
| _extract_paths(control0_uploads), | |
| _extract_paths(control1_uploads), | |
| _extract_paths(control2_uploads), | |
| _extract_paths(control3_uploads), | |
| _extract_paths(control4_uploads), | |
| _extract_paths(control5_uploads), | |
| _extract_paths(control6_uploads), | |
| _extract_paths(control7_uploads), | |
| ] | |
| # Require control_0; others optional | |
| if not control_upload_sets[0]: | |
| log_buf += "[ERROR] control_0 images are required.\n" | |
| yield (log_buf, ckpts, artifacts) | |
| return | |
| control_dirs: List[Optional[str]] = [] | |
| for i, uploads in enumerate(control_upload_sets): | |
| if not uploads: | |
| control_dirs.append(None) | |
| continue | |
| folder_name = f"control_{i}" | |
| cdir = os.path.join(ds_dir, folder_name) | |
| os.makedirs(cdir, exist_ok=True) | |
| # Simply copy; name matching will be handled by create_image_caption_json.py | |
| _copy_uploads(uploads, cdir) | |
| control_dirs.append(folder_name) | |
| log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n" | |
| yield (log_buf, ckpts, artifacts) | |
| # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh | |
| # Prepare script with user parameters | |
| control_folders = [ | |
| (control_dirs[i] if control_dirs[i] else None) | |
| for i in range(8) | |
| ] | |
| # Decide dataset_config path with fallback to runtime auto dir | |
| ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml") | |
| # Update dataset config with requested resolution/batch settings | |
| try: | |
| _update_dataset_toml( | |
| ds_conf, | |
| img_res_w=int(train_res_w) if train_res_w else None, | |
| img_res_h=int(train_res_h) if train_res_h else None, | |
| train_batch_size=int(train_batch_size) if train_batch_size else None, | |
| control_res_w=int(control_res_w) if control_res_w else None, | |
| control_res_h=int(control_res_h) if control_res_h else None, | |
| ) | |
| log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n" | |
| except Exception as e: | |
| log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n" | |
| # Expose dataset config for download (if exists) | |
| if os.path.isfile(ds_conf): | |
| artifacts = [ds_conf] | |
| # Resolve models_root and set output_dir_base to the unique dataset dir | |
| models_root = MODELS_ROOT_RUNTIME | |
| out_base = ds_dir | |
| try: | |
| os.makedirs(out_base, exist_ok=True) | |
| except Exception: | |
| pass | |
| tmp_script = _prepare_script( | |
| dataset_name=ds_name, | |
| caption=caption, | |
| data_root=DATA_ROOT_RUNTIME, | |
| image_folder=img_folder_name, | |
| control_folders=control_folders, | |
| models_root=models_root, | |
| output_dir_base=out_base, | |
| dataset_config=ds_conf, | |
| override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None, | |
| override_save_every=save_every if save_every and save_every > 0 else None, | |
| override_run_name=output_name.strip(), | |
| target_prefix=(target_prefix or ""), | |
| target_suffix=(target_suffix or ""), | |
| control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix], | |
| control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix], | |
| override_learning_rate=(learning_rate or None), | |
| override_network_dim=int(network_dim) if network_dim is not None else None, | |
| override_te_cache_bs=int(te_cache_batch_size) if te_cache_batch_size else None, | |
| override_seed=int(seed) if seed is not None else None, | |
| ) | |
| shell = _pick_shell() | |
| log_buf += f"[QIE] Using shell: {shell}\n" | |
| log_buf += f"[QIE] Running script: {tmp_script}\n" | |
| out_dir = os.path.join(out_base, output_name.strip()) | |
| ckpts = _list_checkpoints(out_dir) | |
| # Copy the final script to dataset dir for download | |
| used_script_path = os.path.join(out_base, "train_QIE_used.sh") | |
| try: | |
| shutil.copy2(str(tmp_script), used_script_path) | |
| try: | |
| os.chmod(used_script_path, 0o755) | |
| except Exception: | |
| pass | |
| if used_script_path not in artifacts: | |
| artifacts.append(used_script_path) | |
| except Exception: | |
| pass | |
| yield (log_buf, ckpts, artifacts) | |
| # Run and stream output | |
| # Ensure child Python processes are unbuffered for real-time logs | |
| child_env = os.environ.copy() | |
| child_env["PYTHONUNBUFFERED"] = "1" | |
| child_env["PYTHONIOENCODING"] = "utf-8" | |
| proc = subprocess.Popen( | |
| [shell, str(tmp_script)], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| env=child_env, | |
| preexec_fn=os.setsid, | |
| ) | |
| # Register active process for hard stop | |
| with _ACTIVE_LOCK: | |
| _ACTIVE_PROC = proc | |
| try: | |
| _ACTIVE_PGID = os.getpgid(proc.pid) | |
| except Exception: | |
| _ACTIVE_PGID = None | |
| try: | |
| assert proc.stdout is not None | |
| i = 0 | |
| for line in proc.stdout: | |
| log_buf += line | |
| i += 1 | |
| if i % 30 == 0: | |
| ckpts = _list_checkpoints(out_dir) | |
| # Try to add metadata.jsonl once available | |
| metadata_json = os.path.join(out_base, "metadata.jsonl") | |
| if os.path.isfile(metadata_json) and metadata_json not in artifacts: | |
| artifacts.append(metadata_json) | |
| yield (log_buf, ckpts, artifacts) | |
| finally: | |
| code = proc.wait() | |
| # Clear active process registration if this proc | |
| with _ACTIVE_LOCK: | |
| if _ACTIVE_PROC is proc: | |
| _ACTIVE_PROC = None | |
| _ACTIVE_PGID = None | |
| # Try to locate latest LoRA file for download | |
| lora_path = None | |
| try: | |
| ckpts = _list_checkpoints(out_dir) | |
| except Exception: | |
| pass | |
| lora_path = ckpts[0] if ckpts else None | |
| if code < 0: | |
| try: | |
| sig = -code | |
| log_buf += f"[QIE] Terminated by signal: {sig}\n" | |
| except Exception: | |
| log_buf += f"[QIE] Terminated by signal.\n" | |
| log_buf += f"[QIE] Exit code: {code}\n" | |
| # Final attempt to include metadata.jsonl | |
| metadata_json = os.path.join(out_base, "metadata.jsonl") | |
| if os.path.isfile(metadata_json) and metadata_json not in artifacts: | |
| artifacts.append(metadata_json) | |
| yield (log_buf, ckpts, artifacts) | |
| def build_ui() -> gr.Blocks: | |
| css = """ | |
| .pad-section { | |
| padding: 6px; | |
| margin-bottom: 12px; | |
| border: 1px solid var(--color-border, #e5e7eb); | |
| border-radius: 8px; | |
| background: var(--color-background-secondary, #ffffff); | |
| } | |
| .pad-section_0 { | |
| padding: 6px; | |
| margin-bottom: 12px; | |
| border: 1px solid var(--color-border, #e5e7eb); | |
| border-radius: 8px; | |
| background: var(--color-background-secondary, #fafafa); | |
| } | |
| .pad-section_1 { | |
| padding: 6px; | |
| margin-bottom: 12px; | |
| border: 1px solid var(--color-border, #e5e7eb); | |
| border-radius: 8px; | |
| background: var(--color-background-secondary, #eaeaea); | |
| } | |
| """ | |
| with gr.Blocks(title="Qwen-Image-Edit: Trainer", css=css) as demo: | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Training"): | |
| gr.Markdown(""" | |
| # Qwen-Image-Edit Trainer | |
| 学習に使う画像をアップロードし、必要ならファイル名の前後にある共通の文字(prefix/suffix)を指定して、 | |
| 自動でデータセットを作成し学習を開始します。難しい操作は不要です。 | |
| """) | |
| with gr.Accordion("Settings", elem_classes=["pad-section"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1) | |
| caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2) | |
| with gr.Row(): | |
| lr_input = gr.Textbox(label="Learning rate", value="1e-3") | |
| dim_input = gr.Number(label="Network dim", value=4, precision=0) | |
| train_bs = gr.Number(label="Batch size (dataset)", value=1, precision=0) | |
| seed_input = gr.Number(label="Seed", value=42, precision=0) | |
| max_epochs = gr.Number(label="Max epochs", value=100, precision=0) | |
| save_every = gr.Number(label="Save every N epochs", value=10, precision=0) | |
| with gr.Row(): | |
| tr_w = gr.Number(label="Image resolution W", value=1024, precision=0) | |
| tr_h = gr.Number(label="Image resolution H", value=1024, precision=0) | |
| cr_w = gr.Number(label="Control resolution W", value=1024, precision=0) | |
| cr_h = gr.Number(label="Control resolution H", value=1024, precision=0) | |
| te_bs = gr.Number(label="TE cache batch size", value=16, precision=0) | |
| with gr.Accordion("Target Image", elem_classes=["pad-section_0"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath", height=220, scale=3) | |
| main_gallery = gr.Gallery(label="Target preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| main_prefix = gr.Textbox(label="Target prefix", placeholder="e.g., IMG_") | |
| main_suffix = gr.Textbox(label="Target suffix", placeholder="e.g., _v2") | |
| with gr.Accordion("prefix/sufixについて", open=False): | |
| gr.Markdown(""" | |
| ファイルの同名判定のため、画像のファイル名から共通の先頭/末尾文字を取り除く指定(例: IMG_ や _v2) | |
| - まずターゲット画像のファイル名(拡張子なし)から、指定した Target prefix/suffix を取り除いたものを key とします。 | |
| - 各コントロールは「付加」規則で、期待名 = control_prefix_i + key + control_suffix_i + ".png" を探して対応付けます。 | |
| - アップロード時に画像は自動で .png に変換して保存します(元のファイル名のベースは維持)。 | |
| - Control 0 は必須、Control 1〜7 は任意。コントロール画像が1枚だけのときは、すべてのターゲット画像に適用します。 | |
| """) | |
| # control_0 is required and shown outside the accordion | |
| with gr.Accordion("Control 0", elem_classes=["pad-section_1"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl0_files = gr.File(label="Upload control_0 images (required)", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl0_gallery = gr.Gallery(label="control_0 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl0_prefix = gr.Textbox(label="control_0 prefix", placeholder="e.g., C0_") | |
| ctrl0_suffix = gr.Textbox(label="control_0 suffix", placeholder="e.g., _mask") | |
| # Optional controls start from 1, accordion closed by default | |
| with gr.Accordion("Control 1", open=False, elem_classes=["pad-section_0"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl1_files = gr.File(label="Upload control_1 images", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl1_gallery = gr.Gallery(label="control_1 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl1_prefix = gr.Textbox(label="control_1 prefix", placeholder="") | |
| ctrl1_suffix = gr.Textbox(label="control_1 suffix", placeholder="") | |
| with gr.Accordion("Control 2", open=False, elem_classes=["pad-section_1"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl2_files = gr.File(label="Upload control_2 images", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl2_gallery = gr.Gallery(label="control_2 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl2_prefix = gr.Textbox(label="control_2 prefix", placeholder="") | |
| ctrl2_suffix = gr.Textbox(label="control_2 suffix", placeholder="") | |
| with gr.Accordion("Control 3", open=False, elem_classes=["pad-section_0"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl3_files = gr.File(label="Upload control_3 images", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl3_gallery = gr.Gallery(label="control_3 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl3_prefix = gr.Textbox(label="control_3 prefix", placeholder="") | |
| ctrl3_suffix = gr.Textbox(label="control_3 suffix", placeholder="") | |
| with gr.Accordion("Control 4", open=False, elem_classes=["pad-section_1"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl4_files = gr.File(label="Upload control_4 images", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl4_gallery = gr.Gallery(label="control_4 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl4_prefix = gr.Textbox(label="control_4 prefix", placeholder="") | |
| ctrl4_suffix = gr.Textbox(label="control_4 suffix", placeholder="") | |
| with gr.Accordion("Control 5", open=False, elem_classes=["pad-section_0"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl5_files = gr.File(label="Upload control_5 images", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl5_gallery = gr.Gallery(label="control_5 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl5_prefix = gr.Textbox(label="control_5 prefix", placeholder="") | |
| ctrl5_suffix = gr.Textbox(label="control_5 suffix", placeholder="") | |
| with gr.Accordion("Control 6", open=False, elem_classes=["pad-section_1"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl6_files = gr.File(label="Upload control_6 images", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl6_gallery = gr.Gallery(label="control_6 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl6_prefix = gr.Textbox(label="control_6 prefix", placeholder="") | |
| ctrl6_suffix = gr.Textbox(label="control_6 suffix", placeholder="") | |
| with gr.Accordion("Control 7", open=False, elem_classes=["pad-section_0"]): | |
| with gr.Group(): | |
| with gr.Row(): | |
| ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath", height=220, scale=3) | |
| ctrl7_gallery = gr.Gallery(label="control_7 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| ctrl7_prefix = gr.Textbox(label="control_7 prefix", placeholder="") | |
| ctrl7_suffix = gr.Textbox(label="control_7 suffix", placeholder="") | |
| # Models root / OUTPUT_DIR_BASE / DATASET_CONFIG are auto-resolved at runtime; no user input needed. | |
| run_btn = gr.Button("Start Training", variant="primary") | |
| logs = gr.Textbox(label="Logs", lines=20) | |
| ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False) | |
| scripts_files = gr.Files(label="Scripts & Config (live)", interactive=False) | |
| with gr.Row(): | |
| stop_btn = gr.Button("学習を停止", variant="secondary") | |
| refresh_scripts_btn = gr.Button("ファイルを再取得", variant="secondary") | |
| # moved max_epochs/save_every above next to OUTPUT NAME | |
| # Wire previews | |
| images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery) | |
| ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery) | |
| ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery) | |
| ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery) | |
| ctrl3_files.change(fn=_files_to_gallery, inputs=ctrl3_files, outputs=ctrl3_gallery) | |
| ctrl4_files.change(fn=_files_to_gallery, inputs=ctrl4_files, outputs=ctrl4_gallery) | |
| ctrl5_files.change(fn=_files_to_gallery, inputs=ctrl5_files, outputs=ctrl5_gallery) | |
| ctrl6_files.change(fn=_files_to_gallery, inputs=ctrl6_files, outputs=ctrl6_gallery) | |
| ctrl7_files.change(fn=_files_to_gallery, inputs=ctrl7_files, outputs=ctrl7_gallery) | |
| run_btn.click( | |
| fn=run_training, | |
| inputs=[ | |
| output_name, caption, images_input, main_prefix, main_suffix, | |
| ctrl0_files, ctrl0_prefix, ctrl0_suffix, | |
| ctrl1_files, ctrl1_prefix, ctrl1_suffix, | |
| ctrl2_files, ctrl2_prefix, ctrl2_suffix, | |
| ctrl3_files, ctrl3_prefix, ctrl3_suffix, | |
| ctrl4_files, ctrl4_prefix, ctrl4_suffix, | |
| ctrl5_files, ctrl5_prefix, ctrl5_suffix, | |
| ctrl6_files, ctrl6_prefix, ctrl6_suffix, | |
| ctrl7_files, ctrl7_prefix, ctrl7_suffix, | |
| lr_input, dim_input, | |
| tr_w, tr_h, train_bs, cr_w, cr_h, te_bs, | |
| seed_input, max_epochs, save_every, | |
| ], | |
| outputs=[logs, ckpt_files, scripts_files], | |
| ) | |
| # 回収ボタン: 直近の dataset_ ディレクトリからチェックポイントとスクリプト/設定を再取得 | |
| def _refresh_all() -> tuple: | |
| try: | |
| ds_dir = _find_latest_dataset_dir(DATA_ROOT_RUNTIME) | |
| except Exception: | |
| ds_dir = None | |
| try: | |
| ck = _list_checkpoints(ds_dir) if ds_dir else [] | |
| except Exception: | |
| ck = [] | |
| try: | |
| sc = _collect_scripts_and_config(ds_dir) | |
| except Exception: | |
| sc = _collect_scripts_and_config(None) | |
| return ck, sc | |
| refresh_scripts_btn.click( | |
| fn=_refresh_all, | |
| inputs=[], | |
| outputs=[ckpt_files, scripts_files], | |
| ) | |
| # Hard stop button (Ubuntu): kill active process group | |
| def _on_stop_click(): | |
| _stop_active_training() | |
| return | |
| stop_btn.click(fn=_on_stop_click, inputs=[], outputs=[]) | |
| with gr.TabItem("Prompt Generator"): | |
| gr.Markdown(""" | |
| # 🎨 A→B 変換プロンプト自動生成 | |
| 画像A(入力)と画像B(出力)、補足説明を入力すると、 | |
| A→B の変換内容を英語プロンプトとして自動生成し、タスク名候補(3件)も提案します。 | |
| モデルは `gpt-5` を使用します。 | |
| """) | |
| api_key_pg = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...") | |
| with gr.Row(): | |
| img_a_pg = gr.Image(type="filepath", label="Image A (Input)", height=300) | |
| img_b_pg = gr.Image(type="filepath", label="Image B (Output)", height=300) | |
| notes_pg = gr.Textbox(label="補足説明(日本語可)", lines=4, value="この画像は例であって、汎用的なプロンプトにする") | |
| want_japanese_pg = gr.Checkbox(label="日本語訳を含める", value=True) | |
| run_btn_pg = gr.Button("生成する", variant="primary") | |
| english_out_pg = gr.Textbox(label="English Prompt", lines=8) | |
| names_out_pg = gr.Textbox(label="Name Suggestions", lines=4) | |
| japanese_out_pg = gr.Textbox(label="日本語訳(任意)", lines=8) | |
| def _on_click_prompt(api_key_in, a_path, b_path, notes_in, ja_flag): | |
| # Lazy import to avoid constructing extra Blocks at startup | |
| qpg = importlib.import_module("QIE_prompt_generator") | |
| a_url = qpg.file_to_data_url(a_path) if a_path else None | |
| b_url = qpg.file_to_data_url(b_path) if b_path else None | |
| return qpg.call_openai_chat(api_key_in, a_url, b_url, notes_in, ja_flag) | |
| run_btn_pg.click( | |
| fn=_on_click_prompt, | |
| inputs=[api_key_pg, img_a_pg, img_b_pg, notes_pg, want_japanese_pg], | |
| outputs=[english_out_pg, names_out_pg, japanese_out_pg], | |
| ) | |
| return demo | |
| def _startup_download_models() -> None: | |
| global MODELS_ROOT_RUNTIME | |
| # Pick a writable models directory | |
| candidate = os.environ.get("QWEN_IMAGE_MODELS_DIR", DEFAULT_MODELS_ROOT) | |
| try: | |
| os.makedirs(candidate, exist_ok=True) | |
| MODELS_ROOT_RUNTIME = candidate | |
| except PermissionError: | |
| MODELS_ROOT_RUNTIME = os.path.join(os.path.expanduser("~"), "Qwen-Image_models") | |
| os.makedirs(MODELS_ROOT_RUNTIME, exist_ok=True) | |
| print(f"[QIE] Ensuring models in: {MODELS_ROOT_RUNTIME}") | |
| try: | |
| download_all_models(MODELS_ROOT_RUNTIME) | |
| except Exception as e: | |
| print(f"[QIE] Model download failed: {e}") | |
| if __name__ == "__main__": | |
| # 1) Ensure musubi-tuner is cloned before anything else | |
| _startup_clone_musubi_tuner() | |
| # 1.1) Install musubi-tuner dependencies (best-effort) | |
| _startup_install_musubi_deps() | |
| # 2) Download models at startup (blocking by design) | |
| _startup_download_models() | |
| # 3) Launch Gradio app | |
| ui = build_ui() | |
| # Limit concurrency (training is heavy). Enable queue for Spaces compatibility. | |
| # Use generic signature to support multiple gradio versions. | |
| try: | |
| ui = ui.queue(max_size=16) | |
| except TypeError: | |
| ui = ui.queue() | |
| # Allow Gradio to serve files saved under our runtime dirs | |
| try: | |
| allowed = [ | |
| AUTO_DIR_RUNTIME, | |
| os.path.join(AUTO_DIR_RUNTIME, "train_LoRA"), | |
| DEFAULT_DATA_ROOT, | |
| DATA_ROOT_RUNTIME, | |
| os.path.join(os.path.expanduser("~"), "auto"), | |
| os.path.join(os.path.expanduser("~"), "data"), | |
| ] | |
| ui.launch(server_name="0.0.0.0", allowed_paths=allowed, ssr_mode=False) | |
| except TypeError: | |
| # Older gradio without allowed_paths | |
| try: | |
| ui.launch(server_name="0.0.0.0", ssr_mode=False) | |
| except TypeError: | |
| # Very old gradio without ssr_mode | |
| ui.launch(server_name="0.0.0.0") | |