FlowAMP / sequence_decoder.py
esunAI's picture
Rename final_sequence_decoder.py to sequence_decoder.py
c51e07f verified
import torch
import torch.nn.functional as F
import numpy as np
import esm
from tqdm import tqdm
import os
from datetime import datetime
class EmbeddingToSequenceConverter:
"""
Convert ESM embeddings back to amino acid sequences using real ESM2 token embeddings.
"""
def __init__(self, device='cuda'):
self.device = device
# Load ESM model
print("Loading ESM model for sequence decoding...")
self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
self.model = self.model.to(device)
self.model.eval()
# Get vocabulary
self.vocab = self.alphabet.standard_toks
self.vocab_list = [token for token in self.vocab if token not in ['<cls>', '<eos>', '<unk>', '<pad>', '<mask>']]
# Pre-compute token embeddings for nearest neighbor search
self._precompute_token_embeddings()
print("✓ ESM model loaded for sequence decoding")
def _precompute_token_embeddings(self):
"""
Pre-compute embeddings for all tokens in the vocabulary using real ESM2 embeddings.
"""
print("Pre-computing token embeddings from ESM2 model...")
# Use standard amino acids
standard_aas = 'ACDEFGHIKLMNPQRSTVWY'
self.token_list = list(standard_aas)
# Extract real embeddings from ESM2 model
with torch.no_grad():
# Get token indices for each amino acid
aa_tokens = []
for aa in standard_aas:
try:
token_idx = self.alphabet.get_idx(aa)
aa_tokens.append(token_idx)
except:
print(f"Warning: Could not find token for amino acid {aa}")
# Fallback to a default token
aa_tokens.append(0)
# Convert to tensor
aa_tokens = torch.tensor(aa_tokens, device=self.device)
# Extract embeddings from ESM2's embedding layer
# Note: ESM2 uses a different embedding structure, so we'll use the model's forward pass
# Create dummy sequences for each amino acid
dummy_sequences = [(f"aa_{i}", aa) for i, aa in enumerate(standard_aas)]
# Get embeddings using the same method as the encoder
converter = self.alphabet.get_batch_converter()
_, _, tokens = converter(dummy_sequences)
tokens = tokens.to(self.device)
# Get embeddings from layer 33 (same as encoder)
with torch.no_grad():
out = self.model(tokens, repr_layers=[33], return_contacts=False)
reps = out['representations'][33] # [B, L+2, D]
# Extract per-residue embeddings (remove CLS and EOS tokens)
token_embeddings = []
for i, (_, seq) in enumerate(dummy_sequences):
L = len(seq)
emb = reps[i, 1:1+L, :] # Remove CLS and EOS tokens
# Take the first position embedding for each amino acid
token_embeddings.append(emb[0])
self.token_embeddings = torch.stack(token_embeddings)
print(f"✓ Pre-computed embeddings for {len(self.token_embeddings)} tokens")
print(f" Embedding shape: {self.token_embeddings.shape}")
def embedding_to_sequence(self, embedding, method='diverse', temperature=0.5):
"""
Convert a single embedding back to amino acid sequence.
Args:
embedding: [seq_len, embed_dim] tensor
method: 'diverse', 'nearest_neighbor', or 'random'
temperature: Temperature for diverse sampling (lower = more diverse)
Returns:
sequence: string of amino acids
"""
if method == 'diverse':
return self._diverse_decode(embedding, temperature)
elif method == 'nearest_neighbor':
return self._nearest_neighbor_decode(embedding)
elif method == 'random':
return self._random_decode(embedding)
else:
raise ValueError(f"Unknown method: {method}")
def _diverse_decode(self, embedding, temperature=0.5):
"""
Decode using diverse sampling with temperature control.
"""
# Ensure both tensors are on the same device
embedding = embedding.to(self.device)
token_embeddings = self.token_embeddings.to(self.device)
# Compute cosine similarity between embedding and all token embeddings
embedding_norm = F.normalize(embedding, dim=-1) # [seq_len, embed_dim]
token_embeddings_norm = F.normalize(token_embeddings, dim=-1) # [vocab_size, embed_dim]
# Compute similarities
similarities = torch.mm(embedding_norm, token_embeddings_norm.t()) # [seq_len, vocab_size]
# Apply temperature to increase diversity
similarities = similarities / temperature
# Convert to probabilities
probs = F.softmax(similarities, dim=-1)
# Sample from the distribution
sampled_indices = torch.multinomial(probs, 1).squeeze(-1)
# Convert to sequence
sequence = ''.join([self.token_list[idx] for idx in sampled_indices.cpu().numpy()])
return sequence
def _nearest_neighbor_decode(self, embedding):
"""
Decode using nearest neighbor search in token embedding space.
"""
# Ensure both tensors are on the same device
embedding = embedding.to(self.device)
token_embeddings = self.token_embeddings.to(self.device)
# Compute cosine similarity between embedding and all token embeddings
embedding_norm = F.normalize(embedding, dim=-1) # [seq_len, embed_dim]
token_embeddings_norm = F.normalize(token_embeddings, dim=-1) # [vocab_size, embed_dim]
# Compute similarities
similarities = torch.mm(embedding_norm, token_embeddings_norm.t()) # [seq_len, vocab_size]
# Find nearest neighbors
nearest_indices = torch.argmax(similarities, dim=-1) # [seq_len]
# Convert to sequence
sequence = ''.join([self.token_list[idx] for idx in nearest_indices.cpu().numpy()])
return sequence
def _random_decode(self, embedding):
"""
Decode using random sampling (fallback method).
"""
seq_len = embedding.shape[0]
sequence = ''.join(np.random.choice(self.token_list, seq_len))
return sequence
def batch_embedding_to_sequences(self, embeddings, method='diverse', temperature=0.5):
"""
Convert batch of embeddings to sequences.
Args:
embeddings: [batch_size, seq_len, embed_dim] tensor
method: decoding method
temperature: Temperature for diverse sampling
Returns:
sequences: list of strings
"""
sequences = []
for i in tqdm(range(len(embeddings)), desc="Converting embeddings to sequences"):
embedding = embeddings[i]
sequence = self.embedding_to_sequence(embedding, method=method, temperature=temperature)
sequences.append(sequence)
return sequences
def validate_sequence(self, sequence):
"""
Validate if a sequence contains valid amino acids.
"""
valid_aas = set('ACDEFGHIKLMNPQRSTVWY')
return all(aa in valid_aas for aa in sequence)
def filter_valid_sequences(self, sequences):
"""
Filter out sequences with invalid amino acids.
"""
valid_sequences = []
for seq in sequences:
if self.validate_sequence(seq):
valid_sequences.append(seq)
else:
print(f"Warning: Invalid sequence found: {seq}")
return valid_sequences
def main():
"""
Decode all CFG-generated peptide embeddings to sequences and analyze distribution.
Uses the best trained model (loss: 0.017183, step: 53).
"""
print("=== CFG-Generated Peptide Sequence Decoder (Best Model) ===")
# Initialize converter
converter = EmbeddingToSequenceConverter()
# Get today's date for filename
today = datetime.now().strftime('%Y%m%d')
# Load all CFG-generated embeddings (using best model)
cfg_files = {
'No CFG (0.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_no_cfg_{today}.pt',
'Weak CFG (3.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_weak_cfg_{today}.pt',
'Strong CFG (7.5)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_strong_cfg_{today}.pt',
'Very Strong CFG (15.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_very_strong_cfg_{today}.pt'
}
all_results = {}
for cfg_name, file_path in cfg_files.items():
print(f"\n{'='*50}")
print(f"Processing {cfg_name}...")
print(f"Loading: {file_path}")
try:
# Load embeddings
embeddings = torch.load(file_path, map_location='cpu')
print(f"✓ Loaded {len(embeddings)} embeddings, shape: {embeddings.shape}")
# Decode to sequences using diverse method
print(f"Decoding sequences...")
sequences = converter.batch_embedding_to_sequences(embeddings, method='diverse', temperature=0.5)
# Filter valid sequences
valid_sequences = converter.filter_valid_sequences(sequences)
print(f"✓ Valid sequences: {len(valid_sequences)}/{len(sequences)}")
# Store results
all_results[cfg_name] = {
'sequences': valid_sequences,
'total': len(sequences),
'valid': len(valid_sequences)
}
# Show sample sequences
print(f"\nSample sequences ({cfg_name}):")
for i, seq in enumerate(valid_sequences[:5]):
print(f" {i+1}: {seq}")
except Exception as e:
print(f"❌ Error processing {file_path}: {e}")
all_results[cfg_name] = {'sequences': [], 'total': 0, 'valid': 0}
# Analysis and comparison
print(f"\n{'='*60}")
print("CFG ANALYSIS SUMMARY")
print(f"{'='*60}")
for cfg_name, results in all_results.items():
sequences = results['sequences']
if sequences:
# Calculate sequence statistics
lengths = [len(seq) for seq in sequences]
avg_length = np.mean(lengths)
std_length = np.std(lengths)
# Calculate amino acid composition
all_aas = ''.join(sequences)
aa_counts = {}
for aa in 'ACDEFGHIKLMNPQRSTVWY':
aa_counts[aa] = all_aas.count(aa)
# Calculate diversity (unique sequences)
unique_sequences = len(set(sequences))
diversity_ratio = unique_sequences / len(sequences)
print(f"\n{cfg_name}:")
print(f" Total sequences: {results['total']}")
print(f" Valid sequences: {results['valid']}")
print(f" Unique sequences: {unique_sequences}")
print(f" Diversity ratio: {diversity_ratio:.3f}")
print(f" Avg length: {avg_length:.1f} ± {std_length:.1f}")
print(f" Length range: {min(lengths)}-{max(lengths)}")
# Show top amino acids
sorted_aas = sorted(aa_counts.items(), key=lambda x: x[1], reverse=True)
print(f" Top 5 AAs: {', '.join([f'{aa}({count})' for aa, count in sorted_aas[:5]])}")
# Create output directory if it doesn't exist
output_dir = '/data2/edwardsun/decoded_sequences'
os.makedirs(output_dir, exist_ok=True)
# Save sequences to file with date
output_file = os.path.join(output_dir, f"decoded_sequences_{cfg_name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')}_{today}.txt")
with open(output_file, 'w') as f:
f.write(f"# Decoded sequences from {cfg_name}\n")
f.write(f"# Total: {results['total']}, Valid: {results['valid']}, Unique: {unique_sequences}\n")
f.write(f"# Generated from best model (loss: 0.017183, step: 53)\n\n")
for i, seq in enumerate(sequences):
f.write(f"seq_{i+1:03d}\t{seq}\n")
print(f" ✓ Saved to: {output_file}")
# Overall comparison
print(f"\n{'='*60}")
print("OVERALL COMPARISON")
print(f"{'='*60}")
cfg_names = list(all_results.keys())
valid_counts = [all_results[name]['valid'] for name in cfg_names]
unique_counts = [len(set(all_results[name]['sequences'])) for name in cfg_names]
print(f"Valid sequences: {dict(zip(cfg_names, valid_counts))}")
print(f"Unique sequences: {dict(zip(cfg_names, unique_counts))}")
# Find most diverse and most similar
if all(valid_counts):
diversity_ratios = [unique_counts[i]/valid_counts[i] for i in range(len(valid_counts))]
most_diverse = cfg_names[diversity_ratios.index(max(diversity_ratios))]
least_diverse = cfg_names[diversity_ratios.index(min(diversity_ratios))]
print(f"\nMost diverse: {most_diverse} (ratio: {max(diversity_ratios):.3f})")
print(f"Least diverse: {least_diverse} (ratio: {min(diversity_ratios):.3f})")
print(f"\n✓ Decoding complete! Check the output files for detailed sequences.")
if __name__ == "__main__":
main()