File size: 3,040 Bytes
b265364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import torch
from rfdetr import RFDETRBase
import supervision as sv
from PIL import Image
import numpy as np

if not torch.distributed.is_initialized():
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    os.environ["LOCAL_RANK"] = "0"

model = RFDETRBase()
model.train(dataset_dir="P&ID-Symbols-3/P&ID-Symbols-3", resume="output/checkpoint0009.pth", epochs=0)

ds = sv.DetectionDataset.from_coco(
    images_directory_path="P&ID-Symbols-3/P&ID-Symbols-3/test",
    annotations_path="P&ID-Symbols-3/P&ID-Symbols-3/test/_annotations.coco.json",
)

import streamlit as st
import io

def detect_symbols_and_lines(image):
    # Convert to PIL.Image if needed
    if not isinstance(image, Image.Image):
        if hasattr(image, "read"):
            image = Image.open(image)

    # === Improve resolution ===
    upscale_factor = 2
    new_size = (int(image.width * upscale_factor), int(image.height * upscale_factor))
    # image = image.resize(new_size, resample=Image.BICUBIC)

    # === Run model prediction ===
    detections = model.predict(image, threshold=0.5)

    # === Find matching dataset entry ===
    matching_index = None
    for idx in range(len(ds)):
        img_path, _, _ = ds[idx]
        if os.path.basename(img_path) == getattr(image, "filename", None):
            matching_index = idx
            break

    if matching_index is None:
        st.warning("No matching ground truth annotations found for this image.")
        annotations = sv.Detections.empty()
        annotations_labels = []
    else:
        _, _, annotations = ds[matching_index]
        annotations_labels = [f"{ds.classes[class_id]}" for class_id in annotations.class_id]

    detections_labels = [
        f"{ds.classes[class_id]} {confidence:.2f}"
        for class_id, confidence in zip(detections.class_id, detections.confidence)
    ]

    text_scale = 0.9
    thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)

    bbox_annotator = sv.BoxAnnotator(thickness=thickness)
    label_annotator = sv.LabelAnnotator(
        text_color=sv.Color.BLACK,
        text_scale=text_scale,
        text_thickness=thickness,
        smart_position=True
    )

    annotation_image = image.copy()
    annotation_image = bbox_annotator.annotate(annotation_image, annotations)
    annotation_image = label_annotator.annotate(annotation_image, annotations, annotations_labels)

    detections_image = image.copy()
    detections_image = bbox_annotator.annotate(detections_image, detections)
    detections_image = label_annotator.annotate(detections_image, detections, detections_labels)

    # === Display side-by-side in Streamlit ===
    col1, col2 = st.columns(2)
    with col1:
        st.image(annotation_image, caption="Ground Truth Annotations", use_column_width=True)
    with col2:
        st.image(detections_image, caption="Model Predictions", use_column_width=True)

    return detections, annotations, ds.classes