File size: 51,702 Bytes
26b93ae
 
 
 
 
 
 
 
517e4c9
fc9a363
873a3a3
26b93ae
 
46b0d09
5f9199a
21d8059
 
5f9199a
26b93ae
 
 
 
 
b742d84
26b93ae
b742d84
 
 
 
26b93ae
6b04281
 
 
 
 
 
26b93ae
 
 
06aa83a
 
 
 
517e4c9
06aa83a
21d8059
 
 
 
 
26b93ae
 
 
 
 
 
 
 
 
 
 
 
 
 
7c1bc29
06aa83a
 
 
 
 
 
26b93ae
 
06aa83a
 
26b93ae
 
 
 
 
 
 
 
 
 
 
 
325c528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517e4c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873a3a3
 
 
 
 
 
 
 
 
517e4c9
 
 
 
 
 
a7503dc
517e4c9
a7503dc
517e4c9
a7503dc
 
 
517e4c9
 
 
a7503dc
517e4c9
 
a7503dc
517e4c9
a7503dc
 
 
 
 
 
 
517e4c9
a7503dc
517e4c9
 
 
 
 
fc9a363
 
 
 
b9b8012
 
 
 
fc9a363
 
 
 
 
 
b9b8012
c02f846
 
 
b9b8012
 
 
 
 
 
 
fc9a363
 
 
 
 
 
 
 
67b33c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28ba4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b93ae
 
 
 
 
 
 
 
 
94acb06
 
f03ecf2
28ba4eb
 
a7503dc
 
7c9164c
 
 
325c528
f03ecf2
26b93ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06aa83a
 
 
 
b742d84
94acb06
 
 
 
 
fd3ce40
94acb06
 
 
 
 
 
 
 
 
 
 
f03ecf2
 
94acb06
a7503dc
 
28ba4eb
a7357ae
28ba4eb
a7357ae
a7503dc
 
 
 
 
 
 
 
 
 
a7357ae
 
 
 
 
 
 
a7503dc
7c9164c
 
 
 
 
 
 
 
325c528
 
 
 
 
 
 
 
 
0a49e69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b93ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b04281
 
 
 
 
 
 
 
 
 
 
 
 
 
06aa83a
 
6b04281
 
 
 
06aa83a
 
 
 
 
6b04281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06aa83a
 
 
 
 
 
 
 
 
 
 
6b04281
 
86a20eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873a3a3
26b93ae
f03ecf2
26b93ae
517e4c9
28ba4eb
 
517e4c9
a7503dc
 
517e4c9
a7503dc
 
517e4c9
a7503dc
 
517e4c9
a7503dc
 
517e4c9
a7503dc
 
517e4c9
a7503dc
 
517e4c9
a7503dc
 
517e4c9
a7503dc
 
6fabdaf
 
325c528
 
 
 
 
 
6fabdaf
94acb06
 
59635a0
96f0e61
26b93ae
bb4717a
fc9a363
4cf412e
bb4717a
 
f03ecf2
fc9a363
4cf412e
21d8059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b93ae
 
fc9a363
4cf412e
26b93ae
 
7c1bc29
26b93ae
7c1bc29
517e4c9
7c1bc29
f03ecf2
 
 
517e4c9
7c1bc29
517e4c9
 
 
 
 
 
fc9a363
4cf412e
517e4c9
 
3d6f715
4cf412e
517e4c9
 
 
 
 
 
 
 
 
 
 
 
2f0d2a8
 
fc9a363
4cf412e
2f0d2a8
 
517e4c9
a170a6e
517e4c9
 
 
7c1bc29
517e4c9
 
a7503dc
 
517e4c9
3d6f715
4cf412e
26b93ae
a7503dc
 
26b93ae
 
517e4c9
 
26b93ae
517e4c9
 
59635a0
 
325c528
 
 
 
 
 
 
 
 
 
 
 
 
4cf412e
 
 
325c528
3e32a9e
59635a0
3e32a9e
59635a0
 
 
 
517e4c9
26b93ae
517e4c9
26b93ae
517e4c9
 
26b93ae
59635a0
 
517e4c9
94acb06
 
f03ecf2
28ba4eb
 
a7503dc
 
6fabdaf
 
325c528
6fabdaf
26b93ae
 
94acb06
26b93ae
3d6f715
 
fc9a363
 
4cf412e
 
 
 
 
 
 
 
 
 
 
 
 
26b93ae
 
fd3ce40
 
 
 
 
26b93ae
 
 
 
 
 
 
fd3ce40
21d8059
26b93ae
21d8059
 
 
 
 
 
 
26b93ae
 
fc9a363
26b93ae
3d6f715
fc9a363
 
 
4cf412e
 
 
 
 
26b93ae
 
21d8059
 
 
 
 
59635a0
 
 
fc9a363
59635a0
 
fc9a363
21d8059
 
 
 
 
 
3d6f715
4cf412e
 
 
 
 
26b93ae
 
 
df2aec8
 
2065727
 
 
 
 
 
 
 
df2aec8
 
 
 
 
2065727
 
 
 
 
 
 
df2aec8
 
46b0d09
 
 
 
 
 
 
 
 
 
 
 
 
f2d2b36
fe91d7e
46b0d09
 
325c528
46b0d09
 
 
 
325c528
 
 
 
 
 
 
46b0d09
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
 
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
7fc39ca
46b0d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf412e
67b33c8
21d8059
67b33c8
46b0d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325c528
 
 
46b0d09
4cf412e
46b0d09
67b33c8
08d5f7b
 
67b33c8
 
 
08d5f7b
 
 
 
 
 
 
 
 
 
67b33c8
 
08d5f7b
67b33c8
08d5f7b
67b33c8
46b0d09
21d8059
 
 
 
 
 
46b0d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b93ae
 
 
 
 
06aa83a
 
 
26b93ae
06aa83a
 
 
 
 
 
 
 
 
26b93ae
 
 
 
 
31c43c1
 
86a20eb
 
6b04281
31c43c1
 
26b93ae
6b04281
26b93ae
94acb06
06aa83a
 
 
 
 
3e32a9e
 
 
 
 
 
 
 
 
 
fd3ce40
3e32a9e
 
fd3ce40
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
#!/usr/bin/env python3
import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Any, Tuple
import time
import json

import gradio as gr
import importlib
import spaces
import signal
from threading import Lock

# Local modules
from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR


# Defaults matching train_QIE.sh expectations
DEFAULT_DATA_ROOT = "/data"
DEFAULT_IMAGE_FOLDER = "image"
DEFAULT_OUTPUT_DIR_BASE = "/auto/train_LoRA"
DEFAULT_DATASET_CONFIG = "/auto/dataset_QIE.toml"
DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR  # "/Qwen-Image_models"
WORKSPACE_AUTO_DIR = "/auto"

# musubi-tuner settings
DEFAULT_MUSUBI_TUNER_DIR = os.environ.get("MUSUBI_TUNER_DIR", "/musubi-tuner")
DEFAULT_MUSUBI_TUNER_REPO = os.environ.get(
    "MUSUBI_TUNER_REPO", "https://github.com/kohya-ss/musubi-tuner.git"
)


TRAINING_DIR = Path(__file__).resolve().parent

# Runtime-resolved paths with fallbacks for non-root environments
MUSUBI_TUNER_DIR_RUNTIME = DEFAULT_MUSUBI_TUNER_DIR
MODELS_ROOT_RUNTIME = DEFAULT_MODELS_ROOT
AUTO_DIR_RUNTIME = WORKSPACE_AUTO_DIR
DATA_ROOT_RUNTIME = DEFAULT_DATA_ROOT

# Active process management for hard stop (Ubuntu)
_ACTIVE_LOCK: Lock = Lock()
_ACTIVE_PROC: Optional[subprocess.Popen] = None
_ACTIVE_PGID: Optional[int] = None


def _bash_quote(s: str) -> str:
    """Return a POSIX-safe single-quoted string literal representing s."""
    if s is None:
        return "''"
    return "'" + str(s).replace("'", "'\"'\"'") + "'"


def _ensure_workspace_auto_files() -> None:
    """Ensure /workspace/auto has required helper files from this repo.

    Copies training/create_image_caption_json.py and training/dataset_QIE.toml
    into /workspace/auto so that train_QIE.sh can run unmodified.
    """
    global AUTO_DIR_RUNTIME
    try:
        os.makedirs(AUTO_DIR_RUNTIME, exist_ok=True)
    except PermissionError:
        home_auto = os.path.join(os.path.expanduser("~"), "auto")
        os.makedirs(home_auto, exist_ok=True)
        AUTO_DIR_RUNTIME = home_auto  # type: ignore
    src_py = TRAINING_DIR / "create_image_caption_json.py"
    src_toml = TRAINING_DIR / "dataset_QIE.toml"
    dst_py = Path(AUTO_DIR_RUNTIME) / "create_image_caption_json.py"
    dst_toml = Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml"

    try:
        shutil.copy2(src_py, dst_py)
    except Exception:
        pass
    try:
        if src_toml.exists():
            shutil.copy2(src_toml, dst_toml)
    except Exception:
        pass


def _update_dataset_toml(
    path: str,
    *,
    img_res_w: Optional[int] = None,
    img_res_h: Optional[int] = None,
    train_batch_size: Optional[int] = None,
    control_res_w: Optional[int] = None,
    control_res_h: Optional[int] = None,
) -> None:
    """Update dataset TOML for resolution/batch/control resolution in-place.

    - Updates [general] resolution and batch_size if provided.
    - Updates first [[datasets]] qwen_image_edit_control_resolution if provided.
    - Creates sections/keys if missing.
    """
    try:
        txt = Path(path).read_text(encoding="utf-8")
    except Exception:
        return

    def _set_in_general(block: str, key: str, value_line: str) -> str:
        import re as _re
        if _re.search(rf"(?m)^\s*{_re.escape(key)}\s*=", block):
            block = _re.sub(rf"(?m)^\s*{_re.escape(key)}\s*=.*$", value_line, block)
        else:
            block = block.rstrip() + "\n" + value_line + "\n"
        return block

    import re
    m = re.search(r"(?ms)^\[general\]\s*(.*?)(?=^\[|\Z)", txt)
    if not m:
        gen = "[general]\n"
        if img_res_w and img_res_h:
            gen += f"resolution = [{int(img_res_w)}, {int(img_res_h)}]\n"
        if train_batch_size is not None:
            gen += f"batch_size = {int(train_batch_size)}\n"
        txt = gen + "\n" + txt
    else:
        head, block, tail = txt[:m.start(1)], m.group(1), txt[m.end(1):]
        if img_res_w and img_res_h:
            block = _set_in_general(block, "resolution", f"resolution = [{int(img_res_w)}, {int(img_res_h)}]")
        if train_batch_size is not None:
            block = _set_in_general(block, "batch_size", f"batch_size = {int(train_batch_size)}")
        txt = head + block + tail

    if control_res_w and control_res_h:
        m2 = re.search(r"(?ms)^\[\[datasets\]\]\s*(.*?)(?=^\[\[|\Z)", txt)
        if m2:
            head, block, tail = txt[:m2.start(1)], m2.group(1), txt[m2.end(1):]
            line = f"qwen_image_edit_control_resolution = [{int(control_res_w)}, {int(control_res_h)}]"
            if re.search(r"(?m)^\s*qwen_image_edit_control_resolution\s*=", block):
                block = re.sub(r"(?m)^\s*qwen_image_edit_control_resolution\s*=.*$", line, block)
            else:
                block = block.rstrip() + "\n" + line + "\n"
            txt = head + block + tail

    try:
        Path(path).write_text(txt, encoding="utf-8")
    except Exception:
        pass


def _ensure_dir_writable(path: str) -> str:
    try:
        os.makedirs(path, exist_ok=True)
        return path
    except PermissionError:
        home_path = os.path.join(os.path.expanduser("~"), os.path.basename(path.strip("/\\")))
        os.makedirs(home_path, exist_ok=True)
        return home_path


def _ensure_data_root(candidate: Optional[str]) -> str:
    root = (candidate or DEFAULT_DATA_ROOT).strip() or DEFAULT_DATA_ROOT
    try:
        os.makedirs(root, exist_ok=True)
        return root
    except PermissionError:
        home_root = os.path.join(os.path.expanduser("~"), "data")
        os.makedirs(home_root, exist_ok=True)
        return home_root


def _extract_paths(files: Any) -> List[Tuple[str, str]]:
    """Extract a list of (abs_path, orig_basename) from Gradio Files input.

    Supports various gradio return shapes across versions.
    """
    out: List[Tuple[str, str]] = []
    if not files:
        return out
    # Gradio Files often returns a list
    if isinstance(files, (list, tuple)):
        items = files
    else:
        items = [files]

    for item in items:
        p: Optional[str] = None
        orig: Optional[str] = None
        # dict-like
        if isinstance(item, dict):
            p = item.get("path") or item.get("name") or item.get("file")
            orig = item.get("orig_name") or item.get("name")
        else:
            # object with attributes
            p = getattr(item, "name", None) or getattr(item, "path", None) or str(item)
            # best-effort original name attribute
            orig = getattr(item, "orig_name", None) or os.path.basename(p) if p else None
        if p:
            abs_p = os.path.abspath(p)
            out.append((abs_p, os.path.basename(orig or abs_p)))
    return out


def _norm_key(filename: str, prefix: str, suffix: str) -> str:
    stem = os.path.splitext(os.path.basename(filename))[0]
    if prefix and stem.startswith(prefix):
        stem = stem[len(prefix):]
    if suffix and stem.endswith(suffix):
        stem = stem[: -len(suffix)]
    return stem


def _copy_uploads(
    uploads: List[Tuple[str, str]], dest_dir: str, rename_to: Optional[List[str]] = None
) -> List[str]:
    os.makedirs(dest_dir, exist_ok=True)
    used_names: List[str] = []
    for idx, (src, orig) in enumerate(uploads):
        # Determine target stem
        if rename_to and idx < len(rename_to):
            stem = os.path.splitext(rename_to[idx])[0]
        else:
            stem = os.path.splitext(orig)[0]
        dst_name = f"{stem}.png"
        # ensure unique within this batch
        final_name = dst_name
        dup_idx = 1
        while final_name in used_names:
            final_name = f"{stem}_{dup_idx}.png"
            dup_idx += 1
        dst_path = os.path.join(dest_dir, final_name)
        # Convert to PNG during save
        try:
            try:
                from PIL import Image  # type: ignore
                with Image.open(src) as img:
                    img.save(dst_path, format="PNG")
            except Exception:
                # Fallback: copy then rename
                shutil.copy2(src, dst_path)
        except Exception:
            # Last resort
            shutil.copy(src, dst_path)
        used_names.append(final_name)
    return used_names


def _list_checkpoints(out_dir: str, limit: int = 20) -> List[str]:
    try:
        if not out_dir or not os.path.isdir(out_dir):
            return []
        import time
        now = time.time()
        min_age_sec = 3.0  # treat files newer than this as possibly in-flight

        items: List[Tuple[float, str]] = []
        for root, _, files in os.walk(out_dir):
            for fn in files:
                if fn.lower().endswith('.safetensors'):
                    full = os.path.join(root, fn)
                    try:
                        # Skip zero-length, too-new, or unreadable files (likely in-flight)
                        size = os.path.getsize(full)
                        if size <= 0:
                            continue
                        mtime = os.path.getmtime(full)
                        if (now - mtime) < min_age_sec:
                            continue
                        # Try opening a small read to ensure readability
                        with open(full, 'rb') as rf:
                            rf.read(64)
                        items.append((mtime, full))
                    except Exception:
                        pass
        items.sort(reverse=True)
        return [p for _, p in items[:limit]]
    except Exception:
        return []


def _find_latest_dataset_dir(root: str) -> Optional[str]:
    try:
        if not os.path.isdir(root):
            return None
        cand: List[Tuple[float, str]] = []
        for name in os.listdir(root):
            if not name.startswith("dataset_"):
                continue
            full = os.path.join(root, name)
            if os.path.isdir(full):
                try:
                    cand.append((os.path.getmtime(full), full))
                except Exception:
                    pass
        if not cand:
            return None
        cand.sort(reverse=True)
        return cand[0][1]
    except Exception:
        return None


def _collect_scripts_and_config(ds_dir: Optional[str]) -> List[str]:
    files: List[str] = []
    try:
        ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
        if os.path.isfile(ds_conf):
            files.append(ds_conf)
        if ds_dir and os.path.isdir(ds_dir):
            used_script = os.path.join(ds_dir, "train_QIE_used.sh")
            if os.path.isfile(used_script):
                files.append(used_script)
            meta = os.path.join(ds_dir, "metadata.jsonl")
            if os.path.isfile(meta):
                files.append(meta)
    except Exception:
        pass
    return files


def _files_to_gallery(files: Any) -> List[str]:
    items: List[str] = []
    if not files:
        return items
    seq = files if isinstance(files, (list, tuple)) else [files]
    for f in seq:
        p = None
        if isinstance(f, str):
            p = f
        elif isinstance(f, dict):
            p = f.get("path") or f.get("name")
        else:
            p = getattr(f, "path", None) or getattr(f, "name", None)
        if p:
            items.append(p)
    return items


def _prepare_script(
    dataset_name: str,
    caption: str,
    data_root: str,
    image_folder: str,
    control_folders: List[Optional[str]],
    models_root: str,
    output_dir_base: Optional[str] = None,
    dataset_config: Optional[str] = None,
    override_max_epochs: Optional[int] = None,
    override_save_every: Optional[int] = None,
    override_run_name: Optional[str] = None,
    target_prefix: Optional[str] = None,
    target_suffix: Optional[str] = None,
    control_prefixes: Optional[List[Optional[str]]] = None,
    control_suffixes: Optional[List[Optional[str]]] = None,
    override_learning_rate: Optional[str] = None,
    override_network_dim: Optional[int] = None,
    override_seed: Optional[int] = None,
    override_te_cache_bs: Optional[int] = None,
) -> Path:
    """Create a temporary copy of train_QIE.sh with injected variables.

    Only variables that must vary per-run are replaced. The rest of the script
    remains as-is to preserve behavior.
    """
    src = TRAINING_DIR / "train_QIE.sh"
    txt = src.read_text(encoding="utf-8")

    # Replace core variables
    replacements = {
        r"^DATA_ROOT=\".*\"": f"DATA_ROOT={_bash_quote(data_root)}",
        r"^DATASET_NAME=\".*\"": f"DATASET_NAME={_bash_quote(dataset_name)}",
        r"^CAPTION=\".*\"": f"CAPTION={_bash_quote(caption)}",
        r"^IMAGE_FOLDER=\".*\"": f"IMAGE_FOLDER={_bash_quote(image_folder)}",
    }
    if output_dir_base:
        replacements[r"^OUTPUT_DIR_BASE=\".*\""] = (
            f"OUTPUT_DIR_BASE={_bash_quote(output_dir_base)}"
        )
    if dataset_config:
        replacements[r"^DATASET_CONFIG=\".*\""] = (
            f"DATASET_CONFIG={_bash_quote(dataset_config)}"
        )

    for pat, val in replacements.items():
        txt = re.sub(pat, val, txt, flags=re.MULTILINE)

    # Inject CONTROL_FOLDER_i if provided (uncomment/override or append)
    for i in range(8):
        val = control_folders[i] if i < len(control_folders) else None
        if not val:
            continue
        # Try to replace commented placeholder first
        pattern = rf"^#\s*CONTROL_FOLDER_{i}=\".*\""
        if re.search(pattern, txt, flags=re.MULTILINE):
            txt = re.sub(
                pattern,
                f"CONTROL_FOLDER_{i}={_bash_quote(val)}",
                txt,
                flags=re.MULTILINE,
            )
        else:
            # Append after IMAGE_FOLDER definition
            txt = re.sub(
                r"^(IMAGE_FOLDER=.*)$",
                rf"\1\nCONTROL_FOLDER_{i}={_bash_quote(val)}",
                txt,
                count=1,
                flags=re.MULTILINE,
            )

    # Point model paths to the selected models_root
    def _replace_model_path(txt: str, key: str, rel: str) -> str:
        return re.sub(
            rf"--{key} \"[^\"]+\"",
            f"--{key} \"{models_root.rstrip('/')}/{rel}\"",
            txt,
        )

    txt = _replace_model_path(txt, "vae", "vae/diffusion_pytorch_model.safetensors")
    txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
    txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")

    # Replace working dir for metadata generation to runtime /auto
    txt = re.sub(r"^cd\s+/workspace/auto\s*$", f"cd {AUTO_DIR_RUNTIME}", txt, flags=re.MULTILINE)
    # Ensure musubi-tuner path matches runtime location
    txt = re.sub(r"^cd\s+/musubi-tuner\s*$", f"cd {re.escape(MUSUBI_TUNER_DIR_RUNTIME)}", txt, flags=re.MULTILINE)

    # ZeroGPU compatibility: avoid spawning via 'accelerate launch'.
    # Run the training module directly in-process so GPU stays attached
    # to the same Python request context.
    txt = re.sub(
        r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py",
        r"python -u src/musubi_tuner/qwen_image_train_network.py",
        txt,
        flags=re.MULTILINE,
    )

    # Optionally override epochs and save frequency for ZeroGPU time slicing
    if override_max_epochs is not None and override_max_epochs > 0:
        txt = re.sub(r"--max_train_epochs\s+\d+",
                     f"--max_train_epochs {override_max_epochs}", txt)
    if override_save_every is not None and override_save_every > 0:
        txt = re.sub(r"--save_every_n_epochs\s+\d+",
                     f"--save_every_n_epochs {override_save_every}", txt)
    if override_run_name:
        txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE)

    # Inject prefix/suffix flags for metadata creation
    extra_lines: List[str] = []
    if (target_prefix or ""):
        extra_lines.append(f"  --target_prefix {_bash_quote(target_prefix)} \\")
    if (target_suffix or ""):
        extra_lines.append(f"  --target_suffix {_bash_quote(target_suffix)} \\")
    for i in range(8):
        pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None
        suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None
        if pre:
            extra_lines.append(f"  --control_prefix_{i} {_bash_quote(pre)} \\")
        if suf:
            extra_lines.append(f"  --control_suffix_{i} {_bash_quote(suf)} \\")

    if extra_lines:
        extra_block = "\n".join(extra_lines)
        # Insert extra flags just before the CONTROL_ARGS line, preserving indentation.
        txt = re.sub(
            r'^(\s*)"\$\{CONTROL_ARGS\[@\]\}"',
            lambda m: f"{extra_block}\n{m.group(1)}\"${{CONTROL_ARGS[@]}}\"",
            txt,
            flags=re.MULTILINE,
        )

    # Override CLI hyperparameters if provided
    if override_learning_rate:
        txt = re.sub(r"--learning_rate\s+[-+eE0-9\.]+", f"--learning_rate {override_learning_rate}", txt)
    if override_network_dim is not None:
        txt = re.sub(r"--network_dim\s+\d+", f"--network_dim {override_network_dim}", txt)
    if override_seed is not None:
        txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)

    # Optionally override text-encoder cache batch size
    if override_te_cache_bs is not None and override_te_cache_bs > 0:
        txt = re.sub(
            r"(qwen_image_cache_text_encoder_outputs\.py[^\n]*--batch_size\s+)\d+",
            rf"\g<1>{int(override_te_cache_bs)}",
            txt,
            flags=re.MULTILINE,
        )

    # Prefer overriding variable definitions at top of script (safer than CLI regex)
    def _set_var(name: str, value: str) -> None:
        nonlocal txt
        pattern = rf"(?m)^\s*{name}\s*=.*$"
        replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}'
        if re.search(pattern, txt):
            txt = re.sub(pattern, replacement, txt)
        else:
            txt = f"{replacement}\n" + txt

    if override_learning_rate:
        _set_var('LEARNING_RATE', override_learning_rate)
    if override_network_dim is not None:
        _set_var('NETWORK_DIM', str(override_network_dim))
    if override_seed is not None:
        _set_var('SEED', str(override_seed))
    if override_max_epochs is not None and override_max_epochs > 0:
        _set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs))
    if override_save_every is not None and override_save_every > 0:
        _set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every))

    # Write to a temp file alongside this repo for easier inspection
    run_dir = TRAINING_DIR / ".gradio_runs"
    run_dir.mkdir(parents=True, exist_ok=True)
    tmp = run_dir / f"train_QIE_run_{os.getpid()}.sh"
    tmp.write_text(txt, encoding="utf-8", newline="\n")
    try:
        os.chmod(tmp, 0o755)
    except Exception:
        pass
    return tmp


def _pick_shell() -> str:
    for sh in ("bash", "sh"):
        if shutil.which(sh):
            return sh
    raise RuntimeError("No POSIX shell found. Please install bash or sh.")


def _is_git_repo(path: str) -> bool:
    try:
        out = subprocess.run(
            ["git", "-C", path, "rev-parse", "--is-inside-work-tree"],
            capture_output=True,
            text=True,
            check=False,
        )
        return out.returncode == 0 and out.stdout.strip() == "true"
    except Exception:
        return False


def _startup_clone_musubi_tuner() -> None:
    global MUSUBI_TUNER_DIR_RUNTIME
    target = MUSUBI_TUNER_DIR_RUNTIME
    repo = DEFAULT_MUSUBI_TUNER_REPO
    parent = os.path.dirname(target.rstrip("/\\")) or "/"
    try:
        os.makedirs(parent, exist_ok=True)
    except PermissionError:
        # Fallback to home directory
        target = os.path.join(os.path.expanduser("~"), "musubi-tuner")
        MUSUBI_TUNER_DIR_RUNTIME = target
        os.makedirs(os.path.dirname(target), exist_ok=True)
    except Exception:
        pass

    if os.path.isdir(target) and _is_git_repo(target):
        print(f"[QIE] musubi-tuner exists at {target}; pulling latest...")
        try:
            subprocess.run(["git", "-C", target, "fetch", "--all", "--prune"], check=False)
            subprocess.run(["git", "-C", target, "pull", "--ff-only"], check=False)
        except Exception as e:
            print(f"[QIE] git pull failed: {e}")
        return

    if os.path.exists(target) and not _is_git_repo(target):
        print(f"[QIE] Warning: {target} exists and is not a git repo. Skipping clone.")
        return

    print(f"[QIE] Cloning musubi-tuner into {target} from {repo} ...")
    try:
        subprocess.run(["git", "clone", "--depth", "1", repo, target], check=True)
        print("[QIE] Clone completed.")
    except subprocess.CalledProcessError as e:
        print(f"[QIE] Clone failed at {target}: {e}")
        # Last-chance fallback into home
        if not target.startswith(os.path.expanduser("~")):
            fallback = os.path.join(os.path.expanduser("~"), "musubi-tuner")
            print(f"[QIE] Retrying clone into {fallback}...")
            try:
                subprocess.run(["git", "clone", "--depth", "1", repo, fallback], check=True)
                MUSUBI_TUNER_DIR_RUNTIME = fallback
                print("[QIE] Clone completed in fallback.")
            except Exception as e2:
                print(f"[QIE] Clone failed in fallback as well: {e2}")


def _run_pip(args: List[str], cwd: Optional[str] = None) -> None:
    cmd = [sys.executable, "-m", "pip"] + args
    try:
        print(f"[QIE] pip {' '.join(args)} (cwd={cwd or os.getcwd()})")
        subprocess.run(cmd, check=True, cwd=cwd)
    except subprocess.CalledProcessError as e:
        print(f"[QIE] pip failed: {e}")


def _startup_install_musubi_deps() -> None:
    repo_dir = MUSUBI_TUNER_DIR_RUNTIME
    if not os.path.isdir(repo_dir):
        print(f"[QIE] Skip deps: musubi-tuner not found at {repo_dir}")
        return
    # Upgrade basic build tooling (best-effort)
    try:
        _run_pip(["install", "-U", "pip", "setuptools", "wheel"])
    except Exception:
        pass

    # Optional Torch extra via env: MUSUBI_TUNER_TORCH_EXTRA=cu124|cu128
    extra = os.environ.get("MUSUBI_TUNER_TORCH_EXTRA", "").strip()
    editable_spec = "." if not extra else f".[{extra}]"

    # Install musubi-tuner in editable mode to expose entrypoints and deps
    try:
        _run_pip(["install", "-e", editable_spec], cwd=repo_dir)
    except Exception:
        # Fallback: plain install without editable
        try:
            _run_pip(["install", editable_spec], cwd=repo_dir)
        except Exception:
            print("[QIE] WARN: musubi-tuner installation failed. Continuing.")


@spaces.GPU
def run_training(
    output_name: str,
    caption: str,
    image_uploads: Any,
    target_prefix: str,
    target_suffix: str,
    control0_uploads: Any,
    ctrl0_prefix: str,
    ctrl0_suffix: str,
    control1_uploads: Any,
    ctrl1_prefix: str,
    ctrl1_suffix: str,
    control2_uploads: Any,
    ctrl2_prefix: str,
    ctrl2_suffix: str,
    control3_uploads: Any,
    ctrl3_prefix: str,
    ctrl3_suffix: str,
    control4_uploads: Any,
    ctrl4_prefix: str,
    ctrl4_suffix: str,
    control5_uploads: Any,
    ctrl5_prefix: str,
    ctrl5_suffix: str,
    control6_uploads: Any,
    ctrl6_prefix: str,
    ctrl6_suffix: str,
    control7_uploads: Any,
    ctrl7_prefix: str,
    ctrl7_suffix: str,
    learning_rate: str,
    network_dim: int,
    train_res_w: int,
    train_res_h: int,
    train_batch_size: int,
    control_res_w: int,
    control_res_h: int,
    te_cache_batch_size: int,
    seed: int,
    max_epochs: int,
    save_every: int,
) -> Iterable[tuple]:
    global _ACTIVE_PROC, _ACTIVE_PGID
    # Basic validation
    log_buf = "[QIE] Start Training invoked.\n"
    ckpts: List[str] = []
    artifacts: List[str] = []
    # Emit an initial line so UI can confirm invocation
    yield (log_buf, ckpts, artifacts)
    if not output_name.strip():
        log_buf += "[ERROR] OUTPUT NAME is required.\n"
        yield (log_buf, ckpts, artifacts)


def _stop_active_training() -> None:
    """Ubuntu向けのハード停止: 実行中の学習プロセスのプロセスグループを終了する"""
    with _ACTIVE_LOCK:
        proc = _ACTIVE_PROC
        pgid = _ACTIVE_PGID
    if not proc:
        return
    try:
        if pgid is not None:
            os.killpg(pgid, signal.SIGTERM)
        else:
            os.kill(proc.pid, signal.SIGTERM)
    except Exception:
        pass
    try:
        proc.wait(timeout=5)
    except Exception:
        try:
            if pgid is not None:
                os.killpg(pgid, signal.SIGKILL)
            else:
                os.kill(proc.pid, signal.SIGKILL)
        except Exception:
            pass
        return
    if not caption.strip():
        log_buf += "[ERROR] CAPTION is required.\n"
        yield (log_buf, ckpts, artifacts)
        return

    # Ensure /auto holds helper files expected by the script
    _ensure_workspace_auto_files()
    # Resolve data root and create dataset directories (auto-decide)
    global DATA_ROOT_RUNTIME
    DATA_ROOT_RUNTIME = _ensure_data_root(None)
    # Auto-generate dataset directory name
    import time
    ds_name = f"dataset_{int(time.time())}"
    ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name)
    img_folder_name = DEFAULT_IMAGE_FOLDER
    img_dir = os.path.join(ds_dir, img_folder_name)
    os.makedirs(img_dir, exist_ok=True)

    # Ingest uploads into dataset folders
    base_files = _extract_paths(image_uploads)
    if not base_files:
        log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
        yield (log_buf, ckpts, artifacts)
        return
    base_filenames = _copy_uploads(base_files, img_dir)
    log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
    yield (log_buf, ckpts, artifacts)

    # Prepare control sets
    control_upload_sets = [
        _extract_paths(control0_uploads),
        _extract_paths(control1_uploads),
        _extract_paths(control2_uploads),
        _extract_paths(control3_uploads),
        _extract_paths(control4_uploads),
        _extract_paths(control5_uploads),
        _extract_paths(control6_uploads),
        _extract_paths(control7_uploads),
    ]
    # Require control_0; others optional
    if not control_upload_sets[0]:
        log_buf += "[ERROR] control_0 images are required.\n"
        yield (log_buf, ckpts, artifacts)
        return

    control_dirs: List[Optional[str]] = []
    for i, uploads in enumerate(control_upload_sets):
        if not uploads:
            control_dirs.append(None)
            continue
        folder_name = f"control_{i}"
        cdir = os.path.join(ds_dir, folder_name)
        os.makedirs(cdir, exist_ok=True)
        # Simply copy; name matching will be handled by create_image_caption_json.py
        _copy_uploads(uploads, cdir)
        control_dirs.append(folder_name)
        log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
        yield (log_buf, ckpts, artifacts)

    # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh

    # Prepare script with user parameters
    control_folders = [
        (control_dirs[i] if control_dirs[i] else None)
        for i in range(8)
    ]

    # Decide dataset_config path with fallback to runtime auto dir
    ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")

    # Update dataset config with requested resolution/batch settings
    try:
        _update_dataset_toml(
            ds_conf,
            img_res_w=int(train_res_w) if train_res_w else None,
            img_res_h=int(train_res_h) if train_res_h else None,
            train_batch_size=int(train_batch_size) if train_batch_size else None,
            control_res_w=int(control_res_w) if control_res_w else None,
            control_res_h=int(control_res_h) if control_res_h else None,
        )
        log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
    except Exception as e:
        log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
    # Expose dataset config for download (if exists)
    if os.path.isfile(ds_conf):
        artifacts = [ds_conf]

    # Resolve models_root and set output_dir_base to the unique dataset dir
    models_root = MODELS_ROOT_RUNTIME
    out_base = ds_dir
    try:
        os.makedirs(out_base, exist_ok=True)
    except Exception:
        pass

    tmp_script = _prepare_script(
        dataset_name=ds_name,
        caption=caption,
        data_root=DATA_ROOT_RUNTIME,
        image_folder=img_folder_name,
        control_folders=control_folders,
        models_root=models_root,
        output_dir_base=out_base,
        dataset_config=ds_conf,
        override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
        override_save_every=save_every if save_every and save_every > 0 else None,
        override_run_name=output_name.strip(),
        target_prefix=(target_prefix or ""),
        target_suffix=(target_suffix or ""),
        control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix],
        control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
        override_learning_rate=(learning_rate or None),
        override_network_dim=int(network_dim) if network_dim is not None else None,
        override_te_cache_bs=int(te_cache_batch_size) if te_cache_batch_size else None,
        override_seed=int(seed) if seed is not None else None,
    )


    shell = _pick_shell()
    log_buf += f"[QIE] Using shell: {shell}\n"
    log_buf += f"[QIE] Running script: {tmp_script}\n"
    out_dir = os.path.join(out_base, output_name.strip())
    ckpts = _list_checkpoints(out_dir)
    # Copy the final script to dataset dir for download
    used_script_path = os.path.join(out_base, "train_QIE_used.sh")
    try:
        shutil.copy2(str(tmp_script), used_script_path)
        try:
            os.chmod(used_script_path, 0o755)
        except Exception:
            pass
        if used_script_path not in artifacts:
            artifacts.append(used_script_path)
    except Exception:
        pass
    yield (log_buf, ckpts, artifacts)

    # Run and stream output
    # Ensure child Python processes are unbuffered for real-time logs
    child_env = os.environ.copy()
    child_env["PYTHONUNBUFFERED"] = "1"
    child_env["PYTHONIOENCODING"] = "utf-8"

    proc = subprocess.Popen(
        [shell, str(tmp_script)],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
        universal_newlines=True,
        env=child_env,
        preexec_fn=os.setsid,
    )
    # Register active process for hard stop
    with _ACTIVE_LOCK:
        _ACTIVE_PROC = proc
        try:
            _ACTIVE_PGID = os.getpgid(proc.pid)
        except Exception:
            _ACTIVE_PGID = None
    try:
        assert proc.stdout is not None
        i = 0
        for line in proc.stdout:
            log_buf += line
            i += 1
            if i % 30 == 0:
                ckpts = _list_checkpoints(out_dir)
                # Try to add metadata.jsonl once available
                metadata_json = os.path.join(out_base, "metadata.jsonl")
                if os.path.isfile(metadata_json) and metadata_json not in artifacts:
                    artifacts.append(metadata_json)
            yield (log_buf, ckpts, artifacts)
    finally:
        code = proc.wait()
        # Clear active process registration if this proc
        with _ACTIVE_LOCK:
            if _ACTIVE_PROC is proc:
                _ACTIVE_PROC = None
                _ACTIVE_PGID = None
        # Try to locate latest LoRA file for download
        lora_path = None
        try:
            ckpts = _list_checkpoints(out_dir)
        except Exception:
            pass
        lora_path = ckpts[0] if ckpts else None
        if code < 0:
            try:
                sig = -code
                log_buf += f"[QIE] Terminated by signal: {sig}\n"
            except Exception:
                log_buf += f"[QIE] Terminated by signal.\n"
        log_buf += f"[QIE] Exit code: {code}\n"
        # Final attempt to include metadata.jsonl
        metadata_json = os.path.join(out_base, "metadata.jsonl")
        if os.path.isfile(metadata_json) and metadata_json not in artifacts:
            artifacts.append(metadata_json)
        yield (log_buf, ckpts, artifacts)


def build_ui() -> gr.Blocks:
    css = """
    .pad-section {
      padding: 6px;
      margin-bottom: 12px;
      border: 1px solid var(--color-border, #e5e7eb);
      border-radius: 8px;
      background: var(--color-background-secondary, #ffffff);
    }
    .pad-section_0 {
      padding: 6px;
      margin-bottom: 12px;
      border: 1px solid var(--color-border, #e5e7eb);
      border-radius: 8px;
      background: var(--color-background-secondary, #fafafa);
    }
    .pad-section_1 {
      padding: 6px;
      margin-bottom: 12px;
      border: 1px solid var(--color-border, #e5e7eb);
      border-radius: 8px;
      background: var(--color-background-secondary, #eaeaea);
    }
    """
    with gr.Blocks(title="Qwen-Image-Edit: Trainer", css=css) as demo:
        with gr.Tabs() as tabs:
            with gr.TabItem("Training"):
                gr.Markdown("""
                # Qwen-Image-Edit Trainer
                学習に使う画像をアップロードし、必要ならファイル名の前後にある共通の文字(prefix/suffix)を指定して、
                自動でデータセットを作成し学習を開始します。難しい操作は不要です。
                """)

                with gr.Accordion("Settings", elem_classes=["pad-section"]):
                    with gr.Group():
                        with gr.Row():
                            output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
                            caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)

                        with gr.Row():
                            lr_input = gr.Textbox(label="Learning rate", value="1e-3")
                            dim_input = gr.Number(label="Network dim", value=4, precision=0)
                            train_bs = gr.Number(label="Batch size (dataset)", value=1, precision=0)
                            seed_input = gr.Number(label="Seed", value=42, precision=0)
                            max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
                            save_every = gr.Number(label="Save every N epochs", value=10, precision=0)

                        with gr.Row():
                            tr_w = gr.Number(label="Image resolution W", value=1024, precision=0)
                            tr_h = gr.Number(label="Image resolution H", value=1024, precision=0)
                            cr_w = gr.Number(label="Control resolution W", value=1024, precision=0)
                            cr_h = gr.Number(label="Control resolution H", value=1024, precision=0)
                            te_bs = gr.Number(label="TE cache batch size", value=16, precision=0)

                with gr.Accordion("Target Image", elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath", height=220, scale=3)
                            main_gallery = gr.Gallery(label="Target preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    main_prefix = gr.Textbox(label="Target prefix", placeholder="e.g., IMG_")
                                    main_suffix = gr.Textbox(label="Target suffix", placeholder="e.g., _v2")
                                with gr.Accordion("prefix/sufixについて", open=False):
                                    gr.Markdown("""
                                    ファイルの同名判定のため、画像のファイル名から共通の先頭/末尾文字を取り除く指定(例: IMG_ や _v2)
                                    - まずターゲット画像のファイル名(拡張子なし)から、指定した Target prefix/suffix を取り除いたものを key とします。
                                    - 各コントロールは「付加」規則で、期待名 = control_prefix_i + key + control_suffix_i + ".png" を探して対応付けます。
                                    - アップロード時に画像は自動で .png に変換して保存します(元のファイル名のベースは維持)。
                                    - Control 0 は必須、Control 1〜7 は任意。コントロール画像が1枚だけのときは、すべてのターゲット画像に適用します。
                                    """)

                # control_0 is required and shown outside the accordion
                with gr.Accordion("Control 0", elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl0_files = gr.File(label="Upload control_0 images (required)", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl0_gallery = gr.Gallery(label="control_0 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl0_prefix = gr.Textbox(label="control_0 prefix", placeholder="e.g., C0_")
                                    ctrl0_suffix = gr.Textbox(label="control_0 suffix", placeholder="e.g., _mask")

                # Optional controls start from 1, accordion closed by default
                with gr.Accordion("Control 1", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl1_files = gr.File(label="Upload control_1 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl1_gallery = gr.Gallery(label="control_1 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl1_prefix = gr.Textbox(label="control_1 prefix", placeholder="")
                                    ctrl1_suffix = gr.Textbox(label="control_1 suffix", placeholder="")
                with gr.Accordion("Control 2", open=False, elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl2_files = gr.File(label="Upload control_2 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl2_gallery = gr.Gallery(label="control_2 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl2_prefix = gr.Textbox(label="control_2 prefix", placeholder="")
                                    ctrl2_suffix = gr.Textbox(label="control_2 suffix", placeholder="")
                with gr.Accordion("Control 3", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl3_files = gr.File(label="Upload control_3 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl3_gallery = gr.Gallery(label="control_3 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl3_prefix = gr.Textbox(label="control_3 prefix", placeholder="")
                                    ctrl3_suffix = gr.Textbox(label="control_3 suffix", placeholder="")
                with gr.Accordion("Control 4", open=False, elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl4_files = gr.File(label="Upload control_4 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl4_gallery = gr.Gallery(label="control_4 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl4_prefix = gr.Textbox(label="control_4 prefix", placeholder="")
                                    ctrl4_suffix = gr.Textbox(label="control_4 suffix", placeholder="")
                with gr.Accordion("Control 5", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl5_files = gr.File(label="Upload control_5 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl5_gallery = gr.Gallery(label="control_5 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl5_prefix = gr.Textbox(label="control_5 prefix", placeholder="")
                                    ctrl5_suffix = gr.Textbox(label="control_5 suffix", placeholder="")
                with gr.Accordion("Control 6", open=False, elem_classes=["pad-section_1"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl6_files = gr.File(label="Upload control_6 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl6_gallery = gr.Gallery(label="control_6 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl6_prefix = gr.Textbox(label="control_6 prefix", placeholder="")
                                    ctrl6_suffix = gr.Textbox(label="control_6 suffix", placeholder="")
                with gr.Accordion("Control 7", open=False, elem_classes=["pad-section_0"]):
                    with gr.Group():
                        with gr.Row():
                            ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath", height=220, scale=3)
                            ctrl7_gallery = gr.Gallery(label="control_7 preview", columns=4, height=220, object_fit='contain', preview=True, scale=3)
                            with gr.Column(scale=1):
                                with gr.Row():
                                    ctrl7_prefix = gr.Textbox(label="control_7 prefix", placeholder="")
                                    ctrl7_suffix = gr.Textbox(label="control_7 suffix", placeholder="")

                # Models root / OUTPUT_DIR_BASE / DATASET_CONFIG are auto-resolved at runtime; no user input needed.

                run_btn = gr.Button("Start Training", variant="primary")
                logs = gr.Textbox(label="Logs", lines=20)
                ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
                scripts_files = gr.Files(label="Scripts & Config (live)", interactive=False)
                with gr.Row():
                    stop_btn = gr.Button("学習を停止", variant="secondary")
                    refresh_scripts_btn = gr.Button("ファイルを再取得", variant="secondary")

                # moved max_epochs/save_every above next to OUTPUT NAME

                # Wire previews
                images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery)
                ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery)
                ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery)
                ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery)
                ctrl3_files.change(fn=_files_to_gallery, inputs=ctrl3_files, outputs=ctrl3_gallery)
                ctrl4_files.change(fn=_files_to_gallery, inputs=ctrl4_files, outputs=ctrl4_gallery)
                ctrl5_files.change(fn=_files_to_gallery, inputs=ctrl5_files, outputs=ctrl5_gallery)
                ctrl6_files.change(fn=_files_to_gallery, inputs=ctrl6_files, outputs=ctrl6_gallery)
                ctrl7_files.change(fn=_files_to_gallery, inputs=ctrl7_files, outputs=ctrl7_gallery)

                run_btn.click(
                    fn=run_training,
                    inputs=[
                        output_name, caption, images_input, main_prefix, main_suffix,
                        ctrl0_files, ctrl0_prefix, ctrl0_suffix,
                        ctrl1_files, ctrl1_prefix, ctrl1_suffix,
                        ctrl2_files, ctrl2_prefix, ctrl2_suffix,
                        ctrl3_files, ctrl3_prefix, ctrl3_suffix,
                        ctrl4_files, ctrl4_prefix, ctrl4_suffix,
                        ctrl5_files, ctrl5_prefix, ctrl5_suffix,
                        ctrl6_files, ctrl6_prefix, ctrl6_suffix,
                        ctrl7_files, ctrl7_prefix, ctrl7_suffix,
                        lr_input, dim_input,
                        tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
                        seed_input, max_epochs, save_every,
                    ],
                    outputs=[logs, ckpt_files, scripts_files],
                )

                # 回収ボタン: 直近の dataset_ ディレクトリからチェックポイントとスクリプト/設定を再取得
                def _refresh_all() -> tuple:
                    try:
                        ds_dir = _find_latest_dataset_dir(DATA_ROOT_RUNTIME)
                    except Exception:
                        ds_dir = None
                    try:
                        ck = _list_checkpoints(ds_dir) if ds_dir else []
                    except Exception:
                        ck = []
                    try:
                        sc = _collect_scripts_and_config(ds_dir)
                    except Exception:
                        sc = _collect_scripts_and_config(None)
                    return ck, sc

                refresh_scripts_btn.click(
                    fn=_refresh_all,
                    inputs=[],
                    outputs=[ckpt_files, scripts_files],
                )

                # Hard stop button (Ubuntu): kill active process group
                def _on_stop_click():
                    _stop_active_training()
                    return
                stop_btn.click(fn=_on_stop_click, inputs=[], outputs=[])

            with gr.TabItem("Prompt Generator"):
                gr.Markdown("""
                # 🎨 A→B 変換プロンプト自動生成
                画像A(入力)と画像B(出力)、補足説明を入力すると、  
                A→B の変換内容を英語プロンプトとして自動生成し、タスク名候補(3件)も提案します。  
                モデルは `gpt-5` を使用します。
                """)

                api_key_pg = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...")
                with gr.Row():
                    img_a_pg = gr.Image(type="filepath", label="Image A (Input)", height=300)
                    img_b_pg = gr.Image(type="filepath", label="Image B (Output)", height=300)

                notes_pg = gr.Textbox(label="補足説明(日本語可)", lines=4, value="この画像は例であって、汎用的なプロンプトにする")
                want_japanese_pg = gr.Checkbox(label="日本語訳を含める", value=True)
                run_btn_pg = gr.Button("生成する", variant="primary")

                english_out_pg = gr.Textbox(label="English Prompt", lines=8)
                names_out_pg = gr.Textbox(label="Name Suggestions", lines=4)
                japanese_out_pg = gr.Textbox(label="日本語訳(任意)", lines=8)

                def _on_click_prompt(api_key_in, a_path, b_path, notes_in, ja_flag):
                    # Lazy import to avoid constructing extra Blocks at startup
                    qpg = importlib.import_module("QIE_prompt_generator")
                    a_url = qpg.file_to_data_url(a_path) if a_path else None
                    b_url = qpg.file_to_data_url(b_path) if b_path else None
                    return qpg.call_openai_chat(api_key_in, a_url, b_url, notes_in, ja_flag)

                run_btn_pg.click(
                    fn=_on_click_prompt,
                    inputs=[api_key_pg, img_a_pg, img_b_pg, notes_pg, want_japanese_pg],
                    outputs=[english_out_pg, names_out_pg, japanese_out_pg],
                )

    return demo


def _startup_download_models() -> None:
    global MODELS_ROOT_RUNTIME
    # Pick a writable models directory
    candidate = os.environ.get("QWEN_IMAGE_MODELS_DIR", DEFAULT_MODELS_ROOT)
    try:
        os.makedirs(candidate, exist_ok=True)
        MODELS_ROOT_RUNTIME = candidate
    except PermissionError:
        MODELS_ROOT_RUNTIME = os.path.join(os.path.expanduser("~"), "Qwen-Image_models")
        os.makedirs(MODELS_ROOT_RUNTIME, exist_ok=True)

    print(f"[QIE] Ensuring models in: {MODELS_ROOT_RUNTIME}")
    try:
        download_all_models(MODELS_ROOT_RUNTIME)
    except Exception as e:
        print(f"[QIE] Model download failed: {e}")


if __name__ == "__main__":
    # 1) Ensure musubi-tuner is cloned before anything else
    _startup_clone_musubi_tuner()
    # 1.1) Install musubi-tuner dependencies (best-effort)
    _startup_install_musubi_deps()

    # 2) Download models at startup (blocking by design)
    _startup_download_models()

    # 3) Launch Gradio app
    ui = build_ui()
    # Limit concurrency (training is heavy). Enable queue for Spaces compatibility.
    # Use generic signature to support multiple gradio versions.
    try:
        ui = ui.queue(max_size=16)
    except TypeError:
        ui = ui.queue()
    # Allow Gradio to serve files saved under our runtime dirs
    try:
        allowed = [
            AUTO_DIR_RUNTIME,
            os.path.join(AUTO_DIR_RUNTIME, "train_LoRA"),
            DEFAULT_DATA_ROOT,
            DATA_ROOT_RUNTIME,
            os.path.join(os.path.expanduser("~"), "auto"),
            os.path.join(os.path.expanduser("~"), "data"),
        ]
        ui.launch(server_name="0.0.0.0", allowed_paths=allowed, ssr_mode=False)
    except TypeError:
        # Older gradio without allowed_paths
        try:
            ui.launch(server_name="0.0.0.0", ssr_mode=False)
        except TypeError:
            # Very old gradio without ssr_mode
            ui.launch(server_name="0.0.0.0")