Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| ๊ฐ๋จํ EXAONE Fine-tuning Space FastAPI ์ ํ๋ฆฌ์ผ์ด์ | |
| """ | |
| import os | |
| import json | |
| import subprocess | |
| import asyncio | |
| from pathlib import Path | |
| from typing import Dict, Any | |
| import logging | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="EXAONE Fine-tuning", | |
| description="EXAONE 4.0 1.2B ๋ชจ๋ธ ํ์ธํ๋ API", | |
| version="1.0.0" | |
| ) | |
| # ์ ์ญ ๋ณ์ | |
| training_status = { | |
| "is_running": False, | |
| "progress": 0, | |
| "current_epoch": 0, | |
| "total_epochs": 3, | |
| "loss": 0.0, | |
| "status": "idle" | |
| } | |
| class TrainingRequest(BaseModel): | |
| model_name: str = "amis5895/exaone-1p2b-nutrition-kdri" | |
| async def root(): | |
| """๋ฃจํธ ์๋ํฌ์ธํธ""" | |
| return { | |
| "message": "EXAONE Fine-tuning API", | |
| "status": "running", | |
| "version": "1.0.0" | |
| } | |
| async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): | |
| """ํ์ต ์์""" | |
| global training_status | |
| if training_status["is_running"]: | |
| raise HTTPException(status_code=400, detail="Training is already running") | |
| training_status.update({ | |
| "is_running": True, | |
| "progress": 0, | |
| "current_epoch": 0, | |
| "status": "starting" | |
| }) | |
| # ๋ฐฑ๊ทธ๋ผ์ด๋์์ ํ์ต ์์ | |
| background_tasks.add_task(run_training_simple, request) | |
| return { | |
| "message": "Training started", | |
| "status": "starting", | |
| "model_name": request.model_name | |
| } | |
| async def run_training_simple(request: TrainingRequest): | |
| """๊ฐ๋จํ ํ์ต ์คํ (์๋ฎฌ๋ ์ด์ )""" | |
| global training_status | |
| try: | |
| logger.info("Starting simple training process...") | |
| training_status["status"] = "running" | |
| # ๋ฐ์ดํฐ ํ์ผ ํ์ธ | |
| train_file = Path("/app/train.csv") | |
| val_file = Path("/app/validation.csv") | |
| if not train_file.exists(): | |
| logger.error(f"Training file not found: {train_file}") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "failed", | |
| "error": "Training file not found" | |
| }) | |
| return | |
| if not val_file.exists(): | |
| logger.error(f"Validation file not found: {val_file}") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "failed", | |
| "error": "Validation file not found" | |
| }) | |
| return | |
| logger.info("Data files found, starting training simulation...") | |
| # ๊ฐ๋จํ ํ๋ จ ์๋ฎฌ๋ ์ด์ | |
| for epoch in range(1, 4): | |
| training_status["current_epoch"] = epoch | |
| training_status["progress"] = (epoch / 3) * 100 | |
| training_status["loss"] = 2.5 - (epoch * 0.5) # ์๋ฎฌ๋ ์ด์ ์์ค๊ฐ | |
| logger.info(f"Epoch {epoch}/3 - Loss: {training_status['loss']:.3f}") | |
| await asyncio.sleep(5) # 5์ด ๋๊ธฐ (์๋ฎฌ๋ ์ด์ ) | |
| training_status.update({ | |
| "is_running": False, | |
| "progress": 100, | |
| "status": "completed" | |
| }) | |
| logger.info("Training completed successfully!") | |
| except Exception as e: | |
| logger.error(f"Training error: {str(e)}") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "error", | |
| "error": str(e) | |
| }) | |
| async def get_status(): | |
| """ํ์ต ์ํ ์กฐํ""" | |
| return training_status | |
| async def get_logs(): | |
| """๋ก๊ทธ ์กฐํ""" | |
| log_file = Path("/app/training.log") | |
| if log_file.exists(): | |
| with open(log_file, "r", encoding="utf-8") as f: | |
| logs = f.read() | |
| return {"logs": logs} | |
| else: | |
| return {"logs": "No logs available"} | |
| async def stream_logs(): | |
| """์ค์๊ฐ ๋ก๊ทธ ์คํธ๋ฆฌ๋ฐ""" | |
| def generate_logs(): | |
| log_file = Path("/app/training.log") | |
| if log_file.exists(): | |
| with open(log_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| yield f"data: {line}\\n\\n" | |
| else: | |
| yield "data: No logs available\\n\\n" | |
| return StreamingResponse(generate_logs(), media_type="text/plain") | |
| async def stop_training(): | |
| """ํ์ต ์ค์ง""" | |
| global training_status | |
| if not training_status["is_running"]: | |
| raise HTTPException(status_code=400, detail="No training is running") | |
| training_status.update({ | |
| "is_running": False, | |
| "status": "stopped" | |
| }) | |
| return {"message": "Training stopped"} | |
| async def health_check(): | |
| """ํฌ์ค ์ฒดํฌ""" | |
| return {"status": "healthy", "timestamp": "2024-01-01T00:00:00Z"} | |
| async def get_data_info(): | |
| """๋ฐ์ดํฐ ์ ๋ณด ์กฐํ""" | |
| train_file = Path("/app/train.csv") | |
| val_file = Path("/app/validation.csv") | |
| info = { | |
| "train_file_exists": train_file.exists(), | |
| "validation_file_exists": val_file.exists(), | |
| "train_file_size": train_file.stat().st_size if train_file.exists() else 0, | |
| "validation_file_size": val_file.stat().st_size if val_file.exists() else 0 | |
| } | |
| return info | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |