Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import subprocess | |
| import json | |
| import os | |
| STATIONS = [ | |
| "DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002", | |
| "DT_0050", "DT_0017", "DT_0052", "DT_0025", "DT_0051", "DT_0037", | |
| "DT_0024", "DT_0018", "DT_0068", "DT_0003", "DT_0066" | |
| ] | |
| # --- ⭐️⭐️⭐️ 이 함수 안의 내용이 빠져 있었습니다 ⭐️⭐️⭐️ --- | |
| def get_common_args(station_id): | |
| """관측소별 공통 인자를 반환하는 함수""" | |
| # inference.py의 인자들을 그대로 가져옵니다. | |
| return [ | |
| "--model", "TimeXer", | |
| "--features", "MS", | |
| "--seq_len", "144", | |
| "--pred_len", "72", | |
| "--label_len", "96", | |
| "--enc_in", "5", | |
| "--dec_in", "5", | |
| "--c_out", "1", | |
| "--d_model", "256", | |
| "--d_ff", "512", | |
| "--n_heads", "8", | |
| "--e_layers", "1", | |
| "--d_layers", "1", | |
| "--factor", "3", | |
| "--patch_len", "16", | |
| "--expand", "2", | |
| "--d_conv", "4", | |
| ] | |
| def execute_inference(command): | |
| """subprocess를 실행하고 결과를 처리하는 공통 함수""" | |
| try: | |
| print(f"Running command: {' '.join(command)}") | |
| result = subprocess.run(command, capture_output=True, text=True, check=True) | |
| # inference.py의 마지막 줄에 있는 JSON 출력만 파싱합니다. | |
| last_line = result.stdout.strip().split('\n')[-1] | |
| parsed_json = json.loads(last_line) | |
| # JSON을 예쁘게 포맷팅해서 반환합니다. | |
| return json.dumps(parsed_json, indent=2) | |
| except subprocess.CalledProcessError as e: | |
| error_message = f"스크립트 실행 실패:\n--- STDERR ---\n{e.stderr}\n--- STDOUT ---\n{e.stdout}" | |
| raise gr.Error(error_message) | |
| except Exception as e: | |
| raise gr.Error(f"알 수 없는 오류 발생: {str(e)}") | |
| # --- 모드 1: 단일 미래 예측 --- | |
| def single_prediction(station_id, input_csv_file): | |
| if input_csv_file is None: | |
| raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.") | |
| common_args = get_common_args(station_id) | |
| # ⭐️ 긴 체크포인트 경로를 변수로 정리 | |
| setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0" | |
| checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth" | |
| scaler_path = f"./checkpoints/{setting_name}/scaler.gz" | |
| command = [ | |
| "python", "inference.py", | |
| "--checkpoint_path", checkpoint_path, | |
| "--scaler_path", scaler_path, | |
| "--predict_input_file", input_csv_file.name, | |
| ] + common_args | |
| return execute_inference(command) | |
| # --- 모드 2: 전체 기간 롤링 평가 --- | |
| def rolling_evaluation(station_id, eval_csv_file): | |
| if eval_csv_file is None: | |
| raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.") | |
| common_args = get_common_args(station_id) | |
| # ⭐️ 긴 체크포인트 경로를 변수로 정리 | |
| setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0" | |
| checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth" | |
| scaler_path = f"./checkpoints/{setting_name}/scaler.gz" | |
| command = [ | |
| "python", "inference.py", | |
| "--checkpoint_path", checkpoint_path, | |
| "--scaler_path", scaler_path, | |
| "--evaluate_file", eval_csv_file.name, | |
| ] + common_args | |
| return execute_inference(command) | |
| # --- Gradio UI 구성 --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 조위 예측 모델 v4.0") | |
| with gr.Tabs(): | |
| with gr.TabItem("단일 미래 예측"): | |
| with gr.Row(): | |
| station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)") | |
| file_input1 = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"]) | |
| submit_btn1 = gr.Button("예측 실행") | |
| output1 = gr.Textbox(label="실행 결과", lines=15) | |
| with gr.TabItem("전체 기간 롤링 평가"): | |
| with gr.Row(): | |
| station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)") | |
| file_input2 = gr.File(label="평가용 전체 데이터 (.csv 파일)", file_types=[".csv"]) | |
| submit_btn2 = gr.Button("평가 실행") | |
| output2 = gr.Textbox(label="실행 결과", lines=15) | |
| submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input1], outputs=output1) | |
| submit_btn2.click(fn=rolling_evaluation, inputs=[station_dropdown2, file_input2], outputs=output2) | |
| if __name__ == "__main__": | |
| demo.launch() |