Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| import traceback | |
| from datetime import datetime, timedelta | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import pytz | |
| from config import STATION_NAMES | |
| from supabase_utils import ( | |
| get_harmonic_predictions, save_predictions_to_supabase, get_supabase_client | |
| ) | |
| from preprocessing import preprocess_uploaded_file | |
| def get_common_args(station_id): | |
| return [ | |
| "--model", "TimeXer", "--features", "MS", "--seq_len", "144", "--pred_len", "72", | |
| "--label_len", "96", "--enc_in", "5", "--dec_in", "5", "--c_out", "1", | |
| "--d_model", "256", "--d_ff", "512", "--n_heads", "8", "--e_layers", "1", | |
| "--d_layers", "1", "--factor", "3", "--patch_len", "16", "--expand", "2", "--d_conv", "4" | |
| ] | |
| def validate_csv_file(file_path, required_rows=144): | |
| """CSV 파일 유효성 검사 - tide_level 또는 residual 지원""" | |
| try: | |
| df = pd.read_csv(file_path) | |
| # 기본 필수 컬럼 (tide_level 또는 residual 중 하나는 있어야 함) | |
| base_columns = ['date', 'air_pres', 'wind_dir', 'wind_speed', 'air_temp'] | |
| missing_base = [col for col in base_columns if col not in df.columns] | |
| if missing_base: | |
| return False, f"필수 컬럼이 누락되었습니다: {missing_base}" | |
| # tide_level 또는 residual 중 하나는 있어야 함 | |
| has_tide_level = 'tide_level' in df.columns | |
| has_residual = 'residual' in df.columns | |
| if not has_tide_level and not has_residual: | |
| return False, "tide_level 또는 residual 컬럼이 필요합니다." | |
| if len(df) < required_rows: | |
| return False, f"데이터가 부족합니다. 최소 {required_rows}행 필요, 현재 {len(df)}행" | |
| data_type = "tide_level" if has_tide_level else "residual" | |
| return True, f"파일이 유효합니다. (데이터 형태: {data_type})" | |
| except Exception as e: | |
| return False, f"파일 읽기 오류: {str(e)}" | |
| def execute_inference_and_get_results(command): | |
| """inference 실행하고 결과 파일을 읽어서 반환""" | |
| try: | |
| print(f"실행 명령어: {' '.join(command)}") | |
| result = subprocess.run(command, capture_output=True, text=True, timeout=300) | |
| if result.returncode != 0: | |
| error_message = ( | |
| f"실행 실패 (Exit Code: {result.returncode}):\n\n" | |
| f"--- 에러 로그 ---\n{result.stderr}\n\n" | |
| f"--- 일반 출력 ---\n{result.stdout}" | |
| ) | |
| raise gr.Error(error_message) | |
| return True, result.stdout | |
| except subprocess.TimeoutExpired: | |
| raise gr.Error("실행 시간이 초과되었습니다. (5분 제한)") | |
| except Exception as e: | |
| raise gr.Error(f"내부 오류: {str(e)}") | |
| def calculate_final_tide(residual_predictions, station_id, last_time): | |
| """잔차 예측 + 조화 예측 = 최종 조위 계산""" | |
| if isinstance(last_time, pd.Timestamp): | |
| last_time = last_time.to_pydatetime() | |
| kst = pytz.timezone('Asia/Seoul') | |
| if last_time.tzinfo is None: | |
| last_time = kst.localize(last_time) | |
| start_time = last_time + timedelta(minutes=5) | |
| end_time = last_time + timedelta(minutes=72*5) | |
| harmonic_data = get_harmonic_predictions(station_id, start_time, end_time) | |
| residual_flat = residual_predictions.flatten() | |
| num_points = len(residual_flat) | |
| if not harmonic_data: | |
| print("조화 예측 데이터를 찾을 수 없습니다. 잔차 예측만 반환합니다.") | |
| return { | |
| 'times': [last_time + timedelta(minutes=(i+1)*5) for i in range(num_points)], | |
| 'residual': residual_flat.tolist(), | |
| 'harmonic': [0.0] * num_points, | |
| 'final_tide': residual_flat.tolist() | |
| } | |
| final_results = { | |
| 'times': [], | |
| 'residual': [], | |
| 'harmonic': [], | |
| 'final_tide': [] | |
| } | |
| harmonic_dict = {} | |
| for h_data in harmonic_data: | |
| h_time_str = h_data['predicted_at'] | |
| try: | |
| if 'T' in h_time_str: | |
| if h_time_str.endswith('Z'): | |
| h_time = datetime.fromisoformat(h_time_str[:-1] + '+00:00') | |
| elif '+' in h_time_str or '-' in h_time_str[-6:]: | |
| h_time = datetime.fromisoformat(h_time_str) | |
| else: | |
| h_time = datetime.fromisoformat(h_time_str + '+00:00') | |
| else: | |
| from dateutil import parser | |
| h_time = parser.parse(h_time_str) | |
| if h_time.tzinfo is None: | |
| h_time = pytz.UTC.localize(h_time) | |
| h_time = h_time.astimezone(kst) | |
| except Exception as e: | |
| print(f"시간 파싱 오류: {h_time_str}, {e}") | |
| continue | |
| minutes = (h_time.minute // 5) * 5 | |
| h_time = h_time.replace(minute=minutes, second=0, microsecond=0) | |
| harmonic_value = float(h_data['harmonic_level']) | |
| harmonic_dict[h_time] = harmonic_value | |
| for i, residual in enumerate(residual_flat): | |
| pred_time = last_time + timedelta(minutes=(i+1)*5) | |
| pred_time = pred_time.replace(second=0, microsecond=0) | |
| harmonic_value = harmonic_dict.get(pred_time, 0.0) | |
| if harmonic_value == 0.0 and harmonic_dict: | |
| min_diff = float('inf') | |
| for h_time, h_val in harmonic_dict.items(): | |
| diff = abs((h_time - pred_time).total_seconds()) | |
| if diff < min_diff and diff < 300: | |
| min_diff = diff | |
| harmonic_value = h_val | |
| final_tide = float(residual) + harmonic_value | |
| final_results['times'].append(pred_time) | |
| final_results['residual'].append(float(residual)) | |
| final_results['harmonic'].append(harmonic_value) | |
| final_results['final_tide'].append(final_tide) | |
| return final_results | |
| def create_enhanced_prediction_plot(prediction_results, input_data, station_name): | |
| """잔차 + 조화 + 최종 조위를 모두 표시하는 향상된 플롯""" | |
| try: | |
| input_df = pd.read_csv(input_data.name) | |
| input_df['date'] = pd.to_datetime(input_df['date']) | |
| recent_data = input_df.tail(24) | |
| future_times = pd.to_datetime(prediction_results['times']) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=recent_data['date'], | |
| y=recent_data['residual'], | |
| mode='lines+markers', | |
| name='실제 잔차조위', | |
| line=dict(color='blue', width=2), | |
| marker=dict(size=4) | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=future_times, | |
| y=prediction_results['residual'], | |
| mode='lines+markers', | |
| name='잔차 예측', | |
| line=dict(color='red', width=2, dash='dash'), | |
| marker=dict(size=3) | |
| )) | |
| if any(h != 0 for h in prediction_results['harmonic']): | |
| fig.add_trace(go.Scatter( | |
| x=future_times, | |
| y=prediction_results['harmonic'], | |
| mode='lines', | |
| name='조화 예측', | |
| line=dict(color='orange', width=2) | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=future_times, | |
| y=prediction_results['final_tide'], | |
| mode='lines+markers', | |
| name='최종 조위', | |
| line=dict(color='green', width=3), | |
| marker=dict(size=4) | |
| )) | |
| last_time = recent_data['date'].iloc[-1] | |
| fig.add_annotation( | |
| x=last_time, | |
| y=0, | |
| text="← 과거 | 미래 →", | |
| showarrow=False, | |
| yref="paper", | |
| yshift=10, | |
| font=dict(size=12, color="gray") | |
| ) | |
| fig.update_layout( | |
| title=f'{station_name} 통합 조위 예측 결과', | |
| xaxis_title='시간', | |
| yaxis_title='수위 (cm)', | |
| hovermode='x unified', | |
| height=600, | |
| showlegend=True, | |
| xaxis=dict(tickformat='%H:%M<br>%m/%d', gridcolor='lightgray', showgrid=True), | |
| yaxis=dict(gridcolor='lightgray', showgrid=True), | |
| plot_bgcolor='white' | |
| ) | |
| return fig | |
| except Exception as e: | |
| print(f"Enhanced plot creation error: {e}") | |
| traceback.print_exc() | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"시각화 생성 중 오류: {str(e)}", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False | |
| ) | |
| return fig | |
| def single_prediction(station_id, input_csv_file): | |
| if input_csv_file is None: | |
| raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.") | |
| # 1. 초기 파일 검증 | |
| is_valid, message = validate_csv_file(input_csv_file.name) | |
| if not is_valid: | |
| raise gr.Error(f"파일 오류: {message}") | |
| station_name = STATION_NAMES.get(station_id, station_id) | |
| # 2. 전처리 수행 (tide_level → residual 변환 포함) | |
| gr.Info(f"📊 {station_name}({station_id}) 데이터 전처리 중...") | |
| processed_data, preprocess_result = preprocess_uploaded_file(input_csv_file.name, station_id) | |
| if processed_data is None: | |
| raise gr.Error(f"전처리 실패: {preprocess_result}") | |
| # 전처리 결과가 문자열(에러)인지 딕셔너리(성공)인지 확인 | |
| if isinstance(preprocess_result, str): | |
| raise gr.Error(f"전처리 오류: {preprocess_result}") | |
| # 전처리된 파일 경로 사용 | |
| processed_file_path = preprocess_result['output_file'] | |
| common_args = get_common_args(station_id) | |
| setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0" | |
| checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth" | |
| scaler_path = f"./checkpoints/{setting_name}/scaler.gz" | |
| if not os.path.exists(checkpoint_path): | |
| raise gr.Error(f"모델 파일을 찾을 수 없습니다: {checkpoint_path}") | |
| if not os.path.exists(scaler_path): | |
| raise gr.Error(f"스케일러 파일을 찾을 수 없습니다: {scaler_path}") | |
| # 전처리된 파일을 inference에 전달 | |
| command = ["python", "inference.py", | |
| "--checkpoint_path", checkpoint_path, | |
| "--scaler_path", scaler_path, | |
| "--predict_input_file", processed_file_path] + common_args | |
| gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...") | |
| success, output = execute_inference_and_get_results(command) | |
| try: | |
| prediction_file = "pred_results/prediction_future.npy" | |
| if os.path.exists(prediction_file): | |
| residual_predictions = np.load(prediction_file) | |
| # 전처리된 데이터 사용 | |
| input_df = processed_data | |
| last_time = input_df['date'].iloc[-1] | |
| prediction_results = calculate_final_tide(residual_predictions, station_id, last_time) | |
| # 플롯은 전처리된 데이터 파일을 사용 | |
| plot = create_enhanced_prediction_plot(prediction_results, type('obj', (object,), {'name': processed_file_path}), station_name) | |
| has_harmonic = any(h != 0 for h in prediction_results['harmonic']) | |
| if has_harmonic: | |
| result_df = pd.DataFrame({ | |
| '예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']], | |
| '잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']], | |
| '조화 예측 (cm)': [f"{val:.2f}" for val in prediction_results['harmonic']], | |
| '최종 조위 (cm)': [f"{val:.2f}" for val in prediction_results['final_tide']] | |
| }) | |
| else: | |
| result_df = pd.DataFrame({ | |
| '예측 시간': [t.strftime('%Y-%m-%d %H:%M') for t in prediction_results['times']], | |
| '잔차 예측 (cm)': [f"{val:.2f}" for val in prediction_results['residual']] | |
| }) | |
| saved_count = save_predictions_to_supabase(station_id, prediction_results) | |
| if saved_count > 0: | |
| save_message = f"\n💾 Supabase에 {saved_count}개 예측 결과 저장 완료!" | |
| elif get_supabase_client() is None: | |
| save_message = "\n⚠️ Supabase 연결 실패 (환경변수 확인 필요)" | |
| else: | |
| save_message = "\n⚠️ Supabase 저장 실패" | |
| # 전처리 정보 추가 | |
| preprocess_info = f"""📊 전처리 결과: | |
| - 원본 데이터: {preprocess_result['original_rows']}행 | |
| - 처리 데이터: {preprocess_result['processed_rows']}행 | |
| - Residual 평균: {preprocess_result['residual_mean']:.2f}cm | |
| - Residual 표준편차: {preprocess_result['residual_std']:.2f}cm""" | |
| return plot, result_df, f"✅ 예측 완료!{save_message}\n\n{preprocess_info}\n\n{output}" | |
| else: | |
| return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}" | |
| except Exception as e: | |
| print(f"Result processing error: {e}") | |
| traceback.print_exc() | |
| return None, None, f"❌ 결과 처리 중 오류: {str(e)}\n\n{output}" |