import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _WeightedLoss from torch_scatter import scatter_mean, scatter_add def split_tensor_by_batch(x, batch, num_graphs=None): """ Args: x: (N, ...) batch: (B, ) Returns: [(N_1, ), (N_2, ) ..., (N_B, ))] """ if num_graphs is None: num_graphs = batch.max().item() + 1 x_split = [] for i in range (num_graphs): mask = batch == i x_split.append(x[mask]) return x_split def concat_tensors_to_batch(x_split): x = torch.cat(x_split, dim=0) batch = torch.repeat_interleave( torch.arange(len(x_split)), repeats=torch.LongTensor([s.size(0) for s in x_split]) ).to(device=x.device) return x, batch def split_tensor_to_segments(x, segsize): num_segs = math.ceil(x.size(0) / segsize) segs = [] for i in range(num_segs): segs.append(x[i*segsize : (i+1)*segsize]) return segs def split_tensor_by_lengths(x, lengths): segs = [] for l in lengths: segs.append(x[:l]) x = x[l:] return segs def batch_intersection_mask(batch, batch_filter): batch_filter = batch_filter.unique() mask = (batch.view(-1, 1) == batch_filter.view(1, -1)).any(dim=1) return mask class MeanReadout(nn.Module): """Mean readout operator over graphs with variadic sizes.""" def forward(self, input, batch, num_graphs): """ Perform readout over the graph(s). Parameters: data (torch_geometric.data.Data): batched graph input (Tensor): node representations Returns: Tensor: graph representations """ output = scatter_mean(input, batch, dim=0, dim_size=num_graphs) return output class SumReadout(nn.Module): """Sum readout operator over graphs with variadic sizes.""" def forward(self, input, batch, num_graphs): """ Perform readout over the graph(s). Parameters: data (torch_geometric.data.Data): batched graph input (Tensor): node representations Returns: Tensor: graph representations """ output = scatter_add(input, batch, dim=0, dim_size=num_graphs) return output class MultiLayerPerceptron(nn.Module): """ Multi-layer Perceptron. Note there is no activation or dropout in the last layer. Parameters: input_dim (int): input dimension hidden_dim (list of int): hidden dimensions activation (str or function, optional): activation function dropout (float, optional): dropout rate """ def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): super(MultiLayerPerceptron, self).__init__() self.dims = [input_dim] + hidden_dims if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = None if dropout: self.dropout = nn.Dropout(dropout) else: self.dropout = None self.layers = nn.ModuleList() for i in range(len(self.dims) - 1): self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) def forward(self, input): """""" x = input for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: if self.activation: x = self.activation(x) if self.dropout: x = self.dropout(x) return x class SmoothCrossEntropyLoss(_WeightedLoss): def __init__(self, weight=None, reduction='mean', smoothing=0.0): super().__init__(weight=weight, reduction=reduction) self.smoothing = smoothing self.weight = weight self.reduction = reduction @staticmethod def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0): assert 0 <= smoothing < 1 with torch.no_grad(): targets = torch.empty(size=(targets.size(0), n_classes), device=targets.device) \ .fill_(smoothing /(n_classes-1)) \ .scatter_(1, targets.data.unsqueeze(1), 1.-smoothing) return targets def forward(self, inputs, targets): targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), self.smoothing) lsm = F.log_softmax(inputs, -1) if self.weight is not None: lsm = lsm * self.weight.unsqueeze(0) loss = -(targets * lsm).sum(-1) if self.reduction == 'sum': loss = loss.sum() elif self.reduction == 'mean': loss = loss.mean() return loss class GaussianSmearing(nn.Module): def __init__(self, start=0.0, stop=10.0, num_gaussians=50): super().__init__() offset = torch.linspace(start, stop, num_gaussians) self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 self.register_buffer('offset', offset) def forward(self, dist): dist = dist.view(-1, 1) - self.offset.view(1, -1) return torch.exp(self.coeff * torch.pow(dist, 2)) class ShiftedSoftplus(nn.Module): def __init__(self): super().__init__() self.shift = torch.log(torch.tensor(2.0)).item() def forward(self, x): return F.softplus(x) - self.shift def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) sort_idx = batch_ctx.argsort() mask_protein = torch.cat([ torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(), torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(), ], dim=0)[sort_idx] batch_ctx = batch_ctx[sort_idx] h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H) pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3) return h_ctx, pos_ctx, batch_ctx def get_complete_graph(batch): """ Args: batch: Batch index. Returns: edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph. neighbors: (B, ), number of edges per graph. """ natoms = scatter_add(torch.ones_like(batch), index=batch, dim=0) natoms_sqr = (natoms ** 2).long() num_atom_pairs = torch.sum(natoms_sqr) natoms_expand = torch.repeat_interleave(natoms, natoms_sqr) index_offset = torch.cumsum(natoms, dim=0) - natoms index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr) index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr) atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)]) mask = torch.logical_not(index1 == index2) edge_index = edge_index[:, mask] num_edges = natoms_sqr - natoms # Number of edges per graph return edge_index, num_edges def compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): num_graphs = batch_protein.max().item() + 1 batch_ctx = [] h_ctx = [] pos_ctx = [] mask_protein = [] for i in range(num_graphs): mask_p, mask_l = (batch_protein == i), (batch_ligand == i) batch_p, batch_l = batch_protein[mask_p], batch_ligand[mask_l] batch_ctx += [batch_p, batch_l] h_ctx += [h_protein[mask_p], h_ligand[mask_l]] pos_ctx += [pos_protein[mask_p], pos_ligand[mask_l]] mask_protein += [ torch.ones([batch_p.size(0)], device=batch_p.device, dtype=torch.bool), torch.zeros([batch_l.size(0)], device=batch_l.device, dtype=torch.bool), ] batch_ctx = torch.cat(batch_ctx, dim=0) h_ctx = torch.cat(h_ctx, dim=0) pos_ctx = torch.cat(pos_ctx, dim=0) mask_protein = torch.cat(mask_protein, dim=0) return h_ctx, pos_ctx, batch_ctx, mask_protein def compose_external_attention(batch_protein, batch_ligand, edit_protein_mask): num_graphs = batch_protein.max().item() + 1 row, col = [], [] protein_index, ligand_index = torch.arange(len(batch_protein)).to(batch_protein.device), torch.arange(len(batch_ligand)).to(batch_protein.device) for i in range(num_graphs): mask_p, mask_l = (batch_protein == i), (batch_ligand == i) p_idx, q_idx = torch.cartesian_prod(protein_index[mask_p][edit_protein_mask[mask_p]], ligand_index[mask_l]).chunk(2, dim=-1) p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1) row.append(p_idx) col.append(q_idx) row = torch.cat(row, dim=0).to(batch_protein.device) col = torch.cat(col, dim=0).to(batch_protein.device) return [row, col] def pos2X(pos, mask_protein, atom2residue): num_graphs = atom2residue.max().item() + 1 X = torch.zeros(num_graphs, 14, 3, device = mask_protein.device) pos_protein = pos[mask_protein] for i in range(num_graphs): mask_p = (atom2residue == i) batch_pos = pos_protein[mask_p] X[i][:len(batch_pos)] = batch_pos return X def X2pos(X): pos = [] for i in range(len(X)): mask = torch.norm(X[i], dim=-1)>1e-6 pos.append(X[i][mask]) pos = torch.cat(pos, dim=0) return pos if __name__ == '__main__': h_protein = torch.randn([60, 64]) h_ligand = -torch.randn([33, 64]) pos_protein = torch.clamp(torch.randn([60, 3]), 0, float('inf')) pos_ligand = torch.clamp(torch.randn([33, 3]), float('-inf'), 0) batch_protein = torch.LongTensor([0]*10 + [1]*20 + [2]*30) batch_ligand = torch.LongTensor([0]*11 + [1]*11 + [2]*11) h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand) assert (batch_ctx[mask_protein] == batch_protein).all() assert (batch_ctx[torch.logical_not(mask_protein)] == batch_ligand).all() assert torch.allclose(h_ctx[torch.logical_not(mask_protein)], h_ligand) assert torch.allclose(h_ctx[mask_protein], h_protein) assert torch.allclose(pos_ctx[torch.logical_not(mask_protein)], pos_ligand) assert torch.allclose(pos_ctx[mask_protein], pos_protein)