Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| import gradio as gr | |
| from typing import Dict, List | |
| from src.logic.data_processing import PARTITION_OPTIONS, prepare_for_non_grouped_plotting, prepare_for_group_plotting | |
| from src.logic.graph_settings import Grouping | |
| from src.logic.utils import set_alpha | |
| from datatrove.utils.stats import MetricStatsDict | |
| def plot_scatter( | |
| data: Dict[str, Dict[float, float]], | |
| metric_name: str, | |
| log_scale_x: bool, | |
| log_scale_y: bool, | |
| normalization: bool, | |
| rounding: int, | |
| cumsum: bool, | |
| perc: bool, | |
| progress: gr.Progress, | |
| ): | |
| fig = go.Figure() | |
| data = {name: histogram for name, histogram in sorted(data.items())} | |
| for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): | |
| histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding) | |
| x = sorted(histogram_prepared.keys()) | |
| y = [histogram_prepared[k] for k in x] | |
| if cumsum: | |
| y = np.cumsum(y).tolist() | |
| if perc: | |
| y = (np.array(y) * 100).tolist() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="lines", | |
| name=name, | |
| marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), | |
| ) | |
| ) | |
| yaxis_title = "Frequency" if normalization else "Total" | |
| fig.update_layout( | |
| title=f"Line Plots for {metric_name}", | |
| xaxis_title=metric_name, | |
| yaxis_title=yaxis_title, | |
| xaxis_type="log" if log_scale_x and len(x) > 1 else None, | |
| yaxis_type="log" if log_scale_y and len(y) > 1 else None, | |
| width=1200, | |
| height=600, | |
| showlegend=True, | |
| ) | |
| return fig | |
| def plot_bars( | |
| data: Dict[str, MetricStatsDict], | |
| metric_name: str, | |
| top_k: int, | |
| direction: PARTITION_OPTIONS, | |
| regex: str | None, | |
| rounding: int, | |
| log_scale_x: bool, | |
| log_scale_y: bool, | |
| show_stds: bool, | |
| progress: gr.Progress, | |
| ): | |
| fig = go.Figure() | |
| x = [] | |
| y = [] | |
| for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): | |
| x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding) | |
| fig.add_trace(go.Bar( | |
| x=x, | |
| y=y, | |
| name=f"{name} Mean", | |
| marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), | |
| error_y=dict(type='data', array=stds, visible=show_stds) | |
| )) | |
| fig.update_layout( | |
| title=f"Bar Plots for {metric_name}", | |
| xaxis_title=metric_name, | |
| yaxis_title="Avg. value", | |
| xaxis_type="log" if log_scale_x and len(x) > 1 else None, | |
| yaxis_type="log" if log_scale_y and len(y) > 1 else None, | |
| autosize=True, | |
| width=1200, | |
| height=600, | |
| showlegend=True, | |
| ) | |
| return fig | |
| # Add any other necessary functions | |
| def plot_data( | |
| metric_data: Dict[str, MetricStatsDict], | |
| metric_name: str, | |
| normalize: bool, | |
| rounding: int, | |
| grouping: Grouping, | |
| top_n: int, | |
| direction: PARTITION_OPTIONS, | |
| group_regex: str, | |
| log_scale_x: bool, | |
| log_scale_y: bool, | |
| cdf: bool, | |
| perc: bool, | |
| show_stds: bool, | |
| ) -> tuple[go.Figure, gr.Row, str]: | |
| if grouping == "histogram": | |
| fig = plot_scatter( | |
| metric_data, | |
| metric_name, | |
| log_scale_x, | |
| log_scale_y, | |
| normalize, | |
| rounding, | |
| cdf, | |
| perc, | |
| gr.Progress(), | |
| ) | |
| min_max_hist_data = generate_min_max_hist_data(metric_data) | |
| return fig, gr.Row.update(visible=True), min_max_hist_data | |
| else: | |
| fig = plot_bars( | |
| metric_data, | |
| metric_name, | |
| top_n, | |
| direction, | |
| group_regex, | |
| rounding, | |
| log_scale_x, | |
| log_scale_y, | |
| show_stds, | |
| gr.Progress(), | |
| ) | |
| return fig, gr.Row.update(visible=True), "" | |
| def generate_min_max_hist_data(data: Dict[str, MetricStatsDict]) -> str: | |
| runs_data = { | |
| run: { | |
| "min": min(map(float, dato.keys())), | |
| "max": max(map(float, dato.keys())), | |
| } | |
| for run, dato in data.items() | |
| } | |
| runs_rows = [ | |
| f"| {run} | {values['min']:.4f} | {values['max']:.4f} |" | |
| for run, values in runs_data.items() | |
| ] | |
| header = "| Run | Min | Max |\n|-----|-----|-----|\n" | |
| return header + "\n".join(runs_rows) |