alwaysgood commited on
Commit
5eccb5b
·
verified ·
1 Parent(s): 4ccea30

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +180 -0
inference.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pandas as pd
4
+ import argparse
5
+ import joblib
6
+ import os
7
+ from tqdm import tqdm
8
+ import json # 👈 JSON 라이브러리 추가
9
+
10
+ # ⭐️ 수정 사항 1: 경로 문제를 피하기 위해 명시적으로 import 경로 추가
11
+ import sys
12
+ sys.path.append('.')
13
+
14
+ from models import TimeXer # 사용하는 모델에 맞게 수정
15
+ from utils.metrics import metric # 성능 평가를 위해 추가
16
+
17
+ # --- 1. 인자 파싱 (수정 없음) ---
18
+ parser = argparse.ArgumentParser(description='Time Series Prediction')
19
+ # ... (이전과 동일한 모든 argparse 코드) ...
20
+ parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to the model checkpoint file (.pth)')
21
+ parser.add_argument('--scaler_path', type=str, required=True, help='Path to the saved scaler file (.gz)')
22
+ parser.add_argument('--predict_input_file', type=str, default=None, help='[Mode 1] Path to the CSV file for single future prediction')
23
+ parser.add_argument('--evaluate_file', type=str, default=None, help='[Mode 2] Path to the CSV file for rolling evaluation')
24
+ parser.add_argument('--model', type=str, default='TimeXer', help='model name')
25
+ parser.add_argument('--task_name', type=str, default='long_term_forecast', help='task name')
26
+ parser.add_argument('--seq_len', type=int, required=True, help='input sequence length')
27
+ parser.add_argument('--pred_len', type=int, required=True, help='prediction sequence length')
28
+ parser.add_argument('--label_len', type=int, required=True, help='start token length')
29
+ parser.add_argument('--features', type=str, required=True, help='M, S, or MS')
30
+ parser.add_argument('--enc_in', type=int, required=True, help='encoder input size')
31
+ parser.add_argument('--dec_in', type=int, required=True, help='decoder input size')
32
+ parser.add_argument('--c_out', type=int, required=True, help='output size')
33
+ parser.add_argument('--d_model', type=int, required=True, help='dimension of model')
34
+ parser.add_argument('--n_heads', type=int, required=True, help='num of heads')
35
+ parser.add_argument('--e_layers', type=int, required=True, help='num of encoder layers')
36
+ parser.add_argument('--d_layers', type=int, required=True, help='num of decoder layers')
37
+ parser.add_argument('--d_ff', type=int, required=True, help='dimension of fcn')
38
+ parser.add_argument('--factor', type=int, required=True, help='attn factor')
39
+ parser.add_argument('--patch_len', type=int, required=True, help='patch length for TimeXer')
40
+ parser.add_argument('--expand', type=int, required=True)
41
+ parser.add_argument('--d_conv', type=int, required=True)
42
+ parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
43
+ parser.add_argument('--embed', type=str, default='timeF', help='time features encoding')
44
+ parser.add_argument('--activation', type=str, default='gelu', help='activation')
45
+ parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
46
+ parser.add_argument('--use_norm', type=int, default=1, help='whether to use normalize')
47
+ parser.add_argument('--freq', type=str, default='t', help='freq for time features encoding')
48
+ args = parser.parse_args()
49
+
50
+
51
+ # --- 2. 공통 함수: 모델 및 스케일러 로드 (수정 없음) ---
52
+ def load_model_and_scaler(args):
53
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
54
+ # ⭐️ 수정 사항 2: args에 device 정보 추가 (TimeXer 모델 초기화 시 필요할 수 있음)
55
+ args.device = device
56
+ model = TimeXer.Model(args).float().to(device)
57
+ model.load_state_dict(torch.load(args.checkpoint_path, map_location=device))
58
+ model.eval()
59
+ scaler = joblib.load(args.scaler_path)
60
+ # 진행 상황을 stderr로 출력하여 stdout의 JSON 결과와 분리
61
+ print(f"Using device: {device}", file=sys.stderr)
62
+ print("Model and scaler loaded successfully.", file=sys.stderr)
63
+ return model, scaler, device
64
+
65
+ # --- 3. 모드 1: 단일 미래 예측 함수 (수정 없음) ---
66
+ def predict_future(args, model, scaler, device):
67
+ # ... (이전과 동일한 코드) ...
68
+ # 이 함수는 예측 결과(prediction)만 반환하면 됩니다.
69
+ df_input = pd.read_csv(args.predict_input_file)
70
+ if 'date' in df_input.columns:
71
+ df_input = df_input.drop(columns=['date'])
72
+ raw_input = df_input.tail(args.seq_len).values
73
+
74
+ input_scaled = scaler.transform(raw_input)
75
+ batch_x = torch.from_numpy(input_scaled).float().unsqueeze(0).to(device)
76
+
77
+ with torch.no_grad():
78
+ # TimeXer 모델의 forward 함수에 맞게 인자 전달
79
+ # 여기서는 batch_x만 필요하다고 가정. 필요 시 batch_x_mark 등 추가
80
+ outputs = model(batch_x)
81
+
82
+ prediction_scaled = outputs.detach().cpu().numpy()[0]
83
+
84
+ # 스케일 복원 로직
85
+ if args.features == 'MS' and scaler.n_features_in_ > 1:
86
+ padding = np.zeros((prediction_scaled.shape[0], scaler.n_features_in_ - args.c_out))
87
+ # 예측 결과를 마지막 feature 자리에 위치
88
+ prediction_padded = np.concatenate((padding, prediction_scaled), axis=1)
89
+ prediction = scaler.inverse_transform(prediction_padded)[:, -args.c_out:]
90
+ else:
91
+ prediction = scaler.inverse_transform(prediction_scaled)
92
+
93
+ return prediction
94
+
95
+
96
+ # --- 4. 모드 2: 전체 기간 롤링 평가 함수 (수정 없음) ---
97
+ def evaluate_performance(args, model, scaler, device):
98
+ # ... (이전과 동일한 코드) ...
99
+ # 이 함수는 예측값들과 실제값들을 반환하면 됩니다.
100
+ df_eval = pd.read_csv(args.evaluate_file)
101
+ if 'date' in df_eval.columns:
102
+ df_eval = df_eval.drop(columns=['date'])
103
+ raw_data = df_eval.values
104
+ data_scaled = scaler.transform(raw_data)
105
+
106
+ preds_unscaled = []
107
+ trues_unscaled = []
108
+
109
+ num_samples = len(data_scaled) - args.seq_len - args.pred_len + 1
110
+ for i in tqdm(range(num_samples), desc="Evaluating", file=sys.stderr):
111
+ s_begin = i
112
+ s_end = s_begin + args.seq_len
113
+ input_scaled = data_scaled[s_begin:s_end]
114
+ batch_x = torch.from_numpy(input_scaled).float().unsqueeze(0).to(device)
115
+
116
+ true_begin = s_end
117
+ true_end = true_begin + args.pred_len
118
+ true_scaled = data_scaled[true_begin:true_end]
119
+
120
+ with torch.no_grad():
121
+ outputs = model(batch_x)
122
+
123
+ pred_scaled = outputs.detach().cpu().numpy()[0]
124
+
125
+ if args.features == 'MS' and scaler.n_features_in_ > 1:
126
+ padding = np.zeros((pred_scaled.shape[0], scaler.n_features_in_ - args.c_out))
127
+ pred_padded = np.concatenate((padding, pred_scaled), axis=1)
128
+ pred_unscaled = scaler.inverse_transform(pred_padded)[:, -args.c_out:]
129
+ else:
130
+ pred_unscaled = scaler.inverse_transform(pred_scaled)
131
+
132
+ true_unscaled = scaler.inverse_transform(true_scaled)[:, -args.c_out:]
133
+
134
+ preds_unscaled.append(pred_unscaled)
135
+ trues_unscaled.append(true_unscaled)
136
+
137
+ return np.array(preds_unscaled), np.array(trues_unscaled)
138
+
139
+ # --- 5. 메인 로직 (⭐️⭐️⭐️ 이 부분이 완전히 변경되었습니다 ⭐️⭐️⭐️) ---
140
+ if __name__ == '__main__':
141
+
142
+ final_output = {} # 최종 결과를 담을 딕셔너리
143
+
144
+ try:
145
+ model, scaler, device = load_model_and_scaler(args)
146
+
147
+ if args.predict_input_file:
148
+ print("--- Running in Single Prediction Mode ---", file=sys.stderr)
149
+ prediction = predict_future(args, model, scaler, device)
150
+ final_output = {
151
+ "status": "success",
152
+ "mode": "single_prediction",
153
+ "prediction": prediction.flatten().tolist()
154
+ }
155
+
156
+ elif args.evaluate_file:
157
+ print("--- Running in Rolling Evaluation Mode ---", file=sys.stderr)
158
+ eval_preds, eval_trues = evaluate_performance(args, model, scaler, device)
159
+
160
+ # 성능 지표 계산
161
+ mae, mse, _, _, _ = metric(eval_preds, eval_trues)
162
+
163
+ final_output = {
164
+ "status": "success",
165
+ "mode": "rolling_evaluation",
166
+ "mse": mse,
167
+ "mae": mae,
168
+ # 전체 예측을 반환하면 너무 크므로, 샘플만 반환하거나 필요한 정보만 반환
169
+ "prediction_samples": [p.flatten().tolist() for p in eval_preds[:5]]
170
+ }
171
+
172
+ else:
173
+ final_output = {"status": "error", "message": "No mode selected. Use --predict_input_file or --evaluate_file."}
174
+
175
+ except Exception as e:
176
+ final_output = {"status": "error", "message": str(e)}
177
+
178
+ # 최종 결과를 JSON 문자열로 표준 출력(stdout)에 프린트합니다.
179
+ # 이 출력을 app.py가 읽어서 API 응답으로 사용합니다.
180
+ print(json.dumps(final_output, indent=2))