yeq6x's picture
Refactor app.py to update prefix/suffix naming conventions for metadata creation and enhance UI with new training hyperparameter inputs. Modify train_QIE.sh to utilize dynamic hyperparameter values for training execution, improving configurability and user experience.
6fabdaf
raw
history blame
4.49 kB
#!/usr/bin/env bash
# ==============================
# Generate metadata.jsonl before training
# Configure variables directly in this file.
# No environment variable overrides are used.
# ==============================
echo "[QIE] Torch version check"
python - <<'PY'
try:
import torch
print(f"[QIE] torch: {torch.__version__}")
except Exception as e:
print(f"[QIE] torch: not available ({e})")
PY
DATA_ROOT="/workspace/data"
DATASET_NAME=""
# Required inputs
CAPTION=""
IMAGE_FOLDER="image"
# CONTROL_FOLDER_0=""
# CONTROL_FOLDER_1=""
# CONTROL_FOLDER_2=""
# CONTROL_FOLDER_3=""
# CONTROL_FOLDER_4=""
# CONTROL_FOLDER_5=""
# CONTROL_FOLDER_6=""
# CONTROL_FOLDER_7=""
RUN_NAME="${DATASET_NAME%/}"
DATASET_DIR="${DATA_ROOT%/}/${DATASET_NAME}"
OUTPUT_DIR_BASE="/workspace/auto/train_LoRA"
DATASET_CONFIG="/workspace/auto/dataset_QIE.toml"
OUTPUT_JSON="${DATASET_DIR%/}/metadata.jsonl"
# Training hyperparameters (can be overridden by app)
LEARNING_RATE="1e-3"
NETWORK_DIM=4
SEED=42
MAX_TRAIN_EPOCHS=100
SAVE_EVERY_N_EPOCHS=10
# Build control args from folder names with auto-detect fallback
CONTROL_ARGS=()
for i in {0..7}; do
var="CONTROL_FOLDER_${i}"
folder_name=${!var}
cpath=""
if [[ -n "$folder_name" ]]; then
cpath="${DATASET_DIR%/}/$folder_name"
elif [[ -d "${DATASET_DIR%/}/control_${i}" ]]; then
cpath="${DATASET_DIR%/}/control_${i}"
elif [[ $i -eq 0 && -d "${DATASET_DIR%/}/control" ]]; then
# Special fallback: allow single control folder named "control" for control_0
cpath="${DATASET_DIR%/}/control"
fi
[[ -n "$cpath" ]] && CONTROL_ARGS+=("--control_dir_${i}" "$cpath")
done
# Sync dataset config's image_jsonl_file with OUTPUT_JSON if present
if [[ -f "$DATASET_CONFIG" ]]; then
python - "$DATASET_CONFIG" "$OUTPUT_JSON" <<'PY'
import sys, re, os
path, out = sys.argv[1], sys.argv[2]
txt = open(path, 'r', encoding='utf-8').read()
base = os.path.dirname(path)
cache = os.path.join(base, 'cache').replace('\\\\', '/')
# Update image_jsonl_file
new = re.sub(r"(?m)^\s*image_jsonl_file\s*=.*$", f'image_jsonl_file = "{out}"', txt)
if new == txt and 'image_jsonl_file' not in txt:
new = txt.rstrip('\n') + f"\nimage_jsonl_file = \"{out}\"\n"
# Update cache_directory to a writable folder under the config directory
if re.search(r"(?m)^\s*cache_directory\s*=", new):
new = re.sub(r"(?m)^\s*cache_directory\s*=.*$", f'cache_directory = "{cache}"', new)
else:
new = new.rstrip('\n') + f"\ncache_directory = \"{cache}\"\n"
open(path, 'w', encoding='utf-8').write(new)
print("[QIE] Updated {}: image_jsonl_file -> {}".format(path, out))
print("[QIE] Updated {}: cache_directory -> {}".format(path, cache))
PY
mkdir -p "$(dirname "$DATASET_CONFIG")/cache"
else
echo "[QIE] WARN: Dataset config not found at $DATASET_CONFIG. Ensure it points to $OUTPUT_JSON"
fi
cd /workspace/auto
echo "[QIE] Generating metadata: $OUTPUT_JSON"
python create_image_caption_json.py \
-i "${DATASET_DIR%/}/${IMAGE_FOLDER}" \
-c "$CAPTION" \
-o "$OUTPUT_JSON" \
--image-dir "${DATASET_DIR%/}/${IMAGE_FOLDER}" \
"${CONTROL_ARGS[@]}"
cd /musubi-tuner
python qwen_image_cache_latents.py \
--dataset_config "$DATASET_CONFIG" \
--vae "/workspace/Qwen-Image_models/vae/diffusion_pytorch_model.safetensors" \
--edit_plus \
--vae_spatial_tile_sample_min_size 16384
python qwen_image_cache_text_encoder_outputs.py \
--dataset_config "$DATASET_CONFIG" \
--text_encoder "/workspace/Qwen-Image_models/text_encoder/qwen_2.5_vl_7b.safetensors" \
--edit_plus \
--batch_size 16
accelerate launch src/musubi_tuner/qwen_image_train_network.py \
--edit_plus \
--dit "/workspace/Qwen-Image_models/dit/qwen_image_edit_2509_bf16.safetensors" \
--vae "/workspace/Qwen-Image_models/vae/diffusion_pytorch_model.safetensors" \
--text_encoder "/workspace/Qwen-Image_models/text_encoder/qwen_2.5_vl_7b.safetensors" \
--dataset_config "$DATASET_CONFIG" \
--mixed_precision bf16 \
--sdpa \
--timestep_sampling shift \
--weighting_scheme none \
--discrete_flow_shift 2.0 \
--optimizer_type adamw8bit \
--learning_rate "$LEARNING_RATE" \
--gradient_checkpointing \
--max_data_loader_n_workers 2 \
--persistent_data_loader_workers \
--network_module networks.lora_qwen_image \
--network_dim "$NETWORK_DIM" \
--max_train_epochs "$MAX_TRAIN_EPOCHS" \
--save_every_n_epochs "$SAVE_EVERY_N_EPOCHS" \
--seed "$SEED" \
--output_dir "${OUTPUT_DIR_BASE}/${RUN_NAME}" \
--output_name "${RUN_NAME}" \
--ddp_gradient_as_bucket_view \
--ddp_static_graph