Zaixi's picture
1
dcacefd
CUDA_LAUNCH_BLOCKING = 1
import sys
sys.path.append("..")
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Module, Sequential, ModuleList, Linear, Conv1d
from torch_geometric.nn import radius_graph, knn_graph
from torch_geometric.utils import add_self_loops
from torch_scatter import scatter_sum, scatter_softmax, scatter_mean, scatter_std
import numpy as np
from .radial_basis import RadialBasis
import copy
from math import pi as PI
from ..common import GaussianSmearing, ShiftedSoftplus
from ..protein_features import ProteinFeatures
residue_atom_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]).float()
class AttentionInteractionBlock(Module):
def __init__(self, hidden_channels, edge_channels, key_channels, num_heads=1):
super().__init__()
assert hidden_channels % num_heads == 0
assert key_channels % num_heads == 0
self.hidden_channels = hidden_channels
self.key_channels = key_channels
self.num_heads = num_heads
self.k_lin = Conv1d(hidden_channels, key_channels, 1, groups=num_heads, bias=False)
self.q_lin = Conv1d(hidden_channels, key_channels, 1, groups=num_heads, bias=False)
self.v_lin = Conv1d(hidden_channels, hidden_channels, 1, groups=num_heads, bias=False)
self.weight_k_net = Sequential(
Linear(edge_channels, key_channels // num_heads),
ShiftedSoftplus(),
Linear(key_channels // num_heads, key_channels // num_heads),
)
self.weight_k_lin = Linear(key_channels // num_heads, key_channels // num_heads)
self.weight_v_net = Sequential(
Linear(edge_channels, hidden_channels // num_heads),
ShiftedSoftplus(),
Linear(hidden_channels // num_heads, hidden_channels // num_heads),
)
self.weight_v_lin = Linear(hidden_channels // num_heads, hidden_channels // num_heads)
self.centroid_lin = Linear(hidden_channels, hidden_channels)
self.act = ShiftedSoftplus()
self.out_transform = Linear(hidden_channels, hidden_channels)
self.layernorm_attention = nn.LayerNorm(hidden_channels)
self.layernorm_ffn = nn.LayerNorm(hidden_channels)
def forward(self, x, edge_index, edge_attr):
"""
Args:
x: Node features, (N, H).
edge_index: (2, E).
edge_attr: (E, H)
"""
N = x.size(0)
row, col = edge_index # (E,) , (E,)
# self-attention layer_norm
y = self.layernorm_attention(x)
# Project to multiple key, query and value spaces
h_keys = self.k_lin(y.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, K_per_head)
h_queries = self.q_lin(y.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, K_per_head)
h_values = self.v_lin(y.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, H_per_head)
# Compute keys and queries
W_k = self.weight_k_net(edge_attr) # (E, K_per_head)
keys_j = self.weight_k_lin(W_k.unsqueeze(1) * h_keys[col]) # (E, heads, K_per_head)
queries_i = h_queries[row] # (E, heads, K_per_head)
# Compute attention weights (alphas)
qk_ij = (queries_i * keys_j).sum(-1) # (E, heads)
alpha = scatter_softmax(qk_ij, row, dim=0)
# Compose messages
W_v = self.weight_v_net(edge_attr) # (E, H_per_head)
msg_j = self.weight_v_lin(W_v.unsqueeze(1) * h_values[col]) # (E, heads, H_per_head)
msg_j = alpha.unsqueeze(-1) * msg_j # (E, heads, H_per_head)
# Aggregate messages
aggr_msg = scatter_sum(msg_j, row, dim=0, dim_size=N).view(N, -1) # (N, heads*H_per_head)
x = aggr_msg + x
y = self.layernorm_ffn(x)
out = self.out_transform(self.act(y)) + x
return out
class CFTransformerEncoder(Module):
def __init__(self, hidden_channels=128, edge_channels=64, key_channels=128, num_heads=4, num_interactions=6, k=32,
cutoff=10.0):
super().__init__()
self.hidden_channels = hidden_channels
self.edge_channels = edge_channels
self.key_channels = key_channels
self.num_heads = num_heads
self.num_interactions = num_interactions
self.k = k
self.cutoff = cutoff
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
self.interactions = ModuleList()
for _ in range(num_interactions):
block = AttentionInteractionBlock(
hidden_channels=hidden_channels,
edge_channels=edge_channels,
key_channels=key_channels,
num_heads=num_heads,
)
self.interactions.append(block)
@property
def out_channels(self):
return self.hidden_channels
def forward(self, node_attr, pos, batch):
# edge_index = radius_graph(pos, self.cutoff, batch=batch, loop=False)
edge_index = knn_graph(pos, k=self.k, batch=batch, flow='target_to_source')
edge_length = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
edge_attr = self.distance_expansion(edge_length)
h = node_attr
for interaction in self.interactions:
h = h + interaction(h, edge_index, edge_attr)
return h
# residue level graph transformer
class AAEmbedding(nn.Module):
def __init__(self, device):
super(AAEmbedding, self).__init__()
self.hydropathy = {'#': 0, "I": 4.5, "V": 4.2, "L": 3.8, "F": 2.8, "C": 2.5, "M": 1.9, "A": 1.8, "W": -0.9,
"G": -0.4, "T": -0.7, "S": -0.8, "Y": -1.3, "P": -1.6, "H": -3.2, "N": -3.5, "D": -3.5,
"Q": -3.5, "E": -3.5, "K": -3.9, "R": -4.5}
self.volume = {'#': 0, "G": 60.1, "A": 88.6, "S": 89.0, "C": 108.5, "D": 111.1, "P": 112.7, "N": 114.1,
"T": 116.1, "E": 138.4, "V": 140.0, "Q": 143.8, "H": 153.2, "M": 162.9, "I": 166.7, "L": 166.7,
"K": 168.6, "R": 173.4, "F": 189.9, "Y": 193.6, "W": 227.8}
self.charge = {**{'R': 1, 'K': 1, 'D': -1, 'E': -1, 'H': 0.1}, **{x: 0 for x in 'ABCFGIJLMNOPQSTUVWXYZ#'}}
self.polarity = {**{x: 1 for x in 'RNDQEHKSTY'}, **{x: 0 for x in "ACGILMFPWV#"}}
self.acceptor = {**{x: 1 for x in 'DENQHSTY'}, **{x: 0 for x in "RKWACGILMFPV#"}}
self.donor = {**{x: 1 for x in 'RKWNQHSTY'}, **{x: 0 for x in "DEACGILMFPV#"}}
ALPHABET = ['#', 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y',
'V']
self.embedding = torch.tensor([
[self.hydropathy[aa], self.volume[aa] / 100, self.charge[aa], self.polarity[aa], self.acceptor[aa],
self.donor[aa]]
for aa in ALPHABET]).to(device)
def to_rbf(self, D, D_min, D_max, stride):
D_count = int((D_max - D_min) / stride)
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu.view(1, -1) # [1, K]
D_expand = torch.unsqueeze(D, -1) # [N, 1]
return torch.exp(-((D_expand - D_mu) / stride) ** 2)
def transform(self, aa_vecs):
return torch.cat([
self.to_rbf(aa_vecs[:, 0], -4.5, 4.5, 0.1),
self.to_rbf(aa_vecs[:, 1], 0, 2.2, 0.1),
self.to_rbf(aa_vecs[:, 2], -1.0, 1.0, 0.25),
torch.sigmoid(aa_vecs[:, 3:] * 6 - 3),
], dim=-1)
def dim(self):
return 90 + 22 + 8 + 3
def forward(self, x, raw=False):
# B, N = x.size(0), x.size(1)
# aa_vecs = self.embedding[x.view(-1)].view(B, N, -1)
aa_vecs = self.embedding[x.view(-1)]
rbf_vecs = self.transform(aa_vecs)
return aa_vecs if raw else rbf_vecs
def soft_forward(self, x):
aa_vecs = torch.matmul(x, self.embedding)
rbf_vecs = self.transform(aa_vecs)
return rbf_vecs
class TransformerLayer(nn.Module):
def __init__(self, num_hidden, num_heads=4, dropout=0.1):
super(TransformerLayer, self).__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
self.dropout_attention = nn.Dropout(dropout)
self.dropout_ffn = nn.Dropout(dropout)
self.self_attention_norm = nn.LayerNorm(num_hidden)
self.ffn_norm = nn.LayerNorm(num_hidden)
self.attention = ResidueAttention(num_hidden, num_heads)
self.ffn = PositionWiseFeedForward(num_hidden, num_hidden)
def forward(self, h_V, h_E, E_idx):
""" Parallel computation of full transformer layer """
# Self-attention
y = self.self_attention_norm(h_V)
y = self.attention(y, h_E, E_idx)
h_V = h_V + self.dropout_attention(y)
# Position-wise feedforward
y = self.ffn_norm(h_V)
y = self.ffn(y)
h_V = h_V + self.dropout_ffn(y)
return h_V
class PositionWiseFeedForward(nn.Module):
def __init__(self, num_hidden, num_ff):
super(PositionWiseFeedForward, self).__init__()
self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
def forward(self, h_V):
h = F.relu(self.W_in(h_V))
h = self.W_out(h)
return h
class ResidueAttention(nn.Module):
def __init__(self, num_hidden, num_heads=4):
super(ResidueAttention, self).__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
# Self-attention layers: {queries, keys, values, output}
self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)
self.W_K = nn.Linear(num_hidden * 2, num_hidden, bias=False)
self.W_V = nn.Linear(num_hidden * 2, num_hidden, bias=False)
self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
self.act = ShiftedSoftplus()
self.layernorm = nn.LayerNorm(num_hidden)
def forward(self, h_V, h_E, edge_index):
""" Self-attention, graph-structured O(Nk)
Args:
h_V: Node features [N_batch, N_nodes, N_hidden]
h_E: Neighbor features [N_batch, N_nodes, K, N_hidden]
mask_attend: Mask for attention [N_batch, N_nodes, K]
Returns:
h_V: Node update
"""
# Queries, Keys, Values
n_edges = h_E.shape[0]
n_nodes = h_V.shape[0]
n_heads = self.num_heads
row, col = edge_index # (E,) , (E,)
d = int(self.num_hidden / n_heads)
Q = self.W_Q(h_V).view([n_nodes, n_heads, 1, d])
K = self.W_K(torch.cat([h_E, h_V[col]], dim=-1)).view([n_edges, n_heads, d, 1])
V = self.W_V(torch.cat([h_E, h_V[col]], dim=-1)).view([n_edges, n_heads, d])
# Attention with scaled inner product
attend_logits = torch.matmul(Q[row], K).view([n_edges, n_heads]) # (E, heads)
alpha = scatter_softmax(attend_logits, row, dim=0) / np.sqrt(d)
# Compose messages
msg_j = alpha.unsqueeze(-1) * V # (E, heads, H_per_head)
# Aggregate messages
aggr_msg = scatter_sum(msg_j, row, dim=0, dim_size=n_nodes).view(n_nodes, -1) # (N, heads*H_per_head)
h_V_update = self.W_O(self.act(aggr_msg))
return h_V_update
# hierachical graph transformer encoder
class HierEncoder(Module):
def __init__(self, hidden_channels=128, edge_channels=64, key_channels=128, num_heads=4, num_interactions=6, k=32,
cutoff=10.0, device='cuda:0'):
super().__init__()
self.hidden_channels = hidden_channels
self.edge_channels = edge_channels
self.key_channels = key_channels
self.num_heads = num_heads
self.num_interactions = num_interactions
self.k = k
self.cutoff = cutoff
self.device = device
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
self.interactions = ModuleList()
for _ in range(num_interactions):
block = AttentionInteractionBlock(
hidden_channels=hidden_channels,
edge_channels=edge_channels,
key_channels=key_channels,
num_heads=num_heads,
)
self.interactions.append(block)
# Residue level settings
self.residue_feat = AAEmbedding(device) # for residue node feature
self.features = ProteinFeatures(top_k=8) # for residue edge feature
self.W_v = nn.Linear(hidden_channels + self.residue_feat.dim(), hidden_channels, bias=True)
self.W_e = nn.Linear(self.features.feature_dimensions, hidden_channels, bias=True)
self.residue_encoder_layers = nn.ModuleList([TransformerLayer(hidden_channels, dropout=0.1) for _ in range(2)])
self.T_a = nn.Sequential(nn.Linear(2 * hidden_channels + edge_channels, hidden_channels), nn.ReLU(),
nn.Linear(hidden_channels, 1))
self.T_x = nn.Sequential(nn.Linear(3 * hidden_channels, hidden_channels), nn.ReLU(),
nn.Linear(hidden_channels, 14))
@property
def out_channels(self):
return self.hidden_channels
def forward(self, node_attr, pos, batch_ctx, batch, pred_res_type, mask_protein, external_index, backbone=True,
mask=True):
S_id, R = batch['res_idx'], batch['amino_acid']
residue_batch, atom2residue = batch['amino_acid_batch'], batch['atom2residue']
edit_residue, edit_atom = batch['protein_edit_residue'], batch['protein_edit_atom']
if mask:
R[batch['random_mask_residue']] = 0
R = F.one_hot(R, num_classes=21).float()
if backbone:
atom2residue, edit_atom = batch['atom2residue_backbone'], batch['protein_edit_atom_backbone']
R_edit = R[edit_residue]
R_edit[:, 1:] = pred_res_type
R[edit_residue] = R_edit
R[:, 0] = 0
edge_index = knn_graph(pos, k=self.k, batch=batch_ctx, flow='target_to_source')
edge_length = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
edge_attr = self.distance_expansion(edge_length)
h = node_attr
for interaction in self.interactions:
h = interaction(h, edge_index, edge_attr)
h_ligand_coarse = scatter_sum(h[~mask_protein], batch['ligand_atom_batch'], dim=0)
pos_ligand_coarse = scatter_sum(batch['ligand_pos'], batch['ligand_atom_batch'], dim=0)
E, residue_edge_index, residue_edge_length, edge_index_new, E_new = self.features(pos_ligand_coarse,
batch['protein_edit_residue'],
batch['residue_pos'], S_id,
residue_batch)
h_protein = h[mask_protein]
V = torch.cat([self.residue_feat.soft_forward(R), scatter_sum(h_protein, atom2residue, dim=0)], dim=-1)
h_res = self.W_v(V)
h_res = torch.cat([h_res, h_ligand_coarse])
edge_index_combined = torch.cat([residue_edge_index, edge_index_new], 1)
E = torch.cat([E, E_new], 0)
h_E = self.W_e(E)
for layer in self.residue_encoder_layers:
h_res = layer(h_res, h_E, edge_index_combined)
# update X:
h_res = h_res[:len(residue_batch)]
h_E = h_E[:residue_edge_index.size(1)]
# protein internal update
mij = torch.cat([h_res[residue_edge_index[0]], h_res[residue_edge_index[1]], h_E], dim=-1)
if backbone:
protein_pos = pos[mask_protein]
ligand_pos = pos[~mask_protein]
N = atom2residue.max() + 1
X_bb = torch.zeros(N, 4, 3).to(pos.device)
for j in range(N):
X_bb[j] = protein_pos[atom2residue == j][:4] # 4 backbone atoms [N,4,3]
xij = X_bb[residue_edge_index[0]] - X_bb[residue_edge_index[1]] # [N,4,3]
dij = xij.norm(dim=-1) + 1e-6 # [N,4]
fij = torch.maximum(self.T_x(mij)[:, :4], 3.8 - dij) # break term [N,4]
xij = xij / dij.unsqueeze(-1) * fij.unsqueeze(-1)
f_res = scatter_mean(xij, residue_edge_index[0], dim=0) # [N,4,3]
X_bb[edit_residue] += f_res.clamp(min=-5.0, max=5.0)[edit_residue]
# Clash correction
for _ in range(2):
xij = X_bb[residue_edge_index[0]] - X_bb[residue_edge_index[1]] # [N,4,3]
dij = xij.norm(dim=-1) + 1e-6 # [N,4]
fij = F.relu(3.8 - dij) # repulsion term [N,4]
xij = xij / dij.unsqueeze(-1) * fij.unsqueeze(-1)
f_res = scatter_mean(xij, residue_edge_index[0], dim=0) # [N,4,3]
X_bb[edit_residue] += f_res.clamp(min=-5.0, max=5.0)[edit_residue]
# protein-ligand external update
protein_pos[edit_atom] = X_bb[edit_residue].view(-1, 3)
pos[mask_protein] = protein_pos
dij = torch.norm(pos[mask_protein][external_index[0]] - pos[~mask_protein][external_index[1]], dim=1) + 1e-6
mij = torch.cat(
[h[mask_protein][external_index[0]], h[~mask_protein][external_index[1]], self.distance_expansion(dij)],
dim=-1)
xij = pos[mask_protein][external_index[0]] - pos[~mask_protein][external_index[1]]
fij = torch.maximum(self.T_a(mij).squeeze(-1), 1.5 - dij)
xij = xij / dij.unsqueeze(-1) * fij.unsqueeze(-1)
f_atom = scatter_mean(xij, external_index[0], dim=0, dim_size=protein_pos.size(0))
protein_pos += f_atom
f_ligand_atom = scatter_mean(xij, external_index[1], dim=0, dim_size=ligand_pos.size(0))
ligand_pos -= f_ligand_atom * 0.05
else:
protein_pos = pos[mask_protein]
ligand_pos = pos[~mask_protein]
X_avg = scatter_mean(protein_pos, atom2residue, dim=0)
X = X_avg.unsqueeze(1).repeat(1, 14, 1)
N = atom2residue.max() + 1
mask = torch.zeros(N, 14, dtype=bool).to(protein_pos.device)
residue_natoms = atom2residue.bincount()
for j in range(N):
mask[j][:residue_natoms[j]] = 1 # all atoms mask [N,14]
X[j][:residue_natoms[j]] = protein_pos[atom2residue == j]
xij = X[residue_edge_index[0]] - X_avg[residue_edge_index[1]].unsqueeze(1) # [N,14,3]
dij = xij.norm(dim=-1) + 1e-6 # [N,14]
fij = torch.maximum(self.T_x(mij), 3.8 - dij) # break term [N,14]
xij = xij / dij.unsqueeze(-1) * fij.unsqueeze(-1)
f_res = scatter_mean(xij, residue_edge_index[0], dim=0) # [N,14,3]
f_res[:, :4] *= 0.1
X[edit_residue] += f_res.clamp(min=-5.0, max=5.0)[edit_residue]
for _ in range(2):
protein_pos = X[mask]
X_avg = scatter_mean(protein_pos, atom2residue, dim=0)
xij = X[residue_edge_index[0]] - X_avg[residue_edge_index[1]].unsqueeze(1) # [N,14,3]
dij = xij.norm(dim=-1) + 1e-6 # [N,14]
fij = F.relu(3.8 - dij) # repulsion term [N,14]
xij = xij / dij.unsqueeze(-1) * fij.unsqueeze(-1)
f_res = scatter_mean(xij, residue_edge_index[0], dim=0) # [N,14,3]
X[edit_residue] += f_res.clamp(min=-5.0, max=5.0)[edit_residue]
# protein-ligand external update
protein_pos = X[mask]
pos[mask_protein] = protein_pos
dij = torch.norm(pos[mask_protein][external_index[0]] - pos[~mask_protein][external_index[1]], dim=1) + 1e-6
mij = torch.cat(
[h[mask_protein][external_index[0]], h[~mask_protein][external_index[1]], self.distance_expansion(dij)],
dim=-1)
xij = pos[mask_protein][external_index[0]] - pos[~mask_protein][external_index[1]]
fij = torch.maximum(self.T_a(mij).squeeze(-1), 1.5 - dij)
xij = xij / dij.unsqueeze(-1) * fij.unsqueeze(-1)
f_atom = scatter_mean(xij, external_index[0], dim=0, dim_size=protein_pos.size(0))
f_atom[batch['edit_backbone']] *= 0.1
protein_pos += f_atom
f_ligand_atom = scatter_mean(xij, external_index[1], dim=0, dim_size=ligand_pos.size(0))
ligand_pos -= f_ligand_atom * 0.05
return h, h_res, protein_pos, ligand_pos
# bilevel encoder
class BilevelEncoder(Module):
def __init__(self, hidden_channels=128, edge_channels=64, key_channels=128, num_heads=4, num_interactions=6, k=8,
cutoff=10.0, device='cuda:0'):
super().__init__()
self.hidden_channels = hidden_channels
self.edge_channels = edge_channels
self.key_channels = key_channels
self.num_heads = num_heads
self.num_interactions = num_interactions
self.k = k
self.cutoff = cutoff
self.device = device
self.esm_refine = True
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
# Residue level settings
self.atom_pos_embedding = nn.Embedding(14, self.hidden_channels)
self.residue_embedding = nn.Embedding(21, self.hidden_channels) # one embedding for mask
self.W_Q = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_K = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_V = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_K_lig = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_V_lig = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_O = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_O_lig = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_O_lig1 = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.act = ShiftedSoftplus()
self.layernorm = nn.LayerNorm(hidden_channels)
self.layernorm1 = nn.LayerNorm(hidden_channels)
self.dropout = nn.Dropout(0.1)
self.T_i = nn.Sequential(nn.Linear(2 * hidden_channels + edge_channels, hidden_channels), nn.ReLU(),
nn.Linear(hidden_channels, 1))
self.T_e1 = nn.Sequential(nn.Linear(2 * hidden_channels + edge_channels, hidden_channels), nn.ReLU(),
nn.Linear(hidden_channels, 1))
self.T_e2 = nn.Sequential(nn.Linear(2 * hidden_channels + edge_channels, hidden_channels), nn.ReLU(),
nn.Linear(hidden_channels, 1))
self.residue_mlp = Linear(hidden_channels, 20)
self.sigma_D = Sequential(
Linear(edge_channels, hidden_channels),
nn.ReLU(),
Linear(hidden_channels, num_heads),
)
self.sigma_D1 = Sequential(
Linear(edge_channels, hidden_channels),
nn.ReLU(),
Linear(hidden_channels, num_heads),
)
self.residue_atom_mask = residue_atom_mask.to(device)
@property
def out_channels(self):
return self.hidden_channels
def connect_edges(self, res_X, batch):
edge_index = knn_graph(res_X[:, 1], k=self.k, batch=batch, flow='target_to_source')
edge_index, _ = add_self_loops(edge_index, num_nodes=res_X.size(0)) # add self loops
return edge_index
def _forward(self, res_H, res_X, res_S, batch, ligand_pos, ligand_feat, ligand_mask, edit_residue_num, residue_mask):
atom_mask = self.residue_atom_mask[res_S]
edge_index = self.connect_edges(res_X, batch)
row, col = edge_index
R_ij = torch.cdist(res_X[row], res_X[col], p=2)
dist_rep = self.distance_expansion(R_ij).view(row.shape[0], res_X.shape[1], res_X.shape[1], -1)
n_nodes = res_H.shape[0]
n_edges = edge_index.shape[1]
n_heads = self.num_heads
n_channels = res_X.shape[1]
d = int(self.hidden_channels / n_heads)
res_H = self.layernorm(res_H)
Q = self.W_Q(res_H).view([n_nodes, n_channels, n_heads, d])
K = self.W_K(res_H).view([n_nodes, n_channels, n_heads, d])
V = self.W_V(res_H).view([n_nodes, n_channels, n_heads, d])
# Attention with scaled inner product
attend_logits = torch.matmul(Q[row].transpose(1, 2), K[col].permute(0, 2, 3, 1)).view([n_edges, n_heads, n_channels, n_channels])
attend_logits /= np.sqrt(d)
attend_logits = attend_logits + self.sigma_D(dist_rep).permute(0, 3, 1, 2)
attend_mask = (atom_mask[row].unsqueeze(-1) @ atom_mask[col].unsqueeze(-2))
attend_mask = attend_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
atom_mask_head = atom_mask.unsqueeze(1).repeat(1, n_heads, 1)
r_ij = atom_mask_head[row].unsqueeze(-2) @ attend_logits @ atom_mask_head[col].unsqueeze(-1)
r_ij = r_ij.squeeze() / attend_mask.sum(-1).sum(-1)
beta = scatter_softmax(r_ij, row, dim=0)
attend_logits = torch.softmax(attend_logits, dim=-1)
attend_logits = attend_logits * attend_mask
attend_logits = attend_logits/(attend_logits.norm(1, dim=-1).unsqueeze(-1)+1e-7)
alpha_ij_Vj = attend_logits @ V[col].transpose(1, 2)
# res_H update
update_H = scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * alpha_ij_Vj, row, dim=0, dim_size=n_nodes)
update_H = update_H.transpose(1, 2).reshape(n_nodes, n_channels, -1) # (N, channels, heads*H_per_head)
res_H = res_H + self.dropout(self.W_O(self.act(update_H)))
res_H = res_H * atom_mask.unsqueeze(-1)
# res_X update
X_ij = res_X[row].unsqueeze(-2) - res_X[col].unsqueeze(1)
X_ij = X_ij/(X_ij.norm(2, dim=-1).unsqueeze(-1)+1e-5)
# Aggregate messages
Q = Q.view(n_nodes, n_channels, -1)
K = K.view(n_nodes, n_channels, -1)
p_idx, q_idx = torch.cartesian_prod(torch.arange(n_channels), torch.arange(n_channels)).chunk(2, dim=-1)
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
input = torch.cat([Q[row][:, p_idx].view(-1, self.hidden_channels), K[col][:, q_idx].view(-1, self.hidden_channels), dist_rep.view(-1, self.edge_channels)], dim=-1)
f = self.dropout(self.T_i(input).view(n_edges, n_channels, n_channels))
#attend_mask = (atom_mask[row].unsqueeze(-1) @ atom_mask[col].unsqueeze(-2)).bool()
f = f * attend_logits.mean(1)
res_X[residue_mask] = res_X[residue_mask] + torch.clamp(scatter_sum(beta.mean(-1).unsqueeze(-1).unsqueeze(-1) * (f.unsqueeze(-1) * X_ij).sum(-2), row, dim=0, dim_size=n_nodes)[residue_mask], min=-3.0, max=3.0)
# consider ligand
batch_size = batch.max().item() + 1
lig_channel = ligand_feat.shape[1]
row1 = torch.arange(n_nodes).to(self.device)[residue_mask]
col1 = torch.repeat_interleave(torch.arange(batch_size).to(self.device), edit_residue_num)
n_edges = row1.shape[0]
Q = Q.view([n_nodes, n_channels, n_heads, d])
ligand_feat = self.layernorm1(ligand_feat)
K_lig = self.W_K_lig(ligand_feat).view([batch_size, lig_channel, n_heads, d])
V_lig = self.W_V_lig(ligand_feat).view([batch_size, lig_channel, n_heads, d])
R_ij = torch.cdist(res_X[row1], ligand_pos[col1], p=2)
dist_rep1 = self.distance_expansion(R_ij).view(row1.shape[0], res_X.shape[1], ligand_pos.shape[1], -1)
attend_logits = torch.matmul(Q[row1].transpose(1, 2), K_lig[col1].permute(0, 2, 3, 1)).view([n_edges, n_heads, n_channels, lig_channel])
attend_logits /= np.sqrt(d)
attend_logits = attend_logits + self.sigma_D1(dist_rep1).permute(0, 3, 1, 2)
attend_mask = (atom_mask[row1].unsqueeze(-1) @ ligand_mask[col1].unsqueeze(-2))
attend_mask = attend_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
atom_mask_head = atom_mask.unsqueeze(1).repeat(1, n_heads, 1)
ligand_mask_head = ligand_mask.unsqueeze(1).repeat(1, n_heads, 1)
r_ij = atom_mask_head[row1].unsqueeze(-2) @ attend_logits @ ligand_mask_head[col1].unsqueeze(-1)
r_ij = r_ij.squeeze() / attend_mask.sum(-1).sum(-1)
beta = scatter_softmax(r_ij, col1, dim=0)
attend_logits_res = torch.softmax(attend_logits, dim=-1)
attend_logits_res = attend_logits_res * attend_mask
attend_logits_res = attend_logits_res / (attend_logits_res.norm(1, dim=-1).unsqueeze(-1) + 1e-7)
alpha_ij_Vj = attend_logits_res @ V_lig[col1].transpose(1, 2)
attend_logits_lig = torch.softmax(attend_logits, dim=-2) * attend_mask
attend_logits_lig = attend_logits_lig / (attend_logits_res.norm(1, dim=-1).unsqueeze(-1) + 1e-7)
res_H[residue_mask] = res_H[residue_mask] + self.dropout(self.W_O_lig(self.act(alpha_ij_Vj.transpose(1, 2).reshape(residue_mask.sum(), n_channels, -1))))
res_H = res_H * atom_mask.unsqueeze(-1)
alpha_ij_Vj_lig = attend_logits_lig.transpose(-1, -2) @ V[row1].transpose(1, 2)
update_lig_feat = scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * alpha_ij_Vj_lig, col1, dim=0, dim_size=batch_size)
ligand_feat = ligand_feat + self.dropout(self.W_O_lig1(self.act(update_lig_feat.transpose(1, 2).reshape(batch_size, lig_channel, -1))))
X_ij = res_X[row1].unsqueeze(-2) - ligand_pos[col1].unsqueeze(1)
X_ij = X_ij / (X_ij.norm(2, dim=-1).unsqueeze(-1) + 1e-5)
Q = Q.view(n_nodes, n_channels, -1)
K_lig = K_lig.view(batch_size, lig_channel, -1)
p_idx, q_idx = torch.cartesian_prod(torch.arange(n_channels), torch.arange(lig_channel)).chunk(2, dim=-1)
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
f_lig = self.dropout(self.T_e1(torch.cat([Q[row1][:, p_idx].view(-1, self.hidden_channels), K_lig[col1][:, q_idx].view(-1, self.hidden_channels), dist_rep1.view(-1, self.edge_channels)], dim=-1)).view(n_edges, n_channels, lig_channel))
f_res = self.dropout(self.T_e2(torch.cat([Q[row1][:, p_idx].view(-1, self.hidden_channels), K_lig[col1][:, q_idx].view(-1, self.hidden_channels), dist_rep1.view(-1, self.edge_channels)], dim=-1)).view(n_edges, n_channels, lig_channel))
attend_mask = (atom_mask[row1].unsqueeze(-1) @ ligand_mask[col1].unsqueeze(-2)).bool()
f_lig = f_lig * attend_logits_lig.mean(1)
f_res = f_res * attend_mask
ligand_pos = ligand_pos + scatter_sum(beta.mean(-1).unsqueeze(-1).unsqueeze(-1) * (f_lig.unsqueeze(-1) * X_ij).sum(1), col1, dim=0, dim_size=batch_size)
res_X[residue_mask] = res_X[residue_mask] + torch.clamp((f_res.unsqueeze(-1) * - X_ij).mean(-2), min=-3.0, max=3.0)
return res_H, res_X, ligand_pos, ligand_feat
def forward(self, res_H, res_X, res_S, batch, full_seq, ligand_pos, ligand_feat, ligand_mask, edit_residue_num, residue_mask, esmadapter, full_seq_mask, r10_mask):
for e in range(4):
res_H, res_X, ligand_pos, ligand_feat = self._forward(res_H, res_X, res_S, batch, ligand_pos, ligand_feat, ligand_mask, edit_residue_num, residue_mask)
# predict residue types
h_residue = res_H.sum(-2)
batch_size = batch.max().item() + 1
encoder_out = {'feats': torch.zeros(batch_size, full_seq.shape[1], self.hidden_channels).to(self.device)}
encoder_out['feats'][r10_mask] = h_residue.view(-1, self.hidden_channels)
init_pred = full_seq
decode_logits = esmadapter(init_pred, encoder_out)['logits']
pred_res_type = decode_logits[full_seq_mask][:, 4:24]
return res_H, res_X, pred_res_type, ligand_pos
# bilevel layer
class GETLayer(Module):
def __init__(self, hidden_channels=128, edge_channels=64, key_channels=128, num_heads=4, num_interactions=6, k=8,
cutoff=10.0, device='cuda:0', sparse_k=3):
super().__init__()
self.consider_ligand = False
self.sparse_k = sparse_k
self.hidden_channels = hidden_channels
self.edge_channels = edge_channels
self.key_channels = key_channels
self.num_heads = num_heads
self.num_interactions = num_interactions
self.d = int(self.hidden_channels / self.num_heads)
self.k = k
self.cutoff = cutoff
self.device = device
self.esm_refine = True
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
# Residue level settings
self.atom_pos_embedding = nn.Embedding(14, self.hidden_channels)
self.residue_embedding = nn.Embedding(21, self.hidden_channels) # one embedding for mask
self.W_Q = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_K = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_V = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_Q_lig = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_K_lig = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_V_lig = nn.Linear(self.hidden_channels, self.hidden_channels, bias=False)
self.W_O = nn.Linear(self.hidden_channels, self.hidden_channels)
self.W_O_lig = nn.Linear(self.hidden_channels, self.hidden_channels)
self.W_O_lig1 = nn.Linear(self.hidden_channels, self.hidden_channels)
self.act = nn.SiLU()
self.layernorm = nn.LayerNorm(hidden_channels)
self.layernorm1 = nn.LayerNorm(hidden_channels)
self.dropout = nn.Dropout(0.1)
self.T_i = nn.Sequential(nn.Linear(3 * self.d, self.d), nn.SiLU(),
nn.Linear(self.d, 1))
self.T_e1 = nn.Sequential(nn.Linear(3 * self.d, self.d), nn.SiLU(),
nn.Linear(self.d, 1))
self.T_e2 = nn.Sequential(nn.Linear(3 * self.hidden_channels, self.hidden_channels), nn.SiLU(),
nn.Linear(self.hidden_channels, self.num_heads))
self.residue_mlp = Linear(hidden_channels, 20, bias=True)
self.sigma_D = Sequential(
Linear(edge_channels, hidden_channels),
nn.SiLU(),
Linear(hidden_channels, num_heads),
)
self.sigma_D1 = Sequential(
Linear(edge_channels, hidden_channels),
nn.SiLU(),
Linear(hidden_channels, num_heads),
)
self.sigma_v = Linear(edge_channels, hidden_channels)
self.block_mlp_invariant = nn.Sequential(nn.Linear(self.d, self.d), nn.SiLU(), nn.Linear(self.d, self.d))
self.block_mlp_equivariant = nn.Sequential(nn.Linear(self.d, self.d), nn.SiLU(), nn.Linear(self.d, self.d))
@property
def out_channels(self):
return self.hidden_channels
def connect_edges(self, res_X, batch):
edge_index = knn_graph(res_X[:, 1], k=self.k, batch=batch, flow='target_to_source')
edge_index, _ = add_self_loops(edge_index, num_nodes=res_X.size(0)) # add self loops
return edge_index
def attention(self, res_H, res_X, atom_mask, batch, residue_mask):
edge_index = self.connect_edges(res_X, batch)
row, col = edge_index
R_ij = torch.cdist(res_X[row], res_X[col], p=2)
dist_rep = self.distance_expansion(R_ij).view(row.shape[0], res_X.shape[1], res_X.shape[1], -1)
n_nodes = res_H.shape[0]
n_edges = edge_index.shape[1]
n_channels = res_X.shape[1]
Q = self.W_Q(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
K = self.W_K(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
V = self.W_V(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
attend_logits = torch.matmul(Q[row].transpose(1, 2), K[col].permute(0, 2, 3, 1)).view(
[n_edges, self.num_heads, n_channels, n_channels])
attend_logits /= np.sqrt(self.d)
attend_logits = attend_logits + self.sigma_D(dist_rep).permute(0, 3, 1, 2) # distance bias
attend_mask = (atom_mask[row].unsqueeze(-1) @ atom_mask[col].unsqueeze(-2)).unsqueeze(1).repeat(1,
self.num_heads,
1, 1)
# sparse attention, only keep top k=3
attend_logits[torch.logical_not(attend_mask)] = -1e5 # do not sellect from entries not attend
_, top_indices = torch.topk(attend_logits, self.sparse_k, dim=-1, largest=True)
sparse_mask = torch.zeros_like(attend_logits, dtype=torch.bool)
rows = torch.arange(attend_logits.size(0)).view(-1, 1, 1, 1).expand(-1, attend_logits.size(1),
attend_logits.size(2), self.sparse_k)
depth = torch.arange(attend_logits.size(1)).view(1, -1, 1, 1).expand(attend_logits.size(0), -1,
attend_logits.size(2), self.sparse_k)
height = torch.arange(attend_logits.size(2)).view(1, 1, -1, 1).expand(attend_logits.size(0),
attend_logits.size(1), -1, self.sparse_k)
sparse_mask[rows, depth, height, top_indices] = True
attend_logits = attend_logits * sparse_mask
attend_logits = attend_logits * attend_mask
# calculate beta
atom_mask_head = atom_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
r_ij = atom_mask_head[row].unsqueeze(-2) @ attend_logits @ atom_mask_head[col].unsqueeze(-1)
r_ij = r_ij.squeeze() / (attend_mask * sparse_mask).sum(-1).sum(-1) # take avarage over non-zero entries
beta = scatter_softmax(r_ij, row, dim=0) # [nnodes, num_heads]
attend_logits = torch.softmax(attend_logits, dim=-1)
attend_logits = attend_logits * attend_mask * sparse_mask
attend_logits = attend_logits / (
attend_logits.norm(1, dim=-1).unsqueeze(-1) + 1e-7) # normalize over every rows
alpha_ij_Vj = attend_logits @ V[col].transpose(1, 2) # [nedges, num_heads, 14, d]
alpha_ij_Vj = self.block_mlp_invariant(alpha_ij_Vj)
# invariant res_H update
update_H = scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * alpha_ij_Vj, row, dim=0, dim_size=n_nodes)
update_H = update_H.transpose(1, 2).reshape(n_nodes, n_channels, -1) # (nnodes, 14, heads*d)
res_H = res_H + self.W_O(update_H)
res_H = res_H * atom_mask.unsqueeze(-1) # set empty entry zeros
###################################################################################
# equivariant res_X update
X_ij = res_X[row].unsqueeze(-2) - res_X[col].unsqueeze(1) # [nedges, 14, 14, 3]
X_ij = X_ij / (X_ij.norm(2, dim=-1).unsqueeze(-1) + 1e-5)
X_ij = X_ij.unsqueeze(1).repeat(1, self.num_heads, 1, 1, 1)
X_ij = X_ij * attend_mask.unsqueeze(-1)
dist_rep = self.sigma_v(dist_rep).view([n_edges, n_channels, n_channels, self.num_heads, self.d]).permute(0, 3, 1, 2, 4)
input = torch.cat(
[(Q[row].transpose(1, 2))[rows, depth, height], (K[col].transpose(1, 2))[rows, depth, top_indices],
dist_rep[rows, depth, height, top_indices]], dim=-1)
f = self.T_i(input).view(n_edges, self.num_heads, n_channels, self.sparse_k)
f = f.unsqueeze(-1) * X_ij[rows, depth, height, top_indices].view(n_edges, self.num_heads, n_channels, self.sparse_k, 3)
f = (f * attend_logits[rows, depth, height, top_indices].unsqueeze(-1)).sum(-2) # [nedges, num_heads, 14, 3]
res_X[residue_mask] = res_X[residue_mask] + torch.clamp(
scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * f, row, dim=0, dim_size=n_nodes).sum(1)[residue_mask],
min=-3.0, max=3.0)
res_X = res_X * atom_mask.unsqueeze(-1) # set empty entries zeros
return res_H, res_X
def attention_ligand(self, res_H, res_X, atom_mask, batch, ligand_pos, ligand_feat, ligand_mask, edit_residue_num,
residue_mask):
batch_size = batch.max().item() + 1
n_nodes = res_H.shape[0]
n_channels = res_X.shape[1]
lig_channel = ligand_feat.shape[1]
row = torch.arange(n_nodes).to(self.device)[residue_mask]
col = torch.repeat_interleave(torch.arange(batch_size).to(self.device), edit_residue_num)
n_edges = row.shape[0]
Q = self.W_Q_lig(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
K_lig = self.W_K_lig(ligand_feat).view([batch_size, lig_channel, self.num_heads, self.d])
V_lig = self.W_V_lig(ligand_feat).view([batch_size, lig_channel, self.num_heads, self.d])
V = self.W_V(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
R_ij = torch.cdist(res_X[row], ligand_pos[col], p=2)
attend_mask = (atom_mask[row].unsqueeze(-1) @ ligand_mask[col].unsqueeze(-2))
R_ij = R_ij * attend_mask
dist_rep1 = self.distance_expansion(R_ij).view(row.shape[0], res_X.shape[1], ligand_pos.shape[1], -1)
attend_logits = torch.matmul(Q[row].transpose(1, 2), K_lig[col].permute(0, 2, 3, 1)).view(
[n_edges, self.num_heads, n_channels, lig_channel])
attend_logits /= np.sqrt(self.d)
attend_logits = attend_logits + self.sigma_D1(dist_rep1).permute(0, 3, 1, 2)
attend_mask = attend_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
# sparse attention, only keep top k=3
attend_logits[torch.logical_not(attend_mask)] = -1e5 # do not sellect from entries not attend
_, top_indices = torch.topk(attend_logits, self.sparse_k, dim=-1, largest=True)
sparse_mask = torch.zeros_like(attend_logits, dtype=torch.bool)
rows = torch.arange(attend_logits.size(0)).view(-1, 1, 1, 1).expand(-1, attend_logits.size(1),
attend_logits.size(2), self.sparse_k)
depth = torch.arange(attend_logits.size(1)).view(1, -1, 1, 1).expand(attend_logits.size(0), -1,
attend_logits.size(2), self.sparse_k)
height = torch.arange(attend_logits.size(2)).view(1, 1, -1, 1).expand(attend_logits.size(0),
attend_logits.size(1), -1, self.sparse_k)
sparse_mask[rows, depth, height, top_indices] = True
attend_logits = attend_logits * attend_mask
attend_logits = attend_logits * sparse_mask
# calculate beta
atom_mask_head = atom_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
ligand_mask_head = ligand_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
r_ij = atom_mask_head[row].unsqueeze(-2) @ attend_logits @ ligand_mask_head[col].unsqueeze(-1)
r_ij = r_ij.squeeze() / (attend_mask * sparse_mask).sum(-1).sum(-1) # take avarage over non-zero entries
beta = scatter_softmax(r_ij, col, dim=0)
attend_logits_lig = torch.softmax(attend_logits, dim=-2) * attend_mask
attend_logits_lig = attend_logits_lig / (attend_logits_lig.norm(1, dim=-2).unsqueeze(-2) + 1e-7)
attend_logits = attend_logits * sparse_mask
attend_logits_res = torch.softmax(attend_logits, dim=-1)
attend_logits_res = attend_logits_res * attend_mask * sparse_mask
attend_logits_res = attend_logits_res / (attend_logits_res.norm(1, dim=-1).unsqueeze(-1) + 1e-7)
alpha_ij_Vj = attend_logits_res @ V_lig[col].transpose(1, 2)
# invariant feature update
res_H[residue_mask] = res_H[residue_mask] + self.W_O_lig(
alpha_ij_Vj.transpose(1, 2).reshape(residue_mask.sum(), n_channels, -1))
res_H = res_H * atom_mask.unsqueeze(-1) # set empty entries zeros
alpha_ij_Vj_lig = attend_logits_lig.transpose(-1, -2) @ V[row].transpose(1, 2)
update_lig_feat = scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * alpha_ij_Vj_lig, col, dim=0,
dim_size=batch_size)
ligand_feat = ligand_feat + self.W_O_lig1(update_lig_feat.transpose(1, 2).reshape(batch_size, lig_channel, -1))
ligand_feat = ligand_feat * ligand_mask.unsqueeze(-1) # set empty entries zeros
###################################################################################
# equivariant res_X update
X_ij = res_X[row].unsqueeze(-2) - ligand_pos[col].unsqueeze(1)
X_ij = X_ij / (X_ij.norm(2, dim=-1).unsqueeze(-1) + 1e-5)
X_ij = X_ij.unsqueeze(1).repeat(1, self.num_heads, 1, 1, 1)
X_ij = X_ij * attend_mask.unsqueeze(-1)
dist_rep1 = self.sigma_v(dist_rep1).view([n_edges, n_channels, lig_channel, self.num_heads, self.d]).permute(0,
3,
1,
2,
4)
input = torch.cat(
[(Q[row].transpose(1, 2))[rows, depth, height], (K_lig[col].transpose(1, 2))[rows, depth, top_indices],
dist_rep1[rows, depth, height, top_indices]], dim=-1)
f_res = self.T_e1(input).view(n_edges, self.num_heads, n_channels, self.sparse_k)
f_res = f_res.unsqueeze(-1) * X_ij[rows, depth, height, top_indices].view(n_edges, self.num_heads, n_channels,
self.sparse_k, 3)
f_res = (f_res * attend_logits[rows, depth, height, top_indices].unsqueeze(-1)).sum(-2).sum(
1) # [nedges, 14, 3]
res_X[residue_mask] = res_X[residue_mask] + torch.clamp(f_res, min=-3.0, max=3.0)
res_X = res_X * atom_mask.unsqueeze(-1) # set empty entries zeros
p_idx, q_idx = torch.cartesian_prod(torch.arange(n_channels), torch.arange(lig_channel)).chunk(2, dim=-1)
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
dist_rep1 = dist_rep1.permute(0, 2, 3, 1, 4)
f_lig = self.T_e2(torch.cat(
[Q[row][:, p_idx].view(-1, self.hidden_channels), K_lig[col][:, q_idx].view(-1, self.hidden_channels),
dist_rep1.view(-1, self.hidden_channels)], dim=-1)).view(n_edges, n_channels, lig_channel, self.num_heads)
# attend_mask = (atom_mask[row].unsqueeze(-1) @ ligand_mask[col].unsqueeze(-2)).bool()
f_lig = f_lig.permute(0, 3, 1, 2) * attend_logits_lig
f_lig = (f_lig.unsqueeze(-1) * X_ij).sum(2) # [n_edges, num_heads, lig_channel, 3]
ligand_pos = ligand_pos + scatter_sum((beta.unsqueeze(-1).unsqueeze(-1) * f_lig).sum(1), col, dim=0,
dim_size=batch_size)
ligand_pos = ligand_pos * ligand_mask.unsqueeze(-1) # set empty entries zeros
return res_H, res_X, ligand_pos, ligand_feat
def attention_res_ligand(self, res_H, res_X, atom_mask, batch, ligand_pos, ligand_feat, ligand_mask, edit_residue_num,
residue_mask):
edge_index = self.connect_edges(res_X, batch)
row, col = edge_index
R_ij = torch.cdist(res_X[row], res_X[col], p=2)
dist_rep = self.distance_expansion(R_ij).view(row.shape[0], res_X.shape[1], res_X.shape[1], -1)
n_nodes = res_H.shape[0]
n_edges = edge_index.shape[1]
n_channels = res_X.shape[1]
Q = self.W_Q(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
K = self.W_K(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
V = self.W_V(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
Q_lig = self.W_Q_lig(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
attend_logits = torch.matmul(Q[row].transpose(1, 2), K[col].permute(0, 2, 3, 1)).view(
[n_edges, self.num_heads, n_channels, n_channels])
attend_logits /= np.sqrt(self.d)
attend_logits = attend_logits + self.sigma_D(dist_rep).permute(0, 3, 1, 2) # distance bias
attend_mask = (atom_mask[row].unsqueeze(-1) @ atom_mask[col].unsqueeze(-2)).unsqueeze(1).repeat(1,
self.num_heads,
1, 1)
# sparse attention, only keep top k=3
attend_logits[torch.logical_not(attend_mask)] = -1e5 # do not sellect from entries not attend
_, top_indices = torch.topk(attend_logits, self.sparse_k, dim=-1, largest=True)
sparse_mask = torch.zeros_like(attend_logits, dtype=torch.bool)
rows = torch.arange(attend_logits.size(0)).view(-1, 1, 1, 1).expand(-1, attend_logits.size(1),
attend_logits.size(2), self.sparse_k)
depth = torch.arange(attend_logits.size(1)).view(1, -1, 1, 1).expand(attend_logits.size(0), -1,
attend_logits.size(2), self.sparse_k)
height = torch.arange(attend_logits.size(2)).view(1, 1, -1, 1).expand(attend_logits.size(0),
attend_logits.size(1), -1, self.sparse_k)
sparse_mask[rows, depth, height, top_indices] = True
attend_logits = attend_logits * sparse_mask
attend_logits = attend_logits * attend_mask
# calculate beta
atom_mask_head = atom_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
r_ij = atom_mask_head[row].unsqueeze(-2) @ attend_logits @ atom_mask_head[col].unsqueeze(-1)
r_ij = r_ij.squeeze() / (attend_mask * sparse_mask).sum(-1).sum(-1) # take avarage over non-zero entries
beta = scatter_softmax(r_ij, row, dim=0) # [nnodes, num_heads]
attend_logits = torch.softmax(attend_logits, dim=-1)
attend_logits = attend_logits * attend_mask * sparse_mask
attend_logits = attend_logits / (
attend_logits.norm(1, dim=-1).unsqueeze(-1) + 1e-7) # normalize over every rows
alpha_ij_Vj = attend_logits @ V[col].transpose(1, 2) # [nedges, num_heads, 14, d]
alpha_ij_Vj = self.block_mlp_invariant(alpha_ij_Vj)
# invariant res_H update
update_H = scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * alpha_ij_Vj, row, dim=0, dim_size=n_nodes)
update_H = update_H.transpose(1, 2).reshape(n_nodes, n_channels, -1) # (nnodes, 14, heads*d)
res_H = res_H + self.W_O(update_H)
res_H = res_H * atom_mask.unsqueeze(-1) # set empty entry zeros
###################################################################################
# equivariant res_X update
X_ij = res_X[row].unsqueeze(-2) - res_X[col].unsqueeze(1) # [nedges, 14, 14, 3]
X_ij = X_ij / (X_ij.norm(2, dim=-1).unsqueeze(-1) + 1e-5)
X_ij = X_ij.unsqueeze(1).repeat(1, self.num_heads, 1, 1, 1)
X_ij = X_ij * attend_mask.unsqueeze(-1)
dist_rep = self.sigma_v(dist_rep).view([n_edges, n_channels, n_channels, self.num_heads, self.d]).permute(0, 3,
1, 2,
4)
input = torch.cat(
[(Q[row].transpose(1, 2))[rows, depth, height], (K[col].transpose(1, 2))[rows, depth, top_indices],
dist_rep[rows, depth, height, top_indices]], dim=-1)
f = self.T_i(input).view(n_edges, self.num_heads, n_channels, self.sparse_k)
f = f.unsqueeze(-1) * X_ij[rows, depth, height, top_indices].view(n_edges, self.num_heads, n_channels,
self.sparse_k, 3)
f = (f * attend_logits[rows, depth, height, top_indices].unsqueeze(-1)).sum(-2) # [nedges, num_heads, 14, 3]
res_X[residue_mask] = res_X[residue_mask] + torch.clamp(
scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * f, row, dim=0, dim_size=n_nodes).sum(1)[residue_mask],
min=-3.0, max=3.0)
res_X = res_X * atom_mask.unsqueeze(-1) # set empty entries zeros
###############################################################################################################
batch_size = batch.max().item() + 1
n_nodes = res_H.shape[0]
n_channels = res_X.shape[1]
lig_channel = ligand_feat.shape[1]
row = torch.arange(n_nodes).to(self.device)[residue_mask]
col = torch.repeat_interleave(torch.arange(batch_size).to(self.device), edit_residue_num)
n_edges = row.shape[0]
#Q_lig = self.W_Q_lig(res_H).view([n_nodes, n_channels, self.num_heads, self.d])
#Q_lig = Q
K_lig = self.W_K_lig(ligand_feat).view([batch_size, lig_channel, self.num_heads, self.d])
V_lig = self.W_V_lig(ligand_feat).view([batch_size, lig_channel, self.num_heads, self.d])
R_ij = torch.cdist(res_X[row], ligand_pos[col], p=2)
attend_mask = (atom_mask[row].unsqueeze(-1) @ ligand_mask[col].unsqueeze(-2))
R_ij = R_ij * attend_mask
dist_rep1 = self.distance_expansion(R_ij).view(row.shape[0], res_X.shape[1], ligand_pos.shape[1], -1)
attend_logits = torch.matmul(Q_lig[row].transpose(1, 2), K_lig[col].permute(0, 2, 3, 1)).view(
[n_edges, self.num_heads, n_channels, lig_channel])
attend_logits /= np.sqrt(self.d)
attend_logits = attend_logits + self.sigma_D1(dist_rep1).permute(0, 3, 1, 2)
attend_mask = attend_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
# sparse attention, only keep top k=3
attend_logits[torch.logical_not(attend_mask)] = -1e5 # do not sellect from entries not attend
_, top_indices = torch.topk(attend_logits, self.sparse_k, dim=-1, largest=True)
sparse_mask = torch.zeros_like(attend_logits, dtype=torch.bool)
rows = torch.arange(attend_logits.size(0)).view(-1, 1, 1, 1).expand(-1, attend_logits.size(1),
attend_logits.size(2), self.sparse_k)
depth = torch.arange(attend_logits.size(1)).view(1, -1, 1, 1).expand(attend_logits.size(0), -1,
attend_logits.size(2), self.sparse_k)
height = torch.arange(attend_logits.size(2)).view(1, 1, -1, 1).expand(attend_logits.size(0),
attend_logits.size(1), -1, self.sparse_k)
sparse_mask[rows, depth, height, top_indices] = True
attend_logits = attend_logits * attend_mask
attend_logits = attend_logits * sparse_mask
return_attend = copy.deepcopy(attend_logits)
# calculate beta
atom_mask_head = atom_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
ligand_mask_head = ligand_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
r_ij = atom_mask_head[row].unsqueeze(-2) @ attend_logits @ ligand_mask_head[col].unsqueeze(-1)
r_ij = r_ij.squeeze() / (attend_mask * sparse_mask).sum(-1).sum(-1) # take avarage over non-zero entries
beta = scatter_softmax(r_ij, col, dim=0)
attend_logits_lig = torch.softmax(attend_logits, dim=-2) * attend_mask
attend_logits_lig = attend_logits_lig / (attend_logits_lig.norm(1, dim=-2).unsqueeze(-2) + 1e-7)
attend_logits = attend_logits * sparse_mask
attend_logits_res = torch.softmax(attend_logits, dim=-1)
attend_logits_res = attend_logits_res * attend_mask * sparse_mask
attend_logits_res = attend_logits_res / (attend_logits_res.norm(1, dim=-1).unsqueeze(-1) + 1e-7)
alpha_ij_Vj = attend_logits_res @ V_lig[col].transpose(1, 2)
# invariant feature update
res_H[residue_mask] = res_H[residue_mask] + self.W_O_lig(
alpha_ij_Vj.transpose(1, 2).reshape(residue_mask.sum(), n_channels, -1))
res_H = res_H * atom_mask.unsqueeze(-1) # set empty entries zeros
alpha_ij_Vj_lig = attend_logits_lig.transpose(-1, -2) @ V[row].transpose(1, 2)
update_lig_feat = scatter_sum(beta.unsqueeze(-1).unsqueeze(-1) * alpha_ij_Vj_lig, col, dim=0,
dim_size=batch_size)
ligand_feat = ligand_feat + self.W_O_lig1(update_lig_feat.transpose(1, 2).reshape(batch_size, lig_channel, -1))
ligand_feat = ligand_feat * ligand_mask.unsqueeze(-1) # set empty entries zeros
###################################################################################
# equivariant res_X update
X_ij = res_X[row].unsqueeze(-2) - ligand_pos[col].unsqueeze(1)
X_ij = X_ij / (X_ij.norm(2, dim=-1).unsqueeze(-1) + 1e-5)
X_ij = X_ij.unsqueeze(1).repeat(1, self.num_heads, 1, 1, 1)
X_ij = X_ij * attend_mask.unsqueeze(-1)
dist_rep1 = self.sigma_v(dist_rep1).view([n_edges, n_channels, lig_channel, self.num_heads, self.d]).permute(0,
3,
1,
2,
4)
input = torch.cat(
[(Q_lig[row].transpose(1, 2))[rows, depth, height], (K_lig[col].transpose(1, 2))[rows, depth, top_indices],
dist_rep1[rows, depth, height, top_indices]], dim=-1)
f_res = self.T_e1(input).view(n_edges, self.num_heads, n_channels, self.sparse_k)
f_res = f_res.unsqueeze(-1) * X_ij[rows, depth, height, top_indices].view(n_edges, self.num_heads, n_channels,
self.sparse_k, 3)
f_res = (f_res * attend_logits[rows, depth, height, top_indices].unsqueeze(-1)).sum(-2).sum(
1) # [nedges, 14, 3]
res_X[residue_mask] = res_X[residue_mask] + torch.clamp(f_res, min=-3.0, max=3.0)
res_X = res_X * atom_mask.unsqueeze(-1) # set empty entries zeros
p_idx, q_idx = torch.cartesian_prod(torch.arange(n_channels), torch.arange(lig_channel)).chunk(2, dim=-1)
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
dist_rep1 = dist_rep1.permute(0, 2, 3, 1, 4)
f_lig = self.T_e2(torch.cat(
[Q_lig[row][:, p_idx].view(-1, self.hidden_channels), K_lig[col][:, q_idx].view(-1, self.hidden_channels),
dist_rep1.view(-1, self.hidden_channels)], dim=-1)).view(n_edges, n_channels, lig_channel, self.num_heads)
# attend_mask = (atom_mask[row].unsqueeze(-1) @ ligand_mask[col].unsqueeze(-2)).bool()
f_lig = f_lig.permute(0, 3, 1, 2) * attend_logits_lig
f_lig = (f_lig.unsqueeze(-1) * X_ij).sum(2) # [n_edges, num_heads, lig_channel, 3]
ligand_pos = ligand_pos + scatter_sum((beta.unsqueeze(-1).unsqueeze(-1) * f_lig).sum(1), col, dim=0,
dim_size=batch_size)
ligand_pos = ligand_pos * ligand_mask.unsqueeze(-1) # set empty entries zeros
return res_H, res_X, ligand_pos, ligand_feat, return_attend
def forward(self, res_H, res_X, atom_mask, batch, ligand_pos, ligand_feat, ligand_mask, edit_residue_num, residue_mask):
#res_H, res_X = self.attention(res_H, res_X, atom_mask, batch, residue_mask)
res_H, res_X, ligand_pos, ligand_feat, return_attend = self.attention_res_ligand(res_H, res_X, atom_mask, batch, ligand_pos,
ligand_feat, ligand_mask, edit_residue_num,
residue_mask)
pred_res_type = self.residue_mlp(res_H[residue_mask].sum(-2))
return res_H, res_X, ligand_pos, ligand_feat, pred_res_type, return_attend
class EquivariantFFN(nn.Module):
def __init__(self, d_in, d_hidden, d_out, n_channel=1, n_rbf=16, act_fn=nn.SiLU(),
residual=True, dropout=0.1, constant=1, z_requires_grad=True) -> None:
super().__init__()
self.constant = constant
self.residual = residual
self.n_rbf = n_rbf
self.mlp_h = nn.Sequential(
nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, d_out),
nn.Dropout(dropout)
)
self.mlp_z = nn.Sequential(
nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, 1),
nn.Dropout(dropout)
)
if not z_requires_grad:
for param in self.mlp_z.parameters():
param.requires_grad = False
self.rbf = RadialBasis(n_rbf, 7.0)
def forward(self, H, X, atom_mask, residue_mask):
'''
:param H: [N, d_in]
:param Z: [N, n_channel, 3]
:param block_id: [Nu]
'''
radial, (X_c, X_o) = self._radial(X, atom_mask) # [N, n_hidden_channel], ([N, 1, 3], [N, n_channel, 3]
# H_c = scatter_mean(H, block_id, dim=0)[block_id] # [N, d_in]
H_c = (H * atom_mask.unsqueeze(-1)).sum(-2) / atom_mask.sum(-1).unsqueeze(-1)
H_c = H_c.unsqueeze(-2).repeat(1, 14, 1)
inputs = torch.cat([H, H_c, radial], dim=-1) # [N, 14, d_in + d_in + n_rbf]
H_update = self.mlp_h(inputs)
H = H + H_update if self.residual else H_update
X_update = X_c.unsqueeze(-2) + self.mlp_z(inputs) * X_o
X[residue_mask] = X_update[residue_mask]
H, X = H * atom_mask.unsqueeze(-1), X * atom_mask.unsqueeze(-1)
return H, X
def _radial(self, X, atom_mask):
X_c = (X * atom_mask.unsqueeze(-1)).sum(-2) / atom_mask.sum(-1).unsqueeze(-1) # center
X_o = X - X_c.unsqueeze(-2) # [N, 14, 3], no translation
X_o = X_o * atom_mask.unsqueeze(-1)
D = stable_norm(X_o, dim=-1) # [N, 14]
radial = self.rbf(D.view(-1)).view(D.shape[0], D.shape[1], -1) # [N, 14, n_rbf]
return radial, (X_c, X_o)
class InvariantLayerNorm(nn.Module):
def __init__(self, d_hidden) -> None:
super().__init__()
self.layernorm = nn.LayerNorm(d_hidden)
self.layernorm1 = nn.LayerNorm(d_hidden)
def forward(self, res_H, ligand_feat, atom_mask, ligand_mask):
res_H[atom_mask.bool()] = self.layernorm(res_H[atom_mask.bool()])
ligand_feat[ligand_mask.bool()] = self.layernorm1(ligand_feat[ligand_mask.bool()])
return res_H, ligand_feat
class GET(nn.Module):
'''Equivariant Adaptive Block Transformer'''
def __init__(self, hidden_channels=128, edge_channels=64, key_channels=128, num_heads=4, num_interactions=6, k=8,
cutoff=10.0, device='cuda:0', n_layers=4, pre_norm=False, sparse_k=3):
super().__init__()
'''
:param d_hidden: Number of hidden features
:param d_radial: Number of features for calculating geometric relations
:param n_channel: Number of channels of coordinates of each unit
:param n_rbf: Dimension of RBF feature, 1 for not using rbf
:param cutoff: cutoff for RBF
:param d_edge: Number of features for the edge features
:param n_layers: Number of layer
:param act_fn: Non-linearity
:param residual: Use residual connections, we recommend not changing this one
:param dropout: probability of dropout
'''
self.n_layers = n_layers
self.pre_norm = pre_norm
self.sparse_k = sparse_k
self.residue_atom_mask = residue_atom_mask.to(device)
if self.pre_norm:
self.pre_layernorm = InvariantLayerNorm(hidden_channels)
for i in range(0, n_layers):
self.add_module(f'layer_{i}',
GETLayer(hidden_channels, edge_channels, key_channels, num_heads, num_interactions, k,
cutoff, device))
self.add_module(f'layernorm0_{i}', InvariantLayerNorm(hidden_channels))
self.add_module(f'ffn_{i}', EquivariantFFN(
hidden_channels, 2 * hidden_channels, hidden_channels,
))
self.add_module(f'layernorm1_{i}', InvariantLayerNorm(hidden_channels))
def forward(self, res_H, res_X, res_S, batch, ligand_pos, ligand_feat, ligand_mask, edit_residue_num, residue_mask):
atom_mask = self.residue_atom_mask[res_S]
if self.pre_norm:
res_H, ligand_feat = self.pre_layernorm(res_H, ligand_feat, atom_mask, ligand_mask)
for i in range(self.n_layers):
res_H, res_X, ligand_pos, ligand_feat, pred_res_type, attend = self._modules[f'layer_{i}'](res_H, res_X, atom_mask,
batch, ligand_pos,
ligand_feat, ligand_mask,
edit_residue_num,
residue_mask)
res_H, ligand_feat = self._modules[f'layernorm0_{i}'](res_H, ligand_feat, atom_mask, ligand_mask)
res_H, res_X = self._modules[f'ffn_{i}'](res_H, res_X, atom_mask, residue_mask)
res_H, ligand_feat = self._modules[f'layernorm1_{i}'](res_H, ligand_feat, atom_mask, ligand_mask)
return res_H, res_X, ligand_pos, ligand_feat, pred_res_type, attend
def stable_norm(input, *args, **kwargs):
return torch.norm(input, *args, **kwargs)
input = input.clone()
with torch.no_grad():
sign = torch.sign(input)
input = torch.abs(input)
input.clamp_(min=1e-10)
input = sign * input
return torch.norm(input, *args, **kwargs)
if __name__ == '__main__':
hidden_channels = 128
edge_channels = 64
key_channels = 128
num_heads = 4
device = torch.device('cuda:0')
model = GET()
model.to(device)
model.eval()
res_H = torch.rand(10, 14, hidden_channels).to(device)
res_X = torch.rand(10, 14, 3).to(device)
res_S = torch.ones(10, dtype=torch.long)
atom_mask = residue_atom_mask[res_S].to(device)
ligand_mask = torch.tensor([[1., 1, 1, 0, 0], [1, 1, 1, 1, 1]]).to(device)
batch = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], device=device)
residue_mask = torch.ones(10, device=device).bool()
edit_residue_num = torch.tensor([3, 7], device=device)
ligand_pos, ligand_feat = torch.rand(2, 5, 3).to(device), torch.rand(2, 5, hidden_channels).to(device)
ligand_pos = ligand_pos * ligand_mask.unsqueeze(-1)
res_X = res_X * atom_mask.unsqueeze(-1)
U, _, V = torch.linalg.svd(torch.randn(3, 3, device=device, dtype=torch.float))
if torch.linalg.det(U) * torch.linalg.det(V) < 0:
U[:, -1] = -U[:, -1]
Q1, t1 = U.mm(V), torch.randn(3, device=device)
U, _, V = torch.linalg.svd(torch.randn(3, 3, device=device, dtype=torch.float))
if torch.linalg.det(U) * torch.linalg.det(V) < 0:
U[:, -1] = -U[:, -1]
Q2, t2 = U.mm(V), torch.randn(3, device=device)
res_H1, res_X1, ligand_pos1, ligand_feat1, pred_res_type1 = model(copy.deepcopy(res_H), res_X, res_S, batch,
ligand_pos, copy.deepcopy(ligand_feat),
ligand_mask, edit_residue_num,
residue_mask)
res_X1 = copy.deepcopy(res_X1.detach())
res_X1[batch == 0] = torch.matmul(res_X1[batch == 0], Q1) + t1
res_X1[batch == 1] = torch.matmul(res_X1[batch == 1], Q2) + t2
res_X1 = res_X1 * atom_mask.unsqueeze(-1)
ligand_pos1[0] = torch.matmul(ligand_pos1[0], Q1) + t1
ligand_pos1[1] = torch.matmul(ligand_pos1[1], Q2) + t2
ligand_pos1 = ligand_pos1 * ligand_mask.unsqueeze(-1)
res_X[batch == 0] = torch.matmul(res_X[batch == 0], Q1) + t1
res_X[batch == 1] = torch.matmul(res_X[batch == 1], Q2) + t2
res_X = res_X * atom_mask.unsqueeze(-1)
ligand_pos[0] = torch.matmul(ligand_pos[0], Q1) + t1
ligand_pos[1] = torch.matmul(ligand_pos[1], Q2) + t2
ligand_pos = ligand_pos * ligand_mask.unsqueeze(-1)
res_H2, res_X2, ligand_pos2, ligand_feat2, pred_res_type2 = model(copy.deepcopy(res_H), res_X, res_S, batch,
ligand_pos, copy.deepcopy(ligand_feat),
ligand_mask, edit_residue_num,
residue_mask)
print((res_X1 - res_X2).norm())
print((res_H1 - res_H2).float().norm())
print((pred_res_type1 - pred_res_type2).norm())
print((ligand_pos1 - ligand_pos2).norm())
print((ligand_feat1 - ligand_feat2).norm())