Spaces:
Running
Running
Update app.py
Browse filesCleaner dashboard with trend line and moving average function
app.py
CHANGED
|
@@ -14,10 +14,10 @@ model = load_model("NX-AI/TiRex")
|
|
| 14 |
|
| 15 |
def load_columns(file):
|
| 16 |
if file is None:
|
| 17 |
-
return (gr.Dropdown(choices=[], label="Select Time Column
|
| 18 |
-
gr.Dropdown(choices=[], label="Select Column
|
| 19 |
-
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index
|
| 20 |
-
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index
|
| 21 |
|
| 22 |
try:
|
| 23 |
# Handle file as path string (Gradio convention)
|
|
@@ -45,14 +45,14 @@ def load_columns(file):
|
|
| 45 |
time_dropdown = gr.Dropdown(
|
| 46 |
choices=time_choices,
|
| 47 |
value=time_value,
|
| 48 |
-
label="Select Time Column
|
| 49 |
interactive=True
|
| 50 |
)
|
| 51 |
|
| 52 |
value_dropdown = gr.Dropdown(
|
| 53 |
choices=value_choices,
|
| 54 |
value=value_value,
|
| 55 |
-
label="Select Column
|
| 56 |
interactive=True
|
| 57 |
) if value_choices else gr.Dropdown(
|
| 58 |
choices=[],
|
|
@@ -63,12 +63,12 @@ def load_columns(file):
|
|
| 63 |
|
| 64 |
start_slider = gr.Slider(
|
| 65 |
minimum=1, maximum=n_rows, value=1, step=1,
|
| 66 |
-
label="Historical Start Index
|
| 67 |
)
|
| 68 |
|
| 69 |
end_slider = gr.Slider(
|
| 70 |
minimum=1, maximum=n_rows, value=n_rows, step=1,
|
| 71 |
-
label="Historical End Index
|
| 72 |
)
|
| 73 |
|
| 74 |
return time_dropdown, value_dropdown, start_slider, end_slider
|
|
@@ -84,10 +84,13 @@ def load_columns(file):
|
|
| 84 |
value=None,
|
| 85 |
label=f"Error loading CSV: {str(e)}",
|
| 86 |
interactive=False
|
| 87 |
-
), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index
|
| 88 |
-
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index
|
| 89 |
|
| 90 |
-
def
|
|
|
|
|
|
|
|
|
|
| 91 |
if file is None or time_col is None or selected_col is None:
|
| 92 |
return None, "### Error\nPlease upload a CSV and select time and value columns!"
|
| 93 |
|
|
@@ -237,6 +240,20 @@ def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_le
|
|
| 237 |
ax.plot(context_df.index, context_df['sales'], label=f'Used Historical {selected_col}', color='#1f77b4', linewidth=1.5, alpha=0.8)
|
| 238 |
if not held_out_df.empty:
|
| 239 |
ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
|
| 241 |
|
| 242 |
# Fan chart: non-overlapping bands
|
|
@@ -288,6 +305,20 @@ def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_le
|
|
| 288 |
|
| 289 |
# Create the Gradio interface
|
| 290 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
gr.Markdown("""
|
| 292 |
# TiRex Forecaster Dashboard
|
| 293 |
Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model.
|
|
@@ -309,23 +340,23 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
|
|
| 309 |
)
|
| 310 |
column_dropdown = gr.Dropdown(
|
| 311 |
choices=[],
|
| 312 |
-
label="Select Column
|
| 313 |
interactive=True,
|
| 314 |
elem_id="column_select"
|
| 315 |
)
|
| 316 |
start_slider = gr.Slider(
|
| 317 |
minimum=1, maximum=1, value=1, step=1,
|
| 318 |
-
label="Historical Start Index
|
| 319 |
elem_id="start_idx"
|
| 320 |
)
|
| 321 |
end_slider = gr.Slider(
|
| 322 |
minimum=1, maximum=1, value=1, step=1,
|
| 323 |
-
label="Historical End Index
|
| 324 |
elem_id="end_idx"
|
| 325 |
)
|
| 326 |
prediction_length = gr.Slider(
|
| 327 |
-
minimum=1, maximum=720, value=
|
| 328 |
-
label="Prediction Length
|
| 329 |
elem_id="pred_length"
|
| 330 |
)
|
| 331 |
confidence = gr.Slider(
|
|
@@ -333,8 +364,22 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
|
|
| 333 |
label="Confidence Level (%)",
|
| 334 |
elem_id="confidence"
|
| 335 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
run_button = gr.Button(
|
| 337 |
-
"Run
|
| 338 |
variant="primary",
|
| 339 |
size="lg",
|
| 340 |
elem_id="run_btn"
|
|
@@ -350,7 +395,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
|
|
| 350 |
elem_id="output"
|
| 351 |
)
|
| 352 |
|
| 353 |
-
gr.Markdown("**Built by** [next one gmbh](https://
|
| 354 |
|
| 355 |
# Event for updating dropdowns on file upload
|
| 356 |
csv_file.change(
|
|
@@ -359,10 +404,17 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), ti
|
|
| 359 |
outputs=[time_dropdown, column_dropdown, start_slider, end_slider]
|
| 360 |
)
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
# Event for running forecast
|
| 363 |
run_button.click(
|
| 364 |
run_forecast,
|
| 365 |
-
inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence],
|
| 366 |
outputs=[forecast_plot, output_text]
|
| 367 |
)
|
| 368 |
|
|
|
|
| 14 |
|
| 15 |
def load_columns(file):
|
| 16 |
if file is None:
|
| 17 |
+
return (gr.Dropdown(choices=[], label="Select Time Column", interactive=True),
|
| 18 |
+
gr.Dropdown(choices=[], label="Select Value Column", interactive=True),
|
| 19 |
+
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"),
|
| 20 |
+
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index"))
|
| 21 |
|
| 22 |
try:
|
| 23 |
# Handle file as path string (Gradio convention)
|
|
|
|
| 45 |
time_dropdown = gr.Dropdown(
|
| 46 |
choices=time_choices,
|
| 47 |
value=time_value,
|
| 48 |
+
label="Select Time Column",
|
| 49 |
interactive=True
|
| 50 |
)
|
| 51 |
|
| 52 |
value_dropdown = gr.Dropdown(
|
| 53 |
choices=value_choices,
|
| 54 |
value=value_value,
|
| 55 |
+
label="Select Value Column",
|
| 56 |
interactive=True
|
| 57 |
) if value_choices else gr.Dropdown(
|
| 58 |
choices=[],
|
|
|
|
| 63 |
|
| 64 |
start_slider = gr.Slider(
|
| 65 |
minimum=1, maximum=n_rows, value=1, step=1,
|
| 66 |
+
label="Historical Start Index"
|
| 67 |
)
|
| 68 |
|
| 69 |
end_slider = gr.Slider(
|
| 70 |
minimum=1, maximum=n_rows, value=n_rows, step=1,
|
| 71 |
+
label="Historical End Index"
|
| 72 |
)
|
| 73 |
|
| 74 |
return time_dropdown, value_dropdown, start_slider, end_slider
|
|
|
|
| 84 |
value=None,
|
| 85 |
label=f"Error loading CSV: {str(e)}",
|
| 86 |
interactive=False
|
| 87 |
+
), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"),
|
| 88 |
+
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index"))
|
| 89 |
|
| 90 |
+
def update_ma_visibility(add_ma):
|
| 91 |
+
return gr.Slider(visible=add_ma)
|
| 92 |
+
|
| 93 |
+
def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence, add_trendline, add_moving_average, ma_window):
|
| 94 |
if file is None or time_col is None or selected_col is None:
|
| 95 |
return None, "### Error\nPlease upload a CSV and select time and value columns!"
|
| 96 |
|
|
|
|
| 240 |
ax.plot(context_df.index, context_df['sales'], label=f'Used Historical {selected_col}', color='#1f77b4', linewidth=1.5, alpha=0.8)
|
| 241 |
if not held_out_df.empty:
|
| 242 |
ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2)
|
| 243 |
+
|
| 244 |
+
if add_trendline:
|
| 245 |
+
x = np.arange(len(context_df))
|
| 246 |
+
y = context_df['sales'].values
|
| 247 |
+
if len(x) > 1:
|
| 248 |
+
coeffs = np.polyfit(x, y, 1)
|
| 249 |
+
trend = np.polyval(coeffs, x)
|
| 250 |
+
ax.plot(context_df.index, trend, label='Trendline', color='black', linestyle='-', linewidth=1.5)
|
| 251 |
+
|
| 252 |
+
if add_moving_average:
|
| 253 |
+
window = int(ma_window)
|
| 254 |
+
ma = context_df['sales'].rolling(window=window, min_periods=1).mean()
|
| 255 |
+
ax.plot(context_df.index, ma, label=f'Moving Average ({window} periods)', color='purple', linewidth=2)
|
| 256 |
+
|
| 257 |
ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
|
| 258 |
|
| 259 |
# Fan chart: non-overlapping bands
|
|
|
|
| 305 |
|
| 306 |
# Create the Gradio interface
|
| 307 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo:
|
| 308 |
+
gr.HTML("""
|
| 309 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 310 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 311 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:[email protected]&display=swap" rel="stylesheet">
|
| 312 |
+
<style>
|
| 313 |
+
:root {
|
| 314 |
+
--font-family: Inter, ui-sans-serif, system-ui, sans-serif;
|
| 315 |
+
}
|
| 316 |
+
.gradio-container * {
|
| 317 |
+
font-family: var(--font-family) !important;
|
| 318 |
+
}
|
| 319 |
+
</style>
|
| 320 |
+
""")
|
| 321 |
+
|
| 322 |
gr.Markdown("""
|
| 323 |
# TiRex Forecaster Dashboard
|
| 324 |
Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model.
|
|
|
|
| 340 |
)
|
| 341 |
column_dropdown = gr.Dropdown(
|
| 342 |
choices=[],
|
| 343 |
+
label="Select Value Column",
|
| 344 |
interactive=True,
|
| 345 |
elem_id="column_select"
|
| 346 |
)
|
| 347 |
start_slider = gr.Slider(
|
| 348 |
minimum=1, maximum=1, value=1, step=1,
|
| 349 |
+
label="Historical Start Index",
|
| 350 |
elem_id="start_idx"
|
| 351 |
)
|
| 352 |
end_slider = gr.Slider(
|
| 353 |
minimum=1, maximum=1, value=1, step=1,
|
| 354 |
+
label="Historical End Index",
|
| 355 |
elem_id="end_idx"
|
| 356 |
)
|
| 357 |
prediction_length = gr.Slider(
|
| 358 |
+
minimum=1, maximum=720, value=100, step=1,
|
| 359 |
+
label="Prediction Length",
|
| 360 |
elem_id="pred_length"
|
| 361 |
)
|
| 362 |
confidence = gr.Slider(
|
|
|
|
| 364 |
label="Confidence Level (%)",
|
| 365 |
elem_id="confidence"
|
| 366 |
)
|
| 367 |
+
trend_checkbox = gr.Checkbox(
|
| 368 |
+
label="Add Trendline",
|
| 369 |
+
value=False
|
| 370 |
+
)
|
| 371 |
+
ma_checkbox = gr.Checkbox(
|
| 372 |
+
label="Add Moving Average",
|
| 373 |
+
value=False
|
| 374 |
+
)
|
| 375 |
+
ma_slider = gr.Slider(
|
| 376 |
+
minimum=3, maximum=30, value=7, step=1,
|
| 377 |
+
label="Moving Average Window (Periods)",
|
| 378 |
+
elem_id="ma_window",
|
| 379 |
+
visible=False
|
| 380 |
+
)
|
| 381 |
run_button = gr.Button(
|
| 382 |
+
"Run forecast",
|
| 383 |
variant="primary",
|
| 384 |
size="lg",
|
| 385 |
elem_id="run_btn"
|
|
|
|
| 395 |
elem_id="output"
|
| 396 |
)
|
| 397 |
|
| 398 |
+
gr.Markdown("**Built by** [next one gmbh](https://nextone.at/?utm_source=dashboard&utm_medium=referrer&utm_campaign=tirex)")
|
| 399 |
|
| 400 |
# Event for updating dropdowns on file upload
|
| 401 |
csv_file.change(
|
|
|
|
| 404 |
outputs=[time_dropdown, column_dropdown, start_slider, end_slider]
|
| 405 |
)
|
| 406 |
|
| 407 |
+
# Event for updating MA slider visibility
|
| 408 |
+
ma_checkbox.change(
|
| 409 |
+
update_ma_visibility,
|
| 410 |
+
inputs=[ma_checkbox],
|
| 411 |
+
outputs=[ma_slider]
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
# Event for running forecast
|
| 415 |
run_button.click(
|
| 416 |
run_forecast,
|
| 417 |
+
inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence, trend_checkbox, ma_checkbox, ma_slider],
|
| 418 |
outputs=[forecast_plot, output_text]
|
| 419 |
)
|
| 420 |
|