my-tide-env / prediction.py
alwaysgood's picture
Update prediction.py
7d0339b verified
import os
import subprocess
import traceback
from datetime import datetime, timedelta
import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import pytz
from config import STATION_NAMES
from supabase_utils import (
get_harmonic_predictions, save_predictions_to_supabase, get_supabase_client
)
from preprocessing import preprocess_uploaded_file
def get_common_args(station_id):
return [
"--model", "TimeXer", "--features", "MS", "--seq_len", "144", "--pred_len", "72",
"--label_len", "96", "--enc_in", "5", "--dec_in", "5", "--c_out", "1",
"--d_model", "256", "--d_ff", "512", "--n_heads", "8", "--e_layers", "1",
"--d_layers", "1", "--factor", "3", "--patch_len", "16", "--expand", "2", "--d_conv", "4"
]
def validate_csv_file(file_path, required_rows=144):
"""CSV 파일 유효성 검사 - tide_level 또는 residual 지원"""
try:
df = pd.read_csv(file_path)
# 기본 필수 컬럼 (tide_level 또는 residual 중 하나는 있어야 함)
base_columns = ['date', 'air_pres', 'wind_dir', 'wind_speed', 'air_temp']
missing_base = [col for col in base_columns if col not in df.columns]
if missing_base:
return False, f"필수 컬럼이 누락되었습니다: {missing_base}"
# tide_level 또는 residual 중 하나는 있어야 함
has_tide_level = 'tide_level' in df.columns
has_residual = 'residual' in df.columns
if not has_tide_level and not has_residual:
return False, "tide_level 또는 residual 컬럼이 필요합니다."
if len(df) < required_rows:
return False, f"데이터가 부족합니다. 최소 {required_rows}행 필요, 현재 {len(df)}행"
data_type = "tide_level" if has_tide_level else "residual"
return True, f"파일이 유효합니다. (데이터 형태: {data_type})"
except Exception as e:
return False, f"파일 읽기 오류: {str(e)}"
def execute_inference_and_get_results(command):
"""inference 실행하고 결과 파일을 읽어서 반환"""
try:
print(f"실행 명령어: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
error_message = (
f"실행 실패 (Exit Code: {result.returncode}):\n\n"
f"--- 에러 로그 ---\n{result.stderr}\n\n"
f"--- 일반 출력 ---\n{result.stdout}"
)
raise gr.Error(error_message)
return True, result.stdout
except subprocess.TimeoutExpired:
raise gr.Error("실행 시간이 초과되었습니다. (5분 제한)")
except Exception as e:
raise gr.Error(f"내부 오류: {str(e)}")
def calculate_final_tide(residual_predictions, station_id, last_time):
"""잔차 예측 + 조화 예측 = 최종 조위 계산"""
if isinstance(last_time, pd.Timestamp):
last_time = last_time.to_pydatetime()
kst = pytz.timezone('Asia/Seoul')
if last_time.tzinfo is None:
last_time = kst.localize(last_time)
start_time = last_time + timedelta(minutes=5)
end_time = last_time + timedelta(minutes=72*5)
harmonic_data = get_harmonic_predictions(station_id, start_time, end_time)
residual_flat = residual_predictions.flatten()
num_points = len(residual_flat)
if not harmonic_data:
print("조화 예측 데이터를 찾을 수 없습니다. 잔차 예측만 반환합니다.")
return {
'times': [last_time + timedelta(minutes=(i+1)*5) for i in range(num_points)],
'residual': residual_flat.tolist(),
'harmonic': [0.0] * num_points,
'final_tide': residual_flat.tolist()
}
final_results = {
'times': [],
'residual': [],
'harmonic': [],
'final_tide': []
}
harmonic_dict = {}
for h_data in harmonic_data:
h_time_str = h_data['predicted_at']
try:
if 'T' in h_time_str:
if h_time_str.endswith('Z'):
h_time = datetime.fromisoformat(h_time_str[:-1] + '+00:00')
elif '+' in h_time_str or '-' in h_time_str[-6:]:
h_time = datetime.fromisoformat(h_time_str)
else:
h_time = datetime.fromisoformat(h_time_str + '+00:00')
else:
from dateutil import parser
h_time = parser.parse(h_time_str)
if h_time.tzinfo is None:
h_time = pytz.UTC.localize(h_time)
h_time = h_time.astimezone(kst)
except Exception as e:
print(f"시간 파싱 오류: {h_time_str}, {e}")
continue
minutes = (h_time.minute // 5) * 5
h_time = h_time.replace(minute=minutes, second=0, microsecond=0)
harmonic_value = float(h_data['harmonic_level'])
harmonic_dict[h_time] = harmonic_value
for i, residual in enumerate(residual_flat):
pred_time = last_time + timedelta(minutes=(i+1)*5)
pred_time = pred_time.replace(second=0, microsecond=0)
harmonic_value = harmonic_dict.get(pred_time, 0.0)
if harmonic_value == 0.0 and harmonic_dict:
min_diff = float('inf')
for h_time, h_val in harmonic_dict.items():
diff = abs((h_time - pred_time).total_seconds())
if diff < min_diff and diff < 300:
min_diff = diff
harmonic_value = h_val
final_tide = float(residual) + harmonic_value
final_results['times'].append(pred_time)
final_results['residual'].append(float(residual))
final_results['harmonic'].append(harmonic_value)
final_results['final_tide'].append(final_tide)
return final_results
def create_enhanced_prediction_plot(prediction_results, input_data, station_name):
"""잔차 + 조화 + 최종 조위를 모두 표시하는 향상된 플롯"""
try:
input_df = pd.read_csv(input_data.name)
input_df['date'] = pd.to_datetime(input_df['date'])
recent_data = input_df.tail(24)
future_times = pd.to_datetime(prediction_results['times'])
fig = go.Figure()
fig.add_trace(go.Scatter(
x=recent_data['date'],
y=recent_data['residual'],
mode='lines+markers',
name='실제 잔차조위',
line=dict(color='blue', width=2),
marker=dict(size=4)
))
fig.add_trace(go.Scatter(
x=future_times,
y=prediction_results['residual'],
mode='lines+markers',
name='잔차 예측',
line=dict(color='red', width=2, dash='dash'),
marker=dict(size=3)
))
if any(h != 0 for h in prediction_results['harmonic']):
fig.add_trace(go.Scatter(
x=future_times,
y=prediction_results['harmonic'],
mode='lines',
name='조화 예측',
line=dict(color='orange', width=2)
))
fig.add_trace(go.Scatter(
x=future_times,
y=prediction_results['final_tide'],
mode='lines+markers',
name='최종 조위',
line=dict(color='green', width=3),
marker=dict(size=4)
))
last_time = recent_data['date'].iloc[-1]
fig.add_annotation(
x=last_time,
y=0,
text="← 과거 | 미래 →",
showarrow=False,
yref="paper",
yshift=10,
font=dict(size=12, color="gray")
)
fig.update_layout(
title=f'{station_name} 통합 조위 예측 결과',
xaxis_title='시간',
yaxis_title='수위 (cm)',
hovermode='x unified',
height=600,
showlegend=True,
xaxis=dict(tickformat='%H:%M<br>%m/%d', gridcolor='lightgray', showgrid=True),
yaxis=dict(gridcolor='lightgray', showgrid=True),
plot_bgcolor='white'
)
return fig
except Exception as e:
print(f"Enhanced plot creation error: {e}")
traceback.print_exc()
fig = go.Figure()
fig.add_annotation(
text=f"시각화 생성 중 오류: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False
)
return fig
def single_prediction(station_id, input_csv_file):
if input_csv_file is None:
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
# 1. 초기 파일 검증
is_valid, message = validate_csv_file(input_csv_file.name)
if not is_valid:
raise gr.Error(f"파일 오류: {message}")
station_name = STATION_NAMES.get(station_id, station_id)
# 2. 전처리 수행 (tide_level → residual 변환 포함)
gr.Info(f"📊 {station_name}({station_id}) 데이터 전처리 중...")
processed_data, preprocess_result = preprocess_uploaded_file(input_csv_file.name, station_id)
if processed_data is None:
raise gr.Error(f"전처리 실패: {preprocess_result}")
# 전처리 결과가 문자열(에러)인지 딕셔너리(성공)인지 확인
if isinstance(preprocess_result, str):
raise gr.Error(f"전처리 오류: {preprocess_result}")
# 전처리된 파일 경로 사용
processed_file_path = preprocess_result['output_file']
common_args = get_common_args(station_id)
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
scaler_path = f"./checkpoints/{setting_name}/scaler.gz"
if not os.path.exists(checkpoint_path):
raise gr.Error(f"모델 파일을 찾을 수 없습니다: {checkpoint_path}")
if not os.path.exists(scaler_path):
raise gr.Error(f"스케일러 파일을 찾을 수 없습니다: {scaler_path}")
# 전처리된 파일을 inference에 전달
command = ["python", "inference.py",
"--checkpoint_path", checkpoint_path,
"--scaler_path", scaler_path,
"--predict_input_file", processed_file_path] + common_args
gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...")
success, output = execute_inference_and_get_results(command)
try:
prediction_file = "pred_results/prediction_future.npy"
if os.path.exists(prediction_file):
residual_predictions = np.load(prediction_file)
# 전처리된 데이터 사용
input_df = processed_data
last_time = input_df['date'].iloc[-1]
prediction_results = calculate_final_tide(residual_predictions, station_id, last_time)
# 플롯은 전처리된 데이터 파일을 사용
plot = create_enhanced_prediction_plot(prediction_results, type('obj', (object,), {'name': processed_file_path}), station_name)
has_harmonic = any(h != 0 for h in prediction_results['harmonic'])
if has_harmonic:
result_df = pd.DataFrame({
'예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']],
'잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']],
'조화 예측 (cm)': [f"{val:.2f}" for val in prediction_results['harmonic']],
'최종 조위 (cm)': [f"{val:.2f}" for val in prediction_results['final_tide']]
})
else:
result_df = pd.DataFrame({
'예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']],
'잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']]
})
saved_count = save_predictions_to_supabase(station_id, prediction_results)
if saved_count > 0:
save_message = f"\n💾 Supabase에 {saved_count}개 예측 결과 저장 완료!"
elif get_supabase_client() is None:
save_message = "\n⚠️ Supabase 연결 실패 (환경변수 확인 필요)"
else:
save_message = "\n⚠️ Supabase 저장 실패"
# 전처리 정보 추가
preprocess_info = f"""📊 전처리 결과:
- 원본 데이터: {preprocess_result['original_rows']}
- 처리 데이터: {preprocess_result['processed_rows']}
- Residual 평균: {preprocess_result['residual_mean']:.2f}cm
- Residual 표준편차: {preprocess_result['residual_std']:.2f}cm"""
return plot, result_df, f"✅ 예측 완료!{save_message}\n\n{preprocess_info}\n\n{output}"
else:
return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
except Exception as e:
print(f"Result processing error: {e}")
traceback.print_exc()
return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}"