Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| ์ฌ๋ฐ๋ฅธ AutoTrain ๋ช ๋ น์ด๋ฅผ ์ฌ์ฉํ EXAONE Fine-tuning Space FastAPI ์ ํ๋ฆฌ์ผ์ด์ | |
| """ | |
| import os | |
| import json | |
| import subprocess | |
| import asyncio | |
| from pathlib import Path | |
| from typing import Dict, Any | |
| import logging | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="EXAONE Fine-tuning", | |
| description="EXAONE 4.0 1.2B ๋ชจ๋ธ ํ์ธํ๋ API", | |
| version="1.0.0" | |
| ) | |
| # ์ ์ญ ๋ณ์ | |
| training_status = { | |
| "is_running": False, | |
| "progress": 0, | |
| "current_epoch": 0, | |
| "total_epochs": 3, | |
| "loss": 0.0, | |
| "status": "idle", | |
| "log_file": "/tmp/training.log" | |
| } | |
| class TrainingRequest(BaseModel): | |
| model_name: str = "amis5895/exaone-1p2b-nutrition-kdri" | |
| async def root(): | |
| """๋ฃจํธ ์๋ํฌ์ธํธ""" | |
| return { | |
| "message": "EXAONE Fine-tuning API", | |
| "status": "running", | |
| "version": "1.0.0" | |
| } | |
| async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): | |
| """ํ์ต ์์""" | |
| global training_status | |
| if training_status["is_running"]: | |
| raise HTTPException(status_code=400, detail="Training is already running") | |
| training_status.update({ | |
| "is_running": True, | |
| "progress": 0, | |
| "current_epoch": 0, | |
| "status": "starting" | |
| }) | |
| # ๋ฐฑ๊ทธ๋ผ์ด๋์์ ํ์ต ์์ | |
| background_tasks.add_task(run_corrected_training, request) | |
| return { | |
| "message": "Training started", | |
| "status": "starting", | |
| "model_name": request.model_name | |
| } | |
| async def run_corrected_training(request: TrainingRequest): | |
| """์์ ๋ AutoTrain์ ์ฌ์ฉํ ํ์ต ์คํ""" | |
| global training_status | |
| try: | |
| logger.info("Starting corrected AutoTrain training process...") | |
| training_status["status"] = "running" | |
| # ๋ฐ์ดํฐ ํ์ผ ํ์ธ | |
| train_file = Path("/app/train.csv") | |
| val_file = Path("/app/validation.csv") | |
| config_file = Path("/app/autotrain_ultra_low_final.yaml") | |
| if not train_file.exists(): | |
| logger.error(f"Training file not found: {train_file}") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "failed", | |
| "error": "Training file not found" | |
| }) | |
| return | |
| if not val_file.exists(): | |
| logger.error(f"Validation file not found: {val_file}") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "failed", | |
| "error": "Validation file not found" | |
| }) | |
| return | |
| if not config_file.exists(): | |
| logger.error(f"Config file not found: {config_file}") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "failed", | |
| "error": "Config file not found" | |
| }) | |
| return | |
| logger.info("All files found, starting corrected AutoTrain training...") | |
| # ๋ก๊ทธ ํ์ผ ์ด๊ธฐํ | |
| log_file = Path(training_status["log_file"]) | |
| try: | |
| log_file.write_text("Starting corrected AutoTrain training...\n", encoding="utf-8") | |
| except Exception as e: | |
| logger.warning(f"Could not write to log file: {e}") | |
| training_status["log_content"] = "Starting corrected AutoTrain training...\n" | |
| # ํ๊ฒฝ๋ณ์ ์ค์ | |
| env = os.environ.copy() | |
| env["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" | |
| env["HF_HOME"] = "/tmp/huggingface" | |
| env["OMP_NUM_THREADS"] = "1" | |
| # ์์ ๋ AutoTrain ๋ช ๋ น์ด (์ฌ๋ฐ๋ฅธ ํ์ ์ฌ์ฉ) | |
| cmd = [ | |
| "autotrain", "llm", | |
| "--train", | |
| "--project_name", "exaone-finetuning", | |
| "--model", "LGAI-EXAONE/EXAONE-4.0-1.2B", | |
| "--data_path", "/app", | |
| "--text_column", "text", | |
| "--use-peft", # --use_peft ๋์ --use-peft | |
| "--quantization", "int4", | |
| "--lora-r", "16", # --lora_r ๋์ --lora-r | |
| "--lora-alpha", "32", # --lora_alpha ๋์ --lora-alpha | |
| "--lora-dropout", "0.05", # --lora_dropout ๋์ --lora-dropout | |
| "--target-modules", "all-linear", # --target_modules ๋์ --target-modules | |
| "--epochs", "3", | |
| "--batch-size", "4", # --batch_size ๋์ --batch-size | |
| "--gradient-accumulation", "4", # --gradient_accumulation ๋์ --gradient-accumulation | |
| "--learning-rate", "2e-4", # --learning_rate ๋์ --learning-rate | |
| "--warmup-ratio", "0.03", # --warmup_ratio ๋์ --warmup-ratio | |
| "--mixed-precision", "fp16", # --mixed_precision ๋์ --mixed-precision | |
| "--push-to-hub", # --push_to_hub ๋์ --push-to-hub | |
| "--hub-model-id", request.model_name, # --hub_model_id ๋์ --hub-model-id | |
| "--username", "amis5895" | |
| ] | |
| logger.info(f"Running corrected command: {' '.join(cmd)}") | |
| # ๋ก๊ทธ ํ์ผ์ ๋ช ๋ น์ด ๊ธฐ๋ก | |
| try: | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(f"Corrected Command: {' '.join(cmd)}\n") | |
| f.write("=" * 50 + "\n") | |
| except: | |
| if "log_content" not in training_status: | |
| training_status["log_content"] = "" | |
| training_status["log_content"] += f"Corrected Command: {' '.join(cmd)}\n" + "=" * 50 + "\n" | |
| # AutoTrain ํ๋ก์ธ์ค ์คํ | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| cwd="/app", | |
| env=env | |
| ) | |
| # ํ์ต ์งํ ์ํฉ ๋ชจ๋ํฐ๋ง | |
| for line in process.stdout: | |
| logger.info(line.strip()) | |
| # ๋ก๊ทธ ํ์ผ์ ๊ธฐ๋ก | |
| try: | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(line) | |
| except: | |
| if "log_content" not in training_status: | |
| training_status["log_content"] = "" | |
| training_status["log_content"] += line | |
| # ์งํ๋ฅ ํ์ฑ | |
| if "epoch" in line.lower() and "/" in line: | |
| try: | |
| # "Epoch 1/3" ํํ์์ ์งํ๋ฅ ์ถ์ถ | |
| parts = line.split() | |
| for i, part in enumerate(parts): | |
| if part.lower() == "epoch" and i + 1 < len(parts): | |
| epoch_info = parts[i + 1] | |
| if "/" in epoch_info: | |
| current, total = epoch_info.split("/") | |
| training_status["current_epoch"] = int(current) | |
| training_status["total_epochs"] = int(total) | |
| training_status["progress"] = (int(current) / int(total)) * 100 | |
| break | |
| except: | |
| pass | |
| # ์์ค๊ฐ ํ์ฑ | |
| if "loss" in line.lower(): | |
| try: | |
| parts = line.split() | |
| for i, part in enumerate(parts): | |
| if part.lower() == "loss" and i + 1 < len(parts): | |
| loss_value = float(parts[i + 1]) | |
| training_status["loss"] = loss_value | |
| break | |
| except: | |
| pass | |
| process.wait() | |
| if process.returncode == 0: | |
| training_status.update({ | |
| "is_running": False, | |
| "progress": 100, | |
| "status": "completed" | |
| }) | |
| logger.info("Training completed successfully!") | |
| # ์๋ฃ ๋ก๊ทธ ๊ธฐ๋ก | |
| try: | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write("\n" + "=" * 50 + "\n") | |
| f.write("Training completed successfully!\n") | |
| except: | |
| if "log_content" not in training_status: | |
| training_status["log_content"] = "" | |
| training_status["log_content"] += "\n" + "=" * 50 + "\nTraining completed successfully!\n" | |
| else: | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "failed" | |
| }) | |
| logger.error("Training failed!") | |
| # ์คํจ ๋ก๊ทธ ๊ธฐ๋ก | |
| try: | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write("\n" + "=" * 50 + "\n") | |
| f.write(f"Training failed with return code: {process.returncode}\n") | |
| except: | |
| if "log_content" not in training_status: | |
| training_status["log_content"] = "" | |
| training_status["log_content"] += "\n" + "=" * 50 + f"\nTraining failed with return code: {process.returncode}\n" | |
| except Exception as e: | |
| logger.error(f"Training error: {str(e)}") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "error", | |
| "error": str(e) | |
| }) | |
| # ์ค๋ฅ ๋ก๊ทธ ๊ธฐ๋ก | |
| try: | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(f"\nError: {str(e)}\n") | |
| except: | |
| if "log_content" not in training_status: | |
| training_status["log_content"] = "" | |
| training_status["log_content"] += f"\nError: {str(e)}\n" | |
| async def get_status(): | |
| """ํ์ต ์ํ ์กฐํ""" | |
| return training_status | |
| async def get_logs(): | |
| """๋ก๊ทธ ์กฐํ""" | |
| log_file = Path(training_status["log_file"]) | |
| if log_file.exists(): | |
| try: | |
| with open(log_file, "r", encoding="utf-8") as f: | |
| logs = f.read() | |
| return {"logs": logs} | |
| except: | |
| pass | |
| # ํ์ผ์ ์ฝ์ ์ ์์ผ๋ฉด ๋ฉ๋ชจ๋ฆฌ์์ ๊ฐ์ ธ์ค๊ธฐ | |
| if "log_content" in training_status: | |
| return {"logs": training_status["log_content"]} | |
| else: | |
| return {"logs": "No logs available"} | |
| async def stream_logs(): | |
| """์ค์๊ฐ ๋ก๊ทธ ์คํธ๋ฆฌ๋ฐ""" | |
| def generate_logs(): | |
| log_file = Path(training_status["log_file"]) | |
| if log_file.exists(): | |
| try: | |
| with open(log_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| yield f"data: {line}\\n\\n" | |
| except: | |
| pass | |
| # ํ์ผ์ ์ฝ์ ์ ์์ผ๋ฉด ๋ฉ๋ชจ๋ฆฌ์์ ๊ฐ์ ธ์ค๊ธฐ | |
| if "log_content" in training_status: | |
| for line in training_status["log_content"].split('\n'): | |
| yield f"data: {line}\\n\\n" | |
| else: | |
| yield "data: No logs available\\n\\n" | |
| return StreamingResponse(generate_logs(), media_type="text/plain") | |
| async def stop_training(): | |
| """ํ์ต ์ค์ง""" | |
| global training_status | |
| if not training_status["is_running"]: | |
| raise HTTPException(status_code=400, detail="No training is running") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "stopped" | |
| }) | |
| return {"message": "Training stopped"} | |
| async def health_check(): | |
| """ํฌ์ค ์ฒดํฌ""" | |
| return {"status": "healthy", "timestamp": "2024-01-01T00:00:00Z"} | |
| async def get_data_info(): | |
| """๋ฐ์ดํฐ ์ ๋ณด ์กฐํ""" | |
| train_file = Path("/app/train.csv") | |
| val_file = Path("/app/validation.csv") | |
| config_file = Path("/app/autotrain_ultra_low_final.yaml") | |
| info = { | |
| "train_file_exists": train_file.exists(), | |
| "validation_file_exists": val_file.exists(), | |
| "config_file_exists": config_file.exists(), | |
| "train_file_size": train_file.stat().st_size if train_file.exists() else 0, | |
| "validation_file_size": val_file.stat().st_size if val_file.exists() else 0, | |
| "config_file_size": config_file.stat().st_size if config_file.exists() else 0 | |
| } | |
| return info | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |