Spaces:
Sleeping
Sleeping
File size: 4,706 Bytes
e1ccef5 4ccea30 e1ccef5 4ccea30 e7d8f04 5431a0f e7d8f04 e1ccef5 5431a0f 4ccea30 5431a0f e1ccef5 4ccea30 5431a0f e1ccef5 4ccea30 5431a0f 4ccea30 5431a0f e1ccef5 4ccea30 4cd8f35 4ccea30 5431a0f 4ccea30 e1ccef5 4ccea30 4cd8f35 4ccea30 5431a0f e1ccef5 4ccea30 4cd8f35 4ccea30 e1ccef5 5431a0f e1ccef5 5431a0f 4ccea30 4cd8f35 4ccea30 4cd8f35 4ccea30 4cd8f35 4ccea30 e1ccef5 4cd8f35 e1ccef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import subprocess
import json
import os
STATIONS = [
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
"DT_0050", "DT_0017", "DT_0052", "DT_0025", "DT_0051", "DT_0037",
"DT_0024", "DT_0018", "DT_0068", "DT_0003", "DT_0066"
]
# --- ⭐️⭐️⭐️ 이 함수 안의 내용이 빠져 있었습니다 ⭐️⭐️⭐️ ---
def get_common_args(station_id):
"""관측소별 공통 인자를 반환하는 함수"""
# inference.py의 인자들을 그대로 가져옵니다.
return [
"--model", "TimeXer",
"--features", "MS",
"--seq_len", "144",
"--pred_len", "72",
"--label_len", "96",
"--enc_in", "5",
"--dec_in", "5",
"--c_out", "1",
"--d_model", "256",
"--d_ff", "512",
"--n_heads", "8",
"--e_layers", "1",
"--d_layers", "1",
"--factor", "3",
"--patch_len", "16",
"--expand", "2",
"--d_conv", "4",
]
def execute_inference(command):
"""subprocess를 실행하고 결과를 처리하는 공통 함수"""
try:
print(f"Running command: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True, check=True)
# inference.py의 마지막 줄에 있는 JSON 출력만 파싱합니다.
last_line = result.stdout.strip().split('\n')[-1]
parsed_json = json.loads(last_line)
# JSON을 예쁘게 포맷팅해서 반환합니다.
return json.dumps(parsed_json, indent=2)
except subprocess.CalledProcessError as e:
error_message = f"스크립트 실행 실패:\n--- STDERR ---\n{e.stderr}\n--- STDOUT ---\n{e.stdout}"
raise gr.Error(error_message)
except Exception as e:
raise gr.Error(f"알 수 없는 오류 발생: {str(e)}")
# --- 모드 1: 단일 미래 예측 ---
def single_prediction(station_id, input_csv_file):
if input_csv_file is None:
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
common_args = get_common_args(station_id)
# ⭐️ 긴 체크포인트 경로를 변수로 정리
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
command = [
"python", "inference.py",
"--checkpoint_path", checkpoint_path,
"--scaler_path", scaler_path,
"--predict_input_file", input_csv_file.name,
] + common_args
return execute_inference(command)
# --- 모드 2: 전체 기간 롤링 평가 ---
def rolling_evaluation(station_id, eval_csv_file):
if eval_csv_file is None:
raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
common_args = get_common_args(station_id)
# ⭐️ 긴 체크포인트 경로를 변수로 정리
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
command = [
"python", "inference.py",
"--checkpoint_path", checkpoint_path,
"--scaler_path", scaler_path,
"--evaluate_file", eval_csv_file.name,
] + common_args
return execute_inference(command)
# --- Gradio UI 구성 ---
with gr.Blocks() as demo:
gr.Markdown("# 조위 예측 모델 v4.0")
with gr.Tabs():
with gr.TabItem("단일 미래 예측"):
with gr.Row():
station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
file_input1 = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"])
submit_btn1 = gr.Button("예측 실행")
output1 = gr.Textbox(label="실행 결과", lines=15)
with gr.TabItem("전체 기간 롤링 평가"):
with gr.Row():
station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
file_input2 = gr.File(label="평가용 전체 데이터 (.csv 파일)", file_types=[".csv"])
submit_btn2 = gr.Button("평가 실행")
output2 = gr.Textbox(label="실행 결과", lines=15)
submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input1], outputs=output1)
submit_btn2.click(fn=rolling_evaluation, inputs=[station_dropdown2, file_input2], outputs=output2)
if __name__ == "__main__":
demo.launch() |