Sophia Tang
commited on
Commit
·
e54915d
1
Parent(s):
9ab0e48
model upload
Browse files- dataloading_for_dynamic_batching.py +156 -0
- dataset.py +207 -0
- diffusion.py +1130 -0
- generate_mcts.py +398 -0
- generate_unconditional.py +126 -0
- main.py +250 -0
- noise_schedule.py +156 -0
- pareto_mcts.py +515 -0
- roformer.py +74 -0
- scoring/__init__.py +0 -0
- scoring/binary_xg.py +280 -0
- scoring/functions/binding.py +202 -0
- scoring/functions/binding_utils.py +291 -0
- scoring/functions/nonfouling.py +68 -0
- scoring/functions/permeability.py +168 -0
- scoring/functions/permeability_xg.py +176 -0
- scoring/functions/scoring_utils.py +111 -0
- scoring/functions/solubility.py +68 -0
- scoring/hemolysis.py +71 -0
- scoring/scoring_functions.py +104 -0
- tokenizer/__init__.py +0 -0
- tokenizer/my_tokenizers.py +441 -0
- tokenizer/new_splits.txt +159 -0
- tokenizer/new_vocab.txt +587 -0
- utils/__pycache__/app.cpython-39.pyc +0 -0
- utils/__pycache__/filter.cpython-39.pyc +0 -0
- utils/__pycache__/generate_utils.cpython-39.pyc +0 -0
- utils/__pycache__/helm_utils.cpython-39.pyc +0 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/app.py +1255 -0
- utils/generate_utils.py +77 -0
dataloading_for_dynamic_batching.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from datasets import Dataset,load_from_disk
|
| 5 |
+
import sys
|
| 6 |
+
import lightning.pytorch as pl
|
| 7 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 8 |
+
from functools import partial
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DynamicBatchingDataset(Dataset):
|
| 13 |
+
def __init__(self, dataset_dict, tokenizer):
|
| 14 |
+
print('Initializing dataset...')
|
| 15 |
+
self.dataset_dict = {
|
| 16 |
+
'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
|
| 17 |
+
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
|
| 18 |
+
'labels': dataset_dict['labels']
|
| 19 |
+
}
|
| 20 |
+
self.tokenizer = tokenizer
|
| 21 |
+
|
| 22 |
+
def __len__(self):
|
| 23 |
+
return len(self.dataset_dict['attention_mask'])
|
| 24 |
+
|
| 25 |
+
def __getitem__(self, idx):
|
| 26 |
+
if isinstance(idx, int):
|
| 27 |
+
return {
|
| 28 |
+
'input_ids': self.dataset_dict['input_ids'][idx],
|
| 29 |
+
'attention_mask': self.dataset_dict['attention_mask'][idx],
|
| 30 |
+
'labels': self.dataset_dict['labels'][idx]
|
| 31 |
+
}
|
| 32 |
+
elif isinstance(idx, list):
|
| 33 |
+
return {
|
| 34 |
+
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
|
| 35 |
+
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
|
| 36 |
+
'labels': [self.dataset_dict['labels'][i] for i in idx]
|
| 37 |
+
}
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
|
| 40 |
+
|
| 41 |
+
class CustomDataModule(pl.LightningDataModule):
|
| 42 |
+
def __init__(self, dataset_path, tokenizer):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.dataset = load_from_disk(dataset_path)
|
| 45 |
+
self.tokenizer = tokenizer
|
| 46 |
+
|
| 47 |
+
def peptide_bond_mask(self, smiles_list):
|
| 48 |
+
"""
|
| 49 |
+
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
|
| 50 |
+
of recognized bonds in the positions dictionary and 0 elsewhere.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
|
| 57 |
+
"""
|
| 58 |
+
# Initialize the batch mask
|
| 59 |
+
batch_size = len(smiles_list)
|
| 60 |
+
max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
|
| 61 |
+
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 62 |
+
|
| 63 |
+
bond_patterns = [
|
| 64 |
+
(r'OC\(=O\)', 'ester'),
|
| 65 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'),
|
| 66 |
+
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
|
| 67 |
+
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
|
| 68 |
+
(r'C\(=O\)N\(C\)', 'n_methyl'),
|
| 69 |
+
(r'C\(=O\)N[12]?', 'peptide')
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
for batch_idx, smiles in enumerate(smiles_list):
|
| 73 |
+
positions = []
|
| 74 |
+
used = set()
|
| 75 |
+
|
| 76 |
+
# Identify bonds
|
| 77 |
+
for pattern, bond_type in bond_patterns:
|
| 78 |
+
for match in re.finditer(pattern, smiles):
|
| 79 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 80 |
+
positions.append({
|
| 81 |
+
'start': match.start(),
|
| 82 |
+
'end': match.end(),
|
| 83 |
+
'type': bond_type,
|
| 84 |
+
'pattern': match.group()
|
| 85 |
+
})
|
| 86 |
+
used.update(range(match.start(), match.end()))
|
| 87 |
+
|
| 88 |
+
# Update the mask for the current SMILES
|
| 89 |
+
for pos in positions:
|
| 90 |
+
mask[batch_idx, pos['start']:pos['end']] = 1
|
| 91 |
+
|
| 92 |
+
return mask
|
| 93 |
+
|
| 94 |
+
def peptide_token_mask(self, smiles_list, token_lists):
|
| 95 |
+
"""
|
| 96 |
+
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
|
| 97 |
+
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 101 |
+
token_lists: List of tokenized SMILES strings (split into tokens).
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
|
| 105 |
+
"""
|
| 106 |
+
# Initialize the batch mask
|
| 107 |
+
batch_size = len(smiles_list)
|
| 108 |
+
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
|
| 109 |
+
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 110 |
+
atomwise_masks = self.peptide_bond_mask(smiles_list)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
|
| 114 |
+
token_seq = token_lists[batch_idx]
|
| 115 |
+
atom_idx = 0
|
| 116 |
+
|
| 117 |
+
for token_idx, token in enumerate(token_seq):
|
| 118 |
+
if token_idx != 0 and token_idx != len(token_seq) - 1:
|
| 119 |
+
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
|
| 120 |
+
tokenized_masks[batch_idx][token_idx] = 1
|
| 121 |
+
atom_idx += len(token)
|
| 122 |
+
|
| 123 |
+
return tokenized_masks
|
| 124 |
+
|
| 125 |
+
def collate_fn(self, batch):
|
| 126 |
+
item = batch[0]
|
| 127 |
+
|
| 128 |
+
token_array = self.tokenizer.get_token_split(item['input_ids'])
|
| 129 |
+
bond_mask = self.peptide_token_mask(item['labels'], token_array)
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
'input_ids': item['input_ids'],
|
| 133 |
+
'attention_mask': item['attention_mask'],
|
| 134 |
+
'bond_mask': bond_mask
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def train_dataloader(self):
|
| 138 |
+
train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer)
|
| 139 |
+
return DataLoader(
|
| 140 |
+
train_dataset,
|
| 141 |
+
batch_size=1,
|
| 142 |
+
collate_fn=self.collate_fn, # Use the instance method
|
| 143 |
+
shuffle=True,
|
| 144 |
+
num_workers=12,
|
| 145 |
+
pin_memory=True
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def val_dataloader(self):
|
| 149 |
+
val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer)
|
| 150 |
+
return DataLoader(
|
| 151 |
+
val_dataset,
|
| 152 |
+
batch_size=1,
|
| 153 |
+
collate_fn=self.collate_fn, # Use the instance method
|
| 154 |
+
num_workers=8,
|
| 155 |
+
pin_memory=True
|
| 156 |
+
)
|
dataset.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import re
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import utils
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import lightning.pytorch as pl
|
| 9 |
+
from functools import partial
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
class CustomDataset(Dataset):
|
| 13 |
+
def __init__(self, dataset, indices):
|
| 14 |
+
self.dataset = dataset
|
| 15 |
+
self.indices = indices
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self.indices)
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, idx):
|
| 21 |
+
actual_idx = int(self.indices[idx])
|
| 22 |
+
item = self.dataset[actual_idx]
|
| 23 |
+
return item
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# for weighting losses of peptide bonds
|
| 27 |
+
def peptide_bond_mask(smiles_list):
|
| 28 |
+
"""
|
| 29 |
+
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
|
| 30 |
+
of recognized bonds in the positions dictionary and 0 elsewhere.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
|
| 37 |
+
"""
|
| 38 |
+
# Initialize the batch mask
|
| 39 |
+
batch_size = len(smiles_list)
|
| 40 |
+
max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
|
| 41 |
+
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 42 |
+
|
| 43 |
+
bond_patterns = [
|
| 44 |
+
(r'OC\(=O\)', 'ester'),
|
| 45 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'),
|
| 46 |
+
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
|
| 47 |
+
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
|
| 48 |
+
(r'C\(=O\)N\(C\)', 'n_methyl'),
|
| 49 |
+
(r'C\(=O\)N[12]?', 'peptide')
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
for batch_idx, smiles in enumerate(smiles_list):
|
| 53 |
+
positions = []
|
| 54 |
+
used = set()
|
| 55 |
+
|
| 56 |
+
# Identify bonds
|
| 57 |
+
for pattern, bond_type in bond_patterns:
|
| 58 |
+
for match in re.finditer(pattern, smiles):
|
| 59 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 60 |
+
positions.append({
|
| 61 |
+
'start': match.start(),
|
| 62 |
+
'end': match.end(),
|
| 63 |
+
'type': bond_type,
|
| 64 |
+
'pattern': match.group()
|
| 65 |
+
})
|
| 66 |
+
used.update(range(match.start(), match.end()))
|
| 67 |
+
|
| 68 |
+
# Update the mask for the current SMILES
|
| 69 |
+
for pos in positions:
|
| 70 |
+
mask[batch_idx, pos['start']:pos['end']] = 1
|
| 71 |
+
|
| 72 |
+
return mask
|
| 73 |
+
|
| 74 |
+
def peptide_token_mask(smiles_list, token_lists):
|
| 75 |
+
"""
|
| 76 |
+
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
|
| 77 |
+
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 81 |
+
token_lists: List of tokenized SMILES strings (split into tokens).
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
|
| 85 |
+
"""
|
| 86 |
+
# Initialize the batch mask
|
| 87 |
+
batch_size = len(smiles_list)
|
| 88 |
+
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
|
| 89 |
+
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 90 |
+
atomwise_masks = peptide_bond_mask(smiles_list)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
|
| 94 |
+
token_seq = token_lists[batch_idx]
|
| 95 |
+
atom_idx = 0
|
| 96 |
+
|
| 97 |
+
for token_idx, token in enumerate(token_seq):
|
| 98 |
+
if token_idx != 0 and token_idx != len(token_seq) - 1:
|
| 99 |
+
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
|
| 100 |
+
tokenized_masks[batch_idx][token_idx] = 1
|
| 101 |
+
atom_idx += len(token)
|
| 102 |
+
|
| 103 |
+
return tokenized_masks
|
| 104 |
+
|
| 105 |
+
def extract_amino_acid_sequence(helm_string):
|
| 106 |
+
"""
|
| 107 |
+
Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
|
| 108 |
+
removing any brackets around each amino acid.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
helm_string (str): The HELM notation string for a peptide.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
list: A list containing each amino acid in sequence without brackets.
|
| 115 |
+
"""
|
| 116 |
+
# Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
|
| 117 |
+
matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
|
| 118 |
+
|
| 119 |
+
if matches:
|
| 120 |
+
# Join all matched sequences and split by dots to get individual amino acids
|
| 121 |
+
amino_acid_sequence = []
|
| 122 |
+
for match in matches:
|
| 123 |
+
sequence = match.replace('[', '').replace(']', '').split('.')
|
| 124 |
+
amino_acid_sequence.extend(sequence)
|
| 125 |
+
return amino_acid_sequence
|
| 126 |
+
else:
|
| 127 |
+
return "Invalid HELM notation or no peptide sequence found."
|
| 128 |
+
|
| 129 |
+
def helm_collate_fn(batch, tokenizer):
|
| 130 |
+
sequences = [item['HELM'] for item in batch]
|
| 131 |
+
|
| 132 |
+
max_len = 0
|
| 133 |
+
for sequence in sequences:
|
| 134 |
+
seq_len = len(extract_amino_acid_sequence(sequence))
|
| 135 |
+
if seq_len > max_len:
|
| 136 |
+
max_len = seq_len
|
| 137 |
+
|
| 138 |
+
tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
'input_ids': tokens['input_ids'],
|
| 142 |
+
'attention_mask': tokens['attention_mask']
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def collate_fn(batch, tokenizer):
|
| 147 |
+
"""Standard data collator that truncates/pad sequences based on max_length"""
|
| 148 |
+
valid_sequences = []
|
| 149 |
+
valid_items = []
|
| 150 |
+
|
| 151 |
+
for item in batch:
|
| 152 |
+
try:
|
| 153 |
+
test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
|
| 154 |
+
valid_sequences.append(item['SMILES'])
|
| 155 |
+
valid_items.append(item)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Skipping sequence due to: {str(e)}")
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
#sequences = [item['SMILES'] for item in batch]
|
| 161 |
+
#max_len = max([len(seq) for seq in sequences])
|
| 162 |
+
#labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
|
| 163 |
+
|
| 164 |
+
tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
|
| 165 |
+
|
| 166 |
+
token_array = tokenizer.get_token_split(tokens['input_ids'])
|
| 167 |
+
bond_mask = peptide_token_mask(valid_sequences, token_array)
|
| 168 |
+
#attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
'input_ids': tokens['input_ids'],
|
| 172 |
+
'attention_mask': tokens['attention_mask'],
|
| 173 |
+
'bond_mask': bond_mask
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class CustomDataModule(pl.LightningDataModule):
|
| 178 |
+
def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.train_dataset = train_dataset
|
| 181 |
+
self.val_dataset = val_dataset
|
| 182 |
+
#self.test_dataset = test_dataset
|
| 183 |
+
self.batch_size = batch_size
|
| 184 |
+
self.tokenizer = tokenizer
|
| 185 |
+
self.collate_fn = collate_fn
|
| 186 |
+
|
| 187 |
+
def train_dataloader(self):
|
| 188 |
+
return DataLoader(self.train_dataset,
|
| 189 |
+
batch_size=self.batch_size,
|
| 190 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 191 |
+
num_workers=8,
|
| 192 |
+
pin_memory=True
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def val_dataloader(self):
|
| 197 |
+
return DataLoader(self.val_dataset,
|
| 198 |
+
batch_size=self.batch_size,
|
| 199 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 200 |
+
num_workers=8,
|
| 201 |
+
pin_memory=True
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
"""def test_dataloader(self):
|
| 205 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size,
|
| 206 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 207 |
+
num_workers=8, pin_memory=True)"""
|
diffusion.py
ADDED
|
@@ -0,0 +1,1130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import sys
|
| 3 |
+
import itertools
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import math
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
import random as rd
|
| 11 |
+
import lightning as L
|
| 12 |
+
from torch.distributions.categorical import Categorical
|
| 13 |
+
import torchmetrics
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
import gc
|
| 16 |
+
import pickle
|
| 17 |
+
import utils.utils as utils
|
| 18 |
+
|
| 19 |
+
import dataset as dataloader
|
| 20 |
+
import models.helmgpt as helmgpt
|
| 21 |
+
import models.peptideclm as peptideclm
|
| 22 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 23 |
+
import noise_schedule
|
| 24 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 25 |
+
import models.roformer as roformer
|
| 26 |
+
from utils.filter import PeptideAnalyzer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class Loss:
|
| 31 |
+
loss: torch.FloatTensor
|
| 32 |
+
nlls: torch.FloatTensor
|
| 33 |
+
attn_mask: torch.FloatTensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class BPD(NLL):
|
| 41 |
+
def compute(self) -> Tensor:
|
| 42 |
+
"""Computes the bits per dimension.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
bpd
|
| 46 |
+
"""
|
| 47 |
+
return self.mean_value / self.weight / math.log(2)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Perplexity(NLL):
|
| 51 |
+
def compute(self) -> Tensor:
|
| 52 |
+
"""Computes the Perplexity.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Perplexity
|
| 56 |
+
"""
|
| 57 |
+
return torch.exp(self.mean_value / self.weight)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Diffusion(L.LightningModule):
|
| 61 |
+
def __init__(self, config, tokenizer):
|
| 62 |
+
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.config = config
|
| 65 |
+
#self.save_hyperparameters()
|
| 66 |
+
|
| 67 |
+
# PeptideCLM tokenizer
|
| 68 |
+
self.tokenizer = tokenizer
|
| 69 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 70 |
+
self.mask_token_id = self.tokenizer.mask_token_id
|
| 71 |
+
self.sampler = self.config.sampling.predictor
|
| 72 |
+
self.analyzer = PeptideAnalyzer()
|
| 73 |
+
|
| 74 |
+
# backbone LM PeptideCLM model
|
| 75 |
+
if self.config.backbone == 'peptideclm':
|
| 76 |
+
self.backbone = peptideclm.EncoderWrapper(self.tokenizer)
|
| 77 |
+
self.backbone.unfreeze_all_layers()
|
| 78 |
+
self.backbone = torch.compile(self.backbone)
|
| 79 |
+
elif self.config.backbone == 'helmgpt':
|
| 80 |
+
self.backbone = helmgpt.GPT(self.config, self.tokenizer)
|
| 81 |
+
#self.backbone = torch.compile(self.backbone)
|
| 82 |
+
elif self.config.backbone == 'roformer':
|
| 83 |
+
self.backbone = roformer.Roformer(self.config, self.tokenizer)
|
| 84 |
+
self.backbone.unfreeze_all_layers()
|
| 85 |
+
elif self.config.backbone == 'finetune_roformer':
|
| 86 |
+
self.backbone = roformer.Roformer(self.config, self.tokenizer)
|
| 87 |
+
self.backbone.freeze_model()
|
| 88 |
+
self.backbone.unfreeze_n_layers(n=8)
|
| 89 |
+
else:
|
| 90 |
+
Exception('invalid backbone config')
|
| 91 |
+
|
| 92 |
+
self.neg_infinity = -1000000.0
|
| 93 |
+
self.T = config.T
|
| 94 |
+
# noise schedule for non-peptide bond tokens (default to log-linear)
|
| 95 |
+
self.noise = noise_schedule.get_noise(config)
|
| 96 |
+
# noise schedule for peptide bonds (log-polynomial)
|
| 97 |
+
self.bond_noise = noise_schedule.LogPolyNoise()
|
| 98 |
+
self.time_conditioning = self.config.time_conditioning
|
| 99 |
+
self.fast_forward_epochs = None
|
| 100 |
+
self.fast_forward_batches = None
|
| 101 |
+
|
| 102 |
+
self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
|
| 103 |
+
self.gen_ppl_metric = Perplexity()
|
| 104 |
+
|
| 105 |
+
self.lr = self.config.optim.lr
|
| 106 |
+
self.sampling_eps = self.config.training.sampling_eps
|
| 107 |
+
|
| 108 |
+
metrics = torchmetrics.MetricCollection({
|
| 109 |
+
'nll': NLL(),
|
| 110 |
+
'bpd': BPD(),
|
| 111 |
+
'ppl': Perplexity(),
|
| 112 |
+
})
|
| 113 |
+
metrics.set_dtype(torch.float64)
|
| 114 |
+
self.train_metrics = metrics.clone(prefix='trainer/')
|
| 115 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
| 116 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
"""LOSS"""
|
| 120 |
+
|
| 121 |
+
"""LOSS FOR INVALID PEPTIDES"""
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def conditional_gumbel(self, logits, D, k):
|
| 125 |
+
"""
|
| 126 |
+
Outputs k samples of Q = StandardGumbel(), such that argmax(logits
|
| 127 |
+
+ Q) is given by D (one-hot vector).
|
| 128 |
+
|
| 129 |
+
Input:
|
| 130 |
+
- logits: Tensor of shape (batch_size, seq_len, vocab_size)
|
| 131 |
+
- D: One-hot tensor of shape (batch_size, seq_len, vocab_size)
|
| 132 |
+
- k: Number of Gumbel samples
|
| 133 |
+
|
| 134 |
+
Output:
|
| 135 |
+
- Adjusted logits with shape (k, batch_size, seq_len, vocab_size)
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# iid. exponential samples of shape (k, batch_size, seq_len, vocab_size)
|
| 139 |
+
E = torch.distributions.exponential.Exponential(rate=torch.ones_like(logits)).sample([k])
|
| 140 |
+
|
| 141 |
+
# E of the chosen class, shape (k, batch_size, seq_len, 1)
|
| 142 |
+
Ei = (D * E).sum(dim=-1, keepdim=True)
|
| 143 |
+
|
| 144 |
+
# Partition function (normalization constant), shape (batch_size, seq_len, 1)
|
| 145 |
+
Z = logits.exp().sum(dim=-1, keepdim=True)
|
| 146 |
+
|
| 147 |
+
# Adjusted logits for Gumbel distribution
|
| 148 |
+
adjusted = (
|
| 149 |
+
D * (-torch.log(Ei) + torch.log(Z)) +
|
| 150 |
+
(1 - D) * -torch.log(E / logits.exp() + Ei / Z)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Adjusted logits shape: (k, batch_size, seq_len, vocab_size)
|
| 154 |
+
return adjusted - logits
|
| 155 |
+
|
| 156 |
+
def replace_gradient(self, value, surrogate):
|
| 157 |
+
"""
|
| 158 |
+
Returns `value` but backpropagates gradients through `surrogate`.
|
| 159 |
+
"""
|
| 160 |
+
return surrogate + (value - surrogate).detach()
|
| 161 |
+
|
| 162 |
+
def gumbel_rao(self, logits, k, temp=1.0, I=None):
|
| 163 |
+
"""
|
| 164 |
+
Returns a categorical sample from logits (over axis=-1) as a
|
| 165 |
+
one-hot vector, with gumbel-rao gradient.
|
| 166 |
+
|
| 167 |
+
Input:
|
| 168 |
+
- logits: Tensor of shape (batch_size, seq_len, vocab_size)
|
| 169 |
+
- k: Number of Gumbel samples for Rao-Blackwellization
|
| 170 |
+
- temp: Temperature for softmax
|
| 171 |
+
- I: Optional, precomputed categorical sample tensor of shape (batch_size, seq_len)
|
| 172 |
+
|
| 173 |
+
Output:
|
| 174 |
+
- One-hot tensor of shape (batch_size, seq_len, vocab_size)
|
| 175 |
+
with Gumbel-Rao gradient.
|
| 176 |
+
"""
|
| 177 |
+
assert logits.shape[-1] == self.tokenizer.vocab_size
|
| 178 |
+
vocab_size = logits.shape[-1]
|
| 179 |
+
|
| 180 |
+
if I is None:
|
| 181 |
+
# Sample indices for each token in the batch
|
| 182 |
+
I = torch.distributions.categorical.Categorical(logits=logits).sample() # (batch_size, seq_len)
|
| 183 |
+
|
| 184 |
+
# Convert indices to one-hot encodings, shape (batch_size, seq_len, vocab_size)
|
| 185 |
+
D = torch.nn.functional.one_hot(I, num_classes=vocab_size).float()
|
| 186 |
+
|
| 187 |
+
# Generate k different adjusted logits that all evaluate to the same sequence
|
| 188 |
+
adjusted = logits + self.conditional_gumbel(logits, D, k=k) # (k, batch_size, seq_len, vocab_size)
|
| 189 |
+
|
| 190 |
+
# Compute the surrogate by averaging softmax across k samples
|
| 191 |
+
surrogate = torch.nn.functional.softmax(adjusted / temp, dim=-1).mean(dim=0) # (batch_size, seq_len, vocab_size)
|
| 192 |
+
|
| 193 |
+
# Return one-hot representation with surrogate gradient
|
| 194 |
+
return self.replace_gradient(D, surrogate)
|
| 195 |
+
|
| 196 |
+
def compute_invalid_loss(self, logits, k=None, temp=None):
|
| 197 |
+
"""
|
| 198 |
+
Penalizes logits that produce invalid sequences using the `is_peptide` function,
|
| 199 |
+
scaling penalties inversely with token probabilities.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
logits: Tensor of shape [batch_size, seq_len, vocab_size].
|
| 203 |
+
k: Number of samples for Gumbel-Rao.
|
| 204 |
+
temp: Temperature for softmax.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
loss: A scalar tensor representing the total loss for invalid sequences.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
#samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
|
| 211 |
+
|
| 212 |
+
# Convert logits to sequences using the tokenizer
|
| 213 |
+
batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
|
| 214 |
+
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
|
| 215 |
+
|
| 216 |
+
# Check validity of each sampled sequence (not differentiable)
|
| 217 |
+
penalties = torch.tensor(
|
| 218 |
+
[1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
|
| 219 |
+
dtype=torch.float32,
|
| 220 |
+
device=self.device
|
| 221 |
+
)
|
| 222 |
+
#print(penalties)
|
| 223 |
+
|
| 224 |
+
# Compute probabilities for each token (batch_size, seq_length)
|
| 225 |
+
sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
|
| 226 |
+
|
| 227 |
+
# scale penalties by softmax probability of sampled tokens
|
| 228 |
+
scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
|
| 229 |
+
|
| 230 |
+
return scaled_penalty.to(self.device)
|
| 231 |
+
|
| 232 |
+
"""DIFFUSION LOSS"""
|
| 233 |
+
|
| 234 |
+
def sample_t(self, n, device):
|
| 235 |
+
"""
|
| 236 |
+
Sample random time steps for batch training
|
| 237 |
+
"""
|
| 238 |
+
# sample values uniformly at random from [0, 1)
|
| 239 |
+
eps_t = torch.rand(n, device=device)
|
| 240 |
+
# antithetic sampling: reduce variance by pairing each sample with complementary sample
|
| 241 |
+
if self.config.training.antithetic_sampling:
|
| 242 |
+
# compute interval between sampled time steps
|
| 243 |
+
offset = torch.arange(n, device=device) / n
|
| 244 |
+
# ensure that each eps value is evenly spaced between [0, 1)
|
| 245 |
+
eps_t = ((eps_t / n) + offset) % 1
|
| 246 |
+
|
| 247 |
+
# ensures values are not exactly 0 or 1
|
| 248 |
+
t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
|
| 249 |
+
|
| 250 |
+
return t
|
| 251 |
+
|
| 252 |
+
"""def mask_samples(self, x0, mask_prob):
|
| 253 |
+
|
| 254 |
+
# generate array of values in range [0, 1] uniformly at random
|
| 255 |
+
# will be used to determine which tokens are masked
|
| 256 |
+
mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
|
| 257 |
+
|
| 258 |
+
# select tokens to mask if the random value in mask_indices is less than mask_prob
|
| 259 |
+
# this will mask approximately the fraction of tokens indicated by mask_prob
|
| 260 |
+
zt = torch.where(mask_indices < mask_prob, self.mask_token_id, x0)
|
| 261 |
+
|
| 262 |
+
return zt"""
|
| 263 |
+
|
| 264 |
+
def q_xt(self, x, mask_prob):
|
| 265 |
+
"""Computes the noisy sample xt.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
x: int torch.Tensor with shape (batch_size,
|
| 269 |
+
diffusion_model_input_length), input.
|
| 270 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
|
| 274 |
+
#print(actual_seq_length)
|
| 275 |
+
|
| 276 |
+
max_mask_length = (actual_seq_length * 0.75).long()
|
| 277 |
+
|
| 278 |
+
mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
|
| 279 |
+
|
| 280 |
+
restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
|
| 281 |
+
|
| 282 |
+
for i in range(x.shape[0]):
|
| 283 |
+
true_positions = torch.where(mask_indices[i])[0]
|
| 284 |
+
if len(true_positions) > max_mask_length[i]:
|
| 285 |
+
selected_positions = true_positions[:max_mask_length[i].item()]
|
| 286 |
+
restricted_move_indices[i, selected_positions] = True
|
| 287 |
+
else:
|
| 288 |
+
restricted_move_indices[i] = mask_indices[i]
|
| 289 |
+
|
| 290 |
+
xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
|
| 291 |
+
|
| 292 |
+
return xt
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def sample_prior(self, *batch_dims):
|
| 296 |
+
"""
|
| 297 |
+
Returns array of fully masked sequences with same shape as input
|
| 298 |
+
"""
|
| 299 |
+
return self.mask_token_id * torch.ones(* batch_dims, dtype=torch.int64)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
"""COMPUTING LOSS"""
|
| 303 |
+
|
| 304 |
+
def compute_diffusion_loss(self, model_output, xt, x0, t):
|
| 305 |
+
"""
|
| 306 |
+
Computes diffusion loss term in ELBO
|
| 307 |
+
(evaluates how accurately the model predicts the token probabilities at each time step)
|
| 308 |
+
|
| 309 |
+
Inputs:
|
| 310 |
+
- model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
|
| 311 |
+
- zt: corrupted version of original input x0 at timestep t
|
| 312 |
+
- x0: original input sequence
|
| 313 |
+
- t: timestep
|
| 314 |
+
"""
|
| 315 |
+
# compute interval between each timestep
|
| 316 |
+
dt = 1 / self.T
|
| 317 |
+
|
| 318 |
+
# compute vectorized alpha scaling terms for the logits at timestep s and t
|
| 319 |
+
alpha_t = 1 - t + torch.zeros_like(x0)
|
| 320 |
+
# s = t - dt
|
| 321 |
+
alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
|
| 322 |
+
|
| 323 |
+
# gather vector of log-probabilities for each token in x0
|
| 324 |
+
# log<x_theta, x>
|
| 325 |
+
log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
|
| 326 |
+
# gather log-probabillities for assigning a masked token at each position in the sequence at time t
|
| 327 |
+
# log<x_theta, m>
|
| 328 |
+
log_x_theta_at_m = model_output[:, :, self.mask_token_id]
|
| 329 |
+
# obtain non-log probability of assigning a masked token
|
| 330 |
+
# <xt, m>
|
| 331 |
+
x_theta_at_m = log_x_theta_at_m.exp()
|
| 332 |
+
|
| 333 |
+
# first term of diffusion loss
|
| 334 |
+
term_1_coef = dt / t
|
| 335 |
+
term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
|
| 336 |
+
term_1_log_denom = log_x_theta_at_x0
|
| 337 |
+
|
| 338 |
+
# second term of diffusion loss
|
| 339 |
+
term_2_coef = 1 - (dt / t)
|
| 340 |
+
term_2_log_numerator = term_1_log_numerator
|
| 341 |
+
term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
|
| 342 |
+
|
| 343 |
+
L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
|
| 344 |
+
term_2_coef * (term_2_log_numerator - term_2_log_denom))
|
| 345 |
+
|
| 346 |
+
# multiply by <zt, m> term
|
| 347 |
+
L_vb = L_vb_masked * (xt == self.mask_token_id)
|
| 348 |
+
|
| 349 |
+
# scale by T and return
|
| 350 |
+
return self.T * L_vb
|
| 351 |
+
|
| 352 |
+
"""def _forward_pass_diffusion(self, x0, attn_mask, mask=None):
|
| 353 |
+
|
| 354 |
+
print(x0)
|
| 355 |
+
# randomly sample time steps to start the denoising process for each x0 in batch
|
| 356 |
+
t = self.sample_t(x0.shape[0], x0.device)
|
| 357 |
+
|
| 358 |
+
# if we are training the intermediate transition blocks
|
| 359 |
+
if self.T > 0:
|
| 360 |
+
# scale by total timesteps T and cast to integer
|
| 361 |
+
t = (t * self.T).to(torch.int)
|
| 362 |
+
# scale down by T to get a multiple of 1/T
|
| 363 |
+
t = t / self.T
|
| 364 |
+
# add 1/T to ensure no 0 values
|
| 365 |
+
t += (1 / self.T)
|
| 366 |
+
|
| 367 |
+
# get noise and rate of noise at timestep t
|
| 368 |
+
sigma, dsigma = self.noise(t)
|
| 369 |
+
time_conditioning = sigma[:, None]
|
| 370 |
+
# get masking probabilities for all tokens for each batch
|
| 371 |
+
mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
|
| 372 |
+
|
| 373 |
+
# get masked samples at different timesteps
|
| 374 |
+
if mask is None: zt = self.q_xt(x0, mask_prob)
|
| 375 |
+
else: zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id))
|
| 376 |
+
|
| 377 |
+
model_output = self.forward(zt, attn_mask, time_conditioning)
|
| 378 |
+
|
| 379 |
+
utils.print_nans(model_output, 'model_output')
|
| 380 |
+
|
| 381 |
+
if self.T > 0:
|
| 382 |
+
# compute diffusion loss
|
| 383 |
+
diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
|
| 384 |
+
return diffusion_loss
|
| 385 |
+
|
| 386 |
+
# compute loss for the final that converts from z0 to x0
|
| 387 |
+
# -log(p_theta)
|
| 388 |
+
# get (batch_size, L) array of log-probabilities
|
| 389 |
+
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) # (B, L)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
return -log_p_theta * (dsigma / torch.expm1(sigma))[:, None]"""
|
| 393 |
+
|
| 394 |
+
def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 395 |
+
"""
|
| 396 |
+
Training reverse diffusion model x_theta to reconstruct samples x0
|
| 397 |
+
|
| 398 |
+
bond_mask: (batch, seq_length)
|
| 399 |
+
"""
|
| 400 |
+
# randomly sample time steps to start the denoising process for each x0 in batch
|
| 401 |
+
t = self.sample_t(x0.shape[0], self.device)
|
| 402 |
+
|
| 403 |
+
# if we are training the intermediate transition blocks
|
| 404 |
+
if self.T > 0:
|
| 405 |
+
# scale by total timesteps T and cast to integer
|
| 406 |
+
t = (t * self.T).to(torch.int)
|
| 407 |
+
# scale down by T to get a multiple of 1/T
|
| 408 |
+
t = t / self.T
|
| 409 |
+
# add 1/T to ensure no 0 values
|
| 410 |
+
t += (1 / self.T)
|
| 411 |
+
|
| 412 |
+
# get noise and rate of noise at timestep t
|
| 413 |
+
# sigma = -log(1-t); dsigma = 1 / (1-t)
|
| 414 |
+
sigma, dsigma = self.noise(t)
|
| 415 |
+
time_conditioning = sigma[:, None]
|
| 416 |
+
|
| 417 |
+
# Get masking probabilities for all tokens for each batch
|
| 418 |
+
# log-linear: 1 - alpha = t
|
| 419 |
+
base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
|
| 420 |
+
|
| 421 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 422 |
+
# log-polynomial masking schedule: alpha = 1 - t^w
|
| 423 |
+
# bond_sigma = -log(1-t^w) for w = 3 (default)
|
| 424 |
+
# bond_dsigma = -wt^(w-1) / (1-t^w)
|
| 425 |
+
bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
|
| 426 |
+
# expand dimensions for broadcasting to (B, L)
|
| 427 |
+
bond_sigma = bond_sigma[:, None]
|
| 428 |
+
bond_dsigma = bond_dsigma[:, None]
|
| 429 |
+
sigma = sigma[:, None]
|
| 430 |
+
dsigma = dsigma[:, None]
|
| 431 |
+
|
| 432 |
+
# compute masking probability for peptide bonds 1 - bond_alpha = t^w
|
| 433 |
+
bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
|
| 434 |
+
# piece together (B, L) tensor with modified masking prob at peptide-bond locations
|
| 435 |
+
mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
|
| 436 |
+
#print(mask_prob)
|
| 437 |
+
dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
|
| 438 |
+
sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
|
| 439 |
+
else:
|
| 440 |
+
mask_prob = base_mask_prob.to(self.device)
|
| 441 |
+
|
| 442 |
+
# get masked samples at different timesteps
|
| 443 |
+
if mask is None:
|
| 444 |
+
zt = self.q_xt(x0, mask_prob).to(self.device)
|
| 445 |
+
else:
|
| 446 |
+
zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id)).to(self.device)
|
| 447 |
+
|
| 448 |
+
model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
|
| 449 |
+
|
| 450 |
+
# debugging
|
| 451 |
+
assert not torch.isnan(model_output).any()
|
| 452 |
+
assert model_output.is_cuda
|
| 453 |
+
utils.print_nans(model_output, 'model_output')
|
| 454 |
+
|
| 455 |
+
# compute invalid loss
|
| 456 |
+
invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
|
| 457 |
+
#print(invalid_loss)
|
| 458 |
+
|
| 459 |
+
if self.T > 0:
|
| 460 |
+
# compute diffusion loss
|
| 461 |
+
diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
|
| 462 |
+
return diffusion_loss
|
| 463 |
+
|
| 464 |
+
# compute loss for the final that converts from z0 to x0
|
| 465 |
+
# -log(p_theta)
|
| 466 |
+
# get (batch_size, L) array of log-probabilities
|
| 467 |
+
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
|
| 468 |
+
|
| 469 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 470 |
+
return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
|
| 471 |
+
else:
|
| 472 |
+
return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
|
| 473 |
+
|
| 474 |
+
def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 475 |
+
loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
|
| 476 |
+
|
| 477 |
+
# negative log loss
|
| 478 |
+
nlls = loss * attn_mask
|
| 479 |
+
|
| 480 |
+
# count number of tokens
|
| 481 |
+
num_tokens = attn_mask.sum()
|
| 482 |
+
|
| 483 |
+
# compute batch loss
|
| 484 |
+
batch_nll = nlls.sum()
|
| 485 |
+
# compute per token loss
|
| 486 |
+
token_nll = batch_nll / num_tokens
|
| 487 |
+
# return losses
|
| 488 |
+
return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
|
| 489 |
+
|
| 490 |
+
def _compute_loss(self, batch, prefix, bond_mask=None):
|
| 491 |
+
|
| 492 |
+
attn_mask = batch['attention_mask'].to(self.device)
|
| 493 |
+
|
| 494 |
+
if 'mask' in batch:
|
| 495 |
+
mask = batch['mask'].to(self.device)
|
| 496 |
+
else:
|
| 497 |
+
mask = None
|
| 498 |
+
|
| 499 |
+
if 'bond_mask' in batch:
|
| 500 |
+
bond_mask = batch['bond_mask'].to(self.device)
|
| 501 |
+
else:
|
| 502 |
+
bond_mask = None
|
| 503 |
+
|
| 504 |
+
losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
|
| 505 |
+
loss = losses.loss
|
| 506 |
+
|
| 507 |
+
if prefix == 'train':
|
| 508 |
+
self.train_metrics.update(
|
| 509 |
+
losses.nlls.to(self.device),
|
| 510 |
+
losses.attn_mask.to(self.device)
|
| 511 |
+
)
|
| 512 |
+
metrics = self.train_metrics
|
| 513 |
+
elif prefix == 'val':
|
| 514 |
+
self.valid_metrics.update(
|
| 515 |
+
losses.nlls.to(self.device),
|
| 516 |
+
losses.attn_mask.to(self.device)
|
| 517 |
+
)
|
| 518 |
+
metrics = self.valid_metrics
|
| 519 |
+
elif prefix == 'test':
|
| 520 |
+
self.test_metrics.update(losses.nlls, losses.attn_mask)
|
| 521 |
+
metrics = self.test_metrics
|
| 522 |
+
else:
|
| 523 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
| 524 |
+
|
| 525 |
+
self.log_dict(metrics,
|
| 526 |
+
on_step=False,
|
| 527 |
+
on_epoch=True,
|
| 528 |
+
sync_dist=True)
|
| 529 |
+
|
| 530 |
+
return loss
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
"""SAMPLING"""
|
| 534 |
+
|
| 535 |
+
def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
|
| 536 |
+
# get number of timesteps
|
| 537 |
+
if sample_steps is None:
|
| 538 |
+
sample_steps = self.config.sampling.steps
|
| 539 |
+
|
| 540 |
+
if seq_length is None:
|
| 541 |
+
seq_length = self.config.sampling.seq_length
|
| 542 |
+
|
| 543 |
+
# sample fully masked sequences
|
| 544 |
+
z = self.sample_prior(num_samples, seq_length).to(self.device)
|
| 545 |
+
|
| 546 |
+
# create vector of sample_steps timesteps
|
| 547 |
+
timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
|
| 548 |
+
|
| 549 |
+
# compute interval between timesteps
|
| 550 |
+
dt = (1 - eps) / sample_steps
|
| 551 |
+
|
| 552 |
+
for i in range(sample_steps):
|
| 553 |
+
t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
|
| 554 |
+
|
| 555 |
+
z = self.single_reverse_step(z, t, dt)
|
| 556 |
+
|
| 557 |
+
return z
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
"""SAMPLING STEP"""
|
| 561 |
+
|
| 562 |
+
def single_reverse_step(self, zt, t, dt, attn_mask=None):
|
| 563 |
+
"""
|
| 564 |
+
Take a single reverse diffusion step for the expansion step of the MCTS algorithm
|
| 565 |
+
"""
|
| 566 |
+
# get sigma values that determine masking prob
|
| 567 |
+
sigma_t, _ = self.noise(t)
|
| 568 |
+
sigma_s, _ = self.noise(t - dt)
|
| 569 |
+
|
| 570 |
+
# reshape sigmas
|
| 571 |
+
if sigma_t.ndim > 1:
|
| 572 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 573 |
+
if sigma_s.ndim > 1:
|
| 574 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 575 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 576 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 577 |
+
|
| 578 |
+
# compute masking probabilities for each timestep
|
| 579 |
+
change_prob_t = 1 - torch.exp(-sigma_t)
|
| 580 |
+
change_prob_s = 1 - torch.exp(-sigma_s)
|
| 581 |
+
|
| 582 |
+
# expand dimensions
|
| 583 |
+
change_prob_t = change_prob_t[:, None, None]
|
| 584 |
+
change_prob_s = change_prob_s[:, None, None]
|
| 585 |
+
|
| 586 |
+
# get prodiction model that outputs token probabilities
|
| 587 |
+
log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
|
| 588 |
+
|
| 589 |
+
# check dimensions match
|
| 590 |
+
assert change_prob_t.ndim == log_p_x0.ndim
|
| 591 |
+
|
| 592 |
+
# compute reverse diffusion probability of being unmasked at timestep s
|
| 593 |
+
# (sigma_s - sigma_t)*x_theta
|
| 594 |
+
q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
|
| 595 |
+
|
| 596 |
+
# compute reverse diffusion probability of remaining masked at timestep s
|
| 597 |
+
# (1 - sigma_s)*m
|
| 598 |
+
q_zs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
|
| 599 |
+
|
| 600 |
+
# sample sequence at timestep s from categorical distribution of q_zs
|
| 601 |
+
z_changed = sample_categorical(q_zs)
|
| 602 |
+
|
| 603 |
+
copy_flag = (zt != self.mask_token_id).to(zt.dtype)
|
| 604 |
+
return (copy_flag * zt) + ((1 - copy_flag) * z_changed)
|
| 605 |
+
|
| 606 |
+
def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
|
| 607 |
+
assert self.config.noise.type == 'loglinear'
|
| 608 |
+
sigma_t, _ = self.noise(t)
|
| 609 |
+
|
| 610 |
+
if t.ndim > 1:
|
| 611 |
+
t = t.squeeze(-1)
|
| 612 |
+
assert t.ndim == 1
|
| 613 |
+
|
| 614 |
+
change_prob_t = t[:, None, None]
|
| 615 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 616 |
+
|
| 617 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 618 |
+
|
| 619 |
+
if p_x0 is None:
|
| 620 |
+
p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 621 |
+
|
| 622 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 623 |
+
|
| 624 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 625 |
+
|
| 626 |
+
# zero-masking probability
|
| 627 |
+
q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
|
| 628 |
+
|
| 629 |
+
x_changed = sample_categorical(q_xs)
|
| 630 |
+
|
| 631 |
+
copy_flag = (x != self.mask_token_id).to(x.dtype)
|
| 632 |
+
|
| 633 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
|
| 634 |
+
|
| 635 |
+
# first step in expansion
|
| 636 |
+
def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
|
| 637 |
+
"""
|
| 638 |
+
Generates batch_size different samples from the same starting point for the
|
| 639 |
+
first expansion step of MCTS
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
x (_type_): _description_
|
| 643 |
+
t (_type_): _description_
|
| 644 |
+
dt (_type_): _description_
|
| 645 |
+
batch_size (_type_): _description_
|
| 646 |
+
p_x0 (_type_, optional): _description_. Defaults to None.
|
| 647 |
+
attn_mask (_type_, optional): _description_. Defaults to None.
|
| 648 |
+
|
| 649 |
+
Returns:
|
| 650 |
+
_type_: _description_
|
| 651 |
+
"""
|
| 652 |
+
|
| 653 |
+
assert self.config.noise.type == 'loglinear'
|
| 654 |
+
sigma_t, _ = self.noise(t)
|
| 655 |
+
|
| 656 |
+
if t.ndim > 1:
|
| 657 |
+
t = t.squeeze(-1)
|
| 658 |
+
assert t.ndim == 1
|
| 659 |
+
|
| 660 |
+
change_prob_t = t[:, None, None]
|
| 661 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 662 |
+
|
| 663 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 664 |
+
|
| 665 |
+
if token_array.dim() == 1:
|
| 666 |
+
token_array = token_array.unsqueeze(0)
|
| 667 |
+
#token_array = token_array.repeat(batch_size, 1)
|
| 668 |
+
|
| 669 |
+
attn_mask = torch.ones_like(token_array)
|
| 670 |
+
|
| 671 |
+
if p_x0 is None:
|
| 672 |
+
p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 673 |
+
|
| 674 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 675 |
+
|
| 676 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 677 |
+
|
| 678 |
+
# zero-masking probability
|
| 679 |
+
q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
|
| 680 |
+
|
| 681 |
+
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
|
| 682 |
+
token_array = token_array.repeat(batch_size, 1)
|
| 683 |
+
|
| 684 |
+
if self.config.mcts.sampling == 0:
|
| 685 |
+
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
|
| 686 |
+
else:
|
| 687 |
+
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
|
| 688 |
+
|
| 689 |
+
copy_flag = (token_array != self.mask_token_id).to(token_array.dtype)
|
| 690 |
+
|
| 691 |
+
return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
|
| 692 |
+
|
| 693 |
+
def _process_sigma(self, sigma):
|
| 694 |
+
if sigma.ndim > 1:
|
| 695 |
+
sigma = sigma.squeeze(-1)
|
| 696 |
+
if not self.time_conditioning:
|
| 697 |
+
sigma = torch.zeros_like(sigma)
|
| 698 |
+
assert sigma.ndim == 1, sigma.shape
|
| 699 |
+
return sigma
|
| 700 |
+
|
| 701 |
+
def forward(self, zt, attn_mask, sigma):
|
| 702 |
+
"""
|
| 703 |
+
Predicts the token log-probabilities from zt at time t with noise schedule sigma
|
| 704 |
+
"""
|
| 705 |
+
sigma = self._process_sigma(sigma)
|
| 706 |
+
|
| 707 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.float32, cache_enabled=True):
|
| 708 |
+
logits = self.backbone(zt, attn_mask).to(self.device)
|
| 709 |
+
|
| 710 |
+
return self.subs_parameterization(logits, zt)
|
| 711 |
+
|
| 712 |
+
def subs_parameterization(self, logits, zt):
|
| 713 |
+
"""
|
| 714 |
+
Updates reverse diffusion logits based on SUBS parameterization:
|
| 715 |
+
- zero masking probabilities: -infinity probability of being masked during reverse diffusion
|
| 716 |
+
- carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
logits: vector of token probabilities for unmasking masked tokens
|
| 720 |
+
zt: partially unmasked sequence at current timestep
|
| 721 |
+
"""
|
| 722 |
+
logits[:, :, self.mask_token_id] += self.neg_infinity # [sequence index, current token, next token]
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
unmasked_indices = (zt != self.mask_token_id).to(self.device) # shape: [200, seq_length]
|
| 729 |
+
batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
|
| 730 |
+
batch_idx = batch_idx.to(self.device)
|
| 731 |
+
seq_idx = seq_idx.to(self.device)
|
| 732 |
+
tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
|
| 733 |
+
|
| 734 |
+
assert logits.is_contiguous(), "logits tensor is not contiguous"
|
| 735 |
+
assert unmasked_indices.shape == zt.shape, "same shape"
|
| 736 |
+
assert not torch.isnan(logits).any(), "NaN values found in logits"
|
| 737 |
+
assert tokens.max() < logits.shape[-1], "token indices out of bounds"
|
| 738 |
+
assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
|
| 739 |
+
assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
|
| 740 |
+
assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
|
| 741 |
+
|
| 742 |
+
logits[batch_idx, seq_idx] = self.neg_infinity # Set everything to -inf first
|
| 743 |
+
logits[batch_idx, seq_idx, tokens] = 0 # Set only the specific token positions to 0
|
| 744 |
+
# return logits with SUBS parameterization
|
| 745 |
+
return logits.to(self.device)
|
| 746 |
+
|
| 747 |
+
"""SAMPLING"""
|
| 748 |
+
@torch.no_grad()
|
| 749 |
+
def _sample(self, num_steps=None, eps=1e-5, x_input=None):
|
| 750 |
+
"""
|
| 751 |
+
Generate samples
|
| 752 |
+
"""
|
| 753 |
+
batch_size_per_gpu = self.config.eval.perplexity_batch_size
|
| 754 |
+
|
| 755 |
+
if num_steps is None:
|
| 756 |
+
num_steps = self.config.sampling.steps
|
| 757 |
+
|
| 758 |
+
if x_input is not None:
|
| 759 |
+
x = x_input['input_ids'].to(self.device)
|
| 760 |
+
attn_mask = x_input['attention_mask'].to(self.device)
|
| 761 |
+
else:
|
| 762 |
+
x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
|
| 763 |
+
attn_mask = torch.ones_like(x).to(self.device)
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
|
| 767 |
+
dt = (1 - eps) / num_steps
|
| 768 |
+
p_x0_cache = None
|
| 769 |
+
generation_history = [] # used to track which tokens are unmasked
|
| 770 |
+
|
| 771 |
+
for i in range(num_steps):
|
| 772 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
|
| 773 |
+
if self.sampler == 'ddpm':
|
| 774 |
+
x = self.single_reverse_step(x, t, dt).to(self.device)
|
| 775 |
+
|
| 776 |
+
elif self.sampler == 'ddpm_cache':
|
| 777 |
+
p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
|
| 778 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 779 |
+
# Disable caching
|
| 780 |
+
p_x0_cache = None
|
| 781 |
+
x = x_next.to(self.device)
|
| 782 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 783 |
+
else:
|
| 784 |
+
x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
|
| 785 |
+
|
| 786 |
+
if self.config.sampling.noise_removal:
|
| 787 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
|
| 788 |
+
if self.sampler == 'analytic':
|
| 789 |
+
x = self._denoiser_update(x, t).to(self.device)
|
| 790 |
+
else:
|
| 791 |
+
time_conditioning = self.noise(t)[0].to(self.device)
|
| 792 |
+
x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
|
| 793 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 794 |
+
return x.to(self.device)
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
def restore_model_and_sample(self, num_steps, eps=1e-5):
|
| 798 |
+
"""Generate samples from the model."""
|
| 799 |
+
self.backbone.eval()
|
| 800 |
+
self.noise.eval()
|
| 801 |
+
samples = self._sample(num_steps=num_steps, eps=eps)
|
| 802 |
+
self.backbone.train()
|
| 803 |
+
self.noise.train()
|
| 804 |
+
return samples
|
| 805 |
+
|
| 806 |
+
def get_score(self, zt, sigma, attn_mask=None):
|
| 807 |
+
|
| 808 |
+
# score(x, t) = p_t(y) / p_t(x)
|
| 809 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
| 810 |
+
|
| 811 |
+
# case 1: x = masked
|
| 812 |
+
# (i) y = unmasked
|
| 813 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
| 814 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 815 |
+
# (ii) y = masked
|
| 816 |
+
# log score(x, t) = 0
|
| 817 |
+
|
| 818 |
+
# case 2: x = unmasked
|
| 819 |
+
# (i) y != masked, y != x
|
| 820 |
+
# log score(x_i, t) = - inf
|
| 821 |
+
# (ii) y = x
|
| 822 |
+
# log score(x_i, t) = 0
|
| 823 |
+
# (iii) y = masked token
|
| 824 |
+
# log score(x_i, t) = - log k
|
| 825 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 826 |
+
|
| 827 |
+
model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
|
| 828 |
+
|
| 829 |
+
log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
|
| 830 |
+
assert log_k.ndim == 1
|
| 831 |
+
|
| 832 |
+
masked_score = model_output + log_k[:, None, None]
|
| 833 |
+
masked_score[:, :, self.mask_token_id] = 0
|
| 834 |
+
|
| 835 |
+
unmasked_score = self.neg_infinity * torch.ones_like(model_output)
|
| 836 |
+
unmasked_score = torch.scatter(
|
| 837 |
+
unmasked_score, -1,
|
| 838 |
+
zt[..., None],
|
| 839 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
| 840 |
+
|
| 841 |
+
unmasked_score[:, :, self.mask_token_id] = - (log_k[:, None] * torch.ones_like(zt))
|
| 842 |
+
|
| 843 |
+
masked_indices = (zt == self.mask_token_id).to(model_output.dtype)[:, :, None]
|
| 844 |
+
|
| 845 |
+
model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
|
| 846 |
+
|
| 847 |
+
return model_output.exp()
|
| 848 |
+
|
| 849 |
+
def _staggered_score(self, score, dsigma):
|
| 850 |
+
score = score.clone()
|
| 851 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
| 852 |
+
score *= dsigma.exp()[:, None]
|
| 853 |
+
score[..., self.mask_token_id] += extra_const
|
| 854 |
+
return score
|
| 855 |
+
|
| 856 |
+
def _analytic_update(self, x, t, step_size, attn_mask=None):
|
| 857 |
+
curr_sigma, _ = self.noise(t)
|
| 858 |
+
next_sigma, _ = self.noise(t - step_size)
|
| 859 |
+
dsigma = curr_sigma - next_sigma
|
| 860 |
+
score = self.get_score(x, attn_mask, curr_sigma)
|
| 861 |
+
stag_score = self._staggered_score(score, dsigma)
|
| 862 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
| 863 |
+
return sample_categorical(probs)
|
| 864 |
+
|
| 865 |
+
def _denoiser_update(self, x, t):
|
| 866 |
+
sigma, _ = self.noise(t)
|
| 867 |
+
score = self.get_score(x, sigma)
|
| 868 |
+
stag_score = self._staggered_score(score, sigma)
|
| 869 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
| 870 |
+
probs[..., self.mask_token_id] = 0
|
| 871 |
+
samples = sample_categorical(probs)
|
| 872 |
+
return samples
|
| 873 |
+
|
| 874 |
+
def _transp_transition(self, i, sigma):
|
| 875 |
+
sigma = unsqueeze(sigma, reference=i[..., None])
|
| 876 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
| 877 |
+
i, num_classes=self.vocab_size)
|
| 878 |
+
edge += torch.where(i == self.mask_token_id,
|
| 879 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
| 880 |
+
0)[..., None]
|
| 881 |
+
return edge
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
"""TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
|
| 885 |
+
|
| 886 |
+
def on_train_epoch_start(self):
|
| 887 |
+
torch.cuda.empty_cache()
|
| 888 |
+
self.backbone.train()
|
| 889 |
+
self.noise.train()
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def training_step(self, batch, batch_idx):
|
| 893 |
+
# Initialize throughput calculation
|
| 894 |
+
start_time = time.time()
|
| 895 |
+
|
| 896 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 897 |
+
loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
|
| 898 |
+
else:
|
| 899 |
+
loss = self._compute_loss(batch, prefix='train')
|
| 900 |
+
|
| 901 |
+
self.log(name='trainer/loss',
|
| 902 |
+
value=loss.item(),
|
| 903 |
+
on_step=True,
|
| 904 |
+
on_epoch=False,
|
| 905 |
+
sync_dist=True)
|
| 906 |
+
|
| 907 |
+
# Calculate throughput
|
| 908 |
+
elapsed_time = time.time() - start_time
|
| 909 |
+
total_tokens = batch['input_ids'].numel()
|
| 910 |
+
throughput = total_tokens / elapsed_time
|
| 911 |
+
|
| 912 |
+
self.log(name='trainer/throughput',
|
| 913 |
+
value=throughput,
|
| 914 |
+
on_step=True,
|
| 915 |
+
on_epoch=False,
|
| 916 |
+
sync_dist=True)
|
| 917 |
+
|
| 918 |
+
return loss
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
def on_load_checkpoint(self, checkpoint):
|
| 922 |
+
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
|
| 923 |
+
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
|
| 924 |
+
|
| 925 |
+
"""VALIDATION"""
|
| 926 |
+
def on_validation_epoch_start(self):
|
| 927 |
+
gc.collect()
|
| 928 |
+
torch.cuda.empty_cache()
|
| 929 |
+
self.backbone.eval()
|
| 930 |
+
self.noise.eval()
|
| 931 |
+
assert self.valid_metrics.nll.mean_value == 0
|
| 932 |
+
assert self.valid_metrics.nll.weight == 0
|
| 933 |
+
|
| 934 |
+
def validation_step(self, batch, batch_idx):
|
| 935 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 936 |
+
loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
|
| 937 |
+
else:
|
| 938 |
+
loss = self._compute_loss(batch, prefix='val')
|
| 939 |
+
|
| 940 |
+
self.log(name='trainer/val_loss',
|
| 941 |
+
value=loss.item(),
|
| 942 |
+
on_step=True,
|
| 943 |
+
on_epoch=False,
|
| 944 |
+
prog_bar=True,
|
| 945 |
+
sync_dist=True)
|
| 946 |
+
return loss
|
| 947 |
+
|
| 948 |
+
def on_validation_epoch_end(self):
|
| 949 |
+
gc.collect()
|
| 950 |
+
torch.cuda.empty_cache()
|
| 951 |
+
|
| 952 |
+
"""OPTIMIZATION"""
|
| 953 |
+
|
| 954 |
+
def optimizer_step(self, *args, **kwargs):
|
| 955 |
+
super().optimizer_step(*args, **kwargs)
|
| 956 |
+
|
| 957 |
+
gc.collect()
|
| 958 |
+
torch.cuda.empty_cache()
|
| 959 |
+
|
| 960 |
+
def configure_optimizers(self):
|
| 961 |
+
optimizer = torch.optim.AdamW(
|
| 962 |
+
itertools.chain(self.backbone.parameters(),self.noise.parameters()),
|
| 963 |
+
lr=self.config.optim.lr,
|
| 964 |
+
betas=(self.config.optim.beta1, self.config.optim.beta2),
|
| 965 |
+
eps=self.config.optim.eps,
|
| 966 |
+
weight_decay=self.config.optim.weight_decay
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
self.total_steps = self.config.trainer.max_steps
|
| 970 |
+
scheduler = CosineWarmup(optimizer,
|
| 971 |
+
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
|
| 972 |
+
total_steps=self.total_steps)
|
| 973 |
+
|
| 974 |
+
scheduler_dict = {
|
| 975 |
+
'scheduler': scheduler,
|
| 976 |
+
'interval': 'step',
|
| 977 |
+
'frequency': 1,
|
| 978 |
+
'monitor': 'val/loss',
|
| 979 |
+
'name': 'trainer/lr'
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
return [optimizer], [scheduler_dict]
|
| 983 |
+
|
| 984 |
+
@torch.no_grad()
|
| 985 |
+
def compute_masked_perplexity(self, generated_ids, input_ids):
|
| 986 |
+
"""
|
| 987 |
+
Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
|
| 988 |
+
"""
|
| 989 |
+
|
| 990 |
+
total_nll = 0
|
| 991 |
+
total_tokens = 0
|
| 992 |
+
|
| 993 |
+
input_ids = torch.tensor(input_ids).to(self.device)
|
| 994 |
+
#print(input_ids)
|
| 995 |
+
|
| 996 |
+
for sequence in generated_ids:
|
| 997 |
+
# tokenize the sequence
|
| 998 |
+
|
| 999 |
+
gt_ids = torch.tensor(sequence).to(self.device)
|
| 1000 |
+
#print(gt_ids)
|
| 1001 |
+
|
| 1002 |
+
sys.stdout.flush()
|
| 1003 |
+
|
| 1004 |
+
# forward pass thorugh backbone peptideclm model
|
| 1005 |
+
attn_mask = torch.ones_like(input_ids).to(self.device)
|
| 1006 |
+
|
| 1007 |
+
# compute logits using backbone
|
| 1008 |
+
|
| 1009 |
+
if self.config.mode in ['train', 'ppl_eval']:
|
| 1010 |
+
outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
|
| 1011 |
+
elif self.config.mode == 'sample_eval':
|
| 1012 |
+
outputs = self.backbone.forward(input_ids=input_ids)
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
# get logits for each position in sequence across all tokens in vocab
|
| 1016 |
+
#logits = outputs[-1] # (batch_size, seq_length, vocab_size)
|
| 1017 |
+
|
| 1018 |
+
logits = outputs.view(-1, outputs.size(-1))
|
| 1019 |
+
gt_ids = gt_ids.view(-1)
|
| 1020 |
+
|
| 1021 |
+
#print(logits.shape)
|
| 1022 |
+
#print(gt_ids.shape)
|
| 1023 |
+
|
| 1024 |
+
# compute loss
|
| 1025 |
+
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
|
| 1026 |
+
# shift_labels = input_ids[:, 1:].contiguous()
|
| 1027 |
+
# print(masked)
|
| 1028 |
+
|
| 1029 |
+
loss = F.cross_entropy(logits,
|
| 1030 |
+
gt_ids.where(input_ids==self.mask_token_id, torch.full_like(gt_ids, -100)).view(-1),
|
| 1031 |
+
reduction='sum')
|
| 1032 |
+
|
| 1033 |
+
total_nll += loss.item()
|
| 1034 |
+
# count all non-padding tokens
|
| 1035 |
+
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
|
| 1036 |
+
|
| 1037 |
+
# compute pseudo-perplexity
|
| 1038 |
+
# print(total_nll, ",;,", total_tokens)
|
| 1039 |
+
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
|
| 1040 |
+
self.gen_ppl_metric.update(pseudo_perplexity)
|
| 1041 |
+
|
| 1042 |
+
return pseudo_perplexity.item()
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
def sample_categorical(categorical_probs):
|
| 1046 |
+
gumbel_norm = (
|
| 1047 |
+
1e-10
|
| 1048 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 1049 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 1050 |
+
|
| 1051 |
+
def sample_batched_categorical(categorical_probs, batch_size):
|
| 1052 |
+
"""
|
| 1053 |
+
Generates `m` distinct sequences sampled from categorical probabilities
|
| 1054 |
+
using the Gumbel distribution to ensure randomness while following probabilities
|
| 1055 |
+
|
| 1056 |
+
Args:
|
| 1057 |
+
categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
|
| 1058 |
+
representing categorical probabilities
|
| 1059 |
+
m (int): number of distinct sequences to sample
|
| 1060 |
+
|
| 1061 |
+
Returns:
|
| 1062 |
+
torch.Tensor: tensor of shape (m, sequence_length), where each row is a
|
| 1063 |
+
distinct sequence of sampled category indices.
|
| 1064 |
+
"""
|
| 1065 |
+
_, sequence_length, vocab_size = categorical_probs.shape
|
| 1066 |
+
|
| 1067 |
+
# add Gumbel noise and sample m sequences
|
| 1068 |
+
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
|
| 1069 |
+
noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
|
| 1070 |
+
|
| 1071 |
+
# select the highest score (most likely category after Gumbel noise)
|
| 1072 |
+
sampled_sequences = noisy_scores.argmax(dim=-1) # shape: (m, sequence_length)
|
| 1073 |
+
|
| 1074 |
+
return sampled_sequences
|
| 1075 |
+
|
| 1076 |
+
def sample_batched_top_k(categorical_probs, batch_size, k):
|
| 1077 |
+
"""
|
| 1078 |
+
Generates `m` sequences sampled from the top-k probabilities of each token
|
| 1079 |
+
using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
|
| 1080 |
+
|
| 1081 |
+
Args:
|
| 1082 |
+
categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
|
| 1083 |
+
representing categorical probabilities.
|
| 1084 |
+
m (int): Number of sequences to sample.
|
| 1085 |
+
k (int): Number of top probabilities to consider for sampling.
|
| 1086 |
+
|
| 1087 |
+
Returns:
|
| 1088 |
+
torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
|
| 1089 |
+
sampled sequence of category indices.
|
| 1090 |
+
"""
|
| 1091 |
+
_, sequence_length, vocab_length = categorical_probs.shape
|
| 1092 |
+
|
| 1093 |
+
# Add Gumbel noise to the log probabilities
|
| 1094 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
|
| 1095 |
+
noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
|
| 1096 |
+
|
| 1097 |
+
# Get the top-k categories based on noisy scores
|
| 1098 |
+
top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
|
| 1099 |
+
|
| 1100 |
+
# Convert top-k scores back to probabilities and normalize
|
| 1101 |
+
top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
|
| 1102 |
+
|
| 1103 |
+
# Sample randomly from the top-k probabilities
|
| 1104 |
+
sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
|
| 1105 |
+
sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
|
| 1106 |
+
|
| 1107 |
+
# Map sampled indices back to the original vocabulary indices
|
| 1108 |
+
sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device)
|
| 1109 |
+
|
| 1110 |
+
return sampled_sequences
|
| 1111 |
+
|
| 1112 |
+
def unsqueeze(x, reference):
|
| 1113 |
+
return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
|
| 1114 |
+
|
| 1115 |
+
class CosineWarmup(_LRScheduler):
|
| 1116 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
| 1117 |
+
self.warmup_steps = warmup_steps
|
| 1118 |
+
self.total_steps = total_steps
|
| 1119 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
| 1120 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
| 1121 |
+
|
| 1122 |
+
def get_lr(self):
|
| 1123 |
+
if self.last_epoch < self.warmup_steps:
|
| 1124 |
+
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
|
| 1125 |
+
|
| 1126 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 1127 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
| 1128 |
+
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
| 1129 |
+
|
| 1130 |
+
return [decayed_lr * base_lr for base_lr in self.base_lrs]
|
generate_mcts.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from utils.generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
|
| 9 |
+
from diffusion import Diffusion
|
| 10 |
+
from pareto_mcts import Node, MCTS
|
| 11 |
+
import hydra
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
| 14 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 15 |
+
from helm_tokenizer.helm_tokenizer import HelmTokenizer
|
| 16 |
+
from utils.helm_utils import create_helm_from_aa_seq
|
| 17 |
+
from utils.app import PeptideAnalyzer
|
| 18 |
+
from new_tokenizer.ape_tokenizer import APETokenizer
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import os
|
| 21 |
+
import seaborn as sns
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
def save_logs_to_file(config, valid_fraction_log, affinity1_log, affinity2_log, sol_log, hemo_log, nf_log, permeability_log, output_path):
|
| 26 |
+
"""
|
| 27 |
+
Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
|
| 28 |
+
|
| 29 |
+
Parameters:
|
| 30 |
+
valid_fraction_log (list): Log of valid fractions over iterations.
|
| 31 |
+
affinity1_log (list): Log of binding affinity over iterations.
|
| 32 |
+
permeability_log (list): Log of membrane permeability over iterations.
|
| 33 |
+
output_path (str): Path to save the log CSV file.
|
| 34 |
+
"""
|
| 35 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 36 |
+
|
| 37 |
+
if config.mcts.perm:
|
| 38 |
+
# Combine logs into a DataFrame
|
| 39 |
+
log_data = {
|
| 40 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 41 |
+
"Valid Fraction": valid_fraction_log,
|
| 42 |
+
"Binding Affinity": affinity1_log,
|
| 43 |
+
"Solubility": sol_log,
|
| 44 |
+
"Hemolysis": hemo_log,
|
| 45 |
+
"Nonfouling": nf_log,
|
| 46 |
+
"Permeability": permeability_log
|
| 47 |
+
}
|
| 48 |
+
elif config.mcts.dual:
|
| 49 |
+
log_data = {
|
| 50 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 51 |
+
"Valid Fraction": valid_fraction_log,
|
| 52 |
+
"Binding Affinity 1": affinity1_log,
|
| 53 |
+
"Binding Affinity 2": affinity2_log,
|
| 54 |
+
"Solubility": sol_log,
|
| 55 |
+
"Hemolysis": hemo_log,
|
| 56 |
+
"Nonfouling": nf_log,
|
| 57 |
+
"Permeability": permeability_log
|
| 58 |
+
}
|
| 59 |
+
elif config.mcts.single:
|
| 60 |
+
log_data = {
|
| 61 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 62 |
+
"Valid Fraction": valid_fraction_log,
|
| 63 |
+
"Permeability": permeability_log
|
| 64 |
+
}
|
| 65 |
+
else:
|
| 66 |
+
log_data = {
|
| 67 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 68 |
+
"Valid Fraction": valid_fraction_log,
|
| 69 |
+
"Binding Affinity": affinity1_log,
|
| 70 |
+
"Solubility": sol_log,
|
| 71 |
+
"Hemolysis": hemo_log,
|
| 72 |
+
"Nonfouling": nf_log
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
df = pd.DataFrame(log_data)
|
| 76 |
+
|
| 77 |
+
# Save to CSV
|
| 78 |
+
df.to_csv(output_path, index=False)
|
| 79 |
+
|
| 80 |
+
def plot_data(log1, log2=None,
|
| 81 |
+
save_path=None,
|
| 82 |
+
label1="Log 1",
|
| 83 |
+
label2=None,
|
| 84 |
+
title="Fraction of Valid Peptides Over Iterations",
|
| 85 |
+
palette=None):
|
| 86 |
+
"""
|
| 87 |
+
Plots one or two datasets with their mean values over iterations.
|
| 88 |
+
|
| 89 |
+
Parameters:
|
| 90 |
+
log1 (list): The first list of mean values for each iteration.
|
| 91 |
+
log2 (list, optional): The second list of mean values for each iteration. Defaults to None.
|
| 92 |
+
save_path (str): Path to save the plot. Defaults to None.
|
| 93 |
+
label1 (str): Label for the first dataset. Defaults to "Log 1".
|
| 94 |
+
label2 (str, optional): Label for the second dataset. Defaults to None.
|
| 95 |
+
title (str): Title of the plot. Defaults to "Mean Values Over Iterations".
|
| 96 |
+
palette (dict, optional): A dictionary defining custom colors for datasets. Defaults to None.
|
| 97 |
+
"""
|
| 98 |
+
# Prepare data for log1
|
| 99 |
+
data1 = pd.DataFrame({
|
| 100 |
+
"Iteration": range(1, len(log1) + 1),
|
| 101 |
+
"Fraction of Valid Peptides": log1,
|
| 102 |
+
"Dataset": label1
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
# Prepare data for log2 if provided
|
| 106 |
+
if log2 is not None:
|
| 107 |
+
data2 = pd.DataFrame({
|
| 108 |
+
"Iteration": range(1, len(log2) + 1),
|
| 109 |
+
"Fraction of Valid Peptides": log2,
|
| 110 |
+
"Dataset": label2
|
| 111 |
+
})
|
| 112 |
+
data = pd.concat([data1, data2], ignore_index=True)
|
| 113 |
+
else:
|
| 114 |
+
data = data1
|
| 115 |
+
|
| 116 |
+
palette = {
|
| 117 |
+
label1: "#8181ED", # Default color for log1
|
| 118 |
+
label2: "#D577FF" # Default color for log2 (if provided)
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# Set Seaborn theme
|
| 122 |
+
sns.set_theme()
|
| 123 |
+
sns.set_context("paper")
|
| 124 |
+
|
| 125 |
+
# Create the plot
|
| 126 |
+
sns.lineplot(
|
| 127 |
+
data=data,
|
| 128 |
+
x="Iteration",
|
| 129 |
+
y="Fraction of Valid Peptides",
|
| 130 |
+
hue="Dataset",
|
| 131 |
+
style="Dataset",
|
| 132 |
+
markers=True,
|
| 133 |
+
dashes=False,
|
| 134 |
+
palette=palette
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Titles and labels
|
| 138 |
+
plt.title(title)
|
| 139 |
+
plt.xlabel("Iteration")
|
| 140 |
+
plt.ylabel("Fraction of Valid Peptides")
|
| 141 |
+
|
| 142 |
+
if save_path:
|
| 143 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 144 |
+
print(f"Plot saved to {save_path}")
|
| 145 |
+
plt.show()
|
| 146 |
+
|
| 147 |
+
def plot_data_with_distribution_seaborn(log1, log2=None,
|
| 148 |
+
save_path=None,
|
| 149 |
+
label1=None,
|
| 150 |
+
label2=None,
|
| 151 |
+
title=None):
|
| 152 |
+
"""
|
| 153 |
+
Plots one or two datasets with the average values and distributions over iterations using Seaborn.
|
| 154 |
+
|
| 155 |
+
Parameters:
|
| 156 |
+
log1 (list of lists): The first list of scores (each element is a list of scores for an iteration).
|
| 157 |
+
log2 (list of lists, optional): The second list of scores (each element is a list of scores for an iteration). Defaults to None.
|
| 158 |
+
save_path (str): Path to save the plot. Defaults to None.
|
| 159 |
+
label1 (str): Label for the first dataset. Defaults to "Fraction of Valid Peptide SMILES".
|
| 160 |
+
label2 (str, optional): Label for the second dataset. Defaults to None.
|
| 161 |
+
title (str): Title of the plot. Defaults to "Fraction of Valid Peptides Over Iterations".
|
| 162 |
+
"""
|
| 163 |
+
# Prepare data for log1
|
| 164 |
+
data1 = pd.DataFrame({
|
| 165 |
+
"Iteration": np.repeat(range(1, len(log1) + 1), [len(scores) for scores in log1]),
|
| 166 |
+
"Fraction of Valid Peptides": [score for scores in log1 for score in scores],
|
| 167 |
+
"Dataset": label1,
|
| 168 |
+
"Style": "Log1"
|
| 169 |
+
})
|
| 170 |
+
|
| 171 |
+
# Prepare data for log2 if provided
|
| 172 |
+
if log2 is not None:
|
| 173 |
+
data2 = pd.DataFrame({
|
| 174 |
+
"Iteration": np.repeat(range(1, len(log2) + 1), [len(scores) for scores in log2]),
|
| 175 |
+
"Fraction of Valid Peptides": [score for scores in log2 for score in scores],
|
| 176 |
+
"Dataset": label2,
|
| 177 |
+
"Style": "Log2"
|
| 178 |
+
})
|
| 179 |
+
data = pd.concat([data1, data2], ignore_index=True)
|
| 180 |
+
else:
|
| 181 |
+
data = data1
|
| 182 |
+
|
| 183 |
+
palette = {
|
| 184 |
+
label1: "#8181ED", # Default color for log1
|
| 185 |
+
label2: "#D577FF" # Default color for log2 (if provided)
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
# Set Seaborn theme
|
| 189 |
+
sns.set_theme()
|
| 190 |
+
sns.set_context("paper")
|
| 191 |
+
|
| 192 |
+
# Create the plot
|
| 193 |
+
sns.relplot(
|
| 194 |
+
data=data,
|
| 195 |
+
kind="line",
|
| 196 |
+
x="Iteration",
|
| 197 |
+
y="Fraction of Valid Peptides",
|
| 198 |
+
hue="Dataset",
|
| 199 |
+
style="Style",
|
| 200 |
+
markers=True,
|
| 201 |
+
dashes=True,
|
| 202 |
+
ci="sd", # Show standard deviation
|
| 203 |
+
height=5,
|
| 204 |
+
aspect=1.5,
|
| 205 |
+
palette=palette
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Titles and labels
|
| 209 |
+
plt.title(title)
|
| 210 |
+
plt.xlabel("Iteration")
|
| 211 |
+
plt.ylabel("Fraction of Valid Peptides")
|
| 212 |
+
|
| 213 |
+
if save_path:
|
| 214 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 215 |
+
print(f"Plot saved to {save_path}")
|
| 216 |
+
plt.show()
|
| 217 |
+
|
| 218 |
+
@torch.no_grad()
|
| 219 |
+
def generate_valid_mcts(config, mdlm, prot1=None, prot2=None, filename=None, prot_name1=None, prot_name2 = None):
|
| 220 |
+
tokenizer = mdlm.tokenizer
|
| 221 |
+
max_sequence_length = config.sampling.seq_length
|
| 222 |
+
|
| 223 |
+
# generate array of [MASK] tokens
|
| 224 |
+
masked_array = mask_for_de_novo(config, max_sequence_length)
|
| 225 |
+
|
| 226 |
+
if config.vocab == 'old_smiles':
|
| 227 |
+
# use custom encode function
|
| 228 |
+
inputs = tokenizer.encode(masked_array)
|
| 229 |
+
elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
|
| 230 |
+
inputs = tokenizer.encode_for_generation(masked_array)
|
| 231 |
+
else:
|
| 232 |
+
# custom HELM tokenizer
|
| 233 |
+
inputs = tokenizer(masked_array, return_tensors="pt")
|
| 234 |
+
|
| 235 |
+
inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
|
| 236 |
+
|
| 237 |
+
# initialize root node
|
| 238 |
+
rootNode = Node(config=config, tokens=inputs, timestep=0)
|
| 239 |
+
# initalize tree search algorithm
|
| 240 |
+
|
| 241 |
+
if config.mcts.perm:
|
| 242 |
+
score_func_names = ['permeability', 'binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
|
| 243 |
+
num_func = [0, 50, 50, 50, 50]
|
| 244 |
+
elif config.mcts.dual:
|
| 245 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'binding_affinity2']
|
| 246 |
+
elif config.mcts.single:
|
| 247 |
+
score_func_names = ['permeability']
|
| 248 |
+
else:
|
| 249 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
|
| 250 |
+
|
| 251 |
+
if not config.mcts.time_dependent:
|
| 252 |
+
num_func = [0] * len(score_func_names)
|
| 253 |
+
|
| 254 |
+
if prot1 and prot2 is not None:
|
| 255 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1, prot2], num_func=num_func)
|
| 256 |
+
elif prot1 is not None:
|
| 257 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1], num_func=num_func)
|
| 258 |
+
elif config.mcts.single:
|
| 259 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
|
| 260 |
+
else:
|
| 261 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
|
| 262 |
+
|
| 263 |
+
paretoFront = mcts.forward(rootNode)
|
| 264 |
+
|
| 265 |
+
output_log_path = f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/log_{filename}.csv'
|
| 266 |
+
save_logs_to_file(config, mcts.valid_fraction_log, mcts.affinity1_log, mcts.affinity2_log, mcts.sol_log, mcts.hemo_log, mcts.nf_log, mcts.permeability_log, output_log_path)
|
| 267 |
+
|
| 268 |
+
if config.mcts.single:
|
| 269 |
+
plot_data_with_distribution_seaborn(log1=mcts.permeability_log,
|
| 270 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/perm_{filename}.png',
|
| 271 |
+
label1="Average Permeability Score",
|
| 272 |
+
title="Average Permeability Score Over Iterations")
|
| 273 |
+
else:
|
| 274 |
+
plot_data(mcts.valid_fraction_log,
|
| 275 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/valid_{filename}.png')
|
| 276 |
+
plot_data_with_distribution_seaborn(log1=mcts.affinity1_log,
|
| 277 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/binding1_{filename}.png',
|
| 278 |
+
label1="Average Binding Affinity to TfR",
|
| 279 |
+
title="Average Binding Affinity to TfR Over Iterations")
|
| 280 |
+
if config.mcts.dual:
|
| 281 |
+
plot_data_with_distribution_seaborn(log1=mcts.affinity2_log,
|
| 282 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/binding2_{filename}.png',
|
| 283 |
+
label1="Average Binding Affinity to SKP2",
|
| 284 |
+
title="Average Binding Affinity to SKP2 Over Iterations")
|
| 285 |
+
plot_data_with_distribution_seaborn(log1=mcts.sol_log,
|
| 286 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/sol_{filename}.png',
|
| 287 |
+
label1="Average Solubility Score",
|
| 288 |
+
title="Average Solubility Score Over Iterations")
|
| 289 |
+
plot_data_with_distribution_seaborn(log1=mcts.hemo_log,
|
| 290 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/hemo_{filename}.png',
|
| 291 |
+
label1="Average Hemolysis Score",
|
| 292 |
+
title="Average Hemolysis Score Over Iterations")
|
| 293 |
+
plot_data_with_distribution_seaborn(log1=mcts.nf_log,
|
| 294 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/nf_{filename}.png',
|
| 295 |
+
label1="Average Nonfouling Score",
|
| 296 |
+
title="Average Nonfouling Score Over Iterations")
|
| 297 |
+
if config.mcts.perm:
|
| 298 |
+
plot_data_with_distribution_seaborn(log1=mcts.permeability_log,
|
| 299 |
+
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/perm_{filename}.png',
|
| 300 |
+
label1="Average Permeability Score",
|
| 301 |
+
title="Average Permeability Score Over Iterations")
|
| 302 |
+
|
| 303 |
+
return paretoFront, inputs
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@hydra.main(version_base=None, config_path='/home/st512/peptune/scripts/peptide-mdlm-mcts', config_name='config')
|
| 307 |
+
def main(config):
|
| 308 |
+
prot_name1 = "time_dependent"
|
| 309 |
+
prot_name2 = "skp2"
|
| 310 |
+
mode = "2"
|
| 311 |
+
model = "mcts"
|
| 312 |
+
length = "100"
|
| 313 |
+
epoch = "7"
|
| 314 |
+
|
| 315 |
+
filename = f'{mode}_{model}_length_{length}_epoch_{epoch}'
|
| 316 |
+
|
| 317 |
+
if config.vocab == 'new_smiles':
|
| 318 |
+
tokenizer = APETokenizer()
|
| 319 |
+
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_smiles_600_vocab.json')
|
| 320 |
+
elif config.vocab == 'old_smiles':
|
| 321 |
+
tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
|
| 322 |
+
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 323 |
+
elif config.vocab == 'selfies':
|
| 324 |
+
tokenizer = APETokenizer()
|
| 325 |
+
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_selfies_600_vocab.json')
|
| 326 |
+
elif config.vocab == 'helm':
|
| 327 |
+
tokenizer = HelmTokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/helm_tokenizer/monomer_vocab.txt')
|
| 328 |
+
|
| 329 |
+
mdlm = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer, strict=False)
|
| 330 |
+
|
| 331 |
+
mdlm.eval()
|
| 332 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 333 |
+
mdlm.to(device)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
print("loaded models...")
|
| 337 |
+
analyzer = PeptideAnalyzer()
|
| 338 |
+
|
| 339 |
+
# proteins
|
| 340 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 341 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 342 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 343 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 344 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 345 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 346 |
+
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
|
| 347 |
+
ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
|
| 348 |
+
skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
|
| 349 |
+
|
| 350 |
+
paretoFront, input_array = generate_valid_mcts(config, mdlm, gfap, None, filename, prot_name1, None)
|
| 351 |
+
generation_results = []
|
| 352 |
+
|
| 353 |
+
for sequence, v in paretoFront.items():
|
| 354 |
+
generated_array = v['token_ids'].to(mdlm.device)
|
| 355 |
+
|
| 356 |
+
# compute perplexity
|
| 357 |
+
perplexity = mdlm.compute_masked_perplexity(generated_array, input_array['input_ids'])
|
| 358 |
+
perplexity = round(perplexity, 4)
|
| 359 |
+
|
| 360 |
+
aa_seq, seq_length = analyzer.analyze_structure(sequence)
|
| 361 |
+
scores = v['scores']
|
| 362 |
+
|
| 363 |
+
if config.mcts.single == False:
|
| 364 |
+
binding1 = scores[0]
|
| 365 |
+
solubility = scores[1]
|
| 366 |
+
hemo = scores[2]
|
| 367 |
+
nonfouling = scores[3]
|
| 368 |
+
|
| 369 |
+
if config.mcts.perm:
|
| 370 |
+
permeability = scores[4]
|
| 371 |
+
generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling, permeability])
|
| 372 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}")
|
| 373 |
+
elif config.mcts.dual:
|
| 374 |
+
binding2 = scores[4]
|
| 375 |
+
generation_results.append([sequence, perplexity, aa_seq, binding1, binding2, solubility, hemo, nonfouling])
|
| 376 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity 1: {binding1} | Binding Affinity 2: {binding2} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
|
| 377 |
+
elif config.mcts.single:
|
| 378 |
+
permeability = scores[0]
|
| 379 |
+
else:
|
| 380 |
+
generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling])
|
| 381 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
|
| 382 |
+
|
| 383 |
+
sys.stdout.flush()
|
| 384 |
+
|
| 385 |
+
if config.mcts.perm:
|
| 386 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
|
| 387 |
+
elif config.mcts.dual:
|
| 388 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity 1', 'Binding Affinity 2', 'Solubility', 'Hemolysis', 'Nonfouling'])
|
| 389 |
+
elif config.mcts.single:
|
| 390 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Permeability'])
|
| 391 |
+
else:
|
| 392 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling'])
|
| 393 |
+
|
| 394 |
+
df.to_csv(f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/{filename}.csv', index=False)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
main()
|
generate_unconditional.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from utils.generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
|
| 8 |
+
from diffusion import Diffusion
|
| 9 |
+
import hydra
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
| 12 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 13 |
+
from helm_tokenizer.helm_tokenizer import HelmTokenizer
|
| 14 |
+
from utils.helm_utils import create_helm_from_aa_seq, get_smi_from_helms
|
| 15 |
+
from utils.filter import PeptideAnalyzer
|
| 16 |
+
from new_tokenizer.ape_tokenizer import APETokenizer
|
| 17 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@torch.no_grad()
|
| 21 |
+
def generate_sequence_unconditional(config, sequence_length: int, mdlm: Diffusion):
|
| 22 |
+
tokenizer = mdlm.tokenizer
|
| 23 |
+
# generate array of [MASK] tokens
|
| 24 |
+
masked_array = mask_for_de_novo(config, sequence_length)
|
| 25 |
+
|
| 26 |
+
if config.vocab == 'old_smiles':
|
| 27 |
+
# use custom encode function
|
| 28 |
+
inputs = tokenizer.encode(masked_array)
|
| 29 |
+
elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
|
| 30 |
+
inputs = tokenizer.encode_for_generation(masked_array)
|
| 31 |
+
else:
|
| 32 |
+
# custom HELM tokenizer
|
| 33 |
+
inputs = tokenizer(masked_array, return_tensors="pt")
|
| 34 |
+
|
| 35 |
+
# tokenized masked array
|
| 36 |
+
inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
|
| 37 |
+
# sample unconditional array of tokens
|
| 38 |
+
logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
|
| 39 |
+
|
| 40 |
+
return logits, inputs
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@hydra.main(version_base=None, config_path='/home/st512/peptune/scripts/peptide-mdlm-mcts', config_name='config')
|
| 44 |
+
def main(config):
|
| 45 |
+
path = "/home/st512/peptune/scripts/peptide-mdlm-mcts"
|
| 46 |
+
|
| 47 |
+
if config.vocab == 'new_smiles':
|
| 48 |
+
tokenizer = APETokenizer()
|
| 49 |
+
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_smiles_600_vocab.json')
|
| 50 |
+
elif config.vocab == 'old_smiles':
|
| 51 |
+
tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
|
| 52 |
+
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 53 |
+
elif config.vocab == 'selfies':
|
| 54 |
+
tokenizer = APETokenizer()
|
| 55 |
+
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_selfies_600_vocab.json')
|
| 56 |
+
elif config.vocab == 'helm':
|
| 57 |
+
tokenizer = HelmTokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/helm_tokenizer/monomer_vocab.txt')
|
| 58 |
+
|
| 59 |
+
mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer, strict=False)
|
| 60 |
+
|
| 61 |
+
mdlm_model.eval()
|
| 62 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 63 |
+
mdlm_model.to(device)
|
| 64 |
+
|
| 65 |
+
print("loaded models...")
|
| 66 |
+
analyzer = PeptideAnalyzer()
|
| 67 |
+
|
| 68 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 69 |
+
|
| 70 |
+
# scoring functions
|
| 71 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
|
| 72 |
+
score_functions = ScoringFunctions(score_func_names, [gfap])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
max_seq_length = config.sampling.seq_length
|
| 76 |
+
num_sequences = config.sampling.num_sequences
|
| 77 |
+
generation_results = []
|
| 78 |
+
num_valid = 0.
|
| 79 |
+
num_total = 0.
|
| 80 |
+
while num_total < num_sequences:
|
| 81 |
+
num_total += 1
|
| 82 |
+
generated_array, input_array = generate_sequence_unconditional(config, max_seq_length, mdlm_model)
|
| 83 |
+
|
| 84 |
+
# store in device
|
| 85 |
+
generated_array = generated_array.to(mdlm_model.device)
|
| 86 |
+
print(generated_array)
|
| 87 |
+
|
| 88 |
+
# compute masked perplexity
|
| 89 |
+
perplexity = mdlm_model.compute_masked_perplexity(generated_array, input_array['input_ids'])
|
| 90 |
+
perplexity = round(perplexity, 4)
|
| 91 |
+
|
| 92 |
+
if config.vocab == 'old_smiles' or config.vocab == 'new_smiles':
|
| 93 |
+
smiles_seq = tokenizer.decode(generated_array)
|
| 94 |
+
if analyzer.is_peptide(smiles_seq):
|
| 95 |
+
aa_seq, seq_length = analyzer.analyze_structure(smiles_seq)
|
| 96 |
+
num_valid += 1
|
| 97 |
+
scores = score_functions(input_seqs=[smiles_seq])
|
| 98 |
+
|
| 99 |
+
binding = scores[0][0]
|
| 100 |
+
sol = scores[0][1]
|
| 101 |
+
hemo = scores[0][2]
|
| 102 |
+
nf = scores[0][3]
|
| 103 |
+
perm = scores[0][4]
|
| 104 |
+
|
| 105 |
+
generation_results.append([smiles_seq, perplexity, aa_seq, binding, sol, hemo, nf, perm])
|
| 106 |
+
else:
|
| 107 |
+
aa_seq = "not valid peptide"
|
| 108 |
+
seq_length = '-'
|
| 109 |
+
scores = "not valid peptide"
|
| 110 |
+
elif config.vocab == 'selfies':
|
| 111 |
+
smiles_seq = tokenizer.decode(generated_array)
|
| 112 |
+
else:
|
| 113 |
+
aa_seq = tokenizer.decode(generated_array)
|
| 114 |
+
smiles_seq = get_smi_from_helms(aa_seq)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {smiles_seq} | amino acid sequence: {aa_seq} | scores: {scores}")
|
| 118 |
+
sys.stdout.flush()
|
| 119 |
+
|
| 120 |
+
valid_frac = num_valid / num_total
|
| 121 |
+
print(f"fraction of synthesizable peptides: {valid_frac}")
|
| 122 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
|
| 123 |
+
df.to_csv(path + f'/benchmarks/unconditional/epoch-10-pretrain-gfap.csv', index=False)
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import os
|
| 3 |
+
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096'
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
import wandb
|
| 7 |
+
import fsspec
|
| 8 |
+
import hydra
|
| 9 |
+
import lightning as L
|
| 10 |
+
from lightning.pytorch import Trainer
|
| 11 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, GradientAccumulationScheduler
|
| 12 |
+
import omegaconf
|
| 13 |
+
import rich.syntax
|
| 14 |
+
import rich.tree
|
| 15 |
+
import torch
|
| 16 |
+
import sys
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
+
sys.path.append("/home/st512/peptune/scripts/peptide-mdlm-mcts")
|
| 20 |
+
|
| 21 |
+
import dataset as dataloader
|
| 22 |
+
import dataloading_for_dynamic_batching as dynamic_dataloader
|
| 23 |
+
from diffusion import Diffusion
|
| 24 |
+
import utils.utils as utils
|
| 25 |
+
from new_tokenizer.ape_tokenizer import APETokenizer
|
| 26 |
+
|
| 27 |
+
from lightning.pytorch.strategies import DDPStrategy
|
| 28 |
+
from datasets import load_dataset
|
| 29 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 30 |
+
from helm_tokenizer.helm_tokenizer import HelmTokenizer
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
#wandb.login(key="5a7613c531cb58f9802f3f8e2f73bc4997b917ab")
|
| 34 |
+
|
| 35 |
+
omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd)
|
| 36 |
+
omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count)
|
| 37 |
+
omegaconf.OmegaConf.register_new_resolver('eval', eval)
|
| 38 |
+
omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
|
| 39 |
+
|
| 40 |
+
def _load_from_checkpoint(config, tokenizer):
|
| 41 |
+
if 'hf' in config.backbone:
|
| 42 |
+
return Diffusion(
|
| 43 |
+
config, tokenizer=tokenizer).to('cuda')
|
| 44 |
+
else:
|
| 45 |
+
model = Diffusion.load_from_checkpoint(
|
| 46 |
+
config.eval.checkpoint_path,
|
| 47 |
+
tokenizer=tokenizer,
|
| 48 |
+
config=config)
|
| 49 |
+
|
| 50 |
+
return model
|
| 51 |
+
|
| 52 |
+
@L.pytorch.utilities.rank_zero_only
|
| 53 |
+
def print_config(
|
| 54 |
+
config: omegaconf.DictConfig,
|
| 55 |
+
resolve: bool = True,
|
| 56 |
+
save_cfg: bool = True) -> None:
|
| 57 |
+
"""
|
| 58 |
+
Prints content of DictConfig using Rich library and its tree structure.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
config (DictConfig): Configuration composed by Hydra.
|
| 62 |
+
resolve (bool): Whether to resolve reference fields of DictConfig.
|
| 63 |
+
save_cfg (bool): Whether to save the configuration tree to a file.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
style = 'dim'
|
| 67 |
+
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
|
| 68 |
+
|
| 69 |
+
fields = config.keys()
|
| 70 |
+
for field in fields:
|
| 71 |
+
branch = tree.add(field, style=style, guide_style=style)
|
| 72 |
+
|
| 73 |
+
config_section = config.get(field)
|
| 74 |
+
branch_content = str(config_section)
|
| 75 |
+
if isinstance(config_section, omegaconf.DictConfig):
|
| 76 |
+
branch_content = omegaconf.OmegaConf.to_yaml(
|
| 77 |
+
config_section, resolve=resolve)
|
| 78 |
+
|
| 79 |
+
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
|
| 80 |
+
rich.print(tree)
|
| 81 |
+
if save_cfg:
|
| 82 |
+
with fsspec.open(
|
| 83 |
+
'{}/config_tree.txt'.format(
|
| 84 |
+
config.checkpointing.save_dir), 'w') as fp:
|
| 85 |
+
rich.print(tree, file=fp)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@L.pytorch.utilities.rank_zero_only
|
| 89 |
+
def print_batch(train_ds, valid_ds, tokenizer, k=64):
|
| 90 |
+
#for dl_type, dl in [
|
| 91 |
+
#('train', train_ds), ('valid', valid_ds)]:
|
| 92 |
+
|
| 93 |
+
for dl_type, dl in [
|
| 94 |
+
('train', train_ds)]:
|
| 95 |
+
print(f'Printing {dl_type} dataloader batch.')
|
| 96 |
+
batch = next(iter(dl))
|
| 97 |
+
print('Batch input_ids.shape', batch['input_ids'].shape)
|
| 98 |
+
first = batch['input_ids'][0, :k]
|
| 99 |
+
last = batch['input_ids'][0, -k:]
|
| 100 |
+
print(f'First {k} tokens:', tokenizer.decode(first))
|
| 101 |
+
print('ids:', first)
|
| 102 |
+
print(f'Last {k} tokens:', tokenizer.decode(last))
|
| 103 |
+
print('ids:', last)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def generate_samples(config, logger, tokenizer):
|
| 107 |
+
logger.info('Generating samples.')
|
| 108 |
+
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
|
| 109 |
+
# model.gen_ppl_metric.reset()
|
| 110 |
+
|
| 111 |
+
#stride_length = config.sampling.stride_length
|
| 112 |
+
#num_strides = config.sampling.num_strides
|
| 113 |
+
|
| 114 |
+
for _ in range(config.sampling.num_sample_batches):
|
| 115 |
+
samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
|
| 116 |
+
peptide_sequences = model.tokenizer.batch_decode(samples)
|
| 117 |
+
model.compute_generative_perplexity(peptide_sequences)
|
| 118 |
+
|
| 119 |
+
print('Peptide samples:', peptide_sequences)
|
| 120 |
+
|
| 121 |
+
print('Generative perplexity:', model.compute_masked_perplexity())
|
| 122 |
+
|
| 123 |
+
return peptide_sequences
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def ppl_eval(config, logger, tokenizer, data_module):
|
| 127 |
+
logger.info('Starting Zero Shot Eval.')
|
| 128 |
+
|
| 129 |
+
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
|
| 130 |
+
|
| 131 |
+
wandb_logger = None
|
| 132 |
+
if config.get('wandb', None) is not None:
|
| 133 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
| 134 |
+
config=omegaconf.OmegaConf.to_object(config),
|
| 135 |
+
** config.wandb)
|
| 136 |
+
|
| 137 |
+
callbacks = []
|
| 138 |
+
|
| 139 |
+
if 'callbacks' in config:
|
| 140 |
+
for _, callback in config.callbacks.items():
|
| 141 |
+
callbacks.append(hydra.utils.instantiate(callback))
|
| 142 |
+
|
| 143 |
+
trainer = hydra.utils.instantiate(
|
| 144 |
+
config.trainer,
|
| 145 |
+
default_root_dir=os.getcwd(),
|
| 146 |
+
callbacks=callbacks,
|
| 147 |
+
strategy=DDPStrategy(find_unused_parameters = True),
|
| 148 |
+
logger=wandb_logger)
|
| 149 |
+
|
| 150 |
+
#_, valid_ds = dataloader.get_dataloaders(config, tokenizer, skiptrain=True, valid_seed=config.seed)
|
| 151 |
+
trainer.test(model, data_module)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _train(config, logger, tokenizer, data_module):
|
| 155 |
+
logger.info('Starting Training.')
|
| 156 |
+
wandb_logger = None
|
| 157 |
+
|
| 158 |
+
if config.get('wandb', None) is not None:
|
| 159 |
+
unique_id = str(uuid.uuid4())
|
| 160 |
+
|
| 161 |
+
config.wandb.id = f"{config.wandb.id}_{unique_id}"
|
| 162 |
+
|
| 163 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
| 164 |
+
config=omegaconf.OmegaConf.to_object(config),
|
| 165 |
+
** config.wandb)
|
| 166 |
+
|
| 167 |
+
if (config.checkpointing.resume_from_ckpt
|
| 168 |
+
and config.checkpointing.resume_ckpt_path is not None
|
| 169 |
+
and utils.fsspec_exists(
|
| 170 |
+
config.checkpointing.resume_ckpt_path)):
|
| 171 |
+
ckpt_path = config.checkpointing.resume_ckpt_path
|
| 172 |
+
else:
|
| 173 |
+
ckpt_path = None
|
| 174 |
+
|
| 175 |
+
# Lightning callbacks
|
| 176 |
+
callbacks = []
|
| 177 |
+
if 'callbacks' in config:
|
| 178 |
+
for callback_name, callback_config in config.callbacks.items():
|
| 179 |
+
if callback_name == 'model_checkpoint':
|
| 180 |
+
model_checkpoint_config = {k: v for k, v in callback_config.items() if k != '_target_'}
|
| 181 |
+
callbacks.append(ModelCheckpoint(**model_checkpoint_config))
|
| 182 |
+
else:
|
| 183 |
+
callbacks.append(hydra.utils.instantiate(callback_config))
|
| 184 |
+
|
| 185 |
+
if config.training.accumulator:
|
| 186 |
+
accumulator = GradientAccumulationScheduler(scheduling = {1: 5, 2: 4, 3: 3, 4: 1})
|
| 187 |
+
callbacks.append(accumulator)
|
| 188 |
+
|
| 189 |
+
trainer = hydra.utils.instantiate(
|
| 190 |
+
config.trainer,
|
| 191 |
+
default_root_dir=os.getcwd(),
|
| 192 |
+
callbacks=callbacks,
|
| 193 |
+
accelerator='cuda',
|
| 194 |
+
strategy=DDPStrategy(find_unused_parameters = True),
|
| 195 |
+
devices=[2,3,4,5,6,7],
|
| 196 |
+
logger=wandb_logger)
|
| 197 |
+
|
| 198 |
+
model = Diffusion(config, tokenizer=tokenizer)
|
| 199 |
+
|
| 200 |
+
if config.backbone == 'finetune_roformer':
|
| 201 |
+
checkpoint = torch.load('/home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/epoch=1-step=24080.ckpt')
|
| 202 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 203 |
+
|
| 204 |
+
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@hydra.main(version_base=None, config_path='/home/st512/peptune/scripts/peptide-mdlm-mcts', config_name='config')
|
| 208 |
+
def main(config):
|
| 209 |
+
"""
|
| 210 |
+
Main entry point for training
|
| 211 |
+
"""
|
| 212 |
+
wandb.init(project="peptune")
|
| 213 |
+
L.seed_everything(config.seed)
|
| 214 |
+
|
| 215 |
+
# print_config(config, resolve=True, save_cfg=True)
|
| 216 |
+
|
| 217 |
+
logger = utils.get_logger(__name__)
|
| 218 |
+
# load PeptideCLM tokenizer
|
| 219 |
+
if config.vocab == 'new_smiles':
|
| 220 |
+
tokenizer = APETokenizer()
|
| 221 |
+
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_smiles_600_vocab.json')
|
| 222 |
+
elif config.vocab == 'old_smiles':
|
| 223 |
+
tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
|
| 224 |
+
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 225 |
+
elif config.vocab == 'selfies':
|
| 226 |
+
tokenizer = APETokenizer()
|
| 227 |
+
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_selfies_600_vocab.json')
|
| 228 |
+
elif config.vocab == 'helm':
|
| 229 |
+
tokenizer = HelmTokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/helm_tokenizer/monomer_vocab.txt')
|
| 230 |
+
|
| 231 |
+
if config.backbone == 'finetune_roformer':
|
| 232 |
+
train_dataset = load_dataset('csv', data_files=config.data.train)
|
| 233 |
+
val_dataset = load_dataset('csv', data_files=config.data.valid)
|
| 234 |
+
|
| 235 |
+
train_dataset = train_dataset['train']#.select(lst)
|
| 236 |
+
val_dataset = val_dataset['train']#.select(lst)
|
| 237 |
+
data_module = dataloader.CustomDataModule(train_dataset, val_dataset, None, tokenizer, batch_size=config.loader.global_batch_size)
|
| 238 |
+
else:
|
| 239 |
+
data_module = dynamic_dataloader.CustomDataModule('/home/st512/peptune/scripts/peptide-mdlm-mcts/data/smiles/11M_smiles_old_tokenizer_no_limit', tokenizer)
|
| 240 |
+
|
| 241 |
+
if config.mode == 'sample_eval':
|
| 242 |
+
generate_samples(config, logger, tokenizer)
|
| 243 |
+
elif config.mode == 'ppl_eval':
|
| 244 |
+
ppl_eval(config, logger, tokenizer, data_module)
|
| 245 |
+
else:
|
| 246 |
+
_train(config, logger, tokenizer, data_module)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == '__main__':
|
| 250 |
+
main()
|
noise_schedule.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
torch._C._jit_set_profiling_mode(False)
|
| 7 |
+
torch._C._jit_set_profiling_executor(False)
|
| 8 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 9 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
MDLM Github Repo:
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_noise(config, dtype=torch.float32):
|
| 18 |
+
if config.noise.type == 'geometric':
|
| 19 |
+
return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
|
| 20 |
+
elif config.noise.type == 'loglinear':
|
| 21 |
+
return LogLinearNoise()
|
| 22 |
+
elif config.noise.type == 'cosine':
|
| 23 |
+
return CosineNoise()
|
| 24 |
+
elif config.noise.type == 'cosinesqr':
|
| 25 |
+
return CosineSqrNoise()
|
| 26 |
+
elif config.noise.type == 'linear':
|
| 27 |
+
return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def binary_discretization(z):
|
| 33 |
+
z_hard = torch.sign(z)
|
| 34 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
| 35 |
+
return z_soft + (z_hard - z_soft).detach()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Noise(abc.ABC, nn.Module):
|
| 39 |
+
"""
|
| 40 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
| 41 |
+
"""
|
| 42 |
+
def forward(self, t):
|
| 43 |
+
# Assume time goes from 0 to 1
|
| 44 |
+
return self.total_noise(t), self.rate_noise(t)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CosineNoise(Noise):
|
| 48 |
+
def __init__(self, eps=1e-3):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.eps = eps
|
| 51 |
+
|
| 52 |
+
def rate_noise(self, t):
|
| 53 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 54 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 55 |
+
scale = torch.pi / 2
|
| 56 |
+
return scale * sin / (cos + self.eps)
|
| 57 |
+
|
| 58 |
+
def total_noise(self, t):
|
| 59 |
+
cos = torch.cos(t * torch.pi / 2)
|
| 60 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CosineSqrNoise(Noise):
|
| 64 |
+
def __init__(self, eps=1e-3):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.eps = eps
|
| 67 |
+
|
| 68 |
+
def rate_noise(self, t):
|
| 69 |
+
cos = (1 - self.eps) * (
|
| 70 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
| 71 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
| 72 |
+
scale = torch.pi / 2
|
| 73 |
+
return scale * sin / (cos + self.eps)
|
| 74 |
+
|
| 75 |
+
def total_noise(self, t):
|
| 76 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
| 77 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Linear(Noise):
|
| 81 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
| 84 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
| 85 |
+
|
| 86 |
+
def rate_noise(self):
|
| 87 |
+
return self.sigma_max - self.sigma_min
|
| 88 |
+
|
| 89 |
+
def total_noise(self, t):
|
| 90 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
| 91 |
+
|
| 92 |
+
def importance_sampling_transformation(self, t):
|
| 93 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 94 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 95 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 96 |
+
return (sigma_t - self.sigma_min) / (
|
| 97 |
+
self.sigma_max - self.sigma_min)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class GeometricNoise(Noise):
|
| 101 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 104 |
+
|
| 105 |
+
def rate_noise(self, t):
|
| 106 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
| 107 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
| 108 |
+
|
| 109 |
+
def total_noise(self, t):
|
| 110 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class LogLinearNoise(Noise):
|
| 114 |
+
"""Log Linear noise schedule.
|
| 115 |
+
|
| 116 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
|
| 117 |
+
~1 when t varies from 0 to 1. Total noise is
|
| 118 |
+
-log(1 - (1 - eps) * t), so the sigma will be
|
| 119 |
+
(1 - eps) * t.
|
| 120 |
+
"""
|
| 121 |
+
def __init__(self, eps=1e-3):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.eps = eps
|
| 124 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 125 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 126 |
+
|
| 127 |
+
def rate_noise(self, t):
|
| 128 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
| 129 |
+
|
| 130 |
+
def total_noise(self, t):
|
| 131 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
| 132 |
+
|
| 133 |
+
def importance_sampling_transformation(self, t):
|
| 134 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 135 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 136 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 137 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
| 138 |
+
return t
|
| 139 |
+
|
| 140 |
+
class LogPolyNoise(Noise):
|
| 141 |
+
"""
|
| 142 |
+
Log Polynomial noise schedule for slower masking of peptide bond tokens
|
| 143 |
+
"""
|
| 144 |
+
def __init__(self, eps=1e-3):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.eps = eps
|
| 147 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 148 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 149 |
+
|
| 150 |
+
def rate_noise(self, t):
|
| 151 |
+
# derivative of -log(1-t^w)
|
| 152 |
+
return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
|
| 153 |
+
|
| 154 |
+
def total_noise(self, t):
|
| 155 |
+
# -log(1-t^w)
|
| 156 |
+
return -torch.log1p(-(1 - self.eps) * (t**3))
|
pareto_mcts.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random as rd
|
| 6 |
+
|
| 7 |
+
from diffusion import Diffusion
|
| 8 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 9 |
+
from utils.filter import PeptideAnalyzer
|
| 10 |
+
import noise_schedule
|
| 11 |
+
|
| 12 |
+
""""
|
| 13 |
+
Notes: store rolled out sequence?
|
| 14 |
+
path of node objects or strings?
|
| 15 |
+
should we only select valid expandable leaf nodes?
|
| 16 |
+
calculate similarity between sibling nodes?
|
| 17 |
+
should we evaluate generated sequences?
|
| 18 |
+
"""
|
| 19 |
+
class Node:
|
| 20 |
+
"""
|
| 21 |
+
Node class: partially unmasked SMILES string
|
| 22 |
+
- parentNode: Node object at previous time step
|
| 23 |
+
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
|
| 24 |
+
- totalReward: vector of cumulative rewards for all K objectives
|
| 25 |
+
- visits: number of times the node has been visited by an interation
|
| 26 |
+
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
|
| 27 |
+
- timestep: the time step where the sequence was sampled
|
| 28 |
+
- sampleProb: probability of sampling the sequence from the diffusion model
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None):
|
| 31 |
+
self.config = config
|
| 32 |
+
self.parentNode = parentNode
|
| 33 |
+
self.childNodes = childNodes
|
| 34 |
+
self.scoreVector = scoreVector
|
| 35 |
+
|
| 36 |
+
# initialize total rewards to the reward of the roll out unmasked sequence
|
| 37 |
+
if totalReward is not None:
|
| 38 |
+
self.totalReward = totalReward
|
| 39 |
+
else:
|
| 40 |
+
self.totalReward = np.zeros(self.config.mcts.num_objectives)
|
| 41 |
+
|
| 42 |
+
# set initial visits to 1
|
| 43 |
+
self.visits = 1
|
| 44 |
+
# array of all sequences in path from the root -> node
|
| 45 |
+
#self.path = path
|
| 46 |
+
# set timestep (value between 0 and num_steps)
|
| 47 |
+
self.timestep = timestep
|
| 48 |
+
# set the sampling probabiltiy equal to the probability from the reverse posterior
|
| 49 |
+
self.sampleProb = sampleProb
|
| 50 |
+
|
| 51 |
+
# dict with 'input_ids' as token array and 'attention_mask'
|
| 52 |
+
self.tokens = tokens
|
| 53 |
+
|
| 54 |
+
#self.sequence = sequence
|
| 55 |
+
|
| 56 |
+
def selectNode(self, num_func):
|
| 57 |
+
"""
|
| 58 |
+
Selects a node to move to among the children nodes
|
| 59 |
+
"""
|
| 60 |
+
# extract the status of the current node
|
| 61 |
+
nodeStatus = self.getExpandStatus()
|
| 62 |
+
|
| 63 |
+
# if the node is a legal non-leaf node
|
| 64 |
+
if (nodeStatus == 3):
|
| 65 |
+
# initialize array that will store select score vectors of each child node
|
| 66 |
+
paretoFront = {}
|
| 67 |
+
for childNode in self.childNodes:
|
| 68 |
+
childStatus = childNode.getExpandStatus()
|
| 69 |
+
# only append child if it is legal leaf node (expandable) or legal non-leaf node
|
| 70 |
+
if childStatus == 2 or childStatus == 3:
|
| 71 |
+
selectScore = childNode.calcSelectScore()
|
| 72 |
+
paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func)
|
| 73 |
+
|
| 74 |
+
# randomly select a node on the Pareto front
|
| 75 |
+
#selected = rd.choice(paretoFront)
|
| 76 |
+
selected = rd.choice(list(paretoFront.keys()))
|
| 77 |
+
# return selected child node and status
|
| 78 |
+
return selected, selected.getExpandStatus()
|
| 79 |
+
|
| 80 |
+
# if node is not valid non-leaf node
|
| 81 |
+
return self, nodeStatus
|
| 82 |
+
|
| 83 |
+
def addChildNode(self, tokens, totalReward, prob=None):
|
| 84 |
+
""""
|
| 85 |
+
Adds a child node
|
| 86 |
+
"""
|
| 87 |
+
child = Node(config=self.config,
|
| 88 |
+
tokens=tokens,
|
| 89 |
+
parentNode=self,
|
| 90 |
+
childNodes=[],
|
| 91 |
+
totalReward=totalReward,
|
| 92 |
+
timestep=self.timestep+1,
|
| 93 |
+
sampleProb=prob)
|
| 94 |
+
|
| 95 |
+
self.childNodes.append(child)
|
| 96 |
+
return child
|
| 97 |
+
|
| 98 |
+
def updateNode(self, rewards):
|
| 99 |
+
"""
|
| 100 |
+
Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
|
| 101 |
+
Increments the number of visits to the node.
|
| 102 |
+
"""
|
| 103 |
+
self.visits += 1
|
| 104 |
+
self.totalReward += rewards
|
| 105 |
+
|
| 106 |
+
def calcSelectScore(self):
|
| 107 |
+
"""
|
| 108 |
+
Calculates the select score for the node from the cumulative rewards vector and number of visits.
|
| 109 |
+
- c: determines the degree of exploration
|
| 110 |
+
- minSelectScore: determines the
|
| 111 |
+
"""
|
| 112 |
+
""""
|
| 113 |
+
if not self.parentNode:
|
| 114 |
+
return 0.0
|
| 115 |
+
"""
|
| 116 |
+
# K-dimensional vector of normalized rewards for each objective
|
| 117 |
+
normRewards = self.totalReward / self.visits
|
| 118 |
+
if self.sampleProb is not None:
|
| 119 |
+
print("Sample Prob")
|
| 120 |
+
print(self.sampleProb)
|
| 121 |
+
return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits)
|
| 122 |
+
return normRewards
|
| 123 |
+
|
| 124 |
+
def getExpandStatus(self):
|
| 125 |
+
"""
|
| 126 |
+
Returns an integer indicating whether the node is a:
|
| 127 |
+
1. terminal node (sequence is fully unmasked)
|
| 128 |
+
2. legal leaf node (partially unmasked sequence that can be expanded)
|
| 129 |
+
3. legal non-leaf node (already expanded sequence with M child nodes)
|
| 130 |
+
"""
|
| 131 |
+
if self.timestep == self.config.sampling.steps:
|
| 132 |
+
return 1
|
| 133 |
+
elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0):
|
| 134 |
+
return 2
|
| 135 |
+
return 3
|
| 136 |
+
|
| 137 |
+
"""END OF NODE CLASS"""
|
| 138 |
+
|
| 139 |
+
def updateParetoFront(paretoFront, node, scoreVector, num_func):
|
| 140 |
+
"""
|
| 141 |
+
Removes sequences that are dominated by scoreVector
|
| 142 |
+
adds the SMILES sequence if it is non-dominated and its scoreVector
|
| 143 |
+
"""
|
| 144 |
+
paretoSize = len(paretoFront)
|
| 145 |
+
if paretoSize == 0:
|
| 146 |
+
# if pareto front is empty, add sequence and scoreVector
|
| 147 |
+
paretoFront[node] = scoreVector
|
| 148 |
+
else:
|
| 149 |
+
# vector of boolean
|
| 150 |
+
# true: sequence is non-dominated by the pareto-optimal sequence
|
| 151 |
+
# false: sequence is completely dominated by the pareto-optimal sequence
|
| 152 |
+
nondominate = []
|
| 153 |
+
# sequences to be deleted
|
| 154 |
+
delete = []
|
| 155 |
+
for k, v in paretoFront.items():
|
| 156 |
+
nondominated = scoreVector >= np.asarray(v)
|
| 157 |
+
dominant = scoreVector > np.asarray(v)
|
| 158 |
+
|
| 159 |
+
if num_func <= len(nondominated):
|
| 160 |
+
attn_nondominated = nondominated[:num_func]
|
| 161 |
+
attn_dominant = dominant[:num_func]
|
| 162 |
+
|
| 163 |
+
# all scores are greater than or equal to v and at least one score is strictly greater than v
|
| 164 |
+
if attn_nondominated.all() and attn_dominant.any():
|
| 165 |
+
# add the dominated sequence to be deleted
|
| 166 |
+
delete.append(k)
|
| 167 |
+
# sequence is dominant
|
| 168 |
+
nondominate.append(True)
|
| 169 |
+
elif attn_nondominated.all():
|
| 170 |
+
# sequence is non-dominated
|
| 171 |
+
nondominate.append(True)
|
| 172 |
+
else:
|
| 173 |
+
# sequence is completely dominated
|
| 174 |
+
nondominate.append(False)
|
| 175 |
+
|
| 176 |
+
nondominate = np.asarray(nondominate)
|
| 177 |
+
# if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
|
| 178 |
+
if nondominate.all():
|
| 179 |
+
paretoFront[node] = scoreVector
|
| 180 |
+
|
| 181 |
+
# delete all dominated sequences
|
| 182 |
+
while (paretoSize > 0) and (len(delete) > 0):
|
| 183 |
+
#for k in delete:
|
| 184 |
+
del paretoFront[delete[0]]
|
| 185 |
+
del delete[0]
|
| 186 |
+
paretoSize -= 1
|
| 187 |
+
return paretoFront
|
| 188 |
+
|
| 189 |
+
"""BEGINNING OF MCTS CLASS"""
|
| 190 |
+
|
| 191 |
+
class MCTS:
|
| 192 |
+
def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []):
|
| 193 |
+
self.config = config
|
| 194 |
+
self.noise = noise_schedule.get_noise(config)
|
| 195 |
+
self.time_conditioning = self.config.time_conditioning
|
| 196 |
+
# dictionary of k (SMILES string) and v (score vector) of Pareto-optimal sequences
|
| 197 |
+
self.peptideParetoFront = {}
|
| 198 |
+
self.num_steps = config.sampling.steps
|
| 199 |
+
self.num_sequences = config.sampling.num_sequences
|
| 200 |
+
|
| 201 |
+
# mdlm model
|
| 202 |
+
self.mdlm = mdlm
|
| 203 |
+
self.tokenizer = mdlm.tokenizer
|
| 204 |
+
self.device = mdlm.device
|
| 205 |
+
|
| 206 |
+
if max_sequence_length is None:
|
| 207 |
+
self.sequence_length = self.config.sampling.seq_length
|
| 208 |
+
else:
|
| 209 |
+
self.sequence_length = max_sequence_length
|
| 210 |
+
|
| 211 |
+
self.num_iter = config.mcts.num_iter
|
| 212 |
+
|
| 213 |
+
self.num_child = config.mcts.num_children
|
| 214 |
+
|
| 215 |
+
# score functions
|
| 216 |
+
self.score_functions = ScoringFunctions(score_func_names, prot_seqs)
|
| 217 |
+
self.num_func = num_func # K-dimensional vector with the iteration number to start conditioning on each of the objectives in increasng order
|
| 218 |
+
self.iter_num = 0
|
| 219 |
+
self.curr_num_func = 1
|
| 220 |
+
self.analyzer = PeptideAnalyzer()
|
| 221 |
+
|
| 222 |
+
# track fraction of valid peptides
|
| 223 |
+
self.valid_fraction_log = []
|
| 224 |
+
self.affinity1_log = []
|
| 225 |
+
self.affinity2_log = []
|
| 226 |
+
self.permeability_log = []
|
| 227 |
+
self.sol_log = []
|
| 228 |
+
self.hemo_log = []
|
| 229 |
+
self.nf_log = []
|
| 230 |
+
|
| 231 |
+
def reset(self):
|
| 232 |
+
self.iter_num = 0
|
| 233 |
+
self.valid_fraction_log = []
|
| 234 |
+
self.affinity1_log = []
|
| 235 |
+
self.affinity2_log = []
|
| 236 |
+
self.permeability_log = []
|
| 237 |
+
self.sol_log = []
|
| 238 |
+
self.hemo_log = []
|
| 239 |
+
self.nf_log = []
|
| 240 |
+
self.peptideParetoFront = {}
|
| 241 |
+
|
| 242 |
+
def forward(self, rootNode):
|
| 243 |
+
self.reset()
|
| 244 |
+
|
| 245 |
+
while (self.iter_num < self.num_iter):
|
| 246 |
+
self.iter_num += 1
|
| 247 |
+
|
| 248 |
+
# traverse the tree form the root node until a leaf node
|
| 249 |
+
leafNode, _ = self.select(rootNode)
|
| 250 |
+
#print(leafNode.tokens['input_ids'])
|
| 251 |
+
|
| 252 |
+
# expand leaf node into num_children partially unmasked sequences at the next timestep
|
| 253 |
+
self.expand(leafNode)
|
| 254 |
+
|
| 255 |
+
# return dictionary of pareto front peptides and their score vectors
|
| 256 |
+
return self.peptideParetoFront
|
| 257 |
+
|
| 258 |
+
# change to include more even if dominated? since there is error in the scores
|
| 259 |
+
def updateParetoFront(self, sequence, scoreVector, tokens):
|
| 260 |
+
"""
|
| 261 |
+
Removes sequences that are dominated by scoreVector
|
| 262 |
+
adds the SMILES sequence if it is non-dominated and its scoreVector
|
| 263 |
+
|
| 264 |
+
num_func: index of the last objective to consider when updating the pareto front from 0 to K
|
| 265 |
+
"""
|
| 266 |
+
paretoSize = len(self.peptideParetoFront)
|
| 267 |
+
|
| 268 |
+
self.curr_num_func = 1
|
| 269 |
+
|
| 270 |
+
for i in range(len(self.num_func)):
|
| 271 |
+
if self.iter_num >= self.num_func[i]:
|
| 272 |
+
self.curr_num_func = i+1
|
| 273 |
+
|
| 274 |
+
if paretoSize == 0:
|
| 275 |
+
# if pareto front is empty, add sequence and scoreVector
|
| 276 |
+
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
|
| 277 |
+
# if pareto front is empty, set reward vector to 1s
|
| 278 |
+
rewardVector = np.ones(len(scoreVector))
|
| 279 |
+
else:
|
| 280 |
+
# vector of boolean
|
| 281 |
+
# true: sequence is non-dominated by the pareto-optimal sequence
|
| 282 |
+
# false: sequence is completely dominated by the pareto-optimal sequence
|
| 283 |
+
nondominate = []
|
| 284 |
+
# sequences to be deleted
|
| 285 |
+
delete = []
|
| 286 |
+
# initialize reward vector with zeros
|
| 287 |
+
rewardVector = np.zeros(len(scoreVector))
|
| 288 |
+
for k, v in self.peptideParetoFront.items():
|
| 289 |
+
# boolean vector
|
| 290 |
+
# true: if all metrics are equal or larger
|
| 291 |
+
# false: if the pareto front sequence dominates scoreVector
|
| 292 |
+
nondominated = scoreVector >= np.asarray(v['scores']) # [num_objectives]
|
| 293 |
+
dominant = scoreVector > np.asarray(v['scores'])
|
| 294 |
+
# add to reward vector
|
| 295 |
+
rewardVector += nondominated # [num_objectives]
|
| 296 |
+
|
| 297 |
+
if self.curr_num_func <= len(nondominated):
|
| 298 |
+
attn_nondominated = nondominated[:self.curr_num_func]
|
| 299 |
+
attn_dominant = dominant[:self.curr_num_func]
|
| 300 |
+
|
| 301 |
+
# only delete pareto-optimal sequence if
|
| 302 |
+
# all scores are greater than or equal to v and at least one score is strictly greater than v
|
| 303 |
+
if attn_nondominated.all() and attn_dominant.any():
|
| 304 |
+
# add the dominated sequence to be deleted
|
| 305 |
+
delete.append(k)
|
| 306 |
+
# sequence is dominant
|
| 307 |
+
nondominate.append(True)
|
| 308 |
+
elif attn_nondominated.all():
|
| 309 |
+
# sequence is non-dominated
|
| 310 |
+
nondominate.append(True)
|
| 311 |
+
else:
|
| 312 |
+
# sequence is completely dominated
|
| 313 |
+
nondominate.append(False)
|
| 314 |
+
|
| 315 |
+
assert len(nondominate) == paretoSize
|
| 316 |
+
nondominate = np.asarray(nondominate)
|
| 317 |
+
# if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
|
| 318 |
+
# or if the pareto front does not have enough sequences
|
| 319 |
+
if nondominate.all() or paretoSize < self.num_sequences:
|
| 320 |
+
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
|
| 321 |
+
|
| 322 |
+
rewardVector = rewardVector / paretoSize
|
| 323 |
+
|
| 324 |
+
# delete all dominated sequences if pareto front is larger than num_sequences
|
| 325 |
+
while (paretoSize > self.num_sequences) and (len(delete) > 0):
|
| 326 |
+
#for k in delete:
|
| 327 |
+
del self.peptideParetoFront[delete[0]]
|
| 328 |
+
del delete[0]
|
| 329 |
+
paretoSize -= 1
|
| 330 |
+
|
| 331 |
+
return rewardVector
|
| 332 |
+
|
| 333 |
+
def isPathEnd(self, path, maxDepth):
|
| 334 |
+
"""
|
| 335 |
+
Checks if the node is completely unmasked (ie. end of path)
|
| 336 |
+
or if the path is at the max depth
|
| 337 |
+
"""
|
| 338 |
+
if (path[-1] != self.config.mcts.mask_token).all():
|
| 339 |
+
return True
|
| 340 |
+
elif len(path) >= maxDepth:
|
| 341 |
+
return True
|
| 342 |
+
return False
|
| 343 |
+
|
| 344 |
+
def select(self, currNode):
|
| 345 |
+
"""
|
| 346 |
+
Traverse the tree from the root node until reaching a legal leaf node
|
| 347 |
+
"""
|
| 348 |
+
while True:
|
| 349 |
+
currNode, nodeStatus = currNode.selectNode(self.curr_num_func)
|
| 350 |
+
if nodeStatus != 3:
|
| 351 |
+
return currNode, nodeStatus
|
| 352 |
+
|
| 353 |
+
def expand(self, parentNode, eps=1e-5, checkSimilarity = True):
|
| 354 |
+
"""
|
| 355 |
+
Sample unmasking steps from the pre-trained MDLM
|
| 356 |
+
adds num_children partially unmasked sequences to the children of the parentNode
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
num_children = self.config.mcts.num_children
|
| 360 |
+
# initialize child rewards that will be added to total rewards
|
| 361 |
+
allChildReward = np.zeros_like(parentNode.totalReward) # (n_objectives)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# compute number of rollout steps
|
| 365 |
+
# if parentNode.timestep = self.num_steps then num_rollout_steps = 1
|
| 366 |
+
num_rollout_steps = self.num_steps - parentNode.timestep
|
| 367 |
+
# array of rollout timesteps from the timestep of parent node to 0
|
| 368 |
+
rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device)
|
| 369 |
+
dt = (1 - eps) / self.num_steps
|
| 370 |
+
p_x0_cache = None
|
| 371 |
+
|
| 372 |
+
# initialize x and attn_mask
|
| 373 |
+
x = parentNode.tokens['input_ids'].to(self.device)
|
| 374 |
+
attn_mask = parentNode.tokens['attention_mask'].to(self.device)
|
| 375 |
+
|
| 376 |
+
t = rollout_t[0] * torch.ones(num_children, 1, device = self.device)
|
| 377 |
+
# generate (n_children, seq_length) array of sampled children nodes
|
| 378 |
+
print("token array:")
|
| 379 |
+
print(x)
|
| 380 |
+
p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x,
|
| 381 |
+
t=t, dt=dt,
|
| 382 |
+
batch_size=num_children,
|
| 383 |
+
attn_mask=attn_mask)
|
| 384 |
+
x_rollout = x_children
|
| 385 |
+
|
| 386 |
+
for i in range(1, num_rollout_steps):
|
| 387 |
+
t = rollout_t[i] * torch.ones(num_children, 1, device = self.device)
|
| 388 |
+
|
| 389 |
+
p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout,
|
| 390 |
+
t=t, dt=dt, p_x0=p_x0_cache,
|
| 391 |
+
attn_mask=attn_mask)
|
| 392 |
+
|
| 393 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 394 |
+
# Disable caching
|
| 395 |
+
p_x0_cache = None
|
| 396 |
+
|
| 397 |
+
x_rollout = x_next
|
| 398 |
+
|
| 399 |
+
if self.config.sampling.noise_removal:
|
| 400 |
+
t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device)
|
| 401 |
+
"""if self.sampler == 'analytic':
|
| 402 |
+
x = self.mdlm._denoiser_update(x, t)
|
| 403 |
+
else:"""
|
| 404 |
+
time_cond = self.noise(t)[0]
|
| 405 |
+
x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) # (n_children, seq_length)
|
| 406 |
+
|
| 407 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 408 |
+
|
| 409 |
+
validSequences = []
|
| 410 |
+
maskedTokens = []
|
| 411 |
+
unmaskedTokens = []
|
| 412 |
+
for i in range(num_children):
|
| 413 |
+
childSeq = childSequences[i]
|
| 414 |
+
#scoreVector = scoreVectors[i]
|
| 415 |
+
rewardVector = np.zeros(self.config.mcts.num_objectives)
|
| 416 |
+
|
| 417 |
+
# check if the peptide is valid
|
| 418 |
+
if self.analyzer.is_peptide(childSeq):
|
| 419 |
+
validSequences.append(childSeq)
|
| 420 |
+
maskedTokens.append(x_children[i])
|
| 421 |
+
unmaskedTokens.append(x_rollout[i])
|
| 422 |
+
else:
|
| 423 |
+
childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask}
|
| 424 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 425 |
+
totalReward=rewardVector)
|
| 426 |
+
|
| 427 |
+
if (len(validSequences) != 0):
|
| 428 |
+
scoreVectors = self.score_functions(input_seqs=validSequences)
|
| 429 |
+
average_scores = scoreVectors.T
|
| 430 |
+
if self.config.mcts.single:
|
| 431 |
+
self.permeability_log.append(average_scores[0])
|
| 432 |
+
else:
|
| 433 |
+
self.affinity1_log.append(average_scores[0])
|
| 434 |
+
self.sol_log.append(average_scores[1])
|
| 435 |
+
self.hemo_log.append(average_scores[2])
|
| 436 |
+
self.nf_log.append(average_scores[3])
|
| 437 |
+
if self.config.mcts.perm:
|
| 438 |
+
self.permeability_log.append(average_scores[4])
|
| 439 |
+
elif self.config.mcts.dual:
|
| 440 |
+
self.affinity2_log.append(average_scores[4])
|
| 441 |
+
else:
|
| 442 |
+
self.affinity1_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
|
| 443 |
+
self.sol_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
|
| 444 |
+
self.hemo_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
|
| 445 |
+
self.nf_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
|
| 446 |
+
|
| 447 |
+
if self.config.mcts.perm:
|
| 448 |
+
self.permeability_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
|
| 449 |
+
elif self.config.mcts.dual:
|
| 450 |
+
self.affinity2_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
|
| 451 |
+
|
| 452 |
+
for i, validSeq in enumerate(validSequences):
|
| 453 |
+
#tokens = validTokens[i]
|
| 454 |
+
scoreVector = scoreVectors[i]
|
| 455 |
+
|
| 456 |
+
# update pareto front
|
| 457 |
+
rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i])
|
| 458 |
+
print(scoreVector)
|
| 459 |
+
print(rewardVector)
|
| 460 |
+
|
| 461 |
+
# add to all child reward vector for backprop
|
| 462 |
+
allChildReward += rewardVector
|
| 463 |
+
|
| 464 |
+
# create node for sequence and add to the children node of parent
|
| 465 |
+
childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask}
|
| 466 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 467 |
+
totalReward=rewardVector)
|
| 468 |
+
|
| 469 |
+
# compute fraction of invalid child sequences
|
| 470 |
+
invalid = (num_children - len(validSequences)) / num_children
|
| 471 |
+
|
| 472 |
+
valid_fraction = len(validSequences) / num_children
|
| 473 |
+
print(f"Valid fraction: {valid_fraction}")
|
| 474 |
+
self.valid_fraction_log.append(valid_fraction)
|
| 475 |
+
|
| 476 |
+
print(self.config.mcts.invalid_penalty)
|
| 477 |
+
# subtract score using fraction of invalid sequences from reward
|
| 478 |
+
allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid)
|
| 479 |
+
# backpropogate all child rewards
|
| 480 |
+
self.backprop(parentNode, allChildReward)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def backprop(self, node, reward_vector):
|
| 484 |
+
# backpropogate rewards through the path leading to the leaf node from the root
|
| 485 |
+
while node:
|
| 486 |
+
node.updateNode(reward_vector)
|
| 487 |
+
node = node.parentNode
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def getSequenceForObjective(self, objective_index, k):
|
| 491 |
+
"""
|
| 492 |
+
Returns the top-k sequences in the pareto front that has the best score for
|
| 493 |
+
a given objective and their score vectors for all objectives
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
# dictionary of top-k peptides for the objective
|
| 497 |
+
topk = {}
|
| 498 |
+
|
| 499 |
+
peptides = []
|
| 500 |
+
objectiveScores = []
|
| 501 |
+
for k, v in self.peptideParetoFront.items():
|
| 502 |
+
# store peptides in list
|
| 503 |
+
peptides.append(k)
|
| 504 |
+
# store score for objective
|
| 505 |
+
objectiveScores.append(v['token_ids'][objective_index])
|
| 506 |
+
|
| 507 |
+
objectiveScores = torch.tensor(objectiveScores)
|
| 508 |
+
topKScores = torch.topk(objectiveScores, k)
|
| 509 |
+
for (_, index) in topKScores.items():
|
| 510 |
+
seq = peptides[index]
|
| 511 |
+
|
| 512 |
+
topk[seq] = self.peptideParetoFront.get(seq)
|
| 513 |
+
|
| 514 |
+
return topk
|
| 515 |
+
|
roformer.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import RoFormerConfig, RoFormerForMaskedLM
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class Roformer(nn.Module):
|
| 7 |
+
def __init__(self, config, tokenizer):
|
| 8 |
+
super(Roformer, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.tokenizer = tokenizer
|
| 11 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
self.device = device
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
roformer_config = RoFormerConfig(
|
| 17 |
+
vocab_size=self.tokenizer.vocab_size,
|
| 18 |
+
embedding_size=config.roformer.hidden_size,
|
| 19 |
+
hidden_size=config.roformer.hidden_size,
|
| 20 |
+
num_hidden_layers=config.roformer.n_layers,
|
| 21 |
+
num_attention_heads=config.roformer.n_heads,
|
| 22 |
+
intermediate_size=config.roformer.hidden_size * 4,
|
| 23 |
+
max_position_embeddings=config.roformer.max_position_embeddings,
|
| 24 |
+
hidden_dropout_prob=0.1,
|
| 25 |
+
attention_probs_dropout_prob=0.1,
|
| 26 |
+
pad_token_id=0,
|
| 27 |
+
rotary_value=False
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
|
| 31 |
+
|
| 32 |
+
def freeze_model(self):
|
| 33 |
+
for param in self.model.parameters():
|
| 34 |
+
param.requires_grad = False
|
| 35 |
+
|
| 36 |
+
def unfreeze_all_layers(self):
|
| 37 |
+
for param in self.model.parameters():
|
| 38 |
+
param.requires_grad = True
|
| 39 |
+
|
| 40 |
+
def unfreeze_n_layers(self, n):
|
| 41 |
+
num_layers = 8
|
| 42 |
+
|
| 43 |
+
for i, layer in enumerate(self.model.roformer.encoder.layer):
|
| 44 |
+
# finetune final n layers
|
| 45 |
+
if i >= num_layers - n:
|
| 46 |
+
# unfreeze query weights
|
| 47 |
+
for module in layer.attention.self.query.modules():
|
| 48 |
+
for param in module.parameters():
|
| 49 |
+
param.requires_grad = True
|
| 50 |
+
# unfreeze key weights
|
| 51 |
+
for module in layer.attention.self.key.modules():
|
| 52 |
+
for param in module.parameters():
|
| 53 |
+
param.requires_grad = True
|
| 54 |
+
|
| 55 |
+
def forward(self, input_ids, attn_mask):
|
| 56 |
+
|
| 57 |
+
input_ids = input_ids.to(self.device)
|
| 58 |
+
attn_mask = attn_mask.to(self.device)
|
| 59 |
+
|
| 60 |
+
# get logits embeddings
|
| 61 |
+
logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
|
| 62 |
+
# return logits
|
| 63 |
+
#print(logits.logits)
|
| 64 |
+
return logits.logits
|
| 65 |
+
|
| 66 |
+
def save_model(self, save_dir):
|
| 67 |
+
self.model.save_pretrained(save_dir)
|
| 68 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def load_model(cls, save_dir, config, tokenizer):
|
| 72 |
+
roformer = cls(config, tokenizer)
|
| 73 |
+
roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
|
| 74 |
+
return roformer
|
scoring/__init__.py
ADDED
|
File without changes
|
scoring/binary_xg.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from sklearn.metrics import precision_recall_curve, f1_score
|
| 6 |
+
import optuna
|
| 7 |
+
from optuna.trial import TrialState
|
| 8 |
+
import xgboost as xgb
|
| 9 |
+
import os
|
| 10 |
+
from datasets import load_from_disk
|
| 11 |
+
from lightning.pytorch import seed_everything
|
| 12 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 13 |
+
from typing import List
|
| 14 |
+
from rdkit.Chem import AllChem
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
| 17 |
+
import seaborn as sns
|
| 18 |
+
|
| 19 |
+
def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
|
| 20 |
+
"""
|
| 21 |
+
Saves the true and predicted values for training and validation sets, and generates binary classification plots.
|
| 22 |
+
|
| 23 |
+
Parameters:
|
| 24 |
+
y_true_train (array): True labels for the training set.
|
| 25 |
+
y_pred_train (array): Predicted probabilities for the training set.
|
| 26 |
+
y_true_val (array): True labels for the validation set.
|
| 27 |
+
y_pred_val (array): Predicted probabilities for the validation set.
|
| 28 |
+
threshold (float): Classification threshold for predictions.
|
| 29 |
+
output_path (str): Directory to save the CSV files and plots.
|
| 30 |
+
"""
|
| 31 |
+
os.makedirs(output_path, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
# Convert probabilities to binary predictions
|
| 34 |
+
y_pred_train_binary = (y_pred_train >= threshold).astype(int)
|
| 35 |
+
y_pred_val_binary = (y_pred_val >= threshold).astype(int)
|
| 36 |
+
|
| 37 |
+
# Save training predictions
|
| 38 |
+
train_df = pd.DataFrame({
|
| 39 |
+
'True Label': y_true_train,
|
| 40 |
+
'Predicted Probability': y_pred_train,
|
| 41 |
+
'Predicted Label': y_pred_train_binary
|
| 42 |
+
})
|
| 43 |
+
train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)
|
| 44 |
+
|
| 45 |
+
# Save validation predictions
|
| 46 |
+
val_df = pd.DataFrame({
|
| 47 |
+
'True Label': y_true_val,
|
| 48 |
+
'Predicted Probability': y_pred_val,
|
| 49 |
+
'Predicted Label': y_pred_val_binary
|
| 50 |
+
})
|
| 51 |
+
val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)
|
| 52 |
+
|
| 53 |
+
# Plot training predictions
|
| 54 |
+
plot_boxplot_with_threshold(
|
| 55 |
+
y_true_train,
|
| 56 |
+
y_pred_train,
|
| 57 |
+
threshold,
|
| 58 |
+
title="Training Set Binary Classification Plot",
|
| 59 |
+
output_file=os.path.join(output_path, 'train_classification_plot.png')
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Plot validation predictions
|
| 63 |
+
plot_boxplot_with_threshold(
|
| 64 |
+
y_true_val,
|
| 65 |
+
y_pred_val,
|
| 66 |
+
threshold,
|
| 67 |
+
title="Validation Set Binary Classification Plot",
|
| 68 |
+
output_file=os.path.join(output_path, 'val_classification_plot.png')
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
|
| 72 |
+
# Scatter plot
|
| 73 |
+
plt.figure(figsize=(10, 8))
|
| 74 |
+
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
|
| 75 |
+
|
| 76 |
+
# Add threshold line
|
| 77 |
+
plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
|
| 78 |
+
|
| 79 |
+
# Add annotations
|
| 80 |
+
plt.title(title)
|
| 81 |
+
plt.xlabel("True Labels")
|
| 82 |
+
plt.ylabel("Predicted Probability")
|
| 83 |
+
plt.legend()
|
| 84 |
+
|
| 85 |
+
# Save and show the plot
|
| 86 |
+
plt.tight_layout()
|
| 87 |
+
plt.savefig(output_file)
|
| 88 |
+
plt.show()
|
| 89 |
+
|
| 90 |
+
def plot_boxplot_with_threshold(y_true, y_pred, threshold, title, output_file):
|
| 91 |
+
"""
|
| 92 |
+
Generates a boxplot for binary classification and includes a threshold line.
|
| 93 |
+
|
| 94 |
+
Parameters:
|
| 95 |
+
y_true (array): True labels.
|
| 96 |
+
y_pred (array): Predicted probabilities.
|
| 97 |
+
threshold (float): Classification threshold for predictions.
|
| 98 |
+
title (str): Title of the plot.
|
| 99 |
+
output_file (str): Path to save the plot.
|
| 100 |
+
"""
|
| 101 |
+
plt.figure(figsize=(10, 8))
|
| 102 |
+
|
| 103 |
+
# Combine data into a DataFrame for seaborn
|
| 104 |
+
df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred})
|
| 105 |
+
|
| 106 |
+
# Boxplot
|
| 107 |
+
sns.boxplot(x='True Label', y='Predicted Probability', data=df)
|
| 108 |
+
|
| 109 |
+
# Add threshold line
|
| 110 |
+
plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
|
| 111 |
+
plt.text(
|
| 112 |
+
x=0.5, y=threshold + 0.05, s=f"Threshold = {threshold}", color="red", fontsize=10
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Add annotations
|
| 116 |
+
plt.title(title)
|
| 117 |
+
plt.xlabel("True Label")
|
| 118 |
+
plt.ylabel("Predicted Probability")
|
| 119 |
+
plt.legend()
|
| 120 |
+
|
| 121 |
+
# Save and show the plot
|
| 122 |
+
plt.tight_layout()
|
| 123 |
+
plt.savefig(output_file)
|
| 124 |
+
plt.show()
|
| 125 |
+
|
| 126 |
+
def plot_boxplot(y_true, y_pred, title, output_file):
|
| 127 |
+
plt.figure(figsize=(10, 8))
|
| 128 |
+
|
| 129 |
+
# Combine data into a single DataFrame for seaborn
|
| 130 |
+
df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred})
|
| 131 |
+
sns.boxplot(x='True Label', y='Predicted Probability', data=df)
|
| 132 |
+
|
| 133 |
+
# Add annotations
|
| 134 |
+
plt.title(title)
|
| 135 |
+
plt.xlabel("True Label")
|
| 136 |
+
plt.ylabel("Predicted Probability")
|
| 137 |
+
|
| 138 |
+
# Save and show the plot
|
| 139 |
+
plt.tight_layout()
|
| 140 |
+
plt.savefig(output_file)
|
| 141 |
+
plt.show()
|
| 142 |
+
|
| 143 |
+
def plot_binary_correlation_with_density(y_true, y_pred, threshold, title, output_file):
|
| 144 |
+
"""
|
| 145 |
+
Generates a scatter plot with a density plot for binary classification and saves it to a file.
|
| 146 |
+
"""
|
| 147 |
+
plt.figure(figsize=(10, 8))
|
| 148 |
+
|
| 149 |
+
# Scatter plot
|
| 150 |
+
plt.scatter(range(len(y_true)), y_pred, alpha=0.5, label='Predicted Probabilities', color='#BC80FF')
|
| 151 |
+
|
| 152 |
+
# Add density plot
|
| 153 |
+
sns.kdeplot(y_pred, color='green', fill=True, alpha=0.3, label='Probability Density')
|
| 154 |
+
|
| 155 |
+
# Add threshold line
|
| 156 |
+
plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
|
| 157 |
+
|
| 158 |
+
# Add annotations
|
| 159 |
+
plt.title(title)
|
| 160 |
+
plt.xlabel("Index")
|
| 161 |
+
plt.ylabel("Predicted Probability")
|
| 162 |
+
plt.legend()
|
| 163 |
+
|
| 164 |
+
# Save and show the plot
|
| 165 |
+
plt.tight_layout()
|
| 166 |
+
plt.savefig(output_file)
|
| 167 |
+
plt.show()
|
| 168 |
+
|
| 169 |
+
seed_everything(42)
|
| 170 |
+
|
| 171 |
+
dataset = load_from_disk('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_data')
|
| 172 |
+
|
| 173 |
+
sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings
|
| 174 |
+
labels = np.stack(dataset['labels'])
|
| 175 |
+
embeddings = np.stack(dataset['embedding'])
|
| 176 |
+
|
| 177 |
+
# Initialize best F1 score and model path
|
| 178 |
+
best_f1 = -np.inf
|
| 179 |
+
best_model_path = "/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_train/"
|
| 180 |
+
|
| 181 |
+
# Trial callback
|
| 182 |
+
def trial_info_callback(study, trial):
|
| 183 |
+
if study.best_trial == trial:
|
| 184 |
+
print(f"Trial {trial.number}:")
|
| 185 |
+
print(f" Weighted F1 Score: {trial.value}")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def objective(trial):
|
| 189 |
+
params = {
|
| 190 |
+
'objective': 'binary:logistic',
|
| 191 |
+
'lambda': trial.suggest_float('lambda', 1e-8, 10.0, log=True),
|
| 192 |
+
'alpha': trial.suggest_float('alpha', 1e-8, 10.0, log=True),
|
| 193 |
+
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.1, 1.0),
|
| 194 |
+
'subsample': trial.suggest_float('subsample', 0.1, 1.0),
|
| 195 |
+
'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3),
|
| 196 |
+
'max_depth': trial.suggest_int('max_depth', 2, 30),
|
| 197 |
+
'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
|
| 198 |
+
'tree_method': 'hist',
|
| 199 |
+
'device': 'cuda:0',
|
| 200 |
+
}
|
| 201 |
+
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
|
| 202 |
+
|
| 203 |
+
# Split the data
|
| 204 |
+
train_idx, val_idx = train_test_split(
|
| 205 |
+
np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
|
| 206 |
+
)
|
| 207 |
+
train_subset = dataset.select(train_idx).with_format("torch")
|
| 208 |
+
val_subset = dataset.select(val_idx).with_format("torch")
|
| 209 |
+
|
| 210 |
+
# Extract embeddings and labels for train/validation
|
| 211 |
+
train_embeddings = train_subset['embedding']
|
| 212 |
+
valid_embeddings = val_subset['embedding']
|
| 213 |
+
train_labels = train_subset['labels']
|
| 214 |
+
valid_labels = val_subset['labels']
|
| 215 |
+
|
| 216 |
+
# Prepare training and validation sets
|
| 217 |
+
dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
|
| 218 |
+
dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)
|
| 219 |
+
|
| 220 |
+
# Train the model
|
| 221 |
+
model = xgb.train(
|
| 222 |
+
params=params,
|
| 223 |
+
dtrain=dtrain,
|
| 224 |
+
num_boost_round=num_boost_round,
|
| 225 |
+
evals=[(dvalid, "validation")],
|
| 226 |
+
early_stopping_rounds=50,
|
| 227 |
+
verbose_eval=False,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Predict probabilities
|
| 231 |
+
preds_train = model.predict(dtrain)
|
| 232 |
+
preds_val = model.predict(dvalid)
|
| 233 |
+
|
| 234 |
+
# Perform dynamic thresholding on validation predictions
|
| 235 |
+
best_f1_val = -np.inf
|
| 236 |
+
best_threshold = 0.5
|
| 237 |
+
|
| 238 |
+
for threshold in np.arange(0.1, 1.0, 0.05): # Try thresholds from 0.1 to 1.0
|
| 239 |
+
preds_val_binary = (preds_val >= threshold).astype(int)
|
| 240 |
+
f1_temp = f1_score(valid_labels, preds_val_binary, average="weighted")
|
| 241 |
+
if f1_temp > best_f1_val:
|
| 242 |
+
best_f1_val = f1_temp
|
| 243 |
+
best_threshold = threshold
|
| 244 |
+
|
| 245 |
+
print(f"Best F1 Score: {best_f1_val:.3f} at Threshold: {best_threshold:.3f}")
|
| 246 |
+
|
| 247 |
+
# Calculate AUC for additional insight
|
| 248 |
+
auc_val = roc_auc_score(valid_labels, preds_val)
|
| 249 |
+
print(f"AUC: {auc_val:.3f}")
|
| 250 |
+
|
| 251 |
+
# Save the best model if the F1 score is improved
|
| 252 |
+
if trial.study.user_attrs.get("best_f1", -np.inf) < best_f1_val:
|
| 253 |
+
trial.study.set_user_attr("best_f1", best_f1_val)
|
| 254 |
+
trial.study.set_user_attr("best_threshold", best_threshold) # Save the best threshold
|
| 255 |
+
os.makedirs(best_model_path, exist_ok=True)
|
| 256 |
+
|
| 257 |
+
model.save_model(os.path.join(best_model_path, "best_model.json"))
|
| 258 |
+
print(f"Best model saved to {os.path.join(best_model_path, 'best_model.json')}")
|
| 259 |
+
|
| 260 |
+
# Save and plot binary predictions with the best threshold
|
| 261 |
+
save_and_plot_binary_predictions(
|
| 262 |
+
train_labels,
|
| 263 |
+
preds_train,
|
| 264 |
+
valid_labels,
|
| 265 |
+
preds_val,
|
| 266 |
+
best_threshold,
|
| 267 |
+
best_model_path
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return best_f1_val
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 274 |
+
study.optimize(objective, n_trials=200)
|
| 275 |
+
|
| 276 |
+
print("Study statistics: ")
|
| 277 |
+
print(f" Number of finished trials: {len(study.trials)}")
|
| 278 |
+
print(f" Best AUC: {study.user_attrs.get('best_auc', None)}")
|
| 279 |
+
for key, value in study.best_trial.params.items():
|
| 280 |
+
print(f" {key}: {value}")
|
scoring/functions/binding.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from sklearn.model_selection import train_test_split
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import torch
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import esm
|
| 11 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 12 |
+
from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModel
|
| 13 |
+
from peft import PeftModel, PeftConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 17 |
+
def __init__(self,
|
| 18 |
+
esm_dim=1280,
|
| 19 |
+
smiles_dim=768,
|
| 20 |
+
hidden_dim=512,
|
| 21 |
+
n_heads=8,
|
| 22 |
+
n_layers=3,
|
| 23 |
+
dropout=0.1):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
# Define binding thresholds
|
| 27 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 28 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 29 |
+
|
| 30 |
+
# Project to same dimension
|
| 31 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 32 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 33 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 34 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 35 |
+
|
| 36 |
+
# Cross attention blocks with layer norm
|
| 37 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 38 |
+
nn.ModuleDict({
|
| 39 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 40 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 41 |
+
'ffn': nn.Sequential(
|
| 42 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 43 |
+
nn.ReLU(),
|
| 44 |
+
nn.Dropout(dropout),
|
| 45 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 46 |
+
),
|
| 47 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 48 |
+
}) for _ in range(n_layers)
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
# Prediction heads
|
| 52 |
+
self.shared_head = nn.Sequential(
|
| 53 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 54 |
+
nn.ReLU(),
|
| 55 |
+
nn.Dropout(dropout),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Regression head
|
| 59 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 60 |
+
|
| 61 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 62 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 63 |
+
|
| 64 |
+
def get_binding_class(self, affinity):
|
| 65 |
+
"""Convert affinity values to class indices
|
| 66 |
+
0: tight binding (>= 7.5)
|
| 67 |
+
1: medium binding (6.0-7.5)
|
| 68 |
+
2: weak binding (< 6.0)
|
| 69 |
+
"""
|
| 70 |
+
if isinstance(affinity, torch.Tensor):
|
| 71 |
+
tight_mask = affinity >= self.tight_threshold
|
| 72 |
+
weak_mask = affinity < self.weak_threshold
|
| 73 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 74 |
+
|
| 75 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 76 |
+
classes[medium_mask] = 1
|
| 77 |
+
classes[weak_mask] = 2
|
| 78 |
+
return classes
|
| 79 |
+
else:
|
| 80 |
+
if affinity >= self.tight_threshold:
|
| 81 |
+
return 0 # tight binding
|
| 82 |
+
elif affinity < self.weak_threshold:
|
| 83 |
+
return 2 # weak binding
|
| 84 |
+
else:
|
| 85 |
+
return 1 # medium binding
|
| 86 |
+
|
| 87 |
+
def forward(self, protein_emb, smiles_emb):
|
| 88 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 89 |
+
smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
|
| 90 |
+
|
| 91 |
+
#protein = protein.transpose(0, 1)
|
| 92 |
+
#smiles = smiles.transpose(0, 1)
|
| 93 |
+
|
| 94 |
+
# Cross attention layers
|
| 95 |
+
for layer in self.cross_attention_layers:
|
| 96 |
+
# Protein attending to SMILES
|
| 97 |
+
attended_protein = layer['attention'](
|
| 98 |
+
protein, smiles, smiles
|
| 99 |
+
)[0]
|
| 100 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 101 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 102 |
+
|
| 103 |
+
# SMILES attending to protein
|
| 104 |
+
attended_smiles = layer['attention'](
|
| 105 |
+
smiles, protein, protein
|
| 106 |
+
)[0]
|
| 107 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 108 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 109 |
+
|
| 110 |
+
# Get sequence-level representations
|
| 111 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 112 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 113 |
+
|
| 114 |
+
# Concatenate both representations
|
| 115 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 116 |
+
|
| 117 |
+
# Shared features
|
| 118 |
+
shared_features = self.shared_head(combined)
|
| 119 |
+
|
| 120 |
+
regression_output = self.regression_head(shared_features)
|
| 121 |
+
classification_logits = self.classification_head(shared_features)
|
| 122 |
+
|
| 123 |
+
return regression_output, classification_logits
|
| 124 |
+
|
| 125 |
+
class BindingAffinity:
|
| 126 |
+
def __init__(self, prot_seq, model_type='PeptideCLM'):
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
if model_type == 'PepDoRA':
|
| 130 |
+
# peptide embeddings
|
| 131 |
+
model_name = "ChatterjeeLab/PepDoRA"
|
| 132 |
+
self.pep_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 133 |
+
self.pep_model = AutoModel.from_pretrained(model_name)
|
| 134 |
+
|
| 135 |
+
self.model = ImprovedBindingPredictor(smiles_dim=384)
|
| 136 |
+
checkpoint = torch.load('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/binding/best_model_optuna1.pt')
|
| 137 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 138 |
+
else:
|
| 139 |
+
# peptide embeddings
|
| 140 |
+
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 141 |
+
self.pep_tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
|
| 142 |
+
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 143 |
+
|
| 144 |
+
self.model = ImprovedBindingPredictor()
|
| 145 |
+
checkpoint = torch.load('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/binding/best_model.pt')
|
| 146 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 147 |
+
|
| 148 |
+
self.model.eval()
|
| 149 |
+
|
| 150 |
+
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
|
| 151 |
+
self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
|
| 152 |
+
|
| 153 |
+
data = [("target", prot_seq)]
|
| 154 |
+
# get tokenized protein
|
| 155 |
+
_, _, prot_tokens = self.prot_tokenizer(data)
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
|
| 158 |
+
prot_emb = results["representations"][33]
|
| 159 |
+
|
| 160 |
+
self.prot_emb = prot_emb[0]
|
| 161 |
+
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def forward(self, input_seqs):
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
scores = []
|
| 167 |
+
for seq in input_seqs:
|
| 168 |
+
pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
|
| 169 |
+
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
emb = self.pep_model(input_ids=pep_tokens['input_ids'],
|
| 172 |
+
attention_mask=pep_tokens['attention_mask'],
|
| 173 |
+
output_hidden_states=True)
|
| 174 |
+
|
| 175 |
+
#emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
|
| 176 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 177 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 178 |
+
|
| 179 |
+
score, logits = self.model.forward(self.prot_emb, pep_emb)
|
| 180 |
+
scores.append(score.item())
|
| 181 |
+
return scores
|
| 182 |
+
|
| 183 |
+
def __call__(self, input_seqs: list):
|
| 184 |
+
return self.forward(input_seqs)
|
| 185 |
+
|
| 186 |
+
def unittest():
|
| 187 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 188 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 189 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 190 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 191 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 192 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 193 |
+
|
| 194 |
+
binding = BindingAffinity(tfr)
|
| 195 |
+
seq = ["CC[C@H](C)[C@H](NC(=O)[C@H](C)NC(=O)[C@@H](N)Cc1c[nH]cn1)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)O"]
|
| 196 |
+
|
| 197 |
+
scores = binding(seq)
|
| 198 |
+
print(scores)
|
| 199 |
+
print(len(scores))
|
| 200 |
+
|
| 201 |
+
if __name__ == '__main__':
|
| 202 |
+
unittest()
|
scoring/functions/binding_utils.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import pdb
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def to_var(x):
|
| 7 |
+
if torch.cuda.is_available():
|
| 8 |
+
x = x.cuda()
|
| 9 |
+
return x
|
| 10 |
+
|
| 11 |
+
class MultiHeadAttentionSequence(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 14 |
+
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.n_head = n_head
|
| 18 |
+
self.d_model = d_model
|
| 19 |
+
self.d_k = d_k
|
| 20 |
+
self.d_v = d_v
|
| 21 |
+
|
| 22 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 23 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 24 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 25 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 26 |
+
|
| 27 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 28 |
+
|
| 29 |
+
self.dropout = nn.Dropout(dropout)
|
| 30 |
+
|
| 31 |
+
def forward(self, q, k, v):
|
| 32 |
+
|
| 33 |
+
batch, len_q, _ = q.size()
|
| 34 |
+
batch, len_k, _ = k.size()
|
| 35 |
+
batch, len_v, _ = v.size()
|
| 36 |
+
|
| 37 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 38 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 39 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 40 |
+
|
| 41 |
+
Q = Q.transpose(1, 2)
|
| 42 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 43 |
+
V = V.transpose(1, 2)
|
| 44 |
+
|
| 45 |
+
attention = torch.matmul(Q, K)
|
| 46 |
+
|
| 47 |
+
attention = attention / np.sqrt(self.d_k)
|
| 48 |
+
|
| 49 |
+
attention = F.softmax(attention, dim=-1)
|
| 50 |
+
|
| 51 |
+
output = torch.matmul(attention, V)
|
| 52 |
+
|
| 53 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 54 |
+
|
| 55 |
+
output = self.W_O(output)
|
| 56 |
+
|
| 57 |
+
output = self.dropout(output)
|
| 58 |
+
|
| 59 |
+
output = self.layer_norm(output + q)
|
| 60 |
+
|
| 61 |
+
return output, attention
|
| 62 |
+
|
| 63 |
+
class MultiHeadAttentionReciprocal(nn.Module):
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 67 |
+
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
self.n_head = n_head
|
| 71 |
+
self.d_model = d_model
|
| 72 |
+
self.d_k = d_k
|
| 73 |
+
self.d_v = d_v
|
| 74 |
+
|
| 75 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 76 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 77 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 78 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 79 |
+
self.W_V_2 = nn.Linear(d_model, n_head*d_v)
|
| 80 |
+
self.W_O_2 = nn.Linear(n_head*d_v, d_model)
|
| 81 |
+
|
| 82 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 83 |
+
|
| 84 |
+
self.dropout = nn.Dropout(dropout)
|
| 85 |
+
|
| 86 |
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
| 87 |
+
|
| 88 |
+
self.dropout_2 = nn.Dropout(dropout)
|
| 89 |
+
|
| 90 |
+
def forward(self, q, k, v, v_2):
|
| 91 |
+
|
| 92 |
+
batch, len_q, _ = q.size()
|
| 93 |
+
batch, len_k, _ = k.size()
|
| 94 |
+
batch, len_v, _ = v.size()
|
| 95 |
+
batch, len_v_2, _ = v_2.size()
|
| 96 |
+
|
| 97 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 98 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 99 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 100 |
+
V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
|
| 101 |
+
|
| 102 |
+
Q = Q.transpose(1, 2)
|
| 103 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 104 |
+
V = V.transpose(1, 2)
|
| 105 |
+
V_2 = V_2.transpose(1,2)
|
| 106 |
+
|
| 107 |
+
attention = torch.matmul(Q, K)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
attention = attention /np.sqrt(self.d_k)
|
| 111 |
+
|
| 112 |
+
attention_2 = attention.transpose(-2, -1)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
attention = F.softmax(attention, dim=-1)
|
| 117 |
+
|
| 118 |
+
attention_2 = F.softmax(attention_2, dim=-1)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
output = torch.matmul(attention, V)
|
| 122 |
+
|
| 123 |
+
output_2 = torch.matmul(attention_2, V_2)
|
| 124 |
+
|
| 125 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 126 |
+
|
| 127 |
+
output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
|
| 128 |
+
|
| 129 |
+
output = self.W_O(output)
|
| 130 |
+
|
| 131 |
+
output_2 = self.W_O_2(output_2)
|
| 132 |
+
|
| 133 |
+
output = self.dropout(output)
|
| 134 |
+
|
| 135 |
+
output = self.layer_norm(output + q)
|
| 136 |
+
|
| 137 |
+
output_2 = self.dropout(output_2)
|
| 138 |
+
|
| 139 |
+
output_2 = self.layer_norm(output_2 + k)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
return output, output_2, attention, attention_2
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class FFN(nn.Module):
|
| 146 |
+
|
| 147 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
| 148 |
+
super().__init__()
|
| 149 |
+
|
| 150 |
+
self.layer_1 = nn.Conv1d(d_in, d_hid,1)
|
| 151 |
+
self.layer_2 = nn.Conv1d(d_hid, d_in,1)
|
| 152 |
+
self.relu = nn.ReLU()
|
| 153 |
+
self.layer_norm = nn.LayerNorm(d_in)
|
| 154 |
+
|
| 155 |
+
self.dropout = nn.Dropout(dropout)
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
|
| 159 |
+
residual = x
|
| 160 |
+
output = self.layer_1(x.transpose(1, 2))
|
| 161 |
+
|
| 162 |
+
output = self.relu(output)
|
| 163 |
+
|
| 164 |
+
output = self.layer_2(output)
|
| 165 |
+
|
| 166 |
+
output = self.dropout(output)
|
| 167 |
+
|
| 168 |
+
output = self.layer_norm(output.transpose(1, 2)+residual)
|
| 169 |
+
|
| 170 |
+
return output
|
| 171 |
+
|
| 172 |
+
class ConvLayer(nn.Module):
|
| 173 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
|
| 174 |
+
super(ConvLayer, self).__init__()
|
| 175 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
| 176 |
+
self.relu = nn.ReLU()
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
out = self.conv(x)
|
| 180 |
+
out = self.relu(out)
|
| 181 |
+
return out
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class DilatedCNN(nn.Module):
|
| 185 |
+
def __init__(self, d_model, d_hidden):
|
| 186 |
+
super(DilatedCNN, self).__init__()
|
| 187 |
+
self.first_ = nn.ModuleList()
|
| 188 |
+
self.second_ = nn.ModuleList()
|
| 189 |
+
self.third_ = nn.ModuleList()
|
| 190 |
+
|
| 191 |
+
dilation_tuple = (1, 2, 3)
|
| 192 |
+
dim_in_tuple = (d_model, d_hidden, d_hidden)
|
| 193 |
+
dim_out_tuple = (d_hidden, d_hidden, d_hidden)
|
| 194 |
+
|
| 195 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 196 |
+
self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
|
| 197 |
+
dilation=dilation_rate))
|
| 198 |
+
|
| 199 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 200 |
+
self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
|
| 201 |
+
dilation=dilation_rate))
|
| 202 |
+
|
| 203 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 204 |
+
self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
|
| 205 |
+
dilation=dilation_rate))
|
| 206 |
+
|
| 207 |
+
def forward(self, protein_seq_enc):
|
| 208 |
+
# pdb.set_trace()
|
| 209 |
+
protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
|
| 210 |
+
|
| 211 |
+
first_embedding = protein_seq_enc
|
| 212 |
+
second_embedding = protein_seq_enc
|
| 213 |
+
third_embedding = protein_seq_enc
|
| 214 |
+
|
| 215 |
+
for i in range(len(self.first_)):
|
| 216 |
+
first_embedding = self.first_[i](first_embedding)
|
| 217 |
+
|
| 218 |
+
for i in range(len(self.second_)):
|
| 219 |
+
second_embedding = self.second_[i](second_embedding)
|
| 220 |
+
|
| 221 |
+
for i in range(len(self.third_)):
|
| 222 |
+
third_embedding = self.third_[i](third_embedding)
|
| 223 |
+
|
| 224 |
+
# pdb.set_trace()
|
| 225 |
+
|
| 226 |
+
protein_seq_enc = first_embedding + second_embedding + third_embedding
|
| 227 |
+
|
| 228 |
+
return protein_seq_enc.transpose(1, 2)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class ReciprocalLayerwithCNN(nn.Module):
|
| 232 |
+
|
| 233 |
+
def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
|
| 234 |
+
super().__init__()
|
| 235 |
+
|
| 236 |
+
self.cnn = DilatedCNN(d_model, d_hidden)
|
| 237 |
+
|
| 238 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 239 |
+
|
| 240 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 241 |
+
|
| 242 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
|
| 243 |
+
|
| 244 |
+
self.ffn_seq = FFN(d_hidden, d_inner)
|
| 245 |
+
|
| 246 |
+
self.ffn_protein = FFN(d_hidden, d_inner)
|
| 247 |
+
|
| 248 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 249 |
+
# pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
|
| 250 |
+
protein_seq_enc = self.cnn(protein_seq_enc)
|
| 251 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 252 |
+
|
| 253 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 254 |
+
|
| 255 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 256 |
+
|
| 257 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 258 |
+
|
| 259 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 260 |
+
|
| 261 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class ReciprocalLayer(nn.Module):
|
| 265 |
+
|
| 266 |
+
def __init__(self, d_model, d_inner, n_head, d_k, d_v):
|
| 267 |
+
|
| 268 |
+
super().__init__()
|
| 269 |
+
|
| 270 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 271 |
+
|
| 272 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 273 |
+
|
| 274 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
|
| 275 |
+
|
| 276 |
+
self.ffn_seq = FFN(d_model, d_inner)
|
| 277 |
+
|
| 278 |
+
self.ffn_protein = FFN(d_model, d_inner)
|
| 279 |
+
|
| 280 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 281 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 282 |
+
|
| 283 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 287 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 288 |
+
|
| 289 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 290 |
+
|
| 291 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
scoring/functions/nonfouling.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
|
| 4 |
+
import xgboost as xgb
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModelForMaskedLM
|
| 8 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 9 |
+
import warnings
|
| 10 |
+
import numpy as np
|
| 11 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 12 |
+
from transformers import AutoModelForMaskedLM
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
rdBase.DisableLog('rdApp.error')
|
| 16 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 17 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 19 |
+
|
| 20 |
+
class Nonfouling:
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/nonfouling/new_data/best_model.json')
|
| 24 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 25 |
+
self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
|
| 26 |
+
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 27 |
+
|
| 28 |
+
def generate_embeddings(self, sequences):
|
| 29 |
+
embeddings = []
|
| 30 |
+
for sequence in sequences:
|
| 31 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
output = self.emb_model(**tokenized)
|
| 34 |
+
# Mean pooling across sequence length
|
| 35 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 36 |
+
embeddings.append(embedding)
|
| 37 |
+
return np.array(embeddings)
|
| 38 |
+
|
| 39 |
+
def get_scores(self, input_seqs: list):
|
| 40 |
+
scores = np.zeros(len(input_seqs))
|
| 41 |
+
features = self.generate_embeddings(input_seqs)
|
| 42 |
+
|
| 43 |
+
if len(features) == 0:
|
| 44 |
+
return scores
|
| 45 |
+
|
| 46 |
+
features = np.nan_to_num(features, nan=0.)
|
| 47 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 48 |
+
|
| 49 |
+
features = xgb.DMatrix(features)
|
| 50 |
+
|
| 51 |
+
scores = self.predictor.predict(features)
|
| 52 |
+
# return the probability of it being not hemolytic
|
| 53 |
+
return scores
|
| 54 |
+
|
| 55 |
+
def __call__(self, input_seqs: list):
|
| 56 |
+
scores = self.get_scores(input_seqs)
|
| 57 |
+
return scores
|
| 58 |
+
|
| 59 |
+
def unittest():
|
| 60 |
+
nf = Nonfouling()
|
| 61 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 62 |
+
|
| 63 |
+
scores = nf(input_seqs=seq)
|
| 64 |
+
print(scores)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
unittest()
|
scoring/functions/permeability.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
|
| 4 |
+
import xgboost as xgb
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModelForMaskedLM
|
| 8 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 9 |
+
import warnings
|
| 10 |
+
import numpy as np
|
| 11 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 12 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 13 |
+
from rdkit.Chem import AllChem
|
| 14 |
+
from typing import List
|
| 15 |
+
from scoring.functions.transformation import TransformFunction
|
| 16 |
+
from transformers import AutoModelForMaskedLM
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
rdBase.DisableLog('rdApp.error')
|
| 20 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 22 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 23 |
+
|
| 24 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 25 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 26 |
+
fps = []
|
| 27 |
+
valid_mask = []
|
| 28 |
+
for i, smile in enumerate(smiles):
|
| 29 |
+
mol = Chem.MolFromSmiles(smile)
|
| 30 |
+
valid_mask.append(int(mol is not None))
|
| 31 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 32 |
+
fps.append(fp)
|
| 33 |
+
|
| 34 |
+
fps = np.concatenate(fps, axis=0)
|
| 35 |
+
return fps, valid_mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 39 |
+
""" Create ECFP fingerprint of a molecule """
|
| 40 |
+
if hashed:
|
| 41 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 42 |
+
else:
|
| 43 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 44 |
+
fp_np = np.zeros((1,))
|
| 45 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 46 |
+
return fp_np.reshape(1, -1)
|
| 47 |
+
|
| 48 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 49 |
+
""" calculate the full list of descriptors for a molecule """
|
| 50 |
+
|
| 51 |
+
values, names = [], []
|
| 52 |
+
for nm, fn in Descriptors._descList:
|
| 53 |
+
try:
|
| 54 |
+
val = fn(mol)
|
| 55 |
+
except:
|
| 56 |
+
val = missingVal
|
| 57 |
+
values.append(val)
|
| 58 |
+
names.append(nm)
|
| 59 |
+
|
| 60 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 61 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 62 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 63 |
+
|
| 64 |
+
for nm, fn in custom_descriptors.items():
|
| 65 |
+
try:
|
| 66 |
+
val = fn(mol)
|
| 67 |
+
except:
|
| 68 |
+
val = missingVal
|
| 69 |
+
values.append(val)
|
| 70 |
+
names.append(nm)
|
| 71 |
+
return values, names
|
| 72 |
+
|
| 73 |
+
def get_pep_dps_from_smi(smi):
|
| 74 |
+
try:
|
| 75 |
+
mol = Chem.MolFromSmiles(smi)
|
| 76 |
+
except:
|
| 77 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 78 |
+
mol = None
|
| 79 |
+
|
| 80 |
+
dps, _ = getMolDescriptors(mol)
|
| 81 |
+
return np.array(dps)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_pep_dps(smi_list):
|
| 85 |
+
if len(smi_list) == 0:
|
| 86 |
+
return np.zeros((0, 213))
|
| 87 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 88 |
+
|
| 89 |
+
def check_smi_validity(smiles: list):
|
| 90 |
+
valid_smi, valid_idx = [], []
|
| 91 |
+
for idx, smi in enumerate(smiles):
|
| 92 |
+
try:
|
| 93 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 94 |
+
if mol:
|
| 95 |
+
valid_smi.append(smi)
|
| 96 |
+
valid_idx.append(idx)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 99 |
+
pass
|
| 100 |
+
return valid_smi, valid_idx
|
| 101 |
+
|
| 102 |
+
class Permeability:
|
| 103 |
+
|
| 104 |
+
def __init__(self):
|
| 105 |
+
self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/permeability/30K-train/best_model.json')
|
| 106 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 107 |
+
self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt', '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 108 |
+
|
| 109 |
+
def generate_embeddings(self, sequences):
|
| 110 |
+
embeddings = []
|
| 111 |
+
for sequence in sequences:
|
| 112 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
output = self.emb_model(**tokenized)
|
| 115 |
+
# Mean pooling across sequence length
|
| 116 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 117 |
+
embeddings.append(embedding)
|
| 118 |
+
return np.array(embeddings)
|
| 119 |
+
|
| 120 |
+
def get_features(self, input_seqs: list, dps=False, fps=False):
|
| 121 |
+
#valid_smiles, valid_idxes = check_smi_validity(input_seqs)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if fps:
|
| 125 |
+
fingerprints = fingerprints_from_smiles(input_seqs)[0]
|
| 126 |
+
else:
|
| 127 |
+
fingerprints = torch.empty((len(input_seqs), 0))
|
| 128 |
+
|
| 129 |
+
if dps:
|
| 130 |
+
descriptors = get_pep_dps(input_seqs)
|
| 131 |
+
else:
|
| 132 |
+
descriptors = torch.empty((len(input_seqs), 0))
|
| 133 |
+
|
| 134 |
+
embeddings = self.generate_embeddings(input_seqs)
|
| 135 |
+
# logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
|
| 136 |
+
|
| 137 |
+
features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
|
| 138 |
+
|
| 139 |
+
return features
|
| 140 |
+
|
| 141 |
+
def get_scores(self, input_seqs: list):
|
| 142 |
+
scores = -10 * np.ones(len(input_seqs))
|
| 143 |
+
features = self.get_features(input_seqs)
|
| 144 |
+
|
| 145 |
+
if len(features) == 0:
|
| 146 |
+
return scores
|
| 147 |
+
|
| 148 |
+
features = np.nan_to_num(features, nan=0.)
|
| 149 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 150 |
+
|
| 151 |
+
features = xgb.DMatrix(features)
|
| 152 |
+
|
| 153 |
+
scores = self.predictor.predict(features)
|
| 154 |
+
return scores
|
| 155 |
+
|
| 156 |
+
def __call__(self, input_seqs: list):
|
| 157 |
+
scores = self.get_scores(input_seqs)
|
| 158 |
+
return scores
|
| 159 |
+
|
| 160 |
+
def unittest():
|
| 161 |
+
permeability = Permeability()
|
| 162 |
+
seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
|
| 163 |
+
scores = permeability(input_seqs=seq)
|
| 164 |
+
print(scores)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == '__main__':
|
| 168 |
+
unittest()
|
scoring/functions/permeability_xg.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import optuna
|
| 4 |
+
from optuna.trial import TrialState
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
from rdkit.Chem import AllChem
|
| 7 |
+
from sklearn.metrics import mean_squared_error
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
import xgboost as xgb
|
| 10 |
+
import os
|
| 11 |
+
from datasets import load_from_disk
|
| 12 |
+
from scipy.stats import spearmanr
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def save_and_plot_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, output_path):
|
| 17 |
+
os.makedirs(output_path, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# Save training predictions
|
| 20 |
+
train_df = pd.DataFrame({'True Permeability': y_true_train, 'Predicted Permeability': y_pred_train})
|
| 21 |
+
train_df.to_csv(os.path.join(output_path, 'train_predictions.csv'), index=False)
|
| 22 |
+
|
| 23 |
+
# Save validation predictions
|
| 24 |
+
val_df = pd.DataFrame({'True Permeability': y_true_val, 'Predicted Permeability': y_pred_val})
|
| 25 |
+
val_df.to_csv(os.path.join(output_path, 'val_predictions.csv'), index=False)
|
| 26 |
+
|
| 27 |
+
# Plot training predictions
|
| 28 |
+
plot_correlation(
|
| 29 |
+
y_true_train,
|
| 30 |
+
y_pred_train,
|
| 31 |
+
title="Training Set Correlation Plot",
|
| 32 |
+
output_file=os.path.join(output_path, 'train_correlation.png'),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Plot validation predictions
|
| 36 |
+
plot_correlation(
|
| 37 |
+
y_true_val,
|
| 38 |
+
y_pred_val,
|
| 39 |
+
title="Validation Set Correlation Plot",
|
| 40 |
+
output_file=os.path.join(output_path, 'val_correlation.png'),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def plot_correlation(y_true, y_pred, title, output_file):
|
| 44 |
+
spearman_corr, _ = spearmanr(y_true, y_pred)
|
| 45 |
+
|
| 46 |
+
# Scatter plot
|
| 47 |
+
plt.figure(figsize=(10, 8))
|
| 48 |
+
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
|
| 49 |
+
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='teal', linestyle='--', label='Ideal fit')
|
| 50 |
+
|
| 51 |
+
# Add annotations
|
| 52 |
+
plt.title(f"{title}\nSpearman Correlation: {spearman_corr:.3f}")
|
| 53 |
+
plt.xlabel("True Permeability (logP)")
|
| 54 |
+
plt.ylabel("Predicted Affinity (logP)")
|
| 55 |
+
plt.legend()
|
| 56 |
+
|
| 57 |
+
# Save and show the plot
|
| 58 |
+
plt.tight_layout()
|
| 59 |
+
plt.savefig(output_file)
|
| 60 |
+
plt.show()
|
| 61 |
+
|
| 62 |
+
# Load dataset
|
| 63 |
+
dataset = load_from_disk('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/permeability/30K-data/')
|
| 64 |
+
|
| 65 |
+
# Extract sequences, labels, and embeddings
|
| 66 |
+
sequences = np.stack(dataset['sequence'])
|
| 67 |
+
labels = np.stack(dataset['labels']) # Regression labels
|
| 68 |
+
embeddings = np.stack(dataset['embedding']) # Pre-trained embeddings
|
| 69 |
+
|
| 70 |
+
# Function to compute Morgan fingerprints
|
| 71 |
+
def compute_morgan_fingerprints(smiles_list, radius=2, n_bits=2048):
|
| 72 |
+
fps = []
|
| 73 |
+
for smiles in smiles_list:
|
| 74 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 75 |
+
if mol is not None:
|
| 76 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
|
| 77 |
+
fps.append(np.array(fp))
|
| 78 |
+
else:
|
| 79 |
+
# If the SMILES string is invalid, use a zero vector
|
| 80 |
+
fps.append(np.zeros(n_bits))
|
| 81 |
+
print(f"Invalid SMILES: {smiles}")
|
| 82 |
+
return np.array(fps)
|
| 83 |
+
|
| 84 |
+
# Compute Morgan fingerprints for the sequences
|
| 85 |
+
#morgan_fingerprints = compute_morgan_fingerprints(sequences)
|
| 86 |
+
|
| 87 |
+
# Concatenate embeddings with Morgan fingerprints
|
| 88 |
+
#input_features = np.concatenate([embeddings, morgan_fingerprints], axis=1)
|
| 89 |
+
input_features = embeddings
|
| 90 |
+
|
| 91 |
+
# Initialize global variables
|
| 92 |
+
best_model_path = "/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/permeability/30K-train"
|
| 93 |
+
os.makedirs(best_model_path, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
def trial_info_callback(study, trial):
|
| 96 |
+
if study.best_trial == trial:
|
| 97 |
+
print(f"Trial {trial.number}:")
|
| 98 |
+
print(f" MSE: {trial.value}")
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# Define hyperparameters
|
| 102 |
+
params = {
|
| 103 |
+
'objective': 'reg:squarederror',
|
| 104 |
+
'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
|
| 105 |
+
'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
|
| 106 |
+
'gamma': trial.suggest_float('gamma', 0, 5),
|
| 107 |
+
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
|
| 108 |
+
'subsample': trial.suggest_float('subsample', 0.6, 0.9),
|
| 109 |
+
'learning_rate': trial.suggest_float('learning_rate', 1e-5, 0.1),
|
| 110 |
+
'max_depth': trial.suggest_int('max_depth', 2, 30),
|
| 111 |
+
'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
|
| 112 |
+
'tree_method': 'hist',
|
| 113 |
+
'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.5, 10.0, log=True),
|
| 114 |
+
'device': 'cuda:6',
|
| 115 |
+
}
|
| 116 |
+
"""params = {
|
| 117 |
+
'objective': 'reg:squarederror',
|
| 118 |
+
'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
|
| 119 |
+
'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
|
| 120 |
+
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
|
| 121 |
+
'subsample': trial.suggest_float('subsample', 0.6, 0.9),
|
| 122 |
+
'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2),
|
| 123 |
+
'max_depth': trial.suggest_int('max_depth', 4, 20),
|
| 124 |
+
'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
|
| 125 |
+
'tree_method': 'hist',
|
| 126 |
+
'device': 'cuda:6',
|
| 127 |
+
}"""
|
| 128 |
+
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
|
| 129 |
+
|
| 130 |
+
# Train-validation split
|
| 131 |
+
X_train, X_val, y_train, y_val = train_test_split(input_features, labels, test_size=0.2, random_state=42)
|
| 132 |
+
|
| 133 |
+
# Convert data to DMatrix
|
| 134 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 135 |
+
dvalid = xgb.DMatrix(X_val, label=y_val)
|
| 136 |
+
|
| 137 |
+
# Train XGBoost
|
| 138 |
+
model = xgb.train(
|
| 139 |
+
params=params,
|
| 140 |
+
dtrain=dtrain,
|
| 141 |
+
num_boost_round=num_boost_round,
|
| 142 |
+
evals=[(dvalid, "validation")],
|
| 143 |
+
early_stopping_rounds=50,
|
| 144 |
+
verbose_eval=False,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Predict and evaluate
|
| 148 |
+
preds_train = model.predict(dtrain)
|
| 149 |
+
preds_val = model.predict(dvalid)
|
| 150 |
+
|
| 151 |
+
mse = mean_squared_error(y_val, preds_val)
|
| 152 |
+
|
| 153 |
+
# Calculate Spearman Rank Correlation
|
| 154 |
+
spearman_corr, _ = spearmanr(y_val, preds_val)
|
| 155 |
+
print(f"Spearman Rank Correlation: {spearman_corr}")
|
| 156 |
+
|
| 157 |
+
# Save the best model
|
| 158 |
+
if trial.study.user_attrs.get("best_mse", np.inf) > mse:
|
| 159 |
+
trial.study.set_user_attr("best_mse", mse)
|
| 160 |
+
trial.study.set_user_attr("best_spearman", spearman_corr) # Save the Spearman correlation
|
| 161 |
+
model.save_model(os.path.join(best_model_path, "best_model.json"))
|
| 162 |
+
save_and_plot_predictions(y_train, preds_train, y_val, preds_val, best_model_path)
|
| 163 |
+
|
| 164 |
+
return mse
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
|
| 168 |
+
study.optimize(objective, n_trials=200, callbacks=[trial_info_callback])
|
| 169 |
+
|
| 170 |
+
# Print study statistics
|
| 171 |
+
print("Study statistics: ")
|
| 172 |
+
print(f" Number of finished trials: {len(study.trials)}")
|
| 173 |
+
print(f" Best trial value (MSE): {study.best_trial.value}")
|
| 174 |
+
print(f" Best Spearman Correlation: {study.user_attrs.get('best_spearman', None)}") # Print the best Spearman correlation
|
| 175 |
+
for key, value in study.best_trial.params.items():
|
| 176 |
+
print(f" {key}: {value}")
|
scoring/functions/scoring_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import numpy as np
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 5 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 6 |
+
import joblib
|
| 7 |
+
from transformation import TransformFunction
|
| 8 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 9 |
+
from rdkit.Chem import AllChem
|
| 10 |
+
from typing import List
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 14 |
+
"""
|
| 15 |
+
Create ECFP fingerprint of a molecule
|
| 16 |
+
"""
|
| 17 |
+
if hashed:
|
| 18 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 19 |
+
else:
|
| 20 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 21 |
+
fp_np = np.zeros((1,))
|
| 22 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 23 |
+
return fp_np.reshape(1, -1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 27 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 28 |
+
fps = []
|
| 29 |
+
valid_mask = []
|
| 30 |
+
for i, smile in enumerate(smiles):
|
| 31 |
+
mol = Chem.MolFromSmiles(smile)
|
| 32 |
+
valid_mask.append(int(mol is not None))
|
| 33 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 34 |
+
fps.append(fp)
|
| 35 |
+
|
| 36 |
+
fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
|
| 37 |
+
return fps, valid_mask
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 41 |
+
""" calculate the full list of descriptors for a molecule """
|
| 42 |
+
|
| 43 |
+
values, names = [], []
|
| 44 |
+
for nm, fn in Descriptors._descList:
|
| 45 |
+
try:
|
| 46 |
+
val = fn(mol)
|
| 47 |
+
except:
|
| 48 |
+
val = missingVal
|
| 49 |
+
values.append(val)
|
| 50 |
+
names.append(nm)
|
| 51 |
+
|
| 52 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 53 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 54 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 55 |
+
|
| 56 |
+
for nm, fn in custom_descriptors.items():
|
| 57 |
+
try:
|
| 58 |
+
val = fn(mol)
|
| 59 |
+
except:
|
| 60 |
+
val = missingVal
|
| 61 |
+
values.append(val)
|
| 62 |
+
names.append(nm)
|
| 63 |
+
return values, names
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_pep_dps_from_smi(smi):
|
| 67 |
+
try:
|
| 68 |
+
mol = Chem.MolFromSmiles(smi)
|
| 69 |
+
except:
|
| 70 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 71 |
+
mol = None
|
| 72 |
+
|
| 73 |
+
dps, _ = getMolDescriptors(mol)
|
| 74 |
+
return np.array(dps)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_pep_dps(smi_list):
|
| 78 |
+
if len(smi_list) == 0:
|
| 79 |
+
return np.zeros((0, 211))
|
| 80 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
"""def get_smi_from_helms(helm_seqs: list):
|
| 84 |
+
valid_idxes = []
|
| 85 |
+
valid_smiles = []
|
| 86 |
+
|
| 87 |
+
for idx, helm in enumerate(helm_seqs):
|
| 88 |
+
# Ignore helm which cannot converted into molecules
|
| 89 |
+
try:
|
| 90 |
+
smi = get_cycpep_smi_from_helm(helm)
|
| 91 |
+
if smi:
|
| 92 |
+
valid_idxes.append(idx)
|
| 93 |
+
valid_smiles.append(smi)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
# logger.debug(f'Error: {e} in helm {helm}')
|
| 96 |
+
pass
|
| 97 |
+
return valid_smiles, valid_idxes"""
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def check_smi_validity(smiles: list):
|
| 101 |
+
valid_smi, valid_idx = [], []
|
| 102 |
+
for idx, smi in enumerate(smiles):
|
| 103 |
+
try:
|
| 104 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 105 |
+
if mol:
|
| 106 |
+
valid_smi.append(smi)
|
| 107 |
+
valid_idx.append(idx)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 110 |
+
pass
|
| 111 |
+
return valid_smi, valid_idx
|
scoring/functions/solubility.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
|
| 4 |
+
import xgboost as xgb
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModelForMaskedLM
|
| 8 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 9 |
+
import warnings
|
| 10 |
+
import numpy as np
|
| 11 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 12 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 13 |
+
from rdkit.Chem import AllChem
|
| 14 |
+
from typing import List
|
| 15 |
+
from scoring.functions.transformation import TransformFunction
|
| 16 |
+
from transformers import AutoModelForMaskedLM
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
rdBase.DisableLog('rdApp.error')
|
| 20 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 22 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 23 |
+
|
| 24 |
+
class Solubility:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_train/best_model.json')
|
| 27 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 28 |
+
self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
|
| 29 |
+
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 30 |
+
|
| 31 |
+
def generate_embeddings(self, sequences):
|
| 32 |
+
embeddings = []
|
| 33 |
+
for sequence in sequences:
|
| 34 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
output = self.emb_model(**tokenized)
|
| 37 |
+
# Mean pooling across sequence length
|
| 38 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 39 |
+
embeddings.append(embedding)
|
| 40 |
+
return np.array(embeddings)
|
| 41 |
+
|
| 42 |
+
def get_scores(self, input_seqs: list):
|
| 43 |
+
scores = np.zeros(len(input_seqs))
|
| 44 |
+
features = self.generate_embeddings(input_seqs)
|
| 45 |
+
|
| 46 |
+
if len(features) == 0:
|
| 47 |
+
return scores
|
| 48 |
+
|
| 49 |
+
features = np.nan_to_num(features, nan=0.)
|
| 50 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 51 |
+
|
| 52 |
+
features = xgb.DMatrix(features)
|
| 53 |
+
|
| 54 |
+
scores = self.predictor.predict(features)
|
| 55 |
+
return scores
|
| 56 |
+
|
| 57 |
+
def __call__(self, input_seqs: list):
|
| 58 |
+
scores = self.get_scores(input_seqs)
|
| 59 |
+
return scores
|
| 60 |
+
|
| 61 |
+
def unittest():
|
| 62 |
+
solubility = Solubility()
|
| 63 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 64 |
+
scores = solubility(input_seqs=seq)
|
| 65 |
+
print(scores)
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
unittest()
|
scoring/hemolysis.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
|
| 4 |
+
import xgboost as xgb
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModelForMaskedLM
|
| 8 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 9 |
+
import warnings
|
| 10 |
+
import numpy as np
|
| 11 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 12 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 13 |
+
from rdkit.Chem import AllChem
|
| 14 |
+
from typing import List
|
| 15 |
+
from scoring.functions.transformation import TransformFunction
|
| 16 |
+
from transformers import AutoModelForMaskedLM
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
rdBase.DisableLog('rdApp.error')
|
| 20 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 22 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 23 |
+
|
| 24 |
+
class Hemolysis:
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/hemolysis/new_train/best_model.json')
|
| 28 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 29 |
+
self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt', '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
|
| 30 |
+
|
| 31 |
+
def generate_embeddings(self, sequences):
|
| 32 |
+
embeddings = []
|
| 33 |
+
for sequence in sequences:
|
| 34 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
output = self.emb_model(**tokenized)
|
| 37 |
+
# Mean pooling across sequence length
|
| 38 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 39 |
+
embeddings.append(embedding)
|
| 40 |
+
return np.array(embeddings)
|
| 41 |
+
|
| 42 |
+
def get_scores(self, input_seqs: list):
|
| 43 |
+
scores = np.ones(len(input_seqs))
|
| 44 |
+
features = self.generate_embeddings(input_seqs)
|
| 45 |
+
|
| 46 |
+
if len(features) == 0:
|
| 47 |
+
return scores
|
| 48 |
+
|
| 49 |
+
features = np.nan_to_num(features, nan=0.)
|
| 50 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 51 |
+
|
| 52 |
+
features = xgb.DMatrix(features)
|
| 53 |
+
|
| 54 |
+
probs = self.predictor.predict(features)
|
| 55 |
+
# return the probability of it being not hemolytic
|
| 56 |
+
return scores - probs
|
| 57 |
+
|
| 58 |
+
def __call__(self, input_seqs: list):
|
| 59 |
+
scores = self.get_scores(input_seqs)
|
| 60 |
+
return scores
|
| 61 |
+
|
| 62 |
+
def unittest():
|
| 63 |
+
hemo = Hemolysis()
|
| 64 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 65 |
+
|
| 66 |
+
scores = hemo(input_seqs=seq)
|
| 67 |
+
print(scores)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
unittest()
|
scoring/scoring_functions.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
|
| 3 |
+
import io
|
| 4 |
+
import subprocess
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from typing import List
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
import torch
|
| 14 |
+
from scoring.functions.binding.binding import BindingAffinity
|
| 15 |
+
from scoring.functions.permeability.permeability import Permeability
|
| 16 |
+
from scoring.functions.solubility.solubility import Solubility
|
| 17 |
+
from scoring.functions.hemolysis.hemolysis import Hemolysis
|
| 18 |
+
from scoring.functions.nonfouling.nonfouling import Nonfouling
|
| 19 |
+
|
| 20 |
+
class ScoringFunctions:
|
| 21 |
+
def __init__(self, score_func_names=None, prot_seqs=[]):
|
| 22 |
+
"""
|
| 23 |
+
Class for generating score vectors given generated sequence
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
score_func_names: list of scoring function names to be evaluated
|
| 27 |
+
score_weights: weights to scale scores (default: 1)
|
| 28 |
+
target_protein: sequence of target protein binder
|
| 29 |
+
"""
|
| 30 |
+
if score_func_names is None:
|
| 31 |
+
# just do unmasking based on validity of peptide bonds
|
| 32 |
+
self.score_func_names = []
|
| 33 |
+
else:
|
| 34 |
+
self.score_func_names = score_func_names
|
| 35 |
+
|
| 36 |
+
# self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
|
| 37 |
+
|
| 38 |
+
# binding affinities
|
| 39 |
+
self.target_protein = prot_seqs
|
| 40 |
+
print(len(prot_seqs))
|
| 41 |
+
|
| 42 |
+
if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
|
| 43 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0])
|
| 44 |
+
binding_affinity2 = None
|
| 45 |
+
elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
|
| 46 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0])
|
| 47 |
+
binding_affinity2 = BindingAffinity(prot_seqs[1])
|
| 48 |
+
else:
|
| 49 |
+
print("here")
|
| 50 |
+
binding_affinity1 = None
|
| 51 |
+
binding_affinity2 = None
|
| 52 |
+
|
| 53 |
+
permeability = Permeability()
|
| 54 |
+
sol = Solubility()
|
| 55 |
+
nonfouling = Nonfouling()
|
| 56 |
+
hemo = Hemolysis()
|
| 57 |
+
|
| 58 |
+
self.all_funcs = {'binding_affinity1': binding_affinity1,
|
| 59 |
+
'binding_affinity2': binding_affinity2,
|
| 60 |
+
'permeability': permeability,
|
| 61 |
+
'nonfouling': nonfouling,
|
| 62 |
+
'solubility': sol,
|
| 63 |
+
'hemolysis': hemo
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def forward(self, input_seqs):
|
| 67 |
+
scores = []
|
| 68 |
+
|
| 69 |
+
for i, score_func in enumerate(self.score_func_names):
|
| 70 |
+
score = self.all_funcs[score_func](input_seqs = input_seqs)
|
| 71 |
+
|
| 72 |
+
scores.append(score)
|
| 73 |
+
|
| 74 |
+
# convert to numpy arrays with shape (num_sequences, num_functions)
|
| 75 |
+
scores = np.float32(scores).T
|
| 76 |
+
|
| 77 |
+
return scores
|
| 78 |
+
|
| 79 |
+
def __call__(self, input_seqs: list):
|
| 80 |
+
return self.forward(input_seqs)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def unittest():
|
| 84 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 85 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 86 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 87 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 88 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 89 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 90 |
+
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
|
| 91 |
+
|
| 92 |
+
num_iter = 0
|
| 93 |
+
score_func_times = [0, 1, 2, 3, 4, 5]
|
| 94 |
+
|
| 95 |
+
scoring = ScoringFunctions(score_func_names=['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'], prot_seqs=[tfr])
|
| 96 |
+
|
| 97 |
+
smiles = ['N2[C@H](CC(C)C)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](Cc1ccccc1C(F)(F)F)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](CCSC)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](CC(=O)N)C2(=O)']
|
| 98 |
+
|
| 99 |
+
scores = scoring(input_seqs=smiles)
|
| 100 |
+
print(scores)
|
| 101 |
+
print(len(scores))
|
| 102 |
+
|
| 103 |
+
if __name__ == '__main__':
|
| 104 |
+
unittest()
|
tokenizer/__init__.py
ADDED
|
File without changes
|
tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import codecs
|
| 6 |
+
import unicodedata
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from transformers import PreTrainedTokenizer
|
| 9 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
def load_vocab(vocab_file):
|
| 13 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 14 |
+
vocab = collections.OrderedDict()
|
| 15 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 16 |
+
tokens = reader.readlines()
|
| 17 |
+
for index, token in enumerate(tokens):
|
| 18 |
+
token = token.rstrip("\n")
|
| 19 |
+
vocab[token] = index
|
| 20 |
+
return vocab
|
| 21 |
+
|
| 22 |
+
class Atomwise_Tokenizer(object):
|
| 23 |
+
"""Run atom-level SMILES tokenization"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
""" Constructs a atom-level Tokenizer.
|
| 27 |
+
"""
|
| 28 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 29 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 30 |
+
|
| 31 |
+
self.regex = re.compile(self.regex_pattern)
|
| 32 |
+
|
| 33 |
+
def tokenize(self, text):
|
| 34 |
+
""" Basic Tokenization of a SMILES.
|
| 35 |
+
"""
|
| 36 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 37 |
+
return tokens
|
| 38 |
+
|
| 39 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 40 |
+
r"""
|
| 41 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 42 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 43 |
+
should refer to the superclass for more information regarding methods.
|
| 44 |
+
Args:
|
| 45 |
+
vocab_file (:obj:`string`):
|
| 46 |
+
File containing the vocabulary.
|
| 47 |
+
spe_file (:obj:`string`):
|
| 48 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 49 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 50 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 51 |
+
token instead.
|
| 52 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 53 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 54 |
+
for sequence classification or for a text and a question for question answering.
|
| 55 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 56 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 57 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 58 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 59 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 60 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 61 |
+
special tokens.
|
| 62 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 63 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 64 |
+
modeling. This is the token which the model will try to predict.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, vocab_file, spe_file,
|
| 68 |
+
unk_token="[UNK]",
|
| 69 |
+
sep_token="[SEP]",
|
| 70 |
+
pad_token="[PAD]",
|
| 71 |
+
cls_token="[CLS]",
|
| 72 |
+
mask_token="[MASK]",
|
| 73 |
+
**kwargs):
|
| 74 |
+
if not os.path.isfile(vocab_file):
|
| 75 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 76 |
+
if not os.path.isfile(spe_file):
|
| 77 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 78 |
+
|
| 79 |
+
self.vocab = load_vocab(vocab_file)
|
| 80 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 81 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 82 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 83 |
+
|
| 84 |
+
super().__init__(
|
| 85 |
+
unk_token=unk_token,
|
| 86 |
+
sep_token=sep_token,
|
| 87 |
+
pad_token=pad_token,
|
| 88 |
+
cls_token=cls_token,
|
| 89 |
+
mask_token=mask_token,
|
| 90 |
+
**kwargs)
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def vocab_size(self):
|
| 94 |
+
return len(self.vocab)
|
| 95 |
+
|
| 96 |
+
def get_vocab(self):
|
| 97 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 98 |
+
|
| 99 |
+
def _tokenize(self, text):
|
| 100 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 101 |
+
|
| 102 |
+
def _convert_token_to_id(self, token):
|
| 103 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 104 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 105 |
+
|
| 106 |
+
# changed encode and decode functions
|
| 107 |
+
def encode(self, token_array):
|
| 108 |
+
token_ids = []
|
| 109 |
+
token_ids.append(2)
|
| 110 |
+
for token in token_array:
|
| 111 |
+
id = self._convert_token_to_id(token)
|
| 112 |
+
token_ids.append(id)
|
| 113 |
+
token_ids.append(3)
|
| 114 |
+
token_ids = torch.tensor([token_ids])
|
| 115 |
+
attn_mask = torch.ones_like(token_ids)
|
| 116 |
+
return {'input_ids': token_ids, 'attention_mask': attn_mask}
|
| 117 |
+
|
| 118 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 119 |
+
token_ids = token_ids.squeeze(0).cpu().tolist()
|
| 120 |
+
token_array = []
|
| 121 |
+
for idx in token_ids:
|
| 122 |
+
if idx == 3: # Stop decoding when token ID 3 is encountered
|
| 123 |
+
break
|
| 124 |
+
if skip_special_tokens and idx in self.all_special_ids:
|
| 125 |
+
continue
|
| 126 |
+
token = self._convert_id_to_token(idx)
|
| 127 |
+
token_array.append(token)
|
| 128 |
+
sequence = "".join(token_array)
|
| 129 |
+
return sequence
|
| 130 |
+
|
| 131 |
+
def batch_decode(self, batch_token_ids, skip_special_tokens=True):
|
| 132 |
+
sequences = []
|
| 133 |
+
for token_ids in batch_token_ids:
|
| 134 |
+
sequences.append(self.decode(token_ids))
|
| 135 |
+
return sequences
|
| 136 |
+
|
| 137 |
+
def get_token_split(self, token_ids):
|
| 138 |
+
if isinstance(token_ids, torch.Tensor):
|
| 139 |
+
token_ids = token_ids.cpu().tolist()
|
| 140 |
+
|
| 141 |
+
token_array = []
|
| 142 |
+
for seq_ids in token_ids:
|
| 143 |
+
seq_array = []
|
| 144 |
+
for id in seq_ids:
|
| 145 |
+
token = self._convert_id_to_token(id)
|
| 146 |
+
seq_array.append(token)
|
| 147 |
+
token_array.append(seq_array)
|
| 148 |
+
|
| 149 |
+
return token_array
|
| 150 |
+
|
| 151 |
+
def _convert_id_to_token(self, index):
|
| 152 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 153 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 154 |
+
|
| 155 |
+
def convert_tokens_to_string(self, tokens):
|
| 156 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 157 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 158 |
+
return out_string
|
| 159 |
+
|
| 160 |
+
def build_inputs_with_special_tokens(
|
| 161 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 162 |
+
) -> List[int]:
|
| 163 |
+
"""
|
| 164 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 165 |
+
by concatenating and adding special tokens.
|
| 166 |
+
A BERT sequence has the following format:
|
| 167 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 168 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 169 |
+
Args:
|
| 170 |
+
token_ids_0 (:obj:`List[int]`):
|
| 171 |
+
List of IDs to which the special tokens will be added
|
| 172 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 173 |
+
Optional second list of IDs for sequence pairs.
|
| 174 |
+
Returns:
|
| 175 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 176 |
+
"""
|
| 177 |
+
if token_ids_1 is None:
|
| 178 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 179 |
+
cls = [self.cls_token_id]
|
| 180 |
+
sep = [self.sep_token_id]
|
| 181 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 182 |
+
|
| 183 |
+
def get_special_tokens_mask(
|
| 184 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 185 |
+
) -> List[int]:
|
| 186 |
+
"""
|
| 187 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 188 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 189 |
+
Args:
|
| 190 |
+
token_ids_0 (:obj:`List[int]`):
|
| 191 |
+
List of ids.
|
| 192 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 193 |
+
Optional second list of IDs for sequence pairs.
|
| 194 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 195 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 196 |
+
Returns:
|
| 197 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
if already_has_special_tokens:
|
| 201 |
+
if token_ids_1 is not None:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 204 |
+
"ids is already formated with special tokens for the model."
|
| 205 |
+
)
|
| 206 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 207 |
+
|
| 208 |
+
if token_ids_1 is not None:
|
| 209 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 210 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 211 |
+
|
| 212 |
+
def create_token_type_ids_from_sequences(
|
| 213 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 214 |
+
) -> List[int]:
|
| 215 |
+
"""
|
| 216 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 217 |
+
A BERT sequence pair mask has the following format:
|
| 218 |
+
::
|
| 219 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 220 |
+
| first sequence | second sequence |
|
| 221 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 222 |
+
Args:
|
| 223 |
+
token_ids_0 (:obj:`List[int]`):
|
| 224 |
+
List of ids.
|
| 225 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 226 |
+
Optional second list of IDs for sequence pairs.
|
| 227 |
+
Returns:
|
| 228 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 229 |
+
sequence(s).
|
| 230 |
+
"""
|
| 231 |
+
sep = [self.sep_token_id]
|
| 232 |
+
cls = [self.cls_token_id]
|
| 233 |
+
if token_ids_1 is None:
|
| 234 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 235 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 236 |
+
|
| 237 |
+
def save_vocabulary(self, vocab_path):
|
| 238 |
+
"""
|
| 239 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 240 |
+
Args:
|
| 241 |
+
vocab_path (:obj:`str`):
|
| 242 |
+
The directory in which to save the vocabulary.
|
| 243 |
+
Returns:
|
| 244 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 245 |
+
"""
|
| 246 |
+
index = 0
|
| 247 |
+
if os.path.isdir(vocab_path):
|
| 248 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 249 |
+
else:
|
| 250 |
+
vocab_file = vocab_path
|
| 251 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 252 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 253 |
+
if index != token_index:
|
| 254 |
+
logger.warning(
|
| 255 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 256 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 257 |
+
)
|
| 258 |
+
index = token_index
|
| 259 |
+
writer.write(token + "\n")
|
| 260 |
+
index += 1
|
| 261 |
+
return (vocab_file,)
|
| 262 |
+
|
| 263 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 264 |
+
r"""
|
| 265 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 266 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 267 |
+
should refer to the superclass for more information regarding methods.
|
| 268 |
+
Args:
|
| 269 |
+
vocab_file (:obj:`string`):
|
| 270 |
+
File containing the vocabulary.
|
| 271 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 272 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 273 |
+
token instead.
|
| 274 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 275 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 276 |
+
for sequence classification or for a text and a question for question answering.
|
| 277 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 278 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 279 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 280 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 281 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 282 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 283 |
+
special tokens.
|
| 284 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 285 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 286 |
+
modeling. This is the token which the model will try to predict.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
vocab_file,
|
| 292 |
+
unk_token="[UNK]",
|
| 293 |
+
sep_token="[SEP]",
|
| 294 |
+
pad_token="[PAD]",
|
| 295 |
+
cls_token="[CLS]",
|
| 296 |
+
mask_token="[MASK]",
|
| 297 |
+
**kwargs
|
| 298 |
+
):
|
| 299 |
+
super().__init__(
|
| 300 |
+
unk_token=unk_token,
|
| 301 |
+
sep_token=sep_token,
|
| 302 |
+
pad_token=pad_token,
|
| 303 |
+
cls_token=cls_token,
|
| 304 |
+
mask_token=mask_token,
|
| 305 |
+
**kwargs,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if not os.path.isfile(vocab_file):
|
| 309 |
+
raise ValueError(
|
| 310 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 311 |
+
)
|
| 312 |
+
self.vocab = load_vocab(vocab_file)
|
| 313 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 314 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 315 |
+
|
| 316 |
+
@property
|
| 317 |
+
def vocab_size(self):
|
| 318 |
+
return len(self.vocab)
|
| 319 |
+
|
| 320 |
+
def get_vocab(self):
|
| 321 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _tokenize(self, text):
|
| 325 |
+
return self.tokenizer.tokenize(text)
|
| 326 |
+
|
| 327 |
+
def _convert_token_to_id(self, token):
|
| 328 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 329 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 330 |
+
|
| 331 |
+
def _convert_id_to_token(self, index):
|
| 332 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 333 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 334 |
+
|
| 335 |
+
def convert_tokens_to_string(self, tokens):
|
| 336 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 337 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 338 |
+
return out_string
|
| 339 |
+
|
| 340 |
+
def build_inputs_with_special_tokens(
|
| 341 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 342 |
+
) -> List[int]:
|
| 343 |
+
"""
|
| 344 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 345 |
+
by concatenating and adding special tokens.
|
| 346 |
+
A BERT sequence has the following format:
|
| 347 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 348 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 349 |
+
Args:
|
| 350 |
+
token_ids_0 (:obj:`List[int]`):
|
| 351 |
+
List of IDs to which the special tokens will be added
|
| 352 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 353 |
+
Optional second list of IDs for sequence pairs.
|
| 354 |
+
Returns:
|
| 355 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 356 |
+
"""
|
| 357 |
+
if token_ids_1 is None:
|
| 358 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 359 |
+
cls = [self.cls_token_id]
|
| 360 |
+
sep = [self.sep_token_id]
|
| 361 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 362 |
+
|
| 363 |
+
def get_special_tokens_mask(
|
| 364 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 365 |
+
) -> List[int]:
|
| 366 |
+
"""
|
| 367 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 368 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 369 |
+
Args:
|
| 370 |
+
token_ids_0 (:obj:`List[int]`):
|
| 371 |
+
List of ids.
|
| 372 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 373 |
+
Optional second list of IDs for sequence pairs.
|
| 374 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 375 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 376 |
+
Returns:
|
| 377 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
if already_has_special_tokens:
|
| 381 |
+
if token_ids_1 is not None:
|
| 382 |
+
raise ValueError(
|
| 383 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 384 |
+
"ids is already formated with special tokens for the model."
|
| 385 |
+
)
|
| 386 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 387 |
+
|
| 388 |
+
if token_ids_1 is not None:
|
| 389 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 390 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 391 |
+
|
| 392 |
+
def create_token_type_ids_from_sequences(
|
| 393 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 394 |
+
) -> List[int]:
|
| 395 |
+
"""
|
| 396 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 397 |
+
A BERT sequence pair mask has the following format:
|
| 398 |
+
::
|
| 399 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 400 |
+
| first sequence | second sequence |
|
| 401 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 402 |
+
Args:
|
| 403 |
+
token_ids_0 (:obj:`List[int]`):
|
| 404 |
+
List of ids.
|
| 405 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 406 |
+
Optional second list of IDs for sequence pairs.
|
| 407 |
+
Returns:
|
| 408 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 409 |
+
sequence(s).
|
| 410 |
+
"""
|
| 411 |
+
sep = [self.sep_token_id]
|
| 412 |
+
cls = [self.cls_token_id]
|
| 413 |
+
if token_ids_1 is None:
|
| 414 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 415 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 416 |
+
|
| 417 |
+
def save_vocabulary(self, vocab_path):
|
| 418 |
+
"""
|
| 419 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 420 |
+
Args:
|
| 421 |
+
vocab_path (:obj:`str`):
|
| 422 |
+
The directory in which to save the vocabulary.
|
| 423 |
+
Returns:
|
| 424 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 425 |
+
"""
|
| 426 |
+
index = 0
|
| 427 |
+
if os.path.isdir(vocab_path):
|
| 428 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 429 |
+
else:
|
| 430 |
+
vocab_file = vocab_path
|
| 431 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 432 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 433 |
+
if index != token_index:
|
| 434 |
+
logger.warning(
|
| 435 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 436 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 437 |
+
)
|
| 438 |
+
index = token_index
|
| 439 |
+
writer.write(token + "\n")
|
| 440 |
+
index += 1
|
| 441 |
+
return (vocab_file,)
|
tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
cccc
|
| 96 |
+
ccn
|
| 97 |
+
cco
|
| 98 |
+
ccs
|
| 99 |
+
cn
|
| 100 |
+
cnc
|
| 101 |
+
cnn
|
| 102 |
+
co
|
| 103 |
+
coc
|
| 104 |
+
cs
|
| 105 |
+
csc
|
| 106 |
+
csn
|
| 107 |
+
e
|
| 108 |
+
g
|
| 109 |
+
i
|
| 110 |
+
l
|
| 111 |
+
n
|
| 112 |
+
nc
|
| 113 |
+
ncc
|
| 114 |
+
ncn
|
| 115 |
+
nco
|
| 116 |
+
ncs
|
| 117 |
+
nn
|
| 118 |
+
nnc
|
| 119 |
+
nnn
|
| 120 |
+
no
|
| 121 |
+
noc
|
| 122 |
+
non
|
| 123 |
+
ns
|
| 124 |
+
nsc
|
| 125 |
+
nsn
|
| 126 |
+
o
|
| 127 |
+
oc
|
| 128 |
+
occ
|
| 129 |
+
on
|
| 130 |
+
p
|
| 131 |
+
r
|
| 132 |
+
s
|
| 133 |
+
sc
|
| 134 |
+
scc
|
| 135 |
+
scn
|
| 136 |
+
sn
|
| 137 |
+
t
|
| 138 |
+
c1
|
| 139 |
+
c2
|
| 140 |
+
c3
|
| 141 |
+
c4
|
| 142 |
+
c5
|
| 143 |
+
c6
|
| 144 |
+
c7
|
| 145 |
+
c8
|
| 146 |
+
c9
|
| 147 |
+
n1
|
| 148 |
+
n2
|
| 149 |
+
n3
|
| 150 |
+
n4
|
| 151 |
+
n5
|
| 152 |
+
n6
|
| 153 |
+
n7
|
| 154 |
+
n8
|
| 155 |
+
n9
|
| 156 |
+
O1
|
| 157 |
+
O2
|
| 158 |
+
O3
|
| 159 |
+
O4
|
| 160 |
+
O5
|
| 161 |
+
O6
|
| 162 |
+
O7
|
| 163 |
+
O8
|
| 164 |
+
O9
|
| 165 |
+
(c1
|
| 166 |
+
(c2
|
| 167 |
+
c1)
|
| 168 |
+
c2)
|
| 169 |
+
(n1
|
| 170 |
+
(n2
|
| 171 |
+
n1)
|
| 172 |
+
n2)
|
| 173 |
+
(O1
|
| 174 |
+
(O2
|
| 175 |
+
O2)
|
| 176 |
+
=O
|
| 177 |
+
=C
|
| 178 |
+
=c
|
| 179 |
+
=N
|
| 180 |
+
=n
|
| 181 |
+
=CC
|
| 182 |
+
=CN
|
| 183 |
+
=Cc
|
| 184 |
+
=cc
|
| 185 |
+
=NC
|
| 186 |
+
=Nc
|
| 187 |
+
=nC
|
| 188 |
+
=nc
|
| 189 |
+
#C
|
| 190 |
+
#CC
|
| 191 |
+
#CN
|
| 192 |
+
#N
|
| 193 |
+
#NC
|
| 194 |
+
#NN
|
| 195 |
+
(C
|
| 196 |
+
C)
|
| 197 |
+
(O
|
| 198 |
+
O)
|
| 199 |
+
(N
|
| 200 |
+
N)
|
| 201 |
+
NP
|
| 202 |
+
PN
|
| 203 |
+
CP
|
| 204 |
+
PC
|
| 205 |
+
NS
|
| 206 |
+
SN
|
| 207 |
+
SP
|
| 208 |
+
PS
|
| 209 |
+
C(=O)
|
| 210 |
+
(/Br)
|
| 211 |
+
(/C#N)
|
| 212 |
+
(/C)
|
| 213 |
+
(/C=N)
|
| 214 |
+
(/C=O)
|
| 215 |
+
(/CBr)
|
| 216 |
+
(/CC)
|
| 217 |
+
(/CCC)
|
| 218 |
+
(/CCF)
|
| 219 |
+
(/CCN)
|
| 220 |
+
(/CCO)
|
| 221 |
+
(/CCl)
|
| 222 |
+
(/CI)
|
| 223 |
+
(/CN)
|
| 224 |
+
(/CO)
|
| 225 |
+
(/CS)
|
| 226 |
+
(/Cl)
|
| 227 |
+
(/F)
|
| 228 |
+
(/I)
|
| 229 |
+
(/N)
|
| 230 |
+
(/NC)
|
| 231 |
+
(/NCC)
|
| 232 |
+
(/NO)
|
| 233 |
+
(/O)
|
| 234 |
+
(/OC)
|
| 235 |
+
(/OCC)
|
| 236 |
+
(/S)
|
| 237 |
+
(/SC)
|
| 238 |
+
(=C)
|
| 239 |
+
(=C/C)
|
| 240 |
+
(=C/F)
|
| 241 |
+
(=C/I)
|
| 242 |
+
(=C/N)
|
| 243 |
+
(=C/O)
|
| 244 |
+
(=CBr)
|
| 245 |
+
(=CC)
|
| 246 |
+
(=CCF)
|
| 247 |
+
(=CCN)
|
| 248 |
+
(=CCO)
|
| 249 |
+
(=CCl)
|
| 250 |
+
(=CF)
|
| 251 |
+
(=CI)
|
| 252 |
+
(=CN)
|
| 253 |
+
(=CO)
|
| 254 |
+
(=C\\C)
|
| 255 |
+
(=C\\F)
|
| 256 |
+
(=C\\I)
|
| 257 |
+
(=C\\N)
|
| 258 |
+
(=C\\O)
|
| 259 |
+
(=N)
|
| 260 |
+
(=N/C)
|
| 261 |
+
(=N/N)
|
| 262 |
+
(=N/O)
|
| 263 |
+
(=NBr)
|
| 264 |
+
(=NC)
|
| 265 |
+
(=NCC)
|
| 266 |
+
(=NCl)
|
| 267 |
+
(=NN)
|
| 268 |
+
(=NO)
|
| 269 |
+
(=NOC)
|
| 270 |
+
(=N\\C)
|
| 271 |
+
(=N\\N)
|
| 272 |
+
(=N\\O)
|
| 273 |
+
(=O)
|
| 274 |
+
(=S)
|
| 275 |
+
(B)
|
| 276 |
+
(Br)
|
| 277 |
+
(C#C)
|
| 278 |
+
(C#CC)
|
| 279 |
+
(C#CI)
|
| 280 |
+
(C#CO)
|
| 281 |
+
(C#N)
|
| 282 |
+
(C#SN)
|
| 283 |
+
(C)
|
| 284 |
+
(C=C)
|
| 285 |
+
(C=CF)
|
| 286 |
+
(C=CI)
|
| 287 |
+
(C=N)
|
| 288 |
+
(C=NN)
|
| 289 |
+
(C=NO)
|
| 290 |
+
(C=O)
|
| 291 |
+
(C=S)
|
| 292 |
+
(CBr)
|
| 293 |
+
(CC#C)
|
| 294 |
+
(CC#N)
|
| 295 |
+
(CC)
|
| 296 |
+
(CC=C)
|
| 297 |
+
(CC=O)
|
| 298 |
+
(CCBr)
|
| 299 |
+
(CCC)
|
| 300 |
+
(CCCC)
|
| 301 |
+
(CCCF)
|
| 302 |
+
(CCCI)
|
| 303 |
+
(CCCN)
|
| 304 |
+
(CCCO)
|
| 305 |
+
(CCCS)
|
| 306 |
+
(CCCl)
|
| 307 |
+
(CCF)
|
| 308 |
+
(CCI)
|
| 309 |
+
(CCN)
|
| 310 |
+
(CCNC)
|
| 311 |
+
(CCNN)
|
| 312 |
+
(CCNO)
|
| 313 |
+
(CCO)
|
| 314 |
+
(CCOC)
|
| 315 |
+
(CCON)
|
| 316 |
+
(CCS)
|
| 317 |
+
(CCSC)
|
| 318 |
+
(CCl)
|
| 319 |
+
(CF)
|
| 320 |
+
(CI)
|
| 321 |
+
(CN)
|
| 322 |
+
(CN=O)
|
| 323 |
+
(CNC)
|
| 324 |
+
(CNCC)
|
| 325 |
+
(CNCO)
|
| 326 |
+
(CNN)
|
| 327 |
+
(CNNC)
|
| 328 |
+
(CNO)
|
| 329 |
+
(CNOC)
|
| 330 |
+
(CO)
|
| 331 |
+
(COC)
|
| 332 |
+
(COCC)
|
| 333 |
+
(COCI)
|
| 334 |
+
(COCN)
|
| 335 |
+
(COCO)
|
| 336 |
+
(COF)
|
| 337 |
+
(CON)
|
| 338 |
+
(COO)
|
| 339 |
+
(CS)
|
| 340 |
+
(CSC)
|
| 341 |
+
(CSCC)
|
| 342 |
+
(CSCF)
|
| 343 |
+
(CSO)
|
| 344 |
+
(Cl)
|
| 345 |
+
(F)
|
| 346 |
+
(I)
|
| 347 |
+
(N)
|
| 348 |
+
(N=N)
|
| 349 |
+
(N=NO)
|
| 350 |
+
(N=O)
|
| 351 |
+
(N=S)
|
| 352 |
+
(NBr)
|
| 353 |
+
(NC#N)
|
| 354 |
+
(NC)
|
| 355 |
+
(NC=N)
|
| 356 |
+
(NC=O)
|
| 357 |
+
(NC=S)
|
| 358 |
+
(NCBr)
|
| 359 |
+
(NCC)
|
| 360 |
+
(NCCC)
|
| 361 |
+
(NCCF)
|
| 362 |
+
(NCCN)
|
| 363 |
+
(NCCO)
|
| 364 |
+
(NCCS)
|
| 365 |
+
(NCCl)
|
| 366 |
+
(NCNC)
|
| 367 |
+
(NCO)
|
| 368 |
+
(NCS)
|
| 369 |
+
(NCl)
|
| 370 |
+
(NN)
|
| 371 |
+
(NN=O)
|
| 372 |
+
(NNC)
|
| 373 |
+
(NO)
|
| 374 |
+
(NOC)
|
| 375 |
+
(O)
|
| 376 |
+
(OC#N)
|
| 377 |
+
(OC)
|
| 378 |
+
(OC=C)
|
| 379 |
+
(OC=O)
|
| 380 |
+
(OC=S)
|
| 381 |
+
(OCBr)
|
| 382 |
+
(OCC)
|
| 383 |
+
(OCCC)
|
| 384 |
+
(OCCF)
|
| 385 |
+
(OCCI)
|
| 386 |
+
(OCCN)
|
| 387 |
+
(OCCO)
|
| 388 |
+
(OCCS)
|
| 389 |
+
(OCCl)
|
| 390 |
+
(OCF)
|
| 391 |
+
(OCI)
|
| 392 |
+
(OCO)
|
| 393 |
+
(OCOC)
|
| 394 |
+
(OCON)
|
| 395 |
+
(OCSC)
|
| 396 |
+
(OCl)
|
| 397 |
+
(OI)
|
| 398 |
+
(ON)
|
| 399 |
+
(OO)
|
| 400 |
+
(OOC)
|
| 401 |
+
(OOCC)
|
| 402 |
+
(OOSN)
|
| 403 |
+
(OSC)
|
| 404 |
+
(P)
|
| 405 |
+
(S)
|
| 406 |
+
(SC#N)
|
| 407 |
+
(SC)
|
| 408 |
+
(SCC)
|
| 409 |
+
(SCCC)
|
| 410 |
+
(SCCF)
|
| 411 |
+
(SCCN)
|
| 412 |
+
(SCCO)
|
| 413 |
+
(SCCS)
|
| 414 |
+
(SCCl)
|
| 415 |
+
(SCF)
|
| 416 |
+
(SCN)
|
| 417 |
+
(SCOC)
|
| 418 |
+
(SCSC)
|
| 419 |
+
(SCl)
|
| 420 |
+
(SI)
|
| 421 |
+
(SN)
|
| 422 |
+
(SN=O)
|
| 423 |
+
(SO)
|
| 424 |
+
(SOC)
|
| 425 |
+
(SOOO)
|
| 426 |
+
(SS)
|
| 427 |
+
(SSC)
|
| 428 |
+
(SSCC)
|
| 429 |
+
([At])
|
| 430 |
+
([O-])
|
| 431 |
+
([O])
|
| 432 |
+
([S-])
|
| 433 |
+
(\\Br)
|
| 434 |
+
(\\C#N)
|
| 435 |
+
(\\C)
|
| 436 |
+
(\\C=N)
|
| 437 |
+
(\\C=O)
|
| 438 |
+
(\\CBr)
|
| 439 |
+
(\\CC)
|
| 440 |
+
(\\CCC)
|
| 441 |
+
(\\CCO)
|
| 442 |
+
(\\CCl)
|
| 443 |
+
(\\CF)
|
| 444 |
+
(\\CN)
|
| 445 |
+
(\\CNC)
|
| 446 |
+
(\\CO)
|
| 447 |
+
(\\COC)
|
| 448 |
+
(\\Cl)
|
| 449 |
+
(\\F)
|
| 450 |
+
(\\I)
|
| 451 |
+
(\\N)
|
| 452 |
+
(\\NC)
|
| 453 |
+
(\\NCC)
|
| 454 |
+
(\\NN)
|
| 455 |
+
(\\NO)
|
| 456 |
+
(\\NOC)
|
| 457 |
+
(\\O)
|
| 458 |
+
(\\OC)
|
| 459 |
+
(\\OCC)
|
| 460 |
+
(\\ON)
|
| 461 |
+
(\\S)
|
| 462 |
+
(\\SC)
|
| 463 |
+
(\\SCC)
|
| 464 |
+
[Ag+]
|
| 465 |
+
[Ag-4]
|
| 466 |
+
[Ag]
|
| 467 |
+
[Al-3]
|
| 468 |
+
[Al]
|
| 469 |
+
[As+]
|
| 470 |
+
[AsH3]
|
| 471 |
+
[AsH]
|
| 472 |
+
[As]
|
| 473 |
+
[At]
|
| 474 |
+
[B-]
|
| 475 |
+
[B@-]
|
| 476 |
+
[B@@-]
|
| 477 |
+
[BH-]
|
| 478 |
+
[BH2-]
|
| 479 |
+
[BH3-]
|
| 480 |
+
[B]
|
| 481 |
+
[Ba]
|
| 482 |
+
[Br+2]
|
| 483 |
+
[BrH]
|
| 484 |
+
[Br]
|
| 485 |
+
[C+]
|
| 486 |
+
[C-]
|
| 487 |
+
[C@@H]
|
| 488 |
+
[C@@]
|
| 489 |
+
[C@H]
|
| 490 |
+
[C@]
|
| 491 |
+
[CH-]
|
| 492 |
+
[CH2]
|
| 493 |
+
[CH3]
|
| 494 |
+
[CH]
|
| 495 |
+
[C]
|
| 496 |
+
[CaH2]
|
| 497 |
+
[Ca]
|
| 498 |
+
[Cl+2]
|
| 499 |
+
[Cl+3]
|
| 500 |
+
[Cl+]
|
| 501 |
+
[Cs]
|
| 502 |
+
[FH]
|
| 503 |
+
[F]
|
| 504 |
+
[H]
|
| 505 |
+
[He]
|
| 506 |
+
[I+2]
|
| 507 |
+
[I+3]
|
| 508 |
+
[I+]
|
| 509 |
+
[IH]
|
| 510 |
+
[I]
|
| 511 |
+
[K]
|
| 512 |
+
[Kr]
|
| 513 |
+
[Li+]
|
| 514 |
+
[LiH]
|
| 515 |
+
[MgH2]
|
| 516 |
+
[Mg]
|
| 517 |
+
[N+]
|
| 518 |
+
[N-]
|
| 519 |
+
[N@+]
|
| 520 |
+
[N@@+]
|
| 521 |
+
[N@@]
|
| 522 |
+
[N@]
|
| 523 |
+
[NH+]
|
| 524 |
+
[NH-]
|
| 525 |
+
[NH2+]
|
| 526 |
+
[NH3]
|
| 527 |
+
[NH]
|
| 528 |
+
[N]
|
| 529 |
+
[Na]
|
| 530 |
+
[O+]
|
| 531 |
+
[O-]
|
| 532 |
+
[OH+]
|
| 533 |
+
[OH2]
|
| 534 |
+
[OH]
|
| 535 |
+
[O]
|
| 536 |
+
[P+]
|
| 537 |
+
[P@+]
|
| 538 |
+
[P@@+]
|
| 539 |
+
[P@@]
|
| 540 |
+
[P@]
|
| 541 |
+
[PH2]
|
| 542 |
+
[PH]
|
| 543 |
+
[P]
|
| 544 |
+
[Ra]
|
| 545 |
+
[Rb]
|
| 546 |
+
[S+]
|
| 547 |
+
[S-]
|
| 548 |
+
[S@+]
|
| 549 |
+
[S@@+]
|
| 550 |
+
[S@@]
|
| 551 |
+
[S@]
|
| 552 |
+
[SH+]
|
| 553 |
+
[SH2]
|
| 554 |
+
[SH]
|
| 555 |
+
[S]
|
| 556 |
+
[Se+]
|
| 557 |
+
[Se-2]
|
| 558 |
+
[SeH2]
|
| 559 |
+
[SeH]
|
| 560 |
+
[Se]
|
| 561 |
+
[Si@]
|
| 562 |
+
[SiH2]
|
| 563 |
+
[SiH]
|
| 564 |
+
[Si]
|
| 565 |
+
[SrH2]
|
| 566 |
+
[TeH]
|
| 567 |
+
[Te]
|
| 568 |
+
[Xe]
|
| 569 |
+
[Zn+2]
|
| 570 |
+
[Zn-2]
|
| 571 |
+
[Zn]
|
| 572 |
+
[b-]
|
| 573 |
+
[c+]
|
| 574 |
+
[c-]
|
| 575 |
+
[cH-]
|
| 576 |
+
[cH]
|
| 577 |
+
[c]
|
| 578 |
+
[n+]
|
| 579 |
+
[n-]
|
| 580 |
+
[nH]
|
| 581 |
+
[n]
|
| 582 |
+
[o+]
|
| 583 |
+
[s+]
|
| 584 |
+
[se+]
|
| 585 |
+
[se]
|
| 586 |
+
[te+]
|
| 587 |
+
[te]
|
utils/__pycache__/app.cpython-39.pyc
ADDED
|
Binary file (26.3 kB). View file
|
|
|
utils/__pycache__/filter.cpython-39.pyc
ADDED
|
Binary file (7.35 kB). View file
|
|
|
utils/__pycache__/generate_utils.cpython-39.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
utils/__pycache__/helm_utils.cpython-39.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (9.48 kB). View file
|
|
|
utils/app.py
ADDED
|
@@ -0,0 +1,1255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from io import StringIO
|
| 5 |
+
import rdkit
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem, Draw
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.patches as patches
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import tempfile
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
|
| 16 |
+
class PeptideAnalyzer:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.bond_patterns = [
|
| 19 |
+
(r'OC\(=O\)', 'ester'), # Ester bond
|
| 20 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
|
| 21 |
+
(r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
|
| 22 |
+
(r'NC\(=O\)', 'peptide'), # Standard peptide bond
|
| 23 |
+
(r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
|
| 24 |
+
(r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
|
| 25 |
+
]
|
| 26 |
+
# Three to one letter code mapping
|
| 27 |
+
self.three_to_one = {
|
| 28 |
+
'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
|
| 29 |
+
'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
|
| 30 |
+
'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
|
| 31 |
+
'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
|
| 32 |
+
'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def is_peptide(self, smiles):
|
| 36 |
+
"""Check if the SMILES represents a peptide structure"""
|
| 37 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 38 |
+
if mol is None:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
# Look for peptide bonds: NC(=O) pattern
|
| 42 |
+
peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
|
| 43 |
+
if mol.HasSubstructMatch(peptide_bond_pattern):
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
# Look for N-methylated peptide bonds: N(C)C(=O) pattern
|
| 47 |
+
n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
|
| 48 |
+
if mol.HasSubstructMatch(n_methyl_pattern):
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
def is_cyclic(self, smiles):
|
| 54 |
+
"""Improved cyclic peptide detection"""
|
| 55 |
+
# Check for C-terminal carboxyl
|
| 56 |
+
if smiles.endswith('C(=O)O'):
|
| 57 |
+
return False, [], []
|
| 58 |
+
|
| 59 |
+
# Find all numbers used in ring closures
|
| 60 |
+
ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
|
| 61 |
+
|
| 62 |
+
# Find aromatic ring numbers
|
| 63 |
+
aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
|
| 64 |
+
aromatic_cycles = []
|
| 65 |
+
for match in aromatic_matches:
|
| 66 |
+
numbers = re.findall(r'[0-9]', match)
|
| 67 |
+
aromatic_cycles.extend(numbers)
|
| 68 |
+
|
| 69 |
+
# Numbers that aren't part of aromatic rings are peptide cycles
|
| 70 |
+
peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
|
| 71 |
+
|
| 72 |
+
is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
|
| 73 |
+
return is_cyclic, peptide_cycles, aromatic_cycles
|
| 74 |
+
|
| 75 |
+
def split_on_bonds(self, smiles):
|
| 76 |
+
"""Split SMILES into segments with simplified Pro handling"""
|
| 77 |
+
positions = []
|
| 78 |
+
used = set()
|
| 79 |
+
|
| 80 |
+
# Find Gly pattern first
|
| 81 |
+
gly_pattern = r'NCC\(=O\)'
|
| 82 |
+
for match in re.finditer(gly_pattern, smiles):
|
| 83 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 84 |
+
positions.append({
|
| 85 |
+
'start': match.start(),
|
| 86 |
+
'end': match.end(),
|
| 87 |
+
'type': 'gly',
|
| 88 |
+
'pattern': match.group()
|
| 89 |
+
})
|
| 90 |
+
used.update(range(match.start(), match.end()))
|
| 91 |
+
|
| 92 |
+
for pattern, bond_type in self.bond_patterns:
|
| 93 |
+
for match in re.finditer(pattern, smiles):
|
| 94 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 95 |
+
positions.append({
|
| 96 |
+
'start': match.start(),
|
| 97 |
+
'end': match.end(),
|
| 98 |
+
'type': bond_type,
|
| 99 |
+
'pattern': match.group()
|
| 100 |
+
})
|
| 101 |
+
used.update(range(match.start(), match.end()))
|
| 102 |
+
|
| 103 |
+
# Sort by position
|
| 104 |
+
positions.sort(key=lambda x: x['start'])
|
| 105 |
+
|
| 106 |
+
# Create segments
|
| 107 |
+
segments = []
|
| 108 |
+
|
| 109 |
+
if positions:
|
| 110 |
+
# First segment
|
| 111 |
+
if positions[0]['start'] > 0:
|
| 112 |
+
segments.append({
|
| 113 |
+
'content': smiles[0:positions[0]['start']],
|
| 114 |
+
'bond_after': positions[0]['pattern']
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
# Process segments
|
| 118 |
+
for i in range(len(positions)-1):
|
| 119 |
+
current = positions[i]
|
| 120 |
+
next_pos = positions[i+1]
|
| 121 |
+
|
| 122 |
+
if current['type'] == 'gly':
|
| 123 |
+
segments.append({
|
| 124 |
+
'content': 'NCC(=O)',
|
| 125 |
+
'bond_before': positions[i-1]['pattern'] if i > 0 else None,
|
| 126 |
+
'bond_after': next_pos['pattern']
|
| 127 |
+
})
|
| 128 |
+
else:
|
| 129 |
+
content = smiles[current['end']:next_pos['start']]
|
| 130 |
+
if content:
|
| 131 |
+
segments.append({
|
| 132 |
+
'content': content,
|
| 133 |
+
'bond_before': current['pattern'],
|
| 134 |
+
'bond_after': next_pos['pattern']
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
# Last segment
|
| 138 |
+
if positions[-1]['end'] < len(smiles):
|
| 139 |
+
segments.append({
|
| 140 |
+
'content': smiles[positions[-1]['end']:],
|
| 141 |
+
'bond_before': positions[-1]['pattern']
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
return segments
|
| 145 |
+
|
| 146 |
+
def clean_terminal_carboxyl(self, segment):
|
| 147 |
+
"""Remove C-terminal carboxyl only if it's the true terminus"""
|
| 148 |
+
content = segment['content']
|
| 149 |
+
|
| 150 |
+
# Only clean if:
|
| 151 |
+
# 1. Contains C(=O)O
|
| 152 |
+
# 2. No bond_after exists (meaning it's the last segment)
|
| 153 |
+
# 3. C(=O)O is at the end of the content
|
| 154 |
+
if 'C(=O)O' in content and not segment.get('bond_after'):
|
| 155 |
+
print('recognized?')
|
| 156 |
+
# Remove C(=O)O pattern regardless of position
|
| 157 |
+
cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
|
| 158 |
+
# Remove any leftover empty parentheses
|
| 159 |
+
cleaned = re.sub(r'\(\)', '', cleaned)
|
| 160 |
+
print(cleaned)
|
| 161 |
+
return cleaned
|
| 162 |
+
return content
|
| 163 |
+
|
| 164 |
+
def identify_residue(self, segment):
|
| 165 |
+
"""Identify residue with Pro reconstruction"""
|
| 166 |
+
# Only clean terminal carboxyl if this is the last segment
|
| 167 |
+
content = self.clean_terminal_carboxyl(segment)
|
| 168 |
+
mods = self.get_modifications(segment)
|
| 169 |
+
|
| 170 |
+
# UAA pattern matching section - before regular residues
|
| 171 |
+
# Phenylglycine and derivatives
|
| 172 |
+
if 'c1ccccc1' in content:
|
| 173 |
+
if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
|
| 174 |
+
return '4', mods # Base phenylglycine
|
| 175 |
+
|
| 176 |
+
# 4-substituted phenylalanines
|
| 177 |
+
if 'Cc1ccc' in content:
|
| 178 |
+
if 'OMe' in content or 'OCc1ccc' in content:
|
| 179 |
+
return '0A1', mods # 4-methoxy-Phenylalanine
|
| 180 |
+
elif 'Clc1ccc' in content:
|
| 181 |
+
return '200', mods # 4-chloro-Phenylalanine
|
| 182 |
+
elif 'Brc1ccc' in content:
|
| 183 |
+
return '4BF', mods # 4-Bromo-phenylalanine
|
| 184 |
+
elif 'C#Nc1ccc' in content:
|
| 185 |
+
return '4CF', mods # 4-cyano-phenylalanine
|
| 186 |
+
elif 'Ic1ccc' in content:
|
| 187 |
+
return 'PHI', mods # 4-Iodo-phenylalanine
|
| 188 |
+
elif 'Fc1ccc' in content:
|
| 189 |
+
return 'PFF', mods # 4-Fluoro-phenylalanine
|
| 190 |
+
|
| 191 |
+
# Modified tryptophans
|
| 192 |
+
if 'c[nH]c2' in content:
|
| 193 |
+
if 'Oc2cccc2' in content:
|
| 194 |
+
return '0AF', mods # 7-hydroxy-tryptophan
|
| 195 |
+
elif 'Fc2cccc2' in content:
|
| 196 |
+
return '4FW', mods # 4-fluoro-tryptophan
|
| 197 |
+
elif 'Clc2cccc2' in content:
|
| 198 |
+
return '6CW', mods # 6-chloro-tryptophan
|
| 199 |
+
elif 'Brc2cccc2' in content:
|
| 200 |
+
return 'BTR', mods # 6-bromo-tryptophan
|
| 201 |
+
elif 'COc2cccc2' in content:
|
| 202 |
+
return 'MOT5', mods # 5-Methoxy-tryptophan
|
| 203 |
+
elif 'Cc2cccc2' in content:
|
| 204 |
+
return 'MTR5', mods # 5-Methyl-tryptophan
|
| 205 |
+
|
| 206 |
+
# Special amino acids
|
| 207 |
+
if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
|
| 208 |
+
return 'BUG', mods # Tertleucine
|
| 209 |
+
|
| 210 |
+
if 'CCCNC(=N)N' in content:
|
| 211 |
+
return 'CIR', mods # Citrulline
|
| 212 |
+
|
| 213 |
+
if '[SeH]' in content:
|
| 214 |
+
return 'CSE', mods # Selenocysteine
|
| 215 |
+
|
| 216 |
+
if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
|
| 217 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 218 |
+
|
| 219 |
+
if 'C1CCCCC1' in content:
|
| 220 |
+
if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
|
| 221 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 222 |
+
elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
|
| 223 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 224 |
+
|
| 225 |
+
# Naphthalene derivatives
|
| 226 |
+
if 'c1cccc2c1cccc2' in content:
|
| 227 |
+
if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
|
| 228 |
+
return 'NAL', mods # 2-Naphthyl-alanine
|
| 229 |
+
|
| 230 |
+
# Heteroaromatic derivatives
|
| 231 |
+
if 'c1cncc' in content:
|
| 232 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 233 |
+
if 'c1cscc' in content:
|
| 234 |
+
return 'THA3', mods # 3-(3-thienyl)-alanine
|
| 235 |
+
if 'c1nnc' in content:
|
| 236 |
+
return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
|
| 237 |
+
|
| 238 |
+
# Modified serines and threonines
|
| 239 |
+
if 'OP(O)(O)O' in content:
|
| 240 |
+
if '[C@@H](COP' in content or '[C@H](COP' in content:
|
| 241 |
+
return 'SEP', mods # phosphoserine
|
| 242 |
+
elif '[C@@H](OP' in content or '[C@H](OP' in content:
|
| 243 |
+
return 'TPO', mods # phosphothreonine
|
| 244 |
+
|
| 245 |
+
# Specialized ring systems
|
| 246 |
+
if 'c1c2ccccc2cc2c1cccc2' in content:
|
| 247 |
+
return 'ANTH', mods # 3-(9-anthryl)-alanine
|
| 248 |
+
if 'c1csc2c1cccc2' in content:
|
| 249 |
+
return 'BTH3', mods # 3-(3-benzothienyl)-alanine
|
| 250 |
+
if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
|
| 251 |
+
return 'ADAM', mods # Adamanthane
|
| 252 |
+
|
| 253 |
+
# Fluorinated derivatives
|
| 254 |
+
if 'FC(F)(F)' in content:
|
| 255 |
+
if 'CC(F)(F)F' in content:
|
| 256 |
+
return 'FLA', mods # Trifluoro-alanine
|
| 257 |
+
if 'C(F)(F)F)c1' in content:
|
| 258 |
+
if 'c1ccccc1C(F)(F)F' in content:
|
| 259 |
+
return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
|
| 260 |
+
if 'c1cccc(c1)C(F)(F)F' in content:
|
| 261 |
+
return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
|
| 262 |
+
if 'c1ccc(cc1)C(F)(F)F' in content:
|
| 263 |
+
return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
|
| 264 |
+
|
| 265 |
+
# Multiple halogen patterns
|
| 266 |
+
if 'F' in content and 'c1' in content:
|
| 267 |
+
if 'c1ccc(c(c1)F)F' in content:
|
| 268 |
+
return 'F2F', mods # 3,4-Difluoro-phenylalanine
|
| 269 |
+
if 'cc(F)cc(c1)F' in content:
|
| 270 |
+
return 'WFP', mods # 3,5-Difluoro-phenylalanine
|
| 271 |
+
if 'Cl' in content and 'c1' in content:
|
| 272 |
+
if 'c1ccc(cc1Cl)Cl' in content:
|
| 273 |
+
return 'CP24', mods # 2,4-dichloro-phenylalanine
|
| 274 |
+
if 'c1ccc(c(c1)Cl)Cl' in content:
|
| 275 |
+
return 'CP34', mods # 3,4-dichloro-phenylalanine
|
| 276 |
+
|
| 277 |
+
# Hydroxy and amino derivatives
|
| 278 |
+
if 'O' in content and 'c1' in content:
|
| 279 |
+
if 'c1cc(O)cc(c1)O' in content:
|
| 280 |
+
return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
|
| 281 |
+
if 'c1ccc(c(c1)O)O' in content:
|
| 282 |
+
return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
|
| 283 |
+
|
| 284 |
+
# Cyclic amino acids
|
| 285 |
+
if 'C1CCCC1' in content:
|
| 286 |
+
return 'CPA3', mods # 3-Cyclopentyl-alanine
|
| 287 |
+
if 'C1CCCCC1' in content:
|
| 288 |
+
if 'CC1CCCCC1' in content:
|
| 289 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 290 |
+
else:
|
| 291 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 292 |
+
|
| 293 |
+
# Chain-length variants
|
| 294 |
+
if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
|
| 295 |
+
return 'NLE', mods # Norleucine
|
| 296 |
+
if 'CC[C@@H]' in content or 'CC[C@H]' in content:
|
| 297 |
+
if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
|
| 298 |
+
return 'ABA', mods # 2-Aminobutyric acid
|
| 299 |
+
|
| 300 |
+
# Modified histidines
|
| 301 |
+
if 'c1cnc' in content:
|
| 302 |
+
if '[C@@H]1CN[C@@H](N1)F' in content:
|
| 303 |
+
return '2HF', mods # 2-fluoro-l-histidine
|
| 304 |
+
if 'c1cnc([nH]1)F' in content:
|
| 305 |
+
return '2HF1', mods # 2-fluoro-l-histidine variant
|
| 306 |
+
if 'c1c[nH]c(n1)F' in content:
|
| 307 |
+
return '2HF2', mods # 2-fluoro-l-histidine variant
|
| 308 |
+
|
| 309 |
+
# Sulfur and selenium containing
|
| 310 |
+
if '[SeH]' in content:
|
| 311 |
+
return 'CSE', mods # Selenocysteine
|
| 312 |
+
if 'S' in content:
|
| 313 |
+
if 'CSCc1ccccc1' in content:
|
| 314 |
+
return 'BCS', mods # benzylcysteine
|
| 315 |
+
if 'CCSC' in content:
|
| 316 |
+
return 'ESC', mods # Ethionine
|
| 317 |
+
if 'CCS' in content:
|
| 318 |
+
return 'HCS', mods # homocysteine
|
| 319 |
+
|
| 320 |
+
# Additional modifications
|
| 321 |
+
if 'CN=[N]=N' in content:
|
| 322 |
+
return 'AZDA', mods # azido-alanine
|
| 323 |
+
if '[NH]=[C](=[NH2])=[NH2]' in content:
|
| 324 |
+
if 'CCC[NH]=' in content:
|
| 325 |
+
return 'AGM', mods # 5-methyl-arginine
|
| 326 |
+
if 'CC[NH]=' in content:
|
| 327 |
+
return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
|
| 328 |
+
|
| 329 |
+
if 'CCON' in content:
|
| 330 |
+
return 'CAN', mods # canaline
|
| 331 |
+
if '[C@@H]1C=C[C@@H](C=C1)' in content:
|
| 332 |
+
return 'ACZ', mods # cis-amiclenomycin
|
| 333 |
+
if 'CCC(=O)[NH3]' in content:
|
| 334 |
+
return 'ONL', mods # 5-oxo-l-norleucine
|
| 335 |
+
if 'c1ccncc1' in content:
|
| 336 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 337 |
+
if 'c1ccco1' in content:
|
| 338 |
+
return 'FUA2', mods # (2-furyl)-alanine
|
| 339 |
+
|
| 340 |
+
if 'c1ccc' in content:
|
| 341 |
+
if 'c1ccc(cc1)c1ccccc1' in content:
|
| 342 |
+
return 'BIF', mods # 4,4-biphenylalanine
|
| 343 |
+
if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
|
| 344 |
+
return 'PBF', mods # 4-benzoyl-phenylalanine
|
| 345 |
+
if 'c1ccc(cc1)C(C)(C)C' in content:
|
| 346 |
+
return 'TBP4', mods # 4-tert-butyl-phenylalanine
|
| 347 |
+
if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
|
| 348 |
+
return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
|
| 349 |
+
if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
|
| 350 |
+
return 'APM', mods # m-amidinophenyl-3-alanine
|
| 351 |
+
|
| 352 |
+
# Multiple hydroxy patterns
|
| 353 |
+
if 'O' in content:
|
| 354 |
+
if '[C@H]([C@H](C)O)O' in content:
|
| 355 |
+
return 'ILX', mods # 4,5-dihydroxy-isoleucine
|
| 356 |
+
if '[C@H]([C@@H](C)O)O' in content:
|
| 357 |
+
return 'ALO', mods # Allo-threonine
|
| 358 |
+
if '[C@H](COP(O)(O)O)' in content:
|
| 359 |
+
return 'SEP', mods # phosphoserine
|
| 360 |
+
if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
|
| 361 |
+
return 'TPO', mods # phosphothreonine
|
| 362 |
+
if '[C@H](c1ccc(O)cc1)O' in content:
|
| 363 |
+
return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
|
| 364 |
+
if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
|
| 365 |
+
return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
|
| 366 |
+
|
| 367 |
+
# Heterocyclic patterns
|
| 368 |
+
if 'n1' in content:
|
| 369 |
+
if 'n1cccn1' in content:
|
| 370 |
+
return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
|
| 371 |
+
if 'n1nncn1' in content:
|
| 372 |
+
return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
|
| 373 |
+
if 'c2c(n1)cccc2' in content:
|
| 374 |
+
return 'QU32', mods # 3-(2-Quinolyl)-alanine
|
| 375 |
+
if 'c1cnc2c(c1)cccc2' in content:
|
| 376 |
+
return 'QU33', mods # 3-(3-quinolyl)-alanine
|
| 377 |
+
if 'c1ccnc2c1cccc2' in content:
|
| 378 |
+
return 'QU34', mods # 3-(4-quinolyl)-alanine
|
| 379 |
+
if 'c1ccc2c(c1)nccc2' in content:
|
| 380 |
+
return 'QU35', mods # 3-(5-Quinolyl)-alanine
|
| 381 |
+
if 'c1ccc2c(c1)cncc2' in content:
|
| 382 |
+
return 'QU36', mods # 3-(6-Quinolyl)-alanine
|
| 383 |
+
if 'c1cnc2c(n1)cccc2' in content:
|
| 384 |
+
return 'QX32', mods # 3-(2-quinoxalyl)-alanine
|
| 385 |
+
|
| 386 |
+
# Multiple nitrogen patterns
|
| 387 |
+
if 'N' in content:
|
| 388 |
+
if '[NH3]CC[C@@H]' in content:
|
| 389 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 390 |
+
if '[NH3]C[C@@H]' in content:
|
| 391 |
+
return 'DPP', mods # 2,3-Diaminopropanoic acid
|
| 392 |
+
if '[NH3]CCCCCC[C@@H]' in content:
|
| 393 |
+
return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
|
| 394 |
+
if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
|
| 395 |
+
return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
|
| 396 |
+
if '[NH]=[C](=S)=[NH2]' in content:
|
| 397 |
+
return 'THIC', mods # Thio-citrulline
|
| 398 |
+
|
| 399 |
+
# Chain modified amino acids
|
| 400 |
+
if 'CC' in content:
|
| 401 |
+
if 'CCCC[C@@H]' in content:
|
| 402 |
+
return 'AHP', mods # 2-Aminoheptanoic acid
|
| 403 |
+
if 'CCC([C@@H])(C)C' in content:
|
| 404 |
+
return 'I2M', mods # 3-methyl-l-alloisoleucine
|
| 405 |
+
if 'CC[C@H]([C@@H])C' in content:
|
| 406 |
+
return 'IIL', mods # Allo-Isoleucine
|
| 407 |
+
if '[C@H](CCC(C)C)' in content:
|
| 408 |
+
return 'HLEU', mods # Homoleucine
|
| 409 |
+
if '[C@@H]([C@@H](C)O)C' in content:
|
| 410 |
+
return 'HLU', mods # beta-hydroxyleucine
|
| 411 |
+
|
| 412 |
+
# Modified glutamate/aspartate patterns
|
| 413 |
+
if '[C@@H]' in content:
|
| 414 |
+
if '[C@@H](C[C@@H](F))' in content:
|
| 415 |
+
return 'FGA4', mods # 4-Fluoro-glutamic acid
|
| 416 |
+
if '[C@@H](C[C@@H](O))' in content:
|
| 417 |
+
return '3GL', mods # 4-hydroxy-glutamic-acid
|
| 418 |
+
if '[C@@H](C[C@H](C))' in content:
|
| 419 |
+
return 'LME', mods # (3r)-3-methyl-l-glutamic acid
|
| 420 |
+
if '[C@@H](CC[C@H](C))' in content:
|
| 421 |
+
return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
|
| 422 |
+
|
| 423 |
+
# Sulfur and selenium modifications
|
| 424 |
+
if 'S' in content:
|
| 425 |
+
if 'SCC[C@@H]' in content:
|
| 426 |
+
return 'HSER', mods # homoserine
|
| 427 |
+
if 'SCCN' in content:
|
| 428 |
+
return 'SLZ', mods # thialysine
|
| 429 |
+
if 'SC(=O)' in content:
|
| 430 |
+
return 'CSA', mods # s-acetonylcysteine
|
| 431 |
+
if '[S@@](=O)' in content:
|
| 432 |
+
return 'SME', mods # Methionine sulfoxide
|
| 433 |
+
if 'S(=O)(=O)' in content:
|
| 434 |
+
return 'OMT', mods # Methionine sulfone
|
| 435 |
+
|
| 436 |
+
# Double bond containing
|
| 437 |
+
if 'C=' in content:
|
| 438 |
+
if 'C=C[C@@H]' in content:
|
| 439 |
+
return '2AG', mods # 2-Allyl-glycine
|
| 440 |
+
if 'C=C[C@@H]' in content:
|
| 441 |
+
return 'LVG', mods # vinylglycine
|
| 442 |
+
if 'C=Cc1ccccc1' in content:
|
| 443 |
+
return 'STYA', mods # Styrylalanine
|
| 444 |
+
|
| 445 |
+
# Special cases
|
| 446 |
+
if '[C@@H]1Cc2c(C1)cccc2' in content:
|
| 447 |
+
return 'IGL', mods # alpha-amino-2-indanacetic acid
|
| 448 |
+
if '[C](=[C](=O)=O)=O' in content:
|
| 449 |
+
return '26P', mods # 2-amino-6-oxopimelic acid
|
| 450 |
+
if '[C](=[C](=O)=O)=C' in content:
|
| 451 |
+
return '2NP', mods # l-2-amino-6-methylene-pimelic acid
|
| 452 |
+
if 'c2cnc[nH]2' in content:
|
| 453 |
+
return 'HIS', mods # histidine core
|
| 454 |
+
if 'c1cccc2c1cc(O)cc2' in content:
|
| 455 |
+
return 'NAO1', mods # 5-hydroxy-1-naphthalene
|
| 456 |
+
if 'c1ccc2c(c1)cc(O)cc2' in content:
|
| 457 |
+
return 'NAO2', mods # 6-hydroxy-2-naphthalene
|
| 458 |
+
|
| 459 |
+
# Proline (P) - flexible ring numbers
|
| 460 |
+
if any([
|
| 461 |
+
# Check for any ring number in bond patterns
|
| 462 |
+
(segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
|
| 463 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 464 |
+
for n in '123456789'
|
| 465 |
+
]) or any([
|
| 466 |
+
# Check ending patterns with any ring number
|
| 467 |
+
(f'CCCN{n}' in content and content.endswith('=O') and
|
| 468 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 469 |
+
for n in '123456789'
|
| 470 |
+
]) or any([
|
| 471 |
+
# Handle CCC[C@H]n patterns
|
| 472 |
+
(content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 473 |
+
(content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 474 |
+
# N-terminal Pro with any ring number
|
| 475 |
+
(f'N{n}CCC[C@H]{n}' in content) or
|
| 476 |
+
(f'N{n}CCC[C@@H]{n}' in content)
|
| 477 |
+
for n in '123456789'
|
| 478 |
+
]):
|
| 479 |
+
return 'Pro', mods
|
| 480 |
+
|
| 481 |
+
# Tryptophan (W) - more specific indole pattern
|
| 482 |
+
if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
|
| 483 |
+
'c[nH]c' in content.replace(' ', ''):
|
| 484 |
+
return 'Trp', mods
|
| 485 |
+
|
| 486 |
+
# Lysine (K) - both patterns
|
| 487 |
+
if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
|
| 488 |
+
return 'Lys', mods
|
| 489 |
+
|
| 490 |
+
# Arginine (R) - both patterns
|
| 491 |
+
if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
|
| 492 |
+
return 'Arg', mods
|
| 493 |
+
|
| 494 |
+
if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
|
| 495 |
+
return 'Nle', mods
|
| 496 |
+
|
| 497 |
+
# Ornithine (Orn) - 3-carbon chain with NH2
|
| 498 |
+
if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
|
| 499 |
+
return 'Orn', mods
|
| 500 |
+
|
| 501 |
+
# 2-Naphthylalanine (2Nal) - distinct from Phe pattern
|
| 502 |
+
if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 503 |
+
return '2Nal', mods
|
| 504 |
+
|
| 505 |
+
# Cyclohexylalanine (Cha) - already in your code but moved here for clarity
|
| 506 |
+
if 'N2CCCCC2' in content or 'CCCCC2' in content:
|
| 507 |
+
return 'Cha', mods
|
| 508 |
+
|
| 509 |
+
# Aminobutyric acid (Abu) - 2-carbon chain
|
| 510 |
+
if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
|
| 511 |
+
return 'Abu', mods
|
| 512 |
+
|
| 513 |
+
# Pipecolic acid (Pip) - 6-membered ring like Pro
|
| 514 |
+
if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 515 |
+
return 'Pip', mods
|
| 516 |
+
|
| 517 |
+
# Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
|
| 518 |
+
if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
|
| 519 |
+
return 'Chg', mods
|
| 520 |
+
|
| 521 |
+
# 4-Fluorophenylalanine (4F-Phe)
|
| 522 |
+
if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 523 |
+
return '4F-Phe', mods
|
| 524 |
+
|
| 525 |
+
# Regular residue identification
|
| 526 |
+
if ('NCC(=O)' in content) or (content == 'C'):
|
| 527 |
+
# Middle case - between bonds
|
| 528 |
+
if segment.get('bond_before') and segment.get('bond_after'):
|
| 529 |
+
if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
|
| 530 |
+
return 'Gly', mods
|
| 531 |
+
# Terminal case - at the end
|
| 532 |
+
elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
|
| 533 |
+
return 'Gly', mods
|
| 534 |
+
|
| 535 |
+
if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
|
| 536 |
+
return 'Leu', mods
|
| 537 |
+
if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
|
| 538 |
+
return 'Leu', mods
|
| 539 |
+
|
| 540 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
|
| 541 |
+
return 'Thr', mods
|
| 542 |
+
|
| 543 |
+
if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
|
| 544 |
+
return 'Phe', mods
|
| 545 |
+
|
| 546 |
+
if ('[C@H](C(C)C)' in content or # With outer parentheses
|
| 547 |
+
'[C@@H](C(C)C)' in content or # With outer parentheses
|
| 548 |
+
'[C@H]C(C)C' in content or # Without outer parentheses
|
| 549 |
+
'[C@@H]C(C)C' in content): # Without outer parentheses
|
| 550 |
+
if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
|
| 551 |
+
return 'Val', mods
|
| 552 |
+
|
| 553 |
+
if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
|
| 554 |
+
return 'O-tBu', mods
|
| 555 |
+
|
| 556 |
+
if any([
|
| 557 |
+
'CC[C@H](C)' in content,
|
| 558 |
+
'CC[C@@H](C)' in content,
|
| 559 |
+
'C(C)C[C@H]' in content and 'CC(C)C' not in content,
|
| 560 |
+
'C(C)C[C@@H]' in content and 'CC(C)C' not in content
|
| 561 |
+
]):
|
| 562 |
+
return 'Ile', mods
|
| 563 |
+
|
| 564 |
+
if ('[C@H](C)' in content or '[C@@H](C)' in content):
|
| 565 |
+
if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
|
| 566 |
+
return 'Ala', mods
|
| 567 |
+
|
| 568 |
+
# Tyrosine (Tyr) - 4-hydroxybenzyl side chain
|
| 569 |
+
if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
|
| 570 |
+
return 'Tyr', mods
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
# Serine (Ser) - Hydroxymethyl side chain
|
| 574 |
+
if '[C@H](CO)' in content or '[C@@H](CO)' in content:
|
| 575 |
+
if not ('C(C)O' in content or 'COC' in content):
|
| 576 |
+
return 'Ser', mods
|
| 577 |
+
|
| 578 |
+
# Threonine (Thr) - 1-hydroxyethyl side chain
|
| 579 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
|
| 580 |
+
return 'Thr', mods
|
| 581 |
+
|
| 582 |
+
# Cysteine (Cys) - Thiol side chain
|
| 583 |
+
if '[C@H](CS)' in content or '[C@@H](CS)' in content:
|
| 584 |
+
return 'Cys', mods
|
| 585 |
+
|
| 586 |
+
# Methionine (Met) - Methylthioethyl side chain
|
| 587 |
+
if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
|
| 588 |
+
return 'Met', mods
|
| 589 |
+
|
| 590 |
+
# Asparagine (Asn) - Carbamoylmethyl side chain
|
| 591 |
+
if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 592 |
+
return 'Asn', mods
|
| 593 |
+
|
| 594 |
+
# Glutamine (Gln) - Carbamoylethyl side chain
|
| 595 |
+
if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 596 |
+
return 'Gln', mods
|
| 597 |
+
|
| 598 |
+
# Aspartic acid (Asp) - Carboxymethyl side chain
|
| 599 |
+
if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 600 |
+
return 'Asp', mods
|
| 601 |
+
|
| 602 |
+
# Glutamic acid (Glu) - Carboxyethyl side chain
|
| 603 |
+
if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 604 |
+
return 'Glu', mods
|
| 605 |
+
|
| 606 |
+
# Arginine (Arg) - 3-guanidinopropyl side chain
|
| 607 |
+
if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 608 |
+
return 'Arg', mods
|
| 609 |
+
|
| 610 |
+
# Histidine (His) - Imidazole side chain
|
| 611 |
+
if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 612 |
+
return 'His', mods
|
| 613 |
+
|
| 614 |
+
return None, mods
|
| 615 |
+
|
| 616 |
+
def get_modifications(self, segment):
|
| 617 |
+
"""Get modifications based on bond types"""
|
| 618 |
+
mods = []
|
| 619 |
+
if segment.get('bond_after'):
|
| 620 |
+
if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
|
| 621 |
+
mods.append('N-Me')
|
| 622 |
+
if 'OC(=O)' in segment['bond_after']:
|
| 623 |
+
mods.append('O-linked')
|
| 624 |
+
return mods
|
| 625 |
+
|
| 626 |
+
def analyze_structure(self, smiles):
|
| 627 |
+
"""Main analysis function with debug output"""
|
| 628 |
+
print("\nAnalyzing structure:", smiles)
|
| 629 |
+
|
| 630 |
+
# Split into segments
|
| 631 |
+
segments = self.split_on_bonds(smiles)
|
| 632 |
+
|
| 633 |
+
print("\nSegment Analysis:")
|
| 634 |
+
sequence = []
|
| 635 |
+
for i, segment in enumerate(segments):
|
| 636 |
+
print(f"\nSegment {i}:")
|
| 637 |
+
print(f"Content: {segment['content']}")
|
| 638 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 639 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 640 |
+
|
| 641 |
+
residue, mods = self.identify_residue(segment)
|
| 642 |
+
if residue:
|
| 643 |
+
if mods:
|
| 644 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 645 |
+
else:
|
| 646 |
+
sequence.append(residue)
|
| 647 |
+
print(f"Identified as: {residue}")
|
| 648 |
+
print(f"Modifications: {mods}")
|
| 649 |
+
else:
|
| 650 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 651 |
+
|
| 652 |
+
# Check if cyclic
|
| 653 |
+
is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
|
| 654 |
+
three_letter = '-'.join(sequence)
|
| 655 |
+
one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
|
| 656 |
+
|
| 657 |
+
if is_cyclic:
|
| 658 |
+
three_letter = f"cyclo({three_letter})"
|
| 659 |
+
one_letter = f"cyclo({one_letter})"
|
| 660 |
+
|
| 661 |
+
print(f"\nFinal sequence: {three_letter}")
|
| 662 |
+
print(f"One-letter code: {one_letter}")
|
| 663 |
+
print(f"Is cyclic: {is_cyclic}")
|
| 664 |
+
#print(f"Peptide cycles: {peptide_cycles}")
|
| 665 |
+
#print(f"Aromatic cycles: {aromatic_cycles}")
|
| 666 |
+
|
| 667 |
+
return three_letter, len(segments)
|
| 668 |
+
"""return {
|
| 669 |
+
'three_letter': three_letter,
|
| 670 |
+
#'one_letter': one_letter,
|
| 671 |
+
'is_cyclic': is_cyclic
|
| 672 |
+
}"""
|
| 673 |
+
|
| 674 |
+
def return_sequence(self, smiles):
|
| 675 |
+
"""Main analysis function with debug output"""
|
| 676 |
+
print("\nAnalyzing structure:", smiles)
|
| 677 |
+
|
| 678 |
+
# Split into segments
|
| 679 |
+
segments = self.split_on_bonds(smiles)
|
| 680 |
+
|
| 681 |
+
print("\nSegment Analysis:")
|
| 682 |
+
sequence = []
|
| 683 |
+
for i, segment in enumerate(segments):
|
| 684 |
+
print(f"\nSegment {i}:")
|
| 685 |
+
print(f"Content: {segment['content']}")
|
| 686 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 687 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 688 |
+
|
| 689 |
+
residue, mods = self.identify_residue(segment)
|
| 690 |
+
if residue:
|
| 691 |
+
if mods:
|
| 692 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 693 |
+
else:
|
| 694 |
+
sequence.append(residue)
|
| 695 |
+
print(f"Identified as: {residue}")
|
| 696 |
+
print(f"Modifications: {mods}")
|
| 697 |
+
else:
|
| 698 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 699 |
+
|
| 700 |
+
return sequence
|
| 701 |
+
|
| 702 |
+
"""
|
| 703 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 704 |
+
'''Create annotated 2D structure with clear, non-overlapping residue labels'''
|
| 705 |
+
# Generate 2D coordinates
|
| 706 |
+
# Generate 2D coordinates
|
| 707 |
+
AllChem.Compute2DCoords(mol)
|
| 708 |
+
|
| 709 |
+
# Create drawer with larger size for annotations
|
| 710 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
|
| 711 |
+
|
| 712 |
+
# Get residue list and reverse it to match structural representation
|
| 713 |
+
if sequence.startswith('cyclo('):
|
| 714 |
+
residues = sequence[6:-1].split('-')
|
| 715 |
+
else:
|
| 716 |
+
residues = sequence.split('-')
|
| 717 |
+
residues = list(reversed(residues)) # Reverse the sequence
|
| 718 |
+
|
| 719 |
+
# Draw molecule first to get its bounds
|
| 720 |
+
drawer.drawOptions().addAtomIndices = False
|
| 721 |
+
drawer.DrawMolecule(mol)
|
| 722 |
+
drawer.FinishDrawing()
|
| 723 |
+
|
| 724 |
+
# Convert to PIL Image
|
| 725 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 726 |
+
draw = ImageDraw.Draw(img)
|
| 727 |
+
|
| 728 |
+
try:
|
| 729 |
+
# Try to use DejaVuSans as it's commonly available on Linux systems
|
| 730 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 731 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 732 |
+
except OSError:
|
| 733 |
+
try:
|
| 734 |
+
# Fallback to Arial if available (common on Windows)
|
| 735 |
+
font = ImageFont.truetype("arial.ttf", 60)
|
| 736 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 737 |
+
except OSError:
|
| 738 |
+
# If no TrueType fonts are available, fall back to default
|
| 739 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 740 |
+
font = ImageFont.load_default()
|
| 741 |
+
small_font = ImageFont.load_default()
|
| 742 |
+
# Get molecule bounds
|
| 743 |
+
conf = mol.GetConformer()
|
| 744 |
+
positions = []
|
| 745 |
+
for i in range(mol.GetNumAtoms()):
|
| 746 |
+
pos = conf.GetAtomPosition(i)
|
| 747 |
+
positions.append((pos.x, pos.y))
|
| 748 |
+
|
| 749 |
+
x_coords = [p[0] for p in positions]
|
| 750 |
+
y_coords = [p[1] for p in positions]
|
| 751 |
+
min_x, max_x = min(x_coords), max(x_coords)
|
| 752 |
+
min_y, max_y = min(y_coords), max(y_coords)
|
| 753 |
+
|
| 754 |
+
# Calculate scaling factors
|
| 755 |
+
scale = 150 # Increased scale factor
|
| 756 |
+
center_x = 1000 # Image center
|
| 757 |
+
center_y = 1000
|
| 758 |
+
|
| 759 |
+
# Add residue labels in a circular arrangement around the structure
|
| 760 |
+
n_residues = len(residues)
|
| 761 |
+
radius = 700 # Distance of labels from center
|
| 762 |
+
|
| 763 |
+
# Start from the rightmost point (3 o'clock position) and go counterclockwise
|
| 764 |
+
# Offset by -3 positions to align with structure
|
| 765 |
+
offset = 0 # Adjust this value to match the structure alignment
|
| 766 |
+
for i, residue in enumerate(residues):
|
| 767 |
+
# Calculate position in a circle around the structure
|
| 768 |
+
# Start from 0 (3 o'clock) and go counterclockwise
|
| 769 |
+
angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
|
| 770 |
+
|
| 771 |
+
# Calculate label position
|
| 772 |
+
label_x = center_x + radius * np.cos(angle)
|
| 773 |
+
label_y = center_y + radius * np.sin(angle)
|
| 774 |
+
|
| 775 |
+
# Draw residue label
|
| 776 |
+
text = f"{i+1}. {residue}"
|
| 777 |
+
bbox = draw.textbbox((label_x, label_y), text, font=font)
|
| 778 |
+
padding = 10
|
| 779 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 780 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 781 |
+
fill='white', outline='white')
|
| 782 |
+
draw.text((label_x, label_y), text,
|
| 783 |
+
font=font, fill='black', anchor="mm")
|
| 784 |
+
|
| 785 |
+
# Add sequence at the top with white background
|
| 786 |
+
seq_text = f"Sequence: {sequence}"
|
| 787 |
+
bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
|
| 788 |
+
padding = 10
|
| 789 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 790 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 791 |
+
fill='white', outline='white')
|
| 792 |
+
draw.text((center_x, 100), seq_text,
|
| 793 |
+
font=small_font, fill='black', anchor="mm")
|
| 794 |
+
|
| 795 |
+
return img
|
| 796 |
+
|
| 797 |
+
"""
|
| 798 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 799 |
+
"""Create structure visualization with just the sequence header"""
|
| 800 |
+
# Generate 2D coordinates
|
| 801 |
+
AllChem.Compute2DCoords(mol)
|
| 802 |
+
|
| 803 |
+
# Create drawer with larger size for annotations
|
| 804 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
|
| 805 |
+
|
| 806 |
+
# Draw molecule first
|
| 807 |
+
drawer.drawOptions().addAtomIndices = False
|
| 808 |
+
drawer.DrawMolecule(mol)
|
| 809 |
+
drawer.FinishDrawing()
|
| 810 |
+
|
| 811 |
+
# Convert to PIL Image
|
| 812 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 813 |
+
draw = ImageDraw.Draw(img)
|
| 814 |
+
try:
|
| 815 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 816 |
+
except OSError:
|
| 817 |
+
try:
|
| 818 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 819 |
+
except OSError:
|
| 820 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 821 |
+
small_font = ImageFont.load_default()
|
| 822 |
+
|
| 823 |
+
# Add just the sequence header at the top
|
| 824 |
+
seq_text = f"Sequence: {sequence}"
|
| 825 |
+
bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
|
| 826 |
+
padding = 10
|
| 827 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 828 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 829 |
+
fill='white', outline='white')
|
| 830 |
+
draw.text((1000, 100), seq_text,
|
| 831 |
+
font=small_font, fill='black', anchor="mm")
|
| 832 |
+
|
| 833 |
+
return img
|
| 834 |
+
|
| 835 |
+
def create_enhanced_linear_viz(sequence, smiles):
|
| 836 |
+
"""Create an enhanced linear representation using PeptideAnalyzer"""
|
| 837 |
+
analyzer = PeptideAnalyzer() # Create analyzer instance
|
| 838 |
+
|
| 839 |
+
# Create figure with two subplots
|
| 840 |
+
fig = plt.figure(figsize=(15, 10))
|
| 841 |
+
gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
|
| 842 |
+
ax_struct = fig.add_subplot(gs[0])
|
| 843 |
+
ax_detail = fig.add_subplot(gs[1])
|
| 844 |
+
|
| 845 |
+
# Parse sequence and get residues
|
| 846 |
+
if sequence.startswith('cyclo('):
|
| 847 |
+
residues = sequence[6:-1].split('-')
|
| 848 |
+
else:
|
| 849 |
+
residues = sequence.split('-')
|
| 850 |
+
|
| 851 |
+
# Get segments using analyzer
|
| 852 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 853 |
+
|
| 854 |
+
# Debug print
|
| 855 |
+
print(f"Number of residues: {len(residues)}")
|
| 856 |
+
print(f"Number of segments: {len(segments)}")
|
| 857 |
+
|
| 858 |
+
# Top subplot - Basic structure
|
| 859 |
+
ax_struct.set_xlim(0, 10)
|
| 860 |
+
ax_struct.set_ylim(0, 2)
|
| 861 |
+
|
| 862 |
+
num_residues = len(residues)
|
| 863 |
+
spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
|
| 864 |
+
|
| 865 |
+
# Draw basic structure
|
| 866 |
+
y_pos = 1.5
|
| 867 |
+
for i in range(num_residues):
|
| 868 |
+
x_pos = 0.5 + i * spacing
|
| 869 |
+
|
| 870 |
+
# Draw amino acid box
|
| 871 |
+
rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
|
| 872 |
+
facecolor='lightblue', edgecolor='black')
|
| 873 |
+
ax_struct.add_patch(rect)
|
| 874 |
+
|
| 875 |
+
# Draw connecting bonds if not the last residue
|
| 876 |
+
if i < num_residues - 1:
|
| 877 |
+
segment = segments[i] if i < len(segments) else None
|
| 878 |
+
if segment:
|
| 879 |
+
# Determine bond type from segment info
|
| 880 |
+
bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
|
| 881 |
+
is_n_methylated = 'N-Me' in segment.get('bond_after', '')
|
| 882 |
+
|
| 883 |
+
bond_color = 'red' if bond_type == 'ester' else 'black'
|
| 884 |
+
linestyle = '--' if bond_type == 'ester' else '-'
|
| 885 |
+
|
| 886 |
+
# Draw bond line
|
| 887 |
+
ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
|
| 888 |
+
color=bond_color, linestyle=linestyle, linewidth=2)
|
| 889 |
+
|
| 890 |
+
# Add bond type label
|
| 891 |
+
mid_x = x_pos + spacing/2
|
| 892 |
+
bond_label = f"{bond_type}"
|
| 893 |
+
if is_n_methylated:
|
| 894 |
+
bond_label += "\n(N-Me)"
|
| 895 |
+
ax_struct.text(mid_x, y_pos+0.1, bond_label,
|
| 896 |
+
ha='center', va='bottom', fontsize=10,
|
| 897 |
+
color=bond_color)
|
| 898 |
+
|
| 899 |
+
# Add residue label
|
| 900 |
+
ax_struct.text(x_pos, y_pos-0.5, residues[i],
|
| 901 |
+
ha='center', va='top', fontsize=14)
|
| 902 |
+
|
| 903 |
+
# Bottom subplot - Detailed breakdown
|
| 904 |
+
ax_detail.set_ylim(0, len(segments)+1)
|
| 905 |
+
ax_detail.set_xlim(0, 1)
|
| 906 |
+
|
| 907 |
+
# Create detailed breakdown
|
| 908 |
+
segment_y = len(segments) # Start from top
|
| 909 |
+
for i, segment in enumerate(segments):
|
| 910 |
+
y = segment_y - i
|
| 911 |
+
|
| 912 |
+
# Check if this is a bond or residue
|
| 913 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 914 |
+
if residue:
|
| 915 |
+
text = f"Residue {i+1}: {residue}"
|
| 916 |
+
if mods:
|
| 917 |
+
text += f" ({', '.join(mods)})"
|
| 918 |
+
color = 'blue'
|
| 919 |
+
else:
|
| 920 |
+
# Must be a bond
|
| 921 |
+
text = f"Bond {i}: "
|
| 922 |
+
if 'O-linked' in segment.get('bond_after', ''):
|
| 923 |
+
text += "ester"
|
| 924 |
+
elif 'N-Me' in segment.get('bond_after', ''):
|
| 925 |
+
text += "peptide (N-methylated)"
|
| 926 |
+
else:
|
| 927 |
+
text += "peptide"
|
| 928 |
+
color = 'red'
|
| 929 |
+
|
| 930 |
+
# Add segment analysis
|
| 931 |
+
ax_detail.text(0.05, y, text, fontsize=12, color=color)
|
| 932 |
+
ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
|
| 933 |
+
|
| 934 |
+
# If cyclic, add connection indicator
|
| 935 |
+
if sequence.startswith('cyclo('):
|
| 936 |
+
ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
|
| 937 |
+
arrowprops=dict(arrowstyle='<->', color='red', lw=2))
|
| 938 |
+
ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
|
| 939 |
+
ha='center', color='red', fontsize=14)
|
| 940 |
+
|
| 941 |
+
# Add titles and adjust layout
|
| 942 |
+
ax_struct.set_title("Peptide Structure Overview", pad=20)
|
| 943 |
+
ax_detail.set_title("Segment Analysis Breakdown", pad=20)
|
| 944 |
+
|
| 945 |
+
# Remove axes
|
| 946 |
+
for ax in [ax_struct, ax_detail]:
|
| 947 |
+
ax.set_xticks([])
|
| 948 |
+
ax.set_yticks([])
|
| 949 |
+
ax.axis('off')
|
| 950 |
+
|
| 951 |
+
plt.tight_layout()
|
| 952 |
+
return fig
|
| 953 |
+
|
| 954 |
+
class PeptideStructureGenerator:
|
| 955 |
+
"""A class to generate 3D structures of peptides using different embedding methods"""
|
| 956 |
+
|
| 957 |
+
@staticmethod
|
| 958 |
+
def prepare_molecule(smiles):
|
| 959 |
+
"""Prepare molecule with proper hydrogen handling"""
|
| 960 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| 961 |
+
if mol is None:
|
| 962 |
+
raise ValueError("Failed to create molecule from SMILES")
|
| 963 |
+
|
| 964 |
+
# Calculate valence for each atom
|
| 965 |
+
for atom in mol.GetAtoms():
|
| 966 |
+
atom.UpdatePropertyCache(strict=False)
|
| 967 |
+
|
| 968 |
+
# Sanitize with reduced requirements
|
| 969 |
+
Chem.SanitizeMol(mol,
|
| 970 |
+
sanitizeOps=Chem.SANITIZE_FINDRADICALS|
|
| 971 |
+
Chem.SANITIZE_KEKULIZE|
|
| 972 |
+
Chem.SANITIZE_SETAROMATICITY|
|
| 973 |
+
Chem.SANITIZE_SETCONJUGATION|
|
| 974 |
+
Chem.SANITIZE_SETHYBRIDIZATION|
|
| 975 |
+
Chem.SANITIZE_CLEANUPCHIRALITY)
|
| 976 |
+
|
| 977 |
+
mol = Chem.AddHs(mol)
|
| 978 |
+
return mol
|
| 979 |
+
|
| 980 |
+
@staticmethod
|
| 981 |
+
def get_etkdg_params(attempt=0):
|
| 982 |
+
"""Get ETKDG parameters with optional modifications based on attempt number"""
|
| 983 |
+
params = AllChem.ETKDGv3()
|
| 984 |
+
params.randomSeed = -1
|
| 985 |
+
params.maxIterations = 200
|
| 986 |
+
params.numThreads = 4 # Reduced for web interface
|
| 987 |
+
params.useBasicKnowledge = True
|
| 988 |
+
params.enforceChirality = True
|
| 989 |
+
params.useExpTorsionAnglePrefs = True
|
| 990 |
+
params.useSmallRingTorsions = True
|
| 991 |
+
params.useMacrocycleTorsions = True
|
| 992 |
+
params.ETversion = 2
|
| 993 |
+
params.pruneRmsThresh = -1
|
| 994 |
+
params.embedRmsThresh = 0.5
|
| 995 |
+
|
| 996 |
+
if attempt > 10:
|
| 997 |
+
params.bondLength = 1.5 + (attempt - 10) * 0.02
|
| 998 |
+
params.useExpTorsionAnglePrefs = False
|
| 999 |
+
|
| 1000 |
+
return params
|
| 1001 |
+
|
| 1002 |
+
def generate_structure_etkdg(self, smiles, max_attempts=20):
|
| 1003 |
+
"""Generate 3D structure using ETKDG without UFF optimization"""
|
| 1004 |
+
success = False
|
| 1005 |
+
mol = None
|
| 1006 |
+
|
| 1007 |
+
for attempt in range(max_attempts):
|
| 1008 |
+
try:
|
| 1009 |
+
mol = self.prepare_molecule(smiles)
|
| 1010 |
+
params = self.get_etkdg_params(attempt)
|
| 1011 |
+
|
| 1012 |
+
if AllChem.EmbedMolecule(mol, params) == 0:
|
| 1013 |
+
success = True
|
| 1014 |
+
break
|
| 1015 |
+
except Exception as e:
|
| 1016 |
+
continue
|
| 1017 |
+
|
| 1018 |
+
if not success:
|
| 1019 |
+
raise ValueError("Failed to generate structure with ETKDG")
|
| 1020 |
+
|
| 1021 |
+
return mol
|
| 1022 |
+
|
| 1023 |
+
def generate_structure_uff(self, smiles, max_attempts=20):
|
| 1024 |
+
"""Generate 3D structure using ETKDG followed by UFF optimization"""
|
| 1025 |
+
best_mol = None
|
| 1026 |
+
lowest_energy = float('inf')
|
| 1027 |
+
|
| 1028 |
+
for attempt in range(max_attempts):
|
| 1029 |
+
try:
|
| 1030 |
+
test_mol = self.prepare_molecule(smiles)
|
| 1031 |
+
params = self.get_etkdg_params(attempt)
|
| 1032 |
+
|
| 1033 |
+
if AllChem.EmbedMolecule(test_mol, params) == 0:
|
| 1034 |
+
res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
|
| 1035 |
+
vdwThresh=10.0, confId=0,
|
| 1036 |
+
ignoreInterfragInteractions=True)
|
| 1037 |
+
|
| 1038 |
+
if res == 0:
|
| 1039 |
+
ff = AllChem.UFFGetMoleculeForceField(test_mol)
|
| 1040 |
+
if ff:
|
| 1041 |
+
current_energy = ff.CalcEnergy()
|
| 1042 |
+
if current_energy < lowest_energy:
|
| 1043 |
+
lowest_energy = current_energy
|
| 1044 |
+
best_mol = Chem.Mol(test_mol)
|
| 1045 |
+
except Exception:
|
| 1046 |
+
continue
|
| 1047 |
+
|
| 1048 |
+
if best_mol is None:
|
| 1049 |
+
raise ValueError("Failed to generate optimized structure")
|
| 1050 |
+
|
| 1051 |
+
return best_mol
|
| 1052 |
+
|
| 1053 |
+
@staticmethod
|
| 1054 |
+
def mol_to_sdf_bytes(mol):
|
| 1055 |
+
"""Convert RDKit molecule to SDF file bytes"""
|
| 1056 |
+
# First write to StringIO in text mode
|
| 1057 |
+
sio = StringIO()
|
| 1058 |
+
writer = Chem.SDWriter(sio)
|
| 1059 |
+
writer.write(mol)
|
| 1060 |
+
writer.close()
|
| 1061 |
+
|
| 1062 |
+
# Convert the string to bytes
|
| 1063 |
+
return sio.getvalue().encode('utf-8')
|
| 1064 |
+
|
| 1065 |
+
def process_input(smiles_input=None, file_obj=None, show_linear=False,
|
| 1066 |
+
show_segment_details=False, generate_3d=False, use_uff=False):
|
| 1067 |
+
"""Process input and create visualizations using PeptideAnalyzer"""
|
| 1068 |
+
analyzer = PeptideAnalyzer()
|
| 1069 |
+
temp_dir = tempfile.mkdtemp() if generate_3d else None
|
| 1070 |
+
structure_files = []
|
| 1071 |
+
|
| 1072 |
+
# Handle direct SMILES input
|
| 1073 |
+
if smiles_input:
|
| 1074 |
+
smiles = smiles_input.strip()
|
| 1075 |
+
|
| 1076 |
+
# First check if it's a peptide using analyzer's method
|
| 1077 |
+
if not analyzer.is_peptide(smiles):
|
| 1078 |
+
return "Error: Input SMILES does not appear to be a peptide structure.", None, None
|
| 1079 |
+
|
| 1080 |
+
try:
|
| 1081 |
+
# Create molecule
|
| 1082 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 1083 |
+
if mol is None:
|
| 1084 |
+
return "Error: Invalid SMILES notation.", None, None
|
| 1085 |
+
|
| 1086 |
+
# Generate 3D structures if requested
|
| 1087 |
+
if generate_3d:
|
| 1088 |
+
generator = PeptideStructureGenerator()
|
| 1089 |
+
|
| 1090 |
+
try:
|
| 1091 |
+
# Generate ETKDG structure
|
| 1092 |
+
mol_etkdg = generator.generate_structure_etkdg(smiles)
|
| 1093 |
+
etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
|
| 1094 |
+
writer = Chem.SDWriter(etkdg_path)
|
| 1095 |
+
writer.write(mol_etkdg)
|
| 1096 |
+
writer.close()
|
| 1097 |
+
structure_files.append(etkdg_path)
|
| 1098 |
+
|
| 1099 |
+
# Generate UFF structure if requested
|
| 1100 |
+
if use_uff:
|
| 1101 |
+
mol_uff = generator.generate_structure_uff(smiles)
|
| 1102 |
+
uff_path = os.path.join(temp_dir, "structure_uff.sdf")
|
| 1103 |
+
writer = Chem.SDWriter(uff_path)
|
| 1104 |
+
writer.write(mol_uff)
|
| 1105 |
+
writer.close()
|
| 1106 |
+
structure_files.append(uff_path)
|
| 1107 |
+
|
| 1108 |
+
except Exception as e:
|
| 1109 |
+
return f"Error generating 3D structures: {str(e)}", None, None, None
|
| 1110 |
+
|
| 1111 |
+
# Use analyzer to get sequence
|
| 1112 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1113 |
+
|
| 1114 |
+
# Process segments and build sequence
|
| 1115 |
+
sequence_parts = []
|
| 1116 |
+
output_text = ""
|
| 1117 |
+
|
| 1118 |
+
# Only include segment analysis in output if requested
|
| 1119 |
+
if show_segment_details:
|
| 1120 |
+
output_text += "Segment Analysis:\n"
|
| 1121 |
+
for i, segment in enumerate(segments):
|
| 1122 |
+
output_text += f"\nSegment {i}:\n"
|
| 1123 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1124 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1125 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1126 |
+
|
| 1127 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1128 |
+
if residue:
|
| 1129 |
+
if mods:
|
| 1130 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1131 |
+
else:
|
| 1132 |
+
sequence_parts.append(residue)
|
| 1133 |
+
output_text += f"Identified as: {residue}\n"
|
| 1134 |
+
output_text += f"Modifications: {mods}\n"
|
| 1135 |
+
else:
|
| 1136 |
+
output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
|
| 1137 |
+
output_text += "\n"
|
| 1138 |
+
else:
|
| 1139 |
+
# Just build sequence without detailed analysis in output
|
| 1140 |
+
for segment in segments:
|
| 1141 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1142 |
+
if residue:
|
| 1143 |
+
if mods:
|
| 1144 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1145 |
+
else:
|
| 1146 |
+
sequence_parts.append(residue)
|
| 1147 |
+
|
| 1148 |
+
# Check if cyclic using analyzer's method
|
| 1149 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1150 |
+
three_letter = '-'.join(sequence_parts)
|
| 1151 |
+
one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
|
| 1152 |
+
|
| 1153 |
+
if is_cyclic:
|
| 1154 |
+
three_letter = f"cyclo({three_letter})"
|
| 1155 |
+
one_letter = f"cyclo({one_letter})"
|
| 1156 |
+
|
| 1157 |
+
# Create cyclic structure visualization
|
| 1158 |
+
img_cyclic = annotate_cyclic_structure(mol, three_letter)
|
| 1159 |
+
|
| 1160 |
+
# Create linear representation if requested
|
| 1161 |
+
img_linear = None
|
| 1162 |
+
if show_linear:
|
| 1163 |
+
fig_linear = create_enhanced_linear_viz(three_letter, smiles)
|
| 1164 |
+
buf = BytesIO()
|
| 1165 |
+
fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
| 1166 |
+
buf.seek(0)
|
| 1167 |
+
img_linear = Image.open(buf)
|
| 1168 |
+
plt.close(fig_linear)
|
| 1169 |
+
|
| 1170 |
+
# Add summary to output
|
| 1171 |
+
summary = "Summary:\n"
|
| 1172 |
+
summary += f"Sequence: {three_letter}\n"
|
| 1173 |
+
summary += f"One-letter code: {one_letter}\n"
|
| 1174 |
+
summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1175 |
+
#if is_cyclic:
|
| 1176 |
+
#summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1177 |
+
#summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1178 |
+
|
| 1179 |
+
if structure_files:
|
| 1180 |
+
summary += "\n3D Structures Generated:\n"
|
| 1181 |
+
for filepath in structure_files:
|
| 1182 |
+
summary += f"- {os.path.basename(filepath)}\n"
|
| 1183 |
+
|
| 1184 |
+
return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
|
| 1185 |
+
|
| 1186 |
+
except Exception as e:
|
| 1187 |
+
return f"Error processing SMILES: {str(e)}", None, None, None
|
| 1188 |
+
|
| 1189 |
+
# Handle file input
|
| 1190 |
+
if file_obj is not None:
|
| 1191 |
+
try:
|
| 1192 |
+
# Handle file content
|
| 1193 |
+
if hasattr(file_obj, 'name'):
|
| 1194 |
+
with open(file_obj.name, 'r') as f:
|
| 1195 |
+
content = f.read()
|
| 1196 |
+
else:
|
| 1197 |
+
content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
|
| 1198 |
+
|
| 1199 |
+
output_text = ""
|
| 1200 |
+
for line in content.splitlines():
|
| 1201 |
+
smiles = line.strip()
|
| 1202 |
+
if smiles:
|
| 1203 |
+
# Check if it's a peptide
|
| 1204 |
+
if not analyzer.is_peptide(smiles):
|
| 1205 |
+
output_text += f"Skipping non-peptide SMILES: {smiles}\n"
|
| 1206 |
+
continue
|
| 1207 |
+
|
| 1208 |
+
# Process this SMILES
|
| 1209 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1210 |
+
sequence_parts = []
|
| 1211 |
+
|
| 1212 |
+
# Add segment details if requested
|
| 1213 |
+
if show_segment_details:
|
| 1214 |
+
output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
|
| 1215 |
+
for i, segment in enumerate(segments):
|
| 1216 |
+
output_text += f"\nSegment {i}:\n"
|
| 1217 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1218 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1219 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1220 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1221 |
+
if residue:
|
| 1222 |
+
if mods:
|
| 1223 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1224 |
+
else:
|
| 1225 |
+
sequence_parts.append(residue)
|
| 1226 |
+
output_text += f"Identified as: {residue}\n"
|
| 1227 |
+
output_text += f"Modifications: {mods}\n"
|
| 1228 |
+
else:
|
| 1229 |
+
for segment in segments:
|
| 1230 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1231 |
+
if residue:
|
| 1232 |
+
if mods:
|
| 1233 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1234 |
+
else:
|
| 1235 |
+
sequence_parts.append(residue)
|
| 1236 |
+
|
| 1237 |
+
# Get cyclicity and create sequence
|
| 1238 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1239 |
+
sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
|
| 1240 |
+
|
| 1241 |
+
output_text += f"\nSummary for SMILES: {smiles}\n"
|
| 1242 |
+
output_text += f"Sequence: {sequence}\n"
|
| 1243 |
+
output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1244 |
+
if is_cyclic:
|
| 1245 |
+
output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1246 |
+
#output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1247 |
+
output_text += "-" * 50 + "\n"
|
| 1248 |
+
|
| 1249 |
+
return output_text, None, None
|
| 1250 |
+
|
| 1251 |
+
except Exception as e:
|
| 1252 |
+
return f"Error processing file: {str(e)}", None, None
|
| 1253 |
+
|
| 1254 |
+
return "No input provided.", None, None
|
| 1255 |
+
|
utils/generate_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def mask_for_de_novo(config, sequence_length):
|
| 9 |
+
if config.vocab == 'helm':
|
| 10 |
+
return "[MASK]" * sequence_length
|
| 11 |
+
elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
|
| 12 |
+
return ["<mask>"] * sequence_length
|
| 13 |
+
else:
|
| 14 |
+
return ["[MASK]"] * sequence_length
|
| 15 |
+
|
| 16 |
+
def generate_de_novo(sequence_length, tokenizer, model):
|
| 17 |
+
masked_sequence = mask_for_de_novo(sequence_length)
|
| 18 |
+
inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
|
| 19 |
+
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
logits = model(**inputs).logits
|
| 22 |
+
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
|
| 23 |
+
logits_at_masks = logits[0, mask_token_indices]
|
| 24 |
+
|
| 25 |
+
pred_tokens = []
|
| 26 |
+
for i in mask_token_indices:
|
| 27 |
+
topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
|
| 28 |
+
probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
|
| 29 |
+
predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
|
| 30 |
+
predicted_token_id = topk_indices[predicted_index].item()
|
| 31 |
+
predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
|
| 32 |
+
pred_tokens.append(predicted_token)
|
| 33 |
+
|
| 34 |
+
generated_sequence = ''.join(pred_tokens)
|
| 35 |
+
perplexity = calculate_perplexity(model, tokenizer, generated_sequence)
|
| 36 |
+
|
| 37 |
+
return (generated_sequence, perplexity)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
|
| 41 |
+
total_loss = 0.0
|
| 42 |
+
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
|
| 43 |
+
|
| 44 |
+
for i in mask_token_indices:
|
| 45 |
+
masked_input = tensor_input.clone()
|
| 46 |
+
masked_input[0, i] = tokenizer.mask_token_id
|
| 47 |
+
|
| 48 |
+
labels = torch.full(tensor_input.shape, -100).to(model.device)
|
| 49 |
+
labels[0, i] = tensor_input[0, i]
|
| 50 |
+
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
outputs = model(masked_input, labels=labels)
|
| 53 |
+
total_loss += outputs.loss.item()
|
| 54 |
+
|
| 55 |
+
num_mask_tokens = len(mask_token_indices)
|
| 56 |
+
if num_mask_tokens == 0:
|
| 57 |
+
perplexity = 10000
|
| 58 |
+
else:
|
| 59 |
+
avg_loss = total_loss / num_mask_tokens
|
| 60 |
+
perplexity = math.exp(avg_loss)
|
| 61 |
+
|
| 62 |
+
return perplexity
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, pepclm_model, device):
|
| 66 |
+
og_embeddings = pepclm_model.roformer.encoder(original_sequence)
|
| 67 |
+
new_embeddings = pepclm_model.roformer.encoder(generated_sequence)
|
| 68 |
+
|
| 69 |
+
sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
|
| 70 |
+
cosine_similarity = torch.mean(sequence_similarity).item()
|
| 71 |
+
return cosine_similarity
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def calculate_hamming_dist(original_sequence, generated_sequence):
|
| 75 |
+
generated_sequence = generated_sequence
|
| 76 |
+
original_sequence = original_sequence
|
| 77 |
+
return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))
|