A tutorial on how to build a Foundation Model for Univariate Time Series Forecasting
A concise, reproducible recipe for training a transformer-based, patch-to-patch forecasting model for univariate time series. The approach mirrors Large Language Model (LLM) practices (next-token β next-patch) while remaining lightweight compared to a classic LLM and practical.
Highlights
- Next-patch prediction objective (autoregressive, causal)
- Patch-based representation of time series (tokens β patches)
- Causal masking self-attention with RoPE (relative positions)
- RevIN (Reversible Instance Normalization) with causal statistics
- SwiGLU feed-forward networks
- Multi-quantile outputs (median + uncertainty bands)
- Efficient rollout with KV caching
tags: - timeseries - forecasting - transformer - patches - foundation - zero-shot pipeline_tag: time-series-forecasting
This model has been pushed to the Hub using the PytorchModelHubMixin integration:
- Code: GitHub
- Paper: Incoming
Installation
git clone https://github.com/vilhess/PatchFM
cd PatchFM
pip install -r requirements.txt
Quick Start
import torch
from model import Forecaster
from configs import PatchFMConfig
# --- Instantiate model ---
config = PatchFMConfig(load_from _hub=True)
model = Forecaster(config)
# --- Inference ---
forecast_horizon = 64
seq = torch.randn(1, 1024) # (batch, time)
pred_median, pred_quantiles = model(seq, forecast_horizon=forecast_horizon, quantiles=[0.1, 0.5, 0.9]) # (batch, time, quantiles)
We provide an extended quick start example in notebooks/tutorial.ipynb. If you dont have suitable hardware you can run the the extended quick start example example also in Google Colab:
Method (TL;DR)
- Patching: Split a context signal of length $w$ into $P_{num} = w / P_{len}$ patches of length $P_{len}$.
- RevIN: Normalize patches using causal running mean/variance over past patches, and denormalize outputs to the original scale.
- Architecture: Input residual MLP β stacked Transformer blocks (MHA + SwiGLU FFN, pre-norm, residual) β $|\mathcal{Q}|$ output heads mapping back to patch space.
- Positional encoding: Rotary Position Embeddings (RoPE) applied to queries/keys.
- Training: Multi-quantile (pinball) loss across positions, elements, and quantiles $\mathcal{Q}$.
- Inference: Predict next patch; roll out autoregressively with KV caching for long horizons.
Problem Formulation
Given context patches $x_{p_1}, \ldots, x_{p_n}$, predict the next patch $x_{p_{i+1}}$ for each position $i$ using only past patches (causality). The model outputs quantiles ${\hat{x}{p{i+1}}^{(q)}: q \in \mathcal{Q}}$ with median (q=0.5) as the point forecast.
Loss: Multi-Quantile (Pinball)
For residual $u = x - \hat{x}^{(q)}$: Aggregate over positions, patch elements, and quantiles.
Architecture
- Input MLP: $\mathbb{R}^{P_{len}} \to \mathbb{R}^{dim}$ residual 2-layer MLP (ReLU)
- Multi-Head Attention: causal mask, RoPE; queries/keys/values per head
- FFN: SwiGLU (SiLU-gated), pre-norm + residual
- Output heads: |Q| linear maps $\mathbb{R}^{dim} \to \mathbb{R}^{P_{len}}$ (one per quantile)
Model Details
- Patch size: 32
- Max context: 32 patches (1024 steps)
- Forecast horizon: 32 steps per forward pass
- Quantiles $\mathcal{Q}$: {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}
- Layers: 6
- Attention heads: 64 (head dim 32)
- Model dim: 2048
- Parameters: ~300M
Inference
- Single step: predict next patch ($P_{len}$ values)
- Long-horizon: append prediction to context and repeat (optionally drop oldest patch to keep window fixed)
- KV caching: reuse cached keys/values for past patches; compute new Q/K/V only for the appended patch
Datasets
- UTSD (Unified Time Series Dataset) [UTSD]: seven domains (Energy, IoT, Nature, Web, Health, Transport, Environment). We start with UTSD-1G (~55M series after preprocessing).
- Artificial: ~1M synthetic series (sinusoidal, linear, polynomial, logarithmic) plus mixtures via TSMixup [Chronos]; Gaussian Process samples via KernelSynth (mixtures of RBF/periodic/linear kernels with swept hyperparameters).
Repository Layout
model/training/β main PatchFM model classmodules.py- core modules (Residual Layers, MHA, SwiGLU, RoPE, Transformer Encoder, ...)revin.pyβ causal RevINloss.pyβ multi-quantile (pinball) losstrainer.pyβ PyTorch Lightning trainer class
model/inference/β main PatchFM model class for inference with KV cachingmodules.pyβ core modules with caching supportforecaster.pyβ Forecasting model with KV caching and rollout logic
dataset/β data loading and preprocessingartificial.pyβ synthetic dataset : artificial signals + TSMixup + KernelSynthutsd.pyβ Unified Time Series Dataset (UTSD) loading and preprocessingget_data.pyβ utility to fetch and preprocess datasetsgenerate_data.pyβ utility to generate and save the KernelSynth dataset (long to generate)
configs/β model and training configurationsnotebooks/inferenceβ how to load a trained model and generate forecaststraining.pyβ training script using PyTorch Lightning
Acknowledgements
We thank the authors of the following repositories for inspiration and code snippets:
- Downloads last month
- 321