Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| import numpy | |
| def plot_tensor_images(data, **kwargs): | |
| data = ((data + 1) / 2 * 255).permute(0, 2, 3, 1).byte().cpu().numpy() | |
| width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) | |
| height = int(numpy.ceil(data.shape[0] / float(width))) | |
| kwargs = dict(kwargs) | |
| margin = 0.01 | |
| if 'figsize' not in kwargs: | |
| # Size figure to one display pixel per data pixel | |
| dpi = plt.rcParams['figure.dpi'] | |
| kwargs['figsize'] = ( | |
| (1 + margin) * (width * data.shape[2] / dpi), | |
| (1 + margin) * (height * data.shape[1] / dpi)) | |
| f, axarr = plt.subplots(height, width, **kwargs) | |
| if len(numpy.shape(axarr)) == 0: | |
| axarr = numpy.array([[axarr]]) | |
| if len(numpy.shape(axarr)) == 1: | |
| axarr = axarr[None,:] | |
| for i, im in enumerate(data): | |
| ax = axarr[i // width, i % width] | |
| ax.imshow(data[i]) | |
| ax.axis('off') | |
| for i in range(i, width * height): | |
| ax = axarr[i // width, i % width] | |
| ax.axis('off') | |
| plt.subplots_adjust(wspace=margin, hspace=margin, | |
| left=0, right=1, bottom=0, top=1) | |
| plt.show() | |
| def plot_max_heatmap(data, shape=None, **kwargs): | |
| if shape is None: | |
| shape = data.shape[2:] | |
| data = data.max(1)[0].cpu().numpy() | |
| vmin = data.min() | |
| vmax = data.max() | |
| width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) | |
| height = int(numpy.ceil(data.shape[0] / float(width))) | |
| kwargs = dict(kwargs) | |
| margin = 0.01 | |
| if 'figsize' not in kwargs: | |
| # Size figure to one display pixel per data pixel | |
| dpi = plt.rcParams['figure.dpi'] | |
| kwargs['figsize'] = ( | |
| width * shape[1] / dpi, height * shape[0] / dpi) | |
| f, axarr = plt.subplots(height, width, **kwargs) | |
| if len(numpy.shape(axarr)) == 0: | |
| axarr = numpy.array([[axarr]]) | |
| if len(numpy.shape(axarr)) == 1: | |
| axarr = axarr[None,:] | |
| for i, im in enumerate(data): | |
| ax = axarr[i // width, i % width] | |
| img = ax.imshow(data[i], vmin=vmin, vmax=vmax, cmap='hot') | |
| ax.axis('off') | |
| for i in range(i, width * height): | |
| ax = axarr[i // width, i % width] | |
| ax.axis('off') | |
| plt.subplots_adjust(wspace=margin, hspace=margin, | |
| left=0, right=1, bottom=0, top=1) | |
| plt.show() | |