my-tide-env / app.py
alwaysgood's picture
Update app.py
5431a0f verified
raw
history blame
4.71 kB
import gradio as gr
import subprocess
import json
import os
STATIONS = [
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
"DT_0050", "DT_0017", "DT_0052", "DT_0025", "DT_0051", "DT_0037",
"DT_0024", "DT_0018", "DT_0068", "DT_0003", "DT_0066"
]
# --- ⭐️⭐️⭐️ 이 함수 안의 내용이 빠져 있었습니다 ⭐️⭐️⭐️ ---
def get_common_args(station_id):
"""관측소별 공통 인자를 반환하는 함수"""
# inference.py의 인자들을 그대로 가져옵니다.
return [
"--model", "TimeXer",
"--features", "MS",
"--seq_len", "144",
"--pred_len", "72",
"--label_len", "96",
"--enc_in", "5",
"--dec_in", "5",
"--c_out", "1",
"--d_model", "256",
"--d_ff", "512",
"--n_heads", "8",
"--e_layers", "1",
"--d_layers", "1",
"--factor", "3",
"--patch_len", "16",
"--expand", "2",
"--d_conv", "4",
]
def execute_inference(command):
"""subprocess를 실행하고 결과를 처리하는 공통 함수"""
try:
print(f"Running command: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True, check=True)
# inference.py의 마지막 줄에 있는 JSON 출력만 파싱합니다.
last_line = result.stdout.strip().split('\n')[-1]
parsed_json = json.loads(last_line)
# JSON을 예쁘게 포맷팅해서 반환합니다.
return json.dumps(parsed_json, indent=2)
except subprocess.CalledProcessError as e:
error_message = f"스크립트 실행 실패:\n--- STDERR ---\n{e.stderr}\n--- STDOUT ---\n{e.stdout}"
raise gr.Error(error_message)
except Exception as e:
raise gr.Error(f"알 수 없는 오류 발생: {str(e)}")
# --- 모드 1: 단일 미래 예측 ---
def single_prediction(station_id, input_csv_file):
if input_csv_file is None:
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
common_args = get_common_args(station_id)
# ⭐️ 긴 체크포인트 경로를 변수로 정리
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
command = [
"python", "inference.py",
"--checkpoint_path", checkpoint_path,
"--scaler_path", scaler_path,
"--predict_input_file", input_csv_file.name,
] + common_args
return execute_inference(command)
# --- 모드 2: 전체 기간 롤링 평가 ---
def rolling_evaluation(station_id, eval_csv_file):
if eval_csv_file is None:
raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
common_args = get_common_args(station_id)
# ⭐️ 긴 체크포인트 경로를 변수로 정리
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
command = [
"python", "inference.py",
"--checkpoint_path", checkpoint_path,
"--scaler_path", scaler_path,
"--evaluate_file", eval_csv_file.name,
] + common_args
return execute_inference(command)
# --- Gradio UI 구성 ---
with gr.Blocks() as demo:
gr.Markdown("# 조위 예측 모델 v4.0")
with gr.Tabs():
with gr.TabItem("단일 미래 예측"):
with gr.Row():
station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
file_input1 = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"])
submit_btn1 = gr.Button("예측 실행")
output1 = gr.Textbox(label="실행 결과", lines=15)
with gr.TabItem("전체 기간 롤링 평가"):
with gr.Row():
station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
file_input2 = gr.File(label="평가용 전체 데이터 (.csv 파일)", file_types=[".csv"])
submit_btn2 = gr.Button("평가 실행")
output2 = gr.Textbox(label="실행 결과", lines=15)
submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input1], outputs=output1)
submit_btn2.click(fn=rolling_evaluation, inputs=[station_dropdown2, file_input2], outputs=output2)
if __name__ == "__main__":
demo.launch()