import torch import torch.nn as nn import numpy as np import networkx as nx import matplotlib.pyplot as plt from pipeline.main import Main from pipeline.test import test from pipeline.evaluate import get_full_err_scores anomaly_node_size = 80 default_node_size = 2 central_node_color = "yellow" anomaly_node_color = "red" default_node_color = "black" anomaly_edge_color = "red" default_edge_color = (0.35686275, 0.20392157, 0.34901961, 0.1) train_config = { 'batch': 16, 'epoch': 100, 'slide_win': 5, 'dim': 64, 'slide_stride': 1, 'comment': '', 'seed': 42, 'out_layer_num': 1, 'out_layer_inter_dim': 128, 'decay': 0, 'val_ratio': 0.1, 'topk': 15, } env_config = { 'save_path': '', 'dataset': 'swat', 'report': 'best', 'device': 'cpu', 'load_model_path': '' } def compute_graph(model: nn.Module, X: torch.Tensor): n_samples, feature_num, slide_win = X.shape with torch.no_grad(): model(X, None) coeff_weights = model.gnn_layers[0].att_weight_1.cpu().detach().numpy() edge_index = model.gnn_layers[0].edge_index_1.cpu().detach().numpy() weight_mat = np.zeros((feature_num, feature_num)) for i in range(len(coeff_weights)): edge_i, edge_j = edge_index[:, i] edge_i, edge_j = edge_i % feature_num, edge_j % feature_num weight_mat[edge_i][edge_j] += coeff_weights[i] weight_mat /= n_samples return weight_mat def run_gnn(central_node_id="auto"): device = "cpu" main = Main(train_config, env_config, debug=False) model = main.model.to(device) checkpoint = torch.load("best_05_22_15_03_20.pt", map_location=torch.device(device)) main.model.load_state_dict(checkpoint) _, train_result = test(model, main.train_dataloader) _, test_result = test(model, main.test_dataloader) all_scores, _ = get_full_err_scores(train_result, test_result) X_train = main.train_dataset.x.float().to(device) n_samples, feature_num, slide_win = X_train.shape adj_mat = compute_graph(model, X_train[:100]) if central_node_id == "auto": central_node = all_scores.mean(axis=1).argmax() else: central_node = int(central_node_id) scores = np.stack([adj_mat[central_node], adj_mat[:, central_node]], axis=1) scores = np.max(scores, axis=1) red_nodes = list(np.where(scores > 0.1)[0]) G = nx.from_numpy_array(adj_mat) G.remove_edges_from(nx.selfloop_edges(G)) edges = [set(edge) for edge in G.edges()] edge_colors = [default_edge_color for edge in edges] node_colors = [default_node_color for _ in range(feature_num)] node_sizes = [default_node_size for _ in range(feature_num)] node_colors[central_node] = central_node_color node_sizes[central_node] = anomaly_node_size for node in red_nodes: if node == central_node: continue node_colors[node] = anomaly_node_color node_sizes[node] = anomaly_node_size if set((node, central_node)) in edges: edge_pos = edges.index(set((node, central_node))) edge_colors[edge_pos] = anomaly_edge_color pos = nx.spring_layout(G) graph_center = np.mean(np.array(list(pos.values())), axis=0) offset_scale = 0.3 fig, ax = plt.subplots(figsize=(8, 6)) nx.draw(G, pos, edge_color=edge_colors, node_color=node_colors, node_size=node_sizes, ax=ax) # Central node label x, y = pos[central_node] dx, dy = x - graph_center[0], y - graph_center[1] norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6 x_offset = x + offset_scale * dx / norm y_offset = y + offset_scale * dy / norm ax.text(x_offset, y_offset, s=main.feature_map[central_node], bbox=dict(facecolor=central_node_color, alpha=0.5), horizontalalignment='center') # Red node labels and dotted lines for node in red_nodes: if node == central_node: continue x, y = pos[node] dx, dy = x - graph_center[0], y - graph_center[1] norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6 x_offset = x + offset_scale * dx / norm y_offset = y + offset_scale * dy / norm ax.plot([x, x_offset], [y, y_offset], 'k--', linewidth=0.8) ax.text(x_offset, y_offset, s=main.feature_map[node], bbox=dict(facecolor=anomaly_node_color, alpha=0.5), horizontalalignment='center') fig.tight_layout() # ? Convert feature map from list to dict for Streamlit compatibility feature_map_dict = {i: label for i, label in enumerate(main.feature_map)} return fig, feature_map_dict, red_nodes, central_node, scores, G