from typing import Tuple
import gradio as gr
import numpy as np
import supervision as sv
from ultralytics import YOLO
MARKDOWN = """
YOLO-Playground 📈
Welcome to YOLO-Playground! This demo showcases the detection capabilities of various YOLO models pre-trained on the COCO Dataset. 🚀🔍👀
A simple project just for fun for on the go object detection. 🎉
Inspired from YOLO-ARENA by SkalskiP. 🙏
- **YOLOv8**
  
- **YOLOv9**
  
- **YOLOv10**
  
- **YOLO11**
Powered by Roboflow [Inference](https://github.com/roboflow/inference), 
[Supervision](https://github.com/roboflow/supervision) and [Ultralytics](https://github.com/ultralytics/ultralytics).🔥
"""
IMAGE_EXAMPLES = [
    ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 0.3, 0.3, 0.3, 0.5],
    ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 0.3, 0.3, 0.3, 0.5],
    ['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 0.3, 0.3, 0.3, 0.3, 0.5],
]
YOLO_V8S_MODEL = YOLO("yolov8s.pt")
YOLO_V9S_MODEL = YOLO("yolov9s.pt")
YOLO_V10S_MODEL = YOLO("yolov10s.pt")
YOLO_11S_MODEL = YOLO("yolo11s.pt")
LABEL_ANNOTATORS = sv.LabelAnnotator()
BOUNDING_BOX_ANNOTATORS = sv.BoxAnnotator()
def detect_and_annotate(
    model,
    input_image: np.ndarray,
    confidence_threshold: float,
    iou_threshold: float,
    class_id_mapping: dict = None
) -> np.ndarray:
    result = model(
        input_image,
        conf=confidence_threshold,
        iou=iou_threshold
    )[0]
    detections = sv.Detections.from_ultralytics(result)
    if class_id_mapping:
        detections.class_id = np.array([
            class_id_mapping[class_id]
            for class_id
            in detections.class_id
        ])
    labels = [
        f"{class_name} ({confidence:.2f})"
        for class_name, confidence
        in zip(detections['class_name'], detections.confidence)
    ]
    annotated_image = input_image.copy()
    annotated_image = BOUNDING_BOX_ANNOTATORS.annotate(
        scene=annotated_image, detections=detections)
    annotated_image = LABEL_ANNOTATORS.annotate(
        scene=annotated_image, detections=detections, labels=labels)
    return annotated_image
def process_image(
    input_image: np.ndarray,
    yolo_v8_confidence_threshold: float,
    yolo_v9_confidence_threshold: float,
    yolo_v10_confidence_threshold: float,
    yolov11_confidence_threshold: float,
    iou_threshold: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    # Validate iou_threshold before using it
    if iou_threshold is None or not isinstance(iou_threshold, float):
        iou_threshold = 0.3  # Default value, adjust as necessary
    yolo_v8s_annotated_image = detect_and_annotate(
        YOLO_V8S_MODEL, input_image, yolo_v8_confidence_threshold, iou_threshold)
    yolo_v9s_annotated_image = detect_and_annotate(
        YOLO_V9S_MODEL, input_image, yolo_v9_confidence_threshold, iou_threshold)
    yolo_v10s_annotated_image = detect_and_annotate(
        YOLO_V10S_MODEL, input_image, yolo_v10_confidence_threshold, iou_threshold)
    yolo_11s_annnotated_image = detect_and_annotate(
        YOLO_11S_MODEL, input_image, yolov11_confidence_threshold, iou_threshold)
    return (
        yolo_v8s_annotated_image,
        yolo_v9s_annotated_image,
        yolo_v10s_annotated_image,
        yolo_11s_annnotated_image
    )
yolo_v8s_confidence_threshold_component = gr.Slider(
    minimum=0,
    maximum=1.0,
    value=0.3,
    step=0.01,
    label="YOLOv8s Confidence Threshold",
    info=(
        "The confidence threshold for the YOLO model. Lower the threshold to "
        "reduce false negatives, enhancing the model's sensitivity to detect "
        "sought-after objects. Conversely, increase the threshold to minimize false "
        "positives, preventing the model from identifying objects it shouldn't."
    ))
yolo_v9s_confidence_threshold_component = gr.Slider(
    minimum=0,
    maximum=1.0,
    value=0.3,
    step=0.01,
    label="YOLOv9s Confidence Threshold",
    info=(
        "The confidence threshold for the YOLO model. Lower the threshold to "
        "reduce false negatives, enhancing the model's sensitivity to detect "
        "sought-after objects. Conversely, increase the threshold to minimize false "
        "positives, preventing the model from identifying objects it shouldn't."
    ))
yolo_v10s_confidence_threshold_component = gr.Slider(
    minimum=0,
    maximum=1.0,
    value=0.3,
    step=0.01,
    label="YOLOv10s Confidence Threshold",
    info=(
        "The confidence threshold for the YOLO model. Lower the threshold to "
        "reduce false negatives, enhancing the model's sensitivity to detect "
        "sought-after objects. Conversely, increase the threshold to minimize false "
        "positives, preventing the model from identifying objects it shouldn't."
    ))
yolo_11s_confidence_threshold_component = gr.Slider(
    minimum=0,
    maximum=1.0,
    value=0.3,
    step=0.01,
    label="YOLO11s Confidence Threshold",
    info=(
        "The confidence threshold for the YOLO model. Lower the threshold to "
        "reduce false negatives, enhancing the model's sensitivity to detect "
        "sought-after objects. Conversely, increase the threshold to minimize false "
        "positives, preventing the model from identifying objects it shouldn't."
    ))
iou_threshold_component = gr.Slider(
    minimum=0,
    maximum=1.0,
    value=0.5,
    step=0.01,
    label="IoU Threshold",
    info=(
        "The Intersection over Union (IoU) threshold for non-maximum suppression. "
        "Decrease the value to lessen the occurrence of overlapping bounding boxes, "
        "making the detection process stricter. On the other hand, increase the value "
        "to allow more overlapping bounding boxes, accommodating a broader range of "
        "detections."
    ))
with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Accordion("Configuration", open=False):
        with gr.Row():
            yolo_v8s_confidence_threshold_component.render()
            yolo_v9s_confidence_threshold_component.render()
            yolo_v10s_confidence_threshold_component.render()
            yolo_11s_confidence_threshold_component.render()
        iou_threshold_component.render()
    with gr.Row():
        input_image_component = gr.Image(
            type='pil',
            label='Input'
        )
    with gr.Row():
        yolo_v8s_output_image_component = gr.Image(
            type='pil',
            label='YOLOv8s'
        )
        yolo_v9s_output_image_component = gr.Image(
            type='pil',
            label='YOLOv9s'
        )
    with gr.Row():
        yolo_v10s_output_image_component = gr.Image(
            type='pil',
            label='YOLOv10s'
        )
        yolo_11s_output_image_component = gr.Image(
            type='pil',
            label='YOLO11s'
            )
    submit_button_component = gr.Button(
        value='Submit',
        scale=1,
        variant='primary'
    )
    gr.Examples(
        fn=process_image,
        examples=IMAGE_EXAMPLES,
        inputs=[
            input_image_component,
            yolo_v8s_confidence_threshold_component,
            yolo_v9s_confidence_threshold_component,
            yolo_v10s_confidence_threshold_component,
            yolo_11s_confidence_threshold_component,
            iou_threshold_component
        ],
        outputs=[
            yolo_v8s_output_image_component,
            yolo_v9s_output_image_component,
            yolo_v10s_output_image_component,
            yolo_11s_output_image_component
        ]
    )
    submit_button_component.click(
        fn=process_image,
        inputs=[
            input_image_component,
            yolo_v8s_confidence_threshold_component,
            yolo_v9s_confidence_threshold_component,
            yolo_v10s_confidence_threshold_component,
            yolo_11s_confidence_threshold_component,
            iou_threshold_component
        ],
        outputs=[
            yolo_v8s_output_image_component,
            yolo_v9s_output_image_component,
            yolo_v10s_output_image_component,
            yolo_11s_output_image_component
        ]
    )
demo.launch(debug=False, show_error=True, max_threads=1)