Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import random | |
| import gradio as gr | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from CCAgT_utils.categories import CategoriesInfos | |
| from CCAgT_utils.types.mask import Mask | |
| from CCAgT_utils.visualization import plot | |
| from PIL import Image | |
| from torch import nn | |
| from transformers import SegformerFeatureExtractor | |
| from transformers import SegformerForSemanticSegmentation | |
| from transformers.modeling_outputs import SemanticSegmenterOutput | |
| matplotlib.use('Agg') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300' | |
| model = SegformerForSemanticSegmentation.from_pretrained( | |
| model_hub_name, | |
| ).to(device) | |
| model.eval() | |
| feature_extractor = SegformerFeatureExtractor.from_pretrained( | |
| model_hub_name, | |
| ) | |
| def segment( | |
| image: Image.Image, | |
| ) -> SemanticSegmenterOutput: | |
| inputs = feature_extractor( | |
| image, | |
| return_tensors='pt', | |
| ).to(device) | |
| outputs = model(**inputs) | |
| return outputs | |
| def post_processing( | |
| outputs: SemanticSegmenterOutput, | |
| target_size: tuple[int, int], | |
| ) -> np.ndarray: | |
| logits = outputs.logits.cpu() | |
| upsampled_logits = nn.functional.interpolate( | |
| logits, | |
| size=target_size, | |
| mode='bilinear', | |
| align_corners=False, | |
| ) | |
| segmentation_mask = upsampled_logits.argmax(dim=1)[0] | |
| return np.array(segmentation_mask) | |
| def colorize( | |
| mask: Mask, | |
| ) -> np.ndarray: | |
| return mask.colorized(CategoriesInfos()) / 255 | |
| # Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L35 | |
| def get_random_crop_coords( | |
| height: int, | |
| width: int, | |
| crop_height: int, | |
| crop_width: int, | |
| h_start: float, | |
| w_start: float, | |
| ): | |
| y1 = int((height - crop_height + 1) * h_start) | |
| y2 = y1 + crop_height | |
| x1 = int((width - crop_width + 1) * w_start) | |
| x2 = x1 + crop_width | |
| return x1, y1, x2, y2 | |
| # Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L46 | |
| def random_crop( | |
| img: np.ndarray, | |
| crop_height: int, | |
| crop_width: int, | |
| h_start: float, | |
| w_start: float, | |
| ) -> np.ndarray: | |
| height, width = img.shape[:2] | |
| x1, y1, x2, y2 = get_random_crop_coords( | |
| height, width, crop_height, crop_width, h_start, w_start, | |
| ) | |
| img = img[y1:y2, x1:x2] | |
| return img | |
| def process_big_images( | |
| image: Image.Image, | |
| ) -> Mask: | |
| '''Process and post-processing for images bigger than 400x300''' | |
| img = np.asarray(image) | |
| if img.shape[0] > 300 or img.shape[1] > 400: | |
| img = random_crop(img, 300, 400, random.random(), random.random()) | |
| target_size = (img.shape[0], img.shape[1]) | |
| outputs = segment(Image.fromarray(img)) | |
| msk = post_processing(outputs, target_size) | |
| return img, Mask(msk) | |
| def image_with_mask( | |
| image: Image.Image, | |
| mask: Mask, | |
| ) -> plt.Figure: | |
| fig = plt.figure(dpi=600) | |
| plt.imshow(image) | |
| plt.imshow( | |
| mask.categorical, | |
| cmap=mask.cmap(CategoriesInfos()), | |
| vmax=max(mask.unique_ids), | |
| vmin=min(mask.unique_ids), | |
| interpolation='nearest', | |
| alpha=0.4, | |
| ) | |
| plt.axis('off') | |
| plt.tight_layout(pad=0) | |
| return fig | |
| def categories_map( | |
| mask: Mask, | |
| ) -> plt.Figure: | |
| fig = plt.figure(dpi=600) | |
| handles = plot.create_handles( | |
| CategoriesInfos(), selected_categories=mask.unique_ids, | |
| ) | |
| plt.legend(handles=handles, fontsize=24, loc='center') | |
| plt.axis('off') | |
| return fig | |
| def main(image): | |
| image = Image.fromarray(image) | |
| img, mask = process_big_images(image) | |
| mask_colorized = colorize(mask) | |
| fig = image_with_mask(img, mask) | |
| return categories_map(mask), Image.fromarray(img), mask_colorized, fig | |
| title = 'SegFormer (b3) - CCAgT dataset' | |
| description = f""" | |
| This is demo for the SegFormer fine-tuned on sub-dataset from | |
| [CCAgT dataset](https://huggingface.co/datasets/lapix/CCAgT). This model | |
| was trained to segment cervical cells silver-stained (AgNOR technique) | |
| images with resolution of 400x300. The model was available at HF hub at | |
| [{model_hub_name}](https://huggingface.co/{model_hub_name}). If input | |
| an image bigger than 400x300, the demo will random crop it. | |
| """ | |
| examples = [ | |
| [f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'], | |
| [f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'], | |
| ] + [ | |
| [f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg'] | |
| for x in {3, 10, 12, 18, 35, 78, 89} | |
| ] | |
| demo = gr.Interface( | |
| main, | |
| inputs=[gr.Image()], | |
| outputs=[ | |
| gr.Plot(label='Categories map'), | |
| gr.Image(label='Image'), | |
| gr.Image(label='Mask'), | |
| gr.Plot(label='Image with mask'), | |
| ], | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| allow_flagging='never', | |
| cache_examples=False, | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |