|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import esm |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
class EmbeddingToSequenceConverter: |
|
|
""" |
|
|
Convert ESM embeddings back to amino acid sequences using real ESM2 token embeddings. |
|
|
""" |
|
|
|
|
|
def __init__(self, device='cuda'): |
|
|
self.device = device |
|
|
|
|
|
|
|
|
print("Loading ESM model for sequence decoding...") |
|
|
self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
|
self.model = self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.vocab = self.alphabet.standard_toks |
|
|
self.vocab_list = [token for token in self.vocab if token not in ['<cls>', '<eos>', '<unk>', '<pad>', '<mask>']] |
|
|
|
|
|
|
|
|
self._precompute_token_embeddings() |
|
|
|
|
|
print("✓ ESM model loaded for sequence decoding") |
|
|
|
|
|
def _precompute_token_embeddings(self): |
|
|
""" |
|
|
Pre-compute embeddings for all tokens in the vocabulary using real ESM2 embeddings. |
|
|
""" |
|
|
print("Pre-computing token embeddings from ESM2 model...") |
|
|
|
|
|
|
|
|
standard_aas = 'ACDEFGHIKLMNPQRSTVWY' |
|
|
self.token_list = list(standard_aas) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
aa_tokens = [] |
|
|
for aa in standard_aas: |
|
|
try: |
|
|
token_idx = self.alphabet.get_idx(aa) |
|
|
aa_tokens.append(token_idx) |
|
|
except: |
|
|
print(f"Warning: Could not find token for amino acid {aa}") |
|
|
|
|
|
aa_tokens.append(0) |
|
|
|
|
|
|
|
|
aa_tokens = torch.tensor(aa_tokens, device=self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dummy_sequences = [(f"aa_{i}", aa) for i, aa in enumerate(standard_aas)] |
|
|
|
|
|
|
|
|
converter = self.alphabet.get_batch_converter() |
|
|
_, _, tokens = converter(dummy_sequences) |
|
|
tokens = tokens.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.model(tokens, repr_layers=[33], return_contacts=False) |
|
|
reps = out['representations'][33] |
|
|
|
|
|
|
|
|
token_embeddings = [] |
|
|
for i, (_, seq) in enumerate(dummy_sequences): |
|
|
L = len(seq) |
|
|
emb = reps[i, 1:1+L, :] |
|
|
|
|
|
token_embeddings.append(emb[0]) |
|
|
|
|
|
self.token_embeddings = torch.stack(token_embeddings) |
|
|
|
|
|
print(f"✓ Pre-computed embeddings for {len(self.token_embeddings)} tokens") |
|
|
print(f" Embedding shape: {self.token_embeddings.shape}") |
|
|
|
|
|
def embedding_to_sequence(self, embedding, method='diverse', temperature=0.5): |
|
|
""" |
|
|
Convert a single embedding back to amino acid sequence. |
|
|
|
|
|
Args: |
|
|
embedding: [seq_len, embed_dim] tensor |
|
|
method: 'diverse', 'nearest_neighbor', or 'random' |
|
|
temperature: Temperature for diverse sampling (lower = more diverse) |
|
|
|
|
|
Returns: |
|
|
sequence: string of amino acids |
|
|
""" |
|
|
if method == 'diverse': |
|
|
return self._diverse_decode(embedding, temperature) |
|
|
elif method == 'nearest_neighbor': |
|
|
return self._nearest_neighbor_decode(embedding) |
|
|
elif method == 'random': |
|
|
return self._random_decode(embedding) |
|
|
else: |
|
|
raise ValueError(f"Unknown method: {method}") |
|
|
|
|
|
def _diverse_decode(self, embedding, temperature=0.5): |
|
|
""" |
|
|
Decode using diverse sampling with temperature control. |
|
|
""" |
|
|
|
|
|
embedding = embedding.to(self.device) |
|
|
token_embeddings = self.token_embeddings.to(self.device) |
|
|
|
|
|
|
|
|
embedding_norm = F.normalize(embedding, dim=-1) |
|
|
token_embeddings_norm = F.normalize(token_embeddings, dim=-1) |
|
|
|
|
|
|
|
|
similarities = torch.mm(embedding_norm, token_embeddings_norm.t()) |
|
|
|
|
|
|
|
|
similarities = similarities / temperature |
|
|
|
|
|
|
|
|
probs = F.softmax(similarities, dim=-1) |
|
|
|
|
|
|
|
|
sampled_indices = torch.multinomial(probs, 1).squeeze(-1) |
|
|
|
|
|
|
|
|
sequence = ''.join([self.token_list[idx] for idx in sampled_indices.cpu().numpy()]) |
|
|
|
|
|
return sequence |
|
|
|
|
|
def _nearest_neighbor_decode(self, embedding): |
|
|
""" |
|
|
Decode using nearest neighbor search in token embedding space. |
|
|
""" |
|
|
|
|
|
embedding = embedding.to(self.device) |
|
|
token_embeddings = self.token_embeddings.to(self.device) |
|
|
|
|
|
|
|
|
embedding_norm = F.normalize(embedding, dim=-1) |
|
|
token_embeddings_norm = F.normalize(token_embeddings, dim=-1) |
|
|
|
|
|
|
|
|
similarities = torch.mm(embedding_norm, token_embeddings_norm.t()) |
|
|
|
|
|
|
|
|
nearest_indices = torch.argmax(similarities, dim=-1) |
|
|
|
|
|
|
|
|
sequence = ''.join([self.token_list[idx] for idx in nearest_indices.cpu().numpy()]) |
|
|
|
|
|
return sequence |
|
|
|
|
|
def _random_decode(self, embedding): |
|
|
""" |
|
|
Decode using random sampling (fallback method). |
|
|
""" |
|
|
seq_len = embedding.shape[0] |
|
|
sequence = ''.join(np.random.choice(self.token_list, seq_len)) |
|
|
return sequence |
|
|
|
|
|
def batch_embedding_to_sequences(self, embeddings, method='diverse', temperature=0.5): |
|
|
""" |
|
|
Convert batch of embeddings to sequences. |
|
|
|
|
|
Args: |
|
|
embeddings: [batch_size, seq_len, embed_dim] tensor |
|
|
method: decoding method |
|
|
temperature: Temperature for diverse sampling |
|
|
|
|
|
Returns: |
|
|
sequences: list of strings |
|
|
""" |
|
|
sequences = [] |
|
|
|
|
|
for i in tqdm(range(len(embeddings)), desc="Converting embeddings to sequences"): |
|
|
embedding = embeddings[i] |
|
|
sequence = self.embedding_to_sequence(embedding, method=method, temperature=temperature) |
|
|
sequences.append(sequence) |
|
|
|
|
|
return sequences |
|
|
|
|
|
def validate_sequence(self, sequence): |
|
|
""" |
|
|
Validate if a sequence contains valid amino acids. |
|
|
""" |
|
|
valid_aas = set('ACDEFGHIKLMNPQRSTVWY') |
|
|
return all(aa in valid_aas for aa in sequence) |
|
|
|
|
|
def filter_valid_sequences(self, sequences): |
|
|
""" |
|
|
Filter out sequences with invalid amino acids. |
|
|
""" |
|
|
valid_sequences = [] |
|
|
for seq in sequences: |
|
|
if self.validate_sequence(seq): |
|
|
valid_sequences.append(seq) |
|
|
else: |
|
|
print(f"Warning: Invalid sequence found: {seq}") |
|
|
|
|
|
return valid_sequences |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Decode all CFG-generated peptide embeddings to sequences and analyze distribution. |
|
|
Uses the best trained model (loss: 0.017183, step: 53). |
|
|
""" |
|
|
print("=== CFG-Generated Peptide Sequence Decoder (Best Model) ===") |
|
|
|
|
|
|
|
|
converter = EmbeddingToSequenceConverter() |
|
|
|
|
|
|
|
|
today = datetime.now().strftime('%Y%m%d') |
|
|
|
|
|
|
|
|
cfg_files = { |
|
|
'No CFG (0.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_no_cfg_{today}.pt', |
|
|
'Weak CFG (3.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_weak_cfg_{today}.pt', |
|
|
'Strong CFG (7.5)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_strong_cfg_{today}.pt', |
|
|
'Very Strong CFG (15.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_very_strong_cfg_{today}.pt' |
|
|
} |
|
|
|
|
|
all_results = {} |
|
|
|
|
|
for cfg_name, file_path in cfg_files.items(): |
|
|
print(f"\n{'='*50}") |
|
|
print(f"Processing {cfg_name}...") |
|
|
print(f"Loading: {file_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
embeddings = torch.load(file_path, map_location='cpu') |
|
|
print(f"✓ Loaded {len(embeddings)} embeddings, shape: {embeddings.shape}") |
|
|
|
|
|
|
|
|
print(f"Decoding sequences...") |
|
|
sequences = converter.batch_embedding_to_sequences(embeddings, method='diverse', temperature=0.5) |
|
|
|
|
|
|
|
|
valid_sequences = converter.filter_valid_sequences(sequences) |
|
|
print(f"✓ Valid sequences: {len(valid_sequences)}/{len(sequences)}") |
|
|
|
|
|
|
|
|
all_results[cfg_name] = { |
|
|
'sequences': valid_sequences, |
|
|
'total': len(sequences), |
|
|
'valid': len(valid_sequences) |
|
|
} |
|
|
|
|
|
|
|
|
print(f"\nSample sequences ({cfg_name}):") |
|
|
for i, seq in enumerate(valid_sequences[:5]): |
|
|
print(f" {i+1}: {seq}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error processing {file_path}: {e}") |
|
|
all_results[cfg_name] = {'sequences': [], 'total': 0, 'valid': 0} |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("CFG ANALYSIS SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for cfg_name, results in all_results.items(): |
|
|
sequences = results['sequences'] |
|
|
if sequences: |
|
|
|
|
|
lengths = [len(seq) for seq in sequences] |
|
|
avg_length = np.mean(lengths) |
|
|
std_length = np.std(lengths) |
|
|
|
|
|
|
|
|
all_aas = ''.join(sequences) |
|
|
aa_counts = {} |
|
|
for aa in 'ACDEFGHIKLMNPQRSTVWY': |
|
|
aa_counts[aa] = all_aas.count(aa) |
|
|
|
|
|
|
|
|
unique_sequences = len(set(sequences)) |
|
|
diversity_ratio = unique_sequences / len(sequences) |
|
|
|
|
|
print(f"\n{cfg_name}:") |
|
|
print(f" Total sequences: {results['total']}") |
|
|
print(f" Valid sequences: {results['valid']}") |
|
|
print(f" Unique sequences: {unique_sequences}") |
|
|
print(f" Diversity ratio: {diversity_ratio:.3f}") |
|
|
print(f" Avg length: {avg_length:.1f} ± {std_length:.1f}") |
|
|
print(f" Length range: {min(lengths)}-{max(lengths)}") |
|
|
|
|
|
|
|
|
sorted_aas = sorted(aa_counts.items(), key=lambda x: x[1], reverse=True) |
|
|
print(f" Top 5 AAs: {', '.join([f'{aa}({count})' for aa, count in sorted_aas[:5]])}") |
|
|
|
|
|
|
|
|
output_dir = '/data2/edwardsun/decoded_sequences' |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
output_file = os.path.join(output_dir, f"decoded_sequences_{cfg_name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')}_{today}.txt") |
|
|
with open(output_file, 'w') as f: |
|
|
f.write(f"# Decoded sequences from {cfg_name}\n") |
|
|
f.write(f"# Total: {results['total']}, Valid: {results['valid']}, Unique: {unique_sequences}\n") |
|
|
f.write(f"# Generated from best model (loss: 0.017183, step: 53)\n\n") |
|
|
for i, seq in enumerate(sequences): |
|
|
f.write(f"seq_{i+1:03d}\t{seq}\n") |
|
|
print(f" ✓ Saved to: {output_file}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("OVERALL COMPARISON") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
cfg_names = list(all_results.keys()) |
|
|
valid_counts = [all_results[name]['valid'] for name in cfg_names] |
|
|
unique_counts = [len(set(all_results[name]['sequences'])) for name in cfg_names] |
|
|
|
|
|
print(f"Valid sequences: {dict(zip(cfg_names, valid_counts))}") |
|
|
print(f"Unique sequences: {dict(zip(cfg_names, unique_counts))}") |
|
|
|
|
|
|
|
|
if all(valid_counts): |
|
|
diversity_ratios = [unique_counts[i]/valid_counts[i] for i in range(len(valid_counts))] |
|
|
most_diverse = cfg_names[diversity_ratios.index(max(diversity_ratios))] |
|
|
least_diverse = cfg_names[diversity_ratios.index(min(diversity_ratios))] |
|
|
|
|
|
print(f"\nMost diverse: {most_diverse} (ratio: {max(diversity_ratios):.3f})") |
|
|
print(f"Least diverse: {least_diverse} (ratio: {min(diversity_ratios):.3f})") |
|
|
|
|
|
print(f"\n✓ Decoding complete! Check the output files for detailed sequences.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |