Spaces:
Running
Running
File size: 10,869 Bytes
776877d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import gradio as gr
import pandas as pd
import torch
import os
import numpy as np
from datetime import datetime
from dynamix.forecaster import DynaMixForecaster
from dynamix.utilities import load_hf_model, auto_model_selection
from dynamix.utilities import create_forecast_plot
# --- Gradio UI ---
with gr.Blocks(title="DynaMix 🧨 - Forecasting", theme=gr.themes.Soft()) as demo:
gr.Markdown("# DynaMix 🧨 - Forecasting")
with gr.Row():
# Left sidebar for configuration
with gr.Column(scale=1):
gr.Markdown("Upload your data or choose a preset, then generate forecasts.")
# Data upload section
gr.Markdown("## Data Selection")
with gr.Group():
file_input = gr.File(
file_types=[".csv", ".npy"],
label="Upload CSV / NPY",
height=200
)
preset_dropdown = gr.Dropdown(
choices=["-- No preset selected --", "Noisy Sine", "Lorenz63", "Chua", "Selkov"],
value="-- No preset selected --",
label="Or choose a preset",
info="Select from predefined datasets"
)
# Forecast settings
gr.Markdown("## Forecast Settings")
with gr.Group():
model_selection = gr.Dropdown(
choices=["Auto"],
value="Auto",
label="Model Selection",
info="Choose the DynaMix model to use for forecasting"
)
horizon_slider = gr.Slider(
minimum=1,
maximum=2001,
value=512,
step=100,
label="Forecast Length",
info="Choose how many future steps to forecast"
)
# Advanced settings
with gr.Accordion("⚙️ Advanced Settings", open=False):
preprocessing_method = gr.Dropdown(
choices=["pos_embedding", "zero_embedding", "delay_embedding", "delay_embedding_random"],
value="pos_embedding",
label="Preprocessing Method",
info="Select the embedding method for time series with dimension < model dimension"
)
standardize = gr.Checkbox(
value=True,
label="Standardize",
info="Normalize the data to zero mean and unit variance"
)
fit_nonstationary = gr.Checkbox(
value=False,
label="Fit Nonstationary",
info="Account for non-stationary trends in the data"
)
context_steps = gr.Number(
value=2048,
label="Context Steps",
info="Maximum number of steps to use as context from provided data (default: 4096)"
)
plot_btn = gr.Button("► Plot Forecasts", variant="primary", size="lg")
gr.Markdown("# Instructions")
instructions = gr.Markdown("""
**📊 Data Format Requirements**
**NPY Files**: Shape: `(time_steps, dimensions)` or `(time_steps,)`\n
**CSV Files**: Each column = one dimension, each row = one time step
**⚡ Quick Start**
1. **Upload** a single dynamical system or time series (CSV or NPY)
2. **Configure** forecast length and settings
3. **Generate** predictions with "Plot Forecasts" (up to 15 dims of data are plotted)
4. **Download** the forecast as CSV or NPY
""")
# Right section for plots and downloads
with gr.Column(scale=3):
gr.Markdown("# Forecast Plot")
plot_output = gr.Plot(show_label=False)
with gr.Row():
csv_output = gr.File(label="Download Forecast CSV", visible=True)
npy_output = gr.File(label="Download Forecast NPY", visible=True)
def load_preset_data(preset_name):
"""Load preset data from the data folder"""
if preset_name == "-- No preset selected --":
return None
preset_files = {
"Lorenz63": "data/lorenz63.npy",
"Noisy Sine": "data/sine.npy",
"Chua": "data/chua.npy",
"Selkov": "data/selkov.npy"
}
if preset_name in preset_files:
file_path = preset_files[preset_name]
if os.path.exists(file_path):
return file_path
return None
def run_forecast(file, horizon, model_selection, preprocessing_method, standardize, fit_nonstationary, context_steps, preset_selection):
try:
# 1. Load the data
# Check if preset is selected
preset_file_path = load_preset_data(preset_selection)
if not file and not preset_file_path:
gr.Warning("Please upload a file or select a preset.")
raise ValueError("Please upload a file or select a preset.")
# Use preset file if selected, otherwise use uploaded file
if preset_file_path:
file_path = preset_file_path
ext = ".npy"
else:
file_path = file.name
ext = os.path.splitext(file.name)[1].lower()
# Load input file (.csv or .npy)
if ext == ".csv":
df = pd.read_csv(file_path)
if 'series_name' in df.columns:
gr.Warning("Unsupported CSV format: only column-per-dimension format is supported.")
raise ValueError("Unsupported CSV format: only column-per-dimension format is supported.")
# Keep only numeric columns
df = df.select_dtypes(include=[np.number]).copy()
if df.shape[1] == 0:
gr.Warning("No numeric columns found in CSV file.")
raise ValueError("No numeric columns found in CSV file.")
values = df.values
elif ext == ".npy":
values = np.load(file_path)
# Defer DataFrame creation until after shape validation (handles 1D arrays)
df = None
else:
gr.Warning("Unsupported file format. Please upload .csv or .npy")
raise ValueError("Unsupported file format. Please upload .csv or .npy")
# 2. Validate shape and dimensions, then construct context
if not isinstance(values, np.ndarray):
values = np.asarray(values)
if values.ndim != 2:
if values.ndim == 1:
values = np.reshape(values, (-1, 1))
else:
gr.Warning("Input must be 2D with shape (time_steps, dimensions).")
raise ValueError("Input must be 2D with shape (time_steps, dimensions).")
if values.shape[0] < 2:
gr.Warning("Input must contain at least 2 time steps.")
raise ValueError("Input must contain at least 2 time steps.")
if values.shape[1] < 1:
gr.Warning("Input must contain at least 1 dimension.")
raise ValueError("Input must contain at least 1 dimension.")
if values.shape[1] > 100:
gr.Warning(f"Too many dimensions: {values.shape[1]} > 100. Reduce dimensions to ≤ 100.")
raise ValueError(f"Too many dimensions: {values.shape[1]} > 100. Reduce dimensions to ≤ 100.")
if context_steps < values.shape[0]:
values = values[-context_steps:] # Use only the last n steps
values = values.astype(np.float32)
context_ts_tensor = torch.tensor(values, dtype=torch.float32)
# 3. Load the selected model
if model_selection == "Auto":
current_model = load_hf_model(auto_model_selection(context_ts_tensor))
else:
current_model = load_hf_model(model_selection)
forecaster = DynaMixForecaster(current_model)
if values.shape[1] > 3 and values.shape[1] <= 100:
gr.Warning(f"Input dimension {values.shape[1]} > model dimension {current_model.N}. This may lead to performance degradation.")
# 4. Run forecast
with torch.no_grad():
reconstruction_ts = forecaster.forecast(
context=context_ts_tensor,
horizon=int(horizon),
preprocessing_method=preprocessing_method,
standardize=standardize,
fit_nonstationary=fit_nonstationary,
)
reconstruction_ts_np = reconstruction_ts.cpu().numpy()
# 5. Create Plotly figure
fig = create_forecast_plot(values, reconstruction_ts_np, horizon)
# 6. Save forecast as CSV (all dimensions) and NPY (all dimensions)
if df is None:
# Create column names for NPY input after shape normalization
df = pd.DataFrame(values, columns=[f"dim_{i+1}" for i in range(values.shape[1])])
forecast_df = pd.DataFrame(reconstruction_ts_np, columns=df.columns.tolist())
csv_path = "forecast.csv"
forecast_df.to_csv(csv_path, index=False)
# 7. Save full forecast as NPY (all dimensions)
npy_path = "forecast.npy"
np.save(npy_path, reconstruction_ts_np)
# 8. Print success notification with timestamp
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{current_time}] - Forecast completed successfully!")
return fig, csv_path, npy_path
except Exception as e:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{current_time}] - Forecast error: {str(e)}")
return None, None, None
plot_btn.click(
run_forecast,
inputs=[
file_input, horizon_slider, model_selection, preprocessing_method, standardize,
fit_nonstationary, context_steps, preset_dropdown
],
outputs=[plot_output, csv_output, npy_output]
)
if __name__ == "__main__":
demo.launch() |