yeq6x's picture
Enhance run_training function in app.py to emit an initial log message upon invocation. This update improves user feedback by confirming the start of the training process and maintains existing error handling for required output names.
bb4717a
#!/usr/bin/env python3
import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Any, Tuple
import time
import json
import gradio as gr
import importlib
import spaces
import signal
from threading import Lock
# Local modules
from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR
# Defaults matching train_QIE.sh expectations
DEFAULT_DATA_ROOT = "/data"
DEFAULT_IMAGE_FOLDER = "image"
DEFAULT_OUTPUT_DIR_BASE = "/auto/train_LoRA"
DEFAULT_DATASET_CONFIG = "/auto/dataset_QIE.toml"
DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR # "/Qwen-Image_models"
WORKSPACE_AUTO_DIR = "/auto"
# musubi-tuner settings
DEFAULT_MUSUBI_TUNER_DIR = os.environ.get("MUSUBI_TUNER_DIR", "/musubi-tuner")
DEFAULT_MUSUBI_TUNER_REPO = os.environ.get(
"MUSUBI_TUNER_REPO", "https://github.com/kohya-ss/musubi-tuner.git"
)
TRAINING_DIR = Path(__file__).resolve().parent
# Runtime-resolved paths with fallbacks for non-root environments
MUSUBI_TUNER_DIR_RUNTIME = DEFAULT_MUSUBI_TUNER_DIR
MODELS_ROOT_RUNTIME = DEFAULT_MODELS_ROOT
AUTO_DIR_RUNTIME = WORKSPACE_AUTO_DIR
DATA_ROOT_RUNTIME = DEFAULT_DATA_ROOT
# Active process management for hard stop (Ubuntu)
_ACTIVE_LOCK: Lock = Lock()
_ACTIVE_PROC: Optional[subprocess.Popen] = None
_ACTIVE_PGID: Optional[int] = None
def _bash_quote(s: str) -> str:
"""Return a POSIX-safe single-quoted string literal representing s."""
if s is None:
return "''"
return "'" + str(s).replace("'", "'\"'\"'") + "'"
def _ensure_workspace_auto_files() -> None:
"""Ensure /workspace/auto has required helper files from this repo.
Copies training/create_image_caption_json.py and training/dataset_QIE.toml
into /workspace/auto so that train_QIE.sh can run unmodified.
"""
global AUTO_DIR_RUNTIME
try:
os.makedirs(AUTO_DIR_RUNTIME, exist_ok=True)
except PermissionError:
home_auto = os.path.join(os.path.expanduser("~"), "auto")
os.makedirs(home_auto, exist_ok=True)
AUTO_DIR_RUNTIME = home_auto # type: ignore
src_py = TRAINING_DIR / "create_image_caption_json.py"
src_toml = TRAINING_DIR / "dataset_QIE.toml"
dst_py = Path(AUTO_DIR_RUNTIME) / "create_image_caption_json.py"
dst_toml = Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml"
try:
shutil.copy2(src_py, dst_py)
except Exception:
pass
try:
if src_toml.exists():
shutil.copy2(src_toml, dst_toml)
except Exception:
pass
def _update_dataset_toml(
path: str,
*,
img_res_w: Optional[int] = None,
img_res_h: Optional[int] = None,
train_batch_size: Optional[int] = None,
control_res_w: Optional[int] = None,
control_res_h: Optional[int] = None,
) -> None:
"""Update dataset TOML for resolution/batch/control resolution in-place.
- Updates [general] resolution and batch_size if provided.
- Updates first [[datasets]] qwen_image_edit_control_resolution if provided.
- Creates sections/keys if missing.
"""
try:
txt = Path(path).read_text(encoding="utf-8")
except Exception:
return
def _set_in_general(block: str, key: str, value_line: str) -> str:
import re as _re
if _re.search(rf"(?m)^\s*{_re.escape(key)}\s*=", block):
block = _re.sub(rf"(?m)^\s*{_re.escape(key)}\s*=.*$", value_line, block)
else:
block = block.rstrip() + "\n" + value_line + "\n"
return block
import re
m = re.search(r"(?ms)^\[general\]\s*(.*?)(?=^\[|\Z)", txt)
if not m:
gen = "[general]\n"
if img_res_w and img_res_h:
gen += f"resolution = [{int(img_res_w)}, {int(img_res_h)}]\n"
if train_batch_size is not None:
gen += f"batch_size = {int(train_batch_size)}\n"
txt = gen + "\n" + txt
else:
head, block, tail = txt[:m.start(1)], m.group(1), txt[m.end(1):]
if img_res_w and img_res_h:
block = _set_in_general(block, "resolution", f"resolution = [{int(img_res_w)}, {int(img_res_h)}]")
if train_batch_size is not None:
block = _set_in_general(block, "batch_size", f"batch_size = {int(train_batch_size)}")
txt = head + block + tail
if control_res_w and control_res_h:
m2 = re.search(r"(?ms)^\[\[datasets\]\]\s*(.*?)(?=^\[\[|\Z)", txt)
if m2:
head, block, tail = txt[:m2.start(1)], m2.group(1), txt[m2.end(1):]
line = f"qwen_image_edit_control_resolution = [{int(control_res_w)}, {int(control_res_h)}]"
if re.search(r"(?m)^\s*qwen_image_edit_control_resolution\s*=", block):
block = re.sub(r"(?m)^\s*qwen_image_edit_control_resolution\s*=.*$", line, block)
else:
block = block.rstrip() + "\n" + line + "\n"
txt = head + block + tail
try:
Path(path).write_text(txt, encoding="utf-8")
except Exception:
pass
def _ensure_dir_writable(path: str) -> str:
try:
os.makedirs(path, exist_ok=True)
return path
except PermissionError:
home_path = os.path.join(os.path.expanduser("~"), os.path.basename(path.strip("/\\")))
os.makedirs(home_path, exist_ok=True)
return home_path
def _ensure_data_root(candidate: Optional[str]) -> str:
root = (candidate or DEFAULT_DATA_ROOT).strip() or DEFAULT_DATA_ROOT
try:
os.makedirs(root, exist_ok=True)
return root
except PermissionError:
home_root = os.path.join(os.path.expanduser("~"), "data")
os.makedirs(home_root, exist_ok=True)
return home_root
def _extract_paths(files: Any) -> List[Tuple[str, str]]:
"""Extract a list of (abs_path, orig_basename) from Gradio Files input.
Supports various gradio return shapes across versions.
"""
out: List[Tuple[str, str]] = []
if not files:
return out
# Gradio Files often returns a list
if isinstance(files, (list, tuple)):
items = files
else:
items = [files]
for item in items:
p: Optional[str] = None
orig: Optional[str] = None
# dict-like
if isinstance(item, dict):
p = item.get("path") or item.get("name") or item.get("file")
orig = item.get("orig_name") or item.get("name")
else:
# object with attributes
p = getattr(item, "name", None) or getattr(item, "path", None) or str(item)
# best-effort original name attribute
orig = getattr(item, "orig_name", None) or os.path.basename(p) if p else None
if p:
abs_p = os.path.abspath(p)
out.append((abs_p, os.path.basename(orig or abs_p)))
return out
def _norm_key(filename: str, prefix: str, suffix: str) -> str:
stem = os.path.splitext(os.path.basename(filename))[0]
if prefix and stem.startswith(prefix):
stem = stem[len(prefix):]
if suffix and stem.endswith(suffix):
stem = stem[: -len(suffix)]
return stem
def _copy_uploads(
uploads: List[Tuple[str, str]], dest_dir: str, rename_to: Optional[List[str]] = None
) -> List[str]:
os.makedirs(dest_dir, exist_ok=True)
used_names: List[str] = []
for idx, (src, orig) in enumerate(uploads):
# Determine target stem
if rename_to and idx < len(rename_to):
stem = os.path.splitext(rename_to[idx])[0]
else:
stem = os.path.splitext(orig)[0]
dst_name = f"{stem}.png"
# ensure unique within this batch
final_name = dst_name
dup_idx = 1
while final_name in used_names:
final_name = f"{stem}_{dup_idx}.png"
dup_idx += 1
dst_path = os.path.join(dest_dir, final_name)
# Convert to PNG during save
try:
try:
from PIL import Image # type: ignore
with Image.open(src) as img:
img.save(dst_path, format="PNG")
except Exception:
# Fallback: copy then rename
shutil.copy2(src, dst_path)
except Exception:
# Last resort
shutil.copy(src, dst_path)
used_names.append(final_name)
return used_names
def _list_checkpoints(out_dir: str, limit: int = 20) -> List[str]:
try:
if not out_dir or not os.path.isdir(out_dir):
return []
import time
now = time.time()
min_age_sec = 3.0 # treat files newer than this as possibly in-flight
items: List[Tuple[float, str]] = []
for root, _, files in os.walk(out_dir):
for fn in files:
if fn.lower().endswith('.safetensors'):
full = os.path.join(root, fn)
try:
# Skip zero-length, too-new, or unreadable files (likely in-flight)
size = os.path.getsize(full)
if size <= 0:
continue
mtime = os.path.getmtime(full)
if (now - mtime) < min_age_sec:
continue
# Try opening a small read to ensure readability
with open(full, 'rb') as rf:
rf.read(64)
items.append((mtime, full))
except Exception:
pass
items.sort(reverse=True)
return [p for _, p in items[:limit]]
except Exception:
return []
def _find_latest_dataset_dir(root: str) -> Optional[str]:
try:
if not os.path.isdir(root):
return None
cand: List[Tuple[float, str]] = []
for name in os.listdir(root):
if not name.startswith("dataset_"):
continue
full = os.path.join(root, name)
if os.path.isdir(full):
try:
cand.append((os.path.getmtime(full), full))
except Exception:
pass
if not cand:
return None
cand.sort(reverse=True)
return cand[0][1]
except Exception:
return None
def _collect_scripts_and_config(ds_dir: Optional[str]) -> List[str]:
files: List[str] = []
try:
ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
if os.path.isfile(ds_conf):
files.append(ds_conf)
if ds_dir and os.path.isdir(ds_dir):
used_script = os.path.join(ds_dir, "train_QIE_used.sh")
if os.path.isfile(used_script):
files.append(used_script)
meta = os.path.join(ds_dir, "metadata.jsonl")
if os.path.isfile(meta):
files.append(meta)
except Exception:
pass
return files
def _files_to_gallery(files: Any) -> List[str]:
items: List[str] = []
if not files:
return items
seq = files if isinstance(files, (list, tuple)) else [files]
for f in seq:
p = None
if isinstance(f, str):
p = f
elif isinstance(f, dict):
p = f.get("path") or f.get("name")
else:
p = getattr(f, "path", None) or getattr(f, "name", None)
if p:
items.append(p)
return items
def _prepare_script(
dataset_name: str,
caption: str,
data_root: str,
image_folder: str,
control_folders: List[Optional[str]],
models_root: str,
output_dir_base: Optional[str] = None,
dataset_config: Optional[str] = None,
override_max_epochs: Optional[int] = None,
override_save_every: Optional[int] = None,
override_run_name: Optional[str] = None,
target_prefix: Optional[str] = None,
target_suffix: Optional[str] = None,
control_prefixes: Optional[List[Optional[str]]] = None,
control_suffixes: Optional[List[Optional[str]]] = None,
override_learning_rate: Optional[str] = None,
override_network_dim: Optional[int] = None,
override_seed: Optional[int] = None,
override_te_cache_bs: Optional[int] = None,
) -> Path:
"""Create a temporary copy of train_QIE.sh with injected variables.
Only variables that must vary per-run are replaced. The rest of the script
remains as-is to preserve behavior.
"""
src = TRAINING_DIR / "train_QIE.sh"
txt = src.read_text(encoding="utf-8")
# Replace core variables
replacements = {
r"^DATA_ROOT=\".*\"": f"DATA_ROOT={_bash_quote(data_root)}",
r"^DATASET_NAME=\".*\"": f"DATASET_NAME={_bash_quote(dataset_name)}",
r"^CAPTION=\".*\"": f"CAPTION={_bash_quote(caption)}",
r"^IMAGE_FOLDER=\".*\"": f"IMAGE_FOLDER={_bash_quote(image_folder)}",
}
if output_dir_base:
replacements[r"^OUTPUT_DIR_BASE=\".*\""] = (
f"OUTPUT_DIR_BASE={_bash_quote(output_dir_base)}"
)
if dataset_config:
replacements[r"^DATASET_CONFIG=\".*\""] = (
f"DATASET_CONFIG={_bash_quote(dataset_config)}"
)
for pat, val in replacements.items():
txt = re.sub(pat, val, txt, flags=re.MULTILINE)
# Inject CONTROL_FOLDER_i if provided (uncomment/override or append)
for i in range(8):
val = control_folders[i] if i < len(control_folders) else None
if not val:
continue
# Try to replace commented placeholder first
pattern = rf"^#\s*CONTROL_FOLDER_{i}=\".*\""
if re.search(pattern, txt, flags=re.MULTILINE):
txt = re.sub(
pattern,
f"CONTROL_FOLDER_{i}={_bash_quote(val)}",
txt,
flags=re.MULTILINE,
)
else:
# Append after IMAGE_FOLDER definition
txt = re.sub(
r"^(IMAGE_FOLDER=.*)$",
rf"\1\nCONTROL_FOLDER_{i}={_bash_quote(val)}",
txt,
count=1,
flags=re.MULTILINE,
)
# Point model paths to the selected models_root
def _replace_model_path(txt: str, key: str, rel: str) -> str:
return re.sub(
rf"--{key} \"[^\"]+\"",
f"--{key} \"{models_root.rstrip('/')}/{rel}\"",
txt,
)
txt = _replace_model_path(txt, "vae", "vae/diffusion_pytorch_model.safetensors")
txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")
# Replace working dir for metadata generation to runtime /auto
txt = re.sub(r"^cd\s+/workspace/auto\s*$", f"cd {AUTO_DIR_RUNTIME}", txt, flags=re.MULTILINE)
# Ensure musubi-tuner path matches runtime location
txt = re.sub(r"^cd\s+/musubi-tuner\s*$", f"cd {re.escape(MUSUBI_TUNER_DIR_RUNTIME)}", txt, flags=re.MULTILINE)
# ZeroGPU compatibility: avoid spawning via 'accelerate launch'.
# Run the training module directly in-process so GPU stays attached
# to the same Python request context.
txt = re.sub(
r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py",
r"python -u src/musubi_tuner/qwen_image_train_network.py",
txt,
flags=re.MULTILINE,
)
# Optionally override epochs and save frequency for ZeroGPU time slicing
if override_max_epochs is not None and override_max_epochs > 0:
txt = re.sub(r"--max_train_epochs\s+\d+",
f"--max_train_epochs {override_max_epochs}", txt)
if override_save_every is not None and override_save_every > 0:
txt = re.sub(r"--save_every_n_epochs\s+\d+",
f"--save_every_n_epochs {override_save_every}", txt)
if override_run_name:
txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE)
# Inject prefix/suffix flags for metadata creation
extra_lines: List[str] = []
if (target_prefix or ""):
extra_lines.append(f" --target_prefix {_bash_quote(target_prefix)} \\")
if (target_suffix or ""):
extra_lines.append(f" --target_suffix {_bash_quote(target_suffix)} \\")
for i in range(8):
pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None
suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None
if pre:
extra_lines.append(f" --control_prefix_{i} {_bash_quote(pre)} \\")
if suf:
extra_lines.append(f" --control_suffix_{i} {_bash_quote(suf)} \\")
if extra_lines:
extra_block = "\n".join(extra_lines)
# Insert extra flags just before the CONTROL_ARGS line, preserving indentation.
txt = re.sub(
r'^(\s*)"\$\{CONTROL_ARGS\[@\]\}"',
lambda m: f"{extra_block}\n{m.group(1)}\"${{CONTROL_ARGS[@]}}\"",
txt,
flags=re.MULTILINE,
)
# Override CLI hyperparameters if provided
if override_learning_rate:
txt = re.sub(r"--learning_rate\s+[-+eE0-9\.]+", f"--learning_rate {override_learning_rate}", txt)
if override_network_dim is not None:
txt = re.sub(r"--network_dim\s+\d+", f"--network_dim {override_network_dim}", txt)
if override_seed is not None:
txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
# Optionally override text-encoder cache batch size
if override_te_cache_bs is not None and override_te_cache_bs > 0:
txt = re.sub(
r"(qwen_image_cache_text_encoder_outputs\.py[^\n]*--batch_size\s+)\d+",
rf"\g<1>{int(override_te_cache_bs)}",
txt,
flags=re.MULTILINE,
)
# Prefer overriding variable definitions at top of script (safer than CLI regex)
def _set_var(name: str, value: str) -> None:
nonlocal txt
pattern = rf"(?m)^\s*{name}\s*=.*$"
replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}'
if re.search(pattern, txt):
txt = re.sub(pattern, replacement, txt)
else:
txt = f"{replacement}\n" + txt
if override_learning_rate:
_set_var('LEARNING_RATE', override_learning_rate)
if override_network_dim is not None:
_set_var('NETWORK_DIM', str(override_network_dim))
if override_seed is not None:
_set_var('SEED', str(override_seed))
if override_max_epochs is not None and override_max_epochs > 0:
_set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs))
if override_save_every is not None and override_save_every > 0:
_set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every))
# Write to a temp file alongside this repo for easier inspection
run_dir = TRAINING_DIR / ".gradio_runs"
run_dir.mkdir(parents=True, exist_ok=True)
tmp = run_dir / f"train_QIE_run_{os.getpid()}.sh"
tmp.write_text(txt, encoding="utf-8", newline="\n")
try:
os.chmod(tmp, 0o755)
except Exception:
pass
return tmp
def _pick_shell() -> str:
for sh in ("bash", "sh"):
if shutil.which(sh):
return sh
raise RuntimeError("No POSIX shell found. Please install bash or sh.")
def _is_git_repo(path: str) -> bool:
try:
out = subprocess.run(
["git", "-C", path, "rev-parse", "--is-inside-work-tree"],
capture_output=True,
text=True,
check=False,
)
return out.returncode == 0 and out.stdout.strip() == "true"
except Exception:
return False
def _startup_clone_musubi_tuner() -> None:
global MUSUBI_TUNER_DIR_RUNTIME
target = MUSUBI_TUNER_DIR_RUNTIME
repo = DEFAULT_MUSUBI_TUNER_REPO
parent = os.path.dirname(target.rstrip("/\\")) or "/"
try:
os.makedirs(parent, exist_ok=True)
except PermissionError:
# Fallback to home directory
target = os.path.join(os.path.expanduser("~"), "musubi-tuner")
MUSUBI_TUNER_DIR_RUNTIME = target
os.makedirs(os.path.dirname(target), exist_ok=True)
except Exception:
pass
if os.path.isdir(target) and _is_git_repo(target):
print(f"[QIE] musubi-tuner exists at {target}; pulling latest...")
try:
subprocess.run(["git", "-C", target, "fetch", "--all", "--prune"], check=False)
subprocess.run(["git", "-C", target, "pull", "--ff-only"], check=False)
except Exception as e:
print(f"[QIE] git pull failed: {e}")
return
if os.path.exists(target) and not _is_git_repo(target):
print(f"[QIE] Warning: {target} exists and is not a git repo. Skipping clone.")
return
print(f"[QIE] Cloning musubi-tuner into {target} from {repo} ...")
try:
subprocess.run(["git", "clone", "--depth", "1", repo, target], check=True)
print("[QIE] Clone completed.")
except subprocess.CalledProcessError as e:
print(f"[QIE] Clone failed at {target}: {e}")
# Last-chance fallback into home
if not target.startswith(os.path.expanduser("~")):
fallback = os.path.join(os.path.expanduser("~"), "musubi-tuner")
print(f"[QIE] Retrying clone into {fallback}...")
try:
subprocess.run(["git", "clone", "--depth", "1", repo, fallback], check=True)
MUSUBI_TUNER_DIR_RUNTIME = fallback
print("[QIE] Clone completed in fallback.")
except Exception as e2:
print(f"[QIE] Clone failed in fallback as well: {e2}")
def _run_pip(args: List[str], cwd: Optional[str] = None) -> None:
cmd = [sys.executable, "-m", "pip"] + args
try:
print(f"[QIE] pip {' '.join(args)} (cwd={cwd or os.getcwd()})")
subprocess.run(cmd, check=True, cwd=cwd)
except subprocess.CalledProcessError as e:
print(f"[QIE] pip failed: {e}")
def _startup_install_musubi_deps() -> None:
repo_dir = MUSUBI_TUNER_DIR_RUNTIME
if not os.path.isdir(repo_dir):
print(f"[QIE] Skip deps: musubi-tuner not found at {repo_dir}")
return
# Upgrade basic build tooling (best-effort)
try:
_run_pip(["install", "-U", "pip", "setuptools", "wheel"])
except Exception:
pass
# Optional Torch extra via env: MUSUBI_TUNER_TORCH_EXTRA=cu124|cu128
extra = os.environ.get("MUSUBI_TUNER_TORCH_EXTRA", "").strip()
editable_spec = "." if not extra else f".[{extra}]"
# Install musubi-tuner in editable mode to expose entrypoints and deps
try:
_run_pip(["install", "-e", editable_spec], cwd=repo_dir)
except Exception:
# Fallback: plain install without editable
try:
_run_pip(["install", editable_spec], cwd=repo_dir)
except Exception:
print("[QIE] WARN: musubi-tuner installation failed. Continuing.")
@spaces.GPU
def run_training(
output_name: str,
caption: str,
image_uploads: Any,
target_prefix: str,
target_suffix: str,
control0_uploads: Any,
ctrl0_prefix: str,
ctrl0_suffix: str,
control1_uploads: Any,
ctrl1_prefix: str,
ctrl1_suffix: str,
control2_uploads: Any,
ctrl2_prefix: str,
ctrl2_suffix: str,
control3_uploads: Any,
ctrl3_prefix: str,
ctrl3_suffix: str,
control4_uploads: Any,
ctrl4_prefix: str,
ctrl4_suffix: str,
control5_uploads: Any,
ctrl5_prefix: str,
ctrl5_suffix: str,
control6_uploads: Any,
ctrl6_prefix: str,
ctrl6_suffix: str,
control7_uploads: Any,
ctrl7_prefix: str,
ctrl7_suffix: str,
learning_rate: str,
network_dim: int,
train_res_w: int,
train_res_h: int,
train_batch_size: int,
control_res_w: int,
control_res_h: int,
te_cache_batch_size: int,
seed: int,
max_epochs: int,
save_every: int,
) -> Iterable[tuple]:
global _ACTIVE_PROC, _ACTIVE_PGID
# Basic validation
log_buf = "[QIE] Start Training invoked.\n"
ckpts: List[str] = []
artifacts: List[str] = []
# Emit an initial line so UI can confirm invocation
yield (log_buf, ckpts, artifacts)
if not output_name.strip():
log_buf += "[ERROR] OUTPUT NAME is required.\n"
yield (log_buf, ckpts, artifacts)
def _stop_active_training() -> None:
"""Ubuntu向けのハード停止: 実行中の学習プロセスのプロセスグループを終了する"""
with _ACTIVE_LOCK:
proc = _ACTIVE_PROC
pgid = _ACTIVE_PGID
if not proc:
return
try:
if pgid is not None:
os.killpg(pgid, signal.SIGTERM)
else:
os.kill(proc.pid, signal.SIGTERM)
except Exception:
pass
try:
proc.wait(timeout=5)
except Exception:
try:
if pgid is not None:
os.killpg(pgid, signal.SIGKILL)
else:
os.kill(proc.pid, signal.SIGKILL)
except Exception:
pass
return
if not caption.strip():
log_buf += "[ERROR] CAPTION is required.\n"
yield (log_buf, ckpts, artifacts)
return
# Ensure /auto holds helper files expected by the script
_ensure_workspace_auto_files()
# Resolve data root and create dataset directories (auto-decide)
global DATA_ROOT_RUNTIME
DATA_ROOT_RUNTIME = _ensure_data_root(None)
# Auto-generate dataset directory name
import time
ds_name = f"dataset_{int(time.time())}"
ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name)
img_folder_name = DEFAULT_IMAGE_FOLDER
img_dir = os.path.join(ds_dir, img_folder_name)
os.makedirs(img_dir, exist_ok=True)
# Ingest uploads into dataset folders
base_files = _extract_paths(image_uploads)
if not base_files:
log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
yield (log_buf, ckpts, artifacts)
return
base_filenames = _copy_uploads(base_files, img_dir)
log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
yield (log_buf, ckpts, artifacts)
# Prepare control sets
control_upload_sets = [
_extract_paths(control0_uploads),
_extract_paths(control1_uploads),
_extract_paths(control2_uploads),
_extract_paths(control3_uploads),
_extract_paths(control4_uploads),
_extract_paths(control5_uploads),
_extract_paths(control6_uploads),
_extract_paths(control7_uploads),
]
# Require control_0; others optional
if not control_upload_sets[0]:
log_buf += "[ERROR] control_0 images are required.\n"
yield (log_buf, ckpts, artifacts)
return
control_dirs: List[Optional[str]] = []
for i, uploads in enumerate(control_upload_sets):
if not uploads:
control_dirs.append(None)
continue
folder_name = f"control_{i}"
cdir = os.path.join(ds_dir, folder_name)
os.makedirs(cdir, exist_ok=True)
# Simply copy; name matching will be handled by create_image_caption_json.py
_copy_uploads(uploads, cdir)
control_dirs.append(folder_name)
log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
yield (log_buf, ckpts, artifacts)
# Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
# Prepare script with user parameters
control_folders = [
(control_dirs[i] if control_dirs[i] else None)
for i in range(8)
]
# Decide dataset_config path with fallback to runtime auto dir
ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
# Update dataset config with requested resolution/batch settings
try:
_update_dataset_toml(
ds_conf,
img_res_w=int(train_res_w) if train_res_w else None,
img_res_h=int(train_res_h) if train_res_h else None,
train_batch_size=int(train_batch_size) if train_batch_size else None,
control_res_w=int(control_res_w) if control_res_w else None,
control_res_h=int(control_res_h) if control_res_h else None,
)
log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
except Exception as e:
log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
# Expose dataset config for download (if exists)
if os.path.isfile(ds_conf):
artifacts = [ds_conf]
# Resolve models_root and set output_dir_base to the unique dataset dir
models_root = MODELS_ROOT_RUNTIME
out_base = ds_dir
try:
os.makedirs(out_base, exist_ok=True)
except Exception:
pass
tmp_script = _prepare_script(
dataset_name=ds_name,
caption=caption,
data_root=DATA_ROOT_RUNTIME,
image_folder=img_folder_name,
control_folders=control_folders,
models_root=models_root,
output_dir_base=out_base,
dataset_config=ds_conf,
override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
override_save_every=save_every if save_every and save_every > 0 else None,
override_run_name=output_name.strip(),
target_prefix=(target_prefix or ""),
target_suffix=(target_suffix or ""),
control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix],
control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
override_learning_rate=(learning_rate or None),
override_network_dim=int(network_dim) if network_dim is not None else None,
override_te_cache_bs=int(te_cache_batch_size) if te_cache_batch_size else None,
override_seed=int(seed) if seed is not None else None,
)
shell = _pick_shell()
log_buf += f"[QIE] Using shell: {shell}\n"
log_buf += f"[QIE] Running script: {tmp_script}\n"
out_dir = os.path.join(out_base, output_name.strip())
ckpts = _list_checkpoints(out_dir)
# Copy the final script to dataset dir for download
used_script_path = os.path.join(out_base, "train_QIE_used.sh")
try:
shutil.copy2(str(tmp_script), used_script_path)
try:
os.chmod(used_script_path, 0o755)
except Exception:
pass
if used_script_path not in artifacts:
artifacts.append(used_script_path)
except Exception:
pass
yield (log_buf, ckpts, artifacts)
# Run and stream output
# Ensure child Python processes are unbuffered for real-time logs
child_env = os.environ.copy()
child_env["PYTHONUNBUFFERED"] = "1"
child_env["PYTHONIOENCODING"] = "utf-8"
proc = subprocess.Popen(
[shell, str(tmp_script)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
env=child_env,
preexec_fn=os.setsid,
)
# Register active process for hard stop
with _ACTIVE_LOCK:
_ACTIVE_PROC = proc
try:
_ACTIVE_PGID = os.getpgid(proc.pid)
except Exception:
_ACTIVE_PGID = None
try:
assert proc.stdout is not None
i = 0
for line in proc.stdout:
log_buf += line
i += 1
if i % 30 == 0:
ckpts = _list_checkpoints(out_dir)
# Try to add metadata.jsonl once available
metadata_json = os.path.join(out_base, "metadata.jsonl")
if os.path.isfile(metadata_json) and metadata_json not in artifacts:
artifacts.append(metadata_json)
yield (log_buf, ckpts, artifacts)
finally:
code = proc.wait()
# Clear active process registration if this proc
with _ACTIVE_LOCK:
if _ACTIVE_PROC is proc:
_ACTIVE_PROC = None
_ACTIVE_PGID = None
# Try to locate latest LoRA file for download
lora_path = None
try:
ckpts = _list_checkpoints(out_dir)
except Exception:
pass
lora_path = ckpts[0] if ckpts else None
if code < 0:
try:
sig = -code
log_buf += f"[QIE] Terminated by signal: {sig}\n"
except Exception:
log_buf += f"[QIE] Terminated by signal.\n"
log_buf += f"[QIE] Exit code: {code}\n"
# Final attempt to include metadata.jsonl
metadata_json = os.path.join(out_base, "metadata.jsonl")
if os.path.isfile(metadata_json) and metadata_json not in artifacts:
artifacts.append(metadata_json)
yield (log_buf, ckpts, artifacts)
def build_ui() -> gr.Blocks:
css = """
.pad-section {
padding: 6px;
margin-bottom: 12px;
border: 1px solid var(--color-border, #e5e7eb);
border-radius: 8px;
background: var(--color-background-secondary, #ffffff);
}
.pad-section_0 {
padding: 6px;
margin-bottom: 12px;
border: 1px solid var(--color-border, #e5e7eb);
border-radius: 8px;
background: var(--color-background-secondary, #fafafa);
}
.pad-section_1 {
padding: 6px;
margin-bottom: 12px;
border: 1px solid var(--color-border, #e5e7eb);
border-radius: 8px;
background: var(--color-background-secondary, #eaeaea);
}
"""
with gr.Blocks(title="Qwen-Image-Edit: Trainer", css=css) as demo:
with gr.Tabs() as tabs:
with gr.TabItem("Training"):
gr.Markdown("""
# Qwen-Image-Edit Trainer
学習に使う画像をアップロードし、必要ならファイル名の前後にある共通の文字(prefix/suffix)を指定して、
自動でデータセットを作成し学習を開始します。難しい操作は不要です。
""")
with gr.Accordion("Settings", elem_classes=["pad-section"]):
with gr.Group():
with gr.Row():
output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
with gr.Row():
lr_input = gr.Textbox(label="Learning rate", value="1e-3")
dim_input = gr.Number(label="Network dim", value=4, precision=0)
train_bs = gr.Number(label="Batch size (dataset)", value=1, precision=0)
seed_input = gr.Number(label="Seed", value=42, precision=0)
max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
with gr.Row():
tr_w = gr.Number(label="Image resolution W", value=1024, precision=0)
tr_h = gr.Number(label="Image resolution H", value=1024, precision=0)
cr_w = gr.Number(label="Control resolution W", value=1024, precision=0)
cr_h = gr.Number(label="Control resolution H", value=1024, precision=0)
te_bs = gr.Number(label="TE cache batch size", value=16, precision=0)
with gr.Accordion("Target Image", elem_classes=["pad-section_0"]):
with gr.Group():
with gr.Row():
images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath", height=220, scale=3)
main_gallery = gr.Gallery(label="Target preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
main_prefix = gr.Textbox(label="Target prefix", placeholder="e.g., IMG_")
main_suffix = gr.Textbox(label="Target suffix", placeholder="e.g., _v2")
with gr.Accordion("prefix/sufixについて", open=False):
gr.Markdown("""
ファイルの同名判定のため、画像のファイル名から共通の先頭/末尾文字を取り除く指定(例: IMG_ や _v2)
- まずターゲット画像のファイル名(拡張子なし)から、指定した Target prefix/suffix を取り除いたものを key とします。
- 各コントロールは「付加」規則で、期待名 = control_prefix_i + key + control_suffix_i + ".png" を探して対応付けます。
- アップロード時に画像は自動で .png に変換して保存します(元のファイル名のベースは維持)。
- Control 0 は必須、Control 1〜7 は任意。コントロール画像が1枚だけのときは、すべてのターゲット画像に適用します。
""")
# control_0 is required and shown outside the accordion
with gr.Accordion("Control 0", elem_classes=["pad-section_1"]):
with gr.Group():
with gr.Row():
ctrl0_files = gr.File(label="Upload control_0 images (required)", file_count="multiple", type="filepath", height=220, scale=3)
ctrl0_gallery = gr.Gallery(label="control_0 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl0_prefix = gr.Textbox(label="control_0 prefix", placeholder="e.g., C0_")
ctrl0_suffix = gr.Textbox(label="control_0 suffix", placeholder="e.g., _mask")
# Optional controls start from 1, accordion closed by default
with gr.Accordion("Control 1", open=False, elem_classes=["pad-section_0"]):
with gr.Group():
with gr.Row():
ctrl1_files = gr.File(label="Upload control_1 images", file_count="multiple", type="filepath", height=220, scale=3)
ctrl1_gallery = gr.Gallery(label="control_1 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl1_prefix = gr.Textbox(label="control_1 prefix", placeholder="")
ctrl1_suffix = gr.Textbox(label="control_1 suffix", placeholder="")
with gr.Accordion("Control 2", open=False, elem_classes=["pad-section_1"]):
with gr.Group():
with gr.Row():
ctrl2_files = gr.File(label="Upload control_2 images", file_count="multiple", type="filepath", height=220, scale=3)
ctrl2_gallery = gr.Gallery(label="control_2 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl2_prefix = gr.Textbox(label="control_2 prefix", placeholder="")
ctrl2_suffix = gr.Textbox(label="control_2 suffix", placeholder="")
with gr.Accordion("Control 3", open=False, elem_classes=["pad-section_0"]):
with gr.Group():
with gr.Row():
ctrl3_files = gr.File(label="Upload control_3 images", file_count="multiple", type="filepath", height=220, scale=3)
ctrl3_gallery = gr.Gallery(label="control_3 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl3_prefix = gr.Textbox(label="control_3 prefix", placeholder="")
ctrl3_suffix = gr.Textbox(label="control_3 suffix", placeholder="")
with gr.Accordion("Control 4", open=False, elem_classes=["pad-section_1"]):
with gr.Group():
with gr.Row():
ctrl4_files = gr.File(label="Upload control_4 images", file_count="multiple", type="filepath", height=220, scale=3)
ctrl4_gallery = gr.Gallery(label="control_4 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl4_prefix = gr.Textbox(label="control_4 prefix", placeholder="")
ctrl4_suffix = gr.Textbox(label="control_4 suffix", placeholder="")
with gr.Accordion("Control 5", open=False, elem_classes=["pad-section_0"]):
with gr.Group():
with gr.Row():
ctrl5_files = gr.File(label="Upload control_5 images", file_count="multiple", type="filepath", height=220, scale=3)
ctrl5_gallery = gr.Gallery(label="control_5 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl5_prefix = gr.Textbox(label="control_5 prefix", placeholder="")
ctrl5_suffix = gr.Textbox(label="control_5 suffix", placeholder="")
with gr.Accordion("Control 6", open=False, elem_classes=["pad-section_1"]):
with gr.Group():
with gr.Row():
ctrl6_files = gr.File(label="Upload control_6 images", file_count="multiple", type="filepath", height=220, scale=3)
ctrl6_gallery = gr.Gallery(label="control_6 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl6_prefix = gr.Textbox(label="control_6 prefix", placeholder="")
ctrl6_suffix = gr.Textbox(label="control_6 suffix", placeholder="")
with gr.Accordion("Control 7", open=False, elem_classes=["pad-section_0"]):
with gr.Group():
with gr.Row():
ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath", height=220, scale=3)
ctrl7_gallery = gr.Gallery(label="control_7 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
with gr.Column(scale=1):
with gr.Row():
ctrl7_prefix = gr.Textbox(label="control_7 prefix", placeholder="")
ctrl7_suffix = gr.Textbox(label="control_7 suffix", placeholder="")
# Models root / OUTPUT_DIR_BASE / DATASET_CONFIG are auto-resolved at runtime; no user input needed.
run_btn = gr.Button("Start Training", variant="primary")
logs = gr.Textbox(label="Logs", lines=20)
ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
scripts_files = gr.Files(label="Scripts & Config (live)", interactive=False)
with gr.Row():
stop_btn = gr.Button("学習を停止", variant="secondary")
refresh_scripts_btn = gr.Button("ファイルを再取得", variant="secondary")
# moved max_epochs/save_every above next to OUTPUT NAME
# Wire previews
images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery)
ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery)
ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery)
ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery)
ctrl3_files.change(fn=_files_to_gallery, inputs=ctrl3_files, outputs=ctrl3_gallery)
ctrl4_files.change(fn=_files_to_gallery, inputs=ctrl4_files, outputs=ctrl4_gallery)
ctrl5_files.change(fn=_files_to_gallery, inputs=ctrl5_files, outputs=ctrl5_gallery)
ctrl6_files.change(fn=_files_to_gallery, inputs=ctrl6_files, outputs=ctrl6_gallery)
ctrl7_files.change(fn=_files_to_gallery, inputs=ctrl7_files, outputs=ctrl7_gallery)
run_btn.click(
fn=run_training,
inputs=[
output_name, caption, images_input, main_prefix, main_suffix,
ctrl0_files, ctrl0_prefix, ctrl0_suffix,
ctrl1_files, ctrl1_prefix, ctrl1_suffix,
ctrl2_files, ctrl2_prefix, ctrl2_suffix,
ctrl3_files, ctrl3_prefix, ctrl3_suffix,
ctrl4_files, ctrl4_prefix, ctrl4_suffix,
ctrl5_files, ctrl5_prefix, ctrl5_suffix,
ctrl6_files, ctrl6_prefix, ctrl6_suffix,
ctrl7_files, ctrl7_prefix, ctrl7_suffix,
lr_input, dim_input,
tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
seed_input, max_epochs, save_every,
],
outputs=[logs, ckpt_files, scripts_files],
)
# 回収ボタン: 直近の dataset_ ディレクトリからチェックポイントとスクリプト/設定を再取得
def _refresh_all() -> tuple:
try:
ds_dir = _find_latest_dataset_dir(DATA_ROOT_RUNTIME)
except Exception:
ds_dir = None
try:
ck = _list_checkpoints(ds_dir) if ds_dir else []
except Exception:
ck = []
try:
sc = _collect_scripts_and_config(ds_dir)
except Exception:
sc = _collect_scripts_and_config(None)
return ck, sc
refresh_scripts_btn.click(
fn=_refresh_all,
inputs=[],
outputs=[ckpt_files, scripts_files],
)
# Hard stop button (Ubuntu): kill active process group
def _on_stop_click():
_stop_active_training()
return
stop_btn.click(fn=_on_stop_click, inputs=[], outputs=[])
with gr.TabItem("Prompt Generator"):
gr.Markdown("""
# 🎨 A→B 変換プロンプト自動生成
画像A(入力)と画像B(出力)、補足説明を入力すると、
A→B の変換内容を英語プロンプトとして自動生成し、タスク名候補(3件)も提案します。
モデルは `gpt-5` を使用します。
""")
api_key_pg = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
with gr.Row():
img_a_pg = gr.Image(type="filepath", label="Image A (Input)", height=300)
img_b_pg = gr.Image(type="filepath", label="Image B (Output)", height=300)
notes_pg = gr.Textbox(label="補足説明(日本語可)", lines=4, value="この画像は例であって、汎用的なプロンプトにする")
want_japanese_pg = gr.Checkbox(label="日本語訳を含める", value=True)
run_btn_pg = gr.Button("生成する", variant="primary")
english_out_pg = gr.Textbox(label="English Prompt", lines=8)
names_out_pg = gr.Textbox(label="Name Suggestions", lines=4)
japanese_out_pg = gr.Textbox(label="日本語訳(任意)", lines=8)
def _on_click_prompt(api_key_in, a_path, b_path, notes_in, ja_flag):
# Lazy import to avoid constructing extra Blocks at startup
qpg = importlib.import_module("QIE_prompt_generator")
a_url = qpg.file_to_data_url(a_path) if a_path else None
b_url = qpg.file_to_data_url(b_path) if b_path else None
return qpg.call_openai_chat(api_key_in, a_url, b_url, notes_in, ja_flag)
run_btn_pg.click(
fn=_on_click_prompt,
inputs=[api_key_pg, img_a_pg, img_b_pg, notes_pg, want_japanese_pg],
outputs=[english_out_pg, names_out_pg, japanese_out_pg],
)
return demo
def _startup_download_models() -> None:
global MODELS_ROOT_RUNTIME
# Pick a writable models directory
candidate = os.environ.get("QWEN_IMAGE_MODELS_DIR", DEFAULT_MODELS_ROOT)
try:
os.makedirs(candidate, exist_ok=True)
MODELS_ROOT_RUNTIME = candidate
except PermissionError:
MODELS_ROOT_RUNTIME = os.path.join(os.path.expanduser("~"), "Qwen-Image_models")
os.makedirs(MODELS_ROOT_RUNTIME, exist_ok=True)
print(f"[QIE] Ensuring models in: {MODELS_ROOT_RUNTIME}")
try:
download_all_models(MODELS_ROOT_RUNTIME)
except Exception as e:
print(f"[QIE] Model download failed: {e}")
if __name__ == "__main__":
# 1) Ensure musubi-tuner is cloned before anything else
_startup_clone_musubi_tuner()
# 1.1) Install musubi-tuner dependencies (best-effort)
_startup_install_musubi_deps()
# 2) Download models at startup (blocking by design)
_startup_download_models()
# 3) Launch Gradio app
ui = build_ui()
# Limit concurrency (training is heavy). Enable queue for Spaces compatibility.
# Use generic signature to support multiple gradio versions.
try:
ui = ui.queue(max_size=16)
except TypeError:
ui = ui.queue()
# Allow Gradio to serve files saved under our runtime dirs
try:
allowed = [
AUTO_DIR_RUNTIME,
os.path.join(AUTO_DIR_RUNTIME, "train_LoRA"),
DEFAULT_DATA_ROOT,
DATA_ROOT_RUNTIME,
os.path.join(os.path.expanduser("~"), "auto"),
os.path.join(os.path.expanduser("~"), "data"),
]
ui.launch(server_name="0.0.0.0", allowed_paths=allowed, ssr_mode=False)
except TypeError:
# Older gradio without allowed_paths
try:
ui.launch(server_name="0.0.0.0", ssr_mode=False)
except TypeError:
# Very old gradio without ssr_mode
ui.launch(server_name="0.0.0.0")