Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
| os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| HF_DATASET = os.environ.get("HF_DATASET") | |
| #HF_DATASET = "DevWild/autotrain-pr0b0rk" | |
| repo_id = os.environ.get("MODEL_REPO_ID") | |
| from huggingface_hub import snapshot_download, delete_repo, metadata_update | |
| import uuid | |
| import json | |
| import yaml | |
| import subprocess | |
| import sys | |
| from typing import Optional | |
| #from huggingface_hub import login | |
| #HF_TOKEN = os.getenv("HF_TOKEN") | |
| #if HF_TOKEN: | |
| # login(token=HF_TOKEN) | |
| #else: | |
| # raise ValueError("HF_TOKEN environment variable not found!") | |
| if not HF_TOKEN: | |
| raise ValueError("Missing HF_TOKEN") | |
| if not HF_DATASET: | |
| raise ValueError("Missing HF_DATASET") | |
| if not repo_id: | |
| raise ValueError("Missing MODEL_REPO_ID") | |
| # Prevent running script.py twice | |
| LOCKFILE = "/tmp/.script_lock" | |
| if os.path.exists(LOCKFILE): | |
| print("🔁 Script already ran once — skipping.") | |
| exit(0) | |
| with open(LOCKFILE, "w") as f: | |
| f.write("lock") | |
| print("🚀 Running script for the first time") | |
| # START logging | |
| print("🚀 ENV DEBUG START") | |
| print("HF_TOKEN present?", bool(HF_TOKEN)) | |
| print("HF_DATASET:", HF_DATASET) | |
| print("MODEL_REPO_ID:", repo_id) | |
| print("🚀 ENV DEBUG END") | |
| #dataset_dir = snapshot_download(HF_DATASET, token=HF_TOKEN) | |
| def download_dataset(hf_dataset_path: str): | |
| random_id = str(uuid.uuid4()) | |
| snapshot_download( | |
| repo_id=hf_dataset_path, | |
| token=HF_TOKEN, | |
| local_dir=f"/tmp/{random_id}", | |
| repo_type="dataset", | |
| ) | |
| return f"/tmp/{random_id}" | |
| def process_dataset(dataset_dir: str): | |
| # dataset dir consists of images, config.yaml and a metadata.jsonl (optional) with fields: file_name, prompt | |
| # generate .txt files with the same name as the images with the prompt as the content | |
| # remove metadata.jsonl | |
| # return the path to the processed dataset | |
| # check if config.yaml exists | |
| if not os.path.exists(os.path.join(dataset_dir, "config.yaml")): | |
| raise ValueError("config.yaml does not exist") | |
| # check if metadata.jsonl exists | |
| if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")): | |
| metadata = [] | |
| with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f: | |
| for line in f: | |
| if len(line.strip()) > 0: | |
| metadata.append(json.loads(line)) | |
| for item in metadata: | |
| txt_path = os.path.join(dataset_dir, item["file_name"]) | |
| txt_path = txt_path.rsplit(".", 1)[0] + ".txt" | |
| with open(txt_path, "w") as f: | |
| f.write(item["prompt"]) | |
| # remove metadata.jsonl | |
| os.remove(os.path.join(dataset_dir, "metadata.jsonl")) | |
| with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: | |
| config = yaml.safe_load(f) | |
| # update config with new dataset | |
| config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir | |
| with open(os.path.join(dataset_dir, "config.yaml"), "w") as f: | |
| yaml.dump(config, f) | |
| return dataset_dir | |
| def run_training(hf_dataset_path: str): | |
| dataset_dir = download_dataset(hf_dataset_path) | |
| dataset_dir = process_dataset(dataset_dir) | |
| # Force repo_id override in config.yaml | |
| config_path = os.path.join(dataset_dir, "config.yaml") | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| config["config"]["process"][0]["save"]["hf_repo_id"] = repo_id | |
| with open(config_path, "w") as f: | |
| yaml.dump(config, f) | |
| print("✅ Updated config.yaml with MODEL_REPO_ID:", repo_id) | |
| # run training | |
| if not os.path.exists("ai-toolkit"): | |
| commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive" | |
| shutil.rmtree(os.path.join(toolkit_src, ".git"), ignore_errors=True) | |
| shutil.rmtree(os.path.join(toolkit_src, ".gitmodules"), ignore_errors=True) | |
| subprocess.run(commands, shell=True) | |
| # patch_ai_toolkit_typing() | |
| commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}" | |
| process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,) | |
| # Stream logs to Space output | |
| for line in process.stdout: | |
| sys.stdout.write(line) | |
| sys.stdout.flush() | |
| return process, dataset_dir | |
| #def patch_ai_toolkit_typing(): | |
| # config_path = "ai-toolkit/toolkit/config_modules.py" | |
| # if os.path.exists(config_path): | |
| # with open(config_path, "r") as f: | |
| # content = f.read() | |
| # content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]") | |
| # with open(config_path, "w") as f: | |
| # f.write(content) | |
| # print("✅ Patched ai-toolkit typing for torch.Tensor | None → Optional[torch.Tensor]") | |
| # else: | |
| # print("⚠️ Could not patch config_modules.py — file not found") | |
| if __name__ == "__main__": | |
| try: | |
| process, dataset_dir = run_training(HF_DATASET) | |
| # process.wait() # Wait for the training process to finish | |
| exit_code = process.wait() | |
| print("Training finished with exit code:", exit_code) | |
| if exit_code != 0: | |
| raise RuntimeError(f"Training failed with exit code {exit_code}") | |
| with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: | |
| config = yaml.safe_load(f) | |
| #repo_id = config["config"]["process"][0]["save"]["hf_repo_id"] | |
| #repo_id = os.environ.get("MODEL_REPO_ID") | |
| #repo_id = os.getenv("MODEL_REPO_ID") | |
| #repo_id = "DevWild/suppab0rk" | |
| metadata = { | |
| "tags": [ | |
| "autotrain", | |
| "spacerunner", | |
| "text-to-image", | |
| "flux", | |
| "lora", | |
| "diffusers", | |
| "template:sd-lora", | |
| ] | |
| } | |
| metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True) | |
| finally: | |
| #delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True) | |
| print("SCRIPT FINISHED, DATASET SHOULD BE DELETED") | |