Spaces:
Build error
Build error
| """Utils for visual iterative prompting. | |
| A number of utility functions for VIP. | |
| """ | |
| import re | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import scipy.spatial.distance as distance | |
| def min_dist(coord, coords): | |
| if not coords: | |
| return np.inf | |
| xys = np.asarray([[coord.xy] for coord in coords]) | |
| return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min() | |
| def coord_outside_image(coord, image, radius): | |
| (height, image_width, _) = image.shape | |
| x, y = coord.xy | |
| x_outside = x > image_width - 2 * radius or x < 2 * radius | |
| y_outside = y > height - 2 * radius or y < 2 * radius | |
| return x_outside or y_outside | |
| def is_invalid_coord(coord, coords, radius, image): | |
| # invalid if too close to others or outside of the image | |
| pos_overlaps = min_dist(coord, coords) < 1.5 * radius | |
| return pos_overlaps or coord_outside_image(coord, image, radius) | |
| def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40): | |
| x, y = arm_coord | |
| x += int(np.cos(angle) * mag) | |
| y += int(np.sin(angle) * mag) | |
| if is_circle: | |
| x += int(np.cos(angle) * radius * np.sign(mag)) | |
| y += int(np.sin(angle) * radius * np.sign(mag)) | |
| return x, y | |
| def coord_to_text_coord(coord, arm_coord, radius): | |
| delta_coord = np.asarray(coord.xy) - arm_coord | |
| if np.linalg.norm(delta_coord) == 0: | |
| return arm_coord | |
| return ( | |
| int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)), | |
| int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)), | |
| ) | |
| def parse_response(response, answer_key='Arrow: ['): | |
| values = [] | |
| if answer_key in response: | |
| print('parse_response from answer_key') | |
| arrow_response = response.split(answer_key)[-1].split(']')[0] | |
| for val in map(int, re.findall(r'\d+', arrow_response)): | |
| values.append(val) | |
| else: | |
| print('parse_response for all ints') | |
| for val in map(int, re.findall(r'\d+', response)): | |
| values.append(val) | |
| return values | |
| def compute_errors(action, true_action, verbose=False): | |
| """Compute errors between a predicted action and true action.""" | |
| l2_error = np.linalg.norm(action - true_action) | |
| cos_sim = 1 - distance.cosine(action, true_action) | |
| l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:]) | |
| cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:]) | |
| z_error = np.abs(action[0] - true_action[0]) | |
| errors = { | |
| 'l2': l2_error, | |
| 'cos_sim': cos_sim, | |
| 'l2_xy_error': l2_xy_error, | |
| 'cos_xy_sim': cos_xy_sim, | |
| 'z_error': z_error, | |
| } | |
| if verbose: | |
| print('action: \t', [f'{a:.3f}' for a in action]) | |
| print('true_action \t', [f'{a:.3f}' for a in true_action]) | |
| print(f'l2: \t\t{l2_error:.3f}') | |
| print(f'l2_xy_error: \t{l2_xy_error:.3f}') | |
| print(f'cos_sim: \t{cos_sim:.3f}') | |
| print(f'cos_xy_sim: \t{cos_xy_sim:.3f}') | |
| print(f'z_error: \t{z_error:.3f}') | |
| return errors | |
| def plot_errors(all_errors, error_types=None): | |
| """Plot errors across iterations.""" | |
| if error_types is None: | |
| error_types = [ | |
| 'l2', | |
| 'l2_xy_error', | |
| 'z_error', | |
| 'cos_sim', | |
| 'cos_xy_sim', | |
| ] | |
| _, axs = plt.subplots(2, 3, figsize=(15, 8)) | |
| for i, error_type in enumerate(error_types): # go through each error type | |
| all_iter_errors = {} | |
| for error_by_iter in all_errors: # go through each call | |
| for itr in error_by_iter: # go through each iteration | |
| if itr in all_iter_errors: # add error to the iteration it happened | |
| all_iter_errors[itr].append(error_by_iter[itr][error_type]) | |
| else: | |
| all_iter_errors[itr] = [error_by_iter[itr][error_type]] | |
| mean_iter_errors = [ | |
| np.mean(all_iter_errors[itr]) for itr in all_iter_errors | |
| ] | |
| axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors) | |
| axs[i // 3, i % 3].set_title(error_type) | |
| plt.show() | |