Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,17 @@ import pandas as pd
|
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
from plotly.subplots import make_subplots
|
| 9 |
import plotly.express as px
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
STATIONS = [
|
| 12 |
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
|
|
@@ -15,25 +26,124 @@ STATIONS = [
|
|
| 15 |
]
|
| 16 |
|
| 17 |
STATION_NAMES = {
|
| 18 |
-
"DT_0001": "인천",
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
"DT_0067": "안흥",
|
| 22 |
-
"
|
| 23 |
-
"DT_0002": "평택",
|
| 24 |
-
"DT_0050": "태안",
|
| 25 |
-
"DT_0017": "대산",
|
| 26 |
-
"DT_0052": "인천송도",
|
| 27 |
-
"DT_0025": "보령",
|
| 28 |
-
"DT_0051": "서천마량",
|
| 29 |
-
"DT_0037": "어청도",
|
| 30 |
-
"DT_0024": "장항",
|
| 31 |
-
"DT_0018": "군산",
|
| 32 |
-
"DT_0068": "위도",
|
| 33 |
-
"DT_0003": "영광",
|
| 34 |
-
"DT_0066": "향화도"
|
| 35 |
}
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def get_common_args(station_id):
|
| 38 |
return [
|
| 39 |
"--model", "TimeXer", "--features", "MS", "--seq_len", "144", "--pred_len", "72",
|
|
@@ -82,32 +192,18 @@ def execute_inference_and_get_results(command):
|
|
| 82 |
except Exception as e:
|
| 83 |
raise gr.Error(f"내부 오류: {str(e)}")
|
| 84 |
|
| 85 |
-
def
|
| 86 |
-
"""
|
| 87 |
-
print(f"Creating plot - predictions type: {type(predictions)}, shape: {predictions.shape}")
|
| 88 |
-
|
| 89 |
-
# 입력 데이터에서 시간 정보 추출
|
| 90 |
-
input_df = pd.read_csv(input_data.name)
|
| 91 |
-
input_df['date'] = pd.to_datetime(input_df['date'])
|
| 92 |
-
|
| 93 |
-
# 최근 24포인트(2시간)만 표시
|
| 94 |
-
recent_data = input_df.tail(24)
|
| 95 |
-
|
| 96 |
-
# predictions를 안전하게 변환
|
| 97 |
-
pred_values = np.array(predictions).flatten()
|
| 98 |
-
print(f"Prediction values shape: {pred_values.shape}, first 5: {pred_values[:5]}")
|
| 99 |
-
|
| 100 |
-
# 미래 시간 생성 - 더 안전한 방법
|
| 101 |
try:
|
| 102 |
-
#
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
future_times =
|
| 107 |
-
start=last_time,
|
| 108 |
-
periods=len(pred_values) + 1, # +1을 해서 시작점 포함
|
| 109 |
-
freq='5min'
|
| 110 |
-
)[1:] # 첫 번째 (현재 시간) 제외
|
| 111 |
|
| 112 |
# 플롯 생성
|
| 113 |
fig = go.Figure()
|
|
@@ -122,148 +218,64 @@ def create_prediction_plot(predictions, input_data, station_name):
|
|
| 122 |
marker=dict(size=4)
|
| 123 |
))
|
| 124 |
|
| 125 |
-
# 예측
|
| 126 |
fig.add_trace(go.Scatter(
|
| 127 |
x=future_times,
|
| 128 |
-
y=
|
| 129 |
mode='lines+markers',
|
| 130 |
-
name='예측
|
| 131 |
line=dict(color='red', width=2, dash='dash'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
marker=dict(size=4)
|
| 133 |
))
|
| 134 |
|
| 135 |
# 구분선 추가
|
|
|
|
| 136 |
fig.add_vline(x=last_time, line_dash="dot", line_color="gray",
|
| 137 |
annotation_text="예측 시작점")
|
| 138 |
|
| 139 |
fig.update_layout(
|
| 140 |
-
title=f'{station_name}
|
| 141 |
xaxis_title='시간',
|
| 142 |
-
yaxis_title='
|
| 143 |
hovermode='x unified',
|
| 144 |
-
height=
|
| 145 |
showlegend=True,
|
| 146 |
xaxis=dict(
|
| 147 |
-
tickformat='%H:%M<br>%m/%d'
|
| 148 |
)
|
| 149 |
)
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
except Exception as time_error:
|
| 154 |
-
print(f"Time axis creation failed: {time_error}, falling back to sequence")
|
| 155 |
|
| 156 |
-
|
|
|
|
|
|
|
| 157 |
fig = go.Figure()
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
x=recent_indices,
|
| 163 |
-
y=recent_data['residual'],
|
| 164 |
-
mode='lines+markers',
|
| 165 |
-
name='실제 잔차조위',
|
| 166 |
-
line=dict(color='blue', width=2),
|
| 167 |
-
marker=dict(size=4)
|
| 168 |
-
))
|
| 169 |
-
|
| 170 |
-
# 예측 데이터 - 순서 기반
|
| 171 |
-
future_indices = list(range(1, len(pred_values) + 1))
|
| 172 |
-
fig.add_trace(go.Scatter(
|
| 173 |
-
x=future_indices,
|
| 174 |
-
y=pred_values,
|
| 175 |
-
mode='lines+markers',
|
| 176 |
-
name='예측 잔차조위',
|
| 177 |
-
line=dict(color='red', width=2, dash='dash'),
|
| 178 |
-
marker=dict(size=4)
|
| 179 |
-
))
|
| 180 |
-
|
| 181 |
-
# 구분선 추가
|
| 182 |
-
fig.add_vline(x=0, line_dash="dot", line_color="gray",
|
| 183 |
-
annotation_text="예측 시작점")
|
| 184 |
-
|
| 185 |
-
fig.update_layout(
|
| 186 |
-
title=f'{station_name} 잔차조위 예측 결과',
|
| 187 |
-
xaxis_title='시간 순서 (5분 간격)',
|
| 188 |
-
yaxis_title='잔차조위 (cm)',
|
| 189 |
-
hovermode='x unified',
|
| 190 |
-
height=500,
|
| 191 |
-
showlegend=True
|
| 192 |
)
|
| 193 |
-
|
| 194 |
-
return fig
|
| 195 |
-
|
| 196 |
-
def create_evaluation_plot(predictions, truths, station_name):
|
| 197 |
-
"""평가 결과 시각화"""
|
| 198 |
-
# 성능 메트릭 계산
|
| 199 |
-
mae = np.mean(np.abs(predictions - truths))
|
| 200 |
-
mse = np.mean((predictions - truths) ** 2)
|
| 201 |
-
rmse = np.sqrt(mse)
|
| 202 |
-
|
| 203 |
-
# 서브플롯 생성
|
| 204 |
-
fig = make_subplots(
|
| 205 |
-
rows=2, cols=2,
|
| 206 |
-
subplot_titles=[
|
| 207 |
-
'첫 번째 샘플: 예측 vs 실제',
|
| 208 |
-
'전체 샘플 MAE 분포',
|
| 209 |
-
'예측값 vs 실제값 산점도 (처음 1000개 포인트)',
|
| 210 |
-
'시간별 오차 분포'
|
| 211 |
-
],
|
| 212 |
-
specs=[[{"secondary_y": False}, {"secondary_y": False}],
|
| 213 |
-
[{"secondary_y": False}, {"secondary_y": False}]]
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
# 1. 첫 번째 샘플 비교
|
| 217 |
-
time_steps = list(range(72))
|
| 218 |
-
fig.add_trace(
|
| 219 |
-
go.Scatter(x=time_steps, y=predictions[0].flatten(),
|
| 220 |
-
name='예측', line=dict(color='red')),
|
| 221 |
-
row=1, col=1
|
| 222 |
-
)
|
| 223 |
-
fig.add_trace(
|
| 224 |
-
go.Scatter(x=time_steps, y=truths[0].flatten(),
|
| 225 |
-
name='실제', line=dict(color='blue')),
|
| 226 |
-
row=1, col=1
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
# 2. MAE 분포
|
| 230 |
-
sample_maes = np.mean(np.abs(predictions - truths), axis=(1,2))
|
| 231 |
-
fig.add_trace(
|
| 232 |
-
go.Histogram(x=sample_maes, name='MAE 분포', nbinsx=20),
|
| 233 |
-
row=1, col=2
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
# 3. 산점도 (메모리 절약을 위해 처음 1000개 포인트만)
|
| 237 |
-
pred_flat = predictions.flatten()[:1000]
|
| 238 |
-
true_flat = truths.flatten()[:1000]
|
| 239 |
-
fig.add_trace(
|
| 240 |
-
go.Scatter(x=true_flat, y=pred_flat, mode='markers',
|
| 241 |
-
name='예측 vs 실제', marker=dict(size=3, opacity=0.6)),
|
| 242 |
-
row=2, col=1
|
| 243 |
-
)
|
| 244 |
-
# 완벽한 예측선 추가
|
| 245 |
-
min_val, max_val = min(min(pred_flat), min(true_flat)), max(max(pred_flat), max(true_flat))
|
| 246 |
-
fig.add_trace(
|
| 247 |
-
go.Scatter(x=[min_val, max_val], y=[min_val, max_val],
|
| 248 |
-
mode='lines', name='완벽한 예측', line=dict(dash='dash', color='gray')),
|
| 249 |
-
row=2, col=1
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
# 4. 시간별 오차
|
| 253 |
-
time_errors = np.mean(np.abs(predictions - truths), axis=0).flatten()
|
| 254 |
-
fig.add_trace(
|
| 255 |
-
go.Scatter(x=time_steps, y=time_errors,
|
| 256 |
-
name='시간별 평균 오차', line=dict(color='orange')),
|
| 257 |
-
row=2, col=2
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
fig.update_layout(
|
| 261 |
-
title=f'{station_name} 성능 평가 결과<br>MAE: {mae:.3f} | MSE: {mse:.3f} | RMSE: {rmse:.3f}',
|
| 262 |
-
height=800,
|
| 263 |
-
showlegend=False
|
| 264 |
-
)
|
| 265 |
-
|
| 266 |
-
return fig, {"MAE": mae, "MSE": mse, "RMSE": rmse}
|
| 267 |
|
| 268 |
def single_prediction(station_id, input_csv_file):
|
| 269 |
if input_csv_file is None:
|
|
@@ -292,7 +304,7 @@ def single_prediction(station_id, input_csv_file):
|
|
| 292 |
"--scaler_path", scaler_path,
|
| 293 |
"--predict_input_file", input_csv_file.name] + common_args
|
| 294 |
|
| 295 |
-
gr.Info(f"{station_name}({station_id}) 조위 예측을 실행중입니다...")
|
| 296 |
|
| 297 |
# inference 실행
|
| 298 |
success, output = execute_inference_and_get_results(command)
|
|
@@ -301,51 +313,41 @@ def single_prediction(station_id, input_csv_file):
|
|
| 301 |
try:
|
| 302 |
prediction_file = "pred_results/prediction_future.npy"
|
| 303 |
if os.path.exists(prediction_file):
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
-
#
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
# 예측 결과 테이블 생성
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
last_datetime = pd.to_datetime(last_date).to_pydatetime()
|
| 321 |
-
|
| 322 |
-
# 5분 간격으로 72개 시점 생성
|
| 323 |
-
time_points = []
|
| 324 |
-
for i in range(72):
|
| 325 |
-
minutes_to_add = (i + 1) * 5
|
| 326 |
-
next_time = last_datetime + datetime.timedelta(minutes=minutes_to_add)
|
| 327 |
-
time_points.append(next_time)
|
| 328 |
-
|
| 329 |
-
result_df = pd.DataFrame({
|
| 330 |
-
'예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in time_points],
|
| 331 |
-
'예측 잔차조위 (cm)': [f"{val:.2f}" for val in predictions]
|
| 332 |
-
})
|
| 333 |
-
|
| 334 |
-
except Exception as e:
|
| 335 |
-
print(f"Table creation error: {e}")
|
| 336 |
-
# 에러 시 간단한 테이블 생성
|
| 337 |
-
result_df = pd.DataFrame({
|
| 338 |
-
'예측 순서': [f"{i+1}번째 (+{(i+1)*5}분)" for i in range(len(predictions))],
|
| 339 |
-
'예측 잔차조위 (cm)': [f"{val:.2f}" for val in predictions]
|
| 340 |
-
})
|
| 341 |
|
| 342 |
-
return plot, result_df, f"✅ 예측
|
| 343 |
else:
|
| 344 |
return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
|
| 345 |
|
| 346 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
| 347 |
return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
|
| 348 |
-
|
| 349 |
def rolling_evaluation(station_id, eval_csv_file):
|
| 350 |
if eval_csv_file is None:
|
| 351 |
raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
|
|
@@ -405,17 +407,22 @@ def rolling_evaluation(station_id, eval_csv_file):
|
|
| 405 |
return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
|
| 406 |
|
| 407 |
# Gradio 인터페이스
|
| 408 |
-
with gr.Blocks(title="조위 예측 모델
|
| 409 |
-
gr.Markdown("# 🌊 조위 예측
|
| 410 |
-
gr.Markdown("TimeXer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
with gr.Tabs():
|
| 413 |
-
with gr.TabItem("🔮
|
| 414 |
gr.Markdown("""
|
| 415 |
-
###
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
| 419 |
""")
|
| 420 |
|
| 421 |
with gr.Row():
|
|
@@ -432,47 +439,15 @@ with gr.Blocks(title="조위 예측 모델 v4.2", theme=gr.themes.Soft()) as dem
|
|
| 432 |
file_count="single"
|
| 433 |
)
|
| 434 |
|
| 435 |
-
submit_btn1 = gr.Button("🚀 예측 실행", variant="primary", size="lg")
|
| 436 |
|
| 437 |
with gr.Row():
|
| 438 |
with gr.Column(scale=2):
|
| 439 |
-
plot_output1 = gr.Plot(label="예측 결과 시각화")
|
| 440 |
with gr.Column(scale=1):
|
| 441 |
table_output1 = gr.Dataframe(label="예측 결과 테이블")
|
| 442 |
|
| 443 |
text_output1 = gr.Textbox(label="실행 로그", lines=5, show_copy_button=True)
|
| 444 |
-
|
| 445 |
-
with gr.TabItem("📊 전체 기간 롤링 평가"):
|
| 446 |
-
gr.Markdown("""
|
| 447 |
-
### 사용 방법
|
| 448 |
-
1. 관측소를 선택하세요
|
| 449 |
-
2. 평가용 전체 데이터 CSV 파일을 업로드하세요 (최소 300행 필요)
|
| 450 |
-
3. 평가 실행 → 성능 메트릭과 예측 정확도를 확인하세요
|
| 451 |
-
""")
|
| 452 |
-
|
| 453 |
-
with gr.Row():
|
| 454 |
-
with gr.Column(scale=1):
|
| 455 |
-
station_dropdown2 = gr.Dropdown(
|
| 456 |
-
choices=[(f"{STATION_NAMES[s]} ({s})", s) for s in STATIONS],
|
| 457 |
-
label="관측소 선택",
|
| 458 |
-
value=STATIONS[0]
|
| 459 |
-
)
|
| 460 |
-
with gr.Column(scale=2):
|
| 461 |
-
file_input2 = gr.File(
|
| 462 |
-
label="평가용 전체 데이터 (.csv 파일)",
|
| 463 |
-
file_types=[".csv"],
|
| 464 |
-
file_count="single"
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
submit_btn2 = gr.Button("📈 평가 실행", variant="primary", size="lg")
|
| 468 |
-
|
| 469 |
-
with gr.Row():
|
| 470 |
-
with gr.Column(scale=2):
|
| 471 |
-
plot_output2 = gr.Plot(label="평가 결과 시각화")
|
| 472 |
-
with gr.Column(scale=1):
|
| 473 |
-
table_output2 = gr.Dataframe(label="성능 메트릭")
|
| 474 |
-
|
| 475 |
-
text_output2 = gr.Textbox(label="실행 로그", lines=5, show_copy_button=True)
|
| 476 |
|
| 477 |
# 이벤트 바인딩
|
| 478 |
submit_btn1.click(
|
|
@@ -481,42 +456,19 @@ with gr.Blocks(title="조위 예측 모델 v4.2", theme=gr.themes.Soft()) as dem
|
|
| 481 |
outputs=[plot_output1, table_output1, text_output1],
|
| 482 |
show_progress=True
|
| 483 |
)
|
| 484 |
-
submit_btn2.click(
|
| 485 |
-
fn=rolling_evaluation,
|
| 486 |
-
inputs=[station_dropdown2, file_input2],
|
| 487 |
-
outputs=[plot_output2, table_output2, text_output2],
|
| 488 |
-
show_progress=True
|
| 489 |
-
)
|
| 490 |
|
| 491 |
# 사용 안내
|
| 492 |
gr.Markdown("""
|
| 493 |
-
##
|
| 494 |
-
|
| 495 |
-
```csv
|
| 496 |
-
date,air_pres,wind_dir,wind_speed,air_temp,residual
|
| 497 |
-
2024-01-01 00:00:00,1013.2,180,3.4,15.2,120.5
|
| 498 |
-
2024-01-01 00:05:00,1013.1,185,3.2,15.1,118.3
|
| 499 |
-
2024-01-01 00:10:00,1012.9,190,3.0,15.0,116.1
|
| 500 |
-
...
|
| 501 |
-
```
|
| 502 |
-
|
| 503 |
-
**컬럼 설명:**
|
| 504 |
-
- **date**: 관측 시간 (5분 간격)
|
| 505 |
-
- **air_pres**: 기압 (hPa)
|
| 506 |
-
- **wind_dir**: 풍향 (도)
|
| 507 |
-
- **wind_speed**: 풍속 (m/s)
|
| 508 |
-
- **air_temp**: 기온 (°C)
|
| 509 |
-
- **residual**: 잔차 조위 (cm) - 예측 대상
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
- 평가용: 최소 300행 데이터 필요
|
| 517 |
-
- 5분 간격 데이터 사용
|
| 518 |
-
- 처리 시간: 1-5분 소요 가능
|
| 519 |
-
- **예측 대상**: residual (잔차 조위) 컬럼
|
| 520 |
""")
|
| 521 |
|
| 522 |
if __name__ == "__main__":
|
|
|
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
from plotly.subplots import make_subplots
|
| 9 |
import plotly.express as px
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
import warnings
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
+
|
| 14 |
+
# Supabase 연동 추가
|
| 15 |
+
try:
|
| 16 |
+
from supabase import create_client, Client
|
| 17 |
+
SUPABASE_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
SUPABASE_AVAILABLE = False
|
| 20 |
+
print("Supabase 패키지가 설치되지 않았습니다.")
|
| 21 |
|
| 22 |
STATIONS = [
|
| 23 |
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
|
|
|
|
| 26 |
]
|
| 27 |
|
| 28 |
STATION_NAMES = {
|
| 29 |
+
"DT_0001": "인천", "DT_0002": "평택", "DT_0003": "영광", "DT_0008": "안산",
|
| 30 |
+
"DT_0017": "대산", "DT_0018": "군산", "DT_0024": "장항", "DT_0025": "보령",
|
| 31 |
+
"DT_0037": "어청도", "DT_0043": "영흥도", "DT_0050": "태안", "DT_0051": "서천마량",
|
| 32 |
+
"DT_0052": "인천송도", "DT_0065": "덕적도", "DT_0066": "향화도", "DT_0067": "안흥",
|
| 33 |
+
"DT_0068": "위도"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
}
|
| 35 |
|
| 36 |
+
def get_supabase_client():
|
| 37 |
+
"""Supabase 클라이언트 생성"""
|
| 38 |
+
if not SUPABASE_AVAILABLE:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
url = os.getenv("SUPABASE_URL")
|
| 43 |
+
key = os.getenv("SUPABASE_KEY")
|
| 44 |
+
|
| 45 |
+
if not url or not key:
|
| 46 |
+
print("Supabase 환경변수가 설정되지 않았습니다.")
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
return create_client(url, key)
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Supabase 연결 오류: {e}")
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def get_harmonic_predictions(station_id, start_time, end_time):
|
| 55 |
+
"""해당 시간 범위의 조화 예측값 조회"""
|
| 56 |
+
supabase = get_supabase_client()
|
| 57 |
+
if not supabase:
|
| 58 |
+
return []
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
result = supabase.table('harmonic_predictions')\
|
| 62 |
+
.select('predicted_at, harmonic_level')\
|
| 63 |
+
.eq('station_id', station_id)\
|
| 64 |
+
.gte('predicted_at', start_time.isoformat())\
|
| 65 |
+
.lte('predicted_at', end_time.isoformat())\
|
| 66 |
+
.order('predicted_at')\
|
| 67 |
+
.execute()
|
| 68 |
+
|
| 69 |
+
return result.data if result.data else []
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"조화 예측값 조회 오류: {e}")
|
| 72 |
+
return []
|
| 73 |
+
|
| 74 |
+
def calculate_final_tide(residual_predictions, station_id, last_time):
|
| 75 |
+
"""잔차 예측 + 조화 예측 = 최종 조위 계산"""
|
| 76 |
+
# 미래 시간 범위 계산 (72개 포인트, 5분 간격)
|
| 77 |
+
start_time = last_time + timedelta(minutes=5)
|
| 78 |
+
end_time = last_time + timedelta(minutes=72*5)
|
| 79 |
+
|
| 80 |
+
# 조화 예측값 조회
|
| 81 |
+
harmonic_data = get_harmonic_predictions(station_id, start_time, end_time)
|
| 82 |
+
|
| 83 |
+
if not harmonic_data:
|
| 84 |
+
print("조화 예측 데이터를 찾을 수 없습니다. 잔차 예측만 반환합니다.")
|
| 85 |
+
return {
|
| 86 |
+
'times': [last_time + timedelta(minutes=(i+1)*5) for i in range(len(residual_predictions))],
|
| 87 |
+
'residual': residual_predictions.flatten(),
|
| 88 |
+
'harmonic': [0] * len(residual_predictions),
|
| 89 |
+
'final_tide': residual_predictions.flatten()
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# 시간별 매칭
|
| 93 |
+
final_results = {
|
| 94 |
+
'times': [],
|
| 95 |
+
'residual': [],
|
| 96 |
+
'harmonic': [],
|
| 97 |
+
'final_tide': []
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
for i, residual in enumerate(residual_predictions.flatten()):
|
| 101 |
+
pred_time = last_time + timedelta(minutes=(i+1)*5)
|
| 102 |
+
|
| 103 |
+
# 가장 가까운 조화 예측값 찾기
|
| 104 |
+
harmonic_value = 0
|
| 105 |
+
for h_data in harmonic_data:
|
| 106 |
+
h_time = datetime.fromisoformat(h_data['predicted_at'].replace('Z', '+00:00'))
|
| 107 |
+
if abs((h_time - pred_time).total_seconds()) < 300: # 5분 이내
|
| 108 |
+
harmonic_value = h_data['harmonic_level']
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
final_tide = residual + harmonic_value
|
| 112 |
+
|
| 113 |
+
final_results['times'].append(pred_time)
|
| 114 |
+
final_results['residual'].append(residual)
|
| 115 |
+
final_results['harmonic'].append(harmonic_value)
|
| 116 |
+
final_results['final_tide'].append(final_tide)
|
| 117 |
+
|
| 118 |
+
return final_results
|
| 119 |
+
|
| 120 |
+
def save_predictions_to_supabase(station_id, prediction_results):
|
| 121 |
+
"""예측 결과를 Supabase에 저장"""
|
| 122 |
+
supabase = get_supabase_client()
|
| 123 |
+
if not supabase:
|
| 124 |
+
return 0
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
insert_data = []
|
| 128 |
+
for i in range(len(prediction_results['times'])):
|
| 129 |
+
insert_data.append({
|
| 130 |
+
'station_id': station_id,
|
| 131 |
+
'predicted_at': prediction_results['times'][i].isoformat(),
|
| 132 |
+
'predicted_residual': float(prediction_results['residual'][i]),
|
| 133 |
+
'harmonic_level': float(prediction_results['harmonic'][i]),
|
| 134 |
+
'final_tide_level': float(prediction_results['final_tide'][i]),
|
| 135 |
+
'model_version': 'TimeXer-v1'
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
result = supabase.table('tide_predictions')\
|
| 139 |
+
.upsert(insert_data, on_conflict='station_id,predicted_at,model_version')\
|
| 140 |
+
.execute()
|
| 141 |
+
|
| 142 |
+
return len(insert_data)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"예측 결과 저장 오류: {e}")
|
| 145 |
+
return 0
|
| 146 |
+
|
| 147 |
def get_common_args(station_id):
|
| 148 |
return [
|
| 149 |
"--model", "TimeXer", "--features", "MS", "--seq_len", "144", "--pred_len", "72",
|
|
|
|
| 192 |
except Exception as e:
|
| 193 |
raise gr.Error(f"내부 오류: {str(e)}")
|
| 194 |
|
| 195 |
+
def create_enhanced_prediction_plot(prediction_results, input_data, station_name):
|
| 196 |
+
"""잔차 + 조화 + 최종 조위를 모두 표시하는 향상된 플롯"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
try:
|
| 198 |
+
# 입력 데이터에서 시간 정보 추출
|
| 199 |
+
input_df = pd.read_csv(input_data.name)
|
| 200 |
+
input_df['date'] = pd.to_datetime(input_df['date'])
|
| 201 |
+
|
| 202 |
+
# 최근 24포인트(2시간)만 표시
|
| 203 |
+
recent_data = input_df.tail(24)
|
| 204 |
|
| 205 |
+
# 미래 시간들
|
| 206 |
+
future_times = prediction_results['times']
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
# 플롯 생성
|
| 209 |
fig = go.Figure()
|
|
|
|
| 218 |
marker=dict(size=4)
|
| 219 |
))
|
| 220 |
|
| 221 |
+
# 잔차 예측
|
| 222 |
fig.add_trace(go.Scatter(
|
| 223 |
x=future_times,
|
| 224 |
+
y=prediction_results['residual'],
|
| 225 |
mode='lines+markers',
|
| 226 |
+
name='잔차 예측',
|
| 227 |
line=dict(color='red', width=2, dash='dash'),
|
| 228 |
+
marker=dict(size=3)
|
| 229 |
+
))
|
| 230 |
+
|
| 231 |
+
# 조화 예측
|
| 232 |
+
fig.add_trace(go.Scatter(
|
| 233 |
+
x=future_times,
|
| 234 |
+
y=prediction_results['harmonic'],
|
| 235 |
+
mode='lines',
|
| 236 |
+
name='조화 예측',
|
| 237 |
+
line=dict(color='orange', width=2)
|
| 238 |
+
))
|
| 239 |
+
|
| 240 |
+
# 최종 조위 (잔차 + 조화)
|
| 241 |
+
fig.add_trace(go.Scatter(
|
| 242 |
+
x=future_times,
|
| 243 |
+
y=prediction_results['final_tide'],
|
| 244 |
+
mode='lines+markers',
|
| 245 |
+
name='최종 조위',
|
| 246 |
+
line=dict(color='green', width=3),
|
| 247 |
marker=dict(size=4)
|
| 248 |
))
|
| 249 |
|
| 250 |
# 구분선 추가
|
| 251 |
+
last_time = input_df['date'].iloc[-1]
|
| 252 |
fig.add_vline(x=last_time, line_dash="dot", line_color="gray",
|
| 253 |
annotation_text="예측 시작점")
|
| 254 |
|
| 255 |
fig.update_layout(
|
| 256 |
+
title=f'{station_name} 통합 조위 예측 결과',
|
| 257 |
xaxis_title='시간',
|
| 258 |
+
yaxis_title='수위 (cm)',
|
| 259 |
hovermode='x unified',
|
| 260 |
+
height=600,
|
| 261 |
showlegend=True,
|
| 262 |
xaxis=dict(
|
| 263 |
+
tickformat='%H:%M<br>%m/%d'
|
| 264 |
)
|
| 265 |
)
|
| 266 |
|
| 267 |
+
return fig
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
+
except Exception as e:
|
| 270 |
+
print(f"Enhanced plot creation error: {e}")
|
| 271 |
+
# 에러 시 기본 플롯 반환
|
| 272 |
fig = go.Figure()
|
| 273 |
+
fig.add_annotation(
|
| 274 |
+
text=f"시각화 생성 중 오류: {str(e)}",
|
| 275 |
+
xref="paper", yref="paper",
|
| 276 |
+
x=0.5, y=0.5, showarrow=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
)
|
| 278 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
def single_prediction(station_id, input_csv_file):
|
| 281 |
if input_csv_file is None:
|
|
|
|
| 304 |
"--scaler_path", scaler_path,
|
| 305 |
"--predict_input_file", input_csv_file.name] + common_args
|
| 306 |
|
| 307 |
+
gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...")
|
| 308 |
|
| 309 |
# inference 실행
|
| 310 |
success, output = execute_inference_and_get_results(command)
|
|
|
|
| 313 |
try:
|
| 314 |
prediction_file = "pred_results/prediction_future.npy"
|
| 315 |
if os.path.exists(prediction_file):
|
| 316 |
+
residual_predictions = np.load(prediction_file)
|
| 317 |
+
|
| 318 |
+
# 입력 데이터에서 마지막 시간 추출
|
| 319 |
+
input_df = pd.read_csv(input_csv_file.name)
|
| 320 |
+
input_df['date'] = pd.to_datetime(input_df['date'])
|
| 321 |
+
last_time = input_df['date'].iloc[-1]
|
| 322 |
|
| 323 |
+
# 최종 조위 계산 (��차 + 조화)
|
| 324 |
+
prediction_results = calculate_final_tide(residual_predictions, station_id, last_time)
|
| 325 |
+
|
| 326 |
+
# 향상된 시각화 생성
|
| 327 |
+
plot = create_enhanced_prediction_plot(prediction_results, input_csv_file, station_name)
|
| 328 |
|
| 329 |
# 예측 결과 테이블 생성
|
| 330 |
+
result_df = pd.DataFrame({
|
| 331 |
+
'예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']],
|
| 332 |
+
'잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']],
|
| 333 |
+
'조화 예측 (cm)': [f"{val:.2f}" for val in prediction_results['harmonic']],
|
| 334 |
+
'최종 조위 (cm)': [f"{val:.2f}" for val in prediction_results['final_tide']]
|
| 335 |
+
})
|
| 336 |
+
|
| 337 |
+
# Supabase에 결과 저장
|
| 338 |
+
saved_count = save_predictions_to_supabase(station_id, prediction_results)
|
| 339 |
+
save_message = f"\n💾 Supabase에 {saved_count}개 예측 결과 저장 완료!" if saved_count > 0 else "\n⚠️ Supabase 저장 실패"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
+
return plot, result_df, f"✅ 통합 예측 완료!{save_message}\n\n{output}"
|
| 342 |
else:
|
| 343 |
return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
|
| 344 |
|
| 345 |
except Exception as e:
|
| 346 |
+
print(f"Result processing error: {e}")
|
| 347 |
+
import traceback
|
| 348 |
+
traceback.print_exc()
|
| 349 |
return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
|
| 350 |
+
|
| 351 |
def rolling_evaluation(station_id, eval_csv_file):
|
| 352 |
if eval_csv_file is None:
|
| 353 |
raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
|
|
|
|
| 407 |
return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
|
| 408 |
|
| 409 |
# Gradio 인터페이스
|
| 410 |
+
with gr.Blocks(title="조위 예측 모델 v5.0", theme=gr.themes.Soft()) as demo:
|
| 411 |
+
gr.Markdown("# 🌊 통합 조위 예측 시스템 v5.0 (TimeXer + 조화분석)")
|
| 412 |
+
gr.Markdown("TimeXer 잔차 예측 + MATLAB 조화 예측을 결합한 최종 조위 예측 시스템입니다.")
|
| 413 |
+
|
| 414 |
+
# Supabase 연결 상태 표시
|
| 415 |
+
supabase_status = "🟢 연결됨" if get_supabase_client() else "🔴 연결 안됨"
|
| 416 |
+
gr.Markdown(f"**Supabase 상태**: {supabase_status}")
|
| 417 |
|
| 418 |
with gr.Tabs():
|
| 419 |
+
with gr.TabItem("🔮 통합 조위 예측"):
|
| 420 |
gr.Markdown("""
|
| 421 |
+
### 🌟 새로운 기능
|
| 422 |
+
- **잔차 예측**: TimeXer 모델로 기상 영향 예측
|
| 423 |
+
- **조화 예측**: MATLAB 조화분석으로 천체 영향 예측
|
| 424 |
+
- **최종 조위**: 잔차 + 조화 = 완전한 조위 예측
|
| 425 |
+
- **자동 저장**: 예측 결과를 데이터베이스에 자동 저장
|
| 426 |
""")
|
| 427 |
|
| 428 |
with gr.Row():
|
|
|
|
| 439 |
file_count="single"
|
| 440 |
)
|
| 441 |
|
| 442 |
+
submit_btn1 = gr.Button("🚀 통합 예측 실행", variant="primary", size="lg")
|
| 443 |
|
| 444 |
with gr.Row():
|
| 445 |
with gr.Column(scale=2):
|
| 446 |
+
plot_output1 = gr.Plot(label="통합 예측 결과 시각화")
|
| 447 |
with gr.Column(scale=1):
|
| 448 |
table_output1 = gr.Dataframe(label="예측 결과 테이블")
|
| 449 |
|
| 450 |
text_output1 = gr.Textbox(label="실행 로그", lines=5, show_copy_button=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
# 이벤트 바인딩
|
| 453 |
submit_btn1.click(
|
|
|
|
| 456 |
outputs=[plot_output1, table_output1, text_output1],
|
| 457 |
show_progress=True
|
| 458 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
# 사용 안내
|
| 461 |
gr.Markdown("""
|
| 462 |
+
## 🎯 통합 예측 시스템 특징
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
|
| 464 |
+
### 📊 4가지 예측 라인
|
| 465 |
+
- **파란색**: 실제 잔차조위 (과거 데이터)
|
| 466 |
+
- **빨간색**: 잔차 예측 (기상 영향)
|
| 467 |
+
- **주황색**: 조화 예측 (천체 영향)
|
| 468 |
+
- **초록색**: 최종 조위 (잔차 + 조화)
|
| 469 |
|
| 470 |
+
### 💾 데이터 저장
|
| 471 |
+
모든 예측 결과는 Supabase 데이터베이스에 자동 저장되어 이후 분석 및 API 활용이 가능합니다.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
""")
|
| 473 |
|
| 474 |
if __name__ == "__main__":
|