asammoud
Re-add large CSVs using Git LFS
b265364
raw
history blame
3.04 kB
import os
import torch
from rfdetr import RFDETRBase
import supervision as sv
from PIL import Image
import numpy as np
if not torch.distributed.is_initialized():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
model = RFDETRBase()
model.train(dataset_dir="P&ID-Symbols-3/P&ID-Symbols-3", resume="output/checkpoint0009.pth", epochs=0)
ds = sv.DetectionDataset.from_coco(
images_directory_path="P&ID-Symbols-3/P&ID-Symbols-3/test",
annotations_path="P&ID-Symbols-3/P&ID-Symbols-3/test/_annotations.coco.json",
)
import streamlit as st
import io
def detect_symbols_and_lines(image):
# Convert to PIL.Image if needed
if not isinstance(image, Image.Image):
if hasattr(image, "read"):
image = Image.open(image)
# === Improve resolution ===
upscale_factor = 2
new_size = (int(image.width * upscale_factor), int(image.height * upscale_factor))
# image = image.resize(new_size, resample=Image.BICUBIC)
# === Run model prediction ===
detections = model.predict(image, threshold=0.5)
# === Find matching dataset entry ===
matching_index = None
for idx in range(len(ds)):
img_path, _, _ = ds[idx]
if os.path.basename(img_path) == getattr(image, "filename", None):
matching_index = idx
break
if matching_index is None:
st.warning("No matching ground truth annotations found for this image.")
annotations = sv.Detections.empty()
annotations_labels = []
else:
_, _, annotations = ds[matching_index]
annotations_labels = [f"{ds.classes[class_id]}" for class_id in annotations.class_id]
detections_labels = [
f"{ds.classes[class_id]} {confidence:.2f}"
for class_id, confidence in zip(detections.class_id, detections.confidence)
]
text_scale = 0.9
thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)
bbox_annotator = sv.BoxAnnotator(thickness=thickness)
label_annotator = sv.LabelAnnotator(
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_thickness=thickness,
smart_position=True
)
annotation_image = image.copy()
annotation_image = bbox_annotator.annotate(annotation_image, annotations)
annotation_image = label_annotator.annotate(annotation_image, annotations, annotations_labels)
detections_image = image.copy()
detections_image = bbox_annotator.annotate(detections_image, detections)
detections_image = label_annotator.annotate(detections_image, detections, detections_labels)
# === Display side-by-side in Streamlit ===
col1, col2 = st.columns(2)
with col1:
st.image(annotation_image, caption="Ground Truth Annotations", use_column_width=True)
with col2:
st.image(detections_image, caption="Model Predictions", use_column_width=True)
return detections, annotations, ds.classes