Spaces:
Runtime error
Runtime error
| '''Artist Classifier | |
| prototype | |
| --- | |
| - 2022-01-18 jkang first created | |
| ''' | |
| from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import matplotlib.image as mpimg | |
| import seaborn as sns | |
| import io | |
| import json | |
| import numpy as np | |
| import skimage | |
| import skimage.io | |
| from skimage.transform import resize | |
| from loguru import logger | |
| from huggingface_hub import from_pretrained_keras | |
| import gradio as gr | |
| import tensorflow as tf | |
| tfk = tf.keras | |
| from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap | |
| # ---------- Settings ---------- | |
| ARTIST_META = 'artist.json' | |
| TREND_META = 'trend.json' | |
| EXAMPLES = ['monet2.jpg', 'surrelaism.png', 'graffitiart.png', 'lichtenstein_popart.jpg', 'pierre_augste_renoir.png'] | |
| ALPHA = 0.9 | |
| IMG_WIDTH = 299 | |
| IMG_HEIGHT = 299 | |
| # ---------- Logging ---------- | |
| logger.add('app.log', mode='a') | |
| logger.info('============================= App restarted =============================') | |
| # ---------- Model ---------- | |
| logger.info('loading models...') | |
| artist_model = from_pretrained_keras("jkang/drawing-artist-classifier") | |
| trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier") | |
| logger.info('both models loaded') | |
| def load_json_as_dict(json_file): | |
| with open(json_file, 'r') as f: | |
| out = json.load(f) | |
| return dict(out) | |
| def load_image_as_array(image_file): | |
| img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib') | |
| if (img.shape[-1] > 3): # if RGBA | |
| img = img[..., :-1] | |
| return img | |
| def resize_image(img_array, width, height): | |
| img_resized = resize(img_array, (height, width), | |
| anti_aliasing=True, | |
| preserve_range=False) | |
| return skimage.img_as_ubyte(img_resized) | |
| def predict(input_image): | |
| img_3d_array = load_image_as_array(input_image) | |
| img_3d_array = resize_image(img_3d_array, IMG_WIDTH, IMG_HEIGHT) | |
| img_4d_array = img_3d_array[np.newaxis,...] | |
| logger.info(f'--- {input_image} loaded') | |
| artist2id = load_json_as_dict(ARTIST_META) | |
| trend2id = load_json_as_dict(TREND_META) | |
| id2artist = {artist2id[artist]:artist for artist in artist2id} | |
| id2trend = {trend2id[trend]:trend for trend in trend2id} | |
| # Artist model | |
| a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model, | |
| img_4d_array, | |
| pred_idx=None) | |
| a_img_pil = align_image_with_heatmap( | |
| img_4d_array, a_heatmap, alpha=ALPHA, cmap='jet') | |
| a_img = np.asarray(a_img_pil).astype('float32')/255 | |
| a_label = id2artist[a_pred_id] | |
| a_prob = a_pred_out[a_pred_id] | |
| # Trend model | |
| t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model, | |
| img_4d_array, | |
| pred_idx=None) | |
| t_img_pil = align_image_with_heatmap( | |
| img_4d_array, t_heatmap, alpha=ALPHA, cmap='jet') | |
| t_img = np.asarray(t_img_pil).astype('float32')/255 | |
| t_label = id2trend[t_pred_id] | |
| t_prob = t_pred_out[t_pred_id] | |
| with sns.plotting_context('poster', font_scale=0.7): | |
| fig, (ax1, ax2, ax3) = plt.subplots( | |
| 1, 3, figsize=(12, 6), facecolor='white') | |
| for ax in (ax1, ax2, ax3): | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax1.imshow(img_3d_array) | |
| ax2.imshow(a_img) | |
| ax3.imshow(t_img) | |
| ax1.set_title(f'Input Image', ha='left', x=0, y=1.05) | |
| ax2.set_title(f'Artist Prediction:\n => {a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05) | |
| ax3.set_title(f'Style Prediction:\n => {t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05) | |
| fig.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, bbox_inches='tight', format='jpg') | |
| buf.seek(0) | |
| pil_img = Image.open(buf) | |
| plt.close() | |
| logger.info('--- image generated') | |
| a_labels = {id2artist[i]: float(pred) for i, pred in enumerate(a_pred_out)} | |
| t_labels = {id2trend[i]: float(pred) for i, pred in enumerate(t_pred_out)} | |
| return a_labels, t_labels, pil_img | |
| iface = gr.Interface( | |
| predict, | |
| title='Predict Artist and Artistic Style of Drawings π¨π¨π»βπ¨ (prototype)', | |
| description='Upload a drawing/image and the model will predict how likely it seems given 10 artists and their trend/style', | |
| inputs=[ | |
| gr.inputs.Image(label='Upload a drawing/image', type='file') | |
| ], | |
| outputs=[ | |
| gr.outputs.Label(label='Artists', num_top_classes=5, type='auto'), | |
| gr.outputs.Label(label='Styles', num_top_classes=5, type='auto'), | |
| gr.outputs.Image(label='Prediction with GradCAM') | |
| ], | |
| examples=EXAMPLES, | |
| ) | |
| iface.launch(debug=True, enable_queue=True) | |