Dschobby commited on
Commit
776877d
·
verified ·
1 Parent(s): c085019

Upload 14 files

Browse files
.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ example_evaluation.ipynb
2
+ test.py
3
+ forecast.csv
4
+ forecast.npy
5
+
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ .pytest_cache/
10
+ .coverage
11
+ htmlcov/
12
+ .tox/
13
+ .nox/
14
+ *.so
15
+ .Python
16
+ env/
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ .vscode/
README.md CHANGED
@@ -1,14 +1,66 @@
1
  ---
2
  title: DynaMix
3
- emoji: 🌖
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.46.1
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
11
  short_description: Zero-shot forecasting of Dynamical Systems using DynaMix
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: DynaMix
3
+ emoji: 🧨
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.43.1
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
11
  short_description: Zero-shot forecasting of Dynamical Systems using DynaMix
12
  ---
13
 
14
+ # DynaMix: Zero-shot Forecasting of Dynamical Systems
15
+
16
+ This DynaMix demo is an interactive Gradio app for zero-shot dynamical systems reconstruction using the DynaMix architecture (accepted NeurIPS 2025 paper [![arXiv](https://img.shields.io/badge/arXiv-2505.13192-b31b1b.svg)](https://arxiv.org/abs/2505.13192)). It loads pretrained models from the Hugging Face Hub (see [DynaMix model](https://huggingface.co/DurstewitzLab/dynamix)) and provides predictions from uploaded context data.
17
+
18
+ ### Key Features
19
+ - **Zero-shot forecasting**: Powered by DynaMix model architecture
20
+ - **Custom Context Upload**: Upload your CSV/NPY data or choose a preset (Lorenz63, Noisy Sine, Chua, Selkov)
21
+ - **Interactive Settings**: Configure forecast settings
22
+ - **Visualizations**: Plots of context data and forecast
23
+ - **Exports**: Download forecast as CSV and NPY
24
+
25
+
26
+ ## Using the Application
27
+
28
+ ### Data Input
29
+ You can either upload your own data or choose a preset dataset from the left panel.
30
+
31
+ - **Upload**: Accepts `.csv` or `.npy` files
32
+ - **Presets**: `Noisy Sine`, `Lorenz63`, `Chua`, `Selkov`
33
+
34
+ Supported data formats:
35
+ - **NPY files**: Numpy array of shape `(time_steps, dimensions)`. 1D time series arrays are auto-expanded to `(time_steps, 1)`; otherwise must be 2D with at least 2 time steps and ≥1 dimension.
36
+
37
+ - **CSV files**: Each column is a dimension; each row is a time step. Only numeric columns are used. Data must be 2D with at least 2 time steps and ≥1 dimension.
38
+
39
+ Example CSV format:
40
+ ```csv
41
+ dim_1,dim_2,dim_3
42
+ 0.1,0.2,0.3
43
+ 0.4,0.5,0.6
44
+ 0.7,0.8,0.9
45
+ ```
46
+
47
+ ### Forecast Settings
48
+ - **Model Selection**: Select the pretrained model to use for forecasting.
49
+
50
+ - **Forecast Length**: Number of future steps to generate (`1`–`2001`, step `100`, default `512`)
51
+
52
+ - **Advanced Settings**
53
+ - **Preprocessing Method**: Method to use for preprocessing the context data (for cases where input dims < model dims)
54
+ - **Standardize**: Normalize context to zero mean and unit variance (default: enabled)
55
+ - **Fit Nonstationary**: Account for non-stationary trends in the data (default: disabled)
56
+ - **Context Steps**: Maximum number of last steps from the uploaded data to use as context. If your uploaded sequence is longer, it will be truncated to the most recent `Context Steps`. (default `2048`)
57
+
58
+ ### Outputs
59
+ - **Interactive Plot**: Shows historical context (blue) and forecast (red) per dimension, up to 15 dimensions.
60
+ - **Files**:
61
+ - `forecast.csv`: Full forecast for all dimensions.
62
+ - `forecast.npy`: Full forecast ndarray including all dimensions.
63
+
64
+
65
+ ## License
66
+ This project is released under the **CC BY 4.0** license.
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ import os
5
+ import numpy as np
6
+ from datetime import datetime
7
+
8
+ from dynamix.forecaster import DynaMixForecaster
9
+ from dynamix.utilities import load_hf_model, auto_model_selection
10
+ from dynamix.utilities import create_forecast_plot
11
+
12
+
13
+
14
+ # --- Gradio UI ---
15
+ with gr.Blocks(title="DynaMix 🧨 - Forecasting", theme=gr.themes.Soft()) as demo:
16
+ gr.Markdown("# DynaMix 🧨 - Forecasting")
17
+
18
+ with gr.Row():
19
+ # Left sidebar for configuration
20
+ with gr.Column(scale=1):
21
+ gr.Markdown("Upload your data or choose a preset, then generate forecasts.")
22
+
23
+ # Data upload section
24
+ gr.Markdown("## Data Selection")
25
+ with gr.Group():
26
+ file_input = gr.File(
27
+ file_types=[".csv", ".npy"],
28
+ label="Upload CSV / NPY",
29
+ height=200
30
+ )
31
+
32
+ preset_dropdown = gr.Dropdown(
33
+ choices=["-- No preset selected --", "Noisy Sine", "Lorenz63", "Chua", "Selkov"],
34
+ value="-- No preset selected --",
35
+ label="Or choose a preset",
36
+ info="Select from predefined datasets"
37
+ )
38
+
39
+ # Forecast settings
40
+ gr.Markdown("## Forecast Settings")
41
+ with gr.Group():
42
+ model_selection = gr.Dropdown(
43
+ choices=["Auto"],
44
+ value="Auto",
45
+ label="Model Selection",
46
+ info="Choose the DynaMix model to use for forecasting"
47
+ )
48
+ horizon_slider = gr.Slider(
49
+ minimum=1,
50
+ maximum=2001,
51
+ value=512,
52
+ step=100,
53
+ label="Forecast Length",
54
+ info="Choose how many future steps to forecast"
55
+ )
56
+
57
+ # Advanced settings
58
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
59
+ preprocessing_method = gr.Dropdown(
60
+ choices=["pos_embedding", "zero_embedding", "delay_embedding", "delay_embedding_random"],
61
+ value="pos_embedding",
62
+ label="Preprocessing Method",
63
+ info="Select the embedding method for time series with dimension < model dimension"
64
+ )
65
+ standardize = gr.Checkbox(
66
+ value=True,
67
+ label="Standardize",
68
+ info="Normalize the data to zero mean and unit variance"
69
+ )
70
+ fit_nonstationary = gr.Checkbox(
71
+ value=False,
72
+ label="Fit Nonstationary",
73
+ info="Account for non-stationary trends in the data"
74
+ )
75
+ context_steps = gr.Number(
76
+ value=2048,
77
+ label="Context Steps",
78
+ info="Maximum number of steps to use as context from provided data (default: 4096)"
79
+ )
80
+
81
+ plot_btn = gr.Button("► Plot Forecasts", variant="primary", size="lg")
82
+
83
+ gr.Markdown("# Instructions")
84
+ instructions = gr.Markdown("""
85
+ **📊 Data Format Requirements**
86
+
87
+ **NPY Files**: Shape: `(time_steps, dimensions)` or `(time_steps,)`\n
88
+ **CSV Files**: Each column = one dimension, each row = one time step
89
+
90
+ **⚡ Quick Start**
91
+ 1. **Upload** a single dynamical system or time series (CSV or NPY)
92
+ 2. **Configure** forecast length and settings
93
+ 3. **Generate** predictions with "Plot Forecasts" (up to 15 dims of data are plotted)
94
+ 4. **Download** the forecast as CSV or NPY
95
+ """)
96
+
97
+ # Right section for plots and downloads
98
+ with gr.Column(scale=3):
99
+ gr.Markdown("# Forecast Plot")
100
+ plot_output = gr.Plot(show_label=False)
101
+ with gr.Row():
102
+ csv_output = gr.File(label="Download Forecast CSV", visible=True)
103
+ npy_output = gr.File(label="Download Forecast NPY", visible=True)
104
+
105
+ def load_preset_data(preset_name):
106
+ """Load preset data from the data folder"""
107
+ if preset_name == "-- No preset selected --":
108
+ return None
109
+
110
+ preset_files = {
111
+ "Lorenz63": "data/lorenz63.npy",
112
+ "Noisy Sine": "data/sine.npy",
113
+ "Chua": "data/chua.npy",
114
+ "Selkov": "data/selkov.npy"
115
+ }
116
+
117
+ if preset_name in preset_files:
118
+ file_path = preset_files[preset_name]
119
+ if os.path.exists(file_path):
120
+ return file_path
121
+ return None
122
+
123
+ def run_forecast(file, horizon, model_selection, preprocessing_method, standardize, fit_nonstationary, context_steps, preset_selection):
124
+ try:
125
+
126
+ # 1. Load the data
127
+ # Check if preset is selected
128
+ preset_file_path = load_preset_data(preset_selection)
129
+
130
+ if not file and not preset_file_path:
131
+ gr.Warning("Please upload a file or select a preset.")
132
+ raise ValueError("Please upload a file or select a preset.")
133
+
134
+ # Use preset file if selected, otherwise use uploaded file
135
+ if preset_file_path:
136
+ file_path = preset_file_path
137
+ ext = ".npy"
138
+ else:
139
+ file_path = file.name
140
+ ext = os.path.splitext(file.name)[1].lower()
141
+
142
+ # Load input file (.csv or .npy)
143
+ if ext == ".csv":
144
+ df = pd.read_csv(file_path)
145
+
146
+ if 'series_name' in df.columns:
147
+ gr.Warning("Unsupported CSV format: only column-per-dimension format is supported.")
148
+ raise ValueError("Unsupported CSV format: only column-per-dimension format is supported.")
149
+
150
+ # Keep only numeric columns
151
+ df = df.select_dtypes(include=[np.number]).copy()
152
+ if df.shape[1] == 0:
153
+ gr.Warning("No numeric columns found in CSV file.")
154
+ raise ValueError("No numeric columns found in CSV file.")
155
+ values = df.values
156
+ elif ext == ".npy":
157
+ values = np.load(file_path)
158
+ # Defer DataFrame creation until after shape validation (handles 1D arrays)
159
+ df = None
160
+ else:
161
+ gr.Warning("Unsupported file format. Please upload .csv or .npy")
162
+ raise ValueError("Unsupported file format. Please upload .csv or .npy")
163
+
164
+ # 2. Validate shape and dimensions, then construct context
165
+ if not isinstance(values, np.ndarray):
166
+ values = np.asarray(values)
167
+ if values.ndim != 2:
168
+ if values.ndim == 1:
169
+ values = np.reshape(values, (-1, 1))
170
+ else:
171
+ gr.Warning("Input must be 2D with shape (time_steps, dimensions).")
172
+ raise ValueError("Input must be 2D with shape (time_steps, dimensions).")
173
+ if values.shape[0] < 2:
174
+ gr.Warning("Input must contain at least 2 time steps.")
175
+ raise ValueError("Input must contain at least 2 time steps.")
176
+ if values.shape[1] < 1:
177
+ gr.Warning("Input must contain at least 1 dimension.")
178
+ raise ValueError("Input must contain at least 1 dimension.")
179
+ if values.shape[1] > 100:
180
+ gr.Warning(f"Too many dimensions: {values.shape[1]} > 100. Reduce dimensions to ≤ 100.")
181
+ raise ValueError(f"Too many dimensions: {values.shape[1]} > 100. Reduce dimensions to ≤ 100.")
182
+ if context_steps < values.shape[0]:
183
+ values = values[-context_steps:] # Use only the last n steps
184
+ values = values.astype(np.float32)
185
+ context_ts_tensor = torch.tensor(values, dtype=torch.float32)
186
+
187
+ # 3. Load the selected model
188
+ if model_selection == "Auto":
189
+ current_model = load_hf_model(auto_model_selection(context_ts_tensor))
190
+ else:
191
+ current_model = load_hf_model(model_selection)
192
+ forecaster = DynaMixForecaster(current_model)
193
+ if values.shape[1] > 3 and values.shape[1] <= 100:
194
+ gr.Warning(f"Input dimension {values.shape[1]} > model dimension {current_model.N}. This may lead to performance degradation.")
195
+
196
+ # 4. Run forecast
197
+ with torch.no_grad():
198
+ reconstruction_ts = forecaster.forecast(
199
+ context=context_ts_tensor,
200
+ horizon=int(horizon),
201
+ preprocessing_method=preprocessing_method,
202
+ standardize=standardize,
203
+ fit_nonstationary=fit_nonstationary,
204
+ )
205
+ reconstruction_ts_np = reconstruction_ts.cpu().numpy()
206
+
207
+ # 5. Create Plotly figure
208
+ fig = create_forecast_plot(values, reconstruction_ts_np, horizon)
209
+
210
+ # 6. Save forecast as CSV (all dimensions) and NPY (all dimensions)
211
+ if df is None:
212
+ # Create column names for NPY input after shape normalization
213
+ df = pd.DataFrame(values, columns=[f"dim_{i+1}" for i in range(values.shape[1])])
214
+ forecast_df = pd.DataFrame(reconstruction_ts_np, columns=df.columns.tolist())
215
+ csv_path = "forecast.csv"
216
+ forecast_df.to_csv(csv_path, index=False)
217
+
218
+ # 7. Save full forecast as NPY (all dimensions)
219
+ npy_path = "forecast.npy"
220
+ np.save(npy_path, reconstruction_ts_np)
221
+
222
+ # 8. Print success notification with timestamp
223
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
224
+ print(f"[{current_time}] - Forecast completed successfully!")
225
+
226
+ return fig, csv_path, npy_path
227
+
228
+ except Exception as e:
229
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
230
+ print(f"[{current_time}] - Forecast error: {str(e)}")
231
+ return None, None, None
232
+
233
+ plot_btn.click(
234
+ run_forecast,
235
+ inputs=[
236
+ file_input, horizon_slider, model_selection, preprocessing_method, standardize,
237
+ fit_nonstationary, context_steps, preset_dropdown
238
+ ],
239
+ outputs=[plot_output, csv_output, npy_output]
240
+ )
241
+
242
+ if __name__ == "__main__":
243
+ demo.launch()
data/chua.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc0a1090e555f13aab17aa70feca3dd0fe64f50edbd33849105a84ce86f08d11
3
+ size 24128
data/lorenz63.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aafcc2759e8981b44b4cc9f335967934647b20e27b1952d89fb0f371e1a835a6
3
+ size 48128
data/selkov.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2576908c5fd0e55267a41022f81b5b9c8f5f8fbec326a0977d8a702982ee4fef
3
+ size 1160
data/sine.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:429315428d533a103b6d772cbb2d5d341f0a73145bb196758717cf73b0a655b2
3
+ size 4224
dynamix/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model module for Zero-shot DSR.
3
+ """
4
+
5
+ from .dynamix import *
6
+ from .preprocessing_utilities import *
7
+ from .preprocessing import *
8
+ from .forecaster import *
9
+ from .utilities import *
dynamix/dynamix.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ class GatingNetwork(nn.Module):
7
+ def __init__(self, N, M, Experts, dtype=torch.float32):
8
+ super().__init__()
9
+ self.conv = nn.Conv1d(N, N, kernel_size=2, padding=0, bias=True, dtype=dtype)
10
+ self.softmax_temp1 = nn.Parameter(torch.tensor([0.1], dtype=dtype))
11
+ self.D = nn.Parameter(torch.zeros(N, M, dtype=dtype))
12
+ self.D.data[:, :N] = torch.eye(N, dtype=dtype)
13
+ self.mlp_layer1 = nn.Linear(M + N, Experts, dtype=dtype)
14
+ self.mlp_layer2 = nn.Linear(Experts, Experts, dtype=dtype)
15
+ self.softmax_temp2 = nn.Parameter(torch.tensor([0.1], dtype=dtype))
16
+ self.sigma = nn.Parameter(torch.ones(N, dtype=dtype) * 0.05, requires_grad=True)
17
+
18
+ def forward(self, context, z, precomputed_cnn=None):
19
+ # context: (seq_length, batch_size, N)
20
+ # z: (M, batch_size)
21
+ # precomputed_cnn: Optional precomputed CNN features for inference (seq_length-1, batch_size, N)
22
+
23
+ seq_length, batch_size, N = context.shape
24
+ M = z.shape[0]
25
+
26
+ # Compute attention weights
27
+ z_obs = self.D @ z.detach()
28
+ z_current = z_obs + self.sigma.unsqueeze(1) * torch.randn(N, batch_size, dtype=z.dtype, device=z.device)
29
+
30
+ z_current_t = z_current.transpose(0, 1)
31
+ context_frames = context[:-1]
32
+
33
+ distances = torch.sum(torch.abs(context_frames - z_current_t.unsqueeze(0)), dim=2)
34
+ attention_weights = F.softmax(-distances / torch.abs(self.softmax_temp1[0]), dim=0)
35
+
36
+ # Process context with convolution
37
+ # Use precomputed CNN features if provided, otherwise compute them
38
+ if precomputed_cnn is not None:
39
+ encoded = precomputed_cnn
40
+ else:
41
+ context_for_conv = context.permute(1, 2, 0)
42
+ encoded = self.conv(context_for_conv)
43
+ encoded = encoded.permute(2, 0, 1)
44
+
45
+ # Build weighted embedding
46
+ weighted_encoded = encoded * attention_weights.unsqueeze(2)
47
+ embedding = torch.sum(weighted_encoded, dim=0)
48
+ embedding = embedding.transpose(0, 1)
49
+
50
+ # Predict expert weights
51
+ combined = torch.cat([embedding, z], dim=0)
52
+ combined_t = combined.transpose(0, 1)
53
+ mlp_output = self.mlp_layer2(F.relu(self.mlp_layer1(combined_t)))
54
+ w_exp = F.softmax(-mlp_output.transpose(0, 1) / torch.abs(self.softmax_temp2[0]), dim=0)
55
+ return w_exp
56
+
57
+ def gaussian_init(self, M, N, dtype=torch.float32):
58
+ return torch.randn(M, N, dtype=dtype) * 0.01
59
+
60
+ class ExpertNetwork(nn.Module):
61
+ """Base class for different expert architectures."""
62
+ def __init__(self, M, P=0, probabilistic=False, dtype=torch.float32):
63
+ super().__init__()
64
+ self.M = M
65
+ self.P = P
66
+ self.probabilistic = probabilistic
67
+ self.dtype = dtype
68
+
69
+ # Parameter for probabilistic experts
70
+ if probabilistic:
71
+ self.sigma = nn.Parameter(torch.ones(1, dtype=dtype) * 0.05, requires_grad=True)
72
+
73
+ def forward(self, z):
74
+ raise NotImplementedError("Subclasses must implement forward method")
75
+
76
+ def add_noise(self, z):
77
+ """Add stochasticity to the latent state if in probabilistic mode.
78
+
79
+ Args:
80
+ z: Input tensor
81
+ """
82
+ if self.probabilistic:
83
+ batch_size = z.shape[1]
84
+ noise = torch.randn(self.M, batch_size, dtype=z.dtype, device=z.device)
85
+ return z + self.sigma * noise
86
+ return z
87
+
88
+ def gaussian_init(self, M, N):
89
+ return torch.randn(M, N, dtype=self.dtype) * 0.01
90
+
91
+ def normalized_positive_definite(self, M):
92
+ R = np.random.randn(M, M).astype(np.float32)
93
+ K = R.T @ R / M + np.eye(M)
94
+ lambd = np.max(np.abs(np.linalg.eigvals(K)))
95
+ return K / lambd
96
+
97
+ class AlmostLinearRNN(ExpertNetwork):
98
+ """Almost linear RNN expert architecture."""
99
+ def __init__(self, M, P, probabilistic=False, dtype=torch.float32):
100
+ super().__init__(M, P, probabilistic, dtype=dtype)
101
+ self.A, self.W, self.h = self.initialize_A_W_h(M)
102
+
103
+ def forward(self, z):
104
+ # z: (M, batch_size)
105
+ # Split z into regular and ReLU parts
106
+ z1 = z[:-self.P, :]
107
+ z2 = F.relu(z[-self.P:, :])
108
+ zcat = torch.cat([z1, z2], dim=0)
109
+
110
+ output = self.A.unsqueeze(-1) * z + self.W @ zcat + self.h.unsqueeze(-1)
111
+
112
+ # Add stochasticity if probabilistic
113
+ if self.probabilistic:
114
+ output = self.add_noise(output)
115
+
116
+ return output
117
+
118
+ def initialize_A_W_h(self, M):
119
+ A = torch.nn.Parameter(torch.diag(torch.tensor(self.normalized_positive_definite(M), dtype=self.dtype)))
120
+ W = torch.nn.Parameter(self.gaussian_init(M, M))
121
+ h = torch.nn.Parameter(torch.zeros(M, dtype=self.dtype))
122
+ return A, W, h
123
+
124
+ class ClippedShallowPLRNN(ExpertNetwork):
125
+ """Clipped shallow PLRNN expert architecture."""
126
+ def __init__(self, M, hidden_dim=50, probabilistic=False, dtype=torch.float32):
127
+ super().__init__(M, hidden_dim, probabilistic, dtype=dtype)
128
+ self.A = torch.nn.Parameter(torch.diag(torch.tensor(self.normalized_positive_definite(M), dtype=self.dtype)))
129
+ self.W1 = torch.nn.Parameter(self.gaussian_init(M, hidden_dim))
130
+ self.W2 = torch.nn.Parameter(self.gaussian_init(hidden_dim, M))
131
+ self.h1 = torch.nn.Parameter(torch.zeros(M, dtype=self.dtype))
132
+ self.h2 = torch.nn.Parameter(torch.zeros(hidden_dim, dtype=self.dtype))
133
+
134
+ def forward(self, z):
135
+ # z: (M, batch_size)
136
+ W2z = self.W2 @ z
137
+ output = (self.A.unsqueeze(-1) * z +
138
+ self.W1 @ (F.relu(W2z + self.h2.unsqueeze(-1)) - F.relu(W2z)) +
139
+ self.h1.unsqueeze(-1))
140
+
141
+ # Add stochasticity if probabilistic
142
+ if self.probabilistic:
143
+ output = self.add_noise(output)
144
+
145
+ return output
146
+
147
+ class DynaMix(nn.Module):
148
+ def __init__(self, M, N, Experts, P=2, hidden_dim=50, expert_type="almost_linear_rnn",
149
+ probabilistic_expert=False, dtype=torch.float32):
150
+ """
151
+ Initialize a DynaMix model.
152
+
153
+ Args:
154
+ M: Dimension of latent state
155
+ N: Dimension of observation space
156
+ Experts: Number of experts
157
+ P: Number of ReLU dimensions
158
+ hidden_dim: Hidden dimension for clipped shallow PLRNN
159
+ expert_type: Type of expert to use ("almost_linear_rnn" or "clipped_shallow_plrnn")
160
+ probabilistic_expert: Whether to use probabilistic experts
161
+ dtype: Data type for model parameters (default: torch.float32)
162
+ """
163
+ super().__init__()
164
+
165
+ self.expert_type = expert_type
166
+ self.probabilistic_expert = probabilistic_expert
167
+ self.experts = nn.ModuleList()
168
+ self.dtype = dtype
169
+
170
+ for _ in range(Experts):
171
+ if expert_type == "almost_linear_rnn":
172
+ self.experts.append(AlmostLinearRNN(M, P, probabilistic=probabilistic_expert, dtype=dtype))
173
+ elif expert_type == "clipped_shallow_plrnn":
174
+ self.experts.append(ClippedShallowPLRNN(M, hidden_dim, probabilistic=probabilistic_expert, dtype=dtype))
175
+ else:
176
+ raise ValueError(f"Unknown expert type: {expert_type}")
177
+
178
+ self.gating_network = GatingNetwork(N, M, Experts, dtype=dtype)
179
+ self.B = nn.Parameter(self.uniform_init((N, M), dtype=dtype))
180
+ self.N = N
181
+ self.Experts = Experts
182
+ self.P = P
183
+ self.hidden_dim = hidden_dim
184
+ self.M = M
185
+
186
+ def step(self, z, context, precomputed_cnn=None):
187
+ # z: (M, batch_size)
188
+ # context: (seq_length, batch_size, N)
189
+ # precomputed_cnn: Optional precomputed CNN features
190
+
191
+ # Compute expert weights
192
+ w_exp = self.gating_network(context, z, precomputed_cnn=precomputed_cnn) # (Experts, batch_size)
193
+ results = []
194
+
195
+ # Compute expert outputs
196
+ for i in range(self.Experts):
197
+ expert_output = self.experts[i](z)
198
+ results.append(expert_output * w_exp[i, :].unsqueeze(0))
199
+
200
+ # Combine expert outputs
201
+ return torch.sum(torch.stack(results, dim=0), dim=0)
202
+
203
+ def forward(self, z, context, precomputed_cnn=None):
204
+ """
205
+ Forward pass through the DynaMix model.
206
+
207
+ Args:
208
+ z: Latent state of shape (M, batch_size)
209
+ context: Context data of shape (seq_length, batch_size, N)
210
+ precomputed_cnn: Optional precomputed CNN features to avoid redundant computation for inference
211
+ Shape should be (seq_length-1, batch_size, N)
212
+
213
+ Returns:
214
+ Updated latent state
215
+ """
216
+ return self.step(z, context, precomputed_cnn=precomputed_cnn)
217
+
218
+ def precompute_cnn(self, context):
219
+ """
220
+ Precompute CNN features for more efficient inference.
221
+
222
+ Args:
223
+ context: Context data of shape (seq_length, batch_size, N)
224
+
225
+ Returns:
226
+ Precomputed CNN features of shape (seq_length-1, batch_size, N)
227
+ """
228
+ # Process context with convolution
229
+ context_for_conv = context.permute(1, 2, 0)
230
+ encoded = self.gating_network.conv(context_for_conv)
231
+
232
+ return encoded.permute(2, 0, 1)
233
+
234
+ def uniform_init(self, shape, dtype=torch.float32):
235
+ din = shape[-1]
236
+ r = 1 / np.sqrt(din)
237
+ return (torch.rand(shape, dtype=dtype) * 2 - 1) * r
238
+
239
+ def gaussian_init(self, M, N):
240
+ return torch.randn(M, N, dtype=self.dtype) * 0.01
241
+
242
+ def print_model_parameters(model):
243
+ """Print simplified breakdown of model parameters by component."""
244
+ total_params = sum(p.numel() for p in model.parameters())
245
+
246
+ print("\n" + "-"*60)
247
+ print("Model Parameter Summary:")
248
+ print(f" Architecture: DynaMix with {model.expert_type} experts")
249
+ if model.expert_type == "almost_linear_rnn":
250
+ print(f" Dimensions: M={model.M}, N={model.N}, Experts={model.Experts}, P={model.P}")
251
+ else:
252
+ print(f" Dimensions: M={model.M}, N={model.N}, Experts={model.Experts}, Hidden dim={model.hidden_dim}")
253
+ print(f" Probabilistic experts: {model.probabilistic_expert}")
254
+
255
+ # Count parameters
256
+ gating_params = sum(p.numel() for p in model.gating_network.parameters())
257
+ expert_params = sum(p.numel() for expert in model.experts for p in expert.parameters())
258
+ b_params = model.B.numel()
259
+
260
+ # Print parameter counts
261
+ print(f"\nParameter counts:")
262
+ print(f" Gating Network: {gating_params:,} ({gating_params/total_params:.1%})")
263
+ print(f" Experts: {expert_params:,} ({expert_params/total_params:.1%})")
264
+ print(f" Observation matrix: {b_params:,} ({b_params/total_params:.1%})")
265
+ print(f" Total: {total_params:,} parameters")
266
+ print("-"*60)
dynamix/forecaster.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .preprocessing import DataPreprocessor
4
+
5
+
6
+ class DynaMixForecaster:
7
+ """
8
+ Forecasting pipeline for DynaMix models with batch processing support.
9
+ """
10
+ def __init__(self, model):
11
+ """
12
+ Initialize the forecaster with a DynaMix model.
13
+
14
+ Args:
15
+ model: DynaMix model instance
16
+ """
17
+ self.model = model
18
+
19
+ def _init_latent_state(self, initial_condition):
20
+ """
21
+ Initialize the latent state from the initial condition.
22
+
23
+ Args:
24
+ initial_condition: Initial state of shape (batch_size, N)
25
+
26
+ Returns:
27
+ Initial latent state z
28
+ """
29
+ N = self.model.N
30
+
31
+ # Initialize latent state
32
+ z = torch.matmul(initial_condition, self.model.B).t() # (M, batch_size)
33
+ z[:N, :] = initial_condition.t()
34
+
35
+ return z
36
+
37
+ def _reshape_for_model(self, context, initial_x=None, device=None):
38
+ """
39
+ Prepare and reshape input data for the model.
40
+ Handles tensor conversion, dimension adjustments, and reshaping when feature_dim > model_dim.
41
+
42
+ Args:
43
+ context: Context data tensor of shape (seq_length, batch_size, feature_dim) or (seq_length, feature_dim)
44
+ initial_x: Optional initial condition of shape (batch_size, feature_dim) or (feature_dim,)
45
+ device: Device to place tensors on
46
+
47
+ Returns:
48
+ Processed context, initial_x, dimensions, and reshaping metadata
49
+ """
50
+ # Get the dtype from model parameters
51
+ model_dtype = next(self.model.parameters()).dtype
52
+
53
+ # Convert to torch tensor if needed
54
+ if not isinstance(context, torch.Tensor):
55
+ context = torch.tensor(context, dtype=model_dtype, device=device)
56
+ elif context.device != device or context.dtype != model_dtype:
57
+ context = context.to(device=device, dtype=model_dtype)
58
+
59
+ if initial_x is not None and not isinstance(initial_x, torch.Tensor):
60
+ initial_x = torch.tensor(initial_x, dtype=model_dtype, device=device)
61
+ elif initial_x is not None and (initial_x.device != device or initial_x.dtype != model_dtype):
62
+ initial_x = initial_x.to(device=device, dtype=model_dtype)
63
+
64
+ # Check data dimensions and reshape if needed
65
+ original_dim = context.dim()
66
+ if original_dim == 2:
67
+ context = context.unsqueeze(1) # (seq_length, feature_dim) -> (seq_length, 1, feature_dim)
68
+ elif original_dim != 3:
69
+ raise ValueError(f"Expected 2D or 3D tensor for context, got shape {context.shape} with {context.dim()} dimensions")
70
+ if initial_x is not None and initial_x.dim() == 1:
71
+ initial_x = initial_x.unsqueeze(0) # (feature_dim,) -> (1, feature_dim)
72
+ if initial_x.shape[1] != context.shape[2]:
73
+ raise ValueError(f"Initial condition has {initial_x.shape[1]} features, but context has {context.shape[2]} features")
74
+
75
+ # Data shape
76
+ seq_length, batch_size, feature_dim = context.shape
77
+
78
+ # Check if reshaping is needed for model dimension
79
+ if feature_dim <= self.model.N:
80
+ return context, initial_x, (batch_size, feature_dim, False, None, None, original_dim)
81
+
82
+ print(f"Warning: Input feature dimension {feature_dim} exceeds model dimension {self.model.N}. "
83
+ f"This may lead to performance degradation."
84
+ f"Reshaping data to treat each feature as separate time series.")
85
+
86
+ # Store original dimensions for reshaping back later
87
+ original_batch_size = batch_size
88
+ original_feature_dim = feature_dim
89
+
90
+ # Reshape context to (seq_length, batch_size * feature_dim, 1)
91
+ transposed = context.permute(0, 2, 1)
92
+ new_batch_size = batch_size * feature_dim
93
+ reshaped_context = transposed.reshape(seq_length, new_batch_size, 1)
94
+
95
+ # Similarly reshape initial_x if provided
96
+ reshaped_initial_x = initial_x
97
+ if initial_x is not None:
98
+ # Reshape from (batch_size, feature_dim) to (batch_size * feature_dim, 1)
99
+ reshaped_initial_x = initial_x.transpose(0, 1).reshape(new_batch_size, 1)
100
+
101
+ return reshaped_context, reshaped_initial_x, (new_batch_size, 1, True, original_batch_size, original_feature_dim, original_dim)
102
+
103
+ def _reshape_to_original(self, output, reshape_metadata):
104
+ """
105
+ Reshape output back to original dimensions.
106
+ Handles both high-dimensional reshaping and 2D input restoration.
107
+
108
+ Args:
109
+ output: Model output of shape (T, batch_size, N)
110
+ reshape_metadata: Tuple containing (was_reshaped, original_batch_size, original_feature_dim, original_dim)
111
+
112
+ Returns:
113
+ Output with original shape restored
114
+ """
115
+ _, _, was_reshaped, original_batch_size, original_feature_dim, original_dim = reshape_metadata
116
+
117
+ # Step 1: Reshape back to original dimensions if needed
118
+ if was_reshaped:
119
+ # Current shape: (T, batch_size=original_batch_size*original_feature_dim, 1)
120
+ T = output.shape[0]
121
+
122
+ # First reshape to (T, original_feature_dim, original_batch_size)
123
+ # by treating the batch dimension as (original_feature_dim, original_batch_size)
124
+ reshaped = output.reshape(T, original_feature_dim, original_batch_size, -1)
125
+
126
+ # Then permute to (T, original_batch_size, original_feature_dim)
127
+ output = reshaped.permute(0, 2, 1, 3).squeeze(-1)
128
+
129
+ # Step 2: If input was 2D, remove batch dimension from output
130
+ if original_dim == 2 and output.shape[1] == 1:
131
+ output = output.squeeze(1)
132
+
133
+ return output
134
+
135
+ @torch.no_grad()
136
+ def forecast(self, context, horizon, preprocessing_method="pos_embedding",
137
+ standardize=True, fit_nonstationary=False, initial_x=None):
138
+ """
139
+ Efficient batched forecasting with the DynaMix model.
140
+
141
+ This method implements a complete forecasting pipeline including:
142
+ - Data preprocessing (Box-Cox, detrending, standardization)
143
+ - Embedding techniques for dimensionality matching
144
+ - DynaMix model prediction
145
+ - Data postprocessing (inverse transformations)
146
+
147
+ Args:
148
+ context: Context data tensor of shape (seq_length, batch_size, feature_dim) or (seq_length, feature_dim)
149
+ horizon: Forecast horizon (number of steps to predict)
150
+ preprocessing_method: Data preprocessing method ('pos_embedding', 'zero_embedding',
151
+ 'delay_embedding', or 'delay_embedding_random') (default: 'pos_embedding')
152
+ standardize: Whether to standardize the data (default: True)
153
+ fit_nonstationary: Whether to fit a non-stationary time series (default: False)
154
+ initial_x: Optional initial condition of shape (batch_size, feature_dim) or (feature_dim,)
155
+
156
+ Returns:
157
+ Predicted sequence of shape (horizon, batch_size, feature_dim)
158
+ """
159
+ # Get model dimensions
160
+ M = self.model.M
161
+ N = self.model.N
162
+ device = context.device if isinstance(context, torch.Tensor) else self.model.B.device
163
+ model_dtype = next(self.model.parameters()).dtype
164
+
165
+ # Apply context reshaping if needed
166
+ context, initial_x, shape_metadata = self._reshape_for_model(context, initial_x, device)
167
+
168
+ # Create data preprocessor
169
+ preprocessor = DataPreprocessor(
170
+ standardize=standardize,
171
+ box_cox=fit_nonstationary,
172
+ detrending=fit_nonstationary,
173
+ preprocessing_method=preprocessing_method
174
+ )
175
+
176
+ # Step 1: Apply preprocessing pipeline
177
+ context_embedded, initial_condition = preprocessor.preprocess(context, self.model.N, initial_x)
178
+
179
+ # Step 2: Initialize latent state
180
+ z = self._init_latent_state(initial_condition)
181
+
182
+ # Step 3: Perform forecasting loop
183
+ Z_gen = torch.empty(horizon, M, shape_metadata[0], device=device, dtype=model_dtype)
184
+ with torch.amp.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', enabled=device.type == 'cuda'):
185
+ precomputed_cnn = self.model.precompute_cnn(context_embedded)
186
+ for t in range(horizon):
187
+ z = self.model(z, context_embedded, precomputed_cnn=precomputed_cnn)
188
+ Z_gen[t] = z
189
+
190
+ # Step 4: Apply observation generation
191
+ output = Z_gen[:, :shape_metadata[1], :].permute(0, 2, 1) # (horizon, batch_size, feature_dim)
192
+
193
+ # Step 5: Apply inverse data transformations (e.g. standardization, ...)
194
+ output = preprocessor.postprocess(output)
195
+
196
+ # Step 6: Reshape back to original dimensions if needed
197
+ output = self._reshape_to_original(output, shape_metadata)
198
+
199
+ return output
dynamix/preprocessing.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .preprocessing_utilities import (TimeSeriesProcessor, Embedding,
4
+ BoxCoxTransformer, Detrending, estimate_initial_condition)
5
+
6
+
7
+ class DataPreprocessor:
8
+ """
9
+ Main class for data preprocessing that orchestrates all transformations.
10
+ """
11
+ def __init__(self, standardize=True, box_cox=False, detrending=False, preprocessing_method="pos_embedding"):
12
+ """
13
+ Initialize the data preprocessor.
14
+
15
+ Args:
16
+ standardize: Whether to standardize the data
17
+ box_cox: Whether to apply Box-Cox transformation
18
+ detrending: Whether to apply exponential detrending
19
+ preprocessing_method: Method for embedding ('pos_embedding', 'zero_embedding',
20
+ 'delay_embedding', 'delay_embedding_random')
21
+ """
22
+ self.standardize = standardize
23
+ self.box_cox = box_cox
24
+ self.detrending = detrending
25
+ self.preprocessing_method = preprocessing_method
26
+
27
+ # Parameters for inverse transformations
28
+ self.box_cox_params_list = None
29
+ self.detrending_params_list = None
30
+ self.context_mean = None
31
+ self.context_std = None
32
+ self.original_context = None
33
+ self.batch_size = None
34
+ self.feature_dim = None
35
+
36
+
37
+ def _apply_transformations(self, context):
38
+ """
39
+ Apply Box-Cox transformation and/or detrending to each batch in the context data.
40
+
41
+ Args:
42
+ context: Context data tensor of shape (seq_length, batch_size, N_data)
43
+
44
+ Returns:
45
+ Transformed context data
46
+ """
47
+ # Store original context for inverse transformations
48
+ self.original_context = context.clone()
49
+
50
+ # Apply Box-Cox transformation for each batch
51
+ if self.box_cox:
52
+ transformed_context = torch.zeros_like(context)
53
+ self.box_cox_params_list = []
54
+
55
+ for b in range(self.batch_size):
56
+ batch_context = context[:, b, :]
57
+ transformed, params = BoxCoxTransformer.transform(batch_context)
58
+ transformed_context[:, b, :] = transformed
59
+ self.box_cox_params_list.append(params)
60
+
61
+ context = transformed_context
62
+
63
+ # Apply detrending for each batch
64
+ if self.detrending:
65
+ detrended_context = torch.zeros_like(context)
66
+ self.detrending_params_list = []
67
+
68
+ for b in range(self.batch_size):
69
+ batch_context = context[:, b, :]
70
+ detrended, params = Detrending.apply_detrending(batch_context)
71
+ detrended_context[:, b, :] = detrended
72
+ self.detrending_params_list.append(params)
73
+
74
+ context = detrended_context
75
+
76
+ return context
77
+
78
+ def _apply_transformations_inverse(self, output):
79
+ """
80
+ Apply inverse Box-Cox and detrending transformations.
81
+
82
+ Args:
83
+ output: Model output of shape (T, batch_size, N)
84
+
85
+ Returns:
86
+ Output with transformations reversed
87
+ """
88
+ # Apply inverse detrending for each batch
89
+ if self.detrending and self.detrending_params_list is not None:
90
+ for b in range(self.batch_size):
91
+ batch_output = output[:, b, :]
92
+ batch_context = self.original_context[:, b, :]
93
+ batch_output = Detrending.apply_detrending_inverse(batch_context, batch_output, self.detrending_params_list[b])
94
+ output[:, b, :] = batch_output
95
+
96
+ # Apply inverse Box-Cox transformation for each batch
97
+ if self.box_cox and self.box_cox_params_list is not None:
98
+ for b in range(self.batch_size):
99
+ batch_output = output[:, b, :]
100
+ batch_output = BoxCoxTransformer.inverse_transform(batch_output, self.box_cox_params_list[b])
101
+ output[:, b, :] = batch_output
102
+
103
+ return output
104
+
105
+ def _standardize_data(self, context):
106
+ """
107
+ Standardize each batch in the context data.
108
+
109
+ Args:
110
+ context: Context data tensor of shape (seq_length, batch_size, N_data)
111
+ initial_x: Optional initial condition of shape (batch_size, N_data)
112
+
113
+ Returns:
114
+ Standardized context and initial_x (if provided)
115
+ """
116
+ if not self.standardize:
117
+ return context
118
+
119
+ # Calculate mean and std across time dimension for each batch separately
120
+ self.context_mean = torch.mean(context, dim=0) # (batch_size, N_data)
121
+ self.context_std = torch.std(context, dim=0) # (batch_size, N_data)
122
+ self.context_std = torch.clamp(self.context_std, min=1e-6) # Avoid division by zero
123
+
124
+ # Standardize using broadcasting
125
+ context = (context - self.context_mean.unsqueeze(0)) / self.context_std.unsqueeze(0)
126
+
127
+ return context
128
+
129
+ def _unstandardize_data(self, output):
130
+ """
131
+ Undo standardization by applying the inverse transformation.
132
+
133
+ Args:
134
+ output: Model output of shape (T, batch_size, N)
135
+
136
+ Returns:
137
+ Output with standardization reversed
138
+ """
139
+ if self.standardize and self.context_mean is not None and self.context_std is not None:
140
+ return output * self.context_std.unsqueeze(0) + self.context_mean.unsqueeze(0)
141
+ return output
142
+
143
+ def _apply_embedding(self, context, model_dim):
144
+ """
145
+ Apply data preprocessing to each batch to reach model dimension.
146
+
147
+ Args:
148
+ context: Context data tensor of shape (seq_length, batch_size, N_data)
149
+ model_dim: Target model dimension
150
+
151
+ Returns:
152
+ Preprocessed context data tensor
153
+ """
154
+ context_embedded_batch = []
155
+
156
+ for b in range(self.batch_size):
157
+ batch_context = context[:, b, :]
158
+ batch_embedded = Embedding.apply_embedding(batch_context, model_dim, self.preprocessing_method)
159
+ context_embedded_batch.append(batch_embedded)
160
+
161
+ # Align sequence lengths across batches
162
+ seq_lengths = [emb.shape[0] for emb in context_embedded_batch]
163
+ min_seq_len = min(seq_lengths)
164
+ context_embedded_batch = [emb[-min_seq_len:] for emb in context_embedded_batch]
165
+
166
+ # Stack along batch dimension
167
+ return torch.stack(context_embedded_batch, dim=1)
168
+
169
+ def _prepare_initial_condition(self, context_embedded, initial_x, model_dim):
170
+ """
171
+ Prepare initial condition for forecasting.
172
+
173
+ Args:
174
+ context_embedded: Preprocessed context data
175
+ initial_x: Optional initial condition
176
+ model_dim: Model dimension
177
+
178
+ Returns:
179
+ Initial condition for forecasting
180
+
181
+ Raises:
182
+ ValueError: If initial condition is provided with Box-Cox or detrending enabled
183
+ """
184
+ if initial_x is None:
185
+ # Use last context value for each batch
186
+ return context_embedded[-1]
187
+
188
+ # Raise error if initial condition is provided with Box-Cox or detrending enabled
189
+ if (self.box_cox or self.detrending):
190
+ raise ValueError(
191
+ "Using initial conditions with Box-Cox or detrending is not supported. "
192
+ "Either disable Box-Cox and detrending or do not provide an initial condition."
193
+ )
194
+
195
+ # Process initial conditions for each batch
196
+ initial_x_processed = torch.zeros(self.batch_size, model_dim, device=context_embedded.device)
197
+ for b in range(self.batch_size):
198
+ batch_initial = initial_x[b]
199
+
200
+ # Apply standardization if enabled
201
+ if self.standardize and self.context_mean is not None and self.context_std is not None:
202
+ batch_initial = (batch_initial - self.context_mean[b]) / (self.context_std[b] + 1e-8)
203
+
204
+ # If dimensions are smaller than model_dim, estimate full initial condition
205
+ if initial_x.shape[1] < model_dim:
206
+ # Find matching state in context_embedded
207
+ batch_initial = estimate_initial_condition(
208
+ batch_initial,
209
+ context_embedded[:, b, :],
210
+ )
211
+
212
+ initial_x_processed[b] = batch_initial
213
+
214
+ return initial_x_processed
215
+
216
+ def preprocess(self, context, model_dim, initial_x=None):
217
+ """
218
+ Apply the complete preprocessing pipeline to the input data.
219
+
220
+ Args:
221
+ context: Context data tensor of shape (seq_length, batch_size, N_data) or (seq_length, N_data)
222
+ model_dim: Target model dimension
223
+ initial_x: Optional initial condition of shape (batch_size, N_data) or (N_data,)
224
+
225
+ Returns:
226
+ Preprocessed context data and initial condition
227
+ """
228
+ # Store dimensions
229
+ self.batch_size = context.shape[1]
230
+ self.feature_dim = context.shape[2]
231
+
232
+ # Apply transformations (Box-Cox, detrending)
233
+ context = self._apply_transformations(context)
234
+
235
+ # Standardize data if requested
236
+ context = self._standardize_data(context)
237
+
238
+ # Apply embedding to reach model dimension
239
+ context_embedded = self._apply_embedding(context, model_dim)
240
+
241
+ # Prepare initial batch
242
+ initial_condition = self._prepare_initial_condition(context_embedded, initial_x, model_dim)
243
+
244
+ return context_embedded, initial_condition
245
+
246
+ def postprocess(self, output):
247
+ """
248
+ Apply inverse transformations to restore original data scaling.
249
+
250
+ Args:
251
+ output: Model output of shape (T, batch_size, N)
252
+
253
+ Returns:
254
+ Output with inverse transformations applied
255
+ """
256
+ # Undo standardization
257
+ output = self._unstandardize_data(output)
258
+
259
+ # Apply inverse transformations (Box-Cox, detrending)
260
+ output = self._apply_transformations_inverse(output)
261
+
262
+ return output
dynamix/preprocessing_utilities.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from scipy import stats
4
+ from scipy.signal import find_peaks
5
+ import random
6
+ from statsmodels.tsa.stattools import acf
7
+ from scipy.ndimage import gaussian_filter1d
8
+ from scipy import optimize
9
+
10
+
11
+ class TimeSeriesProcessor:
12
+ """
13
+ Utility class for converting between numpy and torch.
14
+ """
15
+ @staticmethod
16
+ def to_numpy(data):
17
+ """Convert torch tensor to numpy array while preserving device and dtype info"""
18
+ is_torch = isinstance(data, torch.Tensor)
19
+ if is_torch:
20
+ device = data.device
21
+ dtype = data.dtype
22
+ return data.detach().cpu().numpy(), is_torch, device, dtype
23
+ return data, False, None, None
24
+
25
+ @staticmethod
26
+ def to_torch(data_np, is_torch, device=None, dtype=None):
27
+ """Convert numpy array back to torch tensor if original was a tensor"""
28
+ if is_torch:
29
+ return torch.tensor(data_np, device=device, dtype=dtype)
30
+ return data_np
31
+
32
+
33
+ class Embedding:
34
+ """
35
+ Class for embedding methods to transform time series to target dimension.
36
+ """
37
+ @staticmethod
38
+ def estimate_TDM_tau(data, acorr_threshold=1/np.e):
39
+ """
40
+ Estimate tau using autocorrelation function with threshold method
41
+
42
+ Args:
43
+ data: Input data tensor of shape (seq_length, N)
44
+ acorr_threshold: Autocorrelation threshold
45
+
46
+ Returns:
47
+ Maximum estimated tau across all dimensions
48
+ """
49
+ # Convert to numpy
50
+ data_np, _, _, _ = TimeSeriesProcessor.to_numpy(data)
51
+
52
+ seq_length, n_dims = data_np.shape
53
+ tau_vals = np.zeros(n_dims, dtype=int)
54
+
55
+ for dim in range(n_dims):
56
+ # Calculate autocorrelation
57
+ autocorr_vals = acf(data_np[:, dim] - np.mean(data_np[:, dim]), nlags=seq_length//2)
58
+
59
+ # Find first value below threshold (after lag 0)
60
+ below_threshold = np.where(autocorr_vals[1:] < acorr_threshold)[0]
61
+ if len(below_threshold) > 0:
62
+ tau_vals[dim] = below_threshold[0] + 1 # +1 because skipping lag 0
63
+ else:
64
+ tau_vals[dim] = 1 # Default if no value below threshold
65
+
66
+ return int(np.max(tau_vals))
67
+
68
+ @staticmethod
69
+ def estimate_pos_tau(data, max_lag=None, min_lag=None):
70
+ """
71
+ Estimate autocorrelation time for positional embedding
72
+
73
+ Args:
74
+ data: Input data tensor of shape (seq_length, N)
75
+ max_lag: Maximum lag to consider
76
+ min_lag: Minimum lag to consider
77
+
78
+ Returns:
79
+ Maximum autocorrelation time across dimensions
80
+ """
81
+ data_np, _, _, _ = TimeSeriesProcessor.to_numpy(data)
82
+ seq_length, n = data_np.shape
83
+
84
+ if max_lag is None:
85
+ max_lag = seq_length - 1
86
+ if min_lag is None:
87
+ min_lag = seq_length // 10
88
+
89
+ tau_vals = np.zeros(n, dtype=int)
90
+
91
+ for dim in range(n):
92
+ ts = data_np[:, dim] if not isinstance(data, torch.Tensor) else data[:, dim].cpu().numpy()
93
+ autocorr_vals = acf(ts - np.mean(ts), nlags=max_lag)
94
+
95
+ # Determine max autocorrelation with tau>tau_min
96
+ peaks, _ = find_peaks(autocorr_vals)
97
+ valid_peaks = [i for i in peaks if i > min_lag and i < len(autocorr_vals)]
98
+ if valid_peaks:
99
+ peak_values = autocorr_vals[valid_peaks]
100
+ max_peak_idx = np.argmax(peak_values)
101
+ tau_vals[dim] = valid_peaks[max_peak_idx]
102
+ else:
103
+ start_idx = min_lag + 1
104
+ segment = autocorr_vals[start_idx:]
105
+ tau_vals[dim] = start_idx + int(np.argmax(segment))
106
+
107
+ return np.max(tau_vals)
108
+
109
+ @staticmethod
110
+ def delay_embedding(data, model_dim, tau=None):
111
+ """
112
+ Standard delay embedding with optimal tau
113
+
114
+ Args:
115
+ data: Input data tensor of shape (seq_length, N)
116
+ model_dim: Target dimension
117
+ tau: Time delay (if None, estimated from autocorrelation)
118
+
119
+ Returns:
120
+ Delay embedded data of shape (shortened_length, model_dim)
121
+ """
122
+ seq_length, N_data = data.shape
123
+ needed_dims = model_dim - N_data
124
+
125
+ if needed_dims <= 0:
126
+ return data
127
+
128
+ processed_data = data.clone()
129
+
130
+ # Estimate tau if not provided
131
+ if tau is None:
132
+ tau = Embedding.estimate_TDM_tau(processed_data)
133
+
134
+ # Select the last column for embedding
135
+ ts = processed_data[:, -1].clone()
136
+
137
+ # Calculate starting index
138
+ start_idx = needed_dims * tau
139
+
140
+ # Handle case where start_idx is too large
141
+ if start_idx >= seq_length:
142
+ tau = max(1, seq_length // (needed_dims + 1))
143
+ start_idx = needed_dims * tau
144
+
145
+ # Create shortened data
146
+ shortened_data = processed_data[start_idx:].clone()
147
+ result = shortened_data
148
+
149
+ # Add delayed versions
150
+ for i in range(1, needed_dims + 1):
151
+ delayed = ts[start_idx - i * tau:seq_length - i * tau].unsqueeze(1)
152
+ result = torch.cat([result, delayed], dim=1)
153
+
154
+ return result
155
+
156
+ @staticmethod
157
+ def delay_embedding_random(data, model_dim, upper_tau=10, lower_tau=3):
158
+ """
159
+ Random delay embedding with random tau values
160
+
161
+ Args:
162
+ data: Input data tensor of shape (seq_length, N)
163
+ model_dim: Target dimension
164
+ upper_tau: Upper bound for random tau values
165
+ lower_tau: Lower bound for random tau values
166
+
167
+ Returns:
168
+ Random delay embedded data
169
+ """
170
+ seq_length, N_data = data.shape
171
+ needed_dims = model_dim - N_data
172
+
173
+ if needed_dims <= 0:
174
+ return data
175
+
176
+ processed_data = data.clone()
177
+
178
+ # Generate random tau values
179
+ taus = [random.randint(lower_tau, upper_tau) for _ in range(needed_dims)]
180
+ max_tau = max(taus)
181
+
182
+ # Select the first column for embedding
183
+ ts = processed_data[:, 0].clone()
184
+
185
+ # Create shortened data
186
+ result = processed_data[max_tau:].clone()
187
+
188
+ # Add delayed versions
189
+ for i in range(needed_dims):
190
+ delayed = ts[max_tau - taus[i]:seq_length - taus[i]].unsqueeze(1)
191
+ result = torch.cat([result, delayed], dim=1)
192
+
193
+ return result
194
+
195
+ @staticmethod
196
+ def zero_embedding(data, model_dim):
197
+ """
198
+ Zero embedding: appends zeros to reach model dimensions
199
+
200
+ Args:
201
+ data: Input data tensor of shape (seq_length, N)
202
+ model_dim: Target dimension
203
+
204
+ Returns:
205
+ Tensor with zeros appended to reach model_dim
206
+ """
207
+ seq_length, N_data = data.shape
208
+ needed_dims = model_dim - N_data
209
+
210
+ if needed_dims > 0:
211
+ zeros = torch.zeros(seq_length, needed_dims, device=data.device, dtype=data.dtype)
212
+ data = torch.cat([data, zeros], dim=1)
213
+
214
+ return data
215
+
216
+ @staticmethod
217
+ def positional_embedding(data, model_dim, tau=None):
218
+ """
219
+ Positional embedding: adds sinusoidal signals based on autocorrelation time
220
+
221
+ Args:
222
+ data: Input data tensor of shape (seq_length, N)
223
+ model_dim: Target dimension
224
+ tau: Optional fixed value for tau. If None, estimated from data.
225
+
226
+ Returns:
227
+ Data with positional embeddings added
228
+ """
229
+ seq_length, N_data = data.shape
230
+ needed_dims = model_dim - N_data
231
+
232
+ if needed_dims <= 0:
233
+ return data
234
+
235
+ if needed_dims != 1:
236
+ shifts = torch.linspace(0, np.pi/2, needed_dims, device=data.device)
237
+ else:
238
+ shifts = torch.tensor([0.0], device=data.device)
239
+
240
+ tau_val = tau if tau is not None else Embedding.estimate_pos_tau(data)
241
+ t = torch.arange(1, seq_length + 1, dtype=data.dtype, device=data.device)
242
+
243
+ result = data.clone()
244
+ for shift in shifts:
245
+ pos_feature = torch.sin(2 * np.pi / tau_val * t + shift).unsqueeze(1)
246
+ result = torch.cat([result, pos_feature], dim=1)
247
+
248
+ return result
249
+
250
+ @staticmethod
251
+ def apply_embedding(data, model_dim, method="pos_embedding", **kwargs):
252
+ """
253
+ Apply selected embedding method to the data
254
+
255
+ Args:
256
+ data: Input data tensor of shape (seq_length, N)
257
+ model_dim: Target dimension
258
+ method: Embedding method ('pos_embedding', 'zero_embedding',
259
+ 'delay_embedding', or 'delay_embedding_random')
260
+ **kwargs: Additional parameters to pass to the specific embedding method
261
+
262
+ Returns:
263
+ Embedded data
264
+ """
265
+ if method == "pos_embedding":
266
+ return Embedding.positional_embedding(data, model_dim, **kwargs)
267
+ elif method == "zero_embedding":
268
+ return Embedding.zero_embedding(data, model_dim)
269
+ elif method == "delay_embedding":
270
+ return Embedding.delay_embedding(data, model_dim, **kwargs)
271
+ elif method == "delay_embedding_random":
272
+ return Embedding.delay_embedding_random(data, model_dim, **kwargs)
273
+ else:
274
+ raise ValueError(f"Unsupported embedding method: {method}")
275
+
276
+
277
+ class BoxCoxTransformer:
278
+ """
279
+ Applies Box-Cox transformation to data for variance stabilization.
280
+ """
281
+ def __init__(self, lambda_range=(-2, 2)):
282
+ """
283
+ Initialize BoxCoxTransformer.
284
+
285
+ Args:
286
+ lambda_range: Range for lambda parameter search
287
+ """
288
+ self.lambda_range = lambda_range
289
+ self.params = None
290
+
291
+ @staticmethod
292
+ def transform(data, lambda_range=(-2, 2)):
293
+ """
294
+ Apply Box-Cox transformation to data for stabilization
295
+
296
+ Args:
297
+ data: Input data tensor of shape (seq_length, N)
298
+ lambda_range: Range for lambda parameter search
299
+
300
+ Returns:
301
+ Transformed data and parameters for inverse transformation
302
+ """
303
+ # Convert to numpy
304
+ data_np, is_torch, device, dtype = TimeSeriesProcessor.to_numpy(data)
305
+
306
+ seq_length, n_dims = data_np.shape
307
+ transformed_data = np.zeros_like(data_np)
308
+ box_cox_params = []
309
+
310
+ for dim in range(n_dims):
311
+ # Add constant to ensure positivity
312
+ if np.min(data_np[:, dim]) <= 0:
313
+ offset = abs(np.min(data_np[:, dim])) + 1.2
314
+ data_shifted = data_np[:, dim] + offset
315
+ else:
316
+ offset = 1.2
317
+ data_shifted = data_np[:, dim] + offset
318
+
319
+ try:
320
+ # Find optimal lambda for Box-Cox transformation
321
+ transformed, lambda_param = stats.boxcox(data_shifted)
322
+
323
+ # Limit lambda to a reasonable range to prevent numerical issues
324
+ lambda_param = max(min(lambda_param, 2.0), -2.0)
325
+
326
+ # Recalculate transformation with bounded lambda for consistency
327
+ if abs(lambda_param) < 1e-8:
328
+ # For lambda near zero, use logarithmic transformation
329
+ transformed = np.log(data_shifted)
330
+ else:
331
+ transformed = (data_shifted ** lambda_param - 1) / lambda_param
332
+
333
+ # Store transformed data and parameters
334
+ transformed_data[:, dim] = transformed
335
+ except:
336
+ # If transformation fails, just use the original data
337
+ transformed_data[:, dim] = data_np[:, dim]
338
+ lambda_param = 1.0 # Identity transform
339
+
340
+ box_cox_params.append((lambda_param, offset))
341
+
342
+ # Convert back to torch if needed
343
+ return TimeSeriesProcessor.to_torch(transformed_data, is_torch, device, dtype), box_cox_params
344
+
345
+ @staticmethod
346
+ def inverse_transform(data, box_cox_params):
347
+ """
348
+ Apply inverse Box-Cox transformation
349
+
350
+ Args:
351
+ data: Transformed data tensor
352
+ box_cox_params: Parameters from Box-Cox transformation
353
+
354
+ Returns:
355
+ Original scale data
356
+ """
357
+ # Convert to numpy for computation
358
+ data_np, is_torch, device, dtype = TimeSeriesProcessor.to_numpy(data)
359
+
360
+ seq_length, n_dims = data_np.shape
361
+ inverse_data = np.zeros_like(data_np)
362
+
363
+ for dim in range(min(n_dims, len(box_cox_params))):
364
+ lambda_param, offset = box_cox_params[dim]
365
+
366
+ # Apply inverse transformation
367
+ if abs(lambda_param) < 1e-8:
368
+ # For lambda near zero, the transformation is logarithmic
369
+ inverse_data[:, dim] = np.exp(data_np[:, dim]) - offset
370
+ elif abs(lambda_param - 1.0) < 1e-8:
371
+ # For lambda=1 (identity transform), just subtract offset
372
+ inverse_data[:, dim] = data_np[:, dim] - offset
373
+ else:
374
+ # For other lambda values
375
+ base = lambda_param * data_np[:, dim] + 1
376
+
377
+ # Simple clipping approach to ensure base is positive
378
+ # This avoids complex numbers while preserving most data characteristics
379
+ base = np.maximum(base, 1e-10)
380
+
381
+ # Apply power transformation
382
+ result = base ** (1/lambda_param)
383
+ inverse_data[:, dim] = result - offset
384
+
385
+ # Convert back to torch if needed
386
+ return TimeSeriesProcessor.to_torch(inverse_data, is_torch, device, dtype)
387
+
388
+
389
+ class Detrending:
390
+ """
391
+ Applies exponential detrending to time series data.
392
+ """
393
+ @staticmethod
394
+ def exp_model(t, params):
395
+ """
396
+ Exponential model for detrending
397
+
398
+ Args:
399
+ t: Time points
400
+ params: Model parameters [a, b, c]
401
+
402
+ Returns:
403
+ Model values
404
+ """
405
+ a, b, c = params
406
+ return a * (t ** b) + c
407
+
408
+ @staticmethod
409
+ def fit_objective(params, data):
410
+ """
411
+ Objective function for exponential model fitting
412
+
413
+ Args:
414
+ params: Model parameters
415
+ data: Data to fit
416
+
417
+ Returns:
418
+ Sum of squared errors
419
+ """
420
+ t = np.arange(1, len(data) + 1)
421
+ predicted = Detrending.exp_model(t, params)
422
+ return np.sum((data - predicted) ** 2)
423
+
424
+ @staticmethod
425
+ def apply_detrending(data):
426
+ """
427
+ Apply exponential detrending to data
428
+
429
+ Args:
430
+ data: Input data tensor of shape (seq_length, N)
431
+
432
+ Returns:
433
+ Detrended data and parameters for inverse transformation
434
+ """
435
+ # Convert to numpy
436
+ data_np, is_torch, device, dtype = TimeSeriesProcessor.to_numpy(data)
437
+
438
+ seq_length, n_dims = data_np.shape
439
+ detrended_data = np.zeros_like(data_np)
440
+ detrending_params = []
441
+
442
+ for dim in range(n_dims):
443
+ # Define the objective function for this dimension
444
+ objective = lambda params: Detrending.fit_objective(params, data_np[:, dim])
445
+
446
+ # Initial parameter guess
447
+ initial_params = [0.0, 1.0, data_np[0,dim]]
448
+
449
+ # Bounds for parameters
450
+ bounds = [(None, None), (0.1, 3.0), (None, None)]
451
+
452
+ # Optimize
453
+ result = optimize.minimize(
454
+ objective,
455
+ initial_params,
456
+ method='L-BFGS-B',
457
+ bounds=bounds,
458
+ options={
459
+ 'maxiter': 1000,
460
+ 'gtol': 1e-6,
461
+ 'maxfun': 1500,
462
+ 'maxcor': 10
463
+ }
464
+ )
465
+ optimal_params = np.round(result.x, 3)
466
+
467
+ # Calculate trend and detrend the data
468
+ t = np.arange(1, seq_length + 1)
469
+ trend = Detrending.exp_model(t, optimal_params)
470
+ detrended_data[:, dim] = data_np[:, dim] - trend
471
+
472
+ # Store parameters for inverse transformation
473
+ detrending_params.append(optimal_params)
474
+
475
+ # Convert back to torch if needed
476
+ return TimeSeriesProcessor.to_torch(detrended_data, is_torch, device, dtype), detrending_params
477
+
478
+ @staticmethod
479
+ def apply_detrending_inverse(context, data, detrending_params):
480
+ """
481
+ Apply inverse detrending to forecasted data
482
+
483
+ Args:
484
+ context: Original context data
485
+ data: Forecasted data
486
+ detrending_params: Parameters from detrending
487
+
488
+ Returns:
489
+ Forecasted data with trend restored
490
+ """
491
+ # Convert to numpy for computation
492
+ data_np, is_torch, device, dtype = TimeSeriesProcessor.to_numpy(data)
493
+ context_np, _, _, _ = TimeSeriesProcessor.to_numpy(context)
494
+
495
+ # Get dimensions
496
+ forecast_length, n_dims = data_np.shape
497
+ context_length = len(context_np)
498
+
499
+ # Create time points for the forecast horizon
500
+ t = np.arange(context_length + 1, context_length + forecast_length + 1)
501
+
502
+ # Add trend back to each dimension
503
+ for dim in range(min(n_dims, len(detrending_params))):
504
+ params = detrending_params[dim]
505
+ trend = Detrending.exp_model(t, params)
506
+ data_np[:, dim] = data_np[:, dim] + trend
507
+
508
+ # Convert back to torch if needed
509
+ return TimeSeriesProcessor.to_torch(data_np, is_torch, device, dtype)
510
+
511
+
512
+ def estimate_initial_condition(initial_x, context_embedded):
513
+ """
514
+ Estimate full initial condition from partial observation
515
+
516
+ Args:
517
+ initial_x: Partial initial condition of shape (N_partial,)
518
+ context_embedded: Context data of shape (seq_length, N)
519
+
520
+ Returns:
521
+ Complete initial condition of shape (N,)
522
+ """
523
+ T, N = context_embedded.shape
524
+ N_partial = initial_x.shape[0]
525
+
526
+ assert N_partial <= N, "Initial condition dimension must be <= embedding dimension"
527
+
528
+ # Find timestep with closest match to initial condition in first N_partial dimensions
529
+ distances = torch.zeros(T, device=initial_x.device)
530
+ for t in range(T):
531
+ distances[t] = torch.sum((context_embedded[t, :N_partial] - initial_x) ** 2)
532
+
533
+ closest_t = torch.argmin(distances)
534
+
535
+ # Combine initial condition with closest matching state
536
+ return torch.cat([initial_x, context_embedded[closest_t, N_partial:]])
dynamix/utilities.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from huggingface_hub import hf_hub_download
3
+ from safetensors.torch import load_file
4
+ from dynamix.dynamix import DynaMix
5
+ import plotly.graph_objects as go
6
+ import plotly.subplots as sp
7
+ import numpy as np
8
+
9
+ """
10
+ Loading models from HuggingFace Hub
11
+ """
12
+
13
+ def load_hf_model_config(model_name):
14
+ """Load model configuration from HuggingFace Hub"""
15
+
16
+ config_path = hf_hub_download(
17
+ repo_id="DurstewitzLab/dynamix",
18
+ filename="config_" + model_name.replace("dynamix-", "") + ".json"
19
+ )
20
+
21
+ with open(config_path, 'r') as f:
22
+ model_config = json.load(f)
23
+
24
+ return model_config
25
+
26
+ def load_hf_model(model_name):
27
+ """Load a specific DynaMix model with its configuration"""
28
+ try:
29
+ # Load model configuration
30
+ model_config = load_hf_model_config(model_name)
31
+ architecture = model_config["architecture"]
32
+
33
+ # Extract hyperparameters from config
34
+ M = architecture["M"] # Latent state dimension
35
+ N = architecture["N"] # Observation space dimension
36
+ EXPERTS = architecture["Experts"] # Number of experts
37
+ P = architecture["P"] # Number of ReLU dimensions
38
+ HIDDEN_DIM = architecture["hidden_dim"]
39
+ expert_type = architecture["expert_type"]
40
+ probabilistic_expert = architecture["probabilistic_expert"]
41
+
42
+ # Create model with config parameters
43
+ model = DynaMix(
44
+ M=M,
45
+ N=N,
46
+ Experts=EXPERTS,
47
+ expert_type=expert_type,
48
+ P=P,
49
+ hidden_dim=HIDDEN_DIM,
50
+ probabilistic_expert=probabilistic_expert,
51
+ )
52
+
53
+ # Load model weights
54
+ model_path = hf_hub_download(
55
+ repo_id="DurstewitzLab/dynamix",
56
+ filename=model_name + ".safetensors",
57
+ )
58
+ model_state_dict = load_file(model_path)
59
+ model.load_state_dict(model_state_dict)
60
+ model.eval()
61
+
62
+ except Exception as e:
63
+ print(f"Error loading model {model_name}: {e}")
64
+ raise ValueError(f"Model {model_name} not found")
65
+
66
+ return model
67
+
68
+
69
+ # Model selection function
70
+ def auto_model_selection(context):
71
+ """
72
+ Select the model to use for forecasting
73
+ """
74
+ if context.shape[1] == 1:
75
+ return "dynamix-6d-alrnn-v1.0"
76
+ elif context.shape[1] >= 2 and context.shape[1] <= 3:
77
+ return "dynamix-3d-alrnn-v1.0"
78
+ elif context.shape[1] >= 6:
79
+ return "dynamix-6d-alrnn-v1.0"
80
+
81
+
82
+
83
+ """
84
+ Plotting functions
85
+ """
86
+
87
+ def create_forecast_plot(values, reconstruction_ts_np, horizon):
88
+ """
89
+ Create a Plotly figure with dark theme styling matching the reference image
90
+ """
91
+ dims = reconstruction_ts_np.shape[-1]
92
+ plot_dims = min(dims, 15) # plot up to 15 dimensions
93
+
94
+ context_time = np.arange(-len(values), 0)
95
+ forecast_time = np.arange(0, int(horizon))
96
+
97
+ # Create subplots
98
+ # Adjust spacing based on number of dimensions
99
+ if plot_dims <= 3:
100
+ vertical_spacing = 0.1
101
+ elif plot_dims <= 6:
102
+ vertical_spacing = 0.05
103
+ elif plot_dims <= 15:
104
+ vertical_spacing = 0.02
105
+
106
+ fig = sp.make_subplots(
107
+ rows=plot_dims,
108
+ cols=1,
109
+ vertical_spacing=vertical_spacing
110
+ )
111
+
112
+ # Add traces for each dimension
113
+ for d in range(plot_dims):
114
+ # Historical data
115
+ historical_trace = go.Scatter(
116
+ x=context_time,
117
+ y=values[:, d],
118
+ mode='lines',
119
+ line=dict(color='#4169E1', width=2.5),
120
+ name=f"context_{d+1}",
121
+ showlegend=False,
122
+ hovertemplate=f"context_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>"
123
+ )
124
+
125
+ # Forecast
126
+ forecast_trace = go.Scatter(
127
+ x=forecast_time,
128
+ y=reconstruction_ts_np[:, d],
129
+ mode='lines',
130
+ line=dict(color='#FF4242', width=2.5),
131
+ name=f"forecast_{d+1}",
132
+ showlegend=False,
133
+ hovertemplate=f"forecast_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>"
134
+ )
135
+
136
+ fig.add_trace(historical_trace, row=d+1, col=1)
137
+ fig.add_trace(forecast_trace, row=d+1, col=1)
138
+
139
+ fig.update_layout(
140
+ plot_bgcolor='#1f2937',
141
+ paper_bgcolor='#1f2937',
142
+ font=dict(color='white'),
143
+ showlegend=False,
144
+ title=None,
145
+ margin=dict(l=50, r=50, t=30, b=50),
146
+ xaxis=dict(
147
+ gridcolor='rgba(255, 255, 255, 0.2)',
148
+ zerolinecolor='rgba(255, 255, 255, 0.2)',
149
+ showgrid=True
150
+ ),
151
+ yaxis=dict(
152
+ gridcolor='rgba(255, 255, 255, 0.2)',
153
+ zerolinecolor='rgba(255, 255, 255, 0.2)',
154
+ showgrid=True,
155
+ ),
156
+ height=300 if plot_dims == 1 else 250 * plot_dims,
157
+ width=None
158
+ )
159
+
160
+ for i in range(plot_dims):
161
+ fig.update_xaxes(
162
+ gridcolor='rgba(255, 255, 255, 0.2)',
163
+ zerolinecolor='rgba(255, 255, 255, 0.2)',
164
+ showgrid=True,
165
+ row=i+1, col=1
166
+ )
167
+ fig.update_yaxes(
168
+ gridcolor='rgba(255, 255, 255, 0.2)',
169
+ zerolinecolor='rgba(255, 255, 255, 0.2)',
170
+ showgrid=True,
171
+ row=i+1, col=1
172
+ )
173
+
174
+ return fig
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.10.0
2
+ numpy>=1.20.0
3
+ matplotlib>=3.4.0
4
+ plotly>=6.3.0
5
+ scipy>=1.7.0
6
+ pandas>=1.3.0
7
+ safetensors>=0.4.0
8
+ huggingface_hub>=0.19.0
9
+ statsmodels>=0.14.4
10
+ gradio>=5.43.1