Zaixi's picture
1
dcacefd
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
num_atom_type = 120 #including the extra mask tokens
num_chirality_tag = 3
num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3
class GINConv(MessagePassing):
"""
Extension of GIN aggregation to incorporate edge information by concatenation.
Args:
emb_dim (int): dimensionality of embeddings for nodes and edges.
embed_input (bool): whether to embed input or not.
See https://arxiv.org/abs/1810.00826
"""
def __init__(self, emb_dim, aggr = "add"):
super(GINConv, self).__init__()
#multi-layer perceptron
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.aggr = aggr
def forward(self, x, edge_index, edge_attr):
#add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
#add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)
def message(self, x_j, edge_attr):
return x_j + edge_attr
def update(self, aggr_out):
return self.mlp(aggr_out)
class GCNConv(MessagePassing):
def __init__(self, emb_dim, aggr = "add"):
super(GCNConv, self).__init__()
self.emb_dim = emb_dim
self.linear = torch.nn.Linear(emb_dim, emb_dim)
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.aggr = aggr
def norm(self, edge_index, num_nodes, dtype):
### assuming that self-loops have been already added in edge_index
edge_index = edge_index[0]
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_attr):
#add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
#add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
norm = self.norm(edge_index, x.size(0), x.dtype)
x = self.linear(x)
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
def message(self, x_j, edge_attr, norm):
return norm.view(-1, 1) * (x_j + edge_attr)
class GATConv(MessagePassing):
def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr = "add"):
super(GATConv, self).__init__()
self.aggr = aggr
self.emb_dim = emb_dim
self.heads = heads
self.negative_slope = negative_slope
self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim)
self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))
self.bias = torch.nn.Parameter(torch.Tensor(emb_dim))
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.reset_parameters()
def norm(self, edge_index, num_nodes, dtype):
### assuming that self-loops have been already added in edge_index
edge_index = edge_index[0]
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def reset_parameters(self):
glorot(self.att)
zeros(self.bias)
def forward(self, x, edge_index, edge_attr):
#add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
norm = self.norm(edge_index, x.size(0), x.dtype)
#add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
def message(self, edge_index, x_i, x_j, edge_attr):
edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
x_j += edge_attr
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index[0])
return x_j * alpha.view(-1, self.heads, 1)
def update(self, aggr_out):
aggr_out = aggr_out.mean(dim=1)
aggr_out = aggr_out + self.bias
return aggr_out
class GraphSAGEConv(MessagePassing):
def __init__(self, emb_dim, aggr = "mean"):
super(GraphSAGEConv, self).__init__()
self.emb_dim = emb_dim
self.linear = torch.nn.Linear(emb_dim, emb_dim)
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.aggr = aggr
def norm(self, edge_index, num_nodes, dtype):
### assuming that self-loops have been already added in edge_index
edge_index = edge_index[0]
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_attr):
#add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
#add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
norm = self.norm(edge_index, x.size(0), x.dtype)
x = self.linear(x)
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
def message(self, x_j, edge_attr):
return x_j + edge_attr
def update(self, aggr_out):
return F.normalize(aggr_out, p = 2, dim = -1)
class GNN(torch.nn.Module):
"""
Args:
num_layer (int): the number of GNN layers
emb_dim (int): dimensionality of embeddings
JK (str): last, concat, max or sum.
max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
drop_ratio (float): dropout rate
gnn_type: gin, gcn, graphsage, gat
Output:
node representations
"""
def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
super(GNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)
torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
###List of MLPs
self.gnns = torch.nn.ModuleList()
for layer in range(num_layer):
if gnn_type == "gin":
self.gnns.append(GINConv(emb_dim, aggr = "add"))
elif gnn_type == "gcn":
self.gnns.append(GCNConv(emb_dim))
elif gnn_type == "gat":
self.gnns.append(GATConv(emb_dim))
elif gnn_type == "graphsage":
self.gnns.append(GraphSAGEConv(emb_dim))
###List of batchnorms
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layer):
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
#def forward(self, x, edge_index, edge_attr):
def forward(self, *argv):
if len(argv) == 3:
x, edge_index, edge_attr = argv[0], argv[1], argv[2]
elif len(argv) == 1:
data = argv[0]
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
else:
raise ValueError("unmatched number of arguments.")
x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
h_list = [x]
for layer in range(self.num_layer):
h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
#h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
h_list.append(h)
### Different implementations of Jk-concat
if self.JK == "concat":
node_representation = torch.cat(h_list, dim = 1)
elif self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "max":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
elif self.JK == "sum":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]
return node_representation
class GNN_graphpred(torch.nn.Module):
"""
Extension of GIN to incorporate edge information by concatenation.
Args:
num_layer (int): the number of GNN layers
emb_dim (int): dimensionality of embeddings
num_tasks (int): number of tasks in multi-task learning scenario
drop_ratio (float): dropout rate
JK (str): last, concat, max or sum.
graph_pooling (str): sum, mean, max, attention, set2set
gnn_type: gin, gcn, graphsage, gat
See https://arxiv.org/abs/1810.00826
JK-net: https://arxiv.org/abs/1806.03536
"""
def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"):
super(GNN_graphpred, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type)
#Different kind of graph pooling
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
if self.JK == "concat":
self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
else:
self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1))
elif graph_pooling[:-1] == "set2set":
set2set_iter = int(graph_pooling[-1])
if self.JK == "concat":
self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
else:
self.pool = Set2Set(emb_dim, set2set_iter)
else:
raise ValueError("Invalid graph pooling type.")
#For graph-level binary classification
if graph_pooling[:-1] == "set2set":
self.mult = 2
else:
self.mult = 1
def from_pretrained(self, model_file):
self.gnn.load_state_dict(torch.load(model_file))
def forward(self, *argv):
if len(argv) == 4:
x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
elif len(argv) == 1:
data = argv[0]
x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
else:
raise ValueError("unmatched number of arguments.")
node_representation = self.gnn(x, edge_index, edge_attr)
return self.pool(node_representation, batch)
class MLP(nn.Module):
"""
Creates a NN using nn.ModuleList to automatically adjust the number of layers.
For each hidden layer, the number of inputs and outputs is constant.
Inputs:
in_dim (int): number of features contained in the input layer.
out_dim (int): number of features input and output from each hidden layer,
including the output layer.
num_layers (int): number of layers in the network
activation (torch function): activation function to be used during the hidden layers
"""
def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), layer_norm=False, batch_norm=False):
super(MLP, self).__init__()
self.layers = nn.ModuleList()
h_dim = in_dim if out_dim < 10 else out_dim
# create the input layer
for layer in range(num_layers):
if layer == 0:
self.layers.append(nn.Linear(in_dim, h_dim))
else:
self.layers.append(nn.Linear(h_dim, h_dim))
if layer_norm: self.layers.append(nn.LayerNorm(h_dim))
if batch_norm: self.layers.append(nn.BatchNorm1d(h_dim))
self.layers.append(activation)
self.layers.append(nn.Linear(h_dim, out_dim))
def forward(self, x):
for i in range(len(self.layers)):
x = self.layers[i](x)
return x
class WeightConv(MessagePassing):
def __init__(self, emb_dim, aggr = "add"):
super(WeightConv, self).__init__()
self.emb_dim = emb_dim
self.linear = torch.nn.Linear(emb_dim, emb_dim)
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.aggr = aggr
def norm(self, edge_index, num_nodes, edge_weight, dtype):
### assuming that self-loops have been already added in edge_index
edge_index = edge_index[0]
#edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight.view(-1), row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-1.0)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * edge_weight
norm[:-num_nodes] = 1
return norm
def forward(self, x, edge_index, edge_attr):
#add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
#add features corresponding to self-loop edges.
self_loop_attr = torch.ones(x.size(0))
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr))
norm = self.norm(edge_index, x.size(0), edge_attr, x.dtype)
#x = self.linear(x)
return self.propagate(edge_index[0], x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return self.linear(aggr_out)
class WeightGNN(torch.nn.Module):
def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0):
super(WeightGNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
###List of MLPs
self.gnns = torch.nn.ModuleList()
for layer in range(num_layer):
self.gnns.append(WeightConv(emb_dim))
###List of batchnorms
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layer):
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
self.transform_p = nn.Linear(emb_dim, emb_dim)
self.transform_m = nn.Linear(emb_dim, emb_dim)
# def forward(self, x, edge_index, edge_attr):
def forward(self, *argv):
if len(argv) == 4:
x_p, x_m, edge_index, edge_attr = argv[0], argv[1], argv[2], argv[3]
else:
raise ValueError("unmatched number of arguments.")
# convert pocket prototypes and motifs to the same hidden space
x = torch.cat([self.transform_p(x_p), self.transform_m(x_m)], dim=0)
h_list = [x]
for layer in range(self.num_layer):
h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h_list.append(h)
### Different implementations of Jk-concat
if self.JK == "concat":
node_representation = torch.cat(h_list, dim=1)
elif self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "max":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]
elif self.JK == "sum":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]
return node_representation
if __name__ == "__main__":
pass