import gradio as gr import numpy as np from matplotlib import pyplot as plt import torch from ifbo import FTPFN, Curve from ifbo.priors.ftpfn_prior import sample_curves # Global variables to store state observed_points = [] generated_curves = [] generated_configurations = [] history = {} rank = [] surrogate_model = FTPFN() # Function to generate curves def generate_curves(num_curves, max_length): global generated_curves, generated_configurations, history, rank reset_optimization() configurations, curves = sample_curves(num_hyperparameters=num_curves, hyperparameter_dimensions=3, curve_length=max_length) fig, ax = plt.subplots() hyperparam_display = "Hyperparameter values (corresponding to RGB colors):\n" for i, (x, curve) in enumerate(zip(configurations, curves)): color = x.tolist() # RGB value ax.plot(np.arange(1, len(curve)+1), curve, color=color, label=f'Curve {i+1}') hyperparam_display += f"Curve {i+1}: {np.round(color, 2)}\n" # Display RGB plt.xlabel("t") plt.ylabel("y") plt.ylim(0, 1) plt.legend() generated_curves = torch.FloatTensor(curves) generated_configurations = torch.FloatTensor(configurations) rank = list(range(num_curves)) for i in range(num_curves): history[i] = 0 return fig, None, hyperparam_display # Function to predict the next observed point def next_observed_point(history, curves, configurations): best_sofar = max([max(curves[curve_id][:epoch], default=0) for curve_id, epoch in history.items()]) best_sofar = best_sofar + np.exp(np.random.uniform(-4, -1)) * (1 - best_sofar) max_length = len(curves[0]) horizon = np.random.randint(1, max_length) target = {k: min(max_length, v + horizon) for k, v in history.items()} context, query = [], [] for curve_id in range(len(curves)): if curve_id in history: if history[curve_id] > 0: context.append(Curve( hyperparameters=configurations[curve_id], t=torch.arange(1, history[curve_id] + 1) / max_length, y=curves[curve_id][:history[curve_id]] )) query.append(Curve( hyperparameters=configurations[curve_id], t=torch.FloatTensor([target[curve_id]]) / max_length, )) predictions = surrogate_model.predict(context, query) pis = [] for i, pred in enumerate(predictions): pis.append(pred.pi(best_sofar).item()) return np.argmax(pis) # Bayesian optimization update function def bayesian_update(history, curves, configurations): context, query = [], [] max_length = len(curves[0]) for curve_id in range(len(curves)): if curve_id in history: if history[curve_id] > 0: context.append(Curve( hyperparameters=configurations[curve_id], t=torch.arange(1, history[curve_id] + 1) / max_length, y=curves[curve_id][:history[curve_id]] )) query.append(Curve( hyperparameters=configurations[curve_id], t=torch.arange(max(history[curve_id], 1), max_length + 1) / max_length )) predictions = surrogate_model.predict(context, query) mean, q05, q95 = [], [], [] for i, pred in enumerate(predictions): mean.append(pred.criterion.mean(pred.logits).numpy().tolist()) q05.append(pred.quantile(0.05).numpy().tolist()) q95.append(pred.quantile(0.95).numpy().tolist()) return mean, q05, q95 # Function for Bayesian optimization step def bayesian_optimization_step(num_curves, max_length): global observed_points, generated_curves, generated_configurations, history curves = generated_curves if sum([epoch == len(curves[0]) for epoch in history.values()]) > 0: return gr.update() if len(observed_points) == 0: next_point = np.random.randint(0, num_curves) else: next_point = next_observed_point(history, generated_curves, generated_configurations) rank.remove(next_point) rank.append(next_point) if next_point in history: history[next_point] += 1 else: history[next_point] = 1 next_observation = (history[next_point], next_point) observed_points.append(next_observation) mean, q05, q95 = bayesian_update(history, generated_curves, generated_configurations) # Plot the updated curves with uncertainty and observed points fig, ax = plt.subplots() for r in rank: curve_id = r epoch = history[curve_id] curve = curves[curve_id] color = generated_configurations[curve_id].numpy().tolist() # Plot the full curve with reduced opacity ax.plot(np.arange(1, len(curves[0]) + 1), curve, alpha=0.1, color=color) # Plot the uncertainty region x = np.arange(max(history[curve_id], 1), len(q05[curve_id]) + max(history[curve_id], 1)) ax.fill_between(x, q05[curve_id], q95[curve_id], alpha=0.3, color=color) # Plot observed points and lines for the fully observed part of the curve ax.plot(np.arange(1, epoch + 1), curve[:epoch], 'ro', color=color) ax.plot(np.arange(1, epoch + 1), curve[:epoch], color=color) plt.xlim(0, len(curves[0]) + 1) plt.ylim(0, 1) plt.title(f"Iteration {len(observed_points)}: Observing curve {next_point+1} at step {history[next_point]}") plt.xlabel("t") plt.ylabel("y") return fig # Reset function for Bayesian optimization def reset_optimization(): global observed_points, generated_curves, generated_configurations, history observed_points = [] history = {} rank = [] for i in range(len(generated_curves)): history[i] = 0 return None scroll_script = """ """ # Gradio Interface with gr.Blocks() as demo: gr.HTML(scroll_script) # Add a title gr.Markdown("

ifBO: In-context Freeze-Thaw Bayesian Optimization

") gr.Markdown("Paper: [https://arxiv.org/pdf/2404.16795](https://arxiv.org/pdf/2404.16795)") gr.Markdown("Code: [http://github.com/automl/ifBO](http://github.com/automl/ifBO)") # First section for curve generation gr.Markdown("## Curve Generation") gr.Markdown(""" To generate' a set of synthetic curves (according to the proposed curve prior, see section 4.1), please specify the number of curves and the maximum length of each curve, then click 'Generate'. """) with gr.Row(): with gr.Column(scale=1): num_curves = gr.Number(label="Number of curves", value=3) max_length = gr.Number(label="Max length of curves", value=10) generate_btn = gr.Button("Generate") with gr.Column(scale=1): hyperparam_text = gr.Textbox(label="") with gr.Column(scale=2): curve_plot = gr.Plot() # Separate section for Bayesian optimization gr.Markdown("## Bayesian Optimization") gr.Markdown(""" After generating the curves, click 'One Step with ifBO' to take an optimization step. During this step, **ifBO** will select the next point to evaluate based on the previously observed points, using the MFPI-random acquisition function (see section 4.2). The plot will then update to show the newly observed point and the updated predictions for each curve. """) with gr.Row(): with gr.Column(): next_step_btn = gr.Button("One ifBO Step") reset_btn = gr.Button("Reset") with gr.Column(): optimization_plot = gr.Plot() gr.HTML("") # Link buttons to actions generate_btn.click(generate_curves, inputs=[num_curves, max_length], outputs=[curve_plot, optimization_plot, hyperparam_text]) next_step_btn.click(bayesian_optimization_step, inputs=[num_curves, max_length], outputs=optimization_plot) reset_btn.click(reset_optimization, outputs=optimization_plot) # Launch the demo demo.launch()