import numpy as np import matplotlib.pyplot as plt import streamlit as st from PIL import Image import cv2 import networkx as nx # <-- Added def build_graph(pil_image, detections, annotations, class_names): def dist(p1, p2): return np.hypot(p1[0] - p2[0], p1[1] - p2[1]) def angle_between(p1, p2): return np.degrees(np.arctan2(p2[1] - p1[1], p2[0] - p1[0])) % 180 def lines_are_similar(line1, line2, max_distance=10, max_angle_diff=10): (x1, y1), (x2, y2) = line1 (x3, y3), (x4, y4) = line2 angle1 = angle_between((x1, y1), (x2, y2)) angle2 = angle_between((x3, y3), (x4, y4)) if abs(angle1 - angle2) > max_angle_diff: return False mid1 = ((x1 + x2) / 2, (y1 + y2) / 2) mid2 = ((x3 + x4) / 2, (y3 + y4) / 2) return dist(mid1, mid2) < max_distance def merge_similar_lines(lines): if not lines: return [] merged, used = [], set() for i, l1 in enumerate(lines): if i in used: continue group = [l1]; used.add(i) for j, l2 in enumerate(lines): if j != i and j not in used and lines_are_similar(l1, l2): group.append(l2); used.add(j) x_coords, y_coords = [], [] for (x1, y1), (x2, y2) in group: x_coords.extend([x1, x2]) y_coords.extend([y1, y2]) merged.append(((int(min(x_coords)), int(min(y_coords))), (int(max(x_coords)), int(max(y_coords))))) return merged def point_inside_bbox(px, py, bbox): x1, y1, x2, y2 = bbox return x1 <= px <= x2 and y1 <= py <= y2 def find_nearest_symbol(point, symbols, max_dist=15): px, py = point nearest_sym, nearest_dist = None, float('inf') for sym in symbols: sx, sy = sym['pos'] d = dist((px, py), (sx, sy)) if d < nearest_dist and d <= max_dist: nearest_sym, nearest_dist = sym, d if nearest_sym is None: for sym in symbols: if point_inside_bbox(px, py, sym['bbox']): nearest_sym = sym break return nearest_sym # Convert PIL image to OpenCV format image_cv = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) # Filter symbols allowed_types = {"connector", "crossing", "border_node"} symbols = [] for idx, (box, class_id) in enumerate(zip(detections.xyxy, detections.class_id)): label = class_names[class_id] if label in allowed_types: x1, y1, x2, y2 = map(int, box) cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 symbols.append({ "id": f"{label}_{idx}", "type": label, "pos": (cx, cy), "bbox": (x1, y1, x2, y2) }) # Hough line detection gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY) blurred = cv2.GaussianBlur(gray, (3, 3), 0) edges = cv2.Canny(blurred, 50, 150, apertureSize=3) lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=50, maxLineGap=10) detected_lines = [((x1, y1), (x2, y2)) for line in lines for x1, y1, x2, y2 in line] if lines is not None else [] merged_lines = merge_similar_lines(detected_lines) filtered_lines = [] for pt1, pt2 in merged_lines: sym1 = find_nearest_symbol(pt1, symbols) sym2 = find_nearest_symbol(pt2, symbols) if sym1 and sym2 and sym1 != sym2: filtered_lines.append((pt1, pt2)) # Draw results on image output = image_cv.copy() for sym in symbols: x1, y1, x2, y2 = sym["bbox"] cx, cy = sym["pos"] cv2.rectangle(output, (x1, y1), (x2, y2), (255, 0, 0), 2) cv2.circle(output, (cx, cy), 3, (0, 255, 255), -1) cv2.putText(output, sym["type"], (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, .6, (255, 0, 0), 1) for (x1, y1), (x2, y2) in filtered_lines: cv2.line(output, (x1, y1), (x2, y2), (0, 100, 255), 2) st.image(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)), caption="Graph: Merged Lines + Detected Symbols", use_column_width=True) # === Additional: Plot NetworkX graph === # Ensure each symbol has a unique ID for i, sym in enumerate(symbols): sym['id'] = f"{sym['type']}_{i}" # Build graph G = nx.Graph() for sym in symbols: G.add_node(sym['id'], label=sym['type'], pos=sym['pos']) for pt1, pt2 in filtered_lines: sym1 = find_nearest_symbol(pt1, symbols) sym2 = find_nearest_symbol(pt2, symbols) if sym1 and sym2 and sym1['id'] != sym2['id']: G.add_edge(sym1['id'], sym2['id']) # Draw NetworkX graph in Streamlit fig, ax = plt.subplots(figsize=(8, 8)) pos = {node: data['pos'] for node, data in G.nodes(data=True)} labels = {node: data['label'] for node, data in G.nodes(data=True)} nx.draw(G, pos, labels=labels, node_size=700, node_color='lightblue', font_size=8, with_labels=True, ax=ax) ax.set_title("Extracted Graph from Detected Symbols and Lines") st.pyplot(fig) return G