Spaces:
Sleeping
Sleeping
Update script.py
Browse files
script.py
CHANGED
|
@@ -1,9 +1,20 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from huggingface_hub import snapshot_download, delete_repo, metadata_update
|
| 3 |
import uuid
|
| 4 |
import json
|
| 5 |
import yaml
|
| 6 |
import subprocess
|
|
|
|
|
|
|
| 7 |
|
| 8 |
#from huggingface_hub import login
|
| 9 |
#HF_TOKEN = os.getenv("HF_TOKEN")
|
|
@@ -12,9 +23,36 @@ import subprocess
|
|
| 12 |
#else:
|
| 13 |
# raise ValueError("HF_TOKEN environment variable not found!")
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def download_dataset(hf_dataset_path: str):
|
|
@@ -70,39 +108,80 @@ def run_training(hf_dataset_path: str):
|
|
| 70 |
|
| 71 |
dataset_dir = download_dataset(hf_dataset_path)
|
| 72 |
dataset_dir = process_dataset(dataset_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# run training
|
| 75 |
if not os.path.exists("ai-toolkit"):
|
| 76 |
-
commands = "git clone https://github.com/
|
| 77 |
subprocess.run(commands, shell=True)
|
| 78 |
|
|
|
|
| 79 |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
|
| 80 |
-
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
return process, dataset_dir
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
process, dataset_dir = run_training(HF_DATASET)
|
| 87 |
-
process.wait() # Wait for the training process to finish
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
"
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 3 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
| 4 |
+
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
| 5 |
+
|
| 6 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 7 |
+
#HF_DATASET = os.environ.get("HF_DATASET")
|
| 8 |
+
HF_DATASET = "DevWild/autotrain-pr0b0rk"
|
| 9 |
+
repo_id = os.environ.get("MODEL_REPO_ID")
|
| 10 |
+
|
| 11 |
from huggingface_hub import snapshot_download, delete_repo, metadata_update
|
| 12 |
import uuid
|
| 13 |
import json
|
| 14 |
import yaml
|
| 15 |
import subprocess
|
| 16 |
+
import sys
|
| 17 |
+
from typing import Optional
|
| 18 |
|
| 19 |
#from huggingface_hub import login
|
| 20 |
#HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
| 23 |
#else:
|
| 24 |
# raise ValueError("HF_TOKEN environment variable not found!")
|
| 25 |
|
| 26 |
+
|
| 27 |
+
if not HF_TOKEN:
|
| 28 |
+
raise ValueError("Missing HF_TOKEN")
|
| 29 |
+
|
| 30 |
+
if not HF_DATASET:
|
| 31 |
+
raise ValueError("Missing HF_DATASET")
|
| 32 |
+
|
| 33 |
+
if not repo_id:
|
| 34 |
+
raise ValueError("Missing MODEL_REPO_ID")
|
| 35 |
+
|
| 36 |
+
# Prevent running script.py twice
|
| 37 |
+
LOCKFILE = "/tmp/.script_lock"
|
| 38 |
+
if os.path.exists(LOCKFILE):
|
| 39 |
+
print("π Script already ran once β skipping.")
|
| 40 |
+
exit(0)
|
| 41 |
+
|
| 42 |
+
with open(LOCKFILE, "w") as f:
|
| 43 |
+
f.write("lock")
|
| 44 |
+
|
| 45 |
+
print("π Running script for the first time")
|
| 46 |
+
|
| 47 |
+
# START logging
|
| 48 |
+
print("π ENV DEBUG START")
|
| 49 |
+
print("HF_TOKEN present?", bool(HF_TOKEN))
|
| 50 |
+
print("HF_DATASET:", HF_DATASET)
|
| 51 |
+
print("MODEL_REPO_ID:", repo_id)
|
| 52 |
+
print("π ENV DEBUG END")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#dataset_dir = snapshot_download(HF_DATASET, token=HF_TOKEN)
|
| 56 |
|
| 57 |
|
| 58 |
def download_dataset(hf_dataset_path: str):
|
|
|
|
| 108 |
|
| 109 |
dataset_dir = download_dataset(hf_dataset_path)
|
| 110 |
dataset_dir = process_dataset(dataset_dir)
|
| 111 |
+
# Force repo_id override in config.yaml
|
| 112 |
+
config_path = os.path.join(dataset_dir, "config.yaml")
|
| 113 |
+
with open(config_path, "r") as f:
|
| 114 |
+
config = yaml.safe_load(f)
|
| 115 |
+
|
| 116 |
+
config["config"]["process"][0]["save"]["hf_repo_id"] = repo_id
|
| 117 |
+
|
| 118 |
+
with open(config_path, "w") as f:
|
| 119 |
+
yaml.dump(config, f)
|
| 120 |
+
|
| 121 |
+
print("β
Updated config.yaml with MODEL_REPO_ID:", repo_id)
|
| 122 |
+
|
| 123 |
|
| 124 |
# run training
|
| 125 |
if not os.path.exists("ai-toolkit"):
|
| 126 |
+
commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
|
| 127 |
subprocess.run(commands, shell=True)
|
| 128 |
|
| 129 |
+
patch_ai_toolkit_typing()
|
| 130 |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
|
| 131 |
+
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
|
| 132 |
|
| 133 |
+
# Stream logs to Space output
|
| 134 |
+
for line in process.stdout:
|
| 135 |
+
sys.stdout.write(line)
|
| 136 |
+
sys.stdout.flush()
|
| 137 |
+
|
| 138 |
return process, dataset_dir
|
| 139 |
|
| 140 |
+
def patch_ai_toolkit_typing():
|
| 141 |
+
config_path = "ai-toolkit/toolkit/config_modules.py"
|
| 142 |
+
if os.path.exists(config_path):
|
| 143 |
+
with open(config_path, "r") as f:
|
| 144 |
+
content = f.read()
|
| 145 |
|
| 146 |
+
content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
with open(config_path, "w") as f:
|
| 149 |
+
f.write(content)
|
| 150 |
+
print("β
Patched ai-toolkit typing for torch.Tensor | None β Optional[torch.Tensor]")
|
| 151 |
+
else:
|
| 152 |
+
print("β οΈ Could not patch config_modules.py β file not found")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
try:
|
| 157 |
+
process, dataset_dir = run_training(HF_DATASET)
|
| 158 |
+
# process.wait() # Wait for the training process to finish
|
| 159 |
+
exit_code = process.wait()
|
| 160 |
+
print("Training finished with exit code:", exit_code)
|
| 161 |
+
|
| 162 |
+
if exit_code != 0:
|
| 163 |
+
raise RuntimeError(f"Training failed with exit code {exit_code}")
|
| 164 |
+
|
| 165 |
+
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
|
| 166 |
+
config = yaml.safe_load(f)
|
| 167 |
+
#repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]
|
| 168 |
+
#repo_id = os.environ.get("MODEL_REPO_ID")
|
| 169 |
+
#repo_id = os.getenv("MODEL_REPO_ID")
|
| 170 |
+
#repo_id = "DevWild/suppab0rk"
|
| 171 |
+
|
| 172 |
+
metadata = {
|
| 173 |
+
"tags": [
|
| 174 |
+
"autotrain",
|
| 175 |
+
"spacerunner",
|
| 176 |
+
"text-to-image",
|
| 177 |
+
"flux",
|
| 178 |
+
"lora",
|
| 179 |
+
"diffusers",
|
| 180 |
+
"template:sd-lora",
|
| 181 |
+
]
|
| 182 |
+
}
|
| 183 |
+
metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True)
|
| 184 |
+
|
| 185 |
+
finally:
|
| 186 |
+
#delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True)
|
| 187 |
+
print("SCRIPT FINISHED, DATASET SHOULD BE DELETED")
|