Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.modules.loss import _WeightedLoss | |
| from torch_scatter import scatter_mean, scatter_add | |
| def split_tensor_by_batch(x, batch, num_graphs=None): | |
| """ | |
| Args: | |
| x: (N, ...) | |
| batch: (B, ) | |
| Returns: | |
| [(N_1, ), (N_2, ) ..., (N_B, ))] | |
| """ | |
| if num_graphs is None: | |
| num_graphs = batch.max().item() + 1 | |
| x_split = [] | |
| for i in range (num_graphs): | |
| mask = batch == i | |
| x_split.append(x[mask]) | |
| return x_split | |
| def concat_tensors_to_batch(x_split): | |
| x = torch.cat(x_split, dim=0) | |
| batch = torch.repeat_interleave( | |
| torch.arange(len(x_split)), | |
| repeats=torch.LongTensor([s.size(0) for s in x_split]) | |
| ).to(device=x.device) | |
| return x, batch | |
| def split_tensor_to_segments(x, segsize): | |
| num_segs = math.ceil(x.size(0) / segsize) | |
| segs = [] | |
| for i in range(num_segs): | |
| segs.append(x[i*segsize : (i+1)*segsize]) | |
| return segs | |
| def split_tensor_by_lengths(x, lengths): | |
| segs = [] | |
| for l in lengths: | |
| segs.append(x[:l]) | |
| x = x[l:] | |
| return segs | |
| def batch_intersection_mask(batch, batch_filter): | |
| batch_filter = batch_filter.unique() | |
| mask = (batch.view(-1, 1) == batch_filter.view(1, -1)).any(dim=1) | |
| return mask | |
| class MeanReadout(nn.Module): | |
| """Mean readout operator over graphs with variadic sizes.""" | |
| def forward(self, input, batch, num_graphs): | |
| """ | |
| Perform readout over the graph(s). | |
| Parameters: | |
| data (torch_geometric.data.Data): batched graph | |
| input (Tensor): node representations | |
| Returns: | |
| Tensor: graph representations | |
| """ | |
| output = scatter_mean(input, batch, dim=0, dim_size=num_graphs) | |
| return output | |
| class SumReadout(nn.Module): | |
| """Sum readout operator over graphs with variadic sizes.""" | |
| def forward(self, input, batch, num_graphs): | |
| """ | |
| Perform readout over the graph(s). | |
| Parameters: | |
| data (torch_geometric.data.Data): batched graph | |
| input (Tensor): node representations | |
| Returns: | |
| Tensor: graph representations | |
| """ | |
| output = scatter_add(input, batch, dim=0, dim_size=num_graphs) | |
| return output | |
| class MultiLayerPerceptron(nn.Module): | |
| """ | |
| Multi-layer Perceptron. | |
| Note there is no activation or dropout in the last layer. | |
| Parameters: | |
| input_dim (int): input dimension | |
| hidden_dim (list of int): hidden dimensions | |
| activation (str or function, optional): activation function | |
| dropout (float, optional): dropout rate | |
| """ | |
| def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): | |
| super(MultiLayerPerceptron, self).__init__() | |
| self.dims = [input_dim] + hidden_dims | |
| if isinstance(activation, str): | |
| self.activation = getattr(F, activation) | |
| else: | |
| self.activation = None | |
| if dropout: | |
| self.dropout = nn.Dropout(dropout) | |
| else: | |
| self.dropout = None | |
| self.layers = nn.ModuleList() | |
| for i in range(len(self.dims) - 1): | |
| self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) | |
| def forward(self, input): | |
| """""" | |
| x = input | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x) | |
| if i < len(self.layers) - 1: | |
| if self.activation: | |
| x = self.activation(x) | |
| if self.dropout: | |
| x = self.dropout(x) | |
| return x | |
| class SmoothCrossEntropyLoss(_WeightedLoss): | |
| def __init__(self, weight=None, reduction='mean', smoothing=0.0): | |
| super().__init__(weight=weight, reduction=reduction) | |
| self.smoothing = smoothing | |
| self.weight = weight | |
| self.reduction = reduction | |
| def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0): | |
| assert 0 <= smoothing < 1 | |
| with torch.no_grad(): | |
| targets = torch.empty(size=(targets.size(0), n_classes), | |
| device=targets.device) \ | |
| .fill_(smoothing /(n_classes-1)) \ | |
| .scatter_(1, targets.data.unsqueeze(1), 1.-smoothing) | |
| return targets | |
| def forward(self, inputs, targets): | |
| targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), | |
| self.smoothing) | |
| lsm = F.log_softmax(inputs, -1) | |
| if self.weight is not None: | |
| lsm = lsm * self.weight.unsqueeze(0) | |
| loss = -(targets * lsm).sum(-1) | |
| if self.reduction == 'sum': | |
| loss = loss.sum() | |
| elif self.reduction == 'mean': | |
| loss = loss.mean() | |
| return loss | |
| class GaussianSmearing(nn.Module): | |
| def __init__(self, start=0.0, stop=10.0, num_gaussians=50): | |
| super().__init__() | |
| offset = torch.linspace(start, stop, num_gaussians) | |
| self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 | |
| self.register_buffer('offset', offset) | |
| def forward(self, dist): | |
| dist = dist.view(-1, 1) - self.offset.view(1, -1) | |
| return torch.exp(self.coeff * torch.pow(dist, 2)) | |
| class ShiftedSoftplus(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.shift = torch.log(torch.tensor(2.0)).item() | |
| def forward(self, x): | |
| return F.softplus(x) - self.shift | |
| def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): | |
| batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) | |
| sort_idx = batch_ctx.argsort() | |
| mask_protein = torch.cat([ | |
| torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(), | |
| torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(), | |
| ], dim=0)[sort_idx] | |
| batch_ctx = batch_ctx[sort_idx] | |
| h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H) | |
| pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3) | |
| return h_ctx, pos_ctx, batch_ctx | |
| def get_complete_graph(batch): | |
| """ | |
| Args: | |
| batch: Batch index. | |
| Returns: | |
| edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph. | |
| neighbors: (B, ), number of edges per graph. | |
| """ | |
| natoms = scatter_add(torch.ones_like(batch), index=batch, dim=0) | |
| natoms_sqr = (natoms ** 2).long() | |
| num_atom_pairs = torch.sum(natoms_sqr) | |
| natoms_expand = torch.repeat_interleave(natoms, natoms_sqr) | |
| index_offset = torch.cumsum(natoms, dim=0) - natoms | |
| index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr) | |
| index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr | |
| index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr) | |
| atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset | |
| index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand | |
| index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand | |
| edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)]) | |
| mask = torch.logical_not(index1 == index2) | |
| edge_index = edge_index[:, mask] | |
| num_edges = natoms_sqr - natoms # Number of edges per graph | |
| return edge_index, num_edges | |
| def compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): | |
| num_graphs = batch_protein.max().item() + 1 | |
| batch_ctx = [] | |
| h_ctx = [] | |
| pos_ctx = [] | |
| mask_protein = [] | |
| for i in range(num_graphs): | |
| mask_p, mask_l = (batch_protein == i), (batch_ligand == i) | |
| batch_p, batch_l = batch_protein[mask_p], batch_ligand[mask_l] | |
| batch_ctx += [batch_p, batch_l] | |
| h_ctx += [h_protein[mask_p], h_ligand[mask_l]] | |
| pos_ctx += [pos_protein[mask_p], pos_ligand[mask_l]] | |
| mask_protein += [ | |
| torch.ones([batch_p.size(0)], device=batch_p.device, dtype=torch.bool), | |
| torch.zeros([batch_l.size(0)], device=batch_l.device, dtype=torch.bool), | |
| ] | |
| batch_ctx = torch.cat(batch_ctx, dim=0) | |
| h_ctx = torch.cat(h_ctx, dim=0) | |
| pos_ctx = torch.cat(pos_ctx, dim=0) | |
| mask_protein = torch.cat(mask_protein, dim=0) | |
| return h_ctx, pos_ctx, batch_ctx, mask_protein | |
| def compose_external_attention(batch_protein, batch_ligand, edit_protein_mask): | |
| num_graphs = batch_protein.max().item() + 1 | |
| row, col = [], [] | |
| protein_index, ligand_index = torch.arange(len(batch_protein)).to(batch_protein.device), torch.arange(len(batch_ligand)).to(batch_protein.device) | |
| for i in range(num_graphs): | |
| mask_p, mask_l = (batch_protein == i), (batch_ligand == i) | |
| p_idx, q_idx = torch.cartesian_prod(protein_index[mask_p][edit_protein_mask[mask_p]], ligand_index[mask_l]).chunk(2, dim=-1) | |
| p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1) | |
| row.append(p_idx) | |
| col.append(q_idx) | |
| row = torch.cat(row, dim=0).to(batch_protein.device) | |
| col = torch.cat(col, dim=0).to(batch_protein.device) | |
| return [row, col] | |
| def pos2X(pos, mask_protein, atom2residue): | |
| num_graphs = atom2residue.max().item() + 1 | |
| X = torch.zeros(num_graphs, 14, 3, device = mask_protein.device) | |
| pos_protein = pos[mask_protein] | |
| for i in range(num_graphs): | |
| mask_p = (atom2residue == i) | |
| batch_pos = pos_protein[mask_p] | |
| X[i][:len(batch_pos)] = batch_pos | |
| return X | |
| def X2pos(X): | |
| pos = [] | |
| for i in range(len(X)): | |
| mask = torch.norm(X[i], dim=-1)>1e-6 | |
| pos.append(X[i][mask]) | |
| pos = torch.cat(pos, dim=0) | |
| return pos | |
| if __name__ == '__main__': | |
| h_protein = torch.randn([60, 64]) | |
| h_ligand = -torch.randn([33, 64]) | |
| pos_protein = torch.clamp(torch.randn([60, 3]), 0, float('inf')) | |
| pos_ligand = torch.clamp(torch.randn([33, 3]), float('-inf'), 0) | |
| batch_protein = torch.LongTensor([0]*10 + [1]*20 + [2]*30) | |
| batch_ligand = torch.LongTensor([0]*11 + [1]*11 + [2]*11) | |
| h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand) | |
| assert (batch_ctx[mask_protein] == batch_protein).all() | |
| assert (batch_ctx[torch.logical_not(mask_protein)] == batch_ligand).all() | |
| assert torch.allclose(h_ctx[torch.logical_not(mask_protein)], h_ligand) | |
| assert torch.allclose(h_ctx[mask_protein], h_protein) | |
| assert torch.allclose(pos_ctx[torch.logical_not(mask_protein)], pos_ligand) | |
| assert torch.allclose(pos_ctx[mask_protein], pos_protein) | |