Snxt1 commited on
Commit
6e031de
Β·
verified Β·
1 Parent(s): 2742aad

Update app.py

Browse files

Cleaner dashboard with trend line and moving average function

Files changed (1) hide show
  1. app.py +71 -19
app.py CHANGED
@@ -14,10 +14,10 @@ model = load_model("NX-AI/TiRex")
14
 
15
  def load_columns(file):
16
  if file is None:
17
- return (gr.Dropdown(choices=[], label="Select Time Column:", interactive=True),
18
- gr.Dropdown(choices=[], label="Select Column to Forecast:", interactive=True),
19
- gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index (1-based)"),
20
- gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index (1-based)"))
21
 
22
  try:
23
  # Handle file as path string (Gradio convention)
@@ -45,14 +45,14 @@ def load_columns(file):
45
  time_dropdown = gr.Dropdown(
46
  choices=time_choices,
47
  value=time_value,
48
- label="Select Time Column:",
49
  interactive=True
50
  )
51
 
52
  value_dropdown = gr.Dropdown(
53
  choices=value_choices,
54
  value=value_value,
55
- label="Select Column to Forecast:",
56
  interactive=True
57
  ) if value_choices else gr.Dropdown(
58
  choices=[],
@@ -63,12 +63,12 @@ def load_columns(file):
63
 
64
  start_slider = gr.Slider(
65
  minimum=1, maximum=n_rows, value=1, step=1,
66
- label="Historical Start Index (1-based)"
67
  )
68
 
69
  end_slider = gr.Slider(
70
  minimum=1, maximum=n_rows, value=n_rows, step=1,
71
- label="Historical End Index (1-based)"
72
  )
73
 
74
  return time_dropdown, value_dropdown, start_slider, end_slider
@@ -84,10 +84,13 @@ def load_columns(file):
84
  value=None,
85
  label=f"Error loading CSV: {str(e)}",
86
  interactive=False
87
- ), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index (1-based)"),
88
- gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index (1-based)"))
89
 
90
- def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence):
 
 
 
91
  if file is None or time_col is None or selected_col is None:
92
  return None, "### Error\nPlease upload a CSV and select time and value columns!"
93
 
@@ -237,6 +240,20 @@ def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_le
237
  ax.plot(context_df.index, context_df['sales'], label=f'Used Historical {selected_col}', color='#1f77b4', linewidth=1.5, alpha=0.8)
238
  if not held_out_df.empty:
239
  ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
241
 
242
  # Fan chart: non-overlapping bands
@@ -288,6 +305,20 @@ def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_le
288
 
289
  # Create the Gradio interface
290
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  gr.Markdown("""
292
  # TiRex Forecaster Dashboard
293
  Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model.
@@ -309,23 +340,23 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
309
  )
310
  column_dropdown = gr.Dropdown(
311
  choices=[],
312
- label="Select Column to Forecast",
313
  interactive=True,
314
  elem_id="column_select"
315
  )
316
  start_slider = gr.Slider(
317
  minimum=1, maximum=1, value=1, step=1,
318
- label="Historical Start Index (1-based)",
319
  elem_id="start_idx"
320
  )
321
  end_slider = gr.Slider(
322
  minimum=1, maximum=1, value=1, step=1,
323
- label="Historical End Index (1-based)",
324
  elem_id="end_idx"
325
  )
326
  prediction_length = gr.Slider(
327
- minimum=1, maximum=720, value=12, step=1,
328
- label="Prediction Length (Periods)",
329
  elem_id="pred_length"
330
  )
331
  confidence = gr.Slider(
@@ -333,8 +364,22 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
333
  label="Confidence Level (%)",
334
  elem_id="confidence"
335
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  run_button = gr.Button(
337
- "Run TiRex Forecast",
338
  variant="primary",
339
  size="lg",
340
  elem_id="run_btn"
@@ -350,7 +395,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
350
  elem_id="output"
351
  )
352
 
353
- gr.Markdown("**Built by** [next one gmbh](https://www.nextone.at)")
354
 
355
  # Event for updating dropdowns on file upload
356
  csv_file.change(
@@ -359,10 +404,17 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
359
  outputs=[time_dropdown, column_dropdown, start_slider, end_slider]
360
  )
361
 
 
 
 
 
 
 
 
362
  # Event for running forecast
363
  run_button.click(
364
  run_forecast,
365
- inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence],
366
  outputs=[forecast_plot, output_text]
367
  )
368
 
 
14
 
15
  def load_columns(file):
16
  if file is None:
17
+ return (gr.Dropdown(choices=[], label="Select Time Column", interactive=True),
18
+ gr.Dropdown(choices=[], label="Select Value Column", interactive=True),
19
+ gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"),
20
+ gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index"))
21
 
22
  try:
23
  # Handle file as path string (Gradio convention)
 
45
  time_dropdown = gr.Dropdown(
46
  choices=time_choices,
47
  value=time_value,
48
+ label="Select Time Column",
49
  interactive=True
50
  )
51
 
52
  value_dropdown = gr.Dropdown(
53
  choices=value_choices,
54
  value=value_value,
55
+ label="Select Value Column",
56
  interactive=True
57
  ) if value_choices else gr.Dropdown(
58
  choices=[],
 
63
 
64
  start_slider = gr.Slider(
65
  minimum=1, maximum=n_rows, value=1, step=1,
66
+ label="Historical Start Index"
67
  )
68
 
69
  end_slider = gr.Slider(
70
  minimum=1, maximum=n_rows, value=n_rows, step=1,
71
+ label="Historical End Index"
72
  )
73
 
74
  return time_dropdown, value_dropdown, start_slider, end_slider
 
84
  value=None,
85
  label=f"Error loading CSV: {str(e)}",
86
  interactive=False
87
+ ), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"),
88
+ gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index"))
89
 
90
+ def update_ma_visibility(add_ma):
91
+ return gr.Slider(visible=add_ma)
92
+
93
+ def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence, add_trendline, add_moving_average, ma_window):
94
  if file is None or time_col is None or selected_col is None:
95
  return None, "### Error\nPlease upload a CSV and select time and value columns!"
96
 
 
240
  ax.plot(context_df.index, context_df['sales'], label=f'Used Historical {selected_col}', color='#1f77b4', linewidth=1.5, alpha=0.8)
241
  if not held_out_df.empty:
242
  ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2)
243
+
244
+ if add_trendline:
245
+ x = np.arange(len(context_df))
246
+ y = context_df['sales'].values
247
+ if len(x) > 1:
248
+ coeffs = np.polyfit(x, y, 1)
249
+ trend = np.polyval(coeffs, x)
250
+ ax.plot(context_df.index, trend, label='Trendline', color='black', linestyle='-', linewidth=1.5)
251
+
252
+ if add_moving_average:
253
+ window = int(ma_window)
254
+ ma = context_df['sales'].rolling(window=window, min_periods=1).mean()
255
+ ax.plot(context_df.index, ma, label=f'Moving Average ({window} periods)', color='purple', linewidth=2)
256
+
257
  ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
258
 
259
  # Fan chart: non-overlapping bands
 
305
 
306
  # Create the Gradio interface
307
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo:
308
+ gr.HTML("""
309
+ <link rel="preconnect" href="https://fonts.googleapis.com">
310
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
311
+ <link href="https://fonts.googleapis.com/css2?family=Inter:[email protected]&display=swap" rel="stylesheet">
312
+ <style>
313
+ :root {
314
+ --font-family: Inter, ui-sans-serif, system-ui, sans-serif;
315
+ }
316
+ .gradio-container * {
317
+ font-family: var(--font-family) !important;
318
+ }
319
+ </style>
320
+ """)
321
+
322
  gr.Markdown("""
323
  # TiRex Forecaster Dashboard
324
  Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model.
 
340
  )
341
  column_dropdown = gr.Dropdown(
342
  choices=[],
343
+ label="Select Value Column",
344
  interactive=True,
345
  elem_id="column_select"
346
  )
347
  start_slider = gr.Slider(
348
  minimum=1, maximum=1, value=1, step=1,
349
+ label="Historical Start Index",
350
  elem_id="start_idx"
351
  )
352
  end_slider = gr.Slider(
353
  minimum=1, maximum=1, value=1, step=1,
354
+ label="Historical End Index",
355
  elem_id="end_idx"
356
  )
357
  prediction_length = gr.Slider(
358
+ minimum=1, maximum=720, value=100, step=1,
359
+ label="Prediction Length",
360
  elem_id="pred_length"
361
  )
362
  confidence = gr.Slider(
 
364
  label="Confidence Level (%)",
365
  elem_id="confidence"
366
  )
367
+ trend_checkbox = gr.Checkbox(
368
+ label="Add Trendline",
369
+ value=False
370
+ )
371
+ ma_checkbox = gr.Checkbox(
372
+ label="Add Moving Average",
373
+ value=False
374
+ )
375
+ ma_slider = gr.Slider(
376
+ minimum=3, maximum=30, value=7, step=1,
377
+ label="Moving Average Window (Periods)",
378
+ elem_id="ma_window",
379
+ visible=False
380
+ )
381
  run_button = gr.Button(
382
+ "Run forecast",
383
  variant="primary",
384
  size="lg",
385
  elem_id="run_btn"
 
395
  elem_id="output"
396
  )
397
 
398
+ gr.Markdown("**Built by** [next one gmbh](https://nextone.at/?utm_source=dashboard&utm_medium=referrer&utm_campaign=tirex)")
399
 
400
  # Event for updating dropdowns on file upload
401
  csv_file.change(
 
404
  outputs=[time_dropdown, column_dropdown, start_slider, end_slider]
405
  )
406
 
407
+ # Event for updating MA slider visibility
408
+ ma_checkbox.change(
409
+ update_ma_visibility,
410
+ inputs=[ma_checkbox],
411
+ outputs=[ma_slider]
412
+ )
413
+
414
  # Event for running forecast
415
  run_button.click(
416
  run_forecast,
417
+ inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence, trend_checkbox, ma_checkbox, ma_slider],
418
  outputs=[forecast_plot, output_text]
419
  )
420