AshenH commited on
Commit
94db2b6
·
verified ·
1 Parent(s): 113176c

Update tools/ts_forecast_tool.py

Browse files
Files changed (1) hide show
  1. tools/ts_forecast_tool.py +105 -20
tools/ts_forecast_tool.py CHANGED
@@ -1,30 +1,115 @@
1
  # space/tools/ts_forecast_tool.py
 
 
 
2
  import torch
3
  import pandas as pd
4
- from transformers import AutoModelForTimeSeriesForecasting, AutoTokenizer
 
 
 
 
 
 
 
5
 
6
  class TimeseriesForecastTool:
7
  """
8
- Lightweight wrapper around ibm-granite/granite-timeseries-ttm-r1
9
- using the Transformers interface.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
- def __init__(self,
12
- model_id="ibm-granite/granite-timeseries-ttm-r1",
13
- device=None):
 
 
 
 
 
 
 
 
 
14
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
15
- self.model = AutoModelForTimeSeriesForecasting.from_pretrained(model_id).to(self.device)
16
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
17
-
18
- def zeroshot_forecast(self, series: pd.Series, horizon: int = 96):
19
- """
20
- series: pd.Series indexed by datetime
21
- horizon: forecast steps
22
- """
23
- values = series.values.astype("float32")
24
- inputs = torch.tensor(values, dtype=torch.float32).unsqueeze(0).to(self.device)
 
 
 
 
 
 
25
  with torch.no_grad():
26
- preds = self.model(inputs, prediction_length=horizon).predictions
27
- return pd.DataFrame(
28
- preds.squeeze().cpu().numpy(),
29
- columns=["forecast"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
1
  # space/tools/ts_forecast_tool.py
2
+ import os
3
+ from typing import Optional, Dict
4
+
5
  import torch
6
  import pandas as pd
7
+
8
+ from utils.tracing import Tracer
9
+ from utils.config import AppConfig
10
+
11
+ # We avoid unavailable task-specific heads.
12
+ # Use a generic AutoModel and attempt capability-based calls.
13
+ from transformers import AutoModel, AutoConfig
14
+
15
 
16
  class TimeseriesForecastTool:
17
  """
18
+ Lightweight wrapper around 'ibm-granite/granite-timeseries-ttm-r1' for zero-shot forecasting.
19
+
20
+ This wrapper:
21
+ - loads the model with `AutoModel.from_pretrained`
22
+ - checks for a `.predict(...)` method first
23
+ - else tries calling the model with `prediction_length=horizon`
24
+ - returns a Pandas DataFrame with a single 'forecast' column
25
+
26
+ Expected input:
27
+ - series: pd.Series with a DatetimeIndex (regular frequency recommended)
28
+ - horizon: int, number of future steps
29
+
30
+ NOTE:
31
+ Different library versions expose different APIs. If your environment/model
32
+ lacks a compatible inference method, we raise a clear RuntimeError with
33
+ guidance rather than failing at import time.
34
  """
35
+
36
+ def __init__(
37
+ self,
38
+ cfg: Optional[AppConfig],
39
+ tracer: Optional[Tracer],
40
+ model_id: str = "ibm-granite/granite-timeseries-ttm-r1",
41
+ device: Optional[str] = None,
42
+ ):
43
+ self.cfg = cfg
44
+ self.tracer = tracer
45
+ self.model_id = model_id
46
+
47
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
48
+ # Load config + model generically
49
+ self.config = AutoConfig.from_pretrained(self.model_id)
50
+ self.model = AutoModel.from_pretrained(self.model_id)
51
+ self.model.to(self.device)
52
+ self.model.eval()
53
+
54
+ def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame:
55
+ if not isinstance(series, pd.Series):
56
+ raise ValueError("series must be a pandas Series")
57
+ if series.empty:
58
+ return pd.DataFrame(columns=["forecast"])
59
+
60
+ # Ensure numeric tensor
61
+ values = series.astype("float32").to_numpy()
62
+ x = torch.tensor(values, dtype=torch.float32, device=self.device).unsqueeze(0)
63
+
64
  with torch.no_grad():
65
+ # 1) Preferred: explicit .predict API
66
+ if hasattr(self.model, "predict"):
67
+ try:
68
+ preds = self.model.predict(x, prediction_length=horizon)
69
+ yhat = preds if isinstance(preds, torch.Tensor) else torch.tensor(preds)
70
+ out = yhat.squeeze().detach().cpu().numpy()
71
+ return pd.DataFrame({"forecast": out})
72
+ except Exception as e:
73
+ raise RuntimeError(
74
+ f"Granite model has a 'predict' method but it failed at runtime: {e}"
75
+ )
76
+
77
+ # 2) Fallback: call forward with a 'prediction_length' kwarg if supported
78
+ try:
79
+ outputs = self.model(x, prediction_length=horizon)
80
+ # Try common attribute names
81
+ for k in ("predictions", "prediction", "logits", "output"):
82
+ if hasattr(outputs, k):
83
+ tensor = getattr(outputs, k)
84
+ if isinstance(tensor, (tuple, list)):
85
+ tensor = tensor[0]
86
+ if not isinstance(tensor, torch.Tensor):
87
+ tensor = torch.tensor(tensor)
88
+ out = tensor.squeeze().detach().cpu().numpy()
89
+ # If multi-dim, take last dimension as forecast
90
+ if out.ndim > 1:
91
+ out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
92
+ return pd.DataFrame({"forecast": out})
93
+ # If outputs is a raw tensor
94
+ if isinstance(outputs, torch.Tensor):
95
+ out = outputs.squeeze().detach().cpu().numpy()
96
+ if out.ndim > 1:
97
+ out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
98
+ return pd.DataFrame({"forecast": out})
99
+ except TypeError:
100
+ # Some builds may not accept prediction_length at all
101
+ pass
102
+ except Exception as e:
103
+ raise RuntimeError(
104
+ f"Calling the model forward for forecasting failed: {e}"
105
+ )
106
+
107
+ # If we get here, the installed combo doesn't expose an inference entrypoint we can use.
108
+ raise RuntimeError(
109
+ "The installed transformers/model combo does not expose a usable zero-shot "
110
+ "forecasting interface (no `.predict` and forward(...) didn't accept "
111
+ "`prediction_length`). Consider:\n"
112
+ " • Upgrading transformers/torch versions\n"
113
+ " • Using the 'granite-tsfm-public' PyPI if available in your region\n"
114
+ " • Switching to a classic forecaster for now (e.g., ARIMA/XGBoost)\n"
115
  )