alwaysgood commited on
Commit
301ef28
·
verified ·
1 Parent(s): 95421e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -271
app.py CHANGED
@@ -7,6 +7,17 @@ import pandas as pd
7
  import plotly.graph_objects as go
8
  from plotly.subplots import make_subplots
9
  import plotly.express as px
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  STATIONS = [
12
  "DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
@@ -15,25 +26,124 @@ STATIONS = [
15
  ]
16
 
17
  STATION_NAMES = {
18
- "DT_0001": "인천",
19
- "DT_0065": "덕적도",
20
- "DT_0008": "안산",
21
- "DT_0067": "안흥",
22
- "DT_0043": "영흥도",
23
- "DT_0002": "평택",
24
- "DT_0050": "태안",
25
- "DT_0017": "대산",
26
- "DT_0052": "인천송도",
27
- "DT_0025": "보령",
28
- "DT_0051": "서천마량",
29
- "DT_0037": "어청도",
30
- "DT_0024": "장항",
31
- "DT_0018": "군산",
32
- "DT_0068": "위도",
33
- "DT_0003": "영광",
34
- "DT_0066": "향화도"
35
  }
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def get_common_args(station_id):
38
  return [
39
  "--model", "TimeXer", "--features", "MS", "--seq_len", "144", "--pred_len", "72",
@@ -82,32 +192,18 @@ def execute_inference_and_get_results(command):
82
  except Exception as e:
83
  raise gr.Error(f"내부 오류: {str(e)}")
84
 
85
- def create_prediction_plot(predictions, input_data, station_name):
86
- """예측 결과 시각화"""
87
- print(f"Creating plot - predictions type: {type(predictions)}, shape: {predictions.shape}")
88
-
89
- # 입력 데이터에서 시간 정보 추출
90
- input_df = pd.read_csv(input_data.name)
91
- input_df['date'] = pd.to_datetime(input_df['date'])
92
-
93
- # 최근 24포인트(2시간)만 표시
94
- recent_data = input_df.tail(24)
95
-
96
- # predictions를 안전하게 변환
97
- pred_values = np.array(predictions).flatten()
98
- print(f"Prediction values shape: {pred_values.shape}, first 5: {pred_values[:5]}")
99
-
100
- # 미래 시간 생성 - 더 안전한 방법
101
  try:
102
- # 마지막 시간부터 시작
103
- last_time = input_df['date'].iloc[-1]
 
 
 
 
104
 
105
- # 5분 간격으로 미래 시간 생성 - pandas.date_range 사용
106
- future_times = pd.date_range(
107
- start=last_time,
108
- periods=len(pred_values) + 1, # +1을 해서 시작점 포함
109
- freq='5min'
110
- )[1:] # 첫 번째 (현재 시간) 제외
111
 
112
  # 플롯 생성
113
  fig = go.Figure()
@@ -122,148 +218,64 @@ def create_prediction_plot(predictions, input_data, station_name):
122
  marker=dict(size=4)
123
  ))
124
 
125
- # 예측 데이터
126
  fig.add_trace(go.Scatter(
127
  x=future_times,
128
- y=pred_values,
129
  mode='lines+markers',
130
- name='예측 잔차조위',
131
  line=dict(color='red', width=2, dash='dash'),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  marker=dict(size=4)
133
  ))
134
 
135
  # 구분선 추가
 
136
  fig.add_vline(x=last_time, line_dash="dot", line_color="gray",
137
  annotation_text="예측 시작점")
138
 
139
  fig.update_layout(
140
- title=f'{station_name} 잔차조위 예측 결과',
141
  xaxis_title='시간',
142
- yaxis_title='잔차조위 (cm)',
143
  hovermode='x unified',
144
- height=500,
145
  showlegend=True,
146
  xaxis=dict(
147
- tickformat='%H:%M<br>%m/%d' # 시:분 + 월/일
148
  )
149
  )
150
 
151
- print("Plot with time axis created successfully")
152
-
153
- except Exception as time_error:
154
- print(f"Time axis creation failed: {time_error}, falling back to sequence")
155
 
156
- # 시간축 실패 시 순서 기반으로 fallback
 
 
157
  fig = go.Figure()
158
-
159
- # 과거 데이터 (최근 2시간) - 순서 기반
160
- recent_indices = list(range(-len(recent_data), 0))
161
- fig.add_trace(go.Scatter(
162
- x=recent_indices,
163
- y=recent_data['residual'],
164
- mode='lines+markers',
165
- name='실제 잔차조위',
166
- line=dict(color='blue', width=2),
167
- marker=dict(size=4)
168
- ))
169
-
170
- # 예측 데이터 - 순서 기반
171
- future_indices = list(range(1, len(pred_values) + 1))
172
- fig.add_trace(go.Scatter(
173
- x=future_indices,
174
- y=pred_values,
175
- mode='lines+markers',
176
- name='예측 잔차조위',
177
- line=dict(color='red', width=2, dash='dash'),
178
- marker=dict(size=4)
179
- ))
180
-
181
- # 구분선 추가
182
- fig.add_vline(x=0, line_dash="dot", line_color="gray",
183
- annotation_text="예측 시작점")
184
-
185
- fig.update_layout(
186
- title=f'{station_name} 잔차조위 예측 결과',
187
- xaxis_title='시간 순서 (5분 간격)',
188
- yaxis_title='잔차조위 (cm)',
189
- hovermode='x unified',
190
- height=500,
191
- showlegend=True
192
  )
193
-
194
- return fig
195
-
196
- def create_evaluation_plot(predictions, truths, station_name):
197
- """평가 결과 시각화"""
198
- # 성능 메트릭 계산
199
- mae = np.mean(np.abs(predictions - truths))
200
- mse = np.mean((predictions - truths) ** 2)
201
- rmse = np.sqrt(mse)
202
-
203
- # 서브플롯 생성
204
- fig = make_subplots(
205
- rows=2, cols=2,
206
- subplot_titles=[
207
- '첫 번째 샘플: 예측 vs 실제',
208
- '전체 샘플 MAE 분포',
209
- '예측값 vs 실제값 산점도 (처음 1000개 포인트)',
210
- '시간별 오차 분포'
211
- ],
212
- specs=[[{"secondary_y": False}, {"secondary_y": False}],
213
- [{"secondary_y": False}, {"secondary_y": False}]]
214
- )
215
-
216
- # 1. 첫 번째 샘플 비교
217
- time_steps = list(range(72))
218
- fig.add_trace(
219
- go.Scatter(x=time_steps, y=predictions[0].flatten(),
220
- name='예측', line=dict(color='red')),
221
- row=1, col=1
222
- )
223
- fig.add_trace(
224
- go.Scatter(x=time_steps, y=truths[0].flatten(),
225
- name='실제', line=dict(color='blue')),
226
- row=1, col=1
227
- )
228
-
229
- # 2. MAE 분포
230
- sample_maes = np.mean(np.abs(predictions - truths), axis=(1,2))
231
- fig.add_trace(
232
- go.Histogram(x=sample_maes, name='MAE 분포', nbinsx=20),
233
- row=1, col=2
234
- )
235
-
236
- # 3. 산점도 (메모리 절약을 위해 처음 1000개 포인트만)
237
- pred_flat = predictions.flatten()[:1000]
238
- true_flat = truths.flatten()[:1000]
239
- fig.add_trace(
240
- go.Scatter(x=true_flat, y=pred_flat, mode='markers',
241
- name='예측 vs 실제', marker=dict(size=3, opacity=0.6)),
242
- row=2, col=1
243
- )
244
- # 완벽한 예측선 추가
245
- min_val, max_val = min(min(pred_flat), min(true_flat)), max(max(pred_flat), max(true_flat))
246
- fig.add_trace(
247
- go.Scatter(x=[min_val, max_val], y=[min_val, max_val],
248
- mode='lines', name='완벽한 예측', line=dict(dash='dash', color='gray')),
249
- row=2, col=1
250
- )
251
-
252
- # 4. 시간별 오차
253
- time_errors = np.mean(np.abs(predictions - truths), axis=0).flatten()
254
- fig.add_trace(
255
- go.Scatter(x=time_steps, y=time_errors,
256
- name='시간별 평균 오차', line=dict(color='orange')),
257
- row=2, col=2
258
- )
259
-
260
- fig.update_layout(
261
- title=f'{station_name} 성능 평가 결과<br>MAE: {mae:.3f} | MSE: {mse:.3f} | RMSE: {rmse:.3f}',
262
- height=800,
263
- showlegend=False
264
- )
265
-
266
- return fig, {"MAE": mae, "MSE": mse, "RMSE": rmse}
267
 
268
  def single_prediction(station_id, input_csv_file):
269
  if input_csv_file is None:
@@ -292,7 +304,7 @@ def single_prediction(station_id, input_csv_file):
292
  "--scaler_path", scaler_path,
293
  "--predict_input_file", input_csv_file.name] + common_args
294
 
295
- gr.Info(f"{station_name}({station_id}) 조위 예측을 실행중입니다...")
296
 
297
  # inference 실행
298
  success, output = execute_inference_and_get_results(command)
@@ -301,51 +313,41 @@ def single_prediction(station_id, input_csv_file):
301
  try:
302
  prediction_file = "pred_results/prediction_future.npy"
303
  if os.path.exists(prediction_file):
304
- predictions = np.load(prediction_file)
 
 
 
 
 
305
 
306
- # 시각화 생성
307
- plot = create_prediction_plot(predictions, input_csv_file, station_name)
 
 
 
308
 
309
  # 예측 결과 테이블 생성
310
- try:
311
- import datetime
312
- input_df = pd.read_csv(input_csv_file.name)
313
- input_df['date'] = pd.to_datetime(input_df['date'])
314
- last_date = input_df['date'].iloc[-1]
315
-
316
- # datetime 직접 사용으로 pandas 의존성 제거
317
- if isinstance(last_date, pd.Timestamp):
318
- last_datetime = last_date.to_pydatetime()
319
- else:
320
- last_datetime = pd.to_datetime(last_date).to_pydatetime()
321
-
322
- # 5분 간격으로 72개 시점 생성
323
- time_points = []
324
- for i in range(72):
325
- minutes_to_add = (i + 1) * 5
326
- next_time = last_datetime + datetime.timedelta(minutes=minutes_to_add)
327
- time_points.append(next_time)
328
-
329
- result_df = pd.DataFrame({
330
- '예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in time_points],
331
- '예측 잔차조위 (cm)': [f"{val:.2f}" for val in predictions]
332
- })
333
-
334
- except Exception as e:
335
- print(f"Table creation error: {e}")
336
- # 에러 시 간단한 테이블 생성
337
- result_df = pd.DataFrame({
338
- '예측 순서': [f"{i+1}번째 (+{(i+1)*5}분)" for i in range(len(predictions))],
339
- '예측 잔차조위 (cm)': [f"{val:.2f}" for val in predictions]
340
- })
341
 
342
- return plot, result_df, f"✅ 예측 완료!\n\n{output}"
343
  else:
344
  return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
345
 
346
  except Exception as e:
 
 
 
347
  return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
348
-
349
  def rolling_evaluation(station_id, eval_csv_file):
350
  if eval_csv_file is None:
351
  raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
@@ -405,17 +407,22 @@ def rolling_evaluation(station_id, eval_csv_file):
405
  return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
406
 
407
  # Gradio 인터페이스
408
- with gr.Blocks(title="조위 예측 모델 v4.2", theme=gr.themes.Soft()) as demo:
409
- gr.Markdown("# 🌊 조위 예측 모델 v4.2 (TimeXer 기반)")
410
- gr.Markdown("TimeXer 모델을 사용한 한국 연안 17개 지점의 잔차조위 예측 시스템입니다.")
 
 
 
 
411
 
412
  with gr.Tabs():
413
- with gr.TabItem("🔮 단일 미래 예측"):
414
  gr.Markdown("""
415
- ### 사용 방법
416
- 1. 관측소를 선택하세요
417
- 2. 최근 144개 포인트(12시간, 5분 간격)의 CSV 파일을 업로드하세요
418
- 3. 예측 실행 72개 포인트(6시간) 미래 잔차조위 예측 결과를 확인하세요
 
419
  """)
420
 
421
  with gr.Row():
@@ -432,47 +439,15 @@ with gr.Blocks(title="조위 예측 모델 v4.2", theme=gr.themes.Soft()) as dem
432
  file_count="single"
433
  )
434
 
435
- submit_btn1 = gr.Button("🚀 예측 실행", variant="primary", size="lg")
436
 
437
  with gr.Row():
438
  with gr.Column(scale=2):
439
- plot_output1 = gr.Plot(label="예측 결과 시각화")
440
  with gr.Column(scale=1):
441
  table_output1 = gr.Dataframe(label="예측 결과 테이블")
442
 
443
  text_output1 = gr.Textbox(label="실행 로그", lines=5, show_copy_button=True)
444
-
445
- with gr.TabItem("📊 전체 기간 롤링 평가"):
446
- gr.Markdown("""
447
- ### 사용 방법
448
- 1. 관측소를 선택하세요
449
- 2. 평가용 전체 데이터 CSV 파일을 업로드하세요 (최소 300행 필요)
450
- 3. 평가 실행 → 성능 메트릭과 예측 정확도를 확인하세요
451
- """)
452
-
453
- with gr.Row():
454
- with gr.Column(scale=1):
455
- station_dropdown2 = gr.Dropdown(
456
- choices=[(f"{STATION_NAMES[s]} ({s})", s) for s in STATIONS],
457
- label="관측소 선택",
458
- value=STATIONS[0]
459
- )
460
- with gr.Column(scale=2):
461
- file_input2 = gr.File(
462
- label="평가용 전체 데이터 (.csv 파일)",
463
- file_types=[".csv"],
464
- file_count="single"
465
- )
466
-
467
- submit_btn2 = gr.Button("📈 평가 실행", variant="primary", size="lg")
468
-
469
- with gr.Row():
470
- with gr.Column(scale=2):
471
- plot_output2 = gr.Plot(label="평가 결과 시각화")
472
- with gr.Column(scale=1):
473
- table_output2 = gr.Dataframe(label="성능 메트릭")
474
-
475
- text_output2 = gr.Textbox(label="실행 로그", lines=5, show_copy_button=True)
476
 
477
  # 이벤트 바인딩
478
  submit_btn1.click(
@@ -481,42 +456,19 @@ with gr.Blocks(title="조위 예측 모델 v4.2", theme=gr.themes.Soft()) as dem
481
  outputs=[plot_output1, table_output1, text_output1],
482
  show_progress=True
483
  )
484
- submit_btn2.click(
485
- fn=rolling_evaluation,
486
- inputs=[station_dropdown2, file_input2],
487
- outputs=[plot_output2, table_output2, text_output2],
488
- show_progress=True
489
- )
490
 
491
  # 사용 안내
492
  gr.Markdown("""
493
- ## 📋 CSV 파일 형식
494
-
495
- ```csv
496
- date,air_pres,wind_dir,wind_speed,air_temp,residual
497
- 2024-01-01 00:00:00,1013.2,180,3.4,15.2,120.5
498
- 2024-01-01 00:05:00,1013.1,185,3.2,15.1,118.3
499
- 2024-01-01 00:10:00,1012.9,190,3.0,15.0,116.1
500
- ...
501
- ```
502
-
503
- **컬럼 설명:**
504
- - **date**: 관측 시간 (5분 간격)
505
- - **air_pres**: 기압 (hPa)
506
- - **wind_dir**: 풍향 (도)
507
- - **wind_speed**: 풍속 (m/s)
508
- - **air_temp**: 기온 (°C)
509
- - **residual**: 잔차 조위 (cm) - 예측 대상
510
 
511
- ## 🏖️ 지원 관측소
512
- 인천, 덕적도, 안산, 안흥, 영흥도, 평택, 태안, 대산, 인천송도, 보령, 서천마량, 어청도, 장항, 군산, 위도, 영광, 향화도
 
 
 
513
 
514
- ## ⚠️ 주의사항
515
- - 예측용: 최소 144행 (12시간) 데이터 필요
516
- - 평가용: 최소 300행 데이터 필요
517
- - 5분 간격 데이터 사용
518
- - 처리 시간: 1-5분 소요 가능
519
- - **예측 대상**: residual (잔차 조위) 컬럼
520
  """)
521
 
522
  if __name__ == "__main__":
 
7
  import plotly.graph_objects as go
8
  from plotly.subplots import make_subplots
9
  import plotly.express as px
10
+ from datetime import datetime, timedelta
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+ # Supabase 연동 추가
15
+ try:
16
+ from supabase import create_client, Client
17
+ SUPABASE_AVAILABLE = True
18
+ except ImportError:
19
+ SUPABASE_AVAILABLE = False
20
+ print("Supabase 패키지가 설치되지 않았습니다.")
21
 
22
  STATIONS = [
23
  "DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
 
26
  ]
27
 
28
  STATION_NAMES = {
29
+ "DT_0001": "인천", "DT_0002": "평택", "DT_0003": "영광", "DT_0008": "안산",
30
+ "DT_0017": "대산", "DT_0018": "군산", "DT_0024": "장항", "DT_0025": "보령",
31
+ "DT_0037": "어청도", "DT_0043": "영흥도", "DT_0050": "태안", "DT_0051": "서천마량",
32
+ "DT_0052": "인천송도", "DT_0065": "덕적도", "DT_0066": "향화도", "DT_0067": "안흥",
33
+ "DT_0068": "위도"
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
 
36
+ def get_supabase_client():
37
+ """Supabase 클라이언트 생성"""
38
+ if not SUPABASE_AVAILABLE:
39
+ return None
40
+
41
+ try:
42
+ url = os.getenv("SUPABASE_URL")
43
+ key = os.getenv("SUPABASE_KEY")
44
+
45
+ if not url or not key:
46
+ print("Supabase 환경변수가 설정되지 않았습니다.")
47
+ return None
48
+
49
+ return create_client(url, key)
50
+ except Exception as e:
51
+ print(f"Supabase 연결 오류: {e}")
52
+ return None
53
+
54
+ def get_harmonic_predictions(station_id, start_time, end_time):
55
+ """해당 시간 범위의 조화 예측값 조회"""
56
+ supabase = get_supabase_client()
57
+ if not supabase:
58
+ return []
59
+
60
+ try:
61
+ result = supabase.table('harmonic_predictions')\
62
+ .select('predicted_at, harmonic_level')\
63
+ .eq('station_id', station_id)\
64
+ .gte('predicted_at', start_time.isoformat())\
65
+ .lte('predicted_at', end_time.isoformat())\
66
+ .order('predicted_at')\
67
+ .execute()
68
+
69
+ return result.data if result.data else []
70
+ except Exception as e:
71
+ print(f"조화 예측값 조회 오류: {e}")
72
+ return []
73
+
74
+ def calculate_final_tide(residual_predictions, station_id, last_time):
75
+ """잔차 예측 + 조화 예측 = 최종 조위 계산"""
76
+ # 미래 시간 범위 계산 (72개 포인트, 5분 간격)
77
+ start_time = last_time + timedelta(minutes=5)
78
+ end_time = last_time + timedelta(minutes=72*5)
79
+
80
+ # 조화 예측값 조회
81
+ harmonic_data = get_harmonic_predictions(station_id, start_time, end_time)
82
+
83
+ if not harmonic_data:
84
+ print("조화 예측 데이터를 찾을 수 없습니다. 잔차 예측만 반환합니다.")
85
+ return {
86
+ 'times': [last_time + timedelta(minutes=(i+1)*5) for i in range(len(residual_predictions))],
87
+ 'residual': residual_predictions.flatten(),
88
+ 'harmonic': [0] * len(residual_predictions),
89
+ 'final_tide': residual_predictions.flatten()
90
+ }
91
+
92
+ # 시간별 매칭
93
+ final_results = {
94
+ 'times': [],
95
+ 'residual': [],
96
+ 'harmonic': [],
97
+ 'final_tide': []
98
+ }
99
+
100
+ for i, residual in enumerate(residual_predictions.flatten()):
101
+ pred_time = last_time + timedelta(minutes=(i+1)*5)
102
+
103
+ # 가장 가까운 조화 예측값 찾기
104
+ harmonic_value = 0
105
+ for h_data in harmonic_data:
106
+ h_time = datetime.fromisoformat(h_data['predicted_at'].replace('Z', '+00:00'))
107
+ if abs((h_time - pred_time).total_seconds()) < 300: # 5분 이내
108
+ harmonic_value = h_data['harmonic_level']
109
+ break
110
+
111
+ final_tide = residual + harmonic_value
112
+
113
+ final_results['times'].append(pred_time)
114
+ final_results['residual'].append(residual)
115
+ final_results['harmonic'].append(harmonic_value)
116
+ final_results['final_tide'].append(final_tide)
117
+
118
+ return final_results
119
+
120
+ def save_predictions_to_supabase(station_id, prediction_results):
121
+ """예측 결과를 Supabase에 저장"""
122
+ supabase = get_supabase_client()
123
+ if not supabase:
124
+ return 0
125
+
126
+ try:
127
+ insert_data = []
128
+ for i in range(len(prediction_results['times'])):
129
+ insert_data.append({
130
+ 'station_id': station_id,
131
+ 'predicted_at': prediction_results['times'][i].isoformat(),
132
+ 'predicted_residual': float(prediction_results['residual'][i]),
133
+ 'harmonic_level': float(prediction_results['harmonic'][i]),
134
+ 'final_tide_level': float(prediction_results['final_tide'][i]),
135
+ 'model_version': 'TimeXer-v1'
136
+ })
137
+
138
+ result = supabase.table('tide_predictions')\
139
+ .upsert(insert_data, on_conflict='station_id,predicted_at,model_version')\
140
+ .execute()
141
+
142
+ return len(insert_data)
143
+ except Exception as e:
144
+ print(f"예측 결과 저장 오류: {e}")
145
+ return 0
146
+
147
  def get_common_args(station_id):
148
  return [
149
  "--model", "TimeXer", "--features", "MS", "--seq_len", "144", "--pred_len", "72",
 
192
  except Exception as e:
193
  raise gr.Error(f"내부 오류: {str(e)}")
194
 
195
+ def create_enhanced_prediction_plot(prediction_results, input_data, station_name):
196
+ """잔차 + 조화 + 최종 조위를 모두 표시하는 향상된 플롯"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  try:
198
+ # 입력 데이터에서 시간 정보 추출
199
+ input_df = pd.read_csv(input_data.name)
200
+ input_df['date'] = pd.to_datetime(input_df['date'])
201
+
202
+ # 최근 24포인트(2시간)만 표시
203
+ recent_data = input_df.tail(24)
204
 
205
+ # 미래 시간들
206
+ future_times = prediction_results['times']
 
 
 
 
207
 
208
  # 플롯 생성
209
  fig = go.Figure()
 
218
  marker=dict(size=4)
219
  ))
220
 
221
+ # 잔차 예측
222
  fig.add_trace(go.Scatter(
223
  x=future_times,
224
+ y=prediction_results['residual'],
225
  mode='lines+markers',
226
+ name='잔차 예측',
227
  line=dict(color='red', width=2, dash='dash'),
228
+ marker=dict(size=3)
229
+ ))
230
+
231
+ # 조화 예측
232
+ fig.add_trace(go.Scatter(
233
+ x=future_times,
234
+ y=prediction_results['harmonic'],
235
+ mode='lines',
236
+ name='조화 예측',
237
+ line=dict(color='orange', width=2)
238
+ ))
239
+
240
+ # 최종 조위 (잔차 + 조화)
241
+ fig.add_trace(go.Scatter(
242
+ x=future_times,
243
+ y=prediction_results['final_tide'],
244
+ mode='lines+markers',
245
+ name='최종 조위',
246
+ line=dict(color='green', width=3),
247
  marker=dict(size=4)
248
  ))
249
 
250
  # 구분선 추가
251
+ last_time = input_df['date'].iloc[-1]
252
  fig.add_vline(x=last_time, line_dash="dot", line_color="gray",
253
  annotation_text="예측 시작점")
254
 
255
  fig.update_layout(
256
+ title=f'{station_name} 통합 조위 예측 결과',
257
  xaxis_title='시간',
258
+ yaxis_title='수위 (cm)',
259
  hovermode='x unified',
260
+ height=600,
261
  showlegend=True,
262
  xaxis=dict(
263
+ tickformat='%H:%M<br>%m/%d'
264
  )
265
  )
266
 
267
+ return fig
 
 
 
268
 
269
+ except Exception as e:
270
+ print(f"Enhanced plot creation error: {e}")
271
+ # 에러 시 기본 플롯 반환
272
  fig = go.Figure()
273
+ fig.add_annotation(
274
+ text=f"시각화 생성 오류: {str(e)}",
275
+ xref="paper", yref="paper",
276
+ x=0.5, y=0.5, showarrow=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  )
278
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  def single_prediction(station_id, input_csv_file):
281
  if input_csv_file is None:
 
304
  "--scaler_path", scaler_path,
305
  "--predict_input_file", input_csv_file.name] + common_args
306
 
307
+ gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...")
308
 
309
  # inference 실행
310
  success, output = execute_inference_and_get_results(command)
 
313
  try:
314
  prediction_file = "pred_results/prediction_future.npy"
315
  if os.path.exists(prediction_file):
316
+ residual_predictions = np.load(prediction_file)
317
+
318
+ # 입력 데이터에서 마지막 시간 추출
319
+ input_df = pd.read_csv(input_csv_file.name)
320
+ input_df['date'] = pd.to_datetime(input_df['date'])
321
+ last_time = input_df['date'].iloc[-1]
322
 
323
+ # 최종 조위 계산 (��차 + 조화)
324
+ prediction_results = calculate_final_tide(residual_predictions, station_id, last_time)
325
+
326
+ # 향상된 시각화 생성
327
+ plot = create_enhanced_prediction_plot(prediction_results, input_csv_file, station_name)
328
 
329
  # 예측 결과 테이블 생성
330
+ result_df = pd.DataFrame({
331
+ '예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']],
332
+ '잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']],
333
+ '조화 예측 (cm)': [f"{val:.2f}" for val in prediction_results['harmonic']],
334
+ '최종 조위 (cm)': [f"{val:.2f}" for val in prediction_results['final_tide']]
335
+ })
336
+
337
+ # Supabase에 결과 저장
338
+ saved_count = save_predictions_to_supabase(station_id, prediction_results)
339
+ save_message = f"\n💾 Supabase에 {saved_count}개 예측 결과 저장 완료!" if saved_count > 0 else "\n⚠️ Supabase 저장 실패"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ return plot, result_df, f"✅ 통합 예측 완료!{save_message}\n\n{output}"
342
  else:
343
  return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
344
 
345
  except Exception as e:
346
+ print(f"Result processing error: {e}")
347
+ import traceback
348
+ traceback.print_exc()
349
  return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
350
+
351
  def rolling_evaluation(station_id, eval_csv_file):
352
  if eval_csv_file is None:
353
  raise gr.Error("평가를 위한 입력 파일을 업로드해주세요.")
 
407
  return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"
408
 
409
  # Gradio 인터페이스
410
+ with gr.Blocks(title="조위 예측 모델 v5.0", theme=gr.themes.Soft()) as demo:
411
+ gr.Markdown("# 🌊 통합 조위 예측 시스템 v5.0 (TimeXer + 조화분석)")
412
+ gr.Markdown("TimeXer 잔차 예측 + MATLAB 조화 예측을 결합한 최종 조위 예측 시스템입니다.")
413
+
414
+ # Supabase 연결 상태 표시
415
+ supabase_status = "🟢 연결됨" if get_supabase_client() else "🔴 연결 안됨"
416
+ gr.Markdown(f"**Supabase 상태**: {supabase_status}")
417
 
418
  with gr.Tabs():
419
+ with gr.TabItem("🔮 통합 조위 예측"):
420
  gr.Markdown("""
421
+ ### 🌟 새로운 기능
422
+ - **잔차 예측**: TimeXer 모델로 기상 영향 예측
423
+ - **조화 예측**: MATLAB 조화분석으로 천체 영향 예측
424
+ - **최종 조위**: 잔차 + 조화 = 완전한 조위 예측
425
+ - **자동 저장**: 예측 결과를 데이터베이스에 자동 저장
426
  """)
427
 
428
  with gr.Row():
 
439
  file_count="single"
440
  )
441
 
442
+ submit_btn1 = gr.Button("🚀 통합 예측 실행", variant="primary", size="lg")
443
 
444
  with gr.Row():
445
  with gr.Column(scale=2):
446
+ plot_output1 = gr.Plot(label="통합 예측 결과 시각화")
447
  with gr.Column(scale=1):
448
  table_output1 = gr.Dataframe(label="예측 결과 테이블")
449
 
450
  text_output1 = gr.Textbox(label="실행 로그", lines=5, show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  # 이벤트 바인딩
453
  submit_btn1.click(
 
456
  outputs=[plot_output1, table_output1, text_output1],
457
  show_progress=True
458
  )
 
 
 
 
 
 
459
 
460
  # 사용 안내
461
  gr.Markdown("""
462
+ ## 🎯 통합 예측 시스템 특징
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
+ ### 📊 4가지 예측 라인
465
+ - **파란색**: 실제 잔차조위 (과거 데이터)
466
+ - **빨간색**: 잔차 예측 (기상 영향)
467
+ - **주황색**: 조화 예측 (천체 영향)
468
+ - **초록색**: 최종 조위 (잔차 + 조화)
469
 
470
+ ### 💾 데이터 저장
471
+ 모든 예측 결과는 Supabase 데이터베이스에 자동 저장되어 이후 분석 및 API 활용이 가능합니다.
 
 
 
 
472
  """)
473
 
474
  if __name__ == "__main__":