alwaysgood commited on
Commit
4cd8f35
·
verified ·
1 Parent(s): b444514

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -64
app.py CHANGED
@@ -1,122 +1,81 @@
1
- # app.py (v3.0 - 최종 버전)
2
  import gradio as gr
3
  import subprocess
4
  import json
5
  import os
6
- import pandas as pd
7
 
8
- # 공통 설정
9
  STATIONS = [
10
  "DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
11
  "DT_0050", "DT_0017", "DT_0052", "DT_0025", "DT_0051", "DT_0037",
12
  "DT_0024", "DT_0018", "DT_0068", "DT_0003", "DT_0066"
13
  ]
14
- # 지원하는 관측소 목록
15
 
16
  def get_common_args(station_id):
17
- """관측소별 공통 인자를 반환하는 함수"""
18
- # 이 부분은 inference.py의 기본값과 일치하거나, .sh 파일의 설정을 따릅니다.
19
- return [
20
- "--model", "TimeXer",
21
- "--features", "MS",
22
- "--seq_len", "144",
23
- "--pred_len", "72",
24
- "--label_len", "96",
25
- "--enc_in", "5",
26
- "--dec_in", "5",
27
- "--c_out", "1",
28
- "--d_model", "256",
29
- "--d_ff", "512",
30
- "--n_heads", "8",
31
- "--e_layers", "1",
32
- "--d_layers", "1",
33
- "--factor", "3",
34
- "--patch_len", "16",
35
- "--expand", "2",
36
- "--d_conv", "4",
37
- ]
38
 
39
  def execute_inference(command):
40
- """subprocess를 실행하고 결과를 처리하는 공통 함수"""
41
- try:
42
- print(f"Running command: {' '.join(command)}")
43
- # 스크립트 실행
44
- result = subprocess.run(command, capture_output=True, text=True, check=True)
45
- # 스크립트의 표준 출력(stdout)을 그대로 반환 (디버깅용)
46
- # stderr에 담긴 진행 과정 로그도 함께 보여주면 좋습니다.
47
- output = f"--- STDOUT ---\n{result.stdout}\n\n--- STDERR ---\n{result.stderr}"
48
- return output
49
- except subprocess.CalledProcessError as e:
50
- # 스크립트 실행 중 오류 발생 시
51
- error_message = f"스크립트 실행 실패:\n--- STDERR ---\n{e.stderr}"
52
- raise gr.Error(error_message)
53
- except Exception as e:
54
- raise gr.Error(f"알 수 없는 오류 발생: {str(e)}")
55
 
56
  # --- 모드 1: 단일 미래 예측 ---
57
  def single_prediction(station_id, input_csv_file):
58
  if input_csv_file is None:
59
  raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
60
 
61
- # 공통 인자 가져오기
62
  common_args = get_common_args(station_id)
63
-
64
- # 모델 및 스케일러 경로 설정
65
- checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth"
66
- scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz"
67
 
68
- # 실행할 명령어 조합
69
  command = [
70
  "python", "inference.py",
71
  "--checkpoint_path", checkpoint_path,
72
  "--scaler_path", scaler_path,
73
- "--predict_input_file", input_csv_file.name, # 업로드된 파일의 경로
74
  ] + common_args
75
 
76
  return execute_inference(command)
77
 
78
  # --- 모드 2: 전체 기간 롤링 평가 ---
79
- def rolling_evaluation(station_id):
80
- # 공통 인자 가져오기
 
 
81
  common_args = get_common_args(station_id)
82
-
83
- # 모델, 스케일러, 평가 데이터 경로 설정
84
- setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
85
- checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
86
- scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
87
- evaluate_file_path = f"./dataset/{station_id}_final_training_data.csv"
88
 
89
- # 실행할 명령어 조합
90
  command = [
91
  "python", "inference.py",
92
  "--checkpoint_path", checkpoint_path,
93
  "--scaler_path", scaler_path,
94
- "--evaluate_file", evaluate_file_path,
95
  ] + common_args
96
 
97
  return execute_inference(command)
98
 
99
 
100
- # --- Gradio UI 구성 ---
101
  with gr.Blocks() as demo:
102
- gr.Markdown("# 조위 예측 모델 v3.0")
103
 
104
  with gr.Tabs():
105
  with gr.TabItem("단일 미래 예측"):
106
  with gr.Row():
107
  station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
108
- file_input = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"])
109
  submit_btn1 = gr.Button("예측 실행")
110
  output1 = gr.Textbox(label="실행 결과", lines=15)
111
 
112
  with gr.TabItem("전체 기간 롤링 평가"):
113
- station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
 
 
 
114
  submit_btn2 = gr.Button("평가 실행")
115
  output2 = gr.Textbox(label="실행 결과", lines=15)
116
 
117
- # 버튼 클릭 이벤트 연결
118
- submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input], outputs=output1)
119
- submit_btn2.click(fn=rolling_evaluation, inputs=station_dropdown2, outputs=output2)
120
 
121
  if __name__ == "__main__":
122
  demo.launch()
 
1
+ # app.py (v4.0 - 최종 버전)
2
  import gradio as gr
3
  import subprocess
4
  import json
5
  import os
 
6
 
 
7
  STATIONS = [
8
  "DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
9
  "DT_0050", "DT_0017", "DT_0052", "DT_0025", "DT_0051", "DT_0037",
10
  "DT_0024", "DT_0018", "DT_0068", "DT_0003", "DT_0066"
11
  ]
 
12
 
13
  def get_common_args(station_id):
14
+ # ... (이전과 동일) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def execute_inference(command):
17
+ # ... (이전과 동일) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # --- 모드 1: 단일 미래 예측 ---
20
  def single_prediction(station_id, input_csv_file):
21
  if input_csv_file is None:
22
  raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
23
 
 
24
  common_args = get_common_args(station_id)
25
+ checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../checkpoint.pth" # 이름이 길어서 생략
26
+ scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../scaler.gz"
 
 
27
 
 
28
  command = [
29
  "python", "inference.py",
30
  "--checkpoint_path", checkpoint_path,
31
  "--scaler_path", scaler_path,
32
+ "--predict_input_file", input_csv_file.name,
33
  ] + common_args
34
 
35
  return execute_inference(command)
36
 
37
  # --- 모드 2: 전체 기간 롤링 평가 ---
38
+ def rolling_evaluation(station_id, eval_csv_file):
39
+ if eval_csv_file is None:
40
+ raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
41
+
42
  common_args = get_common_args(station_id)
43
+ checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../checkpoint.pth"
44
+ scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_.../scaler.gz"
 
 
 
 
45
 
 
46
  command = [
47
  "python", "inference.py",
48
  "--checkpoint_path", checkpoint_path,
49
  "--scaler_path", scaler_path,
50
+ "--evaluate_file", eval_csv_file.name,
51
  ] + common_args
52
 
53
  return execute_inference(command)
54
 
55
 
56
+ # --- Gradio UI 구성 (⭐️⭐️⭐️ 이 부분이 변경되었습니다 ⭐️⭐️⭐️) ---
57
  with gr.Blocks() as demo:
58
+ gr.Markdown("# 조위 예측 모델 v4.0")
59
 
60
  with gr.Tabs():
61
  with gr.TabItem("단일 미래 예측"):
62
  with gr.Row():
63
  station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
64
+ file_input1 = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"])
65
  submit_btn1 = gr.Button("예측 실행")
66
  output1 = gr.Textbox(label="실행 결과", lines=15)
67
 
68
  with gr.TabItem("전체 기간 롤링 평가"):
69
+ with gr.Row():
70
+ station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
71
+ # ⭐️ '롤링 평가' 탭에도 파일 업로드 기능 추가
72
+ file_input2 = gr.File(label="평가용 전체 데이터 (.csv 파일)", file_types=[".csv"])
73
  submit_btn2 = gr.Button("평가 실행")
74
  output2 = gr.Textbox(label="실행 결과", lines=15)
75
 
76
+ # 버튼 클릭 이벤트 연결 (file_input2 추가)
77
+ submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input1], outputs=output1)
78
+ submit_btn2.click(fn=rolling_evaluation, inputs=[station_dropdown2, file_input2], outputs=output2)
79
 
80
  if __name__ == "__main__":
81
  demo.launch()