Spaces:
Build error
Build error
| import os | |
| import random | |
| import logging | |
| import gradio as gr | |
| from PIL import Image | |
| from zipfile import ZipFile | |
| from typing import Any, Dict,List | |
| from transformers import pipeline | |
| class Image_classification: | |
| def __init__(self): | |
| pass | |
| def unzip_image_data(self) -> str: | |
| """ | |
| Unzips an image dataset into a specified directory. | |
| Returns: | |
| str: The path to the directory containing the extracted image files. | |
| """ | |
| try: | |
| with ZipFile("image_dataset.zip","r") as extract: | |
| directory_path=str("dataset") | |
| os.mkdir(directory_path) | |
| extract.extractall(f"{directory_path}") | |
| return f"{directory_path}" | |
| except Exception as e: | |
| logging.error(f"An error occurred during extraction: {e}") | |
| return "" | |
| def example_images(self) -> List[str]: | |
| """ | |
| Unzips the image dataset and generates a list of paths to the individual image files and use image for showing example | |
| Returns: | |
| List[str]: A list of file paths to each image in the dataset. | |
| """ | |
| try: | |
| image_dataset_folder = self.unzip_image_data() | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'] | |
| image_count = len([name for name in os.listdir(image_dataset_folder) if os.path.isfile(os.path.join(image_dataset_folder, name)) and os.path.splitext(name)[1].lower() in image_extensions]) | |
| example=[] | |
| for i in range(image_count): | |
| for name in os.listdir(image_dataset_folder): | |
| path=(os.path.join(os.path.dirname(image_dataset_folder),os.path.join(image_dataset_folder,name))) | |
| example.append(path) | |
| return example | |
| except Exception as e: | |
| logging.error(f"An error occurred in example images: {e}") | |
| return "" | |
| def classify(self, image: Image.Image, model: Any) -> Dict[str, float]: | |
| """ | |
| Classifies an image using a specified model. | |
| Args: | |
| image (Image.Image): The image to classify. | |
| model (Any): The model used for classification. | |
| Returns: | |
| Dict[str, float]: A dictionary of classification labels and their corresponding scores. | |
| """ | |
| try: | |
| classifier = pipeline("image-classification", model=model) | |
| result= classifier(image) | |
| return result | |
| except Exception as e: | |
| logging.error(f"An error occurred during image classification: {e}") | |
| raise | |
| def format_the_result(self, image: Image.Image, model: Any) -> Dict[str, float]: | |
| """ | |
| Formats the classification result by retaining the highest score for each label. | |
| Args: | |
| image (Image.Image): The image to classify. | |
| model (Any): The model used for classification. | |
| Returns: | |
| Dict[str, float]: A dictionary with unique labels and the highest score for each label. | |
| """ | |
| try: | |
| data=self.classify(image,model) | |
| new_dict = {} | |
| for item in data: | |
| label = item['label'] | |
| score = item['score'] | |
| if label in new_dict: | |
| if new_dict[label] < score: | |
| new_dict[label] = score | |
| else: | |
| new_dict[label] = score | |
| return new_dict | |
| except Exception as e: | |
| logging.error(f"An error occurred while formatting the results: {e}") | |
| raise | |
| def interface(self): | |
| with gr.Blocks(css=""" | |
| .gradio-container {background: #314755; | |
| background: -webkit-linear-gradient(to right, #26a0da, #314755); | |
| background: linear-gradient(to right, #26a0da, #314755);} | |
| .block svelte-90oupt padded{background:314755; | |
| margin:0; | |
| padding:0;}""") as demo: | |
| gr.HTML(""" | |
| <center><h1 style="color:#fff">Image Classification</h1></center>""") | |
| exam_img=self.example_images() | |
| with gr.Row(): | |
| model = gr.Dropdown(["facebook/regnet-x-040","google/vit-large-patch16-384","microsoft/resnet-50",""],label="Choose a model") | |
| with gr.Row(): | |
| image = gr.Image(type="filepath",sources="upload") | |
| with gr.Column(): | |
| output=gr.Label() | |
| with gr.Row(): | |
| button=gr.Button() | |
| button.click(self.format_the_result,[image,model],output) | |
| gr.Examples( | |
| examples=exam_img, | |
| inputs=[image], | |
| outputs=output, | |
| fn=self.format_the_result, | |
| cache_examples=False, | |
| ) | |
| demo.launch(debug=True) | |
| if __name__=="__main__": | |
| image_classification=Image_classification() | |
| result=image_classification.interface() |