|
|
|
|
|
import re |
|
|
import torch |
|
|
|
|
|
from .utils import utils |
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import lightning.pytorch as pl |
|
|
from functools import partial |
|
|
import sys |
|
|
|
|
|
class CustomDataset(Dataset): |
|
|
def __init__(self, dataset, indices): |
|
|
self.dataset = dataset |
|
|
self.indices = indices |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.indices) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
actual_idx = int(self.indices[idx]) |
|
|
item = self.dataset[actual_idx] |
|
|
return item |
|
|
|
|
|
|
|
|
|
|
|
def peptide_bond_mask(smiles_list): |
|
|
""" |
|
|
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations |
|
|
of recognized bonds in the positions dictionary and 0 elsewhere. |
|
|
|
|
|
Args: |
|
|
smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. |
|
|
""" |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
max_seq_length = max(len(smiles) for smiles in smiles_list) |
|
|
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) |
|
|
|
|
|
bond_patterns = [ |
|
|
(r'OC\(=O\)', 'ester'), |
|
|
(r'N\(C\)C\(=O\)', 'n_methyl'), |
|
|
(r'N[12]C\(=O\)', 'peptide'), |
|
|
(r'NC\(=O\)', 'peptide'), |
|
|
(r'C\(=O\)N\(C\)', 'n_methyl'), |
|
|
(r'C\(=O\)N[12]?', 'peptide') |
|
|
] |
|
|
|
|
|
for batch_idx, smiles in enumerate(smiles_list): |
|
|
positions = [] |
|
|
used = set() |
|
|
|
|
|
|
|
|
for pattern, bond_type in bond_patterns: |
|
|
for match in re.finditer(pattern, smiles): |
|
|
if not any(p in range(match.start(), match.end()) for p in used): |
|
|
positions.append({ |
|
|
'start': match.start(), |
|
|
'end': match.end(), |
|
|
'type': bond_type, |
|
|
'pattern': match.group() |
|
|
}) |
|
|
used.update(range(match.start(), match.end())) |
|
|
|
|
|
|
|
|
for pos in positions: |
|
|
mask[batch_idx, pos['start']:pos['end']] = 1 |
|
|
|
|
|
return mask |
|
|
|
|
|
def peptide_token_mask(smiles_list, token_lists): |
|
|
""" |
|
|
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens |
|
|
where any part of the token overlaps with a peptide bond, and 0 elsewhere. |
|
|
|
|
|
Args: |
|
|
smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
|
|
token_lists: List of tokenized SMILES strings (split into tokens). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. |
|
|
""" |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
token_seq_length = max(len(tokens) for tokens in token_lists) |
|
|
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) |
|
|
atomwise_masks = peptide_bond_mask(smiles_list) |
|
|
|
|
|
|
|
|
for batch_idx, atomwise_mask in enumerate(atomwise_masks): |
|
|
token_seq = token_lists[batch_idx] |
|
|
atom_idx = 0 |
|
|
|
|
|
for token_idx, token in enumerate(token_seq): |
|
|
if token_idx != 0 and token_idx != len(token_seq) - 1: |
|
|
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: |
|
|
tokenized_masks[batch_idx][token_idx] = 1 |
|
|
atom_idx += len(token) |
|
|
|
|
|
return tokenized_masks |
|
|
|
|
|
def extract_amino_acid_sequence(helm_string): |
|
|
""" |
|
|
Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array, |
|
|
removing any brackets around each amino acid. |
|
|
|
|
|
Args: |
|
|
helm_string (str): The HELM notation string for a peptide. |
|
|
|
|
|
Returns: |
|
|
list: A list containing each amino acid in sequence without brackets. |
|
|
""" |
|
|
|
|
|
matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string) |
|
|
|
|
|
if matches: |
|
|
|
|
|
amino_acid_sequence = [] |
|
|
for match in matches: |
|
|
sequence = match.replace('[', '').replace(']', '').split('.') |
|
|
amino_acid_sequence.extend(sequence) |
|
|
return amino_acid_sequence |
|
|
else: |
|
|
return "Invalid HELM notation or no peptide sequence found." |
|
|
|
|
|
def helm_collate_fn(batch, tokenizer): |
|
|
sequences = [item['HELM'] for item in batch] |
|
|
|
|
|
max_len = 0 |
|
|
for sequence in sequences: |
|
|
seq_len = len(extract_amino_acid_sequence(sequence)) |
|
|
if seq_len > max_len: |
|
|
max_len = seq_len |
|
|
|
|
|
tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024) |
|
|
|
|
|
return { |
|
|
'input_ids': tokens['input_ids'], |
|
|
'attention_mask': tokens['attention_mask'] |
|
|
} |
|
|
|
|
|
|
|
|
def collate_fn(batch, tokenizer): |
|
|
"""Standard data collator that truncates/pad sequences based on max_length""" |
|
|
valid_sequences = [] |
|
|
valid_items = [] |
|
|
|
|
|
for item in batch: |
|
|
try: |
|
|
test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035) |
|
|
valid_sequences.append(item['SMILES']) |
|
|
valid_items.append(item) |
|
|
except Exception as e: |
|
|
print(f"Skipping sequence due to: {str(e)}") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035) |
|
|
|
|
|
token_array = tokenizer.get_token_split(tokens['input_ids']) |
|
|
bond_mask = peptide_token_mask(valid_sequences, token_array) |
|
|
|
|
|
|
|
|
return { |
|
|
'input_ids': tokens['input_ids'], |
|
|
'attention_mask': tokens['attention_mask'], |
|
|
'bond_mask': bond_mask |
|
|
} |
|
|
|
|
|
|
|
|
class CustomDataModule(pl.LightningDataModule): |
|
|
def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn): |
|
|
super().__init__() |
|
|
self.train_dataset = train_dataset |
|
|
self.val_dataset = val_dataset |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.tokenizer = tokenizer |
|
|
self.collate_fn = collate_fn |
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader(self.train_dataset, |
|
|
batch_size=self.batch_size, |
|
|
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), |
|
|
num_workers=8, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader(self.val_dataset, |
|
|
batch_size=self.batch_size, |
|
|
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), |
|
|
num_workers=8, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
"""def test_dataloader(self): |
|
|
return DataLoader(self.test_dataset, batch_size=self.batch_size, |
|
|
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), |
|
|
num_workers=8, pin_memory=True)""" |