Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,85 +1,117 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import joblib
|
| 6 |
-
import os
|
| 7 |
import json
|
| 8 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
args.d_layers = 1
|
| 35 |
-
args.factor = 3
|
| 36 |
-
args.enc_in = 5
|
| 37 |
-
args.dec_in = 5
|
| 38 |
-
args.c_out = 1
|
| 39 |
-
args.d_model = 256
|
| 40 |
-
args.n_heads = 8
|
| 41 |
-
args.d_ff = 512
|
| 42 |
-
args.output_attention = True
|
| 43 |
-
args.device = torch.device('cpu')
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
# --- 예측을 수행하는 함수 ---
|
| 60 |
-
def predict_tide(input_csv_string):
|
| 61 |
-
# (이 부분은 수정할 필요 없습니다)
|
| 62 |
-
# ... 이전 코드와 동일 ...
|
| 63 |
-
if not MODEL_LOADED:
|
| 64 |
-
raise gr.Error(f"모델 로딩 실패: {MODEL_ERROR}")
|
| 65 |
-
# ...
|
| 66 |
-
# ...
|
| 67 |
-
return json.dumps({"prediction": prediction.flatten().tolist()}, indent=2)
|
| 68 |
|
| 69 |
-
# --- Gradio
|
| 70 |
-
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
fn=
|
| 78 |
-
|
| 79 |
-
outputs="json",
|
| 80 |
-
title="조위 예측 모델 (TimeXer)",
|
| 81 |
-
description=desc_text
|
| 82 |
-
)
|
| 83 |
|
| 84 |
if __name__ == "__main__":
|
| 85 |
demo.launch()
|
|
|
|
| 1 |
+
# app.py (v3.0 - 최종 버전)
|
| 2 |
import gradio as gr
|
| 3 |
+
import subprocess
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import json
|
| 5 |
+
import os
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
# 공통 설정
|
| 9 |
+
STATIONS = ["DT_0001", "DT_0052"] # 지원하는 관측소 목록
|
| 10 |
|
| 11 |
+
def get_common_args(station_id):
|
| 12 |
+
"""관측소별 공통 인자를 반환하는 함수"""
|
| 13 |
+
# 이 부분은 inference.py의 기본값과 일치하거나, .sh 파일의 설정을 따릅니다.
|
| 14 |
+
return [
|
| 15 |
+
"--model", "TimeXer",
|
| 16 |
+
"--features", "MS",
|
| 17 |
+
"--seq_len", "144",
|
| 18 |
+
"--pred_len", "72",
|
| 19 |
+
"--label_len", "96",
|
| 20 |
+
"--enc_in", "5",
|
| 21 |
+
"--dec_in", "5",
|
| 22 |
+
"--c_out", "1",
|
| 23 |
+
"--d_model", "256",
|
| 24 |
+
"--d_ff", "512",
|
| 25 |
+
"--n_heads", "8",
|
| 26 |
+
"--e_layers", "1",
|
| 27 |
+
"--d_layers", "1",
|
| 28 |
+
"--factor", "3",
|
| 29 |
+
"--patch_len", "16",
|
| 30 |
+
"--expand", "2",
|
| 31 |
+
"--d_conv", "4",
|
| 32 |
+
]
|
| 33 |
|
| 34 |
+
def execute_inference(command):
|
| 35 |
+
"""subprocess를 실행하고 결과를 처리하는 공통 함수"""
|
| 36 |
+
try:
|
| 37 |
+
print(f"Running command: {' '.join(command)}")
|
| 38 |
+
# 스크립트 실행
|
| 39 |
+
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
| 40 |
+
# 스크립트의 표준 출력(stdout)을 그대로 반환 (디버깅용)
|
| 41 |
+
# stderr에 담긴 진행 과정 로그도 함께 보여주면 좋습니다.
|
| 42 |
+
output = f"--- STDOUT ---\n{result.stdout}\n\n--- STDERR ---\n{result.stderr}"
|
| 43 |
+
return output
|
| 44 |
+
except subprocess.CalledProcessError as e:
|
| 45 |
+
# 스크립트 실행 중 오류 발생 시
|
| 46 |
+
error_message = f"스크립트 실행 실패:\n--- STDERR ---\n{e.stderr}"
|
| 47 |
+
raise gr.Error(error_message)
|
| 48 |
+
except Exception as e:
|
| 49 |
+
raise gr.Error(f"알 수 없는 오류 발생: {str(e)}")
|
| 50 |
|
| 51 |
+
# --- 모드 1: 단일 미래 예측 ---
|
| 52 |
+
def single_prediction(station_id, input_csv_file):
|
| 53 |
+
if input_csv_file is None:
|
| 54 |
+
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
|
| 55 |
+
|
| 56 |
+
# 공통 인자 가져오기
|
| 57 |
+
common_args = get_common_args(station_id)
|
| 58 |
+
|
| 59 |
+
# 모델 및 스케일러 경로 설정
|
| 60 |
+
checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth"
|
| 61 |
+
scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# 실행할 명령어 조합
|
| 64 |
+
command = [
|
| 65 |
+
"python", "inference.py",
|
| 66 |
+
"--checkpoint_path", checkpoint_path,
|
| 67 |
+
"--scaler_path", scaler_path,
|
| 68 |
+
"--predict_input_file", input_csv_file.name, # 업로드된 파일의 경로
|
| 69 |
+
] + common_args
|
| 70 |
+
|
| 71 |
+
return execute_inference(command)
|
| 72 |
|
| 73 |
+
# --- 모드 2: 전체 기간 롤링 평가 ---
|
| 74 |
+
def rolling_evaluation(station_id):
|
| 75 |
+
# 공통 인자 가져오기
|
| 76 |
+
common_args = get_common_args(station_id)
|
| 77 |
+
|
| 78 |
+
# 모델, 스케일러, 평가 데이터 경로 설정
|
| 79 |
+
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
|
| 80 |
+
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
|
| 81 |
+
scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
|
| 82 |
+
evaluate_file_path = f"./dataset/{station_id}_final_training_data.csv"
|
| 83 |
|
| 84 |
+
# 실행할 명령어 조합
|
| 85 |
+
command = [
|
| 86 |
+
"python", "inference.py",
|
| 87 |
+
"--checkpoint_path", checkpoint_path,
|
| 88 |
+
"--scaler_path", scaler_path,
|
| 89 |
+
"--evaluate_file", evaluate_file_path,
|
| 90 |
+
] + common_args
|
| 91 |
+
|
| 92 |
+
return execute_inference(command)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
# --- Gradio UI 구성 ---
|
| 96 |
+
with gr.Blocks() as demo:
|
| 97 |
+
gr.Markdown("# 조위 예측 모델 v3.0")
|
| 98 |
+
|
| 99 |
+
with gr.Tabs():
|
| 100 |
+
with gr.TabItem("단일 미래 예측"):
|
| 101 |
+
with gr.Row():
|
| 102 |
+
station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
|
| 103 |
+
file_input = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"])
|
| 104 |
+
submit_btn1 = gr.Button("예측 실행")
|
| 105 |
+
output1 = gr.Textbox(label="실행 결과", lines=15)
|
| 106 |
+
|
| 107 |
+
with gr.TabItem("전체 기간 롤링 평가"):
|
| 108 |
+
station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
|
| 109 |
+
submit_btn2 = gr.Button("평가 실행")
|
| 110 |
+
output2 = gr.Textbox(label="실행 결과", lines=15)
|
| 111 |
|
| 112 |
+
# 버튼 클릭 이벤트 연결
|
| 113 |
+
submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input], outputs=output1)
|
| 114 |
+
submit_btn2.click(fn=rolling_evaluation, inputs=station_dropdown2, outputs=output2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
demo.launch()
|