Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch_geometric.nn import MessagePassing | |
| from torch_geometric.utils import add_self_loops, degree, softmax | |
| from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set | |
| import torch.nn.functional as F | |
| from torch_scatter import scatter_add | |
| from torch_geometric.nn.inits import glorot, zeros | |
| num_atom_type = 120 #including the extra mask tokens | |
| num_chirality_tag = 3 | |
| num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens | |
| num_bond_direction = 3 | |
| class GINConv(MessagePassing): | |
| """ | |
| Extension of GIN aggregation to incorporate edge information by concatenation. | |
| Args: | |
| emb_dim (int): dimensionality of embeddings for nodes and edges. | |
| embed_input (bool): whether to embed input or not. | |
| See https://arxiv.org/abs/1810.00826 | |
| """ | |
| def __init__(self, emb_dim, aggr = "add"): | |
| super(GINConv, self).__init__() | |
| #multi-layer perceptron | |
| self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) | |
| self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) | |
| self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) | |
| self.aggr = aggr | |
| def forward(self, x, edge_index, edge_attr): | |
| #add self loops in the edge space | |
| edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) | |
| #add features corresponding to self-loop edges. | |
| self_loop_attr = torch.zeros(x.size(0), 2) | |
| self_loop_attr[:,0] = 4 #bond type for self-loop edge | |
| self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) | |
| edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) | |
| edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) | |
| return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings) | |
| def message(self, x_j, edge_attr): | |
| return x_j + edge_attr | |
| def update(self, aggr_out): | |
| return self.mlp(aggr_out) | |
| class GCNConv(MessagePassing): | |
| def __init__(self, emb_dim, aggr = "add"): | |
| super(GCNConv, self).__init__() | |
| self.emb_dim = emb_dim | |
| self.linear = torch.nn.Linear(emb_dim, emb_dim) | |
| self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) | |
| self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) | |
| self.aggr = aggr | |
| def norm(self, edge_index, num_nodes, dtype): | |
| ### assuming that self-loops have been already added in edge_index | |
| edge_index = edge_index[0] | |
| edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, | |
| device=edge_index.device) | |
| row, col = edge_index | |
| deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) | |
| deg_inv_sqrt = deg.pow(-0.5) | |
| deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 | |
| return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] | |
| def forward(self, x, edge_index, edge_attr): | |
| #add self loops in the edge space | |
| edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) | |
| #add features corresponding to self-loop edges. | |
| self_loop_attr = torch.zeros(x.size(0), 2) | |
| self_loop_attr[:,0] = 4 #bond type for self-loop edge | |
| self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) | |
| edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) | |
| edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) | |
| norm = self.norm(edge_index, x.size(0), x.dtype) | |
| x = self.linear(x) | |
| return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm) | |
| def message(self, x_j, edge_attr, norm): | |
| return norm.view(-1, 1) * (x_j + edge_attr) | |
| class GATConv(MessagePassing): | |
| def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr = "add"): | |
| super(GATConv, self).__init__() | |
| self.aggr = aggr | |
| self.emb_dim = emb_dim | |
| self.heads = heads | |
| self.negative_slope = negative_slope | |
| self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim) | |
| self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim)) | |
| self.bias = torch.nn.Parameter(torch.Tensor(emb_dim)) | |
| self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim) | |
| self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) | |
| self.reset_parameters() | |
| def norm(self, edge_index, num_nodes, dtype): | |
| ### assuming that self-loops have been already added in edge_index | |
| edge_index = edge_index[0] | |
| edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, | |
| device=edge_index.device) | |
| row, col = edge_index | |
| deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) | |
| deg_inv_sqrt = deg.pow(-0.5) | |
| deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 | |
| return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] | |
| def reset_parameters(self): | |
| glorot(self.att) | |
| zeros(self.bias) | |
| def forward(self, x, edge_index, edge_attr): | |
| #add self loops in the edge space | |
| edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) | |
| norm = self.norm(edge_index, x.size(0), x.dtype) | |
| #add features corresponding to self-loop edges. | |
| self_loop_attr = torch.zeros(x.size(0), 2) | |
| self_loop_attr[:,0] = 4 #bond type for self-loop edge | |
| self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) | |
| edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) | |
| edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) | |
| x = self.weight_linear(x).view(-1, self.heads, self.emb_dim) | |
| return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm) | |
| def message(self, edge_index, x_i, x_j, edge_attr): | |
| edge_attr = edge_attr.view(-1, self.heads, self.emb_dim) | |
| x_j += edge_attr | |
| alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) | |
| alpha = F.leaky_relu(alpha, self.negative_slope) | |
| alpha = softmax(alpha, edge_index[0]) | |
| return x_j * alpha.view(-1, self.heads, 1) | |
| def update(self, aggr_out): | |
| aggr_out = aggr_out.mean(dim=1) | |
| aggr_out = aggr_out + self.bias | |
| return aggr_out | |
| class GraphSAGEConv(MessagePassing): | |
| def __init__(self, emb_dim, aggr = "mean"): | |
| super(GraphSAGEConv, self).__init__() | |
| self.emb_dim = emb_dim | |
| self.linear = torch.nn.Linear(emb_dim, emb_dim) | |
| self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) | |
| self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) | |
| self.aggr = aggr | |
| def norm(self, edge_index, num_nodes, dtype): | |
| ### assuming that self-loops have been already added in edge_index | |
| edge_index = edge_index[0] | |
| edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, | |
| device=edge_index.device) | |
| row, col = edge_index | |
| deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) | |
| deg_inv_sqrt = deg.pow(-0.5) | |
| deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 | |
| return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] | |
| def forward(self, x, edge_index, edge_attr): | |
| #add self loops in the edge space | |
| edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) | |
| #add features corresponding to self-loop edges. | |
| self_loop_attr = torch.zeros(x.size(0), 2) | |
| self_loop_attr[:,0] = 4 #bond type for self-loop edge | |
| self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) | |
| edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) | |
| edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) | |
| norm = self.norm(edge_index, x.size(0), x.dtype) | |
| x = self.linear(x) | |
| return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm) | |
| def message(self, x_j, edge_attr): | |
| return x_j + edge_attr | |
| def update(self, aggr_out): | |
| return F.normalize(aggr_out, p = 2, dim = -1) | |
| class GNN(torch.nn.Module): | |
| """ | |
| Args: | |
| num_layer (int): the number of GNN layers | |
| emb_dim (int): dimensionality of embeddings | |
| JK (str): last, concat, max or sum. | |
| max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation | |
| drop_ratio (float): dropout rate | |
| gnn_type: gin, gcn, graphsage, gat | |
| Output: | |
| node representations | |
| """ | |
| def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"): | |
| super(GNN, self).__init__() | |
| self.num_layer = num_layer | |
| self.drop_ratio = drop_ratio | |
| self.JK = JK | |
| if self.num_layer < 2: | |
| raise ValueError("Number of GNN layers must be greater than 1.") | |
| self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim) | |
| self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim) | |
| torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data) | |
| torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data) | |
| ###List of MLPs | |
| self.gnns = torch.nn.ModuleList() | |
| for layer in range(num_layer): | |
| if gnn_type == "gin": | |
| self.gnns.append(GINConv(emb_dim, aggr = "add")) | |
| elif gnn_type == "gcn": | |
| self.gnns.append(GCNConv(emb_dim)) | |
| elif gnn_type == "gat": | |
| self.gnns.append(GATConv(emb_dim)) | |
| elif gnn_type == "graphsage": | |
| self.gnns.append(GraphSAGEConv(emb_dim)) | |
| ###List of batchnorms | |
| self.batch_norms = torch.nn.ModuleList() | |
| for layer in range(num_layer): | |
| self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) | |
| #def forward(self, x, edge_index, edge_attr): | |
| def forward(self, *argv): | |
| if len(argv) == 3: | |
| x, edge_index, edge_attr = argv[0], argv[1], argv[2] | |
| elif len(argv) == 1: | |
| data = argv[0] | |
| x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr | |
| else: | |
| raise ValueError("unmatched number of arguments.") | |
| x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1]) | |
| h_list = [x] | |
| for layer in range(self.num_layer): | |
| h = self.gnns[layer](h_list[layer], edge_index, edge_attr) | |
| h = self.batch_norms[layer](h) | |
| #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) | |
| if layer == self.num_layer - 1: | |
| #remove relu for the last layer | |
| h = F.dropout(h, self.drop_ratio, training = self.training) | |
| else: | |
| h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) | |
| h_list.append(h) | |
| ### Different implementations of Jk-concat | |
| if self.JK == "concat": | |
| node_representation = torch.cat(h_list, dim = 1) | |
| elif self.JK == "last": | |
| node_representation = h_list[-1] | |
| elif self.JK == "max": | |
| h_list = [h.unsqueeze_(0) for h in h_list] | |
| node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0] | |
| elif self.JK == "sum": | |
| h_list = [h.unsqueeze_(0) for h in h_list] | |
| node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0] | |
| return node_representation | |
| class GNN_graphpred(torch.nn.Module): | |
| """ | |
| Extension of GIN to incorporate edge information by concatenation. | |
| Args: | |
| num_layer (int): the number of GNN layers | |
| emb_dim (int): dimensionality of embeddings | |
| num_tasks (int): number of tasks in multi-task learning scenario | |
| drop_ratio (float): dropout rate | |
| JK (str): last, concat, max or sum. | |
| graph_pooling (str): sum, mean, max, attention, set2set | |
| gnn_type: gin, gcn, graphsage, gat | |
| See https://arxiv.org/abs/1810.00826 | |
| JK-net: https://arxiv.org/abs/1806.03536 | |
| """ | |
| def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"): | |
| super(GNN_graphpred, self).__init__() | |
| self.num_layer = num_layer | |
| self.drop_ratio = drop_ratio | |
| self.JK = JK | |
| self.emb_dim = emb_dim | |
| if self.num_layer < 2: | |
| raise ValueError("Number of GNN layers must be greater than 1.") | |
| self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type) | |
| #Different kind of graph pooling | |
| if graph_pooling == "sum": | |
| self.pool = global_add_pool | |
| elif graph_pooling == "mean": | |
| self.pool = global_mean_pool | |
| elif graph_pooling == "max": | |
| self.pool = global_max_pool | |
| elif graph_pooling == "attention": | |
| if self.JK == "concat": | |
| self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1)) | |
| else: | |
| self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1)) | |
| elif graph_pooling[:-1] == "set2set": | |
| set2set_iter = int(graph_pooling[-1]) | |
| if self.JK == "concat": | |
| self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter) | |
| else: | |
| self.pool = Set2Set(emb_dim, set2set_iter) | |
| else: | |
| raise ValueError("Invalid graph pooling type.") | |
| #For graph-level binary classification | |
| if graph_pooling[:-1] == "set2set": | |
| self.mult = 2 | |
| else: | |
| self.mult = 1 | |
| def from_pretrained(self, model_file): | |
| self.gnn.load_state_dict(torch.load(model_file)) | |
| def forward(self, *argv): | |
| if len(argv) == 4: | |
| x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3] | |
| elif len(argv) == 1: | |
| data = argv[0] | |
| x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch | |
| else: | |
| raise ValueError("unmatched number of arguments.") | |
| node_representation = self.gnn(x, edge_index, edge_attr) | |
| return self.pool(node_representation, batch) | |
| class MLP(nn.Module): | |
| """ | |
| Creates a NN using nn.ModuleList to automatically adjust the number of layers. | |
| For each hidden layer, the number of inputs and outputs is constant. | |
| Inputs: | |
| in_dim (int): number of features contained in the input layer. | |
| out_dim (int): number of features input and output from each hidden layer, | |
| including the output layer. | |
| num_layers (int): number of layers in the network | |
| activation (torch function): activation function to be used during the hidden layers | |
| """ | |
| def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), layer_norm=False, batch_norm=False): | |
| super(MLP, self).__init__() | |
| self.layers = nn.ModuleList() | |
| h_dim = in_dim if out_dim < 10 else out_dim | |
| # create the input layer | |
| for layer in range(num_layers): | |
| if layer == 0: | |
| self.layers.append(nn.Linear(in_dim, h_dim)) | |
| else: | |
| self.layers.append(nn.Linear(h_dim, h_dim)) | |
| if layer_norm: self.layers.append(nn.LayerNorm(h_dim)) | |
| if batch_norm: self.layers.append(nn.BatchNorm1d(h_dim)) | |
| self.layers.append(activation) | |
| self.layers.append(nn.Linear(h_dim, out_dim)) | |
| def forward(self, x): | |
| for i in range(len(self.layers)): | |
| x = self.layers[i](x) | |
| return x | |
| class WeightConv(MessagePassing): | |
| def __init__(self, emb_dim, aggr = "add"): | |
| super(WeightConv, self).__init__() | |
| self.emb_dim = emb_dim | |
| self.linear = torch.nn.Linear(emb_dim, emb_dim) | |
| self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) | |
| self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) | |
| torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) | |
| self.aggr = aggr | |
| def norm(self, edge_index, num_nodes, edge_weight, dtype): | |
| ### assuming that self-loops have been already added in edge_index | |
| edge_index = edge_index[0] | |
| #edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) | |
| row, col = edge_index | |
| deg = scatter_add(edge_weight.view(-1), row, dim=0, dim_size=num_nodes) | |
| deg_inv_sqrt = deg.pow(-1.0) | |
| deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 | |
| norm = deg_inv_sqrt[row] * edge_weight | |
| norm[:-num_nodes] = 1 | |
| return norm | |
| def forward(self, x, edge_index, edge_attr): | |
| #add self loops in the edge space | |
| edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) | |
| #add features corresponding to self-loop edges. | |
| self_loop_attr = torch.ones(x.size(0)) | |
| self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) | |
| edge_attr = torch.cat((edge_attr, self_loop_attr)) | |
| norm = self.norm(edge_index, x.size(0), edge_attr, x.dtype) | |
| #x = self.linear(x) | |
| return self.propagate(edge_index[0], x=x, norm=norm) | |
| def message(self, x_j, norm): | |
| return norm.view(-1, 1) * x_j | |
| def update(self, aggr_out): | |
| return self.linear(aggr_out) | |
| class WeightGNN(torch.nn.Module): | |
| def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0): | |
| super(WeightGNN, self).__init__() | |
| self.num_layer = num_layer | |
| self.drop_ratio = drop_ratio | |
| self.JK = JK | |
| if self.num_layer < 2: | |
| raise ValueError("Number of GNN layers must be greater than 1.") | |
| ###List of MLPs | |
| self.gnns = torch.nn.ModuleList() | |
| for layer in range(num_layer): | |
| self.gnns.append(WeightConv(emb_dim)) | |
| ###List of batchnorms | |
| self.batch_norms = torch.nn.ModuleList() | |
| for layer in range(num_layer): | |
| self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) | |
| self.transform_p = nn.Linear(emb_dim, emb_dim) | |
| self.transform_m = nn.Linear(emb_dim, emb_dim) | |
| # def forward(self, x, edge_index, edge_attr): | |
| def forward(self, *argv): | |
| if len(argv) == 4: | |
| x_p, x_m, edge_index, edge_attr = argv[0], argv[1], argv[2], argv[3] | |
| else: | |
| raise ValueError("unmatched number of arguments.") | |
| # convert pocket prototypes and motifs to the same hidden space | |
| x = torch.cat([self.transform_p(x_p), self.transform_m(x_m)], dim=0) | |
| h_list = [x] | |
| for layer in range(self.num_layer): | |
| h = self.gnns[layer](h_list[layer], edge_index, edge_attr) | |
| h = self.batch_norms[layer](h) | |
| if layer == self.num_layer - 1: | |
| # remove relu for the last layer | |
| h = F.dropout(h, self.drop_ratio, training=self.training) | |
| else: | |
| h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) | |
| h_list.append(h) | |
| ### Different implementations of Jk-concat | |
| if self.JK == "concat": | |
| node_representation = torch.cat(h_list, dim=1) | |
| elif self.JK == "last": | |
| node_representation = h_list[-1] | |
| elif self.JK == "max": | |
| h_list = [h.unsqueeze_(0) for h in h_list] | |
| node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0] | |
| elif self.JK == "sum": | |
| h_list = [h.unsqueeze_(0) for h in h_list] | |
| node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0] | |
| return node_representation | |
| if __name__ == "__main__": | |
| pass | |