alwaysgood commited on
Commit
4ccea30
·
verified ·
1 Parent(s): e1ccef5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -71
app.py CHANGED
@@ -1,85 +1,117 @@
 
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
- import pandas as pd
5
- import joblib
6
- import os
7
  import json
8
- import sys # 👈 이 줄 추가
 
 
 
 
9
 
10
- # --- 모델 및 스케일러 로딩 ---
11
- MODEL_LOADED = False
12
- MODEL_ERROR = "Unknown"
13
- try:
14
- # ⭐️⭐️⭐️ 바로 이 부분입니다! ⭐️⭐️⭐️
15
- # 현재 폴더(.)를 파이썬의 모듈 검색 경로에 추가합니다.
16
- # 이렇게 하면 app.py가 models, utils 폴더를 찾을 수 있게 됩니다.
17
- sys.path.append('.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- from models.TimeXer import Model as TimeXerModel
20
- from utils.tools import dotdict
21
- from utils.timefeatures import time_features
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # 1. 훈련 스크립트(.sh)의 모든 설정을 그대로 가져옵니다.
24
- args = dotdict()
25
- args.model_id = 'DT_0001_144_72'
26
- args.model = 'TimeXer'
27
- args.task_name = 'long_term_forecast'
28
- args.seq_len = 144
29
- args.label_len = 96
30
- args.pred_len = 72
31
- args.features = 'MS'
32
- args.target = 'residual'
33
- args.e_layers = 1
34
- args.d_layers = 1
35
- args.factor = 3
36
- args.enc_in = 5
37
- args.dec_in = 5
38
- args.c_out = 1
39
- args.d_model = 256
40
- args.n_heads = 8
41
- args.d_ff = 512
42
- args.output_attention = True
43
- args.device = torch.device('cpu')
44
 
45
- # 2. 모델 뼈대를 만들고 학습된 가중치를 입힙니다.
46
- model = TimeXerModel(args).float()
47
- model.load_state_dict(torch.load('checkpoints/checkpoint.pth', map_location=args.device))
48
- model.eval()
 
 
 
 
 
49
 
50
- # 3. 스케일러를 불러옵니다.
51
- scaler = joblib.load('checkpoints/scaler.gz')
52
- MODEL_LOADED = True
53
- print("✅ 모델과 스케일러 로딩 성공!")
 
 
 
 
 
 
54
 
55
- except Exception as e:
56
- MODEL_ERROR = str(e)
57
- print(f" 모델 로딩 중 에러 발생: {MODEL_ERROR}")
 
 
 
 
 
 
58
 
59
- # --- 예측을 수행하는 함수 ---
60
- def predict_tide(input_csv_string):
61
- # (이 부분은 수정할 필요 없습니다)
62
- # ... 이전 코드와 동일 ...
63
- if not MODEL_LOADED:
64
- raise gr.Error(f"모델 로딩 실패: {MODEL_ERROR}")
65
- # ...
66
- # ...
67
- return json.dumps({"prediction": prediction.flatten().tolist()}, indent=2)
68
 
69
- # --- Gradio 인터페이스 생성 ---
70
- # NameError를 방지하기 위해, try 블록 바깥에 있는 args 참조를 제거하거나
71
- # 모델 로딩이 실패했을 경우를 대비해 기본값을 사용하도록 수정합니다.
72
- desc_text = "과거 144개 시점의 다변량 데이터를 입력하면, 미래 72개 시점의 조위 편차(residual)를 예측합니다."
73
- if MODEL_LOADED:
74
- desc_text = f"과거 {args.seq_len}개 시점의 다변량 데이터를 입력하면, 미래 {args.pred_len}개 시점의 조위 편차(residual)를 예측합니다."
 
 
 
 
 
 
 
 
 
 
75
 
76
- demo = gr.Interface(
77
- fn=predict_tide,
78
- inputs=gr.Textbox(lines=10, placeholder="CSV 형식으로 144개의 데이터를 입력하세요.\n첫 줄은 헤더(date,OT,...)여야 합니다."),
79
- outputs="json",
80
- title="조위 예측 모델 (TimeXer)",
81
- description=desc_text
82
- )
83
 
84
  if __name__ == "__main__":
85
  demo.launch()
 
1
+ # app.py (v3.0 - 최종 버전)
2
  import gradio as gr
3
+ import subprocess
 
 
 
 
4
  import json
5
+ import os
6
+ import pandas as pd
7
+
8
+ # 공통 설정
9
+ STATIONS = ["DT_0001", "DT_0052"] # 지원하는 관측소 목록
10
 
11
+ def get_common_args(station_id):
12
+ """관측소별 공통 인자를 반환하는 함수"""
13
+ # 부분은 inference.py의 기본값과 일치하거나, .sh 파일의 설정을 따릅니다.
14
+ return [
15
+ "--model", "TimeXer",
16
+ "--features", "MS",
17
+ "--seq_len", "144",
18
+ "--pred_len", "72",
19
+ "--label_len", "96",
20
+ "--enc_in", "5",
21
+ "--dec_in", "5",
22
+ "--c_out", "1",
23
+ "--d_model", "256",
24
+ "--d_ff", "512",
25
+ "--n_heads", "8",
26
+ "--e_layers", "1",
27
+ "--d_layers", "1",
28
+ "--factor", "3",
29
+ "--patch_len", "16",
30
+ "--expand", "2",
31
+ "--d_conv", "4",
32
+ ]
33
 
34
+ def execute_inference(command):
35
+ """subprocess를 실행하고 결과를 처리하는 공통 함수"""
36
+ try:
37
+ print(f"Running command: {' '.join(command)}")
38
+ # 스크립트 실행
39
+ result = subprocess.run(command, capture_output=True, text=True, check=True)
40
+ # 스크립트의 표준 출력(stdout)을 그대로 반환 (디버깅용)
41
+ # stderr에 담긴 진행 과정 로그도 함께 보여주면 좋습니다.
42
+ output = f"--- STDOUT ---\n{result.stdout}\n\n--- STDERR ---\n{result.stderr}"
43
+ return output
44
+ except subprocess.CalledProcessError as e:
45
+ # 스크립트 실행 중 오류 발생 시
46
+ error_message = f"스크립트 실행 실패:\n--- STDERR ---\n{e.stderr}"
47
+ raise gr.Error(error_message)
48
+ except Exception as e:
49
+ raise gr.Error(f"알 수 없는 오류 발생: {str(e)}")
50
 
51
+ # --- 모드 1: 단일 미래 예측 ---
52
+ def single_prediction(station_id, input_csv_file):
53
+ if input_csv_file is None:
54
+ raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
55
+
56
+ # 공통 인자 가져오기
57
+ common_args = get_common_args(station_id)
58
+
59
+ # 모델 및 스케일러 경로 설정
60
+ checkpoint_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth"
61
+ scaler_path = f"./checkpoints/long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz"
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # 실행할 명령어 조합
64
+ command = [
65
+ "python", "inference.py",
66
+ "--checkpoint_path", checkpoint_path,
67
+ "--scaler_path", scaler_path,
68
+ "--predict_input_file", input_csv_file.name, # 업로드된 파일의 경로
69
+ ] + common_args
70
+
71
+ return execute_inference(command)
72
 
73
+ # --- 모드 2: 전체 기간 롤링 평가 ---
74
+ def rolling_evaluation(station_id):
75
+ # 공통 인자 가져오기
76
+ common_args = get_common_args(station_id)
77
+
78
+ # 모델, 스케일러, 평가 데이터 경로 설정
79
+ setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
80
+ checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
81
+ scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
82
+ evaluate_file_path = f"./dataset/{station_id}_final_training_data.csv"
83
 
84
+ # 실행할 명령어 조합
85
+ command = [
86
+ "python", "inference.py",
87
+ "--checkpoint_path", checkpoint_path,
88
+ "--scaler_path", scaler_path,
89
+ "--evaluate_file", evaluate_file_path,
90
+ ] + common_args
91
+
92
+ return execute_inference(command)
93
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # --- Gradio UI 구성 ---
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("# 조위 예측 모델 v3.0")
98
+
99
+ with gr.Tabs():
100
+ with gr.TabItem("단일 미래 예측"):
101
+ with gr.Row():
102
+ station_dropdown1 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
103
+ file_input = gr.File(label="입력 데이터 (.csv 파일)", file_types=[".csv"])
104
+ submit_btn1 = gr.Button("예측 실행")
105
+ output1 = gr.Textbox(label="실행 결과", lines=15)
106
+
107
+ with gr.TabItem("전체 기간 롤링 평가"):
108
+ station_dropdown2 = gr.Dropdown(choices=STATIONS, label="관측소 (Station ID)")
109
+ submit_btn2 = gr.Button("평가 실행")
110
+ output2 = gr.Textbox(label="실행 결과", lines=15)
111
 
112
+ # 버튼 클릭 이벤트 연결
113
+ submit_btn1.click(fn=single_prediction, inputs=[station_dropdown1, file_input], outputs=output1)
114
+ submit_btn2.click(fn=rolling_evaluation, inputs=station_dropdown2, outputs=output2)
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
  demo.launch()