herilalaina commited on
Commit
fb5c7f8
·
1 Parent(s): 4657628
Files changed (2) hide show
  1. app.py +187 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from matplotlib import pyplot as plt
4
+ import torch
5
+ from ifbo import FTPFN, Curve
6
+ from ifbo.priors.ftpfn_prior import sample_curves
7
+
8
+ # Global variables to store state
9
+ observed_points = []
10
+ generated_curves = []
11
+ generated_configurations = []
12
+ history = {}
13
+ surrogate_model = FTPFN()
14
+
15
+ # Function to generate curves
16
+ def generate_curves(num_curves, max_length):
17
+ global generated_curves, generated_configurations, history
18
+ reset_optimization()
19
+
20
+ configurations, curves = sample_curves(num_hyperparameters=num_curves,
21
+ hyperparameter_dimensions=3,
22
+ curve_length=max_length)
23
+ fig, ax = plt.subplots()
24
+ hyperparam_display = "Hyperparameter values (corresponding to RGB colors):\n"
25
+
26
+ for i, (x, curve) in enumerate(zip(configurations, curves)):
27
+ color = x.tolist() # RGB value
28
+ ax.plot(np.arange(1, len(curve)+1), curve, color=color, label=f'Curve {i+1}')
29
+ hyperparam_display += f"Curve {i+1}: {np.round(color, 2)}\n" # Display RGB
30
+
31
+ plt.xlabel("t")
32
+ plt.ylabel("y")
33
+ plt.ylim(0, 1)
34
+ plt.legend()
35
+
36
+ generated_curves = torch.FloatTensor(curves)
37
+ generated_configurations = torch.FloatTensor(configurations)
38
+ for i in range(num_curves):
39
+ history[i] = 0
40
+ return fig, None, hyperparam_display
41
+
42
+ # Function to predict the next observed point
43
+ def next_observed_point(history, curves, configurations):
44
+ best_sofar = max([max(curves[curve_id][:epoch], default=0) for curve_id, epoch in history.items()])
45
+ best_sofar = best_sofar + np.exp(np.random.uniform(-4, -1)) * (1 - best_sofar)
46
+ max_length = len(curves[0])
47
+ horizon = np.random.randint(1, max_length)
48
+ target = {k: min(max_length, v + horizon) for k, v in history.items()}
49
+
50
+ context, query = [], []
51
+ for curve_id in range(len(curves)):
52
+ if curve_id in history:
53
+ if history[curve_id] > 0:
54
+ context.append(Curve(
55
+ hyperparameters=configurations[curve_id],
56
+ t=torch.arange(1, history[curve_id] + 1) / max_length,
57
+ y=curves[curve_id][:history[curve_id]]
58
+ ))
59
+ query.append(Curve(
60
+ hyperparameters=configurations[curve_id],
61
+ t=torch.FloatTensor([target[curve_id]]) / max_length,
62
+ ))
63
+ predictions = surrogate_model.predict(context, query)
64
+
65
+ pis = []
66
+ for i, pred in enumerate(predictions):
67
+ pis.append(pred.pi(best_sofar).item())
68
+
69
+ return np.argmax(pis)
70
+
71
+ # Bayesian optimization update function
72
+ def bayesian_update(history, curves, configurations):
73
+ context, query = [], []
74
+ max_length = len(curves[0])
75
+ for curve_id in range(len(curves)):
76
+ if curve_id in history:
77
+ if history[curve_id] > 0:
78
+ context.append(Curve(
79
+ hyperparameters=configurations[curve_id],
80
+ t=torch.arange(1, history[curve_id] + 1) / max_length,
81
+ y=curves[curve_id][:history[curve_id]]
82
+ ))
83
+ query.append(Curve(
84
+ hyperparameters=configurations[curve_id],
85
+ t=torch.arange(max(history[curve_id], 1), max_length + 1) / max_length
86
+ ))
87
+ predictions = surrogate_model.predict(context, query)
88
+ mean, q05, q95 = [], [], []
89
+
90
+ for i, pred in enumerate(predictions):
91
+ mean.append(pred.criterion.mean(pred.logits).numpy().tolist())
92
+ q05.append(pred.quantile(0.05).numpy().tolist())
93
+ q95.append(pred.quantile(0.95).numpy().tolist())
94
+
95
+ return mean, q05, q95
96
+
97
+ # Function for Bayesian optimization step
98
+ def bayesian_optimization_step(num_curves, max_length):
99
+ global observed_points, generated_curves, generated_configurations, history
100
+ curves = generated_curves
101
+
102
+ if sum([epoch == len(curves[0]) for epoch in history.values()]) > 0:
103
+ return gr.update()
104
+
105
+ if len(observed_points) == 0:
106
+ next_point = np.random.randint(0, num_curves)
107
+ else:
108
+ next_point = next_observed_point(history, generated_curves, generated_configurations)
109
+
110
+ if next_point in history:
111
+ history[next_point] += 1
112
+ else:
113
+ history[next_point] = 1
114
+
115
+ next_observation = (history[next_point], next_point)
116
+ observed_points.append(next_observation)
117
+
118
+ mean, q05, q95 = bayesian_update(history, generated_curves, generated_configurations)
119
+
120
+ # Plot the updated curves with uncertainty and observed points
121
+ fig, ax = plt.subplots()
122
+
123
+ for i, curve in enumerate(curves):
124
+ color = generated_configurations[i].numpy().tolist()
125
+ ax.plot(np.arange(1, len(curves[0]) + 1), curve, alpha=0.1, color=color)
126
+ x = np.arange(max(history[i], 1), len(q05[i]) + max(history[i], 1))
127
+ ax.fill_between(x, q05[i], q95[i], alpha=0.3, color=color)
128
+
129
+ for curve_id, epoch in history.items():
130
+ color = generated_configurations[curve_id].numpy().tolist()
131
+ ax.plot(np.arange(1, epoch+1), curves[curve_id][:epoch], 'ro', color=color)
132
+ ax.plot(np.arange(1, epoch+1), curves[curve_id][:epoch], color=color)
133
+
134
+ plt.xlim(0, len(curves[0]) + 1)
135
+ plt.ylim(0, 1)
136
+ plt.title(f"Step {len(observed_points)}")
137
+ plt.xlabel("t")
138
+ plt.ylabel("y")
139
+ return fig
140
+
141
+ # Reset function for Bayesian optimization
142
+ def reset_optimization():
143
+ global observed_points, generated_curves, generated_configurations, history
144
+ observed_points = []
145
+ history = {}
146
+ for i in range(len(generated_curves)):
147
+ history[i] = 0
148
+ return None
149
+
150
+ # Gradio Interface
151
+ with gr.Blocks() as demo:
152
+ # Add a title
153
+ gr.Markdown("# ifBO: In-context Freeze-Thaw Bayesian Optimization")
154
+
155
+ # First section for curve generation
156
+ gr.Markdown("### Curve Generation")
157
+ gr.Markdown("Input the number of curves and the maximum length of curves, then click 'Generate' to sample curves from our prior.")
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=1):
161
+ num_curves = gr.Number(label="Number of curves", value=3)
162
+ max_length = gr.Number(label="Max length of curves", value=10)
163
+ generate_btn = gr.Button("Generate")
164
+ with gr.Column(scale=1):
165
+ hyperparam_text = gr.Textbox(label="")
166
+ with gr.Column(scale=2):
167
+ curve_plot = gr.Plot()
168
+
169
+ # Separate section for Bayesian optimization
170
+ gr.Markdown("### Bayesian Optimization")
171
+ gr.Markdown("After generating the curves, click 'BO Step' to perform one step of Bayesian optimization. Use 'Reset' to start the process again.")
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ next_step_btn = gr.Button("BO Step")
176
+ reset_btn = gr.Button("Reset")
177
+ with gr.Column():
178
+ optimization_plot = gr.Plot()
179
+
180
+
181
+ # Link buttons to actions
182
+ generate_btn.click(generate_curves, inputs=[num_curves, max_length], outputs=[curve_plot, optimization_plot, hyperparam_text])
183
+ next_step_btn.click(bayesian_optimization_step, inputs=[num_curves, max_length], outputs=optimization_plot)
184
+ reset_btn.click(reset_optimization, outputs=optimization_plot)
185
+
186
+ # Launch the demo
187
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ifbo==0.3.10
2
+ matplotlib
3
+ gradio