| import numpy as np | |
| import scipy.spatial as ss | |
| import torch | |
| import torch.nn.functional as F | |
| from torch_geometric.utils import to_undirected | |
| from torch_sparse import coalesce | |
| atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'} | |
| residue_mapping = {0:'ALA', 1:'ARG', 2:'ASN', 3:'ASP', 4:'CYS', 5:'CYX', 6:'GLN', 7:'GLU', 8:'GLY', 9:'HIE', 10:'ILE', 11:'LEU', 12:'LYS', 13:'MET', 14:'PHE', 15:'PRO', 16:'SER', 17:'THR', 18:'TRP', 19:'TYR', 20:'VAL', 21:'UNK'} | |
| ligand_atoms_mapping = {8: 0, 16: 1, 6: 2, 7: 3, 1: 4, 15: 5, 17: 6, 9: 7, 53: 8, 35: 9, 5: 10, 33: 11, 26: 12, 14: 13, 34: 14, 44: 15, 12: 16, 23: 17, 77: 18, 27: 19, 52: 20, 30: 21, 4: 22, 45: 23} | |
| def prot_df_to_graph(item, df, edge_dist_cutoff, feat_col='element'): | |
| r""" | |
| Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom. | |
| :param df: Protein structure in dataframe format. | |
| :type df: pandas.DataFrame | |
| :param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"`` | |
| :type node_col: str, optional | |
| :param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features. | |
| Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`). | |
| :type allowable_feats: list, optional | |
| :param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5. | |
| :type edge_dist_cutoff: float, optional | |
| :return: tuple containing | |
| - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``. | |
| - edges (torch.LongTensor): Edges in COO format | |
| - edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`. | |
| - node_pos (torch.FloatTensor): x-y-z coordinates of each node | |
| :rtype: Tuple | |
| """ | |
| allowable_feats = atom_mapping | |
| try : | |
| node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy()) | |
| kd_tree = ss.KDTree(node_pos) | |
| edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff)) | |
| edges = torch.LongTensor(edge_tuples).t().contiguous() | |
| edges = to_undirected(edges) | |
| except: | |
| print(f"Problem with PDB Id is {item['id']}") | |
| node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e-1, allowable_feats) for e in df[feat_col]]) | |
| edge_weights = torch.FloatTensor( | |
| [1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1) | |
| return node_feats, edges, edge_weights, node_pos | |
| def mol_df_to_graph_for_qm(df, bonds=None, allowable_atoms=None, edge_dist_cutoff=4.5, onehot_edges=True): | |
| """ | |
| Converts molecule in dataframe to a graph compatible with Pytorch-Geometric | |
| :param df: Molecule structure in dataframe format | |
| :type mol: pandas.DataFrame | |
| :param bonds: Molecule structure in dataframe format | |
| :type bonds: pandas.DataFrame | |
| :param allowable_atoms: List containing allowable atom types | |
| :type allowable_atoms: list[str], optional | |
| :return: Tuple containing \n | |
| - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``. | |
| - edge_index (torch.LongTensor): Edges from chemical bond graph in COO format. | |
| - edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5. | |
| - node_pos (torch.FloatTensor): x-y-z coordinates of each node. | |
| """ | |
| if allowable_atoms is None: | |
| allowable_atoms = ligand_atoms_mapping | |
| node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy()) | |
| if bonds is not None: | |
| N = df.shape[0] | |
| bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3} | |
| bond_data = torch.FloatTensor(bonds) | |
| edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0) | |
| edge_index = edge_tuples.t().long().contiguous() | |
| if onehot_edges: | |
| bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) | |
| edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float) | |
| edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N) | |
| else: | |
| edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0) | |
| else: | |
| kd_tree = ss.KDTree(node_pos) | |
| edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff)) | |
| edge_index = torch.LongTensor(edge_tuples).t().contiguous() | |
| edge_index = to_undirected(edge_index) | |
| edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1) | |
| edge_attr = edge_attr.unsqueeze(1) | |
| node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices_qm(e, allowable_atoms) for e in df['element']]) | |
| return node_feats, edge_index, edge_attr, node_pos | |
| def one_of_k_encoding_unk_indices(x, allowable_set): | |
| """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element.""" | |
| one_hot_encoding = [0] * len(allowable_set) | |
| if x in allowable_set: | |
| one_hot_encoding[x] = 1 | |
| else: | |
| one_hot_encoding[-1] = 1 | |
| return one_hot_encoding | |
| def one_of_k_encoding_unk_indices_qm(x, allowable_set): | |
| """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element.""" | |
| one_hot_encoding = [0] * (len(allowable_set)+1) | |
| if x in allowable_set: | |
| one_hot_encoding[allowable_set[x]] = 1 | |
| else: | |
| one_hot_encoding[-1] = 1 | |
| return one_hot_encoding |