|
|
import numpy as np
|
|
|
import torch
|
|
|
import matplotlib.pyplot as plt
|
|
|
import torch.nn as nn
|
|
|
import time
|
|
|
from util.time import *
|
|
|
from util.env import *
|
|
|
from torch_geometric.nn import GCNConv, GATConv, EdgeConv
|
|
|
import math
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from .graph_layer import GraphLayer
|
|
|
|
|
|
|
|
|
def get_batch_edge_index(org_edge_index, batch_num, node_num):
|
|
|
|
|
|
edge_index = org_edge_index.clone().detach()
|
|
|
edge_num = org_edge_index.shape[1]
|
|
|
batch_edge_index = edge_index.repeat(1,batch_num).contiguous()
|
|
|
|
|
|
for i in range(batch_num):
|
|
|
batch_edge_index[:, i*edge_num:(i+1)*edge_num] += i*node_num
|
|
|
|
|
|
return batch_edge_index.long()
|
|
|
|
|
|
|
|
|
class OutLayer(nn.Module):
|
|
|
def __init__(self, in_num, node_num, layer_num, inter_num = 512):
|
|
|
super(OutLayer, self).__init__()
|
|
|
|
|
|
modules = []
|
|
|
|
|
|
for i in range(layer_num):
|
|
|
|
|
|
if i == layer_num-1:
|
|
|
modules.append(nn.Linear( in_num if layer_num == 1 else inter_num, 1))
|
|
|
else:
|
|
|
layer_in_num = in_num if i == 0 else inter_num
|
|
|
modules.append(nn.Linear( layer_in_num, inter_num ))
|
|
|
modules.append(nn.BatchNorm1d(inter_num))
|
|
|
modules.append(nn.ReLU())
|
|
|
|
|
|
self.mlp = nn.ModuleList(modules)
|
|
|
|
|
|
def forward(self, x):
|
|
|
out = x
|
|
|
|
|
|
for mod in self.mlp:
|
|
|
if isinstance(mod, nn.BatchNorm1d):
|
|
|
out = out.permute(0,2,1)
|
|
|
out = mod(out)
|
|
|
out = out.permute(0,2,1)
|
|
|
else:
|
|
|
out = mod(out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class GNNLayer(nn.Module):
|
|
|
def __init__(self, in_channel, out_channel, inter_dim=0, heads=1, node_num=100):
|
|
|
super(GNNLayer, self).__init__()
|
|
|
|
|
|
|
|
|
self.gnn = GraphLayer(in_channel, out_channel, inter_dim=inter_dim, heads=heads, concat=False)
|
|
|
|
|
|
self.bn = nn.BatchNorm1d(out_channel)
|
|
|
self.relu = nn.ReLU()
|
|
|
self.leaky_relu = nn.LeakyReLU()
|
|
|
|
|
|
def forward(self, x, edge_index, embedding=None, node_num=0):
|
|
|
|
|
|
out, (new_edge_index, att_weight) = self.gnn(x, edge_index, embedding, return_attention_weights=True)
|
|
|
self.att_weight_1 = att_weight
|
|
|
self.edge_index_1 = new_edge_index
|
|
|
|
|
|
out = self.bn(out)
|
|
|
|
|
|
return self.relu(out)
|
|
|
|
|
|
|
|
|
class GDN(nn.Module):
|
|
|
def __init__(self, edge_index_sets, node_num, dim=64, out_layer_inter_dim=256, input_dim=10, out_layer_num=1, topk=20):
|
|
|
|
|
|
super(GDN, self).__init__()
|
|
|
|
|
|
self.edge_index_sets = edge_index_sets
|
|
|
|
|
|
device = get_device()
|
|
|
|
|
|
edge_index = edge_index_sets[0]
|
|
|
|
|
|
|
|
|
embed_dim = dim
|
|
|
self.embedding = nn.Embedding(node_num, embed_dim)
|
|
|
self.bn_outlayer_in = nn.BatchNorm1d(embed_dim)
|
|
|
|
|
|
|
|
|
edge_set_num = len(edge_index_sets)
|
|
|
self.gnn_layers = nn.ModuleList([
|
|
|
GNNLayer(input_dim, dim, inter_dim=dim+embed_dim, heads=1) for i in range(edge_set_num)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.node_embedding = None
|
|
|
self.topk = topk
|
|
|
self.learned_graph = None
|
|
|
|
|
|
self.out_layer = OutLayer(dim*edge_set_num, node_num, out_layer_num, inter_num = out_layer_inter_dim)
|
|
|
|
|
|
self.cache_edge_index_sets = [None] * edge_set_num
|
|
|
self.cache_embed_index = None
|
|
|
|
|
|
self.dp = nn.Dropout(0.2)
|
|
|
|
|
|
self.init_params()
|
|
|
|
|
|
def init_params(self):
|
|
|
nn.init.kaiming_uniform_(self.embedding.weight, a=math.sqrt(5))
|
|
|
|
|
|
|
|
|
def forward(self, data, org_edge_index):
|
|
|
|
|
|
x = data.clone().detach()
|
|
|
edge_index_sets = self.edge_index_sets
|
|
|
|
|
|
device = data.device
|
|
|
|
|
|
batch_num, node_num, all_feature = x.shape
|
|
|
x = x.view(-1, all_feature).contiguous()
|
|
|
|
|
|
|
|
|
gcn_outs = []
|
|
|
for i, edge_index in enumerate(edge_index_sets):
|
|
|
edge_num = edge_index.shape[1]
|
|
|
cache_edge_index = self.cache_edge_index_sets[i]
|
|
|
|
|
|
if cache_edge_index is None or cache_edge_index.shape[1] != edge_num*batch_num:
|
|
|
self.cache_edge_index_sets[i] = get_batch_edge_index(edge_index, batch_num, node_num).to(device)
|
|
|
|
|
|
batch_edge_index = self.cache_edge_index_sets[i]
|
|
|
|
|
|
all_embeddings = self.embedding(torch.arange(node_num).to(device))
|
|
|
|
|
|
weights_arr = all_embeddings.detach().clone()
|
|
|
all_embeddings = all_embeddings.repeat(batch_num, 1)
|
|
|
|
|
|
weights = weights_arr.view(node_num, -1)
|
|
|
|
|
|
cos_ji_mat = torch.matmul(weights, weights.T)
|
|
|
normed_mat = torch.matmul(weights.norm(dim=-1).view(-1,1), weights.norm(dim=-1).view(1,-1))
|
|
|
cos_ji_mat = cos_ji_mat / normed_mat
|
|
|
|
|
|
dim = weights.shape[-1]
|
|
|
topk_num = self.topk
|
|
|
|
|
|
topk_indices_ji = torch.topk(cos_ji_mat, topk_num, dim=-1)[1]
|
|
|
|
|
|
self.learned_graph = topk_indices_ji
|
|
|
|
|
|
|
|
|
gated_i = torch.arange(0, node_num).unsqueeze(1).repeat(1, topk_num).flatten().to(device).unsqueeze(0)
|
|
|
gated_j = topk_indices_ji.flatten().unsqueeze(0)
|
|
|
gated_edge_index = torch.cat((gated_j, gated_i), dim=0)
|
|
|
|
|
|
batch_gated_edge_index = get_batch_edge_index(gated_edge_index, batch_num, node_num).to(device)
|
|
|
gcn_out = self.gnn_layers[i](x, batch_gated_edge_index, node_num=node_num*batch_num, embedding=all_embeddings)
|
|
|
|
|
|
|
|
|
gcn_outs.append(gcn_out)
|
|
|
|
|
|
x = torch.cat(gcn_outs, dim=1)
|
|
|
x = x.view(batch_num, node_num, -1)
|
|
|
|
|
|
|
|
|
indexes = torch.arange(0,node_num).to(device)
|
|
|
out = torch.mul(x, self.embedding(indexes))
|
|
|
|
|
|
out = out.permute(0,2,1)
|
|
|
out = F.relu(self.bn_outlayer_in(out))
|
|
|
out = out.permute(0,2,1)
|
|
|
|
|
|
out = self.dp(out)
|
|
|
out = self.out_layer(out)
|
|
|
out = out.view(-1, node_num)
|
|
|
|
|
|
|
|
|
return out
|
|
|
|