import gradio as gr import numpy as np from matplotlib import pyplot as plt import torch from ifbo import FTPFN, Curve from ifbo.priors.ftpfn_prior import sample_curves # Global variables to store state observed_points = [] generated_curves = [] generated_configurations = [] history = {} rank = [] surrogate_model = FTPFN() # Function to generate curves def generate_curves(num_curves, max_length): global generated_curves, generated_configurations, history, rank reset_optimization() configurations, curves = sample_curves(num_hyperparameters=num_curves, hyperparameter_dimensions=3, curve_length=max_length) fig, ax = plt.subplots() hyperparam_display = "Hyperparameter values (corresponding to RGB colors):\n" for i, (x, curve) in enumerate(zip(configurations, curves)): color = x.tolist() # RGB value ax.plot(np.arange(1, len(curve)+1), curve, color=color, label=f'Curve {i+1}') hyperparam_display += f"Curve {i+1}: {np.round(color, 2)}\n" # Display RGB plt.xlabel("t") plt.ylabel("y") plt.ylim(0, 1) plt.legend() generated_curves = torch.FloatTensor(curves) generated_configurations = torch.FloatTensor(configurations) rank = list(range(num_curves)) for i in range(num_curves): history[i] = 0 return fig, None, hyperparam_display # Function to predict the next observed point def next_observed_point(history, curves, configurations): best_sofar = max([max(curves[curve_id][:epoch], default=0) for curve_id, epoch in history.items()]) best_sofar = best_sofar + np.exp(np.random.uniform(-4, -1)) * (1 - best_sofar) max_length = len(curves[0]) horizon = np.random.randint(1, max_length) target = {k: min(max_length, v + horizon) for k, v in history.items()} context, query = [], [] for curve_id in range(len(curves)): if curve_id in history: if history[curve_id] > 0: context.append(Curve( hyperparameters=configurations[curve_id], t=torch.arange(1, history[curve_id] + 1) / max_length, y=curves[curve_id][:history[curve_id]] )) query.append(Curve( hyperparameters=configurations[curve_id], t=torch.FloatTensor([target[curve_id]]) / max_length, )) predictions = surrogate_model.predict(context, query) pis = [] for i, pred in enumerate(predictions): pis.append(pred.pi(best_sofar).item()) return np.argmax(pis) # Bayesian optimization update function def bayesian_update(history, curves, configurations): context, query = [], [] max_length = len(curves[0]) for curve_id in range(len(curves)): if curve_id in history: if history[curve_id] > 0: context.append(Curve( hyperparameters=configurations[curve_id], t=torch.arange(1, history[curve_id] + 1) / max_length, y=curves[curve_id][:history[curve_id]] )) query.append(Curve( hyperparameters=configurations[curve_id], t=torch.arange(max(history[curve_id], 1), max_length + 1) / max_length )) predictions = surrogate_model.predict(context, query) mean, q05, q95 = [], [], [] for i, pred in enumerate(predictions): mean.append(pred.criterion.mean(pred.logits).numpy().tolist()) q05.append(pred.quantile(0.05).numpy().tolist()) q95.append(pred.quantile(0.95).numpy().tolist()) return mean, q05, q95 # Function for Bayesian optimization step def bayesian_optimization_step(num_curves, max_length): global observed_points, generated_curves, generated_configurations, history curves = generated_curves if sum([epoch == len(curves[0]) for epoch in history.values()]) > 0: return gr.update() if len(observed_points) == 0: next_point = np.random.randint(0, num_curves) else: next_point = next_observed_point(history, generated_curves, generated_configurations) rank.remove(next_point) rank.append(next_point) if next_point in history: history[next_point] += 1 else: history[next_point] = 1 next_observation = (history[next_point], next_point) observed_points.append(next_observation) mean, q05, q95 = bayesian_update(history, generated_curves, generated_configurations) # Plot the updated curves with uncertainty and observed points fig, ax = plt.subplots() for r in rank: curve_id = r epoch = history[curve_id] curve = curves[curve_id] color = generated_configurations[curve_id].numpy().tolist() # Plot the full curve with reduced opacity ax.plot(np.arange(1, len(curves[0]) + 1), curve, alpha=0.1, color=color) # Plot the uncertainty region x = np.arange(max(history[curve_id], 1), len(q05[curve_id]) + max(history[curve_id], 1)) ax.fill_between(x, q05[curve_id], q95[curve_id], alpha=0.3, color=color) # Plot observed points and lines for the fully observed part of the curve ax.plot(np.arange(1, epoch + 1), curve[:epoch], 'ro', color=color) ax.plot(np.arange(1, epoch + 1), curve[:epoch], color=color) plt.xlim(0, len(curves[0]) + 1) plt.ylim(0, 1) plt.title(f"Iteration {len(observed_points)}: Observing curve {next_point+1} at step {history[next_point]}") plt.xlabel("t") plt.ylabel("y") return fig # Reset function for Bayesian optimization def reset_optimization(): global observed_points, generated_curves, generated_configurations, history observed_points = [] history = {} rank = [] for i in range(len(generated_curves)): history[i] = 0 return None scroll_script = """ """ # Gradio Interface with gr.Blocks() as demo: gr.HTML(scroll_script) # Add a title gr.Markdown("