Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 51,702 Bytes
			
			| 26b93ae 517e4c9 fc9a363 873a3a3 26b93ae 46b0d09 5f9199a 21d8059 5f9199a 26b93ae b742d84 26b93ae b742d84 26b93ae 6b04281 26b93ae 06aa83a 517e4c9 06aa83a 21d8059 26b93ae 7c1bc29 06aa83a 26b93ae 06aa83a 26b93ae 325c528 517e4c9 873a3a3 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 fc9a363 b9b8012 fc9a363 b9b8012 c02f846 b9b8012 fc9a363 67b33c8 28ba4eb 26b93ae 94acb06 f03ecf2 28ba4eb a7503dc 7c9164c 325c528 f03ecf2 26b93ae 06aa83a b742d84 94acb06 fd3ce40 94acb06 f03ecf2 94acb06 a7503dc 28ba4eb a7357ae 28ba4eb a7357ae a7503dc a7357ae a7503dc 7c9164c 325c528 0a49e69 26b93ae 6b04281 06aa83a 6b04281 06aa83a 6b04281 06aa83a 6b04281 86a20eb 873a3a3 26b93ae f03ecf2 26b93ae 517e4c9 28ba4eb 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 517e4c9 a7503dc 6fabdaf 325c528 6fabdaf 94acb06 59635a0 96f0e61 26b93ae bb4717a fc9a363 4cf412e bb4717a f03ecf2 fc9a363 4cf412e 21d8059 26b93ae fc9a363 4cf412e 26b93ae 7c1bc29 26b93ae 7c1bc29 517e4c9 7c1bc29 f03ecf2 517e4c9 7c1bc29 517e4c9 fc9a363 4cf412e 517e4c9 3d6f715 4cf412e 517e4c9 2f0d2a8 fc9a363 4cf412e 2f0d2a8 517e4c9 a170a6e 517e4c9 7c1bc29 517e4c9 a7503dc 517e4c9 3d6f715 4cf412e 26b93ae a7503dc 26b93ae 517e4c9 26b93ae 517e4c9 59635a0 325c528 4cf412e 325c528 3e32a9e 59635a0 3e32a9e 59635a0 517e4c9 26b93ae 517e4c9 26b93ae 517e4c9 26b93ae 59635a0 517e4c9 94acb06 f03ecf2 28ba4eb a7503dc 6fabdaf 325c528 6fabdaf 26b93ae 94acb06 26b93ae 3d6f715 fc9a363 4cf412e 26b93ae fd3ce40 26b93ae fd3ce40 21d8059 26b93ae 21d8059 26b93ae fc9a363 26b93ae 3d6f715 fc9a363 4cf412e 26b93ae 21d8059 59635a0 fc9a363 59635a0 fc9a363 21d8059 3d6f715 4cf412e 26b93ae df2aec8 2065727 df2aec8 2065727 df2aec8 46b0d09 f2d2b36 fe91d7e 46b0d09 325c528 46b0d09 325c528 46b0d09 7fc39ca 46b0d09 7fc39ca 46b0d09 7fc39ca 46b0d09 7fc39ca 46b0d09 7fc39ca 46b0d09 7fc39ca 46b0d09 7fc39ca 46b0d09 7fc39ca 46b0d09 4cf412e 67b33c8 21d8059 67b33c8 46b0d09 325c528 46b0d09 4cf412e 46b0d09 67b33c8 08d5f7b 67b33c8 08d5f7b 67b33c8 08d5f7b 67b33c8 08d5f7b 67b33c8 46b0d09 21d8059 46b0d09 26b93ae 06aa83a 26b93ae 06aa83a 26b93ae 31c43c1 86a20eb 6b04281 31c43c1 26b93ae 6b04281 26b93ae 94acb06 06aa83a 3e32a9e fd3ce40 3e32a9e fd3ce40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 | #!/usr/bin/env python3
import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Any, Tuple
import time
import json
import gradio as gr
import importlib
import spaces
import signal
from threading import Lock
# Local modules
from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR
# Defaults matching train_QIE.sh expectations
DEFAULT_DATA_ROOT = "/data"
DEFAULT_IMAGE_FOLDER = "image"
DEFAULT_OUTPUT_DIR_BASE = "/auto/train_LoRA"
DEFAULT_DATASET_CONFIG = "/auto/dataset_QIE.toml"
DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR  # "/Qwen-Image_models"
WORKSPACE_AUTO_DIR = "/auto"
# musubi-tuner settings
DEFAULT_MUSUBI_TUNER_DIR = os.environ.get("MUSUBI_TUNER_DIR", "/musubi-tuner")
DEFAULT_MUSUBI_TUNER_REPO = os.environ.get(
    "MUSUBI_TUNER_REPO", "https://github.com/kohya-ss/musubi-tuner.git"
)
TRAINING_DIR = Path(__file__).resolve().parent
# Runtime-resolved paths with fallbacks for non-root environments
MUSUBI_TUNER_DIR_RUNTIME = DEFAULT_MUSUBI_TUNER_DIR
MODELS_ROOT_RUNTIME = DEFAULT_MODELS_ROOT
AUTO_DIR_RUNTIME = WORKSPACE_AUTO_DIR
DATA_ROOT_RUNTIME = DEFAULT_DATA_ROOT
# Active process management for hard stop (Ubuntu)
_ACTIVE_LOCK: Lock = Lock()
_ACTIVE_PROC: Optional[subprocess.Popen] = None
_ACTIVE_PGID: Optional[int] = None
def _bash_quote(s: str) -> str:
    """Return a POSIX-safe single-quoted string literal representing s."""
    if s is None:
        return "''"
    return "'" + str(s).replace("'", "'\"'\"'") + "'"
def _ensure_workspace_auto_files() -> None:
    """Ensure /workspace/auto has required helper files from this repo.
    Copies training/create_image_caption_json.py and training/dataset_QIE.toml
    into /workspace/auto so that train_QIE.sh can run unmodified.
    """
    global AUTO_DIR_RUNTIME
    try:
        os.makedirs(AUTO_DIR_RUNTIME, exist_ok=True)
    except PermissionError:
        home_auto = os.path.join(os.path.expanduser("~"), "auto")
        os.makedirs(home_auto, exist_ok=True)
        AUTO_DIR_RUNTIME = home_auto  # type: ignore
    src_py = TRAINING_DIR / "create_image_caption_json.py"
    src_toml = TRAINING_DIR / "dataset_QIE.toml"
    dst_py = Path(AUTO_DIR_RUNTIME) / "create_image_caption_json.py"
    dst_toml = Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml"
    try:
        shutil.copy2(src_py, dst_py)
    except Exception:
        pass
    try:
        if src_toml.exists():
            shutil.copy2(src_toml, dst_toml)
    except Exception:
        pass
def _update_dataset_toml(
    path: str,
    *,
    img_res_w: Optional[int] = None,
    img_res_h: Optional[int] = None,
    train_batch_size: Optional[int] = None,
    control_res_w: Optional[int] = None,
    control_res_h: Optional[int] = None,
) -> None:
    """Update dataset TOML for resolution/batch/control resolution in-place.
    - Updates [general] resolution and batch_size if provided.
    - Updates first [[datasets]] qwen_image_edit_control_resolution if provided.
    - Creates sections/keys if missing.
    """
    try:
        txt = Path(path).read_text(encoding="utf-8")
    except Exception:
        return
    def _set_in_general(block: str, key: str, value_line: str) -> str:
        import re as _re
        if _re.search(rf"(?m)^\s*{_re.escape(key)}\s*=", block):
            block = _re.sub(rf"(?m)^\s*{_re.escape(key)}\s*=.*$", value_line, block)
        else:
            block = block.rstrip() + "\n" + value_line + "\n"
        return block
    import re
    m = re.search(r"(?ms)^\[general\]\s*(.*?)(?=^\[|\Z)", txt)
    if not m:
        gen = "[general]\n"
        if img_res_w and img_res_h:
            gen += f"resolution = [{int(img_res_w)}, {int(img_res_h)}]\n"
        if train_batch_size is not None:
            gen += f"batch_size = {int(train_batch_size)}\n"
        txt = gen + "\n" + txt
    else:
        head, block, tail = txt[:m.start(1)], m.group(1), txt[m.end(1):]
        if img_res_w and img_res_h:
            block = _set_in_general(block, "resolution", f"resolution = [{int(img_res_w)}, {int(img_res_h)}]")
        if train_batch_size is not None:
            block = _set_in_general(block, "batch_size", f"batch_size = {int(train_batch_size)}")
        txt = head + block + tail
    if control_res_w and control_res_h:
        m2 = re.search(r"(?ms)^\[\[datasets\]\]\s*(.*?)(?=^\[\[|\Z)", txt)
        if m2:
            head, block, tail = txt[:m2.start(1)], m2.group(1), txt[m2.end(1):]
            line = f"qwen_image_edit_control_resolution = [{int(control_res_w)}, {int(control_res_h)}]"
            if re.search(r"(?m)^\s*qwen_image_edit_control_resolution\s*=", block):
                block = re.sub(r"(?m)^\s*qwen_image_edit_control_resolution\s*=.*$", line, block)
            else:
                block = block.rstrip() + "\n" + line + "\n"
            txt = head + block + tail
    try:
        Path(path).write_text(txt, encoding="utf-8")
    except Exception:
        pass
def _ensure_dir_writable(path: str) -> str:
    try:
        os.makedirs(path, exist_ok=True)
        return path
    except PermissionError:
        home_path = os.path.join(os.path.expanduser("~"), os.path.basename(path.strip("/\\")))
        os.makedirs(home_path, exist_ok=True)
        return home_path
def _ensure_data_root(candidate: Optional[str]) -> str:
    root = (candidate or DEFAULT_DATA_ROOT).strip() or DEFAULT_DATA_ROOT
    try:
        os.makedirs(root, exist_ok=True)
        return root
    except PermissionError:
        home_root = os.path.join(os.path.expanduser("~"), "data")
        os.makedirs(home_root, exist_ok=True)
        return home_root
def _extract_paths(files: Any) -> List[Tuple[str, str]]:
    """Extract a list of (abs_path, orig_basename) from Gradio Files input.
    Supports various gradio return shapes across versions.
    """
    out: List[Tuple[str, str]] = []
    if not files:
        return out
    # Gradio Files often returns a list
    if isinstance(files, (list, tuple)):
        items = files
    else:
        items = [files]
    for item in items:
        p: Optional[str] = None
        orig: Optional[str] = None
        # dict-like
        if isinstance(item, dict):
            p = item.get("path") or item.get("name") or item.get("file")
            orig = item.get("orig_name") or item.get("name")
        else:
            # object with attributes
            p = getattr(item, "name", None) or getattr(item, "path", None) or str(item)
            # best-effort original name attribute
            orig = getattr(item, "orig_name", None) or os.path.basename(p) if p else None
        if p:
            abs_p = os.path.abspath(p)
            out.append((abs_p, os.path.basename(orig or abs_p)))
    return out
def _norm_key(filename: str, prefix: str, suffix: str) -> str:
    stem = os.path.splitext(os.path.basename(filename))[0]
    if prefix and stem.startswith(prefix):
        stem = stem[len(prefix):]
    if suffix and stem.endswith(suffix):
        stem = stem[: -len(suffix)]
    return stem
def _copy_uploads(
    uploads: List[Tuple[str, str]], dest_dir: str, rename_to: Optional[List[str]] = None
) -> List[str]:
    os.makedirs(dest_dir, exist_ok=True)
    used_names: List[str] = []
    for idx, (src, orig) in enumerate(uploads):
        # Determine target stem
        if rename_to and idx < len(rename_to):
            stem = os.path.splitext(rename_to[idx])[0]
        else:
            stem = os.path.splitext(orig)[0]
        dst_name = f"{stem}.png"
        # ensure unique within this batch
        final_name = dst_name
        dup_idx = 1
        while final_name in used_names:
            final_name = f"{stem}_{dup_idx}.png"
            dup_idx += 1
        dst_path = os.path.join(dest_dir, final_name)
        # Convert to PNG during save
        try:
            try:
                from PIL import Image  # type: ignore
                with Image.open(src) as img:
                    img.save(dst_path, format="PNG")
            except Exception:
                # Fallback: copy then rename
                shutil.copy2(src, dst_path)
        except Exception:
            # Last resort
            shutil.copy(src, dst_path)
        used_names.append(final_name)
    return used_names
def _list_checkpoints(out_dir: str, limit: int = 20) -> List[str]:
    try:
        if not out_dir or not os.path.isdir(out_dir):
            return []
        import time
        now = time.time()
        min_age_sec = 3.0  # treat files newer than this as possibly in-flight
        items: List[Tuple[float, str]] = []
        for root, _, files in os.walk(out_dir):
            for fn in files:
                if fn.lower().endswith('.safetensors'):
                    full = os.path.join(root, fn)
                    try:
                        # Skip zero-length, too-new, or unreadable files (likely in-flight)
                        size = os.path.getsize(full)
                        if size <= 0:
                            continue
                        mtime = os.path.getmtime(full)
                        if (now - mtime) < min_age_sec:
                            continue
                        # Try opening a small read to ensure readability
                        with open(full, 'rb') as rf:
                            rf.read(64)
                        items.append((mtime, full))
                    except Exception:
                        pass
        items.sort(reverse=True)
        return [p for _, p in items[:limit]]
    except Exception:
        return []
def _find_latest_dataset_dir(root: str) -> Optional[str]:
    try:
        if not os.path.isdir(root):
            return None
        cand: List[Tuple[float, str]] = []
        for name in os.listdir(root):
            if not name.startswith("dataset_"):
                continue
            full = os.path.join(root, name)
            if os.path.isdir(full):
                try:
                    cand.append((os.path.getmtime(full), full))
                except Exception:
                    pass
        if not cand:
            return None
        cand.sort(reverse=True)
        return cand[0][1]
    except Exception:
        return None
def _collect_scripts_and_config(ds_dir: Optional[str]) -> List[str]:
    files: List[str] = []
    try:
        ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
        if os.path.isfile(ds_conf):
            files.append(ds_conf)
        if ds_dir and os.path.isdir(ds_dir):
            used_script = os.path.join(ds_dir, "train_QIE_used.sh")
            if os.path.isfile(used_script):
                files.append(used_script)
            meta = os.path.join(ds_dir, "metadata.jsonl")
            if os.path.isfile(meta):
                files.append(meta)
    except Exception:
        pass
    return files
def _files_to_gallery(files: Any) -> List[str]:
    items: List[str] = []
    if not files:
        return items
    seq = files if isinstance(files, (list, tuple)) else [files]
    for f in seq:
        p = None
        if isinstance(f, str):
            p = f
        elif isinstance(f, dict):
            p = f.get("path") or f.get("name")
        else:
            p = getattr(f, "path", None) or getattr(f, "name", None)
        if p:
            items.append(p)
    return items
def _prepare_script(
    dataset_name: str,
    caption: str,
    data_root: str,
    image_folder: str,
    control_folders: List[Optional[str]],
    models_root: str,
    output_dir_base: Optional[str] = None,
    dataset_config: Optional[str] = None,
    override_max_epochs: Optional[int] = None,
    override_save_every: Optional[int] = None,
    override_run_name: Optional[str] = None,
    target_prefix: Optional[str] = None,
    target_suffix: Optional[str] = None,
    control_prefixes: Optional[List[Optional[str]]] = None,
    control_suffixes: Optional[List[Optional[str]]] = None,
    override_learning_rate: Optional[str] = None,
    override_network_dim: Optional[int] = None,
    override_seed: Optional[int] = None,
    override_te_cache_bs: Optional[int] = None,
) -> Path:
    """Create a temporary copy of train_QIE.sh with injected variables.
    Only variables that must vary per-run are replaced. The rest of the script
    remains as-is to preserve behavior.
    """
    src = TRAINING_DIR / "train_QIE.sh"
    txt = src.read_text(encoding="utf-8")
    # Replace core variables
    replacements = {
        r"^DATA_ROOT=\".*\"": f"DATA_ROOT={_bash_quote(data_root)}",
        r"^DATASET_NAME=\".*\"": f"DATASET_NAME={_bash_quote(dataset_name)}",
        r"^CAPTION=\".*\"": f"CAPTION={_bash_quote(caption)}",
        r"^IMAGE_FOLDER=\".*\"": f"IMAGE_FOLDER={_bash_quote(image_folder)}",
    }
    if output_dir_base:
        replacements[r"^OUTPUT_DIR_BASE=\".*\""] = (
            f"OUTPUT_DIR_BASE={_bash_quote(output_dir_base)}"
        )
    if dataset_config:
        replacements[r"^DATASET_CONFIG=\".*\""] = (
            f"DATASET_CONFIG={_bash_quote(dataset_config)}"
        )
    for pat, val in replacements.items():
        txt = re.sub(pat, val, txt, flags=re.MULTILINE)
    # Inject CONTROL_FOLDER_i if provided (uncomment/override or append)
    for i in range(8):
        val = control_folders[i] if i < len(control_folders) else None
        if not val:
            continue
        # Try to replace commented placeholder first
        pattern = rf"^#\s*CONTROL_FOLDER_{i}=\".*\""
        if re.search(pattern, txt, flags=re.MULTILINE):
            txt = re.sub(
                pattern,
                f"CONTROL_FOLDER_{i}={_bash_quote(val)}",
                txt,
                flags=re.MULTILINE,
            )
        else:
            # Append after IMAGE_FOLDER definition
            txt = re.sub(
                r"^(IMAGE_FOLDER=.*)$",
                rf"\1\nCONTROL_FOLDER_{i}={_bash_quote(val)}",
                txt,
                count=1,
                flags=re.MULTILINE,
            )
    # Point model paths to the selected models_root
    def _replace_model_path(txt: str, key: str, rel: str) -> str:
        return re.sub(
            rf"--{key} \"[^\"]+\"",
            f"--{key} \"{models_root.rstrip('/')}/{rel}\"",
            txt,
        )
    txt = _replace_model_path(txt, "vae", "vae/diffusion_pytorch_model.safetensors")
    txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
    txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")
    # Replace working dir for metadata generation to runtime /auto
    txt = re.sub(r"^cd\s+/workspace/auto\s*$", f"cd {AUTO_DIR_RUNTIME}", txt, flags=re.MULTILINE)
    # Ensure musubi-tuner path matches runtime location
    txt = re.sub(r"^cd\s+/musubi-tuner\s*$", f"cd {re.escape(MUSUBI_TUNER_DIR_RUNTIME)}", txt, flags=re.MULTILINE)
    # ZeroGPU compatibility: avoid spawning via 'accelerate launch'.
    # Run the training module directly in-process so GPU stays attached
    # to the same Python request context.
    txt = re.sub(
        r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py",
        r"python -u src/musubi_tuner/qwen_image_train_network.py",
        txt,
        flags=re.MULTILINE,
    )
    # Optionally override epochs and save frequency for ZeroGPU time slicing
    if override_max_epochs is not None and override_max_epochs > 0:
        txt = re.sub(r"--max_train_epochs\s+\d+",
                     f"--max_train_epochs {override_max_epochs}", txt)
    if override_save_every is not None and override_save_every > 0:
        txt = re.sub(r"--save_every_n_epochs\s+\d+",
                     f"--save_every_n_epochs {override_save_every}", txt)
    if override_run_name:
        txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE)
    # Inject prefix/suffix flags for metadata creation
    extra_lines: List[str] = []
    if (target_prefix or ""):
        extra_lines.append(f"  --target_prefix {_bash_quote(target_prefix)} \\")
    if (target_suffix or ""):
        extra_lines.append(f"  --target_suffix {_bash_quote(target_suffix)} \\")
    for i in range(8):
        pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None
        suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None
        if pre:
            extra_lines.append(f"  --control_prefix_{i} {_bash_quote(pre)} \\")
        if suf:
            extra_lines.append(f"  --control_suffix_{i} {_bash_quote(suf)} \\")
    if extra_lines:
        extra_block = "\n".join(extra_lines)
        # Insert extra flags just before the CONTROL_ARGS line, preserving indentation.
        txt = re.sub(
            r'^(\s*)"\$\{CONTROL_ARGS\[@\]\}"',
            lambda m: f"{extra_block}\n{m.group(1)}\"${{CONTROL_ARGS[@]}}\"",
            txt,
            flags=re.MULTILINE,
        )
    # Override CLI hyperparameters if provided
    if override_learning_rate:
        txt = re.sub(r"--learning_rate\s+[-+eE0-9\.]+", f"--learning_rate {override_learning_rate}", txt)
    if override_network_dim is not None:
        txt = re.sub(r"--network_dim\s+\d+", f"--network_dim {override_network_dim}", txt)
    if override_seed is not None:
        txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
    # Optionally override text-encoder cache batch size
    if override_te_cache_bs is not None and override_te_cache_bs > 0:
        txt = re.sub(
            r"(qwen_image_cache_text_encoder_outputs\.py[^\n]*--batch_size\s+)\d+",
            rf"\g<1>{int(override_te_cache_bs)}",
            txt,
            flags=re.MULTILINE,
        )
    # Prefer overriding variable definitions at top of script (safer than CLI regex)
    def _set_var(name: str, value: str) -> None:
        nonlocal txt
        pattern = rf"(?m)^\s*{name}\s*=.*$"
        replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}'
        if re.search(pattern, txt):
            txt = re.sub(pattern, replacement, txt)
        else:
            txt = f"{replacement}\n" + txt
    if override_learning_rate:
        _set_var('LEARNING_RATE', override_learning_rate)
    if override_network_dim is not None:
        _set_var('NETWORK_DIM', str(override_network_dim))
    if override_seed is not None:
        _set_var('SEED', str(override_seed))
    if override_max_epochs is not None and override_max_epochs > 0:
        _set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs))
    if override_save_every is not None and override_save_every > 0:
        _set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every))
    # Write to a temp file alongside this repo for easier inspection
    run_dir = TRAINING_DIR / ".gradio_runs"
    run_dir.mkdir(parents=True, exist_ok=True)
    tmp = run_dir / f"train_QIE_run_{os.getpid()}.sh"
    tmp.write_text(txt, encoding="utf-8", newline="\n")
    try:
        os.chmod(tmp, 0o755)
    except Exception:
        pass
    return tmp
def _pick_shell() -> str:
    for sh in ("bash", "sh"):
        if shutil.which(sh):
            return sh
    raise RuntimeError("No POSIX shell found. Please install bash or sh.")
def _is_git_repo(path: str) -> bool:
    try:
        out = subprocess.run(
            ["git", "-C", path, "rev-parse", "--is-inside-work-tree"],
            capture_output=True,
            text=True,
            check=False,
        )
        return out.returncode == 0 and out.stdout.strip() == "true"
    except Exception:
        return False
def _startup_clone_musubi_tuner() -> None:
    global MUSUBI_TUNER_DIR_RUNTIME
    target = MUSUBI_TUNER_DIR_RUNTIME
    repo = DEFAULT_MUSUBI_TUNER_REPO
    parent = os.path.dirname(target.rstrip("/\\")) or "/"
    try:
        os.makedirs(parent, exist_ok=True)
    except PermissionError:
        # Fallback to home directory
        target = os.path.join(os.path.expanduser("~"), "musubi-tuner")
        MUSUBI_TUNER_DIR_RUNTIME = target
        os.makedirs(os.path.dirname(target), exist_ok=True)
    except Exception:
        pass
    if os.path.isdir(target) and _is_git_repo(target):
        print(f"[QIE] musubi-tuner exists at {target}; pulling latest...")
        try:
            subprocess.run(["git", "-C", target, "fetch", "--all", "--prune"], check=False)
            subprocess.run(["git", "-C", target, "pull", "--ff-only"], check=False)
        except Exception as e:
            print(f"[QIE] git pull failed: {e}")
        return
    if os.path.exists(target) and not _is_git_repo(target):
        print(f"[QIE] Warning: {target} exists and is not a git repo. Skipping clone.")
        return
    print(f"[QIE] Cloning musubi-tuner into {target} from {repo} ...")
    try:
        subprocess.run(["git", "clone", "--depth", "1", repo, target], check=True)
        print("[QIE] Clone completed.")
    except subprocess.CalledProcessError as e:
        print(f"[QIE] Clone failed at {target}: {e}")
        # Last-chance fallback into home
        if not target.startswith(os.path.expanduser("~")):
            fallback = os.path.join(os.path.expanduser("~"), "musubi-tuner")
            print(f"[QIE] Retrying clone into {fallback}...")
            try:
                subprocess.run(["git", "clone", "--depth", "1", repo, fallback], check=True)
                MUSUBI_TUNER_DIR_RUNTIME = fallback
                print("[QIE] Clone completed in fallback.")
            except Exception as e2:
                print(f"[QIE] Clone failed in fallback as well: {e2}")
def _run_pip(args: List[str], cwd: Optional[str] = None) -> None:
    cmd = [sys.executable, "-m", "pip"] + args
    try:
        print(f"[QIE] pip {' '.join(args)} (cwd={cwd or os.getcwd()})")
        subprocess.run(cmd, check=True, cwd=cwd)
    except subprocess.CalledProcessError as e:
        print(f"[QIE] pip failed: {e}")
def _startup_install_musubi_deps() -> None:
    repo_dir = MUSUBI_TUNER_DIR_RUNTIME
    if not os.path.isdir(repo_dir):
        print(f"[QIE] Skip deps: musubi-tuner not found at {repo_dir}")
        return
    # Upgrade basic build tooling (best-effort)
    try:
        _run_pip(["install", "-U", "pip", "setuptools", "wheel"])
    except Exception:
        pass
    # Optional Torch extra via env: MUSUBI_TUNER_TORCH_EXTRA=cu124|cu128
    extra = os.environ.get("MUSUBI_TUNER_TORCH_EXTRA", "").strip()
    editable_spec = "." if not extra else f".[{extra}]"
    # Install musubi-tuner in editable mode to expose entrypoints and deps
    try:
        _run_pip(["install", "-e", editable_spec], cwd=repo_dir)
    except Exception:
        # Fallback: plain install without editable
        try:
            _run_pip(["install", editable_spec], cwd=repo_dir)
        except Exception:
            print("[QIE] WARN: musubi-tuner installation failed. Continuing.")
@spaces.GPU
def run_training(
    output_name: str,
    caption: str,
    image_uploads: Any,
    target_prefix: str,
    target_suffix: str,
    control0_uploads: Any,
    ctrl0_prefix: str,
    ctrl0_suffix: str,
    control1_uploads: Any,
    ctrl1_prefix: str,
    ctrl1_suffix: str,
    control2_uploads: Any,
    ctrl2_prefix: str,
    ctrl2_suffix: str,
    control3_uploads: Any,
    ctrl3_prefix: str,
    ctrl3_suffix: str,
    control4_uploads: Any,
    ctrl4_prefix: str,
    ctrl4_suffix: str,
    control5_uploads: Any,
    ctrl5_prefix: str,
    ctrl5_suffix: str,
    control6_uploads: Any,
    ctrl6_prefix: str,
    ctrl6_suffix: str,
    control7_uploads: Any,
    ctrl7_prefix: str,
    ctrl7_suffix: str,
    learning_rate: str,
    network_dim: int,
    train_res_w: int,
    train_res_h: int,
    train_batch_size: int,
    control_res_w: int,
    control_res_h: int,
    te_cache_batch_size: int,
    seed: int,
    max_epochs: int,
    save_every: int,
) -> Iterable[tuple]:
    global _ACTIVE_PROC, _ACTIVE_PGID
    # Basic validation
    log_buf = "[QIE] Start Training invoked.\n"
    ckpts: List[str] = []
    artifacts: List[str] = []
    # Emit an initial line so UI can confirm invocation
    yield (log_buf, ckpts, artifacts)
    if not output_name.strip():
        log_buf += "[ERROR] OUTPUT NAME is required.\n"
        yield (log_buf, ckpts, artifacts)
def _stop_active_training() -> None:
    """Ubuntu向けのハード停止: 実行中の学習プロセスのプロセスグループを終了する"""
    with _ACTIVE_LOCK:
        proc = _ACTIVE_PROC
        pgid = _ACTIVE_PGID
    if not proc:
        return
    try:
        if pgid is not None:
            os.killpg(pgid, signal.SIGTERM)
        else:
            os.kill(proc.pid, signal.SIGTERM)
    except Exception:
        pass
    try:
        proc.wait(timeout=5)
    except Exception:
        try:
            if pgid is not None:
                os.killpg(pgid, signal.SIGKILL)
            else:
                os.kill(proc.pid, signal.SIGKILL)
        except Exception:
            pass
        return
    if not caption.strip():
        log_buf += "[ERROR] CAPTION is required.\n"
        yield (log_buf, ckpts, artifacts)
        return
    # Ensure /auto holds helper files expected by the script
    _ensure_workspace_auto_files()
    # Resolve data root and create dataset directories (auto-decide)
    global DATA_ROOT_RUNTIME
    DATA_ROOT_RUNTIME = _ensure_data_root(None)
    # Auto-generate dataset directory name
    import time
    ds_name = f"dataset_{int(time.time())}"
    ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name)
    img_folder_name = DEFAULT_IMAGE_FOLDER
    img_dir = os.path.join(ds_dir, img_folder_name)
    os.makedirs(img_dir, exist_ok=True)
    # Ingest uploads into dataset folders
    base_files = _extract_paths(image_uploads)
    if not base_files:
        log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
        yield (log_buf, ckpts, artifacts)
        return
    base_filenames = _copy_uploads(base_files, img_dir)
    log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
    yield (log_buf, ckpts, artifacts)
    # Prepare control sets
    control_upload_sets = [
        _extract_paths(control0_uploads),
        _extract_paths(control1_uploads),
        _extract_paths(control2_uploads),
        _extract_paths(control3_uploads),
        _extract_paths(control4_uploads),
        _extract_paths(control5_uploads),
        _extract_paths(control6_uploads),
        _extract_paths(control7_uploads),
    ]
    # Require control_0; others optional
    if not control_upload_sets[0]:
        log_buf += "[ERROR] control_0 images are required.\n"
        yield (log_buf, ckpts, artifacts)
        return
    control_dirs: List[Optional[str]] = []
    for i, uploads in enumerate(control_upload_sets):
        if not uploads:
            control_dirs.append(None)
            continue
        folder_name = f"control_{i}"
        cdir = os.path.join(ds_dir, folder_name)
        os.makedirs(cdir, exist_ok=True)
        # Simply copy; name matching will be handled by create_image_caption_json.py
        _copy_uploads(uploads, cdir)
        control_dirs.append(folder_name)
        log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
        yield (log_buf, ckpts, artifacts)
    # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
    # Prepare script with user parameters
    control_folders = [
        (control_dirs[i] if control_dirs[i] else None)
        for i in range(8)
    ]
    # Decide dataset_config path with fallback to runtime auto dir
    ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
    # Update dataset config with requested resolution/batch settings
    try:
        _update_dataset_toml(
            ds_conf,
            img_res_w=int(train_res_w) if train_res_w else None,
            img_res_h=int(train_res_h) if train_res_h else None,
            train_batch_size=int(train_batch_size) if train_batch_size else None,
            control_res_w=int(control_res_w) if control_res_w else None,
            control_res_h=int(control_res_h) if control_res_h else None,
        )
        log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
    except Exception as e:
        log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
    # Expose dataset config for download (if exists)
    if os.path.isfile(ds_conf):
        artifacts = [ds_conf]
    # Resolve models_root and set output_dir_base to the unique dataset dir
    models_root = MODELS_ROOT_RUNTIME
    out_base = ds_dir
    try:
        os.makedirs(out_base, exist_ok=True)
    except Exception:
        pass
    tmp_script = _prepare_script(
        dataset_name=ds_name,
        caption=caption,
        data_root=DATA_ROOT_RUNTIME,
        image_folder=img_folder_name,
        control_folders=control_folders,
        models_root=models_root,
        output_dir_base=out_base,
        dataset_config=ds_conf,
        override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
        override_save_every=save_every if save_every and save_every > 0 else None,
        override_run_name=output_name.strip(),
        target_prefix=(target_prefix or ""),
        target_suffix=(target_suffix or ""),
        control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix],
        control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
        override_learning_rate=(learning_rate or None),
        override_network_dim=int(network_dim) if network_dim is not None else None,
        override_te_cache_bs=int(te_cache_batch_size) if te_cache_batch_size else None,
        override_seed=int(seed) if seed is not None else None,
    )
    shell = _pick_shell()
    log_buf += f"[QIE] Using shell: {shell}\n"
    log_buf += f"[QIE] Running script: {tmp_script}\n"
    out_dir = os.path.join(out_base, output_name.strip())
    ckpts = _list_checkpoints(out_dir)
    # Copy the final script to dataset dir for download
    used_script_path = os.path.join(out_base, "train_QIE_used.sh")
    try:
        shutil.copy2(str(tmp_script), used_script_path)
        try:
            os.chmod(used_script_path, 0o755)
        except Exception:
            pass
        if used_script_path not in artifacts:
            artifacts.append(used_script_path)
    except Exception:
        pass
    yield (log_buf, ckpts, artifacts)
    # Run and stream output
    # Ensure child Python processes are unbuffered for real-time logs
    child_env = os.environ.copy()
    child_env["PYTHONUNBUFFERED"] = "1"
    child_env["PYTHONIOENCODING"] = "utf-8"
    proc = subprocess.Popen(
        [shell, str(tmp_script)],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
        universal_newlines=True,
        env=child_env,
        preexec_fn=os.setsid,
    )
    # Register active process for hard stop
    with _ACTIVE_LOCK:
        _ACTIVE_PROC = proc
        try:
            _ACTIVE_PGID = os.getpgid(proc.pid)
        except Exception:
            _ACTIVE_PGID = None
    try:
        assert proc.stdout is not None
        i = 0
        for line in proc.stdout:
            log_buf += line
            i += 1
            if i % 30 == 0:
                ckpts = _list_checkpoints(out_dir)
                # Try to add metadata.jsonl once available
                metadata_json = os.path.join(out_base, "metadata.jsonl")
                if os.path.isfile(metadata_json) and metadata_json not in artifacts:
                    artifacts.append(metadata_json)
            yield (log_buf, ckpts, artifacts)
    finally:
        code = proc.wait()
        # Clear active process registration if this proc
        with _ACTIVE_LOCK:
            if _ACTIVE_PROC is proc:
                _ACTIVE_PROC = None
                _ACTIVE_PGID = None
        # Try to locate latest LoRA file for download
        lora_path = None
        try:
            ckpts = _list_checkpoints(out_dir)
        except Exception:
            pass
        lora_path = ckpts[0] if ckpts else None
        if code < 0:
            try:
                sig = -code
                log_buf += f"[QIE] Terminated by signal: {sig}\n"
            except Exception:
                log_buf += f"[QIE] Terminated by signal.\n"
        log_buf += f"[QIE] Exit code: {code}\n"
        # Final attempt to include metadata.jsonl
        metadata_json = os.path.join(out_base, "metadata.jsonl")
        if os.path.isfile(metadata_json) and metadata_json not in artifacts:
            artifacts.append(metadata_json)
        yield (log_buf, ckpts, artifacts)
def build_ui() -> gr.Blocks:
    css = """
    .pad-section {
      padding: 6px;
      margin-bottom: 12px;
      border: 1px solid var(--color-border, #e5e7eb);
      border-radius: 8px;
      background: var(--color-background-secondary, #ffffff);
    }
    .pad-section_0 {
      padding: 6px;
      margin-bottom: 12px;
      border: 1px solid var(--color-border, #e5e7eb);
      border-radius: 8px;
      background: var(--color-background-secondary, #fafafa);
    }
    .pad-section_1 {
      padding: 6px;
      margin-bottom: 12px;
      border: 1px solid var(--color-border, #e5e7eb);
      border-radius: 8px;
      background: var(--color-background-secondary, #eaeaea);
    }
    """
    with gr.Blocks(title="Qwen-Image-Edit: Trainer", css=css) as demo:
        with gr.Tabs() as tabs:
            with gr.TabItem("Training"):
                gr.Markdown("""
                # Qwen-Image-Edit Trainer
                学習に使う画像をアップロードし、必要ならファイル名の前後にある共通の文字(prefix/suffix)を指定して、
                自動でデータセットを作成し学習を開始します。難しい操作は不要です。
                """)
                with gr.Accordion("Settings", elem_classes=["pad-section"]):
                    with gr.Group():
                        with gr.Row():
                            output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
                            caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
                        with gr.Row():
                            lr_input = gr.Textbox(label="Learning rate", value="1e-3")
                            dim_input = gr.Number(label="Network dim", value=4, precision=0)
                            train_bs = gr.Number(label="Batch size (dataset)", value=1, precision=0)
                            seed_input = gr.Number(label="Seed", value=42, precision=0)
                            max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
                            save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
                        with gr.Row():
                            tr_w = gr.Number(label="Image resolution W", value=1024, precision=0)
                            tr_h = gr.Number(label="Image resolution H", value=1024, precision=0)
                            cr_w = gr.Number(label="Control resolution W", value=1024, precision=0)
                            cr_h = gr.Number(label="Control resolution H", value=1024, precision=0)
                            te_bs = gr.Number(label="TE cache batch size", value=16, precision=0)
                with gr.Accordion("Target Image", elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath", height=220, scale=3)
                            main_gallery = gr.Gallery(label="Target preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    main_prefix = gr.Textbox(label="Target prefix", placeholder="e.g., IMG_")
                                    main_suffix = gr.Textbox(label="Target suffix", placeholder="e.g., _v2")
                                with gr.Accordion("prefix/sufixについて", open=False):
                                    gr.Markdown("""
                                    ファイルの同名判定のため、画像のファイル名から共通の先頭/末尾文字を取り除く指定(例: IMG_ や _v2)
                                    - まずターゲット画像のファイル名(拡張子なし)から、指定した Target prefix/suffix を取り除いたものを key とします。
                                    - 各コントロールは「付加」規則で、期待名 = control_prefix_i + key + control_suffix_i + ".png" を探して対応付けます。
                                    - アップロード時に画像は自動で .png に変換して保存します(元のファイル名のベースは維持)。
                                    - Control 0 は必須、Control 1〜7 は任意。コントロール画像が1枚だけのときは、すべてのターゲット画像に適用します。
                                    """)
                # control_0 is required and shown outside the accordion
                with gr.Accordion("Control 0", elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl0_files = gr.File(label="Upload control_0 images (required)", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl0_gallery = gr.Gallery(label="control_0 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl0_prefix = gr.Textbox(label="control_0 prefix", placeholder="e.g., C0_")
                                    ctrl0_suffix = gr.Textbox(label="control_0 suffix", placeholder="e.g., _mask")
                # Optional controls start from 1, accordion closed by default
                with gr.Accordion("Control 1", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl1_files = gr.File(label="Upload control_1 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl1_gallery = gr.Gallery(label="control_1 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl1_prefix = gr.Textbox(label="control_1 prefix", placeholder="")
                                    ctrl1_suffix = gr.Textbox(label="control_1 suffix", placeholder="")
                with gr.Accordion("Control 2", open=False, elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl2_files = gr.File(label="Upload control_2 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl2_gallery = gr.Gallery(label="control_2 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl2_prefix = gr.Textbox(label="control_2 prefix", placeholder="")
                                    ctrl2_suffix = gr.Textbox(label="control_2 suffix", placeholder="")
                with gr.Accordion("Control 3", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl3_files = gr.File(label="Upload control_3 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl3_gallery = gr.Gallery(label="control_3 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl3_prefix = gr.Textbox(label="control_3 prefix", placeholder="")
                                    ctrl3_suffix = gr.Textbox(label="control_3 suffix", placeholder="")
                with gr.Accordion("Control 4", open=False, elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl4_files = gr.File(label="Upload control_4 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl4_gallery = gr.Gallery(label="control_4 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl4_prefix = gr.Textbox(label="control_4 prefix", placeholder="")
                                    ctrl4_suffix = gr.Textbox(label="control_4 suffix", placeholder="")
                with gr.Accordion("Control 5", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl5_files = gr.File(label="Upload control_5 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl5_gallery = gr.Gallery(label="control_5 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl5_prefix = gr.Textbox(label="control_5 prefix", placeholder="")
                                    ctrl5_suffix = gr.Textbox(label="control_5 suffix", placeholder="")
                with gr.Accordion("Control 6", open=False, elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl6_files = gr.File(label="Upload control_6 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl6_gallery = gr.Gallery(label="control_6 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl6_prefix = gr.Textbox(label="control_6 prefix", placeholder="")
                                    ctrl6_suffix = gr.Textbox(label="control_6 suffix", placeholder="")
                with gr.Accordion("Control 7", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl7_gallery = gr.Gallery(label="control_7 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl7_prefix = gr.Textbox(label="control_7 prefix", placeholder="")
                                    ctrl7_suffix = gr.Textbox(label="control_7 suffix", placeholder="")
                # Models root / OUTPUT_DIR_BASE / DATASET_CONFIG are auto-resolved at runtime; no user input needed.
                run_btn = gr.Button("Start Training", variant="primary")
                logs = gr.Textbox(label="Logs", lines=20)
                ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
                scripts_files = gr.Files(label="Scripts & Config (live)", interactive=False)
                with gr.Row():
                    stop_btn = gr.Button("学習を停止", variant="secondary")
                    refresh_scripts_btn = gr.Button("ファイルを再取得", variant="secondary")
                # moved max_epochs/save_every above next to OUTPUT NAME
                # Wire previews
                images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery)
                ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery)
                ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery)
                ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery)
                ctrl3_files.change(fn=_files_to_gallery, inputs=ctrl3_files, outputs=ctrl3_gallery)
                ctrl4_files.change(fn=_files_to_gallery, inputs=ctrl4_files, outputs=ctrl4_gallery)
                ctrl5_files.change(fn=_files_to_gallery, inputs=ctrl5_files, outputs=ctrl5_gallery)
                ctrl6_files.change(fn=_files_to_gallery, inputs=ctrl6_files, outputs=ctrl6_gallery)
                ctrl7_files.change(fn=_files_to_gallery, inputs=ctrl7_files, outputs=ctrl7_gallery)
                run_btn.click(
                    fn=run_training,
                    inputs=[
                        output_name, caption, images_input, main_prefix, main_suffix,
                        ctrl0_files, ctrl0_prefix, ctrl0_suffix,
                        ctrl1_files, ctrl1_prefix, ctrl1_suffix,
                        ctrl2_files, ctrl2_prefix, ctrl2_suffix,
                        ctrl3_files, ctrl3_prefix, ctrl3_suffix,
                        ctrl4_files, ctrl4_prefix, ctrl4_suffix,
                        ctrl5_files, ctrl5_prefix, ctrl5_suffix,
                        ctrl6_files, ctrl6_prefix, ctrl6_suffix,
                        ctrl7_files, ctrl7_prefix, ctrl7_suffix,
                        lr_input, dim_input,
                        tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
                        seed_input, max_epochs, save_every,
                    ],
                    outputs=[logs, ckpt_files, scripts_files],
                )
                # 回収ボタン: 直近の dataset_ ディレクトリからチェックポイントとスクリプト/設定を再取得
                def _refresh_all() -> tuple:
                    try:
                        ds_dir = _find_latest_dataset_dir(DATA_ROOT_RUNTIME)
                    except Exception:
                        ds_dir = None
                    try:
                        ck = _list_checkpoints(ds_dir) if ds_dir else []
                    except Exception:
                        ck = []
                    try:
                        sc = _collect_scripts_and_config(ds_dir)
                    except Exception:
                        sc = _collect_scripts_and_config(None)
                    return ck, sc
                refresh_scripts_btn.click(
                    fn=_refresh_all,
                    inputs=[],
                    outputs=[ckpt_files, scripts_files],
                )
                # Hard stop button (Ubuntu): kill active process group
                def _on_stop_click():
                    _stop_active_training()
                    return
                stop_btn.click(fn=_on_stop_click, inputs=[], outputs=[])
            with gr.TabItem("Prompt Generator"):
                gr.Markdown("""
                # 🎨 A→B 変換プロンプト自動生成
                画像A(入力)と画像B(出力)、補足説明を入力すると、  
                A→B の変換内容を英語プロンプトとして自動生成し、タスク名候補(3件)も提案します。  
                モデルは `gpt-5` を使用します。
                """)
                api_key_pg = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
                with gr.Row():
                    img_a_pg = gr.Image(type="filepath", label="Image A (Input)", height=300)
                    img_b_pg = gr.Image(type="filepath", label="Image B (Output)", height=300)
                notes_pg = gr.Textbox(label="補足説明(日本語可)", lines=4, value="この画像は例であって、汎用的なプロンプトにする")
                want_japanese_pg = gr.Checkbox(label="日本語訳を含める", value=True)
                run_btn_pg = gr.Button("生成する", variant="primary")
                english_out_pg = gr.Textbox(label="English Prompt", lines=8)
                names_out_pg = gr.Textbox(label="Name Suggestions", lines=4)
                japanese_out_pg = gr.Textbox(label="日本語訳(任意)", lines=8)
                def _on_click_prompt(api_key_in, a_path, b_path, notes_in, ja_flag):
                    # Lazy import to avoid constructing extra Blocks at startup
                    qpg = importlib.import_module("QIE_prompt_generator")
                    a_url = qpg.file_to_data_url(a_path) if a_path else None
                    b_url = qpg.file_to_data_url(b_path) if b_path else None
                    return qpg.call_openai_chat(api_key_in, a_url, b_url, notes_in, ja_flag)
                run_btn_pg.click(
                    fn=_on_click_prompt,
                    inputs=[api_key_pg, img_a_pg, img_b_pg, notes_pg, want_japanese_pg],
                    outputs=[english_out_pg, names_out_pg, japanese_out_pg],
                )
    return demo
def _startup_download_models() -> None:
    global MODELS_ROOT_RUNTIME
    # Pick a writable models directory
    candidate = os.environ.get("QWEN_IMAGE_MODELS_DIR", DEFAULT_MODELS_ROOT)
    try:
        os.makedirs(candidate, exist_ok=True)
        MODELS_ROOT_RUNTIME = candidate
    except PermissionError:
        MODELS_ROOT_RUNTIME = os.path.join(os.path.expanduser("~"), "Qwen-Image_models")
        os.makedirs(MODELS_ROOT_RUNTIME, exist_ok=True)
    print(f"[QIE] Ensuring models in: {MODELS_ROOT_RUNTIME}")
    try:
        download_all_models(MODELS_ROOT_RUNTIME)
    except Exception as e:
        print(f"[QIE] Model download failed: {e}")
if __name__ == "__main__":
    # 1) Ensure musubi-tuner is cloned before anything else
    _startup_clone_musubi_tuner()
    # 1.1) Install musubi-tuner dependencies (best-effort)
    _startup_install_musubi_deps()
    # 2) Download models at startup (blocking by design)
    _startup_download_models()
    # 3) Launch Gradio app
    ui = build_ui()
    # Limit concurrency (training is heavy). Enable queue for Spaces compatibility.
    # Use generic signature to support multiple gradio versions.
    try:
        ui = ui.queue(max_size=16)
    except TypeError:
        ui = ui.queue()
    # Allow Gradio to serve files saved under our runtime dirs
    try:
        allowed = [
            AUTO_DIR_RUNTIME,
            os.path.join(AUTO_DIR_RUNTIME, "train_LoRA"),
            DEFAULT_DATA_ROOT,
            DATA_ROOT_RUNTIME,
            os.path.join(os.path.expanduser("~"), "auto"),
            os.path.join(os.path.expanduser("~"), "data"),
        ]
        ui.launch(server_name="0.0.0.0", allowed_paths=allowed, ssr_mode=False)
    except TypeError:
        # Older gradio without allowed_paths
        try:
            ui.launch(server_name="0.0.0.0", ssr_mode=False)
        except TypeError:
            # Very old gradio without ssr_mode
            ui.launch(server_name="0.0.0.0")
 |