Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| import os | |
| import numpy as np | |
| from datetime import datetime | |
| from dynamix.forecaster import DynaMixForecaster | |
| from dynamix.utilities import load_hf_model, auto_model_selection | |
| from dynamix.utilities import create_forecast_plot | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="DynaMix 🧨 - Forecasting", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# DynaMix 🧨 - Forecasting") | |
| with gr.Row(): | |
| # Left sidebar for configuration | |
| with gr.Column(scale=1): | |
| gr.Markdown("Upload your data or choose a preset, then generate forecasts.") | |
| # Data upload section | |
| gr.Markdown("## Data Selection") | |
| with gr.Group(): | |
| file_input = gr.File( | |
| file_types=[".csv", ".npy"], | |
| label="Upload CSV / NPY", | |
| height=200 | |
| ) | |
| preset_dropdown = gr.Dropdown( | |
| choices=["-- No preset selected --", "Noisy Sine", "Lorenz63", "Chua", "Selkov"], | |
| value="-- No preset selected --", | |
| label="Or choose a preset", | |
| info="Select from predefined datasets" | |
| ) | |
| # Forecast settings | |
| gr.Markdown("## Forecast Settings") | |
| with gr.Group(): | |
| model_selection = gr.Dropdown( | |
| choices=["Auto"], | |
| value="Auto", | |
| label="Model Selection", | |
| info="Choose the DynaMix model to use for forecasting" | |
| ) | |
| horizon_slider = gr.Slider( | |
| minimum=1, | |
| maximum=2001, | |
| value=512, | |
| step=100, | |
| label="Forecast Length", | |
| info="Choose how many future steps to forecast" | |
| ) | |
| # Advanced settings | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| preprocessing_method = gr.Dropdown( | |
| choices=["pos_embedding", "zero_embedding", "delay_embedding", "delay_embedding_random"], | |
| value="pos_embedding", | |
| label="Preprocessing Method", | |
| info="Select the embedding method for time series with dimension < model dimension" | |
| ) | |
| standardize = gr.Checkbox( | |
| value=True, | |
| label="Standardize", | |
| info="Normalize the data to zero mean and unit variance" | |
| ) | |
| fit_nonstationary = gr.Checkbox( | |
| value=False, | |
| label="Fit Nonstationary", | |
| info="Account for non-stationary trends in the data" | |
| ) | |
| context_steps = gr.Number( | |
| value=2048, | |
| label="Context Steps", | |
| info="Maximum number of steps to use as context from provided data (default: 4096)" | |
| ) | |
| plot_btn = gr.Button("► Plot Forecasts", variant="primary", size="lg") | |
| gr.Markdown("# Instructions") | |
| instructions = gr.Markdown(""" | |
| **📊 Data Format Requirements** | |
| **NPY Files**: Shape: `(time_steps, dimensions)` or `(time_steps,)`\n | |
| **CSV Files**: Each column = one dimension, each row = one time step | |
| **⚡ Quick Start** | |
| 1. **Upload** a single dynamical system or time series (CSV or NPY) | |
| 2. **Configure** forecast length and settings | |
| 3. **Generate** predictions with "Plot Forecasts" (up to 15 dims of data are plotted) | |
| 4. **Download** the forecast as CSV or NPY | |
| """) | |
| # Right section for plots and downloads | |
| with gr.Column(scale=3): | |
| gr.Markdown("# Forecast Plot") | |
| plot_output = gr.Plot(show_label=False) | |
| with gr.Row(): | |
| csv_output = gr.File(label="Download Forecast CSV", visible=True) | |
| npy_output = gr.File(label="Download Forecast NPY", visible=True) | |
| def load_preset_data(preset_name): | |
| """Load preset data from the data folder""" | |
| if preset_name == "-- No preset selected --": | |
| return None | |
| preset_files = { | |
| "Lorenz63": "data/lorenz63.npy", | |
| "Noisy Sine": "data/sine.npy", | |
| "Chua": "data/chua.npy", | |
| "Selkov": "data/selkov.npy" | |
| } | |
| if preset_name in preset_files: | |
| file_path = preset_files[preset_name] | |
| if os.path.exists(file_path): | |
| return file_path | |
| return None | |
| def run_forecast(file, horizon, model_selection, preprocessing_method, standardize, fit_nonstationary, context_steps, preset_selection): | |
| try: | |
| # 1. Load the data | |
| # Check if preset is selected | |
| preset_file_path = load_preset_data(preset_selection) | |
| if not file and not preset_file_path: | |
| gr.Warning("Please upload a file or select a preset.") | |
| raise ValueError("Please upload a file or select a preset.") | |
| # Use preset file if selected, otherwise use uploaded file | |
| if preset_file_path: | |
| file_path = preset_file_path | |
| ext = ".npy" | |
| else: | |
| file_path = file.name | |
| ext = os.path.splitext(file.name)[1].lower() | |
| # Load input file (.csv or .npy) | |
| if ext == ".csv": | |
| df = pd.read_csv(file_path) | |
| if 'series_name' in df.columns: | |
| gr.Warning("Unsupported CSV format: only column-per-dimension format is supported.") | |
| raise ValueError("Unsupported CSV format: only column-per-dimension format is supported.") | |
| # Keep only numeric columns | |
| df = df.select_dtypes(include=[np.number]).copy() | |
| if df.shape[1] == 0: | |
| gr.Warning("No numeric columns found in CSV file.") | |
| raise ValueError("No numeric columns found in CSV file.") | |
| values = df.values | |
| elif ext == ".npy": | |
| values = np.load(file_path) | |
| # Defer DataFrame creation until after shape validation (handles 1D arrays) | |
| df = None | |
| else: | |
| gr.Warning("Unsupported file format. Please upload .csv or .npy") | |
| raise ValueError("Unsupported file format. Please upload .csv or .npy") | |
| # 2. Validate shape and dimensions, then construct context | |
| if not isinstance(values, np.ndarray): | |
| values = np.asarray(values) | |
| if values.ndim != 2: | |
| if values.ndim == 1: | |
| values = np.reshape(values, (-1, 1)) | |
| else: | |
| gr.Warning("Input must be 2D with shape (time_steps, dimensions).") | |
| raise ValueError("Input must be 2D with shape (time_steps, dimensions).") | |
| if values.shape[0] < 2: | |
| gr.Warning("Input must contain at least 2 time steps.") | |
| raise ValueError("Input must contain at least 2 time steps.") | |
| if values.shape[1] < 1: | |
| gr.Warning("Input must contain at least 1 dimension.") | |
| raise ValueError("Input must contain at least 1 dimension.") | |
| if values.shape[1] > 100: | |
| gr.Warning(f"Too many dimensions: {values.shape[1]} > 100. Reduce dimensions to ≤ 100.") | |
| raise ValueError(f"Too many dimensions: {values.shape[1]} > 100. Reduce dimensions to ≤ 100.") | |
| if context_steps < values.shape[0]: | |
| values = values[-context_steps:] # Use only the last n steps | |
| values = values.astype(np.float32) | |
| context_ts_tensor = torch.tensor(values, dtype=torch.float32) | |
| # 3. Load the selected model | |
| if model_selection == "Auto": | |
| current_model = load_hf_model(auto_model_selection(context_ts_tensor)) | |
| else: | |
| current_model = load_hf_model(model_selection) | |
| forecaster = DynaMixForecaster(current_model) | |
| if values.shape[1] > 3 and values.shape[1] <= 100: | |
| gr.Warning(f"Input dimension {values.shape[1]} > model dimension {current_model.N}. This may lead to performance degradation.") | |
| # 4. Run forecast | |
| with torch.no_grad(): | |
| reconstruction_ts = forecaster.forecast( | |
| context=context_ts_tensor, | |
| horizon=int(horizon), | |
| preprocessing_method=preprocessing_method, | |
| standardize=standardize, | |
| fit_nonstationary=fit_nonstationary, | |
| ) | |
| reconstruction_ts_np = reconstruction_ts.cpu().numpy() | |
| # 5. Create Plotly figure | |
| fig = create_forecast_plot(values, reconstruction_ts_np, horizon) | |
| # 6. Save forecast as CSV (all dimensions) and NPY (all dimensions) | |
| if df is None: | |
| # Create column names for NPY input after shape normalization | |
| df = pd.DataFrame(values, columns=[f"dim_{i+1}" for i in range(values.shape[1])]) | |
| forecast_df = pd.DataFrame(reconstruction_ts_np, columns=df.columns.tolist()) | |
| csv_path = "forecast.csv" | |
| forecast_df.to_csv(csv_path, index=False) | |
| # 7. Save full forecast as NPY (all dimensions) | |
| npy_path = "forecast.npy" | |
| np.save(npy_path, reconstruction_ts_np) | |
| # 8. Print success notification with timestamp | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| print(f"[{current_time}] - Forecast completed successfully!") | |
| return fig, csv_path, npy_path | |
| except Exception as e: | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| print(f"[{current_time}] - Forecast error: {str(e)}") | |
| return None, None, None | |
| plot_btn.click( | |
| run_forecast, | |
| inputs=[ | |
| file_input, horizon_slider, model_selection, preprocessing_method, standardize, | |
| fit_nonstationary, context_steps, preset_dropdown | |
| ], | |
| outputs=[plot_output, csv_output, npy_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |