Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| # Path to the pre-trained sentiment analysis model | |
| model_path = "saved_model" | |
| # Load the pre-trained segmentation model | |
| segmentation_model = tf.keras.models.load_model(model_path) | |
| # Target image shape | |
| TARGET_SHAPE = (256, 256) | |
| # Define image segmentation function | |
| def segment_image(img:np.ndarray): | |
| # Original image shape | |
| ORIGINAL_SHAPE = img.shape | |
| # Check if the image is RGB and convert if not | |
| if len(ORIGINAL_SHAPE) == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| # Resize the image to TARGET_SHAPE | |
| img = cv2.resize(img, TARGET_SHAPE) | |
| # Add a batch dimension | |
| img = np.expand_dims(img, axis=0) | |
| # Predict the segmentation mask | |
| mask = segmentation_model.predict(img) | |
| # Remove the batch dimension | |
| mask = np.squeeze(mask, axis=0) | |
| # Convert to labels | |
| mask = np.argmax(mask, axis=-1) | |
| # Convert to uint8 | |
| mask = mask.astype(np.uint8) | |
| # Resize to original image shape | |
| mask = cv2.resize(mask, (ORIGINAL_SHAPE[1], ORIGINAL_SHAPE[0])) | |
| return mask | |
| def overlay_mask(img, mask, alpha=0.5): | |
| # Define color mapping | |
| colors = { | |
| 0: [255, 0, 0], # Class 0 - Red | |
| 1: [0, 255, 0], # Class 1 - Green | |
| 2: [0, 0, 255] # Class 2 - Blue | |
| # Add more colors for additional classes if needed | |
| } | |
| # Create a blank colored overlay image | |
| overlay = np.zeros_like(img) | |
| # Map each mask value to the corresponding color | |
| for class_id, color in colors.items(): | |
| overlay[mask == class_id] = color | |
| # Blend the overlay with the original image | |
| output = cv2.addWeighted(img, 1 - alpha, overlay, alpha, 0) | |
| return output | |
| # The main function | |
| def transform(img): | |
| mask=segment_image(img) | |
| blended_img = overlay_mask(img, mask) | |
| return blended_img | |
| # Create the Gradio app | |
| app = gr.Interface( | |
| fn=transform, | |
| inputs=gr.Image(label="Input Image"), | |
| outputs=gr.Image(label="Image with Segmentation Overlay"), | |
| title="Image Segmentation on Pet Images", | |
| description="Segment image of a pet animal into three classes: background, pet, and boundary.", | |
| examples=[ | |
| "example_images/img1.jpg", | |
| "example_images/img2.jpg", | |
| "example_images/img3.jpg" | |
| ] | |
| ) | |
| # Run the app | |
| app.launch() |