Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,122 +1,81 @@
|
|
| 1 |
-
# app.py (
|
| 2 |
import gradio as gr
|
| 3 |
import subprocess
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
-
import pandas as pd
|
| 7 |
|
| 8 |
-
# 공통 설정
|
| 9 |
STATIONS = [
|
| 10 |
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
|
| 11 |
"DT_0050", "DT_0017", "DT_0052", "DT_0025", "DT_0051", "DT_0037",
|
| 12 |
"DT_0024", "DT_0018", "DT_0068", "DT_0003", "DT_0066"
|
| 13 |
]
|
| 14 |
-
# 지원하는 관측소 목록
|
| 15 |
|
| 16 |
def get_common_args(station_id):
|
| 17 |
-
|
| 18 |
-
# 이 부분은 inference.py의 기본값과 일치하거나, .sh 파일의 설정을 따릅니다.
|
| 19 |
-
return [
|
| 20 |
-
"--model", "TimeXer",
|
| 21 |
-
"--features", "MS",
|
| 22 |
-
"--seq_len", "144",
|
| 23 |
-
"--pred_len", "72",
|
| 24 |
-
"--label_len", "96",
|
| 25 |
-
"--enc_in", "5",
|
| 26 |
-
"--dec_in", "5",
|
| 27 |
-
"--c_out", "1",
|
| 28 |
-
"--d_model", "256",
|
| 29 |
-
"--d_ff", "512",
|
| 30 |
-
"--n_heads", "8",
|
| 31 |
-
"--e_layers", "1",
|
| 32 |
-
"--d_layers", "1",
|
| 33 |
-
"--factor", "3",
|
| 34 |
-
"--patch_len", "16",
|
| 35 |
-
"--expand", "2",
|
| 36 |
-
"--d_conv", "4",
|
| 37 |
-
]
|
| 38 |
|
| 39 |
def execute_inference(command):
|
| 40 |
-
|
| 41 |
-
try:
|
| 42 |
-
print(f"Running command: {' '.join(command)}")
|
| 43 |
-
# 스크립트 실행
|
| 44 |
-
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
| 45 |
-
# 스크립트의 표준 출력(stdout)을 그대로 반환 (디버깅용)
|
| 46 |
-
# stderr에 담긴 진행 과정 로그도 함께 보여주면 좋습니다.
|
| 47 |
-
output = f"--- STDOUT ---\n{result.stdout}\n\n--- STDERR ---\n{result.stderr}"
|
| 48 |
-
return output
|
| 49 |
-
except subprocess.CalledProcessError as e:
|
| 50 |
-
# 스크립트 실행 중 오류 발생 시
|
| 51 |
-
error_message = f"스크립트 실행 실패:\n--- STDERR ---\n{e.stderr}"
|
| 52 |
-
raise gr.Error(error_message)
|
| 53 |
-
except Exception as e:
|
| 54 |
-
raise gr.Error(f"알 수 없는 오류 발생: {str(e)}")
|
| 55 |
|
| 56 |
# --- 모드 1: 단일 미래 예측 ---
|
| 57 |
def single_prediction(station_id, input_csv_file):
|
| 58 |
if input_csv_file is None:
|
| 59 |
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
|
| 60 |
|
| 61 |
-
# 공통 인자 가져오기
|
| 62 |
common_args = get_common_args(station_id)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth"
|
| 66 |
-
scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz"
|
| 67 |
|
| 68 |
-
# 실행할 명령어 조합
|
| 69 |
command = [
|
| 70 |
"python", "inference.py",
|
| 71 |
"--checkpoint_path", checkpoint_path,
|
| 72 |
"--scaler_path", scaler_path,
|
| 73 |
-
"--predict_input_file", input_csv_file.name,
|
| 74 |
] + common_args
|
| 75 |
|
| 76 |
return execute_inference(command)
|
| 77 |
|
| 78 |
# --- 모드 2: 전체 기간 롤링 평가 ---
|
| 79 |
-
def rolling_evaluation(station_id):
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
common_args = get_common_args(station_id)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
|
| 85 |
-
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
|
| 86 |
-
scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
|
| 87 |
-
evaluate_file_path = f"./dataset/{station_id}_final_training_data.csv"
|
| 88 |
|
| 89 |
-
# 실행할 명령어 조합
|
| 90 |
command = [
|
| 91 |
"python", "inference.py",
|
| 92 |
"--checkpoint_path", checkpoint_path,
|
| 93 |
"--scaler_path", scaler_path,
|
| 94 |
-
"--evaluate_file",
|
| 95 |
] + common_args
|
| 96 |
|
| 97 |
return execute_inference(command)
|
| 98 |
|
| 99 |
|
| 100 |
-
# --- Gradio UI 구성 ---
|
| 101 |
with gr.Blocks() as demo:
|
| 102 |
-
gr.Markdown("# 조위 예측 모델
|
| 103 |
|
| 104 |
with gr.Tabs():
|
| 105 |
with gr.TabItem("단일 미래 예측"):
|
| 106 |
with gr.Row():
|
| 107 |
station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
|
| 108 |
-
|
| 109 |
submit_btn1 = gr.Button("예측 실행")
|
| 110 |
output1 = gr.Textbox(label="실행 결과", lines=15)
|
| 111 |
|
| 112 |
with gr.TabItem("전체 기간 롤링 평가"):
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
| 114 |
submit_btn2 = gr.Button("평가 실행")
|
| 115 |
output2 = gr.Textbox(label="실행 결과", lines=15)
|
| 116 |
|
| 117 |
-
# 버튼 클릭 이벤트 연결
|
| 118 |
-
submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1,
|
| 119 |
-
submit_btn2.click(fn=rolling_evaluation, inputs=station_dropdown2, outputs=output2)
|
| 120 |
|
| 121 |
if __name__ == "__main__":
|
| 122 |
demo.launch()
|
|
|
|
| 1 |
+
# app.py (v4.0 - 최종 버전)
|
| 2 |
import gradio as gr
|
| 3 |
import subprocess
|
| 4 |
import json
|
| 5 |
import os
|
|
|
|
| 6 |
|
|
|
|
| 7 |
STATIONS = [
|
| 8 |
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
|
| 9 |
"DT_0050", "DT_0017", "DT_0052", "DT_0025", "DT_0051", "DT_0037",
|
| 10 |
"DT_0024", "DT_0018", "DT_0068", "DT_0003", "DT_0066"
|
| 11 |
]
|
|
|
|
| 12 |
|
| 13 |
def get_common_args(station_id):
|
| 14 |
+
# ... (이전과 동일) ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def execute_inference(command):
|
| 17 |
+
# ... (이전과 동일) ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# --- 모드 1: 단일 미래 예측 ---
|
| 20 |
def single_prediction(station_id, input_csv_file):
|
| 21 |
if input_csv_file is None:
|
| 22 |
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
|
| 23 |
|
|
|
|
| 24 |
common_args = get_common_args(station_id)
|
| 25 |
+
checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../checkpoint.pth" # 이름이 길어서 생략
|
| 26 |
+
scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../scaler.gz"
|
|
|
|
|
|
|
| 27 |
|
|
|
|
| 28 |
command = [
|
| 29 |
"python", "inference.py",
|
| 30 |
"--checkpoint_path", checkpoint_path,
|
| 31 |
"--scaler_path", scaler_path,
|
| 32 |
+
"--predict_input_file", input_csv_file.name,
|
| 33 |
] + common_args
|
| 34 |
|
| 35 |
return execute_inference(command)
|
| 36 |
|
| 37 |
# --- 모드 2: 전체 기간 롤링 평가 ---
|
| 38 |
+
def rolling_evaluation(station_id, eval_csv_file):
|
| 39 |
+
if eval_csv_file is None:
|
| 40 |
+
raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
|
| 41 |
+
|
| 42 |
common_args = get_common_args(station_id)
|
| 43 |
+
checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../checkpoint.pth"
|
| 44 |
+
scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../scaler.gz"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
|
|
|
| 46 |
command = [
|
| 47 |
"python", "inference.py",
|
| 48 |
"--checkpoint_path", checkpoint_path,
|
| 49 |
"--scaler_path", scaler_path,
|
| 50 |
+
"--evaluate_file", eval_csv_file.name,
|
| 51 |
] + common_args
|
| 52 |
|
| 53 |
return execute_inference(command)
|
| 54 |
|
| 55 |
|
| 56 |
+
# --- Gradio UI 구성 (⭐️⭐️⭐️ 이 부분이 변경되었습니다 ⭐️⭐️⭐️) ---
|
| 57 |
with gr.Blocks() as demo:
|
| 58 |
+
gr.Markdown("# 조위 예측 모델 v4.0")
|
| 59 |
|
| 60 |
with gr.Tabs():
|
| 61 |
with gr.TabItem("단일 미래 예측"):
|
| 62 |
with gr.Row():
|
| 63 |
station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
|
| 64 |
+
file_input1 = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"])
|
| 65 |
submit_btn1 = gr.Button("예측 실행")
|
| 66 |
output1 = gr.Textbox(label="실행 결과", lines=15)
|
| 67 |
|
| 68 |
with gr.TabItem("전체 기간 롤링 평가"):
|
| 69 |
+
with gr.Row():
|
| 70 |
+
station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
|
| 71 |
+
# ⭐️ '롤링 평가' 탭에도 파일 업로드 기능 추가
|
| 72 |
+
file_input2 = gr.File(label="평가용 전체 데이터 (.csv 파일)", file_types=[".csv"])
|
| 73 |
submit_btn2 = gr.Button("평가 실행")
|
| 74 |
output2 = gr.Textbox(label="실행 결과", lines=15)
|
| 75 |
|
| 76 |
+
# 버튼 클릭 이벤트 연결 (file_input2 추가)
|
| 77 |
+
submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input1], outputs=output1)
|
| 78 |
+
submit_btn2.click(fn=rolling_evaluation, inputs=[station_dropdown2, file_input2], outputs=output2)
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
| 81 |
demo.launch()
|