Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| import pydantic | |
| import time | |
| import numpy as np | |
| from tqdm import tqdm, trange | |
| import torch | |
| from torch import nn | |
| from diffusers import StableDiffusionPipeline | |
| import clip | |
| from dreamsim import dreamsim | |
| from ribs.archives import GridArchive | |
| from ribs.schedulers import Scheduler | |
| from ribs.emitters import GaussianEmitter | |
| import itertools | |
| from ribs.visualize import grid_archive_heatmap | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.empty_cache() | |
| print("Torch device:", DEVICE) | |
| # Use float16 for GPU, float32 for CPU. | |
| TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print("Torch dtype:", TORCH_DTYPE) | |
| IMG_WIDTH = 256 | |
| IMG_HEIGHT = 256 | |
| SD_IN_HEIGHT = 32 | |
| SD_IN_WIDTH = 32 | |
| SD_CHECKPOINT = "lambdalabs/miniSD-diffusers" | |
| BATCH_SIZE = 4 | |
| SD_IN_CHANNELS = 4 | |
| SD_IN_SHAPE = ( | |
| BATCH_SIZE, | |
| SD_IN_CHANNELS, | |
| SD_IN_HEIGHT, | |
| SD_IN_WIDTH, | |
| ) | |
| SDPIPE = StableDiffusionPipeline.from_pretrained( | |
| SD_CHECKPOINT, | |
| torch_dtype=TORCH_DTYPE, | |
| safety_checker=None, # For faster inference. | |
| requires_safety_checker=False, | |
| ) | |
| SDPIPE.set_progress_bar_config(disable=True) | |
| SDPIPE = SDPIPE.to(DEVICE) | |
| GRID_SIZE = (20, 20) | |
| SEED = 123 | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| # INIT_POP = 200 # Initial population. | |
| # TOTAL_ITRS = 200 # Total number of iterations. | |
| class DivProj(nn.Module): | |
| def __init__(self, input_dim, latent_dim=2): | |
| super().__init__() | |
| self.proj = nn.Sequential( | |
| nn.Linear(in_features=input_dim, out_features=latent_dim), | |
| ) | |
| def forward(self, x): | |
| """Get diversity representations.""" | |
| x = self.proj(x) | |
| return x | |
| def calc_dis(self, x1, x2): | |
| """Calculate diversity distance as (squared) L2 distance.""" | |
| x1 = self.forward(x1) | |
| x2 = self.forward(x2) | |
| return torch.sum(torch.square(x1 - x2), -1) | |
| def triplet_delta_dis(self, ref, x1, x2): | |
| """Calculate delta distance comparing x1 and x2 to ref.""" | |
| x1 = self.forward(x1) | |
| x2 = self.forward(x2) | |
| ref = self.forward(ref) | |
| return (torch.sum(torch.square(ref - x1), -1) - | |
| torch.sum(torch.square(ref - x2), -1)) | |
| # Triplet loss with margin 0.05. | |
| # The binary preference labels are scaled to y = 1 or -1 for the loss, where y = 1 means x2 is more similar to ref than x1. | |
| loss_fn = lambda y, delta_dis: torch.max( | |
| torch.tensor([0.0]).to(DEVICE), 0.05 - (y * 2 - 1) * delta_dis | |
| ).mean() | |
| def fit_div_proj(inputs, dreamsim_features, latent_dim, batch_size=32): | |
| """Trains the DivProj model on ground-truth labels.""" | |
| t = time.time() | |
| model = DivProj(input_dim=inputs.shape[-1], latent_dim=latent_dim) | |
| model.to(DEVICE) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| n_pref_data = inputs.shape[0] | |
| ref = inputs[:, 0] | |
| x1 = inputs[:, 1] | |
| x2 = inputs[:, 2] | |
| n_train = int(n_pref_data * 0.75) | |
| n_val = n_pref_data - n_train | |
| # Split data into train and val. | |
| ref_train = ref[:n_train] | |
| x1_train = x1[:n_train] | |
| x2_train = x2[:n_train] | |
| ref_val = ref[n_train:] | |
| x1_val = x1[n_train:] | |
| x2_val = x2[n_train:] | |
| # Split DreamSim features into train and val. | |
| ref_dreamsim_features = dreamsim_features[:, 0] | |
| x1_dreamsim_features = dreamsim_features[:, 1] | |
| x2_dreamsim_features = dreamsim_features[:, 2] | |
| ref_gt_train = ref_dreamsim_features[:n_train] | |
| x1_gt_train = x1_dreamsim_features[:n_train] | |
| x2_gt_train = x2_dreamsim_features[:n_train] | |
| ref_gt_val = ref_dreamsim_features[n_train:] | |
| x1_gt_val = x1_dreamsim_features[n_train:] | |
| x2_gt_val = x2_dreamsim_features[n_train:] | |
| val_acc = [] | |
| n_iters_per_epoch = max((n_train) // batch_size, 1) | |
| for epoch in range(200): | |
| for _ in range(n_iters_per_epoch): | |
| optimizer.zero_grad() | |
| idx = np.random.choice(n_train, batch_size) | |
| batch_ref = ref_train[idx].float() | |
| batch1 = x1_train[idx].float() | |
| batch2 = x2_train[idx].float() | |
| # Get delta distance from model. | |
| delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2) | |
| # Get preference labels from DreamSim features. | |
| gt_dis = torch.nn.functional.cosine_similarity( | |
| ref_gt_train[idx], x2_gt_train[idx], dim=-1 | |
| ) - torch.nn.functional.cosine_similarity( | |
| ref_gt_train[idx], x1_gt_train[idx], dim=-1 | |
| ) | |
| gt = (gt_dis > 0).to(TORCH_DTYPE) # if distance from the two sims are greater than 0, convert gt to torch_type | |
| loss = loss_fn(gt, delta_dis) | |
| loss.backward() | |
| optimizer.step() | |
| # Validate. | |
| n_correct = 0 | |
| n_total = 0 | |
| with torch.no_grad(): | |
| idx = np.arange(n_val) | |
| batch_ref = ref_val[idx].float() | |
| batch1 = x1_val[idx].float() | |
| batch2 = x2_val[idx].float() | |
| delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2) | |
| pred = delta_dis > 0 | |
| gt_dis = torch.nn.functional.cosine_similarity( | |
| ref_gt_val[idx], x2_gt_val[idx], dim=-1 | |
| ) - torch.nn.functional.cosine_similarity( | |
| ref_gt_val[idx], x1_gt_val[idx], dim=-1 | |
| ) | |
| gt = gt_dis > 0 | |
| n_correct += (pred == gt).sum().item() | |
| n_total += len(idx) | |
| acc = n_correct / n_total | |
| val_acc.append(acc) | |
| # Early stopping if val_acc does not improve for 10 epochs. | |
| if epoch > 10 and np.mean(val_acc[-10:]) < np.mean(val_acc[-11:-1]): | |
| break | |
| print( | |
| f"{np.round(time.time()- t, 1)}s ({epoch+1} epochs) | DivProj (n={n_pref_data}) fitted with val acc.: {acc}" | |
| ) | |
| return model.to(TORCH_DTYPE), acc | |
| def compute_diversity_measures(clip_features, diversity_model): | |
| with torch.no_grad(): | |
| measures = diversity_model(clip_features).detach().cpu().numpy() | |
| return measures | |
| def tensor_to_list(tensor): | |
| sols = tensor.detach().cpu().numpy().astype(np.float32) | |
| return sols.reshape(sols.shape[0], -1) | |
| def list_to_tensor(list_): | |
| sols = np.array(list_).reshape( | |
| len(list_), 4, SD_IN_HEIGHT, SD_IN_WIDTH | |
| ) # Hard-coded for now. | |
| return torch.tensor(sols, dtype=TORCH_DTYPE, device=DEVICE) | |
| def create_scheduler( | |
| sols, | |
| objs, | |
| clip_features, | |
| diversity_model, | |
| seed=None, | |
| ): | |
| measures = compute_diversity_measures(clip_features, diversity_model) | |
| archive_bounds = np.array( | |
| [np.quantile(measures, 0.01, axis=0), np.quantile(measures, 0.99, axis=0)] | |
| ).T | |
| sols = tensor_to_list(sols) | |
| # Set up archive. | |
| archive = GridArchive( | |
| solution_dim=len(sols[0]), dims=GRID_SIZE, ranges=archive_bounds, seed=SEED | |
| ) | |
| # Add initial solutions to the archive. | |
| archive.add(sols, objs, measures) | |
| # Set up the GaussianEmitter. | |
| emitters = [ | |
| GaussianEmitter( | |
| archive=archive, | |
| sigma=0.1, | |
| initial_solutions=archive.sample_elites(BATCH_SIZE)["solution"], | |
| batch_size=BATCH_SIZE, | |
| seed=SEED, | |
| ) | |
| ] | |
| # Return the archive and scheduler. | |
| return archive, Scheduler(archive, emitters) | |
| def plot_archive(archive): | |
| plt.figure(figsize=(6, 4.5)) | |
| grid_archive_heatmap(archive, vmin=0, vmax=100) | |
| plt.xlabel("Diversity Metric 1") | |
| plt.ylabel("Diversity Metric 2") | |
| return plt | |
| def run_qdhf(prompt:str, init_pop: int=200, total_itrs: int=200): | |
| INIT_POP = init_pop | |
| TOTAL_ITRS = total_itrs | |
| # This tutorial uses ViT-B/32, you may use other checkpoints depending on your resources and need. | |
| CLIP_MODEL, CLIP_PREPROCESS = clip.load("ViT-B/32", device=DEVICE) | |
| CLIP_MODEL.eval() | |
| for p in CLIP_MODEL.parameters(): | |
| p.requires_grad_(False) | |
| def compute_clip_scores(imgs, text, return_clip_features=False): | |
| """Computes CLIP scores for a batch of images and a given text prompt.""" | |
| img_tensor = torch.stack([CLIP_PREPROCESS(img) for img in imgs]).to(DEVICE) | |
| tokenized_text = clip.tokenize([text]).to(DEVICE) | |
| img_logits, _text_logits = CLIP_MODEL(img_tensor, tokenized_text) | |
| img_logits = img_logits.detach().cpu().numpy().astype(np.float32)[:, 0] | |
| img_logits = 1 / img_logits * 100 | |
| # Remap the objective from minimizing [0, 10] to maximizing [0, 100] | |
| img_logits = (10.0 - img_logits) * 10.0 | |
| if return_clip_features: | |
| clip_features = CLIP_MODEL.encode_image(img_tensor).to(TORCH_DTYPE) | |
| return img_logits, clip_features | |
| else: | |
| return img_logits | |
| DREAMSIM_MODEL, DREAMSIM_PREPROCESS = dreamsim( | |
| pretrained=True, dreamsim_type="open_clip_vitb32", device=DEVICE | |
| ) | |
| def evaluate_lsi( | |
| latents, | |
| prompt, | |
| return_features=False, | |
| diversity_model=None, | |
| ): | |
| """Evaluates the objective of LSI for a batch of latents and a given text prompt.""" | |
| images = SDPIPE( | |
| prompt, | |
| num_images_per_prompt=latents.shape[0], | |
| latents=latents, | |
| # num_inference_steps=1, # For testing. | |
| ).images | |
| objs, clip_features = compute_clip_scores( | |
| images, | |
| prompt, | |
| return_clip_features=True, | |
| ) | |
| images = torch.cat([DREAMSIM_PREPROCESS(img) for img in images]).to(DEVICE) | |
| dreamsim_features = DREAMSIM_MODEL.embed(images) | |
| if diversity_model is not None: | |
| measures = compute_diversity_measures(clip_features, diversity_model) | |
| else: | |
| measures = None | |
| if return_features: | |
| return objs, measures, clip_features, dreamsim_features | |
| else: | |
| return objs, measures | |
| update_schedule = [1, 21, 51, 101] # Iterations on which to update the archive. | |
| n_pref_data = 1000 # Number of preferences used in each update. | |
| archive = None | |
| best = 0.0 | |
| for itr in trange(1, TOTAL_ITRS + 1): | |
| # Update archive and scheduler if needed. | |
| if itr in update_schedule: | |
| if archive is None: | |
| tqdm.write("Initializing archive and diversity projection.") | |
| all_sols = [] | |
| all_clip_features = [] | |
| all_dreamsim_features = [] | |
| all_objs = [] | |
| # Sample random solutions and get judgment on similarity. | |
| n_batches = INIT_POP // BATCH_SIZE | |
| for _ in range(n_batches): | |
| sols = torch.randn(SD_IN_SHAPE, device=DEVICE, dtype=TORCH_DTYPE) | |
| objs, _, clip_features, dreamsim_features = evaluate_lsi( | |
| sols, prompt, return_features=True | |
| ) | |
| all_sols.append(sols) | |
| all_clip_features.append(clip_features) | |
| all_dreamsim_features.append(dreamsim_features) | |
| all_objs.append(objs) | |
| all_sols = torch.concat(all_sols, dim=0) | |
| all_clip_features = torch.concat(all_clip_features, dim=0) | |
| all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0) | |
| all_objs = np.concatenate(all_objs, axis=0) | |
| # Initialize the diversity projection model. | |
| div_proj_data = [] | |
| div_proj_labels = [] | |
| for _ in range(n_pref_data): | |
| idx = np.random.choice(all_sols.shape[0], 3) | |
| div_proj_data.append(all_clip_features[idx]) | |
| div_proj_labels.append(all_dreamsim_features[idx]) | |
| div_proj_data = torch.concat(div_proj_data, dim=0) | |
| div_proj_labels = torch.concat(div_proj_labels, dim=0) | |
| div_proj_data = div_proj_data.reshape(n_pref_data, 3, -1) | |
| div_proj_label = div_proj_labels.reshape(n_pref_data, 3, -1) | |
| diversity_model, div_proj_acc = fit_div_proj( | |
| div_proj_data, | |
| div_proj_label, | |
| latent_dim=2, | |
| ) | |
| else: | |
| tqdm.write("Updating archive and diversity projection.") | |
| # Get all the current solutions and collect feedback. | |
| all_sols = list_to_tensor(archive.data("solution")) | |
| n_batches = np.ceil(len(all_sols) / BATCH_SIZE).astype(int) | |
| all_clip_features = [] | |
| all_dreamsim_features = [] | |
| all_objs = [] | |
| for i in range(n_batches): | |
| sols = all_sols[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] | |
| objs, _, clip_features, dreamsim_features = evaluate_lsi( | |
| sols, prompt, return_features=True | |
| ) | |
| all_clip_features.append(clip_features) | |
| all_dreamsim_features.append(dreamsim_features) | |
| all_objs.append(objs) | |
| all_clip_features = torch.concat( | |
| all_clip_features, dim=0 | |
| ) # n_pref_data * 3, dim | |
| all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0) | |
| all_objs = np.concatenate(all_objs, axis=0) | |
| # Update the diversity projection model. | |
| additional_features = [] | |
| additional_labels = [] | |
| for _ in range(n_pref_data): | |
| idx = np.random.choice(all_sols.shape[0], 3) | |
| additional_features.append(all_clip_features[idx]) | |
| additional_labels.append(all_dreamsim_features[idx]) | |
| additional_features = torch.concat(additional_features, dim=0) | |
| additional_labels = torch.concat(additional_labels, dim=0) | |
| additional_div_proj_data = additional_features.reshape(n_pref_data, 3, -1) | |
| additional_div_proj_label = additional_labels.reshape(n_pref_data, 3, -1) | |
| div_proj_data = torch.concat( | |
| (div_proj_data, additional_div_proj_data), axis=0 | |
| ) | |
| div_proj_label = torch.concat( | |
| (div_proj_label, additional_div_proj_label), axis=0 | |
| ) | |
| diversity_model, div_proj_acc = fit_div_proj( | |
| div_proj_data, | |
| div_proj_label, | |
| latent_dim=2, | |
| ) | |
| archive, scheduler = create_scheduler( | |
| all_sols, | |
| all_objs, | |
| all_clip_features, | |
| diversity_model, | |
| seed=SEED, | |
| ) | |
| # Primary QD loop. | |
| sols = scheduler.ask() | |
| sols = list_to_tensor(sols) | |
| objs, measures, clip_features, dreamsim_features = evaluate_lsi( | |
| sols, prompt, return_features=True, diversity_model=diversity_model | |
| ) | |
| best = max(best, max(objs)) | |
| scheduler.tell(objs, measures) | |
| # This can be used as a flag to save on the final iteration, but note that | |
| # we do not save results in this tutorial. | |
| final_itr = itr == TOTAL_ITRS | |
| # Update the summary statistics for the archive. | |
| qd_score, coverage = archive.stats.norm_qd_score, archive.stats.coverage | |
| tqdm.write(f"QD score: {np.round(qd_score, 2)} Coverage: {coverage * 100}") | |
| plt = plot_archive(archive) | |
| yield archive, plt | |
| plt = plot_archive(archive) | |
| return archive, plt | |
| def many_pictures(archive, prompt:str): | |
| # Modify this to determine how many images to plot along each dimension. | |
| img_freq = ( | |
| 4, # Number of columns of images. | |
| 4, # Number of rows of images. | |
| ) | |
| # List of images. | |
| imgs = [] | |
| # Convert archive to a df with solutions available. | |
| df = archive.data(return_type="pandas") | |
| # Compute the min and max measures for which solutions were found. | |
| measure_bounds = np.array( | |
| [ | |
| (df["measures_0"].min(), df["measures_0"].max()), | |
| (df["measures_1"].min(), df["measures_1"].max()), | |
| ] | |
| ) | |
| archive_bounds = np.array( | |
| [archive.boundaries[0][[0, -1]], archive.boundaries[1][[0, -1]]] | |
| ) | |
| delta_measures_0 = (archive_bounds[0][1] - archive_bounds[0][0]) / img_freq[0] | |
| delta_measures_1 = (archive_bounds[1][1] - archive_bounds[1][0]) / img_freq[1] | |
| for col, row in itertools.product(range(img_freq[1]), range(img_freq[0])): | |
| # Compute bounds of a box in measure space. | |
| measures_0_low = archive_bounds[0][0] + delta_measures_0 * row | |
| measures_0_high = archive_bounds[0][0] + delta_measures_0 * (row + 1) | |
| measures_1_low = archive_bounds[1][0] + delta_measures_1 * col | |
| measures_1_high = archive_bounds[1][0] + delta_measures_1 * (col + 1) | |
| if row == 0: | |
| measures_0_low = measure_bounds[0][0] | |
| if col == 0: | |
| measures_1_low = measure_bounds[1][0] | |
| if row == img_freq[0] - 1: | |
| measures_0_high = measure_bounds[0][1] | |
| if col == img_freq[1] - 1: | |
| measures_0_high = measure_bounds[1][1] | |
| # Query for a solution with measures within this box. | |
| query_string = ( | |
| f"{measures_0_low} <= measures_0 & measures_0 <= {measures_0_high} & " | |
| f"{measures_1_low} <= measures_1 & measures_1 <= {measures_1_high}" | |
| ) | |
| df_box = df.query(query_string) | |
| if not df_box.empty: | |
| # Randomly sample a solution from the box. | |
| # Stable Diffusion solutions have SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH | |
| # dimensions, so the final solution col is solution_(x-1). | |
| sol = ( | |
| df_box.loc[ | |
| :, | |
| "solution_0" : "solution_{}".format( | |
| SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH - 1 | |
| ), | |
| ] | |
| .sample(n=1) | |
| .iloc[0] | |
| ) | |
| # Convert the latent vector solution to an image. | |
| latents = torch.tensor(sol.to_numpy()).reshape( | |
| (1, SD_IN_CHANNELS, SD_IN_HEIGHT, SD_IN_WIDTH) | |
| ) | |
| latents = latents.to(TORCH_DTYPE).to(DEVICE) | |
| img = SDPIPE( | |
| prompt, | |
| num_images_per_prompt=1, | |
| latents=latents, | |
| # num_inference_steps=1, # For testing. | |
| ).images[0] | |
| img = torch.from_numpy(np.array(img)).permute(2, 0, 1) / 255.0 | |
| imgs.append(img) | |
| else: | |
| imgs.append(torch.zeros((3, IMG_HEIGHT, IMG_WIDTH))) | |
| from torchvision.utils import make_grid | |
| def create_archive_tick_labels(measure_range, num_ticks): | |
| delta = (measure_range[1] - measure_range[0]) / num_ticks | |
| ticklabels = [round(delta * p + measure_range[0], 3) for p in range(num_ticks + 1)] | |
| return ticklabels | |
| plt.figure(figsize=(img_freq[0] * 2, img_freq[0] * 2)) | |
| img_grid = make_grid(imgs, nrow=img_freq[0], padding=0) | |
| img_grid = np.transpose(img_grid.cpu().numpy(), (1, 2, 0)) | |
| plt.imshow(img_grid) | |
| plt.xlabel("") | |
| num_x_ticks = img_freq[0] | |
| x_ticklabels = create_archive_tick_labels(measure_bounds[0], num_x_ticks) | |
| x_tick_range = img_grid.shape[1] | |
| x_ticks = np.arange(0, x_tick_range + 1e-9, step=x_tick_range / num_x_ticks) | |
| plt.xticks(x_ticks, x_ticklabels) | |
| plt.ylabel("") | |
| num_y_ticks = img_freq[1] | |
| y_ticklabels = create_archive_tick_labels(measure_bounds[1], num_y_ticks) | |
| y_ticklabels.reverse() | |
| y_tick_range = img_grid.shape[0] | |
| y_ticks = np.arange(0, y_tick_range + 1e-9, step=y_tick_range / num_y_ticks) | |
| plt.yticks(y_ticks, y_ticklabels) | |
| plt.tight_layout() | |
| return plt | |