#!/usr/bin/python # -*- coding:utf-8 -*- import math import torch import torch.nn as nn import torch.nn.functional as F import pickle from torch_scatter import scatter_softmax, scatter_mean, scatter_sum, scatter_std from tools import _unit_edges_from_block_edges from radial_basis import RadialBasis def stable_norm(input, *args, **kwargs): return torch.norm(input, *args, **kwargs) input = input.clone() with torch.no_grad(): sign = torch.sign(input) input = torch.abs(input) input.clamp_(min=1e-10) input = sign * input return torch.norm(input, *args, **kwargs) class GET(nn.Module): '''Equivariant Adaptive Block Transformer''' def __init__(self, d_hidden, d_radial, n_channel, n_rbf, cutoff=7.0, d_edge=0, n_layers=4, n_head=4, act_fn=nn.SiLU(), residual=True, dropout=0.1, z_requires_grad=True, pre_norm=False, sparse_k=3): super().__init__() ''' :param d_hidden: Number of hidden features :param d_radial: Number of features for calculating geometric relations :param n_channel: Number of channels of coordinates of each unit :param n_rbf: Dimension of RBF feature, 1 for not using rbf :param cutoff: cutoff for RBF :param d_edge: Number of features for the edge features :param n_layers: Number of layer :param act_fn: Non-linearity :param residual: Use residual connections, we recommend not changing this one :param dropout: probability of dropout ''' self.n_layers = n_layers self.pre_norm = pre_norm self.sparse_k = sparse_k if self.pre_norm: self.pre_layernorm = EquivariantLayerNorm(d_hidden, n_channel, n_rbf, cutoff, act_fn) for i in range(0, n_layers): self.add_module(f'layer_{i}', GETLayer(d_hidden, d_radial, n_channel, n_rbf, cutoff, d_edge, n_head, act_fn, residual)) self.add_module(f'layernorm0_{i}', EquivariantLayerNorm(d_hidden, n_channel, n_rbf, cutoff, act_fn)) self.add_module(f'ffn_{i}', EquivariantFFN( d_hidden, 4 * d_hidden, d_hidden, n_channel, n_rbf, act_fn, residual, dropout, z_requires_grad=z_requires_grad if i == n_layers - 1 else True )) self.add_module(f'layernorm1_{i}', EquivariantLayerNorm(d_hidden, n_channel, n_rbf, cutoff, act_fn)) if not z_requires_grad: self._modules[f'layernorm1_{n_layers - 1}'].sigma.requires_grad = False # @torch.no_grad() # def self_loop_edges(self, block_id, n_blocks): # return None # node_ids = torch.arange(n_blocks, device=block_id.device) # [Nb] # self_loop = torch.stack([node_ids, node_ids], dim=1) # [Nb, 2] # (unit_src, unit_dst), _ = _unit_edges_from_block_edges(block_id, self_loop) # return torch.stack([unit_src, unit_dst], dim=0) # [2, \sum n_i^2] def recover_scale(self, Z, block_id, batch_id, record_scale): with torch.no_grad(): unit_batch_id = batch_id[block_id] Z_c = scatter_mean(Z, unit_batch_id, dim=0) # [bs, n_channel, 3] Z_c = Z_c[unit_batch_id] # [N, n_channel, 3] Z_centered = Z - Z_c Z = Z_c + Z_centered / record_scale[unit_batch_id] return Z def forward(self, H, Z, block_id, batch_id, edges, edge_attr=None, cached_unit_edge_info=None): if cached_unit_edge_info is None: with torch.no_grad(): cached_unit_edge_info = _unit_edges_from_block_edges(block_id, edges.T, Z, k=self.sparse_k) # [Eu], Eu = \sum_{i, j \in E} n_i * n_j # # FFN self-loop # self_loop = self.self_loop_edges(block_id, batch_id.shape[0]) batch_size, n_channel = batch_id.max() + 1, Z.shape[1] record_scale = torch.ones((batch_size, n_channel, 1), dtype=torch.float, device=Z.device) if self.pre_norm: H, Z, rescale = self.pre_layernorm(H, Z, block_id, batch_id) record_scale *= rescale for i in range(self.n_layers): # for attention visualization # self._modules[f'layer_{i}'].prefix = self.prefix + f'_layer{i}' H, Z = self._modules[f'layer_{i}'](H, Z, block_id, edges, edge_attr, cached_unit_edge_info) H, Z, rescale = self._modules[f'layernorm0_{i}'](H, Z, block_id, batch_id) record_scale *= rescale H, Z = self._modules[f'ffn_{i}'](H, Z, block_id) H, Z, rescale = self._modules[f'layernorm1_{i}'](H, Z, block_id, batch_id) record_scale *= rescale Z = self.recover_scale(Z, block_id, batch_id, record_scale) return H, Z ''' Below are the implementation of the equivariant adaptive block message passing mechanism ''' class GETLayer(nn.Module): ''' Equivariant Adaptive Block Transformer layer ''' def __init__(self, d_hidden, d_radial, n_channel, n_rbf, cutoff=7.0, d_edge=0, n_head=4, act_fn=nn.SiLU(), residual=True): super(GETLayer, self).__init__() self.residual = residual self.reci_sqrt_d = 1 / math.sqrt(d_radial) self.epsilon = 1e-8 self.n_rbf = n_rbf self.cutoff = cutoff self.n_head = n_head assert d_radial % self.n_head == 0, f'd_radial not compatible with n_head ({d_radial} and {self.n_head})' assert n_rbf % self.n_head == 0, f'n_rbf not compatible with n_head ({n_rbf} and {self.n_head})' d_hidden_head, d_radial_head = d_hidden // self.n_head, d_radial // self.n_head n_rbf_head = n_rbf // self.n_head self.linear_qk = nn.Linear(d_hidden_head, d_radial_head * 2, bias=False) self.linear_v = nn.Linear(d_hidden_head, d_radial_head) if n_rbf > 1: self.rbf = RadialBasis(num_radial=n_rbf, cutoff=cutoff) # self.dist_mlp = nn.Sequential( # nn.Linear(n_channel * n_rbf, 1, bias=False), # act_fn # ) self.att_mlp = nn.Sequential( nn.Linear(d_radial_head * 3 + n_channel * n_rbf_head, d_radial_head), # radial*3 means H_q, H_k and edge_attr act_fn, nn.Linear(d_radial_head, d_radial_head), act_fn ) self.unit_att_linear = nn.Linear(d_radial_head, 1) self.block_att_linear = nn.Linear(d_radial_head, 1) if d_edge != 0: self.edge_linear = nn.Linear(d_edge, d_radial) # self.edge_mlp = nn.Sequential( # nn.Linear(d_edge, d_hidden_head), # act_fn, # nn.Linear(d_hidden_head, 1), # act_fn # ) self.node_mlp = nn.Sequential( nn.Linear(d_radial, d_hidden), act_fn, nn.Linear(d_hidden, d_hidden), act_fn ) self.node_out_linear = nn.Linear(d_hidden, d_hidden) self.coord_mlp = nn.Sequential( nn.Linear(d_radial, d_hidden), act_fn, nn.Linear(d_hidden, n_head * n_channel), act_fn ) self.unit_msg_mlp = nn.Sequential( nn.Linear(d_radial_head + n_channel * n_rbf_head, d_radial_head), act_fn, nn.Linear(d_radial_head, d_radial_head), act_fn ) self.unit_msg_coord_mlp = nn.Sequential( nn.Linear(d_radial_head + n_channel * n_rbf_head, d_radial_head), act_fn, nn.Linear(d_radial_head, d_radial_head), act_fn ) self.unit_msg_coord_linear = nn.Linear(d_radial_head, n_channel) # self.coord_mlp = nn.Sequential( # nn.Linear(1, n_channel), # act_fn # ) def attention(self, H, Z, edges, edge_attr, cached_unit_edge_info): row, col = edges (unit_row, unit_col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) = cached_unit_edge_info # multi-head H = H.view(H.shape[0], self.n_head, -1) # [N, n_head, hidden_size / n_head] # calculate attention H_qk = self.linear_qk(H) H_q, H_k = H_qk[..., 0::2][unit_row], H_qk[..., 1::2][unit_col] # [Eu, n_head, d_radial / n_head] dZ = Z[unit_row] - Z[unit_col] # [E_u, n_channel, 3] # D = dZ.bmm(dZ.transpose(1, 2)).view(D.shape[0], -1) # [Eu, n_channel^2] # D_norm = torch.norm(D + 1e-16, dim=-1, keepdim=True) # D = D / (1 + D_norm) # D = torch.norm(dZ + 1e-16, dim=-1) # [Eu, n_channel] D = stable_norm(dZ, dim=-1) # [Eu, n_channel] if self.n_rbf > 1: n_channel = D.shape[-1] D = self.rbf(D.view(-1)).view(D.shape[0], n_channel, self.n_head, -1) # [Eu, n_channel, n_head, n_rbf / n_head] D = D.transpose(1, 2).reshape(D.shape[0], self.n_head, -1) # [Eu, n_head, n_channel * n_rbf / n_head] else: D = D.unsqueeze(1).repeat(1, self.n_head, 1) # [Eu, n_head, n_channel] # R = self.reci_sqrt_d * (H_q * H_k).sum(-1) + self.dist_mlp(D).squeeze() # [Eu] if edge_attr is None: R_repr = torch.concat([H_q, H_k, D], dim=-1) # [Eu, n_head, (d_radial * 2 + n_channel * n_rbf) / n_head] else: edge_attr = self.edge_linear(edge_attr).view(edge_attr.shape[0], self.n_head, -1) R_repr = torch.concat([H_q, H_k, D, edge_attr[block_edge_id]], dim=-1) R_repr = self.att_mlp(R_repr) # [Eu, n_head, d_radial / n_head] R = self.unit_att_linear(R_repr).squeeze(-1) # [Eu, n_head] alpha = scatter_softmax(R, unit_edge_src_id, dim=0).unsqueeze( -1) # [Eu, n_head, 1], unit-level attention within block-level edges # alpha = F.silu(R).unsqueeze(-1) # beta = scatter_mean(R, block_edge_id) # [Eb] # if edge_attr is not None: # beta = beta + self.edge_mlp(edge_attr).squeeze() # directly use mean of R is not reasonble as the value before softmax has different scales in different pairs # using max(R) - min(R) or max(R) - mean(R) are also not reasonable as the lowerbound will be 0 instead of -inf # so we use pooling on the representation of unit attention beta = self.block_att_linear(scatter_mean(R_repr, block_edge_id, dim=0)).squeeze(-1) # [Eb, n_head] beta = scatter_softmax(beta, row, dim=0) # [Eb, n_head], block-level edge attention # beta = F.silu(beta) # for attention visualize # pickle.dump((alpha, beta, edges, (unit_row, unit_col)), open(f'./attention/{self.prefix}.pkl', 'wb')) beta = beta[block_edge_id[unit_edge_src_start]].unsqueeze(-1) # [Em, n_head, 1], Em = \sum_{i, j \in E} n_i return alpha, beta, (D, R, dZ) def invariant_update(self, H_v, H, alpha, beta, D, cached_unit_edge_info): (unit_row, unit_col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) = cached_unit_edge_info unit_agg_row = unit_row[unit_edge_src_start] # update invariant feature H_v = self.unit_msg_mlp(torch.cat([H_v[unit_col], D], dim=-1)) # [Eu, n_head, d_radial / n_head] H_agg = scatter_sum(alpha * H_v, unit_edge_src_id, dim=0) # [Em, n_head, hidden_size / n_head] H_agg = H_agg.view(H_agg.shape[0], -1) # [Em, hidden_size] H_agg = self.node_mlp(H_agg) # [Em, hidden_size] H_agg = H_agg.view(H_agg.shape[0], self.n_head, -1) # [Em, n_head, hidden_size / n_head] H_agg = scatter_sum(beta * H_agg, unit_agg_row, dim=0, dim_size=H.shape[0]) # [N, n_head, hidden_size / n_head] H_agg = H_agg.view(H_agg.shape[0], -1) # [N, hidden_size] H_agg = self.node_out_linear(H_agg) H = H + H_agg if self.residual else H_agg return H def equivariant_update(self, H_v, Z, alpha, beta, D, dZ, cached_unit_edge_info): (unit_row, unit_col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) = cached_unit_edge_info unit_agg_row = unit_row[unit_edge_src_start] # update equivariant feature # H_v = self.unit_msg_coord_mlp(torch.cat([H_v[unit_col], D], dim=-1)) # [Eu, n_head, n_channel] H_v = self.unit_msg_coord_mlp(torch.cat([H_v[unit_col], D], dim=-1)) # [Eu, n_head, d_radial / n_head] Z_agg = scatter_sum( (alpha * self.unit_msg_coord_linear(H_v)).unsqueeze(-1) * dZ.unsqueeze(1), unit_edge_src_id, dim=0) # [Em, n_head, n_channel, 3] Z_H_agg = scatter_sum(alpha * H_v, unit_edge_src_id, dim=0) # [Em, n_head, d_radial / n_head] Z_H_agg = self.coord_mlp(Z_H_agg.view(Z_H_agg.shape[0], -1)) # [Em, d_radial] Z_H_agg = Z_H_agg.view(Z_H_agg.shape[0], self.n_head, -1) # [Em, n_head, n_channel] Z_agg = scatter_sum( (beta * Z_H_agg).unsqueeze(-1) * Z_agg, unit_agg_row, dim=0, dim_size=Z.shape[0]) # [N, n_head, n_channel, 3] Z_agg = Z_agg.sum(dim=1) # [N, n_channel, 3] Z = Z + Z_agg return Z def forward(self, H, Z, block_id, edges, edge_attr=None, cached_unit_edge_info=None): ''' H: [N, hidden_size], Z: [N, n_channel, 3], block_id: [N], edges: [2, E], list of [n_row] and [n_col] where n_row == n_col == E, nodes from col are used to update nodes from row edge_attr: [E] cached_unit_edge_info: unit level (row, col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) calculated from block edges ''' with torch.no_grad(): if cached_unit_edge_info is None: cached_unit_edge_info = _unit_edges_from_block_edges(block_id, edges.T) # [Eu], Eu = \sum_{i, j \in E} n_i * n_j alpha, beta, (D, R, dZ) = self.attention(H, Z, edges, edge_attr, cached_unit_edge_info) H_v = self.linear_v(H.view(H.shape[0], self.n_head, -1)) # [N, n_head, d_radial / n_head] H = self.invariant_update(H_v, H, alpha, beta, D, cached_unit_edge_info) Z = self.equivariant_update(H_v, Z, alpha, beta, D, dZ, cached_unit_edge_info) return H, Z class EquivariantFFN(nn.Module): def __init__(self, d_in, d_hidden, d_out, n_channel, n_rbf=16, act_fn=nn.SiLU(), residual=True, dropout=0.1, constant=1, z_requires_grad=True) -> None: super().__init__() self.constant = constant self.residual = residual self.n_rbf = n_rbf # self.mlp_msg = nn.Sequential( # nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden), # act_fn, # nn.Dropout(dropout), # nn.Linear(d_hidden, d_hidden), # act_fn, # nn.Dropout(dropout), # ) self.mlp_h = nn.Sequential( nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden), act_fn, nn.Dropout(dropout), nn.Linear(d_hidden, d_hidden), act_fn, nn.Dropout(dropout), nn.Linear(d_hidden, d_out), nn.Dropout(dropout) ) self.mlp_z = nn.Sequential( nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden), act_fn, nn.Dropout(dropout), nn.Linear(d_hidden, d_hidden), act_fn, nn.Dropout(dropout), nn.Linear(d_hidden, n_channel), nn.Dropout(dropout) ) # self.mlp_z = nn.Linear(d_hidden, n_channel) if not z_requires_grad: for param in self.mlp_z.parameters(): param.requires_grad = False self.rbf = RadialBasis(n_rbf, 7.0) # self.linear_radial = nn.Linear(n_channel * n_rbf, d_in) # self.linear_radial = nn.Linear(n_channel * n_channel, d_in) def forward(self, H, Z, block_id): ''' :param H: [N, d_in] :param Z: [N, n_channel, 3] :param block_id: [Nu] ''' # row, col = self_loop # Z_diff = Z[row] - Z[col] # [E, n_channel, 3] # radial = stable_norm(Z_diff, dim=-1) # [E, n_channel] # radial = self.rbf(radial.view(-1)).view(radial.shape[0], -1) # [E, n_channel * n_rbf] # msg = self.mlp_msg(torch.cat([H[row], H[col], radial], dim=-1)) # [E, d_hidden] # agg = scatter_sum(msg, row, dim=0) # [Nu, d_hidden] # H_update = self.mlp_h(torch.cat([H, agg], dim=-1)) # [Nu, d_out] # H = H + H_update if self.residual else H_update # Z = Z + scatter_sum(self.mlp_z(msg).unsqueeze(-1) * Z_diff, row, dim=0) # return H, Z radial, (Z_c, Z_o) = self._radial(Z, block_id) # [N, n_hidden_channel], ([N, 1, 3], [N, n_channel, 3] H_c = scatter_mean(H, block_id, dim=0)[block_id] # [N, d_in] inputs = torch.cat([H, H_c, radial], dim=-1) # [N, d_in + d_in + d_in] H_update = self.mlp_h(inputs) H = H + H_update if self.residual else H_update Z = Z_c + self.mlp_z(inputs).unsqueeze(-1) * Z_o return H, Z def _radial(self, Z, block_id): Z_c = scatter_mean(Z, block_id, dim=0) # [Nb, n_channel, 3] Z_c = Z_c[block_id] Z_o = Z - Z_c # [N, n_channel, 3], no translation D = stable_norm(Z_o, dim=-1) # [N, n_channel] radial = self.rbf(D.view(-1)).view(D.shape[0], -1) # [N, n_channel * n_rbf] # radial = Z_o.bmm(Z_o.transpose(1, 2)) # [N, n_channel, n_channel], no orthogonal transformation # radial = radial.reshape(Z.shape[0], -1) # [N, n_channel^2] # # radial_norm = torch.norm(radial + 1e-16, dim=-1, keepdim=True) # [N, 1] # radial_norm = stable_norm(radial, dim=-1, keepdim=True) # [N, 1] # radial = radial / (self.constant + radial_norm) # normalize for numerical stability # radial = self.linear_radial(radial) # [N, d_in] return radial, (Z_c, Z_o) class EquivariantLayerNorm(nn.Module): def __init__(self, d_hidden, n_channel, n_rbf=16, cutoff=7.0, act_fn=nn.SiLU()) -> None: super().__init__() # invariant self.fuse_scale_ffn = nn.Sequential( nn.Linear(n_channel * n_rbf, d_hidden), act_fn, nn.Linear(d_hidden, d_hidden), act_fn ) self.layernorm = nn.LayerNorm(d_hidden) # geometric sigma = torch.ones((1, n_channel, 1)) self.sigma = nn.Parameter(sigma, requires_grad=True) self.rbf = RadialBasis(num_radial=n_rbf, cutoff=cutoff) def forward(self, H, Z, block_id, batch_id): with torch.no_grad(): _, n_channel, n_axis = Z.shape unit_batch_id = batch_id[block_id] unit_axis_batch_id = unit_batch_id.unsqueeze(-1).repeat(1, n_axis).flatten() # [N * 3] # H = self.layernorm(H) Z_c = scatter_mean(Z, unit_batch_id, dim=0) # [bs, n_channel, 3] Z_c = Z_c[unit_batch_id] # [N, n_channel, 3] Z_centered = Z - Z_c var = scatter_std( Z_centered.transpose(1, 2).reshape(-1, n_channel).contiguous(), unit_axis_batch_id, dim=0) # [bs, n_channel] # var = var[unit_batch_id].unsqueeze(-1) # [N, n_channel, 1] # Z = Z_c + Z_centered / var * self.sigma rescale = (1 / var).unsqueeze(-1) * self.sigma # [bs, n_channel, 1] Z = Z_c + Z_centered * rescale[unit_batch_id] rescale_rbf = self.rbf(rescale.view(-1)).view(rescale.shape[0], -1) # [bs, n_channel * n_rbf] H = H + self.fuse_scale_ffn(rescale_rbf)[unit_batch_id] H = self.layernorm(H) return H, Z, rescale class GETEncoder(nn.Module): def __init__(self, hidden_size, radial_size, n_channel, n_rbf=1, cutoff=7.0, edge_size=16, n_layers=3, n_head=1, dropout=0.1, z_requires_grad=True, stable=False) -> None: super().__init__() self.encoder = GET( hidden_size, radial_size, n_channel, n_rbf, cutoff, edge_size, n_layers, n_head, dropout=dropout, z_requires_grad=z_requires_grad ) def forward(self, H, Z, block_id, batch_id, edges, edge_attr=None): H, pred_Z = self.encoder(H, Z, block_id, batch_id, edges, edge_attr) # block_repr = scatter_mean(H, block_id, dim=0) # [Nb, hidden] block_repr = scatter_sum(H, block_id, dim=0) # [Nb, hidden] block_repr = F.normalize(block_repr, dim=-1) # graph_repr = scatter_mean(block_repr, batch_id, dim=0) # [bs, hidden] graph_repr = scatter_sum(block_repr, batch_id, dim=0) # [bs, hidden] graph_repr = F.normalize(graph_repr, dim=-1) return H, block_repr, graph_repr, pred_Z if __name__ == '__main__': d_hidden = 64 d_radial = 16 n_channel = 2 d_edge = 16 n_rbf = 16 n_head = 4 device = torch.device('cuda:0') model = GET(d_hidden, d_radial, n_channel, n_rbf, d_edge=d_edge, n_head=n_head) model.to(device) model.eval() block_id = torch.tensor([0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4, 5, 6, 6, 6, 6, 7, 7], dtype=torch.long).to(device) batch_id = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1], dtype=torch.long).to(device) src_dst = torch.tensor([[0, 1], [2, 3], [1, 3], [2, 4], [3, 0], [3, 3], [5, 7], [7, 6], [5, 6], [6, 7]], dtype=torch.long).to(device) src_dst = src_dst.T edge_attr = torch.randn(len(src_dst[0]), d_edge).to(device) n_unit = block_id.shape[0] H = torch.randn(n_unit, d_hidden, device=device) Z = torch.randn(n_unit, n_channel, 3, device=device) print(_unit_edges_from_block_edges(block_id, src_dst.T, Z, k=3)) H1, Z1 = model(H, Z, block_id, batch_id, src_dst, edge_attr) # random rotaion matrix U, _, V = torch.linalg.svd(torch.randn(3, 3, device=device, dtype=torch.float)) if torch.linalg.det(U) * torch.linalg.det(V) < 0: U[:, -1] = -U[:, -1] Q1, t1 = U.mm(V), torch.randn(3, device=device) U, _, V = torch.linalg.svd(torch.randn(3, 3, device=device, dtype=torch.float)) if torch.linalg.det(U) * torch.linalg.det(V) < 0: U[:, -1] = -U[:, -1] Q2, t2 = U.mm(V), torch.randn(3, device=device) unit_batch_id = batch_id[block_id] Z[unit_batch_id == 0] = torch.matmul(Z[unit_batch_id == 0], Q1) + t1 Z[unit_batch_id == 1] = torch.matmul(Z[unit_batch_id == 1], Q2) + t2 # Z = torch.matmul(Z, Q) + t H2, Z2 = model(H, Z, block_id, batch_id, src_dst, edge_attr) print(f'invariant feature: {torch.abs(H1 - H2).sum()}') Z1[unit_batch_id == 0] = torch.matmul(Z1[unit_batch_id == 0], Q1) + t1 Z1[unit_batch_id == 1] = torch.matmul(Z1[unit_batch_id == 1], Q2) + t2 print(f'equivariant feature: {torch.abs(Z1 - Z2).sum()}')