Sophia Tang commited on
Commit
e54915d
·
1 Parent(s): 9ab0e48

model upload

Browse files
dataloading_for_dynamic_batching.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from datasets import Dataset,load_from_disk
5
+ import sys
6
+ import lightning.pytorch as pl
7
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
8
+ from functools import partial
9
+ import re
10
+
11
+
12
+ class DynamicBatchingDataset(Dataset):
13
+ def __init__(self, dataset_dict, tokenizer):
14
+ print('Initializing dataset...')
15
+ self.dataset_dict = {
16
+ 'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
17
+ 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
18
+ 'labels': dataset_dict['labels']
19
+ }
20
+ self.tokenizer = tokenizer
21
+
22
+ def __len__(self):
23
+ return len(self.dataset_dict['attention_mask'])
24
+
25
+ def __getitem__(self, idx):
26
+ if isinstance(idx, int):
27
+ return {
28
+ 'input_ids': self.dataset_dict['input_ids'][idx],
29
+ 'attention_mask': self.dataset_dict['attention_mask'][idx],
30
+ 'labels': self.dataset_dict['labels'][idx]
31
+ }
32
+ elif isinstance(idx, list):
33
+ return {
34
+ 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
35
+ 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
36
+ 'labels': [self.dataset_dict['labels'][i] for i in idx]
37
+ }
38
+ else:
39
+ raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
40
+
41
+ class CustomDataModule(pl.LightningDataModule):
42
+ def __init__(self, dataset_path, tokenizer):
43
+ super().__init__()
44
+ self.dataset = load_from_disk(dataset_path)
45
+ self.tokenizer = tokenizer
46
+
47
+ def peptide_bond_mask(self, smiles_list):
48
+ """
49
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
50
+ of recognized bonds in the positions dictionary and 0 elsewhere.
51
+
52
+ Args:
53
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
54
+
55
+ Returns:
56
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
57
+ """
58
+ # Initialize the batch mask
59
+ batch_size = len(smiles_list)
60
+ max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
61
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
62
+
63
+ bond_patterns = [
64
+ (r'OC\(=O\)', 'ester'),
65
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
66
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
67
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
68
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
69
+ (r'C\(=O\)N[12]?', 'peptide')
70
+ ]
71
+
72
+ for batch_idx, smiles in enumerate(smiles_list):
73
+ positions = []
74
+ used = set()
75
+
76
+ # Identify bonds
77
+ for pattern, bond_type in bond_patterns:
78
+ for match in re.finditer(pattern, smiles):
79
+ if not any(p in range(match.start(), match.end()) for p in used):
80
+ positions.append({
81
+ 'start': match.start(),
82
+ 'end': match.end(),
83
+ 'type': bond_type,
84
+ 'pattern': match.group()
85
+ })
86
+ used.update(range(match.start(), match.end()))
87
+
88
+ # Update the mask for the current SMILES
89
+ for pos in positions:
90
+ mask[batch_idx, pos['start']:pos['end']] = 1
91
+
92
+ return mask
93
+
94
+ def peptide_token_mask(self, smiles_list, token_lists):
95
+ """
96
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
97
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
98
+
99
+ Args:
100
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
101
+ token_lists: List of tokenized SMILES strings (split into tokens).
102
+
103
+ Returns:
104
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
105
+ """
106
+ # Initialize the batch mask
107
+ batch_size = len(smiles_list)
108
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
109
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
110
+ atomwise_masks = self.peptide_bond_mask(smiles_list)
111
+
112
+
113
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
114
+ token_seq = token_lists[batch_idx]
115
+ atom_idx = 0
116
+
117
+ for token_idx, token in enumerate(token_seq):
118
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
119
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
120
+ tokenized_masks[batch_idx][token_idx] = 1
121
+ atom_idx += len(token)
122
+
123
+ return tokenized_masks
124
+
125
+ def collate_fn(self, batch):
126
+ item = batch[0]
127
+
128
+ token_array = self.tokenizer.get_token_split(item['input_ids'])
129
+ bond_mask = self.peptide_token_mask(item['labels'], token_array)
130
+
131
+ return {
132
+ 'input_ids': item['input_ids'],
133
+ 'attention_mask': item['attention_mask'],
134
+ 'bond_mask': bond_mask
135
+ }
136
+
137
+ def train_dataloader(self):
138
+ train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer)
139
+ return DataLoader(
140
+ train_dataset,
141
+ batch_size=1,
142
+ collate_fn=self.collate_fn, # Use the instance method
143
+ shuffle=True,
144
+ num_workers=12,
145
+ pin_memory=True
146
+ )
147
+
148
+ def val_dataloader(self):
149
+ val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer)
150
+ return DataLoader(
151
+ val_dataset,
152
+ batch_size=1,
153
+ collate_fn=self.collate_fn, # Use the instance method
154
+ num_workers=8,
155
+ pin_memory=True
156
+ )
dataset.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import torch
4
+
5
+ import utils
6
+
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import lightning.pytorch as pl
9
+ from functools import partial
10
+ import sys
11
+
12
+ class CustomDataset(Dataset):
13
+ def __init__(self, dataset, indices):
14
+ self.dataset = dataset
15
+ self.indices = indices
16
+
17
+ def __len__(self):
18
+ return len(self.indices)
19
+
20
+ def __getitem__(self, idx):
21
+ actual_idx = int(self.indices[idx])
22
+ item = self.dataset[actual_idx]
23
+ return item
24
+
25
+
26
+ # for weighting losses of peptide bonds
27
+ def peptide_bond_mask(smiles_list):
28
+ """
29
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
30
+ of recognized bonds in the positions dictionary and 0 elsewhere.
31
+
32
+ Args:
33
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
34
+
35
+ Returns:
36
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
37
+ """
38
+ # Initialize the batch mask
39
+ batch_size = len(smiles_list)
40
+ max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
41
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
42
+
43
+ bond_patterns = [
44
+ (r'OC\(=O\)', 'ester'),
45
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
46
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
47
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
48
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
49
+ (r'C\(=O\)N[12]?', 'peptide')
50
+ ]
51
+
52
+ for batch_idx, smiles in enumerate(smiles_list):
53
+ positions = []
54
+ used = set()
55
+
56
+ # Identify bonds
57
+ for pattern, bond_type in bond_patterns:
58
+ for match in re.finditer(pattern, smiles):
59
+ if not any(p in range(match.start(), match.end()) for p in used):
60
+ positions.append({
61
+ 'start': match.start(),
62
+ 'end': match.end(),
63
+ 'type': bond_type,
64
+ 'pattern': match.group()
65
+ })
66
+ used.update(range(match.start(), match.end()))
67
+
68
+ # Update the mask for the current SMILES
69
+ for pos in positions:
70
+ mask[batch_idx, pos['start']:pos['end']] = 1
71
+
72
+ return mask
73
+
74
+ def peptide_token_mask(smiles_list, token_lists):
75
+ """
76
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
77
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
78
+
79
+ Args:
80
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
81
+ token_lists: List of tokenized SMILES strings (split into tokens).
82
+
83
+ Returns:
84
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
85
+ """
86
+ # Initialize the batch mask
87
+ batch_size = len(smiles_list)
88
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
89
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
90
+ atomwise_masks = peptide_bond_mask(smiles_list)
91
+
92
+
93
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
94
+ token_seq = token_lists[batch_idx]
95
+ atom_idx = 0
96
+
97
+ for token_idx, token in enumerate(token_seq):
98
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
99
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
100
+ tokenized_masks[batch_idx][token_idx] = 1
101
+ atom_idx += len(token)
102
+
103
+ return tokenized_masks
104
+
105
+ def extract_amino_acid_sequence(helm_string):
106
+ """
107
+ Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
108
+ removing any brackets around each amino acid.
109
+
110
+ Args:
111
+ helm_string (str): The HELM notation string for a peptide.
112
+
113
+ Returns:
114
+ list: A list containing each amino acid in sequence without brackets.
115
+ """
116
+ # Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
117
+ matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
118
+
119
+ if matches:
120
+ # Join all matched sequences and split by dots to get individual amino acids
121
+ amino_acid_sequence = []
122
+ for match in matches:
123
+ sequence = match.replace('[', '').replace(']', '').split('.')
124
+ amino_acid_sequence.extend(sequence)
125
+ return amino_acid_sequence
126
+ else:
127
+ return "Invalid HELM notation or no peptide sequence found."
128
+
129
+ def helm_collate_fn(batch, tokenizer):
130
+ sequences = [item['HELM'] for item in batch]
131
+
132
+ max_len = 0
133
+ for sequence in sequences:
134
+ seq_len = len(extract_amino_acid_sequence(sequence))
135
+ if seq_len > max_len:
136
+ max_len = seq_len
137
+
138
+ tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
139
+
140
+ return {
141
+ 'input_ids': tokens['input_ids'],
142
+ 'attention_mask': tokens['attention_mask']
143
+ }
144
+
145
+
146
+ def collate_fn(batch, tokenizer):
147
+ """Standard data collator that truncates/pad sequences based on max_length"""
148
+ valid_sequences = []
149
+ valid_items = []
150
+
151
+ for item in batch:
152
+ try:
153
+ test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
154
+ valid_sequences.append(item['SMILES'])
155
+ valid_items.append(item)
156
+ except Exception as e:
157
+ print(f"Skipping sequence due to: {str(e)}")
158
+ continue
159
+
160
+ #sequences = [item['SMILES'] for item in batch]
161
+ #max_len = max([len(seq) for seq in sequences])
162
+ #labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
163
+
164
+ tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
165
+
166
+ token_array = tokenizer.get_token_split(tokens['input_ids'])
167
+ bond_mask = peptide_token_mask(valid_sequences, token_array)
168
+ #attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
169
+
170
+ return {
171
+ 'input_ids': tokens['input_ids'],
172
+ 'attention_mask': tokens['attention_mask'],
173
+ 'bond_mask': bond_mask
174
+ }
175
+
176
+
177
+ class CustomDataModule(pl.LightningDataModule):
178
+ def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
179
+ super().__init__()
180
+ self.train_dataset = train_dataset
181
+ self.val_dataset = val_dataset
182
+ #self.test_dataset = test_dataset
183
+ self.batch_size = batch_size
184
+ self.tokenizer = tokenizer
185
+ self.collate_fn = collate_fn
186
+
187
+ def train_dataloader(self):
188
+ return DataLoader(self.train_dataset,
189
+ batch_size=self.batch_size,
190
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
191
+ num_workers=8,
192
+ pin_memory=True
193
+ )
194
+
195
+
196
+ def val_dataloader(self):
197
+ return DataLoader(self.val_dataset,
198
+ batch_size=self.batch_size,
199
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
200
+ num_workers=8,
201
+ pin_memory=True
202
+ )
203
+
204
+ """def test_dataloader(self):
205
+ return DataLoader(self.test_dataset, batch_size=self.batch_size,
206
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
207
+ num_workers=8, pin_memory=True)"""
diffusion.py ADDED
@@ -0,0 +1,1130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sys
3
+ import itertools
4
+ import time
5
+ import torch
6
+ from torch import Tensor
7
+ import math
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import random as rd
11
+ import lightning as L
12
+ from torch.distributions.categorical import Categorical
13
+ import torchmetrics
14
+ from dataclasses import dataclass
15
+ import gc
16
+ import pickle
17
+ import utils.utils as utils
18
+
19
+ import dataset as dataloader
20
+ import models.helmgpt as helmgpt
21
+ import models.peptideclm as peptideclm
22
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
23
+ import noise_schedule
24
+ from torch.optim.lr_scheduler import _LRScheduler
25
+ import models.roformer as roformer
26
+ from utils.filter import PeptideAnalyzer
27
+
28
+
29
+ @dataclass
30
+ class Loss:
31
+ loss: torch.FloatTensor
32
+ nlls: torch.FloatTensor
33
+ attn_mask: torch.FloatTensor
34
+
35
+
36
+ class NLL(torchmetrics.aggregation.MeanMetric):
37
+ pass
38
+
39
+
40
+ class BPD(NLL):
41
+ def compute(self) -> Tensor:
42
+ """Computes the bits per dimension.
43
+
44
+ Returns:
45
+ bpd
46
+ """
47
+ return self.mean_value / self.weight / math.log(2)
48
+
49
+
50
+ class Perplexity(NLL):
51
+ def compute(self) -> Tensor:
52
+ """Computes the Perplexity.
53
+
54
+ Returns:
55
+ Perplexity
56
+ """
57
+ return torch.exp(self.mean_value / self.weight)
58
+
59
+
60
+ class Diffusion(L.LightningModule):
61
+ def __init__(self, config, tokenizer):
62
+
63
+ super().__init__()
64
+ self.config = config
65
+ #self.save_hyperparameters()
66
+
67
+ # PeptideCLM tokenizer
68
+ self.tokenizer = tokenizer
69
+ self.vocab_size = self.tokenizer.vocab_size
70
+ self.mask_token_id = self.tokenizer.mask_token_id
71
+ self.sampler = self.config.sampling.predictor
72
+ self.analyzer = PeptideAnalyzer()
73
+
74
+ # backbone LM PeptideCLM model
75
+ if self.config.backbone == 'peptideclm':
76
+ self.backbone = peptideclm.EncoderWrapper(self.tokenizer)
77
+ self.backbone.unfreeze_all_layers()
78
+ self.backbone = torch.compile(self.backbone)
79
+ elif self.config.backbone == 'helmgpt':
80
+ self.backbone = helmgpt.GPT(self.config, self.tokenizer)
81
+ #self.backbone = torch.compile(self.backbone)
82
+ elif self.config.backbone == 'roformer':
83
+ self.backbone = roformer.Roformer(self.config, self.tokenizer)
84
+ self.backbone.unfreeze_all_layers()
85
+ elif self.config.backbone == 'finetune_roformer':
86
+ self.backbone = roformer.Roformer(self.config, self.tokenizer)
87
+ self.backbone.freeze_model()
88
+ self.backbone.unfreeze_n_layers(n=8)
89
+ else:
90
+ Exception('invalid backbone config')
91
+
92
+ self.neg_infinity = -1000000.0
93
+ self.T = config.T
94
+ # noise schedule for non-peptide bond tokens (default to log-linear)
95
+ self.noise = noise_schedule.get_noise(config)
96
+ # noise schedule for peptide bonds (log-polynomial)
97
+ self.bond_noise = noise_schedule.LogPolyNoise()
98
+ self.time_conditioning = self.config.time_conditioning
99
+ self.fast_forward_epochs = None
100
+ self.fast_forward_batches = None
101
+
102
+ self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
103
+ self.gen_ppl_metric = Perplexity()
104
+
105
+ self.lr = self.config.optim.lr
106
+ self.sampling_eps = self.config.training.sampling_eps
107
+
108
+ metrics = torchmetrics.MetricCollection({
109
+ 'nll': NLL(),
110
+ 'bpd': BPD(),
111
+ 'ppl': Perplexity(),
112
+ })
113
+ metrics.set_dtype(torch.float64)
114
+ self.train_metrics = metrics.clone(prefix='trainer/')
115
+ self.valid_metrics = metrics.clone(prefix='val/')
116
+ self.test_metrics = metrics.clone(prefix='test/')
117
+
118
+
119
+ """LOSS"""
120
+
121
+ """LOSS FOR INVALID PEPTIDES"""
122
+
123
+ @torch.no_grad()
124
+ def conditional_gumbel(self, logits, D, k):
125
+ """
126
+ Outputs k samples of Q = StandardGumbel(), such that argmax(logits
127
+ + Q) is given by D (one-hot vector).
128
+
129
+ Input:
130
+ - logits: Tensor of shape (batch_size, seq_len, vocab_size)
131
+ - D: One-hot tensor of shape (batch_size, seq_len, vocab_size)
132
+ - k: Number of Gumbel samples
133
+
134
+ Output:
135
+ - Adjusted logits with shape (k, batch_size, seq_len, vocab_size)
136
+ """
137
+
138
+ # iid. exponential samples of shape (k, batch_size, seq_len, vocab_size)
139
+ E = torch.distributions.exponential.Exponential(rate=torch.ones_like(logits)).sample([k])
140
+
141
+ # E of the chosen class, shape (k, batch_size, seq_len, 1)
142
+ Ei = (D * E).sum(dim=-1, keepdim=True)
143
+
144
+ # Partition function (normalization constant), shape (batch_size, seq_len, 1)
145
+ Z = logits.exp().sum(dim=-1, keepdim=True)
146
+
147
+ # Adjusted logits for Gumbel distribution
148
+ adjusted = (
149
+ D * (-torch.log(Ei) + torch.log(Z)) +
150
+ (1 - D) * -torch.log(E / logits.exp() + Ei / Z)
151
+ )
152
+
153
+ # Adjusted logits shape: (k, batch_size, seq_len, vocab_size)
154
+ return adjusted - logits
155
+
156
+ def replace_gradient(self, value, surrogate):
157
+ """
158
+ Returns `value` but backpropagates gradients through `surrogate`.
159
+ """
160
+ return surrogate + (value - surrogate).detach()
161
+
162
+ def gumbel_rao(self, logits, k, temp=1.0, I=None):
163
+ """
164
+ Returns a categorical sample from logits (over axis=-1) as a
165
+ one-hot vector, with gumbel-rao gradient.
166
+
167
+ Input:
168
+ - logits: Tensor of shape (batch_size, seq_len, vocab_size)
169
+ - k: Number of Gumbel samples for Rao-Blackwellization
170
+ - temp: Temperature for softmax
171
+ - I: Optional, precomputed categorical sample tensor of shape (batch_size, seq_len)
172
+
173
+ Output:
174
+ - One-hot tensor of shape (batch_size, seq_len, vocab_size)
175
+ with Gumbel-Rao gradient.
176
+ """
177
+ assert logits.shape[-1] == self.tokenizer.vocab_size
178
+ vocab_size = logits.shape[-1]
179
+
180
+ if I is None:
181
+ # Sample indices for each token in the batch
182
+ I = torch.distributions.categorical.Categorical(logits=logits).sample() # (batch_size, seq_len)
183
+
184
+ # Convert indices to one-hot encodings, shape (batch_size, seq_len, vocab_size)
185
+ D = torch.nn.functional.one_hot(I, num_classes=vocab_size).float()
186
+
187
+ # Generate k different adjusted logits that all evaluate to the same sequence
188
+ adjusted = logits + self.conditional_gumbel(logits, D, k=k) # (k, batch_size, seq_len, vocab_size)
189
+
190
+ # Compute the surrogate by averaging softmax across k samples
191
+ surrogate = torch.nn.functional.softmax(adjusted / temp, dim=-1).mean(dim=0) # (batch_size, seq_len, vocab_size)
192
+
193
+ # Return one-hot representation with surrogate gradient
194
+ return self.replace_gradient(D, surrogate)
195
+
196
+ def compute_invalid_loss(self, logits, k=None, temp=None):
197
+ """
198
+ Penalizes logits that produce invalid sequences using the `is_peptide` function,
199
+ scaling penalties inversely with token probabilities.
200
+
201
+ Args:
202
+ logits: Tensor of shape [batch_size, seq_len, vocab_size].
203
+ k: Number of samples for Gumbel-Rao.
204
+ temp: Temperature for softmax.
205
+
206
+ Returns:
207
+ loss: A scalar tensor representing the total loss for invalid sequences.
208
+ """
209
+
210
+ #samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
211
+
212
+ # Convert logits to sequences using the tokenizer
213
+ batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
214
+ sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
215
+
216
+ # Check validity of each sampled sequence (not differentiable)
217
+ penalties = torch.tensor(
218
+ [1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
219
+ dtype=torch.float32,
220
+ device=self.device
221
+ )
222
+ #print(penalties)
223
+
224
+ # Compute probabilities for each token (batch_size, seq_length)
225
+ sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
226
+
227
+ # scale penalties by softmax probability of sampled tokens
228
+ scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
229
+
230
+ return scaled_penalty.to(self.device)
231
+
232
+ """DIFFUSION LOSS"""
233
+
234
+ def sample_t(self, n, device):
235
+ """
236
+ Sample random time steps for batch training
237
+ """
238
+ # sample values uniformly at random from [0, 1)
239
+ eps_t = torch.rand(n, device=device)
240
+ # antithetic sampling: reduce variance by pairing each sample with complementary sample
241
+ if self.config.training.antithetic_sampling:
242
+ # compute interval between sampled time steps
243
+ offset = torch.arange(n, device=device) / n
244
+ # ensure that each eps value is evenly spaced between [0, 1)
245
+ eps_t = ((eps_t / n) + offset) % 1
246
+
247
+ # ensures values are not exactly 0 or 1
248
+ t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
249
+
250
+ return t
251
+
252
+ """def mask_samples(self, x0, mask_prob):
253
+
254
+ # generate array of values in range [0, 1] uniformly at random
255
+ # will be used to determine which tokens are masked
256
+ mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
257
+
258
+ # select tokens to mask if the random value in mask_indices is less than mask_prob
259
+ # this will mask approximately the fraction of tokens indicated by mask_prob
260
+ zt = torch.where(mask_indices < mask_prob, self.mask_token_id, x0)
261
+
262
+ return zt"""
263
+
264
+ def q_xt(self, x, mask_prob):
265
+ """Computes the noisy sample xt.
266
+
267
+ Args:
268
+ x: int torch.Tensor with shape (batch_size,
269
+ diffusion_model_input_length), input.
270
+ move_chance: float torch.Tensor with shape (batch_size, 1).
271
+ """
272
+
273
+ actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
274
+ #print(actual_seq_length)
275
+
276
+ max_mask_length = (actual_seq_length * 0.75).long()
277
+
278
+ mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
279
+
280
+ restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
281
+
282
+ for i in range(x.shape[0]):
283
+ true_positions = torch.where(mask_indices[i])[0]
284
+ if len(true_positions) > max_mask_length[i]:
285
+ selected_positions = true_positions[:max_mask_length[i].item()]
286
+ restricted_move_indices[i, selected_positions] = True
287
+ else:
288
+ restricted_move_indices[i] = mask_indices[i]
289
+
290
+ xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
291
+
292
+ return xt
293
+
294
+
295
+ def sample_prior(self, *batch_dims):
296
+ """
297
+ Returns array of fully masked sequences with same shape as input
298
+ """
299
+ return self.mask_token_id * torch.ones(* batch_dims, dtype=torch.int64)
300
+
301
+
302
+ """COMPUTING LOSS"""
303
+
304
+ def compute_diffusion_loss(self, model_output, xt, x0, t):
305
+ """
306
+ Computes diffusion loss term in ELBO
307
+ (evaluates how accurately the model predicts the token probabilities at each time step)
308
+
309
+ Inputs:
310
+ - model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
311
+ - zt: corrupted version of original input x0 at timestep t
312
+ - x0: original input sequence
313
+ - t: timestep
314
+ """
315
+ # compute interval between each timestep
316
+ dt = 1 / self.T
317
+
318
+ # compute vectorized alpha scaling terms for the logits at timestep s and t
319
+ alpha_t = 1 - t + torch.zeros_like(x0)
320
+ # s = t - dt
321
+ alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
322
+
323
+ # gather vector of log-probabilities for each token in x0
324
+ # log<x_theta, x>
325
+ log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
326
+ # gather log-probabillities for assigning a masked token at each position in the sequence at time t
327
+ # log<x_theta, m>
328
+ log_x_theta_at_m = model_output[:, :, self.mask_token_id]
329
+ # obtain non-log probability of assigning a masked token
330
+ # <xt, m>
331
+ x_theta_at_m = log_x_theta_at_m.exp()
332
+
333
+ # first term of diffusion loss
334
+ term_1_coef = dt / t
335
+ term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
336
+ term_1_log_denom = log_x_theta_at_x0
337
+
338
+ # second term of diffusion loss
339
+ term_2_coef = 1 - (dt / t)
340
+ term_2_log_numerator = term_1_log_numerator
341
+ term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
342
+
343
+ L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
344
+ term_2_coef * (term_2_log_numerator - term_2_log_denom))
345
+
346
+ # multiply by <zt, m> term
347
+ L_vb = L_vb_masked * (xt == self.mask_token_id)
348
+
349
+ # scale by T and return
350
+ return self.T * L_vb
351
+
352
+ """def _forward_pass_diffusion(self, x0, attn_mask, mask=None):
353
+
354
+ print(x0)
355
+ # randomly sample time steps to start the denoising process for each x0 in batch
356
+ t = self.sample_t(x0.shape[0], x0.device)
357
+
358
+ # if we are training the intermediate transition blocks
359
+ if self.T > 0:
360
+ # scale by total timesteps T and cast to integer
361
+ t = (t * self.T).to(torch.int)
362
+ # scale down by T to get a multiple of 1/T
363
+ t = t / self.T
364
+ # add 1/T to ensure no 0 values
365
+ t += (1 / self.T)
366
+
367
+ # get noise and rate of noise at timestep t
368
+ sigma, dsigma = self.noise(t)
369
+ time_conditioning = sigma[:, None]
370
+ # get masking probabilities for all tokens for each batch
371
+ mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
372
+
373
+ # get masked samples at different timesteps
374
+ if mask is None: zt = self.q_xt(x0, mask_prob)
375
+ else: zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id))
376
+
377
+ model_output = self.forward(zt, attn_mask, time_conditioning)
378
+
379
+ utils.print_nans(model_output, 'model_output')
380
+
381
+ if self.T > 0:
382
+ # compute diffusion loss
383
+ diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
384
+ return diffusion_loss
385
+
386
+ # compute loss for the final that converts from z0 to x0
387
+ # -log(p_theta)
388
+ # get (batch_size, L) array of log-probabilities
389
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) # (B, L)
390
+
391
+
392
+ return -log_p_theta * (dsigma / torch.expm1(sigma))[:, None]"""
393
+
394
+ def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
395
+ """
396
+ Training reverse diffusion model x_theta to reconstruct samples x0
397
+
398
+ bond_mask: (batch, seq_length)
399
+ """
400
+ # randomly sample time steps to start the denoising process for each x0 in batch
401
+ t = self.sample_t(x0.shape[0], self.device)
402
+
403
+ # if we are training the intermediate transition blocks
404
+ if self.T > 0:
405
+ # scale by total timesteps T and cast to integer
406
+ t = (t * self.T).to(torch.int)
407
+ # scale down by T to get a multiple of 1/T
408
+ t = t / self.T
409
+ # add 1/T to ensure no 0 values
410
+ t += (1 / self.T)
411
+
412
+ # get noise and rate of noise at timestep t
413
+ # sigma = -log(1-t); dsigma = 1 / (1-t)
414
+ sigma, dsigma = self.noise(t)
415
+ time_conditioning = sigma[:, None]
416
+
417
+ # Get masking probabilities for all tokens for each batch
418
+ # log-linear: 1 - alpha = t
419
+ base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
420
+
421
+ if self.config.noise.state_dependent and (bond_mask is not None):
422
+ # log-polynomial masking schedule: alpha = 1 - t^w
423
+ # bond_sigma = -log(1-t^w) for w = 3 (default)
424
+ # bond_dsigma = -wt^(w-1) / (1-t^w)
425
+ bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
426
+ # expand dimensions for broadcasting to (B, L)
427
+ bond_sigma = bond_sigma[:, None]
428
+ bond_dsigma = bond_dsigma[:, None]
429
+ sigma = sigma[:, None]
430
+ dsigma = dsigma[:, None]
431
+
432
+ # compute masking probability for peptide bonds 1 - bond_alpha = t^w
433
+ bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
434
+ # piece together (B, L) tensor with modified masking prob at peptide-bond locations
435
+ mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
436
+ #print(mask_prob)
437
+ dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
438
+ sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
439
+ else:
440
+ mask_prob = base_mask_prob.to(self.device)
441
+
442
+ # get masked samples at different timesteps
443
+ if mask is None:
444
+ zt = self.q_xt(x0, mask_prob).to(self.device)
445
+ else:
446
+ zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id)).to(self.device)
447
+
448
+ model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
449
+
450
+ # debugging
451
+ assert not torch.isnan(model_output).any()
452
+ assert model_output.is_cuda
453
+ utils.print_nans(model_output, 'model_output')
454
+
455
+ # compute invalid loss
456
+ invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
457
+ #print(invalid_loss)
458
+
459
+ if self.T > 0:
460
+ # compute diffusion loss
461
+ diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
462
+ return diffusion_loss
463
+
464
+ # compute loss for the final that converts from z0 to x0
465
+ # -log(p_theta)
466
+ # get (batch_size, L) array of log-probabilities
467
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
468
+
469
+ if self.config.noise.state_dependent and (bond_mask is not None):
470
+ return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
471
+ else:
472
+ return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
473
+
474
+ def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
475
+ loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
476
+
477
+ # negative log loss
478
+ nlls = loss * attn_mask
479
+
480
+ # count number of tokens
481
+ num_tokens = attn_mask.sum()
482
+
483
+ # compute batch loss
484
+ batch_nll = nlls.sum()
485
+ # compute per token loss
486
+ token_nll = batch_nll / num_tokens
487
+ # return losses
488
+ return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
489
+
490
+ def _compute_loss(self, batch, prefix, bond_mask=None):
491
+
492
+ attn_mask = batch['attention_mask'].to(self.device)
493
+
494
+ if 'mask' in batch:
495
+ mask = batch['mask'].to(self.device)
496
+ else:
497
+ mask = None
498
+
499
+ if 'bond_mask' in batch:
500
+ bond_mask = batch['bond_mask'].to(self.device)
501
+ else:
502
+ bond_mask = None
503
+
504
+ losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
505
+ loss = losses.loss
506
+
507
+ if prefix == 'train':
508
+ self.train_metrics.update(
509
+ losses.nlls.to(self.device),
510
+ losses.attn_mask.to(self.device)
511
+ )
512
+ metrics = self.train_metrics
513
+ elif prefix == 'val':
514
+ self.valid_metrics.update(
515
+ losses.nlls.to(self.device),
516
+ losses.attn_mask.to(self.device)
517
+ )
518
+ metrics = self.valid_metrics
519
+ elif prefix == 'test':
520
+ self.test_metrics.update(losses.nlls, losses.attn_mask)
521
+ metrics = self.test_metrics
522
+ else:
523
+ raise ValueError(f'Invalid prefix: {prefix}')
524
+
525
+ self.log_dict(metrics,
526
+ on_step=False,
527
+ on_epoch=True,
528
+ sync_dist=True)
529
+
530
+ return loss
531
+
532
+
533
+ """SAMPLING"""
534
+
535
+ def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
536
+ # get number of timesteps
537
+ if sample_steps is None:
538
+ sample_steps = self.config.sampling.steps
539
+
540
+ if seq_length is None:
541
+ seq_length = self.config.sampling.seq_length
542
+
543
+ # sample fully masked sequences
544
+ z = self.sample_prior(num_samples, seq_length).to(self.device)
545
+
546
+ # create vector of sample_steps timesteps
547
+ timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
548
+
549
+ # compute interval between timesteps
550
+ dt = (1 - eps) / sample_steps
551
+
552
+ for i in range(sample_steps):
553
+ t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
554
+
555
+ z = self.single_reverse_step(z, t, dt)
556
+
557
+ return z
558
+
559
+
560
+ """SAMPLING STEP"""
561
+
562
+ def single_reverse_step(self, zt, t, dt, attn_mask=None):
563
+ """
564
+ Take a single reverse diffusion step for the expansion step of the MCTS algorithm
565
+ """
566
+ # get sigma values that determine masking prob
567
+ sigma_t, _ = self.noise(t)
568
+ sigma_s, _ = self.noise(t - dt)
569
+
570
+ # reshape sigmas
571
+ if sigma_t.ndim > 1:
572
+ sigma_t = sigma_t.squeeze(-1)
573
+ if sigma_s.ndim > 1:
574
+ sigma_s = sigma_s.squeeze(-1)
575
+ assert sigma_t.ndim == 1, sigma_t.shape
576
+ assert sigma_s.ndim == 1, sigma_s.shape
577
+
578
+ # compute masking probabilities for each timestep
579
+ change_prob_t = 1 - torch.exp(-sigma_t)
580
+ change_prob_s = 1 - torch.exp(-sigma_s)
581
+
582
+ # expand dimensions
583
+ change_prob_t = change_prob_t[:, None, None]
584
+ change_prob_s = change_prob_s[:, None, None]
585
+
586
+ # get prodiction model that outputs token probabilities
587
+ log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
588
+
589
+ # check dimensions match
590
+ assert change_prob_t.ndim == log_p_x0.ndim
591
+
592
+ # compute reverse diffusion probability of being unmasked at timestep s
593
+ # (sigma_s - sigma_t)*x_theta
594
+ q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
595
+
596
+ # compute reverse diffusion probability of remaining masked at timestep s
597
+ # (1 - sigma_s)*m
598
+ q_zs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
599
+
600
+ # sample sequence at timestep s from categorical distribution of q_zs
601
+ z_changed = sample_categorical(q_zs)
602
+
603
+ copy_flag = (zt != self.mask_token_id).to(zt.dtype)
604
+ return (copy_flag * zt) + ((1 - copy_flag) * z_changed)
605
+
606
+ def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
607
+ assert self.config.noise.type == 'loglinear'
608
+ sigma_t, _ = self.noise(t)
609
+
610
+ if t.ndim > 1:
611
+ t = t.squeeze(-1)
612
+ assert t.ndim == 1
613
+
614
+ change_prob_t = t[:, None, None]
615
+ change_prob_s = (t - dt)[:, None, None]
616
+
617
+ assert change_prob_t.ndim == 3, change_prob_t.shape
618
+
619
+ if p_x0 is None:
620
+ p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
621
+
622
+ assert change_prob_t.ndim == p_x0.ndim
623
+
624
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
625
+
626
+ # zero-masking probability
627
+ q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
628
+
629
+ x_changed = sample_categorical(q_xs)
630
+
631
+ copy_flag = (x != self.mask_token_id).to(x.dtype)
632
+
633
+ return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
634
+
635
+ # first step in expansion
636
+ def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
637
+ """
638
+ Generates batch_size different samples from the same starting point for the
639
+ first expansion step of MCTS
640
+
641
+ Args:
642
+ x (_type_): _description_
643
+ t (_type_): _description_
644
+ dt (_type_): _description_
645
+ batch_size (_type_): _description_
646
+ p_x0 (_type_, optional): _description_. Defaults to None.
647
+ attn_mask (_type_, optional): _description_. Defaults to None.
648
+
649
+ Returns:
650
+ _type_: _description_
651
+ """
652
+
653
+ assert self.config.noise.type == 'loglinear'
654
+ sigma_t, _ = self.noise(t)
655
+
656
+ if t.ndim > 1:
657
+ t = t.squeeze(-1)
658
+ assert t.ndim == 1
659
+
660
+ change_prob_t = t[:, None, None]
661
+ change_prob_s = (t - dt)[:, None, None]
662
+
663
+ assert change_prob_t.ndim == 3, change_prob_t.shape
664
+
665
+ if token_array.dim() == 1:
666
+ token_array = token_array.unsqueeze(0)
667
+ #token_array = token_array.repeat(batch_size, 1)
668
+
669
+ attn_mask = torch.ones_like(token_array)
670
+
671
+ if p_x0 is None:
672
+ p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
673
+
674
+ assert change_prob_t.ndim == p_x0.ndim
675
+
676
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
677
+
678
+ # zero-masking probability
679
+ q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
680
+
681
+ # repeat the parent token along the first dimension which will be unmasked into distinct sequences
682
+ token_array = token_array.repeat(batch_size, 1)
683
+
684
+ if self.config.mcts.sampling == 0:
685
+ x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
686
+ else:
687
+ x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
688
+
689
+ copy_flag = (token_array != self.mask_token_id).to(token_array.dtype)
690
+
691
+ return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
692
+
693
+ def _process_sigma(self, sigma):
694
+ if sigma.ndim > 1:
695
+ sigma = sigma.squeeze(-1)
696
+ if not self.time_conditioning:
697
+ sigma = torch.zeros_like(sigma)
698
+ assert sigma.ndim == 1, sigma.shape
699
+ return sigma
700
+
701
+ def forward(self, zt, attn_mask, sigma):
702
+ """
703
+ Predicts the token log-probabilities from zt at time t with noise schedule sigma
704
+ """
705
+ sigma = self._process_sigma(sigma)
706
+
707
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.float32, cache_enabled=True):
708
+ logits = self.backbone(zt, attn_mask).to(self.device)
709
+
710
+ return self.subs_parameterization(logits, zt)
711
+
712
+ def subs_parameterization(self, logits, zt):
713
+ """
714
+ Updates reverse diffusion logits based on SUBS parameterization:
715
+ - zero masking probabilities: -infinity probability of being masked during reverse diffusion
716
+ - carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
717
+
718
+ Args:
719
+ logits: vector of token probabilities for unmasking masked tokens
720
+ zt: partially unmasked sequence at current timestep
721
+ """
722
+ logits[:, :, self.mask_token_id] += self.neg_infinity # [sequence index, current token, next token]
723
+
724
+
725
+ logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
726
+
727
+
728
+ unmasked_indices = (zt != self.mask_token_id).to(self.device) # shape: [200, seq_length]
729
+ batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
730
+ batch_idx = batch_idx.to(self.device)
731
+ seq_idx = seq_idx.to(self.device)
732
+ tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
733
+
734
+ assert logits.is_contiguous(), "logits tensor is not contiguous"
735
+ assert unmasked_indices.shape == zt.shape, "same shape"
736
+ assert not torch.isnan(logits).any(), "NaN values found in logits"
737
+ assert tokens.max() < logits.shape[-1], "token indices out of bounds"
738
+ assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
739
+ assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
740
+ assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
741
+
742
+ logits[batch_idx, seq_idx] = self.neg_infinity # Set everything to -inf first
743
+ logits[batch_idx, seq_idx, tokens] = 0 # Set only the specific token positions to 0
744
+ # return logits with SUBS parameterization
745
+ return logits.to(self.device)
746
+
747
+ """SAMPLING"""
748
+ @torch.no_grad()
749
+ def _sample(self, num_steps=None, eps=1e-5, x_input=None):
750
+ """
751
+ Generate samples
752
+ """
753
+ batch_size_per_gpu = self.config.eval.perplexity_batch_size
754
+
755
+ if num_steps is None:
756
+ num_steps = self.config.sampling.steps
757
+
758
+ if x_input is not None:
759
+ x = x_input['input_ids'].to(self.device)
760
+ attn_mask = x_input['attention_mask'].to(self.device)
761
+ else:
762
+ x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
763
+ attn_mask = torch.ones_like(x).to(self.device)
764
+
765
+
766
+ timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
767
+ dt = (1 - eps) / num_steps
768
+ p_x0_cache = None
769
+ generation_history = [] # used to track which tokens are unmasked
770
+
771
+ for i in range(num_steps):
772
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
773
+ if self.sampler == 'ddpm':
774
+ x = self.single_reverse_step(x, t, dt).to(self.device)
775
+
776
+ elif self.sampler == 'ddpm_cache':
777
+ p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
778
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
779
+ # Disable caching
780
+ p_x0_cache = None
781
+ x = x_next.to(self.device)
782
+ #print(self.tokenizer.decode(x.squeeze()))
783
+ else:
784
+ x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
785
+
786
+ if self.config.sampling.noise_removal:
787
+ t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
788
+ if self.sampler == 'analytic':
789
+ x = self._denoiser_update(x, t).to(self.device)
790
+ else:
791
+ time_conditioning = self.noise(t)[0].to(self.device)
792
+ x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
793
+ #print(self.tokenizer.decode(x.squeeze()))
794
+ return x.to(self.device)
795
+
796
+
797
+ def restore_model_and_sample(self, num_steps, eps=1e-5):
798
+ """Generate samples from the model."""
799
+ self.backbone.eval()
800
+ self.noise.eval()
801
+ samples = self._sample(num_steps=num_steps, eps=eps)
802
+ self.backbone.train()
803
+ self.noise.train()
804
+ return samples
805
+
806
+ def get_score(self, zt, sigma, attn_mask=None):
807
+
808
+ # score(x, t) = p_t(y) / p_t(x)
809
+ # => log score(x, t) = log p_t(y) - log p_t(x)
810
+
811
+ # case 1: x = masked
812
+ # (i) y = unmasked
813
+ # log score(x, t) = log p_\theta(x)|_y + log k
814
+ # where k = exp(- sigma) / (1 - exp(- sigma))
815
+ # (ii) y = masked
816
+ # log score(x, t) = 0
817
+
818
+ # case 2: x = unmasked
819
+ # (i) y != masked, y != x
820
+ # log score(x_i, t) = - inf
821
+ # (ii) y = x
822
+ # log score(x_i, t) = 0
823
+ # (iii) y = masked token
824
+ # log score(x_i, t) = - log k
825
+ # where k = exp(- sigma) / (1 - exp(- sigma))
826
+
827
+ model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
828
+
829
+ log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
830
+ assert log_k.ndim == 1
831
+
832
+ masked_score = model_output + log_k[:, None, None]
833
+ masked_score[:, :, self.mask_token_id] = 0
834
+
835
+ unmasked_score = self.neg_infinity * torch.ones_like(model_output)
836
+ unmasked_score = torch.scatter(
837
+ unmasked_score, -1,
838
+ zt[..., None],
839
+ torch.zeros_like(unmasked_score[..., :1]))
840
+
841
+ unmasked_score[:, :, self.mask_token_id] = - (log_k[:, None] * torch.ones_like(zt))
842
+
843
+ masked_indices = (zt == self.mask_token_id).to(model_output.dtype)[:, :, None]
844
+
845
+ model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
846
+
847
+ return model_output.exp()
848
+
849
+ def _staggered_score(self, score, dsigma):
850
+ score = score.clone()
851
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
852
+ score *= dsigma.exp()[:, None]
853
+ score[..., self.mask_token_id] += extra_const
854
+ return score
855
+
856
+ def _analytic_update(self, x, t, step_size, attn_mask=None):
857
+ curr_sigma, _ = self.noise(t)
858
+ next_sigma, _ = self.noise(t - step_size)
859
+ dsigma = curr_sigma - next_sigma
860
+ score = self.get_score(x, attn_mask, curr_sigma)
861
+ stag_score = self._staggered_score(score, dsigma)
862
+ probs = stag_score * self._transp_transition(x, dsigma)
863
+ return sample_categorical(probs)
864
+
865
+ def _denoiser_update(self, x, t):
866
+ sigma, _ = self.noise(t)
867
+ score = self.get_score(x, sigma)
868
+ stag_score = self._staggered_score(score, sigma)
869
+ probs = stag_score * self._transp_transition(x, sigma)
870
+ probs[..., self.mask_token_id] = 0
871
+ samples = sample_categorical(probs)
872
+ return samples
873
+
874
+ def _transp_transition(self, i, sigma):
875
+ sigma = unsqueeze(sigma, reference=i[..., None])
876
+ edge = torch.exp(-sigma) * F.one_hot(
877
+ i, num_classes=self.vocab_size)
878
+ edge += torch.where(i == self.mask_token_id,
879
+ 1 - torch.exp(-sigma).squeeze(-1),
880
+ 0)[..., None]
881
+ return edge
882
+
883
+
884
+ """TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
885
+
886
+ def on_train_epoch_start(self):
887
+ torch.cuda.empty_cache()
888
+ self.backbone.train()
889
+ self.noise.train()
890
+
891
+
892
+ def training_step(self, batch, batch_idx):
893
+ # Initialize throughput calculation
894
+ start_time = time.time()
895
+
896
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
897
+ loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
898
+ else:
899
+ loss = self._compute_loss(batch, prefix='train')
900
+
901
+ self.log(name='trainer/loss',
902
+ value=loss.item(),
903
+ on_step=True,
904
+ on_epoch=False,
905
+ sync_dist=True)
906
+
907
+ # Calculate throughput
908
+ elapsed_time = time.time() - start_time
909
+ total_tokens = batch['input_ids'].numel()
910
+ throughput = total_tokens / elapsed_time
911
+
912
+ self.log(name='trainer/throughput',
913
+ value=throughput,
914
+ on_step=True,
915
+ on_epoch=False,
916
+ sync_dist=True)
917
+
918
+ return loss
919
+
920
+
921
+ def on_load_checkpoint(self, checkpoint):
922
+ self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
923
+ self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
924
+
925
+ """VALIDATION"""
926
+ def on_validation_epoch_start(self):
927
+ gc.collect()
928
+ torch.cuda.empty_cache()
929
+ self.backbone.eval()
930
+ self.noise.eval()
931
+ assert self.valid_metrics.nll.mean_value == 0
932
+ assert self.valid_metrics.nll.weight == 0
933
+
934
+ def validation_step(self, batch, batch_idx):
935
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
936
+ loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
937
+ else:
938
+ loss = self._compute_loss(batch, prefix='val')
939
+
940
+ self.log(name='trainer/val_loss',
941
+ value=loss.item(),
942
+ on_step=True,
943
+ on_epoch=False,
944
+ prog_bar=True,
945
+ sync_dist=True)
946
+ return loss
947
+
948
+ def on_validation_epoch_end(self):
949
+ gc.collect()
950
+ torch.cuda.empty_cache()
951
+
952
+ """OPTIMIZATION"""
953
+
954
+ def optimizer_step(self, *args, **kwargs):
955
+ super().optimizer_step(*args, **kwargs)
956
+
957
+ gc.collect()
958
+ torch.cuda.empty_cache()
959
+
960
+ def configure_optimizers(self):
961
+ optimizer = torch.optim.AdamW(
962
+ itertools.chain(self.backbone.parameters(),self.noise.parameters()),
963
+ lr=self.config.optim.lr,
964
+ betas=(self.config.optim.beta1, self.config.optim.beta2),
965
+ eps=self.config.optim.eps,
966
+ weight_decay=self.config.optim.weight_decay
967
+ )
968
+
969
+ self.total_steps = self.config.trainer.max_steps
970
+ scheduler = CosineWarmup(optimizer,
971
+ warmup_steps=self.config.lr_scheduler.num_warmup_steps,
972
+ total_steps=self.total_steps)
973
+
974
+ scheduler_dict = {
975
+ 'scheduler': scheduler,
976
+ 'interval': 'step',
977
+ 'frequency': 1,
978
+ 'monitor': 'val/loss',
979
+ 'name': 'trainer/lr'
980
+ }
981
+
982
+ return [optimizer], [scheduler_dict]
983
+
984
+ @torch.no_grad()
985
+ def compute_masked_perplexity(self, generated_ids, input_ids):
986
+ """
987
+ Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
988
+ """
989
+
990
+ total_nll = 0
991
+ total_tokens = 0
992
+
993
+ input_ids = torch.tensor(input_ids).to(self.device)
994
+ #print(input_ids)
995
+
996
+ for sequence in generated_ids:
997
+ # tokenize the sequence
998
+
999
+ gt_ids = torch.tensor(sequence).to(self.device)
1000
+ #print(gt_ids)
1001
+
1002
+ sys.stdout.flush()
1003
+
1004
+ # forward pass thorugh backbone peptideclm model
1005
+ attn_mask = torch.ones_like(input_ids).to(self.device)
1006
+
1007
+ # compute logits using backbone
1008
+
1009
+ if self.config.mode in ['train', 'ppl_eval']:
1010
+ outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
1011
+ elif self.config.mode == 'sample_eval':
1012
+ outputs = self.backbone.forward(input_ids=input_ids)
1013
+
1014
+
1015
+ # get logits for each position in sequence across all tokens in vocab
1016
+ #logits = outputs[-1] # (batch_size, seq_length, vocab_size)
1017
+
1018
+ logits = outputs.view(-1, outputs.size(-1))
1019
+ gt_ids = gt_ids.view(-1)
1020
+
1021
+ #print(logits.shape)
1022
+ #print(gt_ids.shape)
1023
+
1024
+ # compute loss
1025
+ # shift_logits = logits[:, :-1, :].contiguous() # remove eos
1026
+ # shift_labels = input_ids[:, 1:].contiguous()
1027
+ # print(masked)
1028
+
1029
+ loss = F.cross_entropy(logits,
1030
+ gt_ids.where(input_ids==self.mask_token_id, torch.full_like(gt_ids, -100)).view(-1),
1031
+ reduction='sum')
1032
+
1033
+ total_nll += loss.item()
1034
+ # count all non-padding tokens
1035
+ total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
1036
+
1037
+ # compute pseudo-perplexity
1038
+ # print(total_nll, ",;,", total_tokens)
1039
+ pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
1040
+ self.gen_ppl_metric.update(pseudo_perplexity)
1041
+
1042
+ return pseudo_perplexity.item()
1043
+
1044
+
1045
+ def sample_categorical(categorical_probs):
1046
+ gumbel_norm = (
1047
+ 1e-10
1048
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
1049
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
1050
+
1051
+ def sample_batched_categorical(categorical_probs, batch_size):
1052
+ """
1053
+ Generates `m` distinct sequences sampled from categorical probabilities
1054
+ using the Gumbel distribution to ensure randomness while following probabilities
1055
+
1056
+ Args:
1057
+ categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
1058
+ representing categorical probabilities
1059
+ m (int): number of distinct sequences to sample
1060
+
1061
+ Returns:
1062
+ torch.Tensor: tensor of shape (m, sequence_length), where each row is a
1063
+ distinct sequence of sampled category indices.
1064
+ """
1065
+ _, sequence_length, vocab_size = categorical_probs.shape
1066
+
1067
+ # add Gumbel noise and sample m sequences
1068
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
1069
+ noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
1070
+
1071
+ # select the highest score (most likely category after Gumbel noise)
1072
+ sampled_sequences = noisy_scores.argmax(dim=-1) # shape: (m, sequence_length)
1073
+
1074
+ return sampled_sequences
1075
+
1076
+ def sample_batched_top_k(categorical_probs, batch_size, k):
1077
+ """
1078
+ Generates `m` sequences sampled from the top-k probabilities of each token
1079
+ using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
1080
+
1081
+ Args:
1082
+ categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
1083
+ representing categorical probabilities.
1084
+ m (int): Number of sequences to sample.
1085
+ k (int): Number of top probabilities to consider for sampling.
1086
+
1087
+ Returns:
1088
+ torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
1089
+ sampled sequence of category indices.
1090
+ """
1091
+ _, sequence_length, vocab_length = categorical_probs.shape
1092
+
1093
+ # Add Gumbel noise to the log probabilities
1094
+ gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
1095
+ noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
1096
+
1097
+ # Get the top-k categories based on noisy scores
1098
+ top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
1099
+
1100
+ # Convert top-k scores back to probabilities and normalize
1101
+ top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
1102
+
1103
+ # Sample randomly from the top-k probabilities
1104
+ sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
1105
+ sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
1106
+
1107
+ # Map sampled indices back to the original vocabulary indices
1108
+ sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device)
1109
+
1110
+ return sampled_sequences
1111
+
1112
+ def unsqueeze(x, reference):
1113
+ return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
1114
+
1115
+ class CosineWarmup(_LRScheduler):
1116
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
1117
+ self.warmup_steps = warmup_steps
1118
+ self.total_steps = total_steps
1119
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
1120
+ super(CosineWarmup, self).__init__(optimizer, last_epoch)
1121
+
1122
+ def get_lr(self):
1123
+ if self.last_epoch < self.warmup_steps:
1124
+ return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
1125
+
1126
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
1127
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
1128
+ decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
1129
+
1130
+ return [decayed_lr * base_lr for base_lr in self.base_lrs]
generate_mcts.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ import random
6
+ import sys
7
+ import pandas as pd
8
+ from utils.generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
9
+ from diffusion import Diffusion
10
+ from pareto_mcts import Node, MCTS
11
+ import hydra
12
+ from tqdm import tqdm
13
+ from transformers import AutoTokenizer, AutoModel, pipeline
14
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
15
+ from helm_tokenizer.helm_tokenizer import HelmTokenizer
16
+ from utils.helm_utils import create_helm_from_aa_seq
17
+ from utils.app import PeptideAnalyzer
18
+ from new_tokenizer.ape_tokenizer import APETokenizer
19
+ import matplotlib.pyplot as plt
20
+ import os
21
+ import seaborn as sns
22
+ import pandas as pd
23
+ import numpy as np
24
+
25
+ def save_logs_to_file(config, valid_fraction_log, affinity1_log, affinity2_log, sol_log, hemo_log, nf_log, permeability_log, output_path):
26
+ """
27
+ Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
28
+
29
+ Parameters:
30
+ valid_fraction_log (list): Log of valid fractions over iterations.
31
+ affinity1_log (list): Log of binding affinity over iterations.
32
+ permeability_log (list): Log of membrane permeability over iterations.
33
+ output_path (str): Path to save the log CSV file.
34
+ """
35
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
36
+
37
+ if config.mcts.perm:
38
+ # Combine logs into a DataFrame
39
+ log_data = {
40
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
41
+ "Valid Fraction": valid_fraction_log,
42
+ "Binding Affinity": affinity1_log,
43
+ "Solubility": sol_log,
44
+ "Hemolysis": hemo_log,
45
+ "Nonfouling": nf_log,
46
+ "Permeability": permeability_log
47
+ }
48
+ elif config.mcts.dual:
49
+ log_data = {
50
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
51
+ "Valid Fraction": valid_fraction_log,
52
+ "Binding Affinity 1": affinity1_log,
53
+ "Binding Affinity 2": affinity2_log,
54
+ "Solubility": sol_log,
55
+ "Hemolysis": hemo_log,
56
+ "Nonfouling": nf_log,
57
+ "Permeability": permeability_log
58
+ }
59
+ elif config.mcts.single:
60
+ log_data = {
61
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
62
+ "Valid Fraction": valid_fraction_log,
63
+ "Permeability": permeability_log
64
+ }
65
+ else:
66
+ log_data = {
67
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
68
+ "Valid Fraction": valid_fraction_log,
69
+ "Binding Affinity": affinity1_log,
70
+ "Solubility": sol_log,
71
+ "Hemolysis": hemo_log,
72
+ "Nonfouling": nf_log
73
+ }
74
+
75
+ df = pd.DataFrame(log_data)
76
+
77
+ # Save to CSV
78
+ df.to_csv(output_path, index=False)
79
+
80
+ def plot_data(log1, log2=None,
81
+ save_path=None,
82
+ label1="Log 1",
83
+ label2=None,
84
+ title="Fraction of Valid Peptides Over Iterations",
85
+ palette=None):
86
+ """
87
+ Plots one or two datasets with their mean values over iterations.
88
+
89
+ Parameters:
90
+ log1 (list): The first list of mean values for each iteration.
91
+ log2 (list, optional): The second list of mean values for each iteration. Defaults to None.
92
+ save_path (str): Path to save the plot. Defaults to None.
93
+ label1 (str): Label for the first dataset. Defaults to "Log 1".
94
+ label2 (str, optional): Label for the second dataset. Defaults to None.
95
+ title (str): Title of the plot. Defaults to "Mean Values Over Iterations".
96
+ palette (dict, optional): A dictionary defining custom colors for datasets. Defaults to None.
97
+ """
98
+ # Prepare data for log1
99
+ data1 = pd.DataFrame({
100
+ "Iteration": range(1, len(log1) + 1),
101
+ "Fraction of Valid Peptides": log1,
102
+ "Dataset": label1
103
+ })
104
+
105
+ # Prepare data for log2 if provided
106
+ if log2 is not None:
107
+ data2 = pd.DataFrame({
108
+ "Iteration": range(1, len(log2) + 1),
109
+ "Fraction of Valid Peptides": log2,
110
+ "Dataset": label2
111
+ })
112
+ data = pd.concat([data1, data2], ignore_index=True)
113
+ else:
114
+ data = data1
115
+
116
+ palette = {
117
+ label1: "#8181ED", # Default color for log1
118
+ label2: "#D577FF" # Default color for log2 (if provided)
119
+ }
120
+
121
+ # Set Seaborn theme
122
+ sns.set_theme()
123
+ sns.set_context("paper")
124
+
125
+ # Create the plot
126
+ sns.lineplot(
127
+ data=data,
128
+ x="Iteration",
129
+ y="Fraction of Valid Peptides",
130
+ hue="Dataset",
131
+ style="Dataset",
132
+ markers=True,
133
+ dashes=False,
134
+ palette=palette
135
+ )
136
+
137
+ # Titles and labels
138
+ plt.title(title)
139
+ plt.xlabel("Iteration")
140
+ plt.ylabel("Fraction of Valid Peptides")
141
+
142
+ if save_path:
143
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
144
+ print(f"Plot saved to {save_path}")
145
+ plt.show()
146
+
147
+ def plot_data_with_distribution_seaborn(log1, log2=None,
148
+ save_path=None,
149
+ label1=None,
150
+ label2=None,
151
+ title=None):
152
+ """
153
+ Plots one or two datasets with the average values and distributions over iterations using Seaborn.
154
+
155
+ Parameters:
156
+ log1 (list of lists): The first list of scores (each element is a list of scores for an iteration).
157
+ log2 (list of lists, optional): The second list of scores (each element is a list of scores for an iteration). Defaults to None.
158
+ save_path (str): Path to save the plot. Defaults to None.
159
+ label1 (str): Label for the first dataset. Defaults to "Fraction of Valid Peptide SMILES".
160
+ label2 (str, optional): Label for the second dataset. Defaults to None.
161
+ title (str): Title of the plot. Defaults to "Fraction of Valid Peptides Over Iterations".
162
+ """
163
+ # Prepare data for log1
164
+ data1 = pd.DataFrame({
165
+ "Iteration": np.repeat(range(1, len(log1) + 1), [len(scores) for scores in log1]),
166
+ "Fraction of Valid Peptides": [score for scores in log1 for score in scores],
167
+ "Dataset": label1,
168
+ "Style": "Log1"
169
+ })
170
+
171
+ # Prepare data for log2 if provided
172
+ if log2 is not None:
173
+ data2 = pd.DataFrame({
174
+ "Iteration": np.repeat(range(1, len(log2) + 1), [len(scores) for scores in log2]),
175
+ "Fraction of Valid Peptides": [score for scores in log2 for score in scores],
176
+ "Dataset": label2,
177
+ "Style": "Log2"
178
+ })
179
+ data = pd.concat([data1, data2], ignore_index=True)
180
+ else:
181
+ data = data1
182
+
183
+ palette = {
184
+ label1: "#8181ED", # Default color for log1
185
+ label2: "#D577FF" # Default color for log2 (if provided)
186
+ }
187
+
188
+ # Set Seaborn theme
189
+ sns.set_theme()
190
+ sns.set_context("paper")
191
+
192
+ # Create the plot
193
+ sns.relplot(
194
+ data=data,
195
+ kind="line",
196
+ x="Iteration",
197
+ y="Fraction of Valid Peptides",
198
+ hue="Dataset",
199
+ style="Style",
200
+ markers=True,
201
+ dashes=True,
202
+ ci="sd", # Show standard deviation
203
+ height=5,
204
+ aspect=1.5,
205
+ palette=palette
206
+ )
207
+
208
+ # Titles and labels
209
+ plt.title(title)
210
+ plt.xlabel("Iteration")
211
+ plt.ylabel("Fraction of Valid Peptides")
212
+
213
+ if save_path:
214
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
215
+ print(f"Plot saved to {save_path}")
216
+ plt.show()
217
+
218
+ @torch.no_grad()
219
+ def generate_valid_mcts(config, mdlm, prot1=None, prot2=None, filename=None, prot_name1=None, prot_name2 = None):
220
+ tokenizer = mdlm.tokenizer
221
+ max_sequence_length = config.sampling.seq_length
222
+
223
+ # generate array of [MASK] tokens
224
+ masked_array = mask_for_de_novo(config, max_sequence_length)
225
+
226
+ if config.vocab == 'old_smiles':
227
+ # use custom encode function
228
+ inputs = tokenizer.encode(masked_array)
229
+ elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
230
+ inputs = tokenizer.encode_for_generation(masked_array)
231
+ else:
232
+ # custom HELM tokenizer
233
+ inputs = tokenizer(masked_array, return_tensors="pt")
234
+
235
+ inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
236
+
237
+ # initialize root node
238
+ rootNode = Node(config=config, tokens=inputs, timestep=0)
239
+ # initalize tree search algorithm
240
+
241
+ if config.mcts.perm:
242
+ score_func_names = ['permeability', 'binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
243
+ num_func = [0, 50, 50, 50, 50]
244
+ elif config.mcts.dual:
245
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'binding_affinity2']
246
+ elif config.mcts.single:
247
+ score_func_names = ['permeability']
248
+ else:
249
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
250
+
251
+ if not config.mcts.time_dependent:
252
+ num_func = [0] * len(score_func_names)
253
+
254
+ if prot1 and prot2 is not None:
255
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1, prot2], num_func=num_func)
256
+ elif prot1 is not None:
257
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1], num_func=num_func)
258
+ elif config.mcts.single:
259
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
260
+ else:
261
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
262
+
263
+ paretoFront = mcts.forward(rootNode)
264
+
265
+ output_log_path = f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/log_{filename}.csv'
266
+ save_logs_to_file(config, mcts.valid_fraction_log, mcts.affinity1_log, mcts.affinity2_log, mcts.sol_log, mcts.hemo_log, mcts.nf_log, mcts.permeability_log, output_log_path)
267
+
268
+ if config.mcts.single:
269
+ plot_data_with_distribution_seaborn(log1=mcts.permeability_log,
270
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/perm_{filename}.png',
271
+ label1="Average Permeability Score",
272
+ title="Average Permeability Score Over Iterations")
273
+ else:
274
+ plot_data(mcts.valid_fraction_log,
275
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/valid_{filename}.png')
276
+ plot_data_with_distribution_seaborn(log1=mcts.affinity1_log,
277
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/binding1_{filename}.png',
278
+ label1="Average Binding Affinity to TfR",
279
+ title="Average Binding Affinity to TfR Over Iterations")
280
+ if config.mcts.dual:
281
+ plot_data_with_distribution_seaborn(log1=mcts.affinity2_log,
282
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/binding2_{filename}.png',
283
+ label1="Average Binding Affinity to SKP2",
284
+ title="Average Binding Affinity to SKP2 Over Iterations")
285
+ plot_data_with_distribution_seaborn(log1=mcts.sol_log,
286
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/sol_{filename}.png',
287
+ label1="Average Solubility Score",
288
+ title="Average Solubility Score Over Iterations")
289
+ plot_data_with_distribution_seaborn(log1=mcts.hemo_log,
290
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/hemo_{filename}.png',
291
+ label1="Average Hemolysis Score",
292
+ title="Average Hemolysis Score Over Iterations")
293
+ plot_data_with_distribution_seaborn(log1=mcts.nf_log,
294
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/nf_{filename}.png',
295
+ label1="Average Nonfouling Score",
296
+ title="Average Nonfouling Score Over Iterations")
297
+ if config.mcts.perm:
298
+ plot_data_with_distribution_seaborn(log1=mcts.permeability_log,
299
+ save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/perm_{filename}.png',
300
+ label1="Average Permeability Score",
301
+ title="Average Permeability Score Over Iterations")
302
+
303
+ return paretoFront, inputs
304
+
305
+
306
+ @hydra.main(version_base=None, config_path='/home/st512/peptune/scripts/peptide-mdlm-mcts', config_name='config')
307
+ def main(config):
308
+ prot_name1 = "time_dependent"
309
+ prot_name2 = "skp2"
310
+ mode = "2"
311
+ model = "mcts"
312
+ length = "100"
313
+ epoch = "7"
314
+
315
+ filename = f'{mode}_{model}_length_{length}_epoch_{epoch}'
316
+
317
+ if config.vocab == 'new_smiles':
318
+ tokenizer = APETokenizer()
319
+ tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_smiles_600_vocab.json')
320
+ elif config.vocab == 'old_smiles':
321
+ tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
322
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
323
+ elif config.vocab == 'selfies':
324
+ tokenizer = APETokenizer()
325
+ tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_selfies_600_vocab.json')
326
+ elif config.vocab == 'helm':
327
+ tokenizer = HelmTokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/helm_tokenizer/monomer_vocab.txt')
328
+
329
+ mdlm = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer, strict=False)
330
+
331
+ mdlm.eval()
332
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
333
+ mdlm.to(device)
334
+
335
+
336
+ print("loaded models...")
337
+ analyzer = PeptideAnalyzer()
338
+
339
+ # proteins
340
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
341
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
342
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
343
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
344
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
345
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
346
+ cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
347
+ ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
348
+ skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
349
+
350
+ paretoFront, input_array = generate_valid_mcts(config, mdlm, gfap, None, filename, prot_name1, None)
351
+ generation_results = []
352
+
353
+ for sequence, v in paretoFront.items():
354
+ generated_array = v['token_ids'].to(mdlm.device)
355
+
356
+ # compute perplexity
357
+ perplexity = mdlm.compute_masked_perplexity(generated_array, input_array['input_ids'])
358
+ perplexity = round(perplexity, 4)
359
+
360
+ aa_seq, seq_length = analyzer.analyze_structure(sequence)
361
+ scores = v['scores']
362
+
363
+ if config.mcts.single == False:
364
+ binding1 = scores[0]
365
+ solubility = scores[1]
366
+ hemo = scores[2]
367
+ nonfouling = scores[3]
368
+
369
+ if config.mcts.perm:
370
+ permeability = scores[4]
371
+ generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling, permeability])
372
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}")
373
+ elif config.mcts.dual:
374
+ binding2 = scores[4]
375
+ generation_results.append([sequence, perplexity, aa_seq, binding1, binding2, solubility, hemo, nonfouling])
376
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity 1: {binding1} | Binding Affinity 2: {binding2} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
377
+ elif config.mcts.single:
378
+ permeability = scores[0]
379
+ else:
380
+ generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling])
381
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
382
+
383
+ sys.stdout.flush()
384
+
385
+ if config.mcts.perm:
386
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
387
+ elif config.mcts.dual:
388
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity 1', 'Binding Affinity 2', 'Solubility', 'Hemolysis', 'Nonfouling'])
389
+ elif config.mcts.single:
390
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Permeability'])
391
+ else:
392
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling'])
393
+
394
+ df.to_csv(f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/{filename}.csv', index=False)
395
+
396
+
397
+ if __name__ == "__main__":
398
+ main()
generate_unconditional.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import random
5
+ import sys
6
+ import pandas as pd
7
+ from utils.generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
8
+ from diffusion import Diffusion
9
+ import hydra
10
+ from tqdm import tqdm
11
+ from transformers import AutoTokenizer, AutoModel, pipeline
12
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
13
+ from helm_tokenizer.helm_tokenizer import HelmTokenizer
14
+ from utils.helm_utils import create_helm_from_aa_seq, get_smi_from_helms
15
+ from utils.filter import PeptideAnalyzer
16
+ from new_tokenizer.ape_tokenizer import APETokenizer
17
+ from scoring.scoring_functions import ScoringFunctions
18
+
19
+
20
+ @torch.no_grad()
21
+ def generate_sequence_unconditional(config, sequence_length: int, mdlm: Diffusion):
22
+ tokenizer = mdlm.tokenizer
23
+ # generate array of [MASK] tokens
24
+ masked_array = mask_for_de_novo(config, sequence_length)
25
+
26
+ if config.vocab == 'old_smiles':
27
+ # use custom encode function
28
+ inputs = tokenizer.encode(masked_array)
29
+ elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
30
+ inputs = tokenizer.encode_for_generation(masked_array)
31
+ else:
32
+ # custom HELM tokenizer
33
+ inputs = tokenizer(masked_array, return_tensors="pt")
34
+
35
+ # tokenized masked array
36
+ inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
37
+ # sample unconditional array of tokens
38
+ logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
39
+
40
+ return logits, inputs
41
+
42
+
43
+ @hydra.main(version_base=None, config_path='/home/st512/peptune/scripts/peptide-mdlm-mcts', config_name='config')
44
+ def main(config):
45
+ path = "/home/st512/peptune/scripts/peptide-mdlm-mcts"
46
+
47
+ if config.vocab == 'new_smiles':
48
+ tokenizer = APETokenizer()
49
+ tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_smiles_600_vocab.json')
50
+ elif config.vocab == 'old_smiles':
51
+ tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
52
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
53
+ elif config.vocab == 'selfies':
54
+ tokenizer = APETokenizer()
55
+ tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_selfies_600_vocab.json')
56
+ elif config.vocab == 'helm':
57
+ tokenizer = HelmTokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/helm_tokenizer/monomer_vocab.txt')
58
+
59
+ mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer, strict=False)
60
+
61
+ mdlm_model.eval()
62
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
63
+ mdlm_model.to(device)
64
+
65
+ print("loaded models...")
66
+ analyzer = PeptideAnalyzer()
67
+
68
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
69
+
70
+ # scoring functions
71
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
72
+ score_functions = ScoringFunctions(score_func_names, [gfap])
73
+
74
+
75
+ max_seq_length = config.sampling.seq_length
76
+ num_sequences = config.sampling.num_sequences
77
+ generation_results = []
78
+ num_valid = 0.
79
+ num_total = 0.
80
+ while num_total < num_sequences:
81
+ num_total += 1
82
+ generated_array, input_array = generate_sequence_unconditional(config, max_seq_length, mdlm_model)
83
+
84
+ # store in device
85
+ generated_array = generated_array.to(mdlm_model.device)
86
+ print(generated_array)
87
+
88
+ # compute masked perplexity
89
+ perplexity = mdlm_model.compute_masked_perplexity(generated_array, input_array['input_ids'])
90
+ perplexity = round(perplexity, 4)
91
+
92
+ if config.vocab == 'old_smiles' or config.vocab == 'new_smiles':
93
+ smiles_seq = tokenizer.decode(generated_array)
94
+ if analyzer.is_peptide(smiles_seq):
95
+ aa_seq, seq_length = analyzer.analyze_structure(smiles_seq)
96
+ num_valid += 1
97
+ scores = score_functions(input_seqs=[smiles_seq])
98
+
99
+ binding = scores[0][0]
100
+ sol = scores[0][1]
101
+ hemo = scores[0][2]
102
+ nf = scores[0][3]
103
+ perm = scores[0][4]
104
+
105
+ generation_results.append([smiles_seq, perplexity, aa_seq, binding, sol, hemo, nf, perm])
106
+ else:
107
+ aa_seq = "not valid peptide"
108
+ seq_length = '-'
109
+ scores = "not valid peptide"
110
+ elif config.vocab == 'selfies':
111
+ smiles_seq = tokenizer.decode(generated_array)
112
+ else:
113
+ aa_seq = tokenizer.decode(generated_array)
114
+ smiles_seq = get_smi_from_helms(aa_seq)
115
+
116
+
117
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {smiles_seq} | amino acid sequence: {aa_seq} | scores: {scores}")
118
+ sys.stdout.flush()
119
+
120
+ valid_frac = num_valid / num_total
121
+ print(f"fraction of synthesizable peptides: {valid_frac}")
122
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
123
+ df.to_csv(path + f'/benchmarks/unconditional/epoch-10-pretrain-gfap.csv', index=False)
124
+
125
+ if __name__ == "__main__":
126
+ main()
main.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import os
3
+ #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096'
4
+ import uuid
5
+
6
+ import wandb
7
+ import fsspec
8
+ import hydra
9
+ import lightning as L
10
+ from lightning.pytorch import Trainer
11
+ from lightning.pytorch.callbacks import ModelCheckpoint, GradientAccumulationScheduler
12
+ import omegaconf
13
+ import rich.syntax
14
+ import rich.tree
15
+ import torch
16
+ import sys
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ sys.path.append("/home/st512/peptune/scripts/peptide-mdlm-mcts")
20
+
21
+ import dataset as dataloader
22
+ import dataloading_for_dynamic_batching as dynamic_dataloader
23
+ from diffusion import Diffusion
24
+ import utils.utils as utils
25
+ from new_tokenizer.ape_tokenizer import APETokenizer
26
+
27
+ from lightning.pytorch.strategies import DDPStrategy
28
+ from datasets import load_dataset
29
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
30
+ from helm_tokenizer.helm_tokenizer import HelmTokenizer
31
+
32
+
33
+ #wandb.login(key="5a7613c531cb58f9802f3f8e2f73bc4997b917ab")
34
+
35
+ omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd)
36
+ omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count)
37
+ omegaconf.OmegaConf.register_new_resolver('eval', eval)
38
+ omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
39
+
40
+ def _load_from_checkpoint(config, tokenizer):
41
+ if 'hf' in config.backbone:
42
+ return Diffusion(
43
+ config, tokenizer=tokenizer).to('cuda')
44
+ else:
45
+ model = Diffusion.load_from_checkpoint(
46
+ config.eval.checkpoint_path,
47
+ tokenizer=tokenizer,
48
+ config=config)
49
+
50
+ return model
51
+
52
+ @L.pytorch.utilities.rank_zero_only
53
+ def print_config(
54
+ config: omegaconf.DictConfig,
55
+ resolve: bool = True,
56
+ save_cfg: bool = True) -> None:
57
+ """
58
+ Prints content of DictConfig using Rich library and its tree structure.
59
+
60
+ Args:
61
+ config (DictConfig): Configuration composed by Hydra.
62
+ resolve (bool): Whether to resolve reference fields of DictConfig.
63
+ save_cfg (bool): Whether to save the configuration tree to a file.
64
+ """
65
+
66
+ style = 'dim'
67
+ tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
68
+
69
+ fields = config.keys()
70
+ for field in fields:
71
+ branch = tree.add(field, style=style, guide_style=style)
72
+
73
+ config_section = config.get(field)
74
+ branch_content = str(config_section)
75
+ if isinstance(config_section, omegaconf.DictConfig):
76
+ branch_content = omegaconf.OmegaConf.to_yaml(
77
+ config_section, resolve=resolve)
78
+
79
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
80
+ rich.print(tree)
81
+ if save_cfg:
82
+ with fsspec.open(
83
+ '{}/config_tree.txt'.format(
84
+ config.checkpointing.save_dir), 'w') as fp:
85
+ rich.print(tree, file=fp)
86
+
87
+
88
+ @L.pytorch.utilities.rank_zero_only
89
+ def print_batch(train_ds, valid_ds, tokenizer, k=64):
90
+ #for dl_type, dl in [
91
+ #('train', train_ds), ('valid', valid_ds)]:
92
+
93
+ for dl_type, dl in [
94
+ ('train', train_ds)]:
95
+ print(f'Printing {dl_type} dataloader batch.')
96
+ batch = next(iter(dl))
97
+ print('Batch input_ids.shape', batch['input_ids'].shape)
98
+ first = batch['input_ids'][0, :k]
99
+ last = batch['input_ids'][0, -k:]
100
+ print(f'First {k} tokens:', tokenizer.decode(first))
101
+ print('ids:', first)
102
+ print(f'Last {k} tokens:', tokenizer.decode(last))
103
+ print('ids:', last)
104
+
105
+
106
+ def generate_samples(config, logger, tokenizer):
107
+ logger.info('Generating samples.')
108
+ model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
109
+ # model.gen_ppl_metric.reset()
110
+
111
+ #stride_length = config.sampling.stride_length
112
+ #num_strides = config.sampling.num_strides
113
+
114
+ for _ in range(config.sampling.num_sample_batches):
115
+ samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
116
+ peptide_sequences = model.tokenizer.batch_decode(samples)
117
+ model.compute_generative_perplexity(peptide_sequences)
118
+
119
+ print('Peptide samples:', peptide_sequences)
120
+
121
+ print('Generative perplexity:', model.compute_masked_perplexity())
122
+
123
+ return peptide_sequences
124
+
125
+
126
+ def ppl_eval(config, logger, tokenizer, data_module):
127
+ logger.info('Starting Zero Shot Eval.')
128
+
129
+ model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
130
+
131
+ wandb_logger = None
132
+ if config.get('wandb', None) is not None:
133
+ wandb_logger = L.pytorch.loggers.WandbLogger(
134
+ config=omegaconf.OmegaConf.to_object(config),
135
+ ** config.wandb)
136
+
137
+ callbacks = []
138
+
139
+ if 'callbacks' in config:
140
+ for _, callback in config.callbacks.items():
141
+ callbacks.append(hydra.utils.instantiate(callback))
142
+
143
+ trainer = hydra.utils.instantiate(
144
+ config.trainer,
145
+ default_root_dir=os.getcwd(),
146
+ callbacks=callbacks,
147
+ strategy=DDPStrategy(find_unused_parameters = True),
148
+ logger=wandb_logger)
149
+
150
+ #_, valid_ds = dataloader.get_dataloaders(config, tokenizer, skiptrain=True, valid_seed=config.seed)
151
+ trainer.test(model, data_module)
152
+
153
+
154
+ def _train(config, logger, tokenizer, data_module):
155
+ logger.info('Starting Training.')
156
+ wandb_logger = None
157
+
158
+ if config.get('wandb', None) is not None:
159
+ unique_id = str(uuid.uuid4())
160
+
161
+ config.wandb.id = f"{config.wandb.id}_{unique_id}"
162
+
163
+ wandb_logger = L.pytorch.loggers.WandbLogger(
164
+ config=omegaconf.OmegaConf.to_object(config),
165
+ ** config.wandb)
166
+
167
+ if (config.checkpointing.resume_from_ckpt
168
+ and config.checkpointing.resume_ckpt_path is not None
169
+ and utils.fsspec_exists(
170
+ config.checkpointing.resume_ckpt_path)):
171
+ ckpt_path = config.checkpointing.resume_ckpt_path
172
+ else:
173
+ ckpt_path = None
174
+
175
+ # Lightning callbacks
176
+ callbacks = []
177
+ if 'callbacks' in config:
178
+ for callback_name, callback_config in config.callbacks.items():
179
+ if callback_name == 'model_checkpoint':
180
+ model_checkpoint_config = {k: v for k, v in callback_config.items() if k != '_target_'}
181
+ callbacks.append(ModelCheckpoint(**model_checkpoint_config))
182
+ else:
183
+ callbacks.append(hydra.utils.instantiate(callback_config))
184
+
185
+ if config.training.accumulator:
186
+ accumulator = GradientAccumulationScheduler(scheduling = {1: 5, 2: 4, 3: 3, 4: 1})
187
+ callbacks.append(accumulator)
188
+
189
+ trainer = hydra.utils.instantiate(
190
+ config.trainer,
191
+ default_root_dir=os.getcwd(),
192
+ callbacks=callbacks,
193
+ accelerator='cuda',
194
+ strategy=DDPStrategy(find_unused_parameters = True),
195
+ devices=[2,3,4,5,6,7],
196
+ logger=wandb_logger)
197
+
198
+ model = Diffusion(config, tokenizer=tokenizer)
199
+
200
+ if config.backbone == 'finetune_roformer':
201
+ checkpoint = torch.load('/home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/epoch=1-step=24080.ckpt')
202
+ model.load_state_dict(checkpoint['state_dict'])
203
+
204
+ trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
205
+
206
+
207
+ @hydra.main(version_base=None, config_path='/home/st512/peptune/scripts/peptide-mdlm-mcts', config_name='config')
208
+ def main(config):
209
+ """
210
+ Main entry point for training
211
+ """
212
+ wandb.init(project="peptune")
213
+ L.seed_everything(config.seed)
214
+
215
+ # print_config(config, resolve=True, save_cfg=True)
216
+
217
+ logger = utils.get_logger(__name__)
218
+ # load PeptideCLM tokenizer
219
+ if config.vocab == 'new_smiles':
220
+ tokenizer = APETokenizer()
221
+ tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_smiles_600_vocab.json')
222
+ elif config.vocab == 'old_smiles':
223
+ tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
224
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
225
+ elif config.vocab == 'selfies':
226
+ tokenizer = APETokenizer()
227
+ tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_selfies_600_vocab.json')
228
+ elif config.vocab == 'helm':
229
+ tokenizer = HelmTokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/helm_tokenizer/monomer_vocab.txt')
230
+
231
+ if config.backbone == 'finetune_roformer':
232
+ train_dataset = load_dataset('csv', data_files=config.data.train)
233
+ val_dataset = load_dataset('csv', data_files=config.data.valid)
234
+
235
+ train_dataset = train_dataset['train']#.select(lst)
236
+ val_dataset = val_dataset['train']#.select(lst)
237
+ data_module = dataloader.CustomDataModule(train_dataset, val_dataset, None, tokenizer, batch_size=config.loader.global_batch_size)
238
+ else:
239
+ data_module = dynamic_dataloader.CustomDataModule('/home/st512/peptune/scripts/peptide-mdlm-mcts/data/smiles/11M_smiles_old_tokenizer_no_limit', tokenizer)
240
+
241
+ if config.mode == 'sample_eval':
242
+ generate_samples(config, logger, tokenizer)
243
+ elif config.mode == 'ppl_eval':
244
+ ppl_eval(config, logger, tokenizer, data_module)
245
+ else:
246
+ _train(config, logger, tokenizer, data_module)
247
+
248
+
249
+ if __name__ == '__main__':
250
+ main()
noise_schedule.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ torch._C._jit_set_profiling_mode(False)
7
+ torch._C._jit_set_profiling_executor(False)
8
+ torch._C._jit_override_can_fuse_on_cpu(True)
9
+ torch._C._jit_override_can_fuse_on_gpu(True)
10
+
11
+ """
12
+ MDLM Github Repo:
13
+
14
+ """
15
+
16
+
17
+ def get_noise(config, dtype=torch.float32):
18
+ if config.noise.type == 'geometric':
19
+ return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
20
+ elif config.noise.type == 'loglinear':
21
+ return LogLinearNoise()
22
+ elif config.noise.type == 'cosine':
23
+ return CosineNoise()
24
+ elif config.noise.type == 'cosinesqr':
25
+ return CosineSqrNoise()
26
+ elif config.noise.type == 'linear':
27
+ return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
28
+ else:
29
+ raise ValueError(f'{config.noise.type} is not a valid noise')
30
+
31
+
32
+ def binary_discretization(z):
33
+ z_hard = torch.sign(z)
34
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
35
+ return z_soft + (z_hard - z_soft).detach()
36
+
37
+
38
+ class Noise(abc.ABC, nn.Module):
39
+ """
40
+ Baseline forward method to get the total + rate of noise at a timestep
41
+ """
42
+ def forward(self, t):
43
+ # Assume time goes from 0 to 1
44
+ return self.total_noise(t), self.rate_noise(t)
45
+
46
+
47
+ class CosineNoise(Noise):
48
+ def __init__(self, eps=1e-3):
49
+ super().__init__()
50
+ self.eps = eps
51
+
52
+ def rate_noise(self, t):
53
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
54
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
55
+ scale = torch.pi / 2
56
+ return scale * sin / (cos + self.eps)
57
+
58
+ def total_noise(self, t):
59
+ cos = torch.cos(t * torch.pi / 2)
60
+ return - torch.log(self.eps + (1 - self.eps) * cos)
61
+
62
+
63
+ class CosineSqrNoise(Noise):
64
+ def __init__(self, eps=1e-3):
65
+ super().__init__()
66
+ self.eps = eps
67
+
68
+ def rate_noise(self, t):
69
+ cos = (1 - self.eps) * (
70
+ torch.cos(t * torch.pi / 2) ** 2)
71
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
72
+ scale = torch.pi / 2
73
+ return scale * sin / (cos + self.eps)
74
+
75
+ def total_noise(self, t):
76
+ cos = torch.cos(t * torch.pi / 2) ** 2
77
+ return - torch.log(self.eps + (1 - self.eps) * cos)
78
+
79
+
80
+ class Linear(Noise):
81
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
82
+ super().__init__()
83
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
84
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
85
+
86
+ def rate_noise(self):
87
+ return self.sigma_max - self.sigma_min
88
+
89
+ def total_noise(self, t):
90
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
91
+
92
+ def importance_sampling_transformation(self, t):
93
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
94
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
95
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
96
+ return (sigma_t - self.sigma_min) / (
97
+ self.sigma_max - self.sigma_min)
98
+
99
+
100
+ class GeometricNoise(Noise):
101
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
102
+ super().__init__()
103
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
104
+
105
+ def rate_noise(self, t):
106
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
107
+ self.sigmas[1].log() - self.sigmas[0].log())
108
+
109
+ def total_noise(self, t):
110
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
111
+
112
+
113
+ class LogLinearNoise(Noise):
114
+ """Log Linear noise schedule.
115
+
116
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
117
+ ~1 when t varies from 0 to 1. Total noise is
118
+ -log(1 - (1 - eps) * t), so the sigma will be
119
+ (1 - eps) * t.
120
+ """
121
+ def __init__(self, eps=1e-3):
122
+ super().__init__()
123
+ self.eps = eps
124
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
125
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
126
+
127
+ def rate_noise(self, t):
128
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
129
+
130
+ def total_noise(self, t):
131
+ return -torch.log1p(-(1 - self.eps) * t)
132
+
133
+ def importance_sampling_transformation(self, t):
134
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
135
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
136
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
137
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
138
+ return t
139
+
140
+ class LogPolyNoise(Noise):
141
+ """
142
+ Log Polynomial noise schedule for slower masking of peptide bond tokens
143
+ """
144
+ def __init__(self, eps=1e-3):
145
+ super().__init__()
146
+ self.eps = eps
147
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
148
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
149
+
150
+ def rate_noise(self, t):
151
+ # derivative of -log(1-t^w)
152
+ return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
153
+
154
+ def total_noise(self, t):
155
+ # -log(1-t^w)
156
+ return -torch.log1p(-(1 - self.eps) * (t**3))
pareto_mcts.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import random as rd
6
+
7
+ from diffusion import Diffusion
8
+ from scoring.scoring_functions import ScoringFunctions
9
+ from utils.filter import PeptideAnalyzer
10
+ import noise_schedule
11
+
12
+ """"
13
+ Notes: store rolled out sequence?
14
+ path of node objects or strings?
15
+ should we only select valid expandable leaf nodes?
16
+ calculate similarity between sibling nodes?
17
+ should we evaluate generated sequences?
18
+ """
19
+ class Node:
20
+ """
21
+ Node class: partially unmasked SMILES string
22
+ - parentNode: Node object at previous time step
23
+ - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
24
+ - totalReward: vector of cumulative rewards for all K objectives
25
+ - visits: number of times the node has been visited by an interation
26
+ - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
27
+ - timestep: the time step where the sequence was sampled
28
+ - sampleProb: probability of sampling the sequence from the diffusion model
29
+ """
30
+ def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None):
31
+ self.config = config
32
+ self.parentNode = parentNode
33
+ self.childNodes = childNodes
34
+ self.scoreVector = scoreVector
35
+
36
+ # initialize total rewards to the reward of the roll out unmasked sequence
37
+ if totalReward is not None:
38
+ self.totalReward = totalReward
39
+ else:
40
+ self.totalReward = np.zeros(self.config.mcts.num_objectives)
41
+
42
+ # set initial visits to 1
43
+ self.visits = 1
44
+ # array of all sequences in path from the root -> node
45
+ #self.path = path
46
+ # set timestep (value between 0 and num_steps)
47
+ self.timestep = timestep
48
+ # set the sampling probabiltiy equal to the probability from the reverse posterior
49
+ self.sampleProb = sampleProb
50
+
51
+ # dict with 'input_ids' as token array and 'attention_mask'
52
+ self.tokens = tokens
53
+
54
+ #self.sequence = sequence
55
+
56
+ def selectNode(self, num_func):
57
+ """
58
+ Selects a node to move to among the children nodes
59
+ """
60
+ # extract the status of the current node
61
+ nodeStatus = self.getExpandStatus()
62
+
63
+ # if the node is a legal non-leaf node
64
+ if (nodeStatus == 3):
65
+ # initialize array that will store select score vectors of each child node
66
+ paretoFront = {}
67
+ for childNode in self.childNodes:
68
+ childStatus = childNode.getExpandStatus()
69
+ # only append child if it is legal leaf node (expandable) or legal non-leaf node
70
+ if childStatus == 2 or childStatus == 3:
71
+ selectScore = childNode.calcSelectScore()
72
+ paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func)
73
+
74
+ # randomly select a node on the Pareto front
75
+ #selected = rd.choice(paretoFront)
76
+ selected = rd.choice(list(paretoFront.keys()))
77
+ # return selected child node and status
78
+ return selected, selected.getExpandStatus()
79
+
80
+ # if node is not valid non-leaf node
81
+ return self, nodeStatus
82
+
83
+ def addChildNode(self, tokens, totalReward, prob=None):
84
+ """"
85
+ Adds a child node
86
+ """
87
+ child = Node(config=self.config,
88
+ tokens=tokens,
89
+ parentNode=self,
90
+ childNodes=[],
91
+ totalReward=totalReward,
92
+ timestep=self.timestep+1,
93
+ sampleProb=prob)
94
+
95
+ self.childNodes.append(child)
96
+ return child
97
+
98
+ def updateNode(self, rewards):
99
+ """
100
+ Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
101
+ Increments the number of visits to the node.
102
+ """
103
+ self.visits += 1
104
+ self.totalReward += rewards
105
+
106
+ def calcSelectScore(self):
107
+ """
108
+ Calculates the select score for the node from the cumulative rewards vector and number of visits.
109
+ - c: determines the degree of exploration
110
+ - minSelectScore: determines the
111
+ """
112
+ """"
113
+ if not self.parentNode:
114
+ return 0.0
115
+ """
116
+ # K-dimensional vector of normalized rewards for each objective
117
+ normRewards = self.totalReward / self.visits
118
+ if self.sampleProb is not None:
119
+ print("Sample Prob")
120
+ print(self.sampleProb)
121
+ return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits)
122
+ return normRewards
123
+
124
+ def getExpandStatus(self):
125
+ """
126
+ Returns an integer indicating whether the node is a:
127
+ 1. terminal node (sequence is fully unmasked)
128
+ 2. legal leaf node (partially unmasked sequence that can be expanded)
129
+ 3. legal non-leaf node (already expanded sequence with M child nodes)
130
+ """
131
+ if self.timestep == self.config.sampling.steps:
132
+ return 1
133
+ elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0):
134
+ return 2
135
+ return 3
136
+
137
+ """END OF NODE CLASS"""
138
+
139
+ def updateParetoFront(paretoFront, node, scoreVector, num_func):
140
+ """
141
+ Removes sequences that are dominated by scoreVector
142
+ adds the SMILES sequence if it is non-dominated and its scoreVector
143
+ """
144
+ paretoSize = len(paretoFront)
145
+ if paretoSize == 0:
146
+ # if pareto front is empty, add sequence and scoreVector
147
+ paretoFront[node] = scoreVector
148
+ else:
149
+ # vector of boolean
150
+ # true: sequence is non-dominated by the pareto-optimal sequence
151
+ # false: sequence is completely dominated by the pareto-optimal sequence
152
+ nondominate = []
153
+ # sequences to be deleted
154
+ delete = []
155
+ for k, v in paretoFront.items():
156
+ nondominated = scoreVector >= np.asarray(v)
157
+ dominant = scoreVector > np.asarray(v)
158
+
159
+ if num_func <= len(nondominated):
160
+ attn_nondominated = nondominated[:num_func]
161
+ attn_dominant = dominant[:num_func]
162
+
163
+ # all scores are greater than or equal to v and at least one score is strictly greater than v
164
+ if attn_nondominated.all() and attn_dominant.any():
165
+ # add the dominated sequence to be deleted
166
+ delete.append(k)
167
+ # sequence is dominant
168
+ nondominate.append(True)
169
+ elif attn_nondominated.all():
170
+ # sequence is non-dominated
171
+ nondominate.append(True)
172
+ else:
173
+ # sequence is completely dominated
174
+ nondominate.append(False)
175
+
176
+ nondominate = np.asarray(nondominate)
177
+ # if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
178
+ if nondominate.all():
179
+ paretoFront[node] = scoreVector
180
+
181
+ # delete all dominated sequences
182
+ while (paretoSize > 0) and (len(delete) > 0):
183
+ #for k in delete:
184
+ del paretoFront[delete[0]]
185
+ del delete[0]
186
+ paretoSize -= 1
187
+ return paretoFront
188
+
189
+ """BEGINNING OF MCTS CLASS"""
190
+
191
+ class MCTS:
192
+ def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []):
193
+ self.config = config
194
+ self.noise = noise_schedule.get_noise(config)
195
+ self.time_conditioning = self.config.time_conditioning
196
+ # dictionary of k (SMILES string) and v (score vector) of Pareto-optimal sequences
197
+ self.peptideParetoFront = {}
198
+ self.num_steps = config.sampling.steps
199
+ self.num_sequences = config.sampling.num_sequences
200
+
201
+ # mdlm model
202
+ self.mdlm = mdlm
203
+ self.tokenizer = mdlm.tokenizer
204
+ self.device = mdlm.device
205
+
206
+ if max_sequence_length is None:
207
+ self.sequence_length = self.config.sampling.seq_length
208
+ else:
209
+ self.sequence_length = max_sequence_length
210
+
211
+ self.num_iter = config.mcts.num_iter
212
+
213
+ self.num_child = config.mcts.num_children
214
+
215
+ # score functions
216
+ self.score_functions = ScoringFunctions(score_func_names, prot_seqs)
217
+ self.num_func = num_func # K-dimensional vector with the iteration number to start conditioning on each of the objectives in increasng order
218
+ self.iter_num = 0
219
+ self.curr_num_func = 1
220
+ self.analyzer = PeptideAnalyzer()
221
+
222
+ # track fraction of valid peptides
223
+ self.valid_fraction_log = []
224
+ self.affinity1_log = []
225
+ self.affinity2_log = []
226
+ self.permeability_log = []
227
+ self.sol_log = []
228
+ self.hemo_log = []
229
+ self.nf_log = []
230
+
231
+ def reset(self):
232
+ self.iter_num = 0
233
+ self.valid_fraction_log = []
234
+ self.affinity1_log = []
235
+ self.affinity2_log = []
236
+ self.permeability_log = []
237
+ self.sol_log = []
238
+ self.hemo_log = []
239
+ self.nf_log = []
240
+ self.peptideParetoFront = {}
241
+
242
+ def forward(self, rootNode):
243
+ self.reset()
244
+
245
+ while (self.iter_num < self.num_iter):
246
+ self.iter_num += 1
247
+
248
+ # traverse the tree form the root node until a leaf node
249
+ leafNode, _ = self.select(rootNode)
250
+ #print(leafNode.tokens['input_ids'])
251
+
252
+ # expand leaf node into num_children partially unmasked sequences at the next timestep
253
+ self.expand(leafNode)
254
+
255
+ # return dictionary of pareto front peptides and their score vectors
256
+ return self.peptideParetoFront
257
+
258
+ # change to include more even if dominated? since there is error in the scores
259
+ def updateParetoFront(self, sequence, scoreVector, tokens):
260
+ """
261
+ Removes sequences that are dominated by scoreVector
262
+ adds the SMILES sequence if it is non-dominated and its scoreVector
263
+
264
+ num_func: index of the last objective to consider when updating the pareto front from 0 to K
265
+ """
266
+ paretoSize = len(self.peptideParetoFront)
267
+
268
+ self.curr_num_func = 1
269
+
270
+ for i in range(len(self.num_func)):
271
+ if self.iter_num >= self.num_func[i]:
272
+ self.curr_num_func = i+1
273
+
274
+ if paretoSize == 0:
275
+ # if pareto front is empty, add sequence and scoreVector
276
+ self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
277
+ # if pareto front is empty, set reward vector to 1s
278
+ rewardVector = np.ones(len(scoreVector))
279
+ else:
280
+ # vector of boolean
281
+ # true: sequence is non-dominated by the pareto-optimal sequence
282
+ # false: sequence is completely dominated by the pareto-optimal sequence
283
+ nondominate = []
284
+ # sequences to be deleted
285
+ delete = []
286
+ # initialize reward vector with zeros
287
+ rewardVector = np.zeros(len(scoreVector))
288
+ for k, v in self.peptideParetoFront.items():
289
+ # boolean vector
290
+ # true: if all metrics are equal or larger
291
+ # false: if the pareto front sequence dominates scoreVector
292
+ nondominated = scoreVector >= np.asarray(v['scores']) # [num_objectives]
293
+ dominant = scoreVector > np.asarray(v['scores'])
294
+ # add to reward vector
295
+ rewardVector += nondominated # [num_objectives]
296
+
297
+ if self.curr_num_func <= len(nondominated):
298
+ attn_nondominated = nondominated[:self.curr_num_func]
299
+ attn_dominant = dominant[:self.curr_num_func]
300
+
301
+ # only delete pareto-optimal sequence if
302
+ # all scores are greater than or equal to v and at least one score is strictly greater than v
303
+ if attn_nondominated.all() and attn_dominant.any():
304
+ # add the dominated sequence to be deleted
305
+ delete.append(k)
306
+ # sequence is dominant
307
+ nondominate.append(True)
308
+ elif attn_nondominated.all():
309
+ # sequence is non-dominated
310
+ nondominate.append(True)
311
+ else:
312
+ # sequence is completely dominated
313
+ nondominate.append(False)
314
+
315
+ assert len(nondominate) == paretoSize
316
+ nondominate = np.asarray(nondominate)
317
+ # if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
318
+ # or if the pareto front does not have enough sequences
319
+ if nondominate.all() or paretoSize < self.num_sequences:
320
+ self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
321
+
322
+ rewardVector = rewardVector / paretoSize
323
+
324
+ # delete all dominated sequences if pareto front is larger than num_sequences
325
+ while (paretoSize > self.num_sequences) and (len(delete) > 0):
326
+ #for k in delete:
327
+ del self.peptideParetoFront[delete[0]]
328
+ del delete[0]
329
+ paretoSize -= 1
330
+
331
+ return rewardVector
332
+
333
+ def isPathEnd(self, path, maxDepth):
334
+ """
335
+ Checks if the node is completely unmasked (ie. end of path)
336
+ or if the path is at the max depth
337
+ """
338
+ if (path[-1] != self.config.mcts.mask_token).all():
339
+ return True
340
+ elif len(path) >= maxDepth:
341
+ return True
342
+ return False
343
+
344
+ def select(self, currNode):
345
+ """
346
+ Traverse the tree from the root node until reaching a legal leaf node
347
+ """
348
+ while True:
349
+ currNode, nodeStatus = currNode.selectNode(self.curr_num_func)
350
+ if nodeStatus != 3:
351
+ return currNode, nodeStatus
352
+
353
+ def expand(self, parentNode, eps=1e-5, checkSimilarity = True):
354
+ """
355
+ Sample unmasking steps from the pre-trained MDLM
356
+ adds num_children partially unmasked sequences to the children of the parentNode
357
+ """
358
+
359
+ num_children = self.config.mcts.num_children
360
+ # initialize child rewards that will be added to total rewards
361
+ allChildReward = np.zeros_like(parentNode.totalReward) # (n_objectives)
362
+
363
+
364
+ # compute number of rollout steps
365
+ # if parentNode.timestep = self.num_steps then num_rollout_steps = 1
366
+ num_rollout_steps = self.num_steps - parentNode.timestep
367
+ # array of rollout timesteps from the timestep of parent node to 0
368
+ rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device)
369
+ dt = (1 - eps) / self.num_steps
370
+ p_x0_cache = None
371
+
372
+ # initialize x and attn_mask
373
+ x = parentNode.tokens['input_ids'].to(self.device)
374
+ attn_mask = parentNode.tokens['attention_mask'].to(self.device)
375
+
376
+ t = rollout_t[0] * torch.ones(num_children, 1, device = self.device)
377
+ # generate (n_children, seq_length) array of sampled children nodes
378
+ print("token array:")
379
+ print(x)
380
+ p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x,
381
+ t=t, dt=dt,
382
+ batch_size=num_children,
383
+ attn_mask=attn_mask)
384
+ x_rollout = x_children
385
+
386
+ for i in range(1, num_rollout_steps):
387
+ t = rollout_t[i] * torch.ones(num_children, 1, device = self.device)
388
+
389
+ p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout,
390
+ t=t, dt=dt, p_x0=p_x0_cache,
391
+ attn_mask=attn_mask)
392
+
393
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
394
+ # Disable caching
395
+ p_x0_cache = None
396
+
397
+ x_rollout = x_next
398
+
399
+ if self.config.sampling.noise_removal:
400
+ t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device)
401
+ """if self.sampler == 'analytic':
402
+ x = self.mdlm._denoiser_update(x, t)
403
+ else:"""
404
+ time_cond = self.noise(t)[0]
405
+ x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) # (n_children, seq_length)
406
+
407
+ childSequences = self.tokenizer.batch_decode(x_rollout)
408
+
409
+ validSequences = []
410
+ maskedTokens = []
411
+ unmaskedTokens = []
412
+ for i in range(num_children):
413
+ childSeq = childSequences[i]
414
+ #scoreVector = scoreVectors[i]
415
+ rewardVector = np.zeros(self.config.mcts.num_objectives)
416
+
417
+ # check if the peptide is valid
418
+ if self.analyzer.is_peptide(childSeq):
419
+ validSequences.append(childSeq)
420
+ maskedTokens.append(x_children[i])
421
+ unmaskedTokens.append(x_rollout[i])
422
+ else:
423
+ childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask}
424
+ parentNode.addChildNode(tokens=childTokens,
425
+ totalReward=rewardVector)
426
+
427
+ if (len(validSequences) != 0):
428
+ scoreVectors = self.score_functions(input_seqs=validSequences)
429
+ average_scores = scoreVectors.T
430
+ if self.config.mcts.single:
431
+ self.permeability_log.append(average_scores[0])
432
+ else:
433
+ self.affinity1_log.append(average_scores[0])
434
+ self.sol_log.append(average_scores[1])
435
+ self.hemo_log.append(average_scores[2])
436
+ self.nf_log.append(average_scores[3])
437
+ if self.config.mcts.perm:
438
+ self.permeability_log.append(average_scores[4])
439
+ elif self.config.mcts.dual:
440
+ self.affinity2_log.append(average_scores[4])
441
+ else:
442
+ self.affinity1_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
443
+ self.sol_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
444
+ self.hemo_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
445
+ self.nf_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
446
+
447
+ if self.config.mcts.perm:
448
+ self.permeability_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
449
+ elif self.config.mcts.dual:
450
+ self.affinity2_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
451
+
452
+ for i, validSeq in enumerate(validSequences):
453
+ #tokens = validTokens[i]
454
+ scoreVector = scoreVectors[i]
455
+
456
+ # update pareto front
457
+ rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i])
458
+ print(scoreVector)
459
+ print(rewardVector)
460
+
461
+ # add to all child reward vector for backprop
462
+ allChildReward += rewardVector
463
+
464
+ # create node for sequence and add to the children node of parent
465
+ childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask}
466
+ parentNode.addChildNode(tokens=childTokens,
467
+ totalReward=rewardVector)
468
+
469
+ # compute fraction of invalid child sequences
470
+ invalid = (num_children - len(validSequences)) / num_children
471
+
472
+ valid_fraction = len(validSequences) / num_children
473
+ print(f"Valid fraction: {valid_fraction}")
474
+ self.valid_fraction_log.append(valid_fraction)
475
+
476
+ print(self.config.mcts.invalid_penalty)
477
+ # subtract score using fraction of invalid sequences from reward
478
+ allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid)
479
+ # backpropogate all child rewards
480
+ self.backprop(parentNode, allChildReward)
481
+
482
+
483
+ def backprop(self, node, reward_vector):
484
+ # backpropogate rewards through the path leading to the leaf node from the root
485
+ while node:
486
+ node.updateNode(reward_vector)
487
+ node = node.parentNode
488
+
489
+
490
+ def getSequenceForObjective(self, objective_index, k):
491
+ """
492
+ Returns the top-k sequences in the pareto front that has the best score for
493
+ a given objective and their score vectors for all objectives
494
+ """
495
+
496
+ # dictionary of top-k peptides for the objective
497
+ topk = {}
498
+
499
+ peptides = []
500
+ objectiveScores = []
501
+ for k, v in self.peptideParetoFront.items():
502
+ # store peptides in list
503
+ peptides.append(k)
504
+ # store score for objective
505
+ objectiveScores.append(v['token_ids'][objective_index])
506
+
507
+ objectiveScores = torch.tensor(objectiveScores)
508
+ topKScores = torch.topk(objectiveScores, k)
509
+ for (_, index) in topKScores.items():
510
+ seq = peptides[index]
511
+
512
+ topk[seq] = self.peptideParetoFront.get(seq)
513
+
514
+ return topk
515
+
roformer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RoFormerConfig, RoFormerForMaskedLM
2
+ import torch.nn as nn
3
+ from torch.nn.parallel import DistributedDataParallel as DDP
4
+ import torch
5
+
6
+ class Roformer(nn.Module):
7
+ def __init__(self, config, tokenizer):
8
+ super(Roformer, self).__init__()
9
+
10
+ self.tokenizer = tokenizer
11
+ self.vocab_size = self.tokenizer.vocab_size
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.device = device
14
+
15
+
16
+ roformer_config = RoFormerConfig(
17
+ vocab_size=self.tokenizer.vocab_size,
18
+ embedding_size=config.roformer.hidden_size,
19
+ hidden_size=config.roformer.hidden_size,
20
+ num_hidden_layers=config.roformer.n_layers,
21
+ num_attention_heads=config.roformer.n_heads,
22
+ intermediate_size=config.roformer.hidden_size * 4,
23
+ max_position_embeddings=config.roformer.max_position_embeddings,
24
+ hidden_dropout_prob=0.1,
25
+ attention_probs_dropout_prob=0.1,
26
+ pad_token_id=0,
27
+ rotary_value=False
28
+ )
29
+
30
+ self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
31
+
32
+ def freeze_model(self):
33
+ for param in self.model.parameters():
34
+ param.requires_grad = False
35
+
36
+ def unfreeze_all_layers(self):
37
+ for param in self.model.parameters():
38
+ param.requires_grad = True
39
+
40
+ def unfreeze_n_layers(self, n):
41
+ num_layers = 8
42
+
43
+ for i, layer in enumerate(self.model.roformer.encoder.layer):
44
+ # finetune final n layers
45
+ if i >= num_layers - n:
46
+ # unfreeze query weights
47
+ for module in layer.attention.self.query.modules():
48
+ for param in module.parameters():
49
+ param.requires_grad = True
50
+ # unfreeze key weights
51
+ for module in layer.attention.self.key.modules():
52
+ for param in module.parameters():
53
+ param.requires_grad = True
54
+
55
+ def forward(self, input_ids, attn_mask):
56
+
57
+ input_ids = input_ids.to(self.device)
58
+ attn_mask = attn_mask.to(self.device)
59
+
60
+ # get logits embeddings
61
+ logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
62
+ # return logits
63
+ #print(logits.logits)
64
+ return logits.logits
65
+
66
+ def save_model(self, save_dir):
67
+ self.model.save_pretrained(save_dir)
68
+ self.tokenizer.save_pretrained(save_dir)
69
+
70
+ @classmethod
71
+ def load_model(cls, save_dir, config, tokenizer):
72
+ roformer = cls(config, tokenizer)
73
+ roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
74
+ return roformer
scoring/__init__.py ADDED
File without changes
scoring/binary_xg.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.metrics import precision_recall_curve, f1_score
6
+ import optuna
7
+ from optuna.trial import TrialState
8
+ import xgboost as xgb
9
+ import os
10
+ from datasets import load_from_disk
11
+ from lightning.pytorch import seed_everything
12
+ from rdkit import Chem, rdBase, DataStructs
13
+ from typing import List
14
+ from rdkit.Chem import AllChem
15
+ import matplotlib.pyplot as plt
16
+ from sklearn.metrics import accuracy_score, roc_auc_score
17
+ import seaborn as sns
18
+
19
+ def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
20
+ """
21
+ Saves the true and predicted values for training and validation sets, and generates binary classification plots.
22
+
23
+ Parameters:
24
+ y_true_train (array): True labels for the training set.
25
+ y_pred_train (array): Predicted probabilities for the training set.
26
+ y_true_val (array): True labels for the validation set.
27
+ y_pred_val (array): Predicted probabilities for the validation set.
28
+ threshold (float): Classification threshold for predictions.
29
+ output_path (str): Directory to save the CSV files and plots.
30
+ """
31
+ os.makedirs(output_path, exist_ok=True)
32
+
33
+ # Convert probabilities to binary predictions
34
+ y_pred_train_binary = (y_pred_train >= threshold).astype(int)
35
+ y_pred_val_binary = (y_pred_val >= threshold).astype(int)
36
+
37
+ # Save training predictions
38
+ train_df = pd.DataFrame({
39
+ 'True Label': y_true_train,
40
+ 'Predicted Probability': y_pred_train,
41
+ 'Predicted Label': y_pred_train_binary
42
+ })
43
+ train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)
44
+
45
+ # Save validation predictions
46
+ val_df = pd.DataFrame({
47
+ 'True Label': y_true_val,
48
+ 'Predicted Probability': y_pred_val,
49
+ 'Predicted Label': y_pred_val_binary
50
+ })
51
+ val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)
52
+
53
+ # Plot training predictions
54
+ plot_boxplot_with_threshold(
55
+ y_true_train,
56
+ y_pred_train,
57
+ threshold,
58
+ title="Training Set Binary Classification Plot",
59
+ output_file=os.path.join(output_path, 'train_classification_plot.png')
60
+ )
61
+
62
+ # Plot validation predictions
63
+ plot_boxplot_with_threshold(
64
+ y_true_val,
65
+ y_pred_val,
66
+ threshold,
67
+ title="Validation Set Binary Classification Plot",
68
+ output_file=os.path.join(output_path, 'val_classification_plot.png')
69
+ )
70
+
71
+ def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
72
+ # Scatter plot
73
+ plt.figure(figsize=(10, 8))
74
+ plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
75
+
76
+ # Add threshold line
77
+ plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
78
+
79
+ # Add annotations
80
+ plt.title(title)
81
+ plt.xlabel("True Labels")
82
+ plt.ylabel("Predicted Probability")
83
+ plt.legend()
84
+
85
+ # Save and show the plot
86
+ plt.tight_layout()
87
+ plt.savefig(output_file)
88
+ plt.show()
89
+
90
+ def plot_boxplot_with_threshold(y_true, y_pred, threshold, title, output_file):
91
+ """
92
+ Generates a boxplot for binary classification and includes a threshold line.
93
+
94
+ Parameters:
95
+ y_true (array): True labels.
96
+ y_pred (array): Predicted probabilities.
97
+ threshold (float): Classification threshold for predictions.
98
+ title (str): Title of the plot.
99
+ output_file (str): Path to save the plot.
100
+ """
101
+ plt.figure(figsize=(10, 8))
102
+
103
+ # Combine data into a DataFrame for seaborn
104
+ df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred})
105
+
106
+ # Boxplot
107
+ sns.boxplot(x='True Label', y='Predicted Probability', data=df)
108
+
109
+ # Add threshold line
110
+ plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
111
+ plt.text(
112
+ x=0.5, y=threshold + 0.05, s=f"Threshold = {threshold}", color="red", fontsize=10
113
+ )
114
+
115
+ # Add annotations
116
+ plt.title(title)
117
+ plt.xlabel("True Label")
118
+ plt.ylabel("Predicted Probability")
119
+ plt.legend()
120
+
121
+ # Save and show the plot
122
+ plt.tight_layout()
123
+ plt.savefig(output_file)
124
+ plt.show()
125
+
126
+ def plot_boxplot(y_true, y_pred, title, output_file):
127
+ plt.figure(figsize=(10, 8))
128
+
129
+ # Combine data into a single DataFrame for seaborn
130
+ df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred})
131
+ sns.boxplot(x='True Label', y='Predicted Probability', data=df)
132
+
133
+ # Add annotations
134
+ plt.title(title)
135
+ plt.xlabel("True Label")
136
+ plt.ylabel("Predicted Probability")
137
+
138
+ # Save and show the plot
139
+ plt.tight_layout()
140
+ plt.savefig(output_file)
141
+ plt.show()
142
+
143
+ def plot_binary_correlation_with_density(y_true, y_pred, threshold, title, output_file):
144
+ """
145
+ Generates a scatter plot with a density plot for binary classification and saves it to a file.
146
+ """
147
+ plt.figure(figsize=(10, 8))
148
+
149
+ # Scatter plot
150
+ plt.scatter(range(len(y_true)), y_pred, alpha=0.5, label='Predicted Probabilities', color='#BC80FF')
151
+
152
+ # Add density plot
153
+ sns.kdeplot(y_pred, color='green', fill=True, alpha=0.3, label='Probability Density')
154
+
155
+ # Add threshold line
156
+ plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
157
+
158
+ # Add annotations
159
+ plt.title(title)
160
+ plt.xlabel("Index")
161
+ plt.ylabel("Predicted Probability")
162
+ plt.legend()
163
+
164
+ # Save and show the plot
165
+ plt.tight_layout()
166
+ plt.savefig(output_file)
167
+ plt.show()
168
+
169
+ seed_everything(42)
170
+
171
+ dataset = load_from_disk('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_data')
172
+
173
+ sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings
174
+ labels = np.stack(dataset['labels'])
175
+ embeddings = np.stack(dataset['embedding'])
176
+
177
+ # Initialize best F1 score and model path
178
+ best_f1 = -np.inf
179
+ best_model_path = "/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_train/"
180
+
181
+ # Trial callback
182
+ def trial_info_callback(study, trial):
183
+ if study.best_trial == trial:
184
+ print(f"Trial {trial.number}:")
185
+ print(f" Weighted F1 Score: {trial.value}")
186
+
187
+
188
+ def objective(trial):
189
+ params = {
190
+ 'objective': 'binary:logistic',
191
+ 'lambda': trial.suggest_float('lambda', 1e-8, 10.0, log=True),
192
+ 'alpha': trial.suggest_float('alpha', 1e-8, 10.0, log=True),
193
+ 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.1, 1.0),
194
+ 'subsample': trial.suggest_float('subsample', 0.1, 1.0),
195
+ 'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3),
196
+ 'max_depth': trial.suggest_int('max_depth', 2, 30),
197
+ 'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
198
+ 'tree_method': 'hist',
199
+ 'device': 'cuda:0',
200
+ }
201
+ num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
202
+
203
+ # Split the data
204
+ train_idx, val_idx = train_test_split(
205
+ np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
206
+ )
207
+ train_subset = dataset.select(train_idx).with_format("torch")
208
+ val_subset = dataset.select(val_idx).with_format("torch")
209
+
210
+ # Extract embeddings and labels for train/validation
211
+ train_embeddings = train_subset['embedding']
212
+ valid_embeddings = val_subset['embedding']
213
+ train_labels = train_subset['labels']
214
+ valid_labels = val_subset['labels']
215
+
216
+ # Prepare training and validation sets
217
+ dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
218
+ dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)
219
+
220
+ # Train the model
221
+ model = xgb.train(
222
+ params=params,
223
+ dtrain=dtrain,
224
+ num_boost_round=num_boost_round,
225
+ evals=[(dvalid, "validation")],
226
+ early_stopping_rounds=50,
227
+ verbose_eval=False,
228
+ )
229
+
230
+ # Predict probabilities
231
+ preds_train = model.predict(dtrain)
232
+ preds_val = model.predict(dvalid)
233
+
234
+ # Perform dynamic thresholding on validation predictions
235
+ best_f1_val = -np.inf
236
+ best_threshold = 0.5
237
+
238
+ for threshold in np.arange(0.1, 1.0, 0.05): # Try thresholds from 0.1 to 1.0
239
+ preds_val_binary = (preds_val >= threshold).astype(int)
240
+ f1_temp = f1_score(valid_labels, preds_val_binary, average="weighted")
241
+ if f1_temp > best_f1_val:
242
+ best_f1_val = f1_temp
243
+ best_threshold = threshold
244
+
245
+ print(f"Best F1 Score: {best_f1_val:.3f} at Threshold: {best_threshold:.3f}")
246
+
247
+ # Calculate AUC for additional insight
248
+ auc_val = roc_auc_score(valid_labels, preds_val)
249
+ print(f"AUC: {auc_val:.3f}")
250
+
251
+ # Save the best model if the F1 score is improved
252
+ if trial.study.user_attrs.get("best_f1", -np.inf) < best_f1_val:
253
+ trial.study.set_user_attr("best_f1", best_f1_val)
254
+ trial.study.set_user_attr("best_threshold", best_threshold) # Save the best threshold
255
+ os.makedirs(best_model_path, exist_ok=True)
256
+
257
+ model.save_model(os.path.join(best_model_path, "best_model.json"))
258
+ print(f"Best model saved to {os.path.join(best_model_path, 'best_model.json')}")
259
+
260
+ # Save and plot binary predictions with the best threshold
261
+ save_and_plot_binary_predictions(
262
+ train_labels,
263
+ preds_train,
264
+ valid_labels,
265
+ preds_val,
266
+ best_threshold,
267
+ best_model_path
268
+ )
269
+
270
+ return best_f1_val
271
+
272
+ if __name__ == "__main__":
273
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
274
+ study.optimize(objective, n_trials=200)
275
+
276
+ print("Study statistics: ")
277
+ print(f" Number of finished trials: {len(study.trials)}")
278
+ print(f" Best AUC: {study.user_attrs.get('best_auc', None)}")
279
+ for key, value in study.best_trial.params.items():
280
+ print(f" {key}: {value}")
scoring/functions/binding.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
3
+ import numpy as np
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from sklearn.model_selection import train_test_split
6
+ from collections import defaultdict
7
+ import torch
8
+ import pandas as pd
9
+ import torch.nn as nn
10
+ import esm
11
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
12
+ from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModel
13
+ from peft import PeftModel, PeftConfig
14
+
15
+
16
+ class ImprovedBindingPredictor(nn.Module):
17
+ def __init__(self,
18
+ esm_dim=1280,
19
+ smiles_dim=768,
20
+ hidden_dim=512,
21
+ n_heads=8,
22
+ n_layers=3,
23
+ dropout=0.1):
24
+ super().__init__()
25
+
26
+ # Define binding thresholds
27
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
28
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
29
+
30
+ # Project to same dimension
31
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
32
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
33
+ self.protein_norm = nn.LayerNorm(hidden_dim)
34
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
35
+
36
+ # Cross attention blocks with layer norm
37
+ self.cross_attention_layers = nn.ModuleList([
38
+ nn.ModuleDict({
39
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
40
+ 'norm1': nn.LayerNorm(hidden_dim),
41
+ 'ffn': nn.Sequential(
42
+ nn.Linear(hidden_dim, hidden_dim * 4),
43
+ nn.ReLU(),
44
+ nn.Dropout(dropout),
45
+ nn.Linear(hidden_dim * 4, hidden_dim)
46
+ ),
47
+ 'norm2': nn.LayerNorm(hidden_dim)
48
+ }) for _ in range(n_layers)
49
+ ])
50
+
51
+ # Prediction heads
52
+ self.shared_head = nn.Sequential(
53
+ nn.Linear(hidden_dim * 2, hidden_dim),
54
+ nn.ReLU(),
55
+ nn.Dropout(dropout),
56
+ )
57
+
58
+ # Regression head
59
+ self.regression_head = nn.Linear(hidden_dim, 1)
60
+
61
+ # Classification head (3 classes: tight, medium, loose binding)
62
+ self.classification_head = nn.Linear(hidden_dim, 3)
63
+
64
+ def get_binding_class(self, affinity):
65
+ """Convert affinity values to class indices
66
+ 0: tight binding (>= 7.5)
67
+ 1: medium binding (6.0-7.5)
68
+ 2: weak binding (< 6.0)
69
+ """
70
+ if isinstance(affinity, torch.Tensor):
71
+ tight_mask = affinity >= self.tight_threshold
72
+ weak_mask = affinity < self.weak_threshold
73
+ medium_mask = ~(tight_mask | weak_mask)
74
+
75
+ classes = torch.zeros_like(affinity, dtype=torch.long)
76
+ classes[medium_mask] = 1
77
+ classes[weak_mask] = 2
78
+ return classes
79
+ else:
80
+ if affinity >= self.tight_threshold:
81
+ return 0 # tight binding
82
+ elif affinity < self.weak_threshold:
83
+ return 2 # weak binding
84
+ else:
85
+ return 1 # medium binding
86
+
87
+ def forward(self, protein_emb, smiles_emb):
88
+ protein = self.protein_norm(self.protein_projection(protein_emb))
89
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
90
+
91
+ #protein = protein.transpose(0, 1)
92
+ #smiles = smiles.transpose(0, 1)
93
+
94
+ # Cross attention layers
95
+ for layer in self.cross_attention_layers:
96
+ # Protein attending to SMILES
97
+ attended_protein = layer['attention'](
98
+ protein, smiles, smiles
99
+ )[0]
100
+ protein = layer['norm1'](protein + attended_protein)
101
+ protein = layer['norm2'](protein + layer['ffn'](protein))
102
+
103
+ # SMILES attending to protein
104
+ attended_smiles = layer['attention'](
105
+ smiles, protein, protein
106
+ )[0]
107
+ smiles = layer['norm1'](smiles + attended_smiles)
108
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
109
+
110
+ # Get sequence-level representations
111
+ protein_pool = torch.mean(protein, dim=0)
112
+ smiles_pool = torch.mean(smiles, dim=0)
113
+
114
+ # Concatenate both representations
115
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
116
+
117
+ # Shared features
118
+ shared_features = self.shared_head(combined)
119
+
120
+ regression_output = self.regression_head(shared_features)
121
+ classification_logits = self.classification_head(shared_features)
122
+
123
+ return regression_output, classification_logits
124
+
125
+ class BindingAffinity:
126
+ def __init__(self, prot_seq, model_type='PeptideCLM'):
127
+ super().__init__()
128
+
129
+ if model_type == 'PepDoRA':
130
+ # peptide embeddings
131
+ model_name = "ChatterjeeLab/PepDoRA"
132
+ self.pep_tokenizer = AutoTokenizer.from_pretrained(model_name)
133
+ self.pep_model = AutoModel.from_pretrained(model_name)
134
+
135
+ self.model = ImprovedBindingPredictor(smiles_dim=384)
136
+ checkpoint = torch.load('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/binding/best_model_optuna1.pt')
137
+ self.model.load_state_dict(checkpoint['model_state_dict'])
138
+ else:
139
+ # peptide embeddings
140
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
141
+ self.pep_tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
142
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
143
+
144
+ self.model = ImprovedBindingPredictor()
145
+ checkpoint = torch.load('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/binding/best_model.pt')
146
+ self.model.load_state_dict(checkpoint['model_state_dict'])
147
+
148
+ self.model.eval()
149
+
150
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
151
+ self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
152
+
153
+ data = [("target", prot_seq)]
154
+ # get tokenized protein
155
+ _, _, prot_tokens = self.prot_tokenizer(data)
156
+ with torch.no_grad():
157
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
158
+ prot_emb = results["representations"][33]
159
+
160
+ self.prot_emb = prot_emb[0]
161
+ self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
162
+
163
+
164
+ def forward(self, input_seqs):
165
+ with torch.no_grad():
166
+ scores = []
167
+ for seq in input_seqs:
168
+ pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
169
+
170
+ with torch.no_grad():
171
+ emb = self.pep_model(input_ids=pep_tokens['input_ids'],
172
+ attention_mask=pep_tokens['attention_mask'],
173
+ output_hidden_states=True)
174
+
175
+ #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
176
+ pep_emb = emb.last_hidden_state.squeeze(0)
177
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
178
+
179
+ score, logits = self.model.forward(self.prot_emb, pep_emb)
180
+ scores.append(score.item())
181
+ return scores
182
+
183
+ def __call__(self, input_seqs: list):
184
+ return self.forward(input_seqs)
185
+
186
+ def unittest():
187
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
188
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
189
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
190
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
191
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
192
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
193
+
194
+ binding = BindingAffinity(tfr)
195
+ seq = ["CC[C@H](C)[C@H](NC(=O)[C@H](C)NC(=O)[C@@H](N)Cc1c[nH]cn1)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)O"]
196
+
197
+ scores = binding(seq)
198
+ print(scores)
199
+ print(len(scores))
200
+
201
+ if __name__ == '__main__':
202
+ unittest()
scoring/functions/binding_utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import pdb
3
+ import torch
4
+ import numpy as np
5
+
6
+ def to_var(x):
7
+ if torch.cuda.is_available():
8
+ x = x.cuda()
9
+ return x
10
+
11
+ class MultiHeadAttentionSequence(nn.Module):
12
+
13
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
14
+
15
+ super().__init__()
16
+
17
+ self.n_head = n_head
18
+ self.d_model = d_model
19
+ self.d_k = d_k
20
+ self.d_v = d_v
21
+
22
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
23
+ self.W_K = nn.Linear(d_model, n_head*d_k)
24
+ self.W_V = nn.Linear(d_model, n_head*d_v)
25
+ self.W_O = nn.Linear(n_head*d_v, d_model)
26
+
27
+ self.layer_norm = nn.LayerNorm(d_model)
28
+
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, q, k, v):
32
+
33
+ batch, len_q, _ = q.size()
34
+ batch, len_k, _ = k.size()
35
+ batch, len_v, _ = v.size()
36
+
37
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
38
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
39
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
40
+
41
+ Q = Q.transpose(1, 2)
42
+ K = K.transpose(1, 2).transpose(2, 3)
43
+ V = V.transpose(1, 2)
44
+
45
+ attention = torch.matmul(Q, K)
46
+
47
+ attention = attention / np.sqrt(self.d_k)
48
+
49
+ attention = F.softmax(attention, dim=-1)
50
+
51
+ output = torch.matmul(attention, V)
52
+
53
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
54
+
55
+ output = self.W_O(output)
56
+
57
+ output = self.dropout(output)
58
+
59
+ output = self.layer_norm(output + q)
60
+
61
+ return output, attention
62
+
63
+ class MultiHeadAttentionReciprocal(nn.Module):
64
+
65
+
66
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
67
+
68
+ super().__init__()
69
+
70
+ self.n_head = n_head
71
+ self.d_model = d_model
72
+ self.d_k = d_k
73
+ self.d_v = d_v
74
+
75
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
76
+ self.W_K = nn.Linear(d_model, n_head*d_k)
77
+ self.W_V = nn.Linear(d_model, n_head*d_v)
78
+ self.W_O = nn.Linear(n_head*d_v, d_model)
79
+ self.W_V_2 = nn.Linear(d_model, n_head*d_v)
80
+ self.W_O_2 = nn.Linear(n_head*d_v, d_model)
81
+
82
+ self.layer_norm = nn.LayerNorm(d_model)
83
+
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ self.layer_norm_2 = nn.LayerNorm(d_model)
87
+
88
+ self.dropout_2 = nn.Dropout(dropout)
89
+
90
+ def forward(self, q, k, v, v_2):
91
+
92
+ batch, len_q, _ = q.size()
93
+ batch, len_k, _ = k.size()
94
+ batch, len_v, _ = v.size()
95
+ batch, len_v_2, _ = v_2.size()
96
+
97
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
98
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
99
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
100
+ V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
101
+
102
+ Q = Q.transpose(1, 2)
103
+ K = K.transpose(1, 2).transpose(2, 3)
104
+ V = V.transpose(1, 2)
105
+ V_2 = V_2.transpose(1,2)
106
+
107
+ attention = torch.matmul(Q, K)
108
+
109
+
110
+ attention = attention /np.sqrt(self.d_k)
111
+
112
+ attention_2 = attention.transpose(-2, -1)
113
+
114
+
115
+
116
+ attention = F.softmax(attention, dim=-1)
117
+
118
+ attention_2 = F.softmax(attention_2, dim=-1)
119
+
120
+
121
+ output = torch.matmul(attention, V)
122
+
123
+ output_2 = torch.matmul(attention_2, V_2)
124
+
125
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
126
+
127
+ output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
128
+
129
+ output = self.W_O(output)
130
+
131
+ output_2 = self.W_O_2(output_2)
132
+
133
+ output = self.dropout(output)
134
+
135
+ output = self.layer_norm(output + q)
136
+
137
+ output_2 = self.dropout(output_2)
138
+
139
+ output_2 = self.layer_norm(output_2 + k)
140
+
141
+
142
+ return output, output_2, attention, attention_2
143
+
144
+
145
+ class FFN(nn.Module):
146
+
147
+ def __init__(self, d_in, d_hid, dropout=0.1):
148
+ super().__init__()
149
+
150
+ self.layer_1 = nn.Conv1d(d_in, d_hid,1)
151
+ self.layer_2 = nn.Conv1d(d_hid, d_in,1)
152
+ self.relu = nn.ReLU()
153
+ self.layer_norm = nn.LayerNorm(d_in)
154
+
155
+ self.dropout = nn.Dropout(dropout)
156
+
157
+ def forward(self, x):
158
+
159
+ residual = x
160
+ output = self.layer_1(x.transpose(1, 2))
161
+
162
+ output = self.relu(output)
163
+
164
+ output = self.layer_2(output)
165
+
166
+ output = self.dropout(output)
167
+
168
+ output = self.layer_norm(output.transpose(1, 2)+residual)
169
+
170
+ return output
171
+
172
+ class ConvLayer(nn.Module):
173
+ def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
174
+ super(ConvLayer, self).__init__()
175
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
176
+ self.relu = nn.ReLU()
177
+
178
+ def forward(self, x):
179
+ out = self.conv(x)
180
+ out = self.relu(out)
181
+ return out
182
+
183
+
184
+ class DilatedCNN(nn.Module):
185
+ def __init__(self, d_model, d_hidden):
186
+ super(DilatedCNN, self).__init__()
187
+ self.first_ = nn.ModuleList()
188
+ self.second_ = nn.ModuleList()
189
+ self.third_ = nn.ModuleList()
190
+
191
+ dilation_tuple = (1, 2, 3)
192
+ dim_in_tuple = (d_model, d_hidden, d_hidden)
193
+ dim_out_tuple = (d_hidden, d_hidden, d_hidden)
194
+
195
+ for i, dilation_rate in enumerate(dilation_tuple):
196
+ self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
197
+ dilation=dilation_rate))
198
+
199
+ for i, dilation_rate in enumerate(dilation_tuple):
200
+ self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
201
+ dilation=dilation_rate))
202
+
203
+ for i, dilation_rate in enumerate(dilation_tuple):
204
+ self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
205
+ dilation=dilation_rate))
206
+
207
+ def forward(self, protein_seq_enc):
208
+ # pdb.set_trace()
209
+ protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
210
+
211
+ first_embedding = protein_seq_enc
212
+ second_embedding = protein_seq_enc
213
+ third_embedding = protein_seq_enc
214
+
215
+ for i in range(len(self.first_)):
216
+ first_embedding = self.first_[i](first_embedding)
217
+
218
+ for i in range(len(self.second_)):
219
+ second_embedding = self.second_[i](second_embedding)
220
+
221
+ for i in range(len(self.third_)):
222
+ third_embedding = self.third_[i](third_embedding)
223
+
224
+ # pdb.set_trace()
225
+
226
+ protein_seq_enc = first_embedding + second_embedding + third_embedding
227
+
228
+ return protein_seq_enc.transpose(1, 2)
229
+
230
+
231
+ class ReciprocalLayerwithCNN(nn.Module):
232
+
233
+ def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
234
+ super().__init__()
235
+
236
+ self.cnn = DilatedCNN(d_model, d_hidden)
237
+
238
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
239
+
240
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
241
+
242
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
243
+
244
+ self.ffn_seq = FFN(d_hidden, d_inner)
245
+
246
+ self.ffn_protein = FFN(d_hidden, d_inner)
247
+
248
+ def forward(self, sequence_enc, protein_seq_enc):
249
+ # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
250
+ protein_seq_enc = self.cnn(protein_seq_enc)
251
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
252
+
253
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
254
+
255
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
256
+
257
+ prot_enc = self.ffn_protein(prot_enc)
258
+
259
+ seq_enc = self.ffn_seq(seq_enc)
260
+
261
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
262
+
263
+
264
+ class ReciprocalLayer(nn.Module):
265
+
266
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v):
267
+
268
+ super().__init__()
269
+
270
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
271
+
272
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
273
+
274
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
275
+
276
+ self.ffn_seq = FFN(d_model, d_inner)
277
+
278
+ self.ffn_protein = FFN(d_model, d_inner)
279
+
280
+ def forward(self, sequence_enc, protein_seq_enc):
281
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
282
+
283
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
284
+
285
+
286
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
287
+ prot_enc = self.ffn_protein(prot_enc)
288
+
289
+ seq_enc = self.ffn_seq(seq_enc)
290
+
291
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
scoring/functions/nonfouling.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
4
+ import xgboost as xgb
5
+ import torch
6
+ import numpy as np
7
+ from transformers import AutoModelForMaskedLM
8
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
9
+ import warnings
10
+ import numpy as np
11
+ from rdkit import Chem, rdBase, DataStructs
12
+ from transformers import AutoModelForMaskedLM
13
+
14
+
15
+ rdBase.DisableLog('rdApp.error')
16
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
17
+ warnings.filterwarnings("ignore", category=UserWarning)
18
+ warnings.filterwarnings("ignore", category=FutureWarning)
19
+
20
+ class Nonfouling:
21
+
22
+ def __init__(self):
23
+ self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/nonfouling/new_data/best_model.json')
24
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
25
+ self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
26
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
27
+
28
+ def generate_embeddings(self, sequences):
29
+ embeddings = []
30
+ for sequence in sequences:
31
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
32
+ with torch.no_grad():
33
+ output = self.emb_model(**tokenized)
34
+ # Mean pooling across sequence length
35
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
36
+ embeddings.append(embedding)
37
+ return np.array(embeddings)
38
+
39
+ def get_scores(self, input_seqs: list):
40
+ scores = np.zeros(len(input_seqs))
41
+ features = self.generate_embeddings(input_seqs)
42
+
43
+ if len(features) == 0:
44
+ return scores
45
+
46
+ features = np.nan_to_num(features, nan=0.)
47
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
48
+
49
+ features = xgb.DMatrix(features)
50
+
51
+ scores = self.predictor.predict(features)
52
+ # return the probability of it being not hemolytic
53
+ return scores
54
+
55
+ def __call__(self, input_seqs: list):
56
+ scores = self.get_scores(input_seqs)
57
+ return scores
58
+
59
+ def unittest():
60
+ nf = Nonfouling()
61
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
62
+
63
+ scores = nf(input_seqs=seq)
64
+ print(scores)
65
+
66
+
67
+ if __name__ == '__main__':
68
+ unittest()
scoring/functions/permeability.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
4
+ import xgboost as xgb
5
+ import torch
6
+ import numpy as np
7
+ from transformers import AutoModelForMaskedLM
8
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
9
+ import warnings
10
+ import numpy as np
11
+ from rdkit.Chem import Descriptors, rdMolDescriptors
12
+ from rdkit import Chem, rdBase, DataStructs
13
+ from rdkit.Chem import AllChem
14
+ from typing import List
15
+ from scoring.functions.transformation import TransformFunction
16
+ from transformers import AutoModelForMaskedLM
17
+
18
+
19
+ rdBase.DisableLog('rdApp.error')
20
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ warnings.filterwarnings("ignore", category=FutureWarning)
23
+
24
+ def fingerprints_from_smiles(smiles: List, size=2048):
25
+ """ Create ECFP fingerprints of smiles, with validity check """
26
+ fps = []
27
+ valid_mask = []
28
+ for i, smile in enumerate(smiles):
29
+ mol = Chem.MolFromSmiles(smile)
30
+ valid_mask.append(int(mol is not None))
31
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
32
+ fps.append(fp)
33
+
34
+ fps = np.concatenate(fps, axis=0)
35
+ return fps, valid_mask
36
+
37
+
38
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
39
+ """ Create ECFP fingerprint of a molecule """
40
+ if hashed:
41
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
42
+ else:
43
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
44
+ fp_np = np.zeros((1,))
45
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
46
+ return fp_np.reshape(1, -1)
47
+
48
+ def getMolDescriptors(mol, missingVal=0):
49
+ """ calculate the full list of descriptors for a molecule """
50
+
51
+ values, names = [], []
52
+ for nm, fn in Descriptors._descList:
53
+ try:
54
+ val = fn(mol)
55
+ except:
56
+ val = missingVal
57
+ values.append(val)
58
+ names.append(nm)
59
+
60
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
61
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
62
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
63
+
64
+ for nm, fn in custom_descriptors.items():
65
+ try:
66
+ val = fn(mol)
67
+ except:
68
+ val = missingVal
69
+ values.append(val)
70
+ names.append(nm)
71
+ return values, names
72
+
73
+ def get_pep_dps_from_smi(smi):
74
+ try:
75
+ mol = Chem.MolFromSmiles(smi)
76
+ except:
77
+ print(f"convert smi {smi} to molecule failed!")
78
+ mol = None
79
+
80
+ dps, _ = getMolDescriptors(mol)
81
+ return np.array(dps)
82
+
83
+
84
+ def get_pep_dps(smi_list):
85
+ if len(smi_list) == 0:
86
+ return np.zeros((0, 213))
87
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
88
+
89
+ def check_smi_validity(smiles: list):
90
+ valid_smi, valid_idx = [], []
91
+ for idx, smi in enumerate(smiles):
92
+ try:
93
+ mol = Chem.MolFromSmiles(smi) if smi else None
94
+ if mol:
95
+ valid_smi.append(smi)
96
+ valid_idx.append(idx)
97
+ except Exception as e:
98
+ # logger.debug(f'Error: {e} in smiles {smi}')
99
+ pass
100
+ return valid_smi, valid_idx
101
+
102
+ class Permeability:
103
+
104
+ def __init__(self):
105
+ self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/permeability/30K-train/best_model.json')
106
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
107
+ self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt', '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
108
+
109
+ def generate_embeddings(self, sequences):
110
+ embeddings = []
111
+ for sequence in sequences:
112
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
113
+ with torch.no_grad():
114
+ output = self.emb_model(**tokenized)
115
+ # Mean pooling across sequence length
116
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
117
+ embeddings.append(embedding)
118
+ return np.array(embeddings)
119
+
120
+ def get_features(self, input_seqs: list, dps=False, fps=False):
121
+ #valid_smiles, valid_idxes = check_smi_validity(input_seqs)
122
+
123
+
124
+ if fps:
125
+ fingerprints = fingerprints_from_smiles(input_seqs)[0]
126
+ else:
127
+ fingerprints = torch.empty((len(input_seqs), 0))
128
+
129
+ if dps:
130
+ descriptors = get_pep_dps(input_seqs)
131
+ else:
132
+ descriptors = torch.empty((len(input_seqs), 0))
133
+
134
+ embeddings = self.generate_embeddings(input_seqs)
135
+ # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
136
+
137
+ features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
138
+
139
+ return features
140
+
141
+ def get_scores(self, input_seqs: list):
142
+ scores = -10 * np.ones(len(input_seqs))
143
+ features = self.get_features(input_seqs)
144
+
145
+ if len(features) == 0:
146
+ return scores
147
+
148
+ features = np.nan_to_num(features, nan=0.)
149
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
150
+
151
+ features = xgb.DMatrix(features)
152
+
153
+ scores = self.predictor.predict(features)
154
+ return scores
155
+
156
+ def __call__(self, input_seqs: list):
157
+ scores = self.get_scores(input_seqs)
158
+ return scores
159
+
160
+ def unittest():
161
+ permeability = Permeability()
162
+ seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
163
+ scores = permeability(input_seqs=seq)
164
+ print(scores)
165
+
166
+
167
+ if __name__ == '__main__':
168
+ unittest()
scoring/functions/permeability_xg.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import optuna
4
+ from optuna.trial import TrialState
5
+ from rdkit import Chem
6
+ from rdkit.Chem import AllChem
7
+ from sklearn.metrics import mean_squared_error
8
+ from sklearn.model_selection import train_test_split
9
+ import xgboost as xgb
10
+ import os
11
+ from datasets import load_from_disk
12
+ from scipy.stats import spearmanr
13
+ import matplotlib.pyplot as plt
14
+
15
+
16
+ def save_and_plot_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, output_path):
17
+ os.makedirs(output_path, exist_ok=True)
18
+
19
+ # Save training predictions
20
+ train_df = pd.DataFrame({'True Permeability': y_true_train, 'Predicted Permeability': y_pred_train})
21
+ train_df.to_csv(os.path.join(output_path, 'train_predictions.csv'), index=False)
22
+
23
+ # Save validation predictions
24
+ val_df = pd.DataFrame({'True Permeability': y_true_val, 'Predicted Permeability': y_pred_val})
25
+ val_df.to_csv(os.path.join(output_path, 'val_predictions.csv'), index=False)
26
+
27
+ # Plot training predictions
28
+ plot_correlation(
29
+ y_true_train,
30
+ y_pred_train,
31
+ title="Training Set Correlation Plot",
32
+ output_file=os.path.join(output_path, 'train_correlation.png'),
33
+ )
34
+
35
+ # Plot validation predictions
36
+ plot_correlation(
37
+ y_true_val,
38
+ y_pred_val,
39
+ title="Validation Set Correlation Plot",
40
+ output_file=os.path.join(output_path, 'val_correlation.png'),
41
+ )
42
+
43
+ def plot_correlation(y_true, y_pred, title, output_file):
44
+ spearman_corr, _ = spearmanr(y_true, y_pred)
45
+
46
+ # Scatter plot
47
+ plt.figure(figsize=(10, 8))
48
+ plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
49
+ plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='teal', linestyle='--', label='Ideal fit')
50
+
51
+ # Add annotations
52
+ plt.title(f"{title}\nSpearman Correlation: {spearman_corr:.3f}")
53
+ plt.xlabel("True Permeability (logP)")
54
+ plt.ylabel("Predicted Affinity (logP)")
55
+ plt.legend()
56
+
57
+ # Save and show the plot
58
+ plt.tight_layout()
59
+ plt.savefig(output_file)
60
+ plt.show()
61
+
62
+ # Load dataset
63
+ dataset = load_from_disk('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/permeability/30K-data/')
64
+
65
+ # Extract sequences, labels, and embeddings
66
+ sequences = np.stack(dataset['sequence'])
67
+ labels = np.stack(dataset['labels']) # Regression labels
68
+ embeddings = np.stack(dataset['embedding']) # Pre-trained embeddings
69
+
70
+ # Function to compute Morgan fingerprints
71
+ def compute_morgan_fingerprints(smiles_list, radius=2, n_bits=2048):
72
+ fps = []
73
+ for smiles in smiles_list:
74
+ mol = Chem.MolFromSmiles(smiles)
75
+ if mol is not None:
76
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
77
+ fps.append(np.array(fp))
78
+ else:
79
+ # If the SMILES string is invalid, use a zero vector
80
+ fps.append(np.zeros(n_bits))
81
+ print(f"Invalid SMILES: {smiles}")
82
+ return np.array(fps)
83
+
84
+ # Compute Morgan fingerprints for the sequences
85
+ #morgan_fingerprints = compute_morgan_fingerprints(sequences)
86
+
87
+ # Concatenate embeddings with Morgan fingerprints
88
+ #input_features = np.concatenate([embeddings, morgan_fingerprints], axis=1)
89
+ input_features = embeddings
90
+
91
+ # Initialize global variables
92
+ best_model_path = "/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/permeability/30K-train"
93
+ os.makedirs(best_model_path, exist_ok=True)
94
+
95
+ def trial_info_callback(study, trial):
96
+ if study.best_trial == trial:
97
+ print(f"Trial {trial.number}:")
98
+ print(f" MSE: {trial.value}")
99
+
100
+ def objective(trial):
101
+ # Define hyperparameters
102
+ params = {
103
+ 'objective': 'reg:squarederror',
104
+ 'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
105
+ 'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
106
+ 'gamma': trial.suggest_float('gamma', 0, 5),
107
+ 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
108
+ 'subsample': trial.suggest_float('subsample', 0.6, 0.9),
109
+ 'learning_rate': trial.suggest_float('learning_rate', 1e-5, 0.1),
110
+ 'max_depth': trial.suggest_int('max_depth', 2, 30),
111
+ 'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
112
+ 'tree_method': 'hist',
113
+ 'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.5, 10.0, log=True),
114
+ 'device': 'cuda:6',
115
+ }
116
+ """params = {
117
+ 'objective': 'reg:squarederror',
118
+ 'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
119
+ 'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
120
+ 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
121
+ 'subsample': trial.suggest_float('subsample', 0.6, 0.9),
122
+ 'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2),
123
+ 'max_depth': trial.suggest_int('max_depth', 4, 20),
124
+ 'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
125
+ 'tree_method': 'hist',
126
+ 'device': 'cuda:6',
127
+ }"""
128
+ num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
129
+
130
+ # Train-validation split
131
+ X_train, X_val, y_train, y_val = train_test_split(input_features, labels, test_size=0.2, random_state=42)
132
+
133
+ # Convert data to DMatrix
134
+ dtrain = xgb.DMatrix(X_train, label=y_train)
135
+ dvalid = xgb.DMatrix(X_val, label=y_val)
136
+
137
+ # Train XGBoost
138
+ model = xgb.train(
139
+ params=params,
140
+ dtrain=dtrain,
141
+ num_boost_round=num_boost_round,
142
+ evals=[(dvalid, "validation")],
143
+ early_stopping_rounds=50,
144
+ verbose_eval=False,
145
+ )
146
+
147
+ # Predict and evaluate
148
+ preds_train = model.predict(dtrain)
149
+ preds_val = model.predict(dvalid)
150
+
151
+ mse = mean_squared_error(y_val, preds_val)
152
+
153
+ # Calculate Spearman Rank Correlation
154
+ spearman_corr, _ = spearmanr(y_val, preds_val)
155
+ print(f"Spearman Rank Correlation: {spearman_corr}")
156
+
157
+ # Save the best model
158
+ if trial.study.user_attrs.get("best_mse", np.inf) > mse:
159
+ trial.study.set_user_attr("best_mse", mse)
160
+ trial.study.set_user_attr("best_spearman", spearman_corr) # Save the Spearman correlation
161
+ model.save_model(os.path.join(best_model_path, "best_model.json"))
162
+ save_and_plot_predictions(y_train, preds_train, y_val, preds_val, best_model_path)
163
+
164
+ return mse
165
+
166
+ if __name__ == "__main__":
167
+ study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
168
+ study.optimize(objective, n_trials=200, callbacks=[trial_info_callback])
169
+
170
+ # Print study statistics
171
+ print("Study statistics: ")
172
+ print(f" Number of finished trials: {len(study.trials)}")
173
+ print(f" Best trial value (MSE): {study.best_trial.value}")
174
+ print(f" Best Spearman Correlation: {study.user_attrs.get('best_spearman', None)}") # Print the best Spearman correlation
175
+ for key, value in study.best_trial.params.items():
176
+ print(f" {key}: {value}")
scoring/functions/scoring_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ from loguru import logger
4
+ from sklearn.ensemble import RandomForestRegressor
5
+ from rdkit.Chem import Descriptors, rdMolDescriptors
6
+ import joblib
7
+ from transformation import TransformFunction
8
+ from rdkit import Chem, rdBase, DataStructs
9
+ from rdkit.Chem import AllChem
10
+ from typing import List
11
+
12
+
13
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
14
+ """
15
+ Create ECFP fingerprint of a molecule
16
+ """
17
+ if hashed:
18
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
19
+ else:
20
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
21
+ fp_np = np.zeros((1,))
22
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
23
+ return fp_np.reshape(1, -1)
24
+
25
+
26
+ def fingerprints_from_smiles(smiles: List, size=2048):
27
+ """ Create ECFP fingerprints of smiles, with validity check """
28
+ fps = []
29
+ valid_mask = []
30
+ for i, smile in enumerate(smiles):
31
+ mol = Chem.MolFromSmiles(smile)
32
+ valid_mask.append(int(mol is not None))
33
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
34
+ fps.append(fp)
35
+
36
+ fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
37
+ return fps, valid_mask
38
+
39
+
40
+ def getMolDescriptors(mol, missingVal=0):
41
+ """ calculate the full list of descriptors for a molecule """
42
+
43
+ values, names = [], []
44
+ for nm, fn in Descriptors._descList:
45
+ try:
46
+ val = fn(mol)
47
+ except:
48
+ val = missingVal
49
+ values.append(val)
50
+ names.append(nm)
51
+
52
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
53
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
54
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
55
+
56
+ for nm, fn in custom_descriptors.items():
57
+ try:
58
+ val = fn(mol)
59
+ except:
60
+ val = missingVal
61
+ values.append(val)
62
+ names.append(nm)
63
+ return values, names
64
+
65
+
66
+ def get_pep_dps_from_smi(smi):
67
+ try:
68
+ mol = Chem.MolFromSmiles(smi)
69
+ except:
70
+ print(f"convert smi {smi} to molecule failed!")
71
+ mol = None
72
+
73
+ dps, _ = getMolDescriptors(mol)
74
+ return np.array(dps)
75
+
76
+
77
+ def get_pep_dps(smi_list):
78
+ if len(smi_list) == 0:
79
+ return np.zeros((0, 211))
80
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
81
+
82
+
83
+ """def get_smi_from_helms(helm_seqs: list):
84
+ valid_idxes = []
85
+ valid_smiles = []
86
+
87
+ for idx, helm in enumerate(helm_seqs):
88
+ # Ignore helm which cannot converted into molecules
89
+ try:
90
+ smi = get_cycpep_smi_from_helm(helm)
91
+ if smi:
92
+ valid_idxes.append(idx)
93
+ valid_smiles.append(smi)
94
+ except Exception as e:
95
+ # logger.debug(f'Error: {e} in helm {helm}')
96
+ pass
97
+ return valid_smiles, valid_idxes"""
98
+
99
+
100
+ def check_smi_validity(smiles: list):
101
+ valid_smi, valid_idx = [], []
102
+ for idx, smi in enumerate(smiles):
103
+ try:
104
+ mol = Chem.MolFromSmiles(smi) if smi else None
105
+ if mol:
106
+ valid_smi.append(smi)
107
+ valid_idx.append(idx)
108
+ except Exception as e:
109
+ # logger.debug(f'Error: {e} in smiles {smi}')
110
+ pass
111
+ return valid_smi, valid_idx
scoring/functions/solubility.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
4
+ import xgboost as xgb
5
+ import torch
6
+ import numpy as np
7
+ from transformers import AutoModelForMaskedLM
8
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
9
+ import warnings
10
+ import numpy as np
11
+ from rdkit.Chem import Descriptors, rdMolDescriptors
12
+ from rdkit import Chem, rdBase, DataStructs
13
+ from rdkit.Chem import AllChem
14
+ from typing import List
15
+ from scoring.functions.transformation import TransformFunction
16
+ from transformers import AutoModelForMaskedLM
17
+
18
+
19
+ rdBase.DisableLog('rdApp.error')
20
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ warnings.filterwarnings("ignore", category=FutureWarning)
23
+
24
+ class Solubility:
25
+ def __init__(self):
26
+ self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_train/best_model.json')
27
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
28
+ self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
29
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
30
+
31
+ def generate_embeddings(self, sequences):
32
+ embeddings = []
33
+ for sequence in sequences:
34
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
35
+ with torch.no_grad():
36
+ output = self.emb_model(**tokenized)
37
+ # Mean pooling across sequence length
38
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
39
+ embeddings.append(embedding)
40
+ return np.array(embeddings)
41
+
42
+ def get_scores(self, input_seqs: list):
43
+ scores = np.zeros(len(input_seqs))
44
+ features = self.generate_embeddings(input_seqs)
45
+
46
+ if len(features) == 0:
47
+ return scores
48
+
49
+ features = np.nan_to_num(features, nan=0.)
50
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
51
+
52
+ features = xgb.DMatrix(features)
53
+
54
+ scores = self.predictor.predict(features)
55
+ return scores
56
+
57
+ def __call__(self, input_seqs: list):
58
+ scores = self.get_scores(input_seqs)
59
+ return scores
60
+
61
+ def unittest():
62
+ solubility = Solubility()
63
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
64
+ scores = solubility(input_seqs=seq)
65
+ print(scores)
66
+
67
+ if __name__ == '__main__':
68
+ unittest()
scoring/hemolysis.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
4
+ import xgboost as xgb
5
+ import torch
6
+ import numpy as np
7
+ from transformers import AutoModelForMaskedLM
8
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
9
+ import warnings
10
+ import numpy as np
11
+ from rdkit.Chem import Descriptors, rdMolDescriptors
12
+ from rdkit import Chem, rdBase, DataStructs
13
+ from rdkit.Chem import AllChem
14
+ from typing import List
15
+ from scoring.functions.transformation import TransformFunction
16
+ from transformers import AutoModelForMaskedLM
17
+
18
+
19
+ rdBase.DisableLog('rdApp.error')
20
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ warnings.filterwarnings("ignore", category=FutureWarning)
23
+
24
+ class Hemolysis:
25
+
26
+ def __init__(self):
27
+ self.predictor = xgb.Booster(model_file='/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/hemolysis/new_train/best_model.json')
28
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
29
+ self.tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt', '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
30
+
31
+ def generate_embeddings(self, sequences):
32
+ embeddings = []
33
+ for sequence in sequences:
34
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
35
+ with torch.no_grad():
36
+ output = self.emb_model(**tokenized)
37
+ # Mean pooling across sequence length
38
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
39
+ embeddings.append(embedding)
40
+ return np.array(embeddings)
41
+
42
+ def get_scores(self, input_seqs: list):
43
+ scores = np.ones(len(input_seqs))
44
+ features = self.generate_embeddings(input_seqs)
45
+
46
+ if len(features) == 0:
47
+ return scores
48
+
49
+ features = np.nan_to_num(features, nan=0.)
50
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
51
+
52
+ features = xgb.DMatrix(features)
53
+
54
+ probs = self.predictor.predict(features)
55
+ # return the probability of it being not hemolytic
56
+ return scores - probs
57
+
58
+ def __call__(self, input_seqs: list):
59
+ scores = self.get_scores(input_seqs)
60
+ return scores
61
+
62
+ def unittest():
63
+ hemo = Hemolysis()
64
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
65
+
66
+ scores = hemo(input_seqs=seq)
67
+ print(scores)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ unittest()
scoring/scoring_functions.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
3
+ import io
4
+ import subprocess
5
+ import warnings
6
+ import numpy as np
7
+ import pandas as pd
8
+ from typing import List
9
+ from loguru import logger
10
+ from tqdm import tqdm
11
+ from rdkit import Chem, rdBase, DataStructs
12
+ from rdkit.Chem import AllChem
13
+ import torch
14
+ from scoring.functions.binding.binding import BindingAffinity
15
+ from scoring.functions.permeability.permeability import Permeability
16
+ from scoring.functions.solubility.solubility import Solubility
17
+ from scoring.functions.hemolysis.hemolysis import Hemolysis
18
+ from scoring.functions.nonfouling.nonfouling import Nonfouling
19
+
20
+ class ScoringFunctions:
21
+ def __init__(self, score_func_names=None, prot_seqs=[]):
22
+ """
23
+ Class for generating score vectors given generated sequence
24
+
25
+ Args:
26
+ score_func_names: list of scoring function names to be evaluated
27
+ score_weights: weights to scale scores (default: 1)
28
+ target_protein: sequence of target protein binder
29
+ """
30
+ if score_func_names is None:
31
+ # just do unmasking based on validity of peptide bonds
32
+ self.score_func_names = []
33
+ else:
34
+ self.score_func_names = score_func_names
35
+
36
+ # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
37
+
38
+ # binding affinities
39
+ self.target_protein = prot_seqs
40
+ print(len(prot_seqs))
41
+
42
+ if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
43
+ binding_affinity1 = BindingAffinity(prot_seqs[0])
44
+ binding_affinity2 = None
45
+ elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
46
+ binding_affinity1 = BindingAffinity(prot_seqs[0])
47
+ binding_affinity2 = BindingAffinity(prot_seqs[1])
48
+ else:
49
+ print("here")
50
+ binding_affinity1 = None
51
+ binding_affinity2 = None
52
+
53
+ permeability = Permeability()
54
+ sol = Solubility()
55
+ nonfouling = Nonfouling()
56
+ hemo = Hemolysis()
57
+
58
+ self.all_funcs = {'binding_affinity1': binding_affinity1,
59
+ 'binding_affinity2': binding_affinity2,
60
+ 'permeability': permeability,
61
+ 'nonfouling': nonfouling,
62
+ 'solubility': sol,
63
+ 'hemolysis': hemo
64
+ }
65
+
66
+ def forward(self, input_seqs):
67
+ scores = []
68
+
69
+ for i, score_func in enumerate(self.score_func_names):
70
+ score = self.all_funcs[score_func](input_seqs = input_seqs)
71
+
72
+ scores.append(score)
73
+
74
+ # convert to numpy arrays with shape (num_sequences, num_functions)
75
+ scores = np.float32(scores).T
76
+
77
+ return scores
78
+
79
+ def __call__(self, input_seqs: list):
80
+ return self.forward(input_seqs)
81
+
82
+
83
+ def unittest():
84
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
85
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
86
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
87
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
88
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
89
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
90
+ cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
91
+
92
+ num_iter = 0
93
+ score_func_times = [0, 1, 2, 3, 4, 5]
94
+
95
+ scoring = ScoringFunctions(score_func_names=['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'], prot_seqs=[tfr])
96
+
97
+ smiles = ['N2[C@H](CC(C)C)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](Cc1ccccc1C(F)(F)F)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](CCSC)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](CC(=O)N)C2(=O)']
98
+
99
+ scores = scoring(input_seqs=smiles)
100
+ print(scores)
101
+ print(len(scores))
102
+
103
+ if __name__ == '__main__':
104
+ unittest()
tokenizer/__init__.py ADDED
File without changes
tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import os
4
+ import re
5
+ import codecs
6
+ import unicodedata
7
+ from typing import List, Optional
8
+ from transformers import PreTrainedTokenizer
9
+ from SmilesPE.tokenizer import SPE_Tokenizer
10
+ import torch
11
+
12
+ def load_vocab(vocab_file):
13
+ """Loads a vocabulary file into a dictionary."""
14
+ vocab = collections.OrderedDict()
15
+ with open(vocab_file, "r", encoding="utf-8") as reader:
16
+ tokens = reader.readlines()
17
+ for index, token in enumerate(tokens):
18
+ token = token.rstrip("\n")
19
+ vocab[token] = index
20
+ return vocab
21
+
22
+ class Atomwise_Tokenizer(object):
23
+ """Run atom-level SMILES tokenization"""
24
+
25
+ def __init__(self):
26
+ """ Constructs a atom-level Tokenizer.
27
+ """
28
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
29
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
30
+
31
+ self.regex = re.compile(self.regex_pattern)
32
+
33
+ def tokenize(self, text):
34
+ """ Basic Tokenization of a SMILES.
35
+ """
36
+ tokens = [token for token in self.regex.findall(text)]
37
+ return tokens
38
+
39
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
40
+ r"""
41
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
42
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
43
+ should refer to the superclass for more information regarding methods.
44
+ Args:
45
+ vocab_file (:obj:`string`):
46
+ File containing the vocabulary.
47
+ spe_file (:obj:`string`):
48
+ File containing the trained SMILES Pair Encoding vocabulary.
49
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
50
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
51
+ token instead.
52
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
53
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
54
+ for sequence classification or for a text and a question for question answering.
55
+ It is also used as the last token of a sequence built with special tokens.
56
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
57
+ The token used for padding, for example when batching sequences of different lengths.
58
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
59
+ The classifier token which is used when doing sequence classification (classification of the whole
60
+ sequence instead of per-token classification). It is the first token of the sequence when built with
61
+ special tokens.
62
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
63
+ The token used for masking values. This is the token used when training this model with masked language
64
+ modeling. This is the token which the model will try to predict.
65
+ """
66
+
67
+ def __init__(self, vocab_file, spe_file,
68
+ unk_token="[UNK]",
69
+ sep_token="[SEP]",
70
+ pad_token="[PAD]",
71
+ cls_token="[CLS]",
72
+ mask_token="[MASK]",
73
+ **kwargs):
74
+ if not os.path.isfile(vocab_file):
75
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
76
+ if not os.path.isfile(spe_file):
77
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
78
+
79
+ self.vocab = load_vocab(vocab_file)
80
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
81
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
82
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
83
+
84
+ super().__init__(
85
+ unk_token=unk_token,
86
+ sep_token=sep_token,
87
+ pad_token=pad_token,
88
+ cls_token=cls_token,
89
+ mask_token=mask_token,
90
+ **kwargs)
91
+
92
+ @property
93
+ def vocab_size(self):
94
+ return len(self.vocab)
95
+
96
+ def get_vocab(self):
97
+ return dict(self.vocab, **self.added_tokens_encoder)
98
+
99
+ def _tokenize(self, text):
100
+ return self.spe_tokenizer.tokenize(text).split(' ')
101
+
102
+ def _convert_token_to_id(self, token):
103
+ """ Converts a token (str) in an id using the vocab. """
104
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
105
+
106
+ # changed encode and decode functions
107
+ def encode(self, token_array):
108
+ token_ids = []
109
+ token_ids.append(2)
110
+ for token in token_array:
111
+ id = self._convert_token_to_id(token)
112
+ token_ids.append(id)
113
+ token_ids.append(3)
114
+ token_ids = torch.tensor([token_ids])
115
+ attn_mask = torch.ones_like(token_ids)
116
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
117
+
118
+ def decode(self, token_ids, skip_special_tokens=True):
119
+ token_ids = token_ids.squeeze(0).cpu().tolist()
120
+ token_array = []
121
+ for idx in token_ids:
122
+ if idx == 3: # Stop decoding when token ID 3 is encountered
123
+ break
124
+ if skip_special_tokens and idx in self.all_special_ids:
125
+ continue
126
+ token = self._convert_id_to_token(idx)
127
+ token_array.append(token)
128
+ sequence = "".join(token_array)
129
+ return sequence
130
+
131
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
132
+ sequences = []
133
+ for token_ids in batch_token_ids:
134
+ sequences.append(self.decode(token_ids))
135
+ return sequences
136
+
137
+ def get_token_split(self, token_ids):
138
+ if isinstance(token_ids, torch.Tensor):
139
+ token_ids = token_ids.cpu().tolist()
140
+
141
+ token_array = []
142
+ for seq_ids in token_ids:
143
+ seq_array = []
144
+ for id in seq_ids:
145
+ token = self._convert_id_to_token(id)
146
+ seq_array.append(token)
147
+ token_array.append(seq_array)
148
+
149
+ return token_array
150
+
151
+ def _convert_id_to_token(self, index):
152
+ """Converts an index (integer) in a token (str) using the vocab."""
153
+ return self.ids_to_tokens.get(index, self.unk_token)
154
+
155
+ def convert_tokens_to_string(self, tokens):
156
+ """ Converts a sequence of tokens (string) in a single string. """
157
+ out_string = " ".join(tokens).replace(" ##", "").strip()
158
+ return out_string
159
+
160
+ def build_inputs_with_special_tokens(
161
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
162
+ ) -> List[int]:
163
+ """
164
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
165
+ by concatenating and adding special tokens.
166
+ A BERT sequence has the following format:
167
+ - single sequence: ``[CLS] X [SEP]``
168
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
169
+ Args:
170
+ token_ids_0 (:obj:`List[int]`):
171
+ List of IDs to which the special tokens will be added
172
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
173
+ Optional second list of IDs for sequence pairs.
174
+ Returns:
175
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
176
+ """
177
+ if token_ids_1 is None:
178
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
179
+ cls = [self.cls_token_id]
180
+ sep = [self.sep_token_id]
181
+ return cls + token_ids_0 + sep + token_ids_1 + sep
182
+
183
+ def get_special_tokens_mask(
184
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
185
+ ) -> List[int]:
186
+ """
187
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
188
+ special tokens using the tokenizer ``prepare_for_model`` method.
189
+ Args:
190
+ token_ids_0 (:obj:`List[int]`):
191
+ List of ids.
192
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
193
+ Optional second list of IDs for sequence pairs.
194
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
195
+ Set to True if the token list is already formatted with special tokens for the model
196
+ Returns:
197
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
198
+ """
199
+
200
+ if already_has_special_tokens:
201
+ if token_ids_1 is not None:
202
+ raise ValueError(
203
+ "You should not supply a second sequence if the provided sequence of "
204
+ "ids is already formated with special tokens for the model."
205
+ )
206
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
207
+
208
+ if token_ids_1 is not None:
209
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
210
+ return [1] + ([0] * len(token_ids_0)) + [1]
211
+
212
+ def create_token_type_ids_from_sequences(
213
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
214
+ ) -> List[int]:
215
+ """
216
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
217
+ A BERT sequence pair mask has the following format:
218
+ ::
219
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
220
+ | first sequence | second sequence |
221
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
222
+ Args:
223
+ token_ids_0 (:obj:`List[int]`):
224
+ List of ids.
225
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
226
+ Optional second list of IDs for sequence pairs.
227
+ Returns:
228
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
229
+ sequence(s).
230
+ """
231
+ sep = [self.sep_token_id]
232
+ cls = [self.cls_token_id]
233
+ if token_ids_1 is None:
234
+ return len(cls + token_ids_0 + sep) * [0]
235
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
236
+
237
+ def save_vocabulary(self, vocab_path):
238
+ """
239
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
240
+ Args:
241
+ vocab_path (:obj:`str`):
242
+ The directory in which to save the vocabulary.
243
+ Returns:
244
+ :obj:`Tuple(str)`: Paths to the files saved.
245
+ """
246
+ index = 0
247
+ if os.path.isdir(vocab_path):
248
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
249
+ else:
250
+ vocab_file = vocab_path
251
+ with open(vocab_file, "w", encoding="utf-8") as writer:
252
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
253
+ if index != token_index:
254
+ logger.warning(
255
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
256
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
257
+ )
258
+ index = token_index
259
+ writer.write(token + "\n")
260
+ index += 1
261
+ return (vocab_file,)
262
+
263
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
264
+ r"""
265
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
266
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
267
+ should refer to the superclass for more information regarding methods.
268
+ Args:
269
+ vocab_file (:obj:`string`):
270
+ File containing the vocabulary.
271
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
272
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
273
+ token instead.
274
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
275
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
276
+ for sequence classification or for a text and a question for question answering.
277
+ It is also used as the last token of a sequence built with special tokens.
278
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
279
+ The token used for padding, for example when batching sequences of different lengths.
280
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
281
+ The classifier token which is used when doing sequence classification (classification of the whole
282
+ sequence instead of per-token classification). It is the first token of the sequence when built with
283
+ special tokens.
284
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
285
+ The token used for masking values. This is the token used when training this model with masked language
286
+ modeling. This is the token which the model will try to predict.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ vocab_file,
292
+ unk_token="[UNK]",
293
+ sep_token="[SEP]",
294
+ pad_token="[PAD]",
295
+ cls_token="[CLS]",
296
+ mask_token="[MASK]",
297
+ **kwargs
298
+ ):
299
+ super().__init__(
300
+ unk_token=unk_token,
301
+ sep_token=sep_token,
302
+ pad_token=pad_token,
303
+ cls_token=cls_token,
304
+ mask_token=mask_token,
305
+ **kwargs,
306
+ )
307
+
308
+ if not os.path.isfile(vocab_file):
309
+ raise ValueError(
310
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
311
+ )
312
+ self.vocab = load_vocab(vocab_file)
313
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
314
+ self.tokenizer = Atomwise_Tokenizer()
315
+
316
+ @property
317
+ def vocab_size(self):
318
+ return len(self.vocab)
319
+
320
+ def get_vocab(self):
321
+ return dict(self.vocab, **self.added_tokens_encoder)
322
+
323
+
324
+ def _tokenize(self, text):
325
+ return self.tokenizer.tokenize(text)
326
+
327
+ def _convert_token_to_id(self, token):
328
+ """ Converts a token (str) in an id using the vocab. """
329
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
330
+
331
+ def _convert_id_to_token(self, index):
332
+ """Converts an index (integer) in a token (str) using the vocab."""
333
+ return self.ids_to_tokens.get(index, self.unk_token)
334
+
335
+ def convert_tokens_to_string(self, tokens):
336
+ """ Converts a sequence of tokens (string) in a single string. """
337
+ out_string = " ".join(tokens).replace(" ##", "").strip()
338
+ return out_string
339
+
340
+ def build_inputs_with_special_tokens(
341
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
342
+ ) -> List[int]:
343
+ """
344
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
345
+ by concatenating and adding special tokens.
346
+ A BERT sequence has the following format:
347
+ - single sequence: ``[CLS] X [SEP]``
348
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
349
+ Args:
350
+ token_ids_0 (:obj:`List[int]`):
351
+ List of IDs to which the special tokens will be added
352
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
353
+ Optional second list of IDs for sequence pairs.
354
+ Returns:
355
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
356
+ """
357
+ if token_ids_1 is None:
358
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
359
+ cls = [self.cls_token_id]
360
+ sep = [self.sep_token_id]
361
+ return cls + token_ids_0 + sep + token_ids_1 + sep
362
+
363
+ def get_special_tokens_mask(
364
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
365
+ ) -> List[int]:
366
+ """
367
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
368
+ special tokens using the tokenizer ``prepare_for_model`` method.
369
+ Args:
370
+ token_ids_0 (:obj:`List[int]`):
371
+ List of ids.
372
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
373
+ Optional second list of IDs for sequence pairs.
374
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
375
+ Set to True if the token list is already formatted with special tokens for the model
376
+ Returns:
377
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
378
+ """
379
+
380
+ if already_has_special_tokens:
381
+ if token_ids_1 is not None:
382
+ raise ValueError(
383
+ "You should not supply a second sequence if the provided sequence of "
384
+ "ids is already formated with special tokens for the model."
385
+ )
386
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
387
+
388
+ if token_ids_1 is not None:
389
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
390
+ return [1] + ([0] * len(token_ids_0)) + [1]
391
+
392
+ def create_token_type_ids_from_sequences(
393
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
394
+ ) -> List[int]:
395
+ """
396
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
397
+ A BERT sequence pair mask has the following format:
398
+ ::
399
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
400
+ | first sequence | second sequence |
401
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
402
+ Args:
403
+ token_ids_0 (:obj:`List[int]`):
404
+ List of ids.
405
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
406
+ Optional second list of IDs for sequence pairs.
407
+ Returns:
408
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
409
+ sequence(s).
410
+ """
411
+ sep = [self.sep_token_id]
412
+ cls = [self.cls_token_id]
413
+ if token_ids_1 is None:
414
+ return len(cls + token_ids_0 + sep) * [0]
415
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
416
+
417
+ def save_vocabulary(self, vocab_path):
418
+ """
419
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
420
+ Args:
421
+ vocab_path (:obj:`str`):
422
+ The directory in which to save the vocabulary.
423
+ Returns:
424
+ :obj:`Tuple(str)`: Paths to the files saved.
425
+ """
426
+ index = 0
427
+ if os.path.isdir(vocab_path):
428
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
429
+ else:
430
+ vocab_file = vocab_path
431
+ with open(vocab_file, "w", encoding="utf-8") as writer:
432
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
433
+ if index != token_index:
434
+ logger.warning(
435
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
436
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
437
+ )
438
+ index = token_index
439
+ writer.write(token + "\n")
440
+ index += 1
441
+ return (vocab_file,)
tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ cccc
96
+ ccn
97
+ cco
98
+ ccs
99
+ cn
100
+ cnc
101
+ cnn
102
+ co
103
+ coc
104
+ cs
105
+ csc
106
+ csn
107
+ e
108
+ g
109
+ i
110
+ l
111
+ n
112
+ nc
113
+ ncc
114
+ ncn
115
+ nco
116
+ ncs
117
+ nn
118
+ nnc
119
+ nnn
120
+ no
121
+ noc
122
+ non
123
+ ns
124
+ nsc
125
+ nsn
126
+ o
127
+ oc
128
+ occ
129
+ on
130
+ p
131
+ r
132
+ s
133
+ sc
134
+ scc
135
+ scn
136
+ sn
137
+ t
138
+ c1
139
+ c2
140
+ c3
141
+ c4
142
+ c5
143
+ c6
144
+ c7
145
+ c8
146
+ c9
147
+ n1
148
+ n2
149
+ n3
150
+ n4
151
+ n5
152
+ n6
153
+ n7
154
+ n8
155
+ n9
156
+ O1
157
+ O2
158
+ O3
159
+ O4
160
+ O5
161
+ O6
162
+ O7
163
+ O8
164
+ O9
165
+ (c1
166
+ (c2
167
+ c1)
168
+ c2)
169
+ (n1
170
+ (n2
171
+ n1)
172
+ n2)
173
+ (O1
174
+ (O2
175
+ O2)
176
+ =O
177
+ =C
178
+ =c
179
+ =N
180
+ =n
181
+ =CC
182
+ =CN
183
+ =Cc
184
+ =cc
185
+ =NC
186
+ =Nc
187
+ =nC
188
+ =nc
189
+ #C
190
+ #CC
191
+ #CN
192
+ #N
193
+ #NC
194
+ #NN
195
+ (C
196
+ C)
197
+ (O
198
+ O)
199
+ (N
200
+ N)
201
+ NP
202
+ PN
203
+ CP
204
+ PC
205
+ NS
206
+ SN
207
+ SP
208
+ PS
209
+ C(=O)
210
+ (/Br)
211
+ (/C#N)
212
+ (/C)
213
+ (/C=N)
214
+ (/C=O)
215
+ (/CBr)
216
+ (/CC)
217
+ (/CCC)
218
+ (/CCF)
219
+ (/CCN)
220
+ (/CCO)
221
+ (/CCl)
222
+ (/CI)
223
+ (/CN)
224
+ (/CO)
225
+ (/CS)
226
+ (/Cl)
227
+ (/F)
228
+ (/I)
229
+ (/N)
230
+ (/NC)
231
+ (/NCC)
232
+ (/NO)
233
+ (/O)
234
+ (/OC)
235
+ (/OCC)
236
+ (/S)
237
+ (/SC)
238
+ (=C)
239
+ (=C/C)
240
+ (=C/F)
241
+ (=C/I)
242
+ (=C/N)
243
+ (=C/O)
244
+ (=CBr)
245
+ (=CC)
246
+ (=CCF)
247
+ (=CCN)
248
+ (=CCO)
249
+ (=CCl)
250
+ (=CF)
251
+ (=CI)
252
+ (=CN)
253
+ (=CO)
254
+ (=C\\C)
255
+ (=C\\F)
256
+ (=C\\I)
257
+ (=C\\N)
258
+ (=C\\O)
259
+ (=N)
260
+ (=N/C)
261
+ (=N/N)
262
+ (=N/O)
263
+ (=NBr)
264
+ (=NC)
265
+ (=NCC)
266
+ (=NCl)
267
+ (=NN)
268
+ (=NO)
269
+ (=NOC)
270
+ (=N\\C)
271
+ (=N\\N)
272
+ (=N\\O)
273
+ (=O)
274
+ (=S)
275
+ (B)
276
+ (Br)
277
+ (C#C)
278
+ (C#CC)
279
+ (C#CI)
280
+ (C#CO)
281
+ (C#N)
282
+ (C#SN)
283
+ (C)
284
+ (C=C)
285
+ (C=CF)
286
+ (C=CI)
287
+ (C=N)
288
+ (C=NN)
289
+ (C=NO)
290
+ (C=O)
291
+ (C=S)
292
+ (CBr)
293
+ (CC#C)
294
+ (CC#N)
295
+ (CC)
296
+ (CC=C)
297
+ (CC=O)
298
+ (CCBr)
299
+ (CCC)
300
+ (CCCC)
301
+ (CCCF)
302
+ (CCCI)
303
+ (CCCN)
304
+ (CCCO)
305
+ (CCCS)
306
+ (CCCl)
307
+ (CCF)
308
+ (CCI)
309
+ (CCN)
310
+ (CCNC)
311
+ (CCNN)
312
+ (CCNO)
313
+ (CCO)
314
+ (CCOC)
315
+ (CCON)
316
+ (CCS)
317
+ (CCSC)
318
+ (CCl)
319
+ (CF)
320
+ (CI)
321
+ (CN)
322
+ (CN=O)
323
+ (CNC)
324
+ (CNCC)
325
+ (CNCO)
326
+ (CNN)
327
+ (CNNC)
328
+ (CNO)
329
+ (CNOC)
330
+ (CO)
331
+ (COC)
332
+ (COCC)
333
+ (COCI)
334
+ (COCN)
335
+ (COCO)
336
+ (COF)
337
+ (CON)
338
+ (COO)
339
+ (CS)
340
+ (CSC)
341
+ (CSCC)
342
+ (CSCF)
343
+ (CSO)
344
+ (Cl)
345
+ (F)
346
+ (I)
347
+ (N)
348
+ (N=N)
349
+ (N=NO)
350
+ (N=O)
351
+ (N=S)
352
+ (NBr)
353
+ (NC#N)
354
+ (NC)
355
+ (NC=N)
356
+ (NC=O)
357
+ (NC=S)
358
+ (NCBr)
359
+ (NCC)
360
+ (NCCC)
361
+ (NCCF)
362
+ (NCCN)
363
+ (NCCO)
364
+ (NCCS)
365
+ (NCCl)
366
+ (NCNC)
367
+ (NCO)
368
+ (NCS)
369
+ (NCl)
370
+ (NN)
371
+ (NN=O)
372
+ (NNC)
373
+ (NO)
374
+ (NOC)
375
+ (O)
376
+ (OC#N)
377
+ (OC)
378
+ (OC=C)
379
+ (OC=O)
380
+ (OC=S)
381
+ (OCBr)
382
+ (OCC)
383
+ (OCCC)
384
+ (OCCF)
385
+ (OCCI)
386
+ (OCCN)
387
+ (OCCO)
388
+ (OCCS)
389
+ (OCCl)
390
+ (OCF)
391
+ (OCI)
392
+ (OCO)
393
+ (OCOC)
394
+ (OCON)
395
+ (OCSC)
396
+ (OCl)
397
+ (OI)
398
+ (ON)
399
+ (OO)
400
+ (OOC)
401
+ (OOCC)
402
+ (OOSN)
403
+ (OSC)
404
+ (P)
405
+ (S)
406
+ (SC#N)
407
+ (SC)
408
+ (SCC)
409
+ (SCCC)
410
+ (SCCF)
411
+ (SCCN)
412
+ (SCCO)
413
+ (SCCS)
414
+ (SCCl)
415
+ (SCF)
416
+ (SCN)
417
+ (SCOC)
418
+ (SCSC)
419
+ (SCl)
420
+ (SI)
421
+ (SN)
422
+ (SN=O)
423
+ (SO)
424
+ (SOC)
425
+ (SOOO)
426
+ (SS)
427
+ (SSC)
428
+ (SSCC)
429
+ ([At])
430
+ ([O-])
431
+ ([O])
432
+ ([S-])
433
+ (\\Br)
434
+ (\\C#N)
435
+ (\\C)
436
+ (\\C=N)
437
+ (\\C=O)
438
+ (\\CBr)
439
+ (\\CC)
440
+ (\\CCC)
441
+ (\\CCO)
442
+ (\\CCl)
443
+ (\\CF)
444
+ (\\CN)
445
+ (\\CNC)
446
+ (\\CO)
447
+ (\\COC)
448
+ (\\Cl)
449
+ (\\F)
450
+ (\\I)
451
+ (\\N)
452
+ (\\NC)
453
+ (\\NCC)
454
+ (\\NN)
455
+ (\\NO)
456
+ (\\NOC)
457
+ (\\O)
458
+ (\\OC)
459
+ (\\OCC)
460
+ (\\ON)
461
+ (\\S)
462
+ (\\SC)
463
+ (\\SCC)
464
+ [Ag+]
465
+ [Ag-4]
466
+ [Ag]
467
+ [Al-3]
468
+ [Al]
469
+ [As+]
470
+ [AsH3]
471
+ [AsH]
472
+ [As]
473
+ [At]
474
+ [B-]
475
+ [B@-]
476
+ [B@@-]
477
+ [BH-]
478
+ [BH2-]
479
+ [BH3-]
480
+ [B]
481
+ [Ba]
482
+ [Br+2]
483
+ [BrH]
484
+ [Br]
485
+ [C+]
486
+ [C-]
487
+ [C@@H]
488
+ [C@@]
489
+ [C@H]
490
+ [C@]
491
+ [CH-]
492
+ [CH2]
493
+ [CH3]
494
+ [CH]
495
+ [C]
496
+ [CaH2]
497
+ [Ca]
498
+ [Cl+2]
499
+ [Cl+3]
500
+ [Cl+]
501
+ [Cs]
502
+ [FH]
503
+ [F]
504
+ [H]
505
+ [He]
506
+ [I+2]
507
+ [I+3]
508
+ [I+]
509
+ [IH]
510
+ [I]
511
+ [K]
512
+ [Kr]
513
+ [Li+]
514
+ [LiH]
515
+ [MgH2]
516
+ [Mg]
517
+ [N+]
518
+ [N-]
519
+ [N@+]
520
+ [N@@+]
521
+ [N@@]
522
+ [N@]
523
+ [NH+]
524
+ [NH-]
525
+ [NH2+]
526
+ [NH3]
527
+ [NH]
528
+ [N]
529
+ [Na]
530
+ [O+]
531
+ [O-]
532
+ [OH+]
533
+ [OH2]
534
+ [OH]
535
+ [O]
536
+ [P+]
537
+ [P@+]
538
+ [P@@+]
539
+ [P@@]
540
+ [P@]
541
+ [PH2]
542
+ [PH]
543
+ [P]
544
+ [Ra]
545
+ [Rb]
546
+ [S+]
547
+ [S-]
548
+ [S@+]
549
+ [S@@+]
550
+ [S@@]
551
+ [S@]
552
+ [SH+]
553
+ [SH2]
554
+ [SH]
555
+ [S]
556
+ [Se+]
557
+ [Se-2]
558
+ [SeH2]
559
+ [SeH]
560
+ [Se]
561
+ [Si@]
562
+ [SiH2]
563
+ [SiH]
564
+ [Si]
565
+ [SrH2]
566
+ [TeH]
567
+ [Te]
568
+ [Xe]
569
+ [Zn+2]
570
+ [Zn-2]
571
+ [Zn]
572
+ [b-]
573
+ [c+]
574
+ [c-]
575
+ [cH-]
576
+ [cH]
577
+ [c]
578
+ [n+]
579
+ [n-]
580
+ [nH]
581
+ [n]
582
+ [o+]
583
+ [s+]
584
+ [se+]
585
+ [se]
586
+ [te+]
587
+ [te]
utils/__pycache__/app.cpython-39.pyc ADDED
Binary file (26.3 kB). View file
 
utils/__pycache__/filter.cpython-39.pyc ADDED
Binary file (7.35 kB). View file
 
utils/__pycache__/generate_utils.cpython-39.pyc ADDED
Binary file (2.78 kB). View file
 
utils/__pycache__/helm_utils.cpython-39.pyc ADDED
Binary file (12.4 kB). View file
 
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (9.48 kB). View file
 
utils/app.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from io import StringIO
5
+ import rdkit
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem, Draw
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+ from io import BytesIO
13
+ import tempfile
14
+ from rdkit import Chem
15
+
16
+ class PeptideAnalyzer:
17
+ def __init__(self):
18
+ self.bond_patterns = [
19
+ (r'OC\(=O\)', 'ester'), # Ester bond
20
+ (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
21
+ (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
22
+ (r'NC\(=O\)', 'peptide'), # Standard peptide bond
23
+ (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
24
+ (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
25
+ ]
26
+ # Three to one letter code mapping
27
+ self.three_to_one = {
28
+ 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
29
+ 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
30
+ 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
31
+ 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
32
+ 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
33
+ }
34
+
35
+ def is_peptide(self, smiles):
36
+ """Check if the SMILES represents a peptide structure"""
37
+ mol = Chem.MolFromSmiles(smiles)
38
+ if mol is None:
39
+ return False
40
+
41
+ # Look for peptide bonds: NC(=O) pattern
42
+ peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
43
+ if mol.HasSubstructMatch(peptide_bond_pattern):
44
+ return True
45
+
46
+ # Look for N-methylated peptide bonds: N(C)C(=O) pattern
47
+ n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
48
+ if mol.HasSubstructMatch(n_methyl_pattern):
49
+ return True
50
+
51
+ return False
52
+
53
+ def is_cyclic(self, smiles):
54
+ """Improved cyclic peptide detection"""
55
+ # Check for C-terminal carboxyl
56
+ if smiles.endswith('C(=O)O'):
57
+ return False, [], []
58
+
59
+ # Find all numbers used in ring closures
60
+ ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
61
+
62
+ # Find aromatic ring numbers
63
+ aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
64
+ aromatic_cycles = []
65
+ for match in aromatic_matches:
66
+ numbers = re.findall(r'[0-9]', match)
67
+ aromatic_cycles.extend(numbers)
68
+
69
+ # Numbers that aren't part of aromatic rings are peptide cycles
70
+ peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
71
+
72
+ is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
73
+ return is_cyclic, peptide_cycles, aromatic_cycles
74
+
75
+ def split_on_bonds(self, smiles):
76
+ """Split SMILES into segments with simplified Pro handling"""
77
+ positions = []
78
+ used = set()
79
+
80
+ # Find Gly pattern first
81
+ gly_pattern = r'NCC\(=O\)'
82
+ for match in re.finditer(gly_pattern, smiles):
83
+ if not any(p in range(match.start(), match.end()) for p in used):
84
+ positions.append({
85
+ 'start': match.start(),
86
+ 'end': match.end(),
87
+ 'type': 'gly',
88
+ 'pattern': match.group()
89
+ })
90
+ used.update(range(match.start(), match.end()))
91
+
92
+ for pattern, bond_type in self.bond_patterns:
93
+ for match in re.finditer(pattern, smiles):
94
+ if not any(p in range(match.start(), match.end()) for p in used):
95
+ positions.append({
96
+ 'start': match.start(),
97
+ 'end': match.end(),
98
+ 'type': bond_type,
99
+ 'pattern': match.group()
100
+ })
101
+ used.update(range(match.start(), match.end()))
102
+
103
+ # Sort by position
104
+ positions.sort(key=lambda x: x['start'])
105
+
106
+ # Create segments
107
+ segments = []
108
+
109
+ if positions:
110
+ # First segment
111
+ if positions[0]['start'] > 0:
112
+ segments.append({
113
+ 'content': smiles[0:positions[0]['start']],
114
+ 'bond_after': positions[0]['pattern']
115
+ })
116
+
117
+ # Process segments
118
+ for i in range(len(positions)-1):
119
+ current = positions[i]
120
+ next_pos = positions[i+1]
121
+
122
+ if current['type'] == 'gly':
123
+ segments.append({
124
+ 'content': 'NCC(=O)',
125
+ 'bond_before': positions[i-1]['pattern'] if i > 0 else None,
126
+ 'bond_after': next_pos['pattern']
127
+ })
128
+ else:
129
+ content = smiles[current['end']:next_pos['start']]
130
+ if content:
131
+ segments.append({
132
+ 'content': content,
133
+ 'bond_before': current['pattern'],
134
+ 'bond_after': next_pos['pattern']
135
+ })
136
+
137
+ # Last segment
138
+ if positions[-1]['end'] < len(smiles):
139
+ segments.append({
140
+ 'content': smiles[positions[-1]['end']:],
141
+ 'bond_before': positions[-1]['pattern']
142
+ })
143
+
144
+ return segments
145
+
146
+ def clean_terminal_carboxyl(self, segment):
147
+ """Remove C-terminal carboxyl only if it's the true terminus"""
148
+ content = segment['content']
149
+
150
+ # Only clean if:
151
+ # 1. Contains C(=O)O
152
+ # 2. No bond_after exists (meaning it's the last segment)
153
+ # 3. C(=O)O is at the end of the content
154
+ if 'C(=O)O' in content and not segment.get('bond_after'):
155
+ print('recognized?')
156
+ # Remove C(=O)O pattern regardless of position
157
+ cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
158
+ # Remove any leftover empty parentheses
159
+ cleaned = re.sub(r'\(\)', '', cleaned)
160
+ print(cleaned)
161
+ return cleaned
162
+ return content
163
+
164
+ def identify_residue(self, segment):
165
+ """Identify residue with Pro reconstruction"""
166
+ # Only clean terminal carboxyl if this is the last segment
167
+ content = self.clean_terminal_carboxyl(segment)
168
+ mods = self.get_modifications(segment)
169
+
170
+ # UAA pattern matching section - before regular residues
171
+ # Phenylglycine and derivatives
172
+ if 'c1ccccc1' in content:
173
+ if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
174
+ return '4', mods # Base phenylglycine
175
+
176
+ # 4-substituted phenylalanines
177
+ if 'Cc1ccc' in content:
178
+ if 'OMe' in content or 'OCc1ccc' in content:
179
+ return '0A1', mods # 4-methoxy-Phenylalanine
180
+ elif 'Clc1ccc' in content:
181
+ return '200', mods # 4-chloro-Phenylalanine
182
+ elif 'Brc1ccc' in content:
183
+ return '4BF', mods # 4-Bromo-phenylalanine
184
+ elif 'C#Nc1ccc' in content:
185
+ return '4CF', mods # 4-cyano-phenylalanine
186
+ elif 'Ic1ccc' in content:
187
+ return 'PHI', mods # 4-Iodo-phenylalanine
188
+ elif 'Fc1ccc' in content:
189
+ return 'PFF', mods # 4-Fluoro-phenylalanine
190
+
191
+ # Modified tryptophans
192
+ if 'c[nH]c2' in content:
193
+ if 'Oc2cccc2' in content:
194
+ return '0AF', mods # 7-hydroxy-tryptophan
195
+ elif 'Fc2cccc2' in content:
196
+ return '4FW', mods # 4-fluoro-tryptophan
197
+ elif 'Clc2cccc2' in content:
198
+ return '6CW', mods # 6-chloro-tryptophan
199
+ elif 'Brc2cccc2' in content:
200
+ return 'BTR', mods # 6-bromo-tryptophan
201
+ elif 'COc2cccc2' in content:
202
+ return 'MOT5', mods # 5-Methoxy-tryptophan
203
+ elif 'Cc2cccc2' in content:
204
+ return 'MTR5', mods # 5-Methyl-tryptophan
205
+
206
+ # Special amino acids
207
+ if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
208
+ return 'BUG', mods # Tertleucine
209
+
210
+ if 'CCCNC(=N)N' in content:
211
+ return 'CIR', mods # Citrulline
212
+
213
+ if '[SeH]' in content:
214
+ return 'CSE', mods # Selenocysteine
215
+
216
+ if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
217
+ return 'DAB', mods # Diaminobutyric acid
218
+
219
+ if 'C1CCCCC1' in content:
220
+ if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
221
+ return 'CHG', mods # Cyclohexylglycine
222
+ elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
223
+ return 'ALC', mods # 3-cyclohexyl-alanine
224
+
225
+ # Naphthalene derivatives
226
+ if 'c1cccc2c1cccc2' in content:
227
+ if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
228
+ return 'NAL', mods # 2-Naphthyl-alanine
229
+
230
+ # Heteroaromatic derivatives
231
+ if 'c1cncc' in content:
232
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
233
+ if 'c1cscc' in content:
234
+ return 'THA3', mods # 3-(3-thienyl)-alanine
235
+ if 'c1nnc' in content:
236
+ return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
237
+
238
+ # Modified serines and threonines
239
+ if 'OP(O)(O)O' in content:
240
+ if '[C@@H](COP' in content or '[C@H](COP' in content:
241
+ return 'SEP', mods # phosphoserine
242
+ elif '[C@@H](OP' in content or '[C@H](OP' in content:
243
+ return 'TPO', mods # phosphothreonine
244
+
245
+ # Specialized ring systems
246
+ if 'c1c2ccccc2cc2c1cccc2' in content:
247
+ return 'ANTH', mods # 3-(9-anthryl)-alanine
248
+ if 'c1csc2c1cccc2' in content:
249
+ return 'BTH3', mods # 3-(3-benzothienyl)-alanine
250
+ if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
251
+ return 'ADAM', mods # Adamanthane
252
+
253
+ # Fluorinated derivatives
254
+ if 'FC(F)(F)' in content:
255
+ if 'CC(F)(F)F' in content:
256
+ return 'FLA', mods # Trifluoro-alanine
257
+ if 'C(F)(F)F)c1' in content:
258
+ if 'c1ccccc1C(F)(F)F' in content:
259
+ return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
260
+ if 'c1cccc(c1)C(F)(F)F' in content:
261
+ return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
262
+ if 'c1ccc(cc1)C(F)(F)F' in content:
263
+ return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
264
+
265
+ # Multiple halogen patterns
266
+ if 'F' in content and 'c1' in content:
267
+ if 'c1ccc(c(c1)F)F' in content:
268
+ return 'F2F', mods # 3,4-Difluoro-phenylalanine
269
+ if 'cc(F)cc(c1)F' in content:
270
+ return 'WFP', mods # 3,5-Difluoro-phenylalanine
271
+ if 'Cl' in content and 'c1' in content:
272
+ if 'c1ccc(cc1Cl)Cl' in content:
273
+ return 'CP24', mods # 2,4-dichloro-phenylalanine
274
+ if 'c1ccc(c(c1)Cl)Cl' in content:
275
+ return 'CP34', mods # 3,4-dichloro-phenylalanine
276
+
277
+ # Hydroxy and amino derivatives
278
+ if 'O' in content and 'c1' in content:
279
+ if 'c1cc(O)cc(c1)O' in content:
280
+ return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
281
+ if 'c1ccc(c(c1)O)O' in content:
282
+ return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
283
+
284
+ # Cyclic amino acids
285
+ if 'C1CCCC1' in content:
286
+ return 'CPA3', mods # 3-Cyclopentyl-alanine
287
+ if 'C1CCCCC1' in content:
288
+ if 'CC1CCCCC1' in content:
289
+ return 'ALC', mods # 3-cyclohexyl-alanine
290
+ else:
291
+ return 'CHG', mods # Cyclohexylglycine
292
+
293
+ # Chain-length variants
294
+ if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
295
+ return 'NLE', mods # Norleucine
296
+ if 'CC[C@@H]' in content or 'CC[C@H]' in content:
297
+ if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
298
+ return 'ABA', mods # 2-Aminobutyric acid
299
+
300
+ # Modified histidines
301
+ if 'c1cnc' in content:
302
+ if '[C@@H]1CN[C@@H](N1)F' in content:
303
+ return '2HF', mods # 2-fluoro-l-histidine
304
+ if 'c1cnc([nH]1)F' in content:
305
+ return '2HF1', mods # 2-fluoro-l-histidine variant
306
+ if 'c1c[nH]c(n1)F' in content:
307
+ return '2HF2', mods # 2-fluoro-l-histidine variant
308
+
309
+ # Sulfur and selenium containing
310
+ if '[SeH]' in content:
311
+ return 'CSE', mods # Selenocysteine
312
+ if 'S' in content:
313
+ if 'CSCc1ccccc1' in content:
314
+ return 'BCS', mods # benzylcysteine
315
+ if 'CCSC' in content:
316
+ return 'ESC', mods # Ethionine
317
+ if 'CCS' in content:
318
+ return 'HCS', mods # homocysteine
319
+
320
+ # Additional modifications
321
+ if 'CN=[N]=N' in content:
322
+ return 'AZDA', mods # azido-alanine
323
+ if '[NH]=[C](=[NH2])=[NH2]' in content:
324
+ if 'CCC[NH]=' in content:
325
+ return 'AGM', mods # 5-methyl-arginine
326
+ if 'CC[NH]=' in content:
327
+ return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
328
+
329
+ if 'CCON' in content:
330
+ return 'CAN', mods # canaline
331
+ if '[C@@H]1C=C[C@@H](C=C1)' in content:
332
+ return 'ACZ', mods # cis-amiclenomycin
333
+ if 'CCC(=O)[NH3]' in content:
334
+ return 'ONL', mods # 5-oxo-l-norleucine
335
+ if 'c1ccncc1' in content:
336
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
337
+ if 'c1ccco1' in content:
338
+ return 'FUA2', mods # (2-furyl)-alanine
339
+
340
+ if 'c1ccc' in content:
341
+ if 'c1ccc(cc1)c1ccccc1' in content:
342
+ return 'BIF', mods # 4,4-biphenylalanine
343
+ if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
344
+ return 'PBF', mods # 4-benzoyl-phenylalanine
345
+ if 'c1ccc(cc1)C(C)(C)C' in content:
346
+ return 'TBP4', mods # 4-tert-butyl-phenylalanine
347
+ if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
348
+ return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
349
+ if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
350
+ return 'APM', mods # m-amidinophenyl-3-alanine
351
+
352
+ # Multiple hydroxy patterns
353
+ if 'O' in content:
354
+ if '[C@H]([C@H](C)O)O' in content:
355
+ return 'ILX', mods # 4,5-dihydroxy-isoleucine
356
+ if '[C@H]([C@@H](C)O)O' in content:
357
+ return 'ALO', mods # Allo-threonine
358
+ if '[C@H](COP(O)(O)O)' in content:
359
+ return 'SEP', mods # phosphoserine
360
+ if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
361
+ return 'TPO', mods # phosphothreonine
362
+ if '[C@H](c1ccc(O)cc1)O' in content:
363
+ return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
364
+ if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
365
+ return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
366
+
367
+ # Heterocyclic patterns
368
+ if 'n1' in content:
369
+ if 'n1cccn1' in content:
370
+ return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
371
+ if 'n1nncn1' in content:
372
+ return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
373
+ if 'c2c(n1)cccc2' in content:
374
+ return 'QU32', mods # 3-(2-Quinolyl)-alanine
375
+ if 'c1cnc2c(c1)cccc2' in content:
376
+ return 'QU33', mods # 3-(3-quinolyl)-alanine
377
+ if 'c1ccnc2c1cccc2' in content:
378
+ return 'QU34', mods # 3-(4-quinolyl)-alanine
379
+ if 'c1ccc2c(c1)nccc2' in content:
380
+ return 'QU35', mods # 3-(5-Quinolyl)-alanine
381
+ if 'c1ccc2c(c1)cncc2' in content:
382
+ return 'QU36', mods # 3-(6-Quinolyl)-alanine
383
+ if 'c1cnc2c(n1)cccc2' in content:
384
+ return 'QX32', mods # 3-(2-quinoxalyl)-alanine
385
+
386
+ # Multiple nitrogen patterns
387
+ if 'N' in content:
388
+ if '[NH3]CC[C@@H]' in content:
389
+ return 'DAB', mods # Diaminobutyric acid
390
+ if '[NH3]C[C@@H]' in content:
391
+ return 'DPP', mods # 2,3-Diaminopropanoic acid
392
+ if '[NH3]CCCCCC[C@@H]' in content:
393
+ return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
394
+ if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
395
+ return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
396
+ if '[NH]=[C](=S)=[NH2]' in content:
397
+ return 'THIC', mods # Thio-citrulline
398
+
399
+ # Chain modified amino acids
400
+ if 'CC' in content:
401
+ if 'CCCC[C@@H]' in content:
402
+ return 'AHP', mods # 2-Aminoheptanoic acid
403
+ if 'CCC([C@@H])(C)C' in content:
404
+ return 'I2M', mods # 3-methyl-l-alloisoleucine
405
+ if 'CC[C@H]([C@@H])C' in content:
406
+ return 'IIL', mods # Allo-Isoleucine
407
+ if '[C@H](CCC(C)C)' in content:
408
+ return 'HLEU', mods # Homoleucine
409
+ if '[C@@H]([C@@H](C)O)C' in content:
410
+ return 'HLU', mods # beta-hydroxyleucine
411
+
412
+ # Modified glutamate/aspartate patterns
413
+ if '[C@@H]' in content:
414
+ if '[C@@H](C[C@@H](F))' in content:
415
+ return 'FGA4', mods # 4-Fluoro-glutamic acid
416
+ if '[C@@H](C[C@@H](O))' in content:
417
+ return '3GL', mods # 4-hydroxy-glutamic-acid
418
+ if '[C@@H](C[C@H](C))' in content:
419
+ return 'LME', mods # (3r)-3-methyl-l-glutamic acid
420
+ if '[C@@H](CC[C@H](C))' in content:
421
+ return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
422
+
423
+ # Sulfur and selenium modifications
424
+ if 'S' in content:
425
+ if 'SCC[C@@H]' in content:
426
+ return 'HSER', mods # homoserine
427
+ if 'SCCN' in content:
428
+ return 'SLZ', mods # thialysine
429
+ if 'SC(=O)' in content:
430
+ return 'CSA', mods # s-acetonylcysteine
431
+ if '[S@@](=O)' in content:
432
+ return 'SME', mods # Methionine sulfoxide
433
+ if 'S(=O)(=O)' in content:
434
+ return 'OMT', mods # Methionine sulfone
435
+
436
+ # Double bond containing
437
+ if 'C=' in content:
438
+ if 'C=C[C@@H]' in content:
439
+ return '2AG', mods # 2-Allyl-glycine
440
+ if 'C=C[C@@H]' in content:
441
+ return 'LVG', mods # vinylglycine
442
+ if 'C=Cc1ccccc1' in content:
443
+ return 'STYA', mods # Styrylalanine
444
+
445
+ # Special cases
446
+ if '[C@@H]1Cc2c(C1)cccc2' in content:
447
+ return 'IGL', mods # alpha-amino-2-indanacetic acid
448
+ if '[C](=[C](=O)=O)=O' in content:
449
+ return '26P', mods # 2-amino-6-oxopimelic acid
450
+ if '[C](=[C](=O)=O)=C' in content:
451
+ return '2NP', mods # l-2-amino-6-methylene-pimelic acid
452
+ if 'c2cnc[nH]2' in content:
453
+ return 'HIS', mods # histidine core
454
+ if 'c1cccc2c1cc(O)cc2' in content:
455
+ return 'NAO1', mods # 5-hydroxy-1-naphthalene
456
+ if 'c1ccc2c(c1)cc(O)cc2' in content:
457
+ return 'NAO2', mods # 6-hydroxy-2-naphthalene
458
+
459
+ # Proline (P) - flexible ring numbers
460
+ if any([
461
+ # Check for any ring number in bond patterns
462
+ (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
463
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
464
+ for n in '123456789'
465
+ ]) or any([
466
+ # Check ending patterns with any ring number
467
+ (f'CCCN{n}' in content and content.endswith('=O') and
468
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
469
+ for n in '123456789'
470
+ ]) or any([
471
+ # Handle CCC[C@H]n patterns
472
+ (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
473
+ (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
474
+ # N-terminal Pro with any ring number
475
+ (f'N{n}CCC[C@H]{n}' in content) or
476
+ (f'N{n}CCC[C@@H]{n}' in content)
477
+ for n in '123456789'
478
+ ]):
479
+ return 'Pro', mods
480
+
481
+ # Tryptophan (W) - more specific indole pattern
482
+ if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
483
+ 'c[nH]c' in content.replace(' ', ''):
484
+ return 'Trp', mods
485
+
486
+ # Lysine (K) - both patterns
487
+ if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
488
+ return 'Lys', mods
489
+
490
+ # Arginine (R) - both patterns
491
+ if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
492
+ return 'Arg', mods
493
+
494
+ if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
495
+ return 'Nle', mods
496
+
497
+ # Ornithine (Orn) - 3-carbon chain with NH2
498
+ if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
499
+ return 'Orn', mods
500
+
501
+ # 2-Naphthylalanine (2Nal) - distinct from Phe pattern
502
+ if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
503
+ return '2Nal', mods
504
+
505
+ # Cyclohexylalanine (Cha) - already in your code but moved here for clarity
506
+ if 'N2CCCCC2' in content or 'CCCCC2' in content:
507
+ return 'Cha', mods
508
+
509
+ # Aminobutyric acid (Abu) - 2-carbon chain
510
+ if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
511
+ return 'Abu', mods
512
+
513
+ # Pipecolic acid (Pip) - 6-membered ring like Pro
514
+ if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
515
+ return 'Pip', mods
516
+
517
+ # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
518
+ if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
519
+ return 'Chg', mods
520
+
521
+ # 4-Fluorophenylalanine (4F-Phe)
522
+ if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
523
+ return '4F-Phe', mods
524
+
525
+ # Regular residue identification
526
+ if ('NCC(=O)' in content) or (content == 'C'):
527
+ # Middle case - between bonds
528
+ if segment.get('bond_before') and segment.get('bond_after'):
529
+ if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
530
+ return 'Gly', mods
531
+ # Terminal case - at the end
532
+ elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
533
+ return 'Gly', mods
534
+
535
+ if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
536
+ return 'Leu', mods
537
+ if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
538
+ return 'Leu', mods
539
+
540
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
541
+ return 'Thr', mods
542
+
543
+ if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
544
+ return 'Phe', mods
545
+
546
+ if ('[C@H](C(C)C)' in content or # With outer parentheses
547
+ '[C@@H](C(C)C)' in content or # With outer parentheses
548
+ '[C@H]C(C)C' in content or # Without outer parentheses
549
+ '[C@@H]C(C)C' in content): # Without outer parentheses
550
+ if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
551
+ return 'Val', mods
552
+
553
+ if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
554
+ return 'O-tBu', mods
555
+
556
+ if any([
557
+ 'CC[C@H](C)' in content,
558
+ 'CC[C@@H](C)' in content,
559
+ 'C(C)C[C@H]' in content and 'CC(C)C' not in content,
560
+ 'C(C)C[C@@H]' in content and 'CC(C)C' not in content
561
+ ]):
562
+ return 'Ile', mods
563
+
564
+ if ('[C@H](C)' in content or '[C@@H](C)' in content):
565
+ if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
566
+ return 'Ala', mods
567
+
568
+ # Tyrosine (Tyr) - 4-hydroxybenzyl side chain
569
+ if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
570
+ return 'Tyr', mods
571
+
572
+
573
+ # Serine (Ser) - Hydroxymethyl side chain
574
+ if '[C@H](CO)' in content or '[C@@H](CO)' in content:
575
+ if not ('C(C)O' in content or 'COC' in content):
576
+ return 'Ser', mods
577
+
578
+ # Threonine (Thr) - 1-hydroxyethyl side chain
579
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
580
+ return 'Thr', mods
581
+
582
+ # Cysteine (Cys) - Thiol side chain
583
+ if '[C@H](CS)' in content or '[C@@H](CS)' in content:
584
+ return 'Cys', mods
585
+
586
+ # Methionine (Met) - Methylthioethyl side chain
587
+ if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
588
+ return 'Met', mods
589
+
590
+ # Asparagine (Asn) - Carbamoylmethyl side chain
591
+ if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
592
+ return 'Asn', mods
593
+
594
+ # Glutamine (Gln) - Carbamoylethyl side chain
595
+ if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
596
+ return 'Gln', mods
597
+
598
+ # Aspartic acid (Asp) - Carboxymethyl side chain
599
+ if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
600
+ return 'Asp', mods
601
+
602
+ # Glutamic acid (Glu) - Carboxyethyl side chain
603
+ if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
604
+ return 'Glu', mods
605
+
606
+ # Arginine (Arg) - 3-guanidinopropyl side chain
607
+ if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
608
+ return 'Arg', mods
609
+
610
+ # Histidine (His) - Imidazole side chain
611
+ if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
612
+ return 'His', mods
613
+
614
+ return None, mods
615
+
616
+ def get_modifications(self, segment):
617
+ """Get modifications based on bond types"""
618
+ mods = []
619
+ if segment.get('bond_after'):
620
+ if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
621
+ mods.append('N-Me')
622
+ if 'OC(=O)' in segment['bond_after']:
623
+ mods.append('O-linked')
624
+ return mods
625
+
626
+ def analyze_structure(self, smiles):
627
+ """Main analysis function with debug output"""
628
+ print("\nAnalyzing structure:", smiles)
629
+
630
+ # Split into segments
631
+ segments = self.split_on_bonds(smiles)
632
+
633
+ print("\nSegment Analysis:")
634
+ sequence = []
635
+ for i, segment in enumerate(segments):
636
+ print(f"\nSegment {i}:")
637
+ print(f"Content: {segment['content']}")
638
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
639
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
640
+
641
+ residue, mods = self.identify_residue(segment)
642
+ if residue:
643
+ if mods:
644
+ sequence.append(f"{residue}({','.join(mods)})")
645
+ else:
646
+ sequence.append(residue)
647
+ print(f"Identified as: {residue}")
648
+ print(f"Modifications: {mods}")
649
+ else:
650
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
651
+
652
+ # Check if cyclic
653
+ is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
654
+ three_letter = '-'.join(sequence)
655
+ one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
656
+
657
+ if is_cyclic:
658
+ three_letter = f"cyclo({three_letter})"
659
+ one_letter = f"cyclo({one_letter})"
660
+
661
+ print(f"\nFinal sequence: {three_letter}")
662
+ print(f"One-letter code: {one_letter}")
663
+ print(f"Is cyclic: {is_cyclic}")
664
+ #print(f"Peptide cycles: {peptide_cycles}")
665
+ #print(f"Aromatic cycles: {aromatic_cycles}")
666
+
667
+ return three_letter, len(segments)
668
+ """return {
669
+ 'three_letter': three_letter,
670
+ #'one_letter': one_letter,
671
+ 'is_cyclic': is_cyclic
672
+ }"""
673
+
674
+ def return_sequence(self, smiles):
675
+ """Main analysis function with debug output"""
676
+ print("\nAnalyzing structure:", smiles)
677
+
678
+ # Split into segments
679
+ segments = self.split_on_bonds(smiles)
680
+
681
+ print("\nSegment Analysis:")
682
+ sequence = []
683
+ for i, segment in enumerate(segments):
684
+ print(f"\nSegment {i}:")
685
+ print(f"Content: {segment['content']}")
686
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
687
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
688
+
689
+ residue, mods = self.identify_residue(segment)
690
+ if residue:
691
+ if mods:
692
+ sequence.append(f"{residue}({','.join(mods)})")
693
+ else:
694
+ sequence.append(residue)
695
+ print(f"Identified as: {residue}")
696
+ print(f"Modifications: {mods}")
697
+ else:
698
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
699
+
700
+ return sequence
701
+
702
+ """
703
+ def annotate_cyclic_structure(mol, sequence):
704
+ '''Create annotated 2D structure with clear, non-overlapping residue labels'''
705
+ # Generate 2D coordinates
706
+ # Generate 2D coordinates
707
+ AllChem.Compute2DCoords(mol)
708
+
709
+ # Create drawer with larger size for annotations
710
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
711
+
712
+ # Get residue list and reverse it to match structural representation
713
+ if sequence.startswith('cyclo('):
714
+ residues = sequence[6:-1].split('-')
715
+ else:
716
+ residues = sequence.split('-')
717
+ residues = list(reversed(residues)) # Reverse the sequence
718
+
719
+ # Draw molecule first to get its bounds
720
+ drawer.drawOptions().addAtomIndices = False
721
+ drawer.DrawMolecule(mol)
722
+ drawer.FinishDrawing()
723
+
724
+ # Convert to PIL Image
725
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
726
+ draw = ImageDraw.Draw(img)
727
+
728
+ try:
729
+ # Try to use DejaVuSans as it's commonly available on Linux systems
730
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
731
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
732
+ except OSError:
733
+ try:
734
+ # Fallback to Arial if available (common on Windows)
735
+ font = ImageFont.truetype("arial.ttf", 60)
736
+ small_font = ImageFont.truetype("arial.ttf", 60)
737
+ except OSError:
738
+ # If no TrueType fonts are available, fall back to default
739
+ print("Warning: TrueType fonts not available, using default font")
740
+ font = ImageFont.load_default()
741
+ small_font = ImageFont.load_default()
742
+ # Get molecule bounds
743
+ conf = mol.GetConformer()
744
+ positions = []
745
+ for i in range(mol.GetNumAtoms()):
746
+ pos = conf.GetAtomPosition(i)
747
+ positions.append((pos.x, pos.y))
748
+
749
+ x_coords = [p[0] for p in positions]
750
+ y_coords = [p[1] for p in positions]
751
+ min_x, max_x = min(x_coords), max(x_coords)
752
+ min_y, max_y = min(y_coords), max(y_coords)
753
+
754
+ # Calculate scaling factors
755
+ scale = 150 # Increased scale factor
756
+ center_x = 1000 # Image center
757
+ center_y = 1000
758
+
759
+ # Add residue labels in a circular arrangement around the structure
760
+ n_residues = len(residues)
761
+ radius = 700 # Distance of labels from center
762
+
763
+ # Start from the rightmost point (3 o'clock position) and go counterclockwise
764
+ # Offset by -3 positions to align with structure
765
+ offset = 0 # Adjust this value to match the structure alignment
766
+ for i, residue in enumerate(residues):
767
+ # Calculate position in a circle around the structure
768
+ # Start from 0 (3 o'clock) and go counterclockwise
769
+ angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
770
+
771
+ # Calculate label position
772
+ label_x = center_x + radius * np.cos(angle)
773
+ label_y = center_y + radius * np.sin(angle)
774
+
775
+ # Draw residue label
776
+ text = f"{i+1}. {residue}"
777
+ bbox = draw.textbbox((label_x, label_y), text, font=font)
778
+ padding = 10
779
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
780
+ bbox[2]+padding, bbox[3]+padding],
781
+ fill='white', outline='white')
782
+ draw.text((label_x, label_y), text,
783
+ font=font, fill='black', anchor="mm")
784
+
785
+ # Add sequence at the top with white background
786
+ seq_text = f"Sequence: {sequence}"
787
+ bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
788
+ padding = 10
789
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
790
+ bbox[2]+padding, bbox[3]+padding],
791
+ fill='white', outline='white')
792
+ draw.text((center_x, 100), seq_text,
793
+ font=small_font, fill='black', anchor="mm")
794
+
795
+ return img
796
+
797
+ """
798
+ def annotate_cyclic_structure(mol, sequence):
799
+ """Create structure visualization with just the sequence header"""
800
+ # Generate 2D coordinates
801
+ AllChem.Compute2DCoords(mol)
802
+
803
+ # Create drawer with larger size for annotations
804
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
805
+
806
+ # Draw molecule first
807
+ drawer.drawOptions().addAtomIndices = False
808
+ drawer.DrawMolecule(mol)
809
+ drawer.FinishDrawing()
810
+
811
+ # Convert to PIL Image
812
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
813
+ draw = ImageDraw.Draw(img)
814
+ try:
815
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
816
+ except OSError:
817
+ try:
818
+ small_font = ImageFont.truetype("arial.ttf", 60)
819
+ except OSError:
820
+ print("Warning: TrueType fonts not available, using default font")
821
+ small_font = ImageFont.load_default()
822
+
823
+ # Add just the sequence header at the top
824
+ seq_text = f"Sequence: {sequence}"
825
+ bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
826
+ padding = 10
827
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
828
+ bbox[2]+padding, bbox[3]+padding],
829
+ fill='white', outline='white')
830
+ draw.text((1000, 100), seq_text,
831
+ font=small_font, fill='black', anchor="mm")
832
+
833
+ return img
834
+
835
+ def create_enhanced_linear_viz(sequence, smiles):
836
+ """Create an enhanced linear representation using PeptideAnalyzer"""
837
+ analyzer = PeptideAnalyzer() # Create analyzer instance
838
+
839
+ # Create figure with two subplots
840
+ fig = plt.figure(figsize=(15, 10))
841
+ gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
842
+ ax_struct = fig.add_subplot(gs[0])
843
+ ax_detail = fig.add_subplot(gs[1])
844
+
845
+ # Parse sequence and get residues
846
+ if sequence.startswith('cyclo('):
847
+ residues = sequence[6:-1].split('-')
848
+ else:
849
+ residues = sequence.split('-')
850
+
851
+ # Get segments using analyzer
852
+ segments = analyzer.split_on_bonds(smiles)
853
+
854
+ # Debug print
855
+ print(f"Number of residues: {len(residues)}")
856
+ print(f"Number of segments: {len(segments)}")
857
+
858
+ # Top subplot - Basic structure
859
+ ax_struct.set_xlim(0, 10)
860
+ ax_struct.set_ylim(0, 2)
861
+
862
+ num_residues = len(residues)
863
+ spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
864
+
865
+ # Draw basic structure
866
+ y_pos = 1.5
867
+ for i in range(num_residues):
868
+ x_pos = 0.5 + i * spacing
869
+
870
+ # Draw amino acid box
871
+ rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
872
+ facecolor='lightblue', edgecolor='black')
873
+ ax_struct.add_patch(rect)
874
+
875
+ # Draw connecting bonds if not the last residue
876
+ if i < num_residues - 1:
877
+ segment = segments[i] if i < len(segments) else None
878
+ if segment:
879
+ # Determine bond type from segment info
880
+ bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
881
+ is_n_methylated = 'N-Me' in segment.get('bond_after', '')
882
+
883
+ bond_color = 'red' if bond_type == 'ester' else 'black'
884
+ linestyle = '--' if bond_type == 'ester' else '-'
885
+
886
+ # Draw bond line
887
+ ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
888
+ color=bond_color, linestyle=linestyle, linewidth=2)
889
+
890
+ # Add bond type label
891
+ mid_x = x_pos + spacing/2
892
+ bond_label = f"{bond_type}"
893
+ if is_n_methylated:
894
+ bond_label += "\n(N-Me)"
895
+ ax_struct.text(mid_x, y_pos+0.1, bond_label,
896
+ ha='center', va='bottom', fontsize=10,
897
+ color=bond_color)
898
+
899
+ # Add residue label
900
+ ax_struct.text(x_pos, y_pos-0.5, residues[i],
901
+ ha='center', va='top', fontsize=14)
902
+
903
+ # Bottom subplot - Detailed breakdown
904
+ ax_detail.set_ylim(0, len(segments)+1)
905
+ ax_detail.set_xlim(0, 1)
906
+
907
+ # Create detailed breakdown
908
+ segment_y = len(segments) # Start from top
909
+ for i, segment in enumerate(segments):
910
+ y = segment_y - i
911
+
912
+ # Check if this is a bond or residue
913
+ residue, mods = analyzer.identify_residue(segment)
914
+ if residue:
915
+ text = f"Residue {i+1}: {residue}"
916
+ if mods:
917
+ text += f" ({', '.join(mods)})"
918
+ color = 'blue'
919
+ else:
920
+ # Must be a bond
921
+ text = f"Bond {i}: "
922
+ if 'O-linked' in segment.get('bond_after', ''):
923
+ text += "ester"
924
+ elif 'N-Me' in segment.get('bond_after', ''):
925
+ text += "peptide (N-methylated)"
926
+ else:
927
+ text += "peptide"
928
+ color = 'red'
929
+
930
+ # Add segment analysis
931
+ ax_detail.text(0.05, y, text, fontsize=12, color=color)
932
+ ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
933
+
934
+ # If cyclic, add connection indicator
935
+ if sequence.startswith('cyclo('):
936
+ ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
937
+ arrowprops=dict(arrowstyle='<->', color='red', lw=2))
938
+ ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
939
+ ha='center', color='red', fontsize=14)
940
+
941
+ # Add titles and adjust layout
942
+ ax_struct.set_title("Peptide Structure Overview", pad=20)
943
+ ax_detail.set_title("Segment Analysis Breakdown", pad=20)
944
+
945
+ # Remove axes
946
+ for ax in [ax_struct, ax_detail]:
947
+ ax.set_xticks([])
948
+ ax.set_yticks([])
949
+ ax.axis('off')
950
+
951
+ plt.tight_layout()
952
+ return fig
953
+
954
+ class PeptideStructureGenerator:
955
+ """A class to generate 3D structures of peptides using different embedding methods"""
956
+
957
+ @staticmethod
958
+ def prepare_molecule(smiles):
959
+ """Prepare molecule with proper hydrogen handling"""
960
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
961
+ if mol is None:
962
+ raise ValueError("Failed to create molecule from SMILES")
963
+
964
+ # Calculate valence for each atom
965
+ for atom in mol.GetAtoms():
966
+ atom.UpdatePropertyCache(strict=False)
967
+
968
+ # Sanitize with reduced requirements
969
+ Chem.SanitizeMol(mol,
970
+ sanitizeOps=Chem.SANITIZE_FINDRADICALS|
971
+ Chem.SANITIZE_KEKULIZE|
972
+ Chem.SANITIZE_SETAROMATICITY|
973
+ Chem.SANITIZE_SETCONJUGATION|
974
+ Chem.SANITIZE_SETHYBRIDIZATION|
975
+ Chem.SANITIZE_CLEANUPCHIRALITY)
976
+
977
+ mol = Chem.AddHs(mol)
978
+ return mol
979
+
980
+ @staticmethod
981
+ def get_etkdg_params(attempt=0):
982
+ """Get ETKDG parameters with optional modifications based on attempt number"""
983
+ params = AllChem.ETKDGv3()
984
+ params.randomSeed = -1
985
+ params.maxIterations = 200
986
+ params.numThreads = 4 # Reduced for web interface
987
+ params.useBasicKnowledge = True
988
+ params.enforceChirality = True
989
+ params.useExpTorsionAnglePrefs = True
990
+ params.useSmallRingTorsions = True
991
+ params.useMacrocycleTorsions = True
992
+ params.ETversion = 2
993
+ params.pruneRmsThresh = -1
994
+ params.embedRmsThresh = 0.5
995
+
996
+ if attempt > 10:
997
+ params.bondLength = 1.5 + (attempt - 10) * 0.02
998
+ params.useExpTorsionAnglePrefs = False
999
+
1000
+ return params
1001
+
1002
+ def generate_structure_etkdg(self, smiles, max_attempts=20):
1003
+ """Generate 3D structure using ETKDG without UFF optimization"""
1004
+ success = False
1005
+ mol = None
1006
+
1007
+ for attempt in range(max_attempts):
1008
+ try:
1009
+ mol = self.prepare_molecule(smiles)
1010
+ params = self.get_etkdg_params(attempt)
1011
+
1012
+ if AllChem.EmbedMolecule(mol, params) == 0:
1013
+ success = True
1014
+ break
1015
+ except Exception as e:
1016
+ continue
1017
+
1018
+ if not success:
1019
+ raise ValueError("Failed to generate structure with ETKDG")
1020
+
1021
+ return mol
1022
+
1023
+ def generate_structure_uff(self, smiles, max_attempts=20):
1024
+ """Generate 3D structure using ETKDG followed by UFF optimization"""
1025
+ best_mol = None
1026
+ lowest_energy = float('inf')
1027
+
1028
+ for attempt in range(max_attempts):
1029
+ try:
1030
+ test_mol = self.prepare_molecule(smiles)
1031
+ params = self.get_etkdg_params(attempt)
1032
+
1033
+ if AllChem.EmbedMolecule(test_mol, params) == 0:
1034
+ res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
1035
+ vdwThresh=10.0, confId=0,
1036
+ ignoreInterfragInteractions=True)
1037
+
1038
+ if res == 0:
1039
+ ff = AllChem.UFFGetMoleculeForceField(test_mol)
1040
+ if ff:
1041
+ current_energy = ff.CalcEnergy()
1042
+ if current_energy < lowest_energy:
1043
+ lowest_energy = current_energy
1044
+ best_mol = Chem.Mol(test_mol)
1045
+ except Exception:
1046
+ continue
1047
+
1048
+ if best_mol is None:
1049
+ raise ValueError("Failed to generate optimized structure")
1050
+
1051
+ return best_mol
1052
+
1053
+ @staticmethod
1054
+ def mol_to_sdf_bytes(mol):
1055
+ """Convert RDKit molecule to SDF file bytes"""
1056
+ # First write to StringIO in text mode
1057
+ sio = StringIO()
1058
+ writer = Chem.SDWriter(sio)
1059
+ writer.write(mol)
1060
+ writer.close()
1061
+
1062
+ # Convert the string to bytes
1063
+ return sio.getvalue().encode('utf-8')
1064
+
1065
+ def process_input(smiles_input=None, file_obj=None, show_linear=False,
1066
+ show_segment_details=False, generate_3d=False, use_uff=False):
1067
+ """Process input and create visualizations using PeptideAnalyzer"""
1068
+ analyzer = PeptideAnalyzer()
1069
+ temp_dir = tempfile.mkdtemp() if generate_3d else None
1070
+ structure_files = []
1071
+
1072
+ # Handle direct SMILES input
1073
+ if smiles_input:
1074
+ smiles = smiles_input.strip()
1075
+
1076
+ # First check if it's a peptide using analyzer's method
1077
+ if not analyzer.is_peptide(smiles):
1078
+ return "Error: Input SMILES does not appear to be a peptide structure.", None, None
1079
+
1080
+ try:
1081
+ # Create molecule
1082
+ mol = Chem.MolFromSmiles(smiles)
1083
+ if mol is None:
1084
+ return "Error: Invalid SMILES notation.", None, None
1085
+
1086
+ # Generate 3D structures if requested
1087
+ if generate_3d:
1088
+ generator = PeptideStructureGenerator()
1089
+
1090
+ try:
1091
+ # Generate ETKDG structure
1092
+ mol_etkdg = generator.generate_structure_etkdg(smiles)
1093
+ etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
1094
+ writer = Chem.SDWriter(etkdg_path)
1095
+ writer.write(mol_etkdg)
1096
+ writer.close()
1097
+ structure_files.append(etkdg_path)
1098
+
1099
+ # Generate UFF structure if requested
1100
+ if use_uff:
1101
+ mol_uff = generator.generate_structure_uff(smiles)
1102
+ uff_path = os.path.join(temp_dir, "structure_uff.sdf")
1103
+ writer = Chem.SDWriter(uff_path)
1104
+ writer.write(mol_uff)
1105
+ writer.close()
1106
+ structure_files.append(uff_path)
1107
+
1108
+ except Exception as e:
1109
+ return f"Error generating 3D structures: {str(e)}", None, None, None
1110
+
1111
+ # Use analyzer to get sequence
1112
+ segments = analyzer.split_on_bonds(smiles)
1113
+
1114
+ # Process segments and build sequence
1115
+ sequence_parts = []
1116
+ output_text = ""
1117
+
1118
+ # Only include segment analysis in output if requested
1119
+ if show_segment_details:
1120
+ output_text += "Segment Analysis:\n"
1121
+ for i, segment in enumerate(segments):
1122
+ output_text += f"\nSegment {i}:\n"
1123
+ output_text += f"Content: {segment['content']}\n"
1124
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1125
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1126
+
1127
+ residue, mods = analyzer.identify_residue(segment)
1128
+ if residue:
1129
+ if mods:
1130
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1131
+ else:
1132
+ sequence_parts.append(residue)
1133
+ output_text += f"Identified as: {residue}\n"
1134
+ output_text += f"Modifications: {mods}\n"
1135
+ else:
1136
+ output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
1137
+ output_text += "\n"
1138
+ else:
1139
+ # Just build sequence without detailed analysis in output
1140
+ for segment in segments:
1141
+ residue, mods = analyzer.identify_residue(segment)
1142
+ if residue:
1143
+ if mods:
1144
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1145
+ else:
1146
+ sequence_parts.append(residue)
1147
+
1148
+ # Check if cyclic using analyzer's method
1149
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1150
+ three_letter = '-'.join(sequence_parts)
1151
+ one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
1152
+
1153
+ if is_cyclic:
1154
+ three_letter = f"cyclo({three_letter})"
1155
+ one_letter = f"cyclo({one_letter})"
1156
+
1157
+ # Create cyclic structure visualization
1158
+ img_cyclic = annotate_cyclic_structure(mol, three_letter)
1159
+
1160
+ # Create linear representation if requested
1161
+ img_linear = None
1162
+ if show_linear:
1163
+ fig_linear = create_enhanced_linear_viz(three_letter, smiles)
1164
+ buf = BytesIO()
1165
+ fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
1166
+ buf.seek(0)
1167
+ img_linear = Image.open(buf)
1168
+ plt.close(fig_linear)
1169
+
1170
+ # Add summary to output
1171
+ summary = "Summary:\n"
1172
+ summary += f"Sequence: {three_letter}\n"
1173
+ summary += f"One-letter code: {one_letter}\n"
1174
+ summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1175
+ #if is_cyclic:
1176
+ #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1177
+ #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1178
+
1179
+ if structure_files:
1180
+ summary += "\n3D Structures Generated:\n"
1181
+ for filepath in structure_files:
1182
+ summary += f"- {os.path.basename(filepath)}\n"
1183
+
1184
+ return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
1185
+
1186
+ except Exception as e:
1187
+ return f"Error processing SMILES: {str(e)}", None, None, None
1188
+
1189
+ # Handle file input
1190
+ if file_obj is not None:
1191
+ try:
1192
+ # Handle file content
1193
+ if hasattr(file_obj, 'name'):
1194
+ with open(file_obj.name, 'r') as f:
1195
+ content = f.read()
1196
+ else:
1197
+ content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
1198
+
1199
+ output_text = ""
1200
+ for line in content.splitlines():
1201
+ smiles = line.strip()
1202
+ if smiles:
1203
+ # Check if it's a peptide
1204
+ if not analyzer.is_peptide(smiles):
1205
+ output_text += f"Skipping non-peptide SMILES: {smiles}\n"
1206
+ continue
1207
+
1208
+ # Process this SMILES
1209
+ segments = analyzer.split_on_bonds(smiles)
1210
+ sequence_parts = []
1211
+
1212
+ # Add segment details if requested
1213
+ if show_segment_details:
1214
+ output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
1215
+ for i, segment in enumerate(segments):
1216
+ output_text += f"\nSegment {i}:\n"
1217
+ output_text += f"Content: {segment['content']}\n"
1218
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1219
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1220
+ residue, mods = analyzer.identify_residue(segment)
1221
+ if residue:
1222
+ if mods:
1223
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1224
+ else:
1225
+ sequence_parts.append(residue)
1226
+ output_text += f"Identified as: {residue}\n"
1227
+ output_text += f"Modifications: {mods}\n"
1228
+ else:
1229
+ for segment in segments:
1230
+ residue, mods = analyzer.identify_residue(segment)
1231
+ if residue:
1232
+ if mods:
1233
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1234
+ else:
1235
+ sequence_parts.append(residue)
1236
+
1237
+ # Get cyclicity and create sequence
1238
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1239
+ sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
1240
+
1241
+ output_text += f"\nSummary for SMILES: {smiles}\n"
1242
+ output_text += f"Sequence: {sequence}\n"
1243
+ output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1244
+ if is_cyclic:
1245
+ output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1246
+ #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1247
+ output_text += "-" * 50 + "\n"
1248
+
1249
+ return output_text, None, None
1250
+
1251
+ except Exception as e:
1252
+ return f"Error processing file: {str(e)}", None, None
1253
+
1254
+ return "No input provided.", None, None
1255
+
utils/generate_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import sys
4
+ import pandas as pd
5
+ from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
6
+
7
+
8
+ def mask_for_de_novo(config, sequence_length):
9
+ if config.vocab == 'helm':
10
+ return "[MASK]" * sequence_length
11
+ elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
12
+ return ["<mask>"] * sequence_length
13
+ else:
14
+ return ["[MASK]"] * sequence_length
15
+
16
+ def generate_de_novo(sequence_length, tokenizer, model):
17
+ masked_sequence = mask_for_de_novo(sequence_length)
18
+ inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
19
+
20
+ with torch.no_grad():
21
+ logits = model(**inputs).logits
22
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
23
+ logits_at_masks = logits[0, mask_token_indices]
24
+
25
+ pred_tokens = []
26
+ for i in mask_token_indices:
27
+ topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
28
+ probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
29
+ predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
30
+ predicted_token_id = topk_indices[predicted_index].item()
31
+ predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
32
+ pred_tokens.append(predicted_token)
33
+
34
+ generated_sequence = ''.join(pred_tokens)
35
+ perplexity = calculate_perplexity(model, tokenizer, generated_sequence)
36
+
37
+ return (generated_sequence, perplexity)
38
+
39
+
40
+ def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
41
+ total_loss = 0.0
42
+ tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
43
+
44
+ for i in mask_token_indices:
45
+ masked_input = tensor_input.clone()
46
+ masked_input[0, i] = tokenizer.mask_token_id
47
+
48
+ labels = torch.full(tensor_input.shape, -100).to(model.device)
49
+ labels[0, i] = tensor_input[0, i]
50
+
51
+ with torch.no_grad():
52
+ outputs = model(masked_input, labels=labels)
53
+ total_loss += outputs.loss.item()
54
+
55
+ num_mask_tokens = len(mask_token_indices)
56
+ if num_mask_tokens == 0:
57
+ perplexity = 10000
58
+ else:
59
+ avg_loss = total_loss / num_mask_tokens
60
+ perplexity = math.exp(avg_loss)
61
+
62
+ return perplexity
63
+
64
+
65
+ def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, pepclm_model, device):
66
+ og_embeddings = pepclm_model.roformer.encoder(original_sequence)
67
+ new_embeddings = pepclm_model.roformer.encoder(generated_sequence)
68
+
69
+ sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
70
+ cosine_similarity = torch.mean(sequence_similarity).item()
71
+ return cosine_similarity
72
+
73
+
74
+ def calculate_hamming_dist(original_sequence, generated_sequence):
75
+ generated_sequence = generated_sequence
76
+ original_sequence = original_sequence
77
+ return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))