PepTune / dataset.py
Yinuo Zhang
upload data
d65f3a2
import re
import torch
from .utils import utils
from torch.utils.data import Dataset, DataLoader
import lightning.pytorch as pl
from functools import partial
import sys
class CustomDataset(Dataset):
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
actual_idx = int(self.indices[idx])
item = self.dataset[actual_idx]
return item
# for weighting losses of peptide bonds
def peptide_bond_mask(smiles_list):
"""
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
of recognized bonds in the positions dictionary and 0 elsewhere.
Args:
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
Returns:
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
"""
# Initialize the batch mask
batch_size = len(smiles_list)
max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
bond_patterns = [
(r'OC\(=O\)', 'ester'),
(r'N\(C\)C\(=O\)', 'n_methyl'),
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
(r'C\(=O\)N\(C\)', 'n_methyl'),
(r'C\(=O\)N[12]?', 'peptide')
]
for batch_idx, smiles in enumerate(smiles_list):
positions = []
used = set()
# Identify bonds
for pattern, bond_type in bond_patterns:
for match in re.finditer(pattern, smiles):
if not any(p in range(match.start(), match.end()) for p in used):
positions.append({
'start': match.start(),
'end': match.end(),
'type': bond_type,
'pattern': match.group()
})
used.update(range(match.start(), match.end()))
# Update the mask for the current SMILES
for pos in positions:
mask[batch_idx, pos['start']:pos['end']] = 1
return mask
def peptide_token_mask(smiles_list, token_lists):
"""
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
Args:
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
token_lists: List of tokenized SMILES strings (split into tokens).
Returns:
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
"""
# Initialize the batch mask
batch_size = len(smiles_list)
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
atomwise_masks = peptide_bond_mask(smiles_list)
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
token_seq = token_lists[batch_idx]
atom_idx = 0
for token_idx, token in enumerate(token_seq):
if token_idx != 0 and token_idx != len(token_seq) - 1:
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
tokenized_masks[batch_idx][token_idx] = 1
atom_idx += len(token)
return tokenized_masks
def extract_amino_acid_sequence(helm_string):
"""
Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
removing any brackets around each amino acid.
Args:
helm_string (str): The HELM notation string for a peptide.
Returns:
list: A list containing each amino acid in sequence without brackets.
"""
# Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
if matches:
# Join all matched sequences and split by dots to get individual amino acids
amino_acid_sequence = []
for match in matches:
sequence = match.replace('[', '').replace(']', '').split('.')
amino_acid_sequence.extend(sequence)
return amino_acid_sequence
else:
return "Invalid HELM notation or no peptide sequence found."
def helm_collate_fn(batch, tokenizer):
sequences = [item['HELM'] for item in batch]
max_len = 0
for sequence in sequences:
seq_len = len(extract_amino_acid_sequence(sequence))
if seq_len > max_len:
max_len = seq_len
tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
return {
'input_ids': tokens['input_ids'],
'attention_mask': tokens['attention_mask']
}
def collate_fn(batch, tokenizer):
"""Standard data collator that truncates/pad sequences based on max_length"""
valid_sequences = []
valid_items = []
for item in batch:
try:
test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
valid_sequences.append(item['SMILES'])
valid_items.append(item)
except Exception as e:
print(f"Skipping sequence due to: {str(e)}")
continue
#sequences = [item['SMILES'] for item in batch]
#max_len = max([len(seq) for seq in sequences])
#labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
token_array = tokenizer.get_token_split(tokens['input_ids'])
bond_mask = peptide_token_mask(valid_sequences, token_array)
#attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
return {
'input_ids': tokens['input_ids'],
'attention_mask': tokens['attention_mask'],
'bond_mask': bond_mask
}
class CustomDataModule(pl.LightningDataModule):
def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
#self.test_dataset = test_dataset
self.batch_size = batch_size
self.tokenizer = tokenizer
self.collate_fn = collate_fn
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
num_workers=8,
pin_memory=True
)
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
num_workers=8,
pin_memory=True
)
"""def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size,
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
num_workers=8, pin_memory=True)"""