PepTune / generate_mcts.py
Sophia Tang
model upload
e54915d
#!/usr/bin/env
import torch
import torch.nn.functional as F
import math
import random
import sys
import pandas as pd
from utils.generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
from diffusion import Diffusion
from pareto_mcts import Node, MCTS
import hydra
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, pipeline
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from helm_tokenizer.helm_tokenizer import HelmTokenizer
from utils.helm_utils import create_helm_from_aa_seq
from utils.app import PeptideAnalyzer
from new_tokenizer.ape_tokenizer import APETokenizer
import matplotlib.pyplot as plt
import os
import seaborn as sns
import pandas as pd
import numpy as np
def save_logs_to_file(config, valid_fraction_log, affinity1_log, affinity2_log, sol_log, hemo_log, nf_log, permeability_log, output_path):
"""
Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
Parameters:
valid_fraction_log (list): Log of valid fractions over iterations.
affinity1_log (list): Log of binding affinity over iterations.
permeability_log (list): Log of membrane permeability over iterations.
output_path (str): Path to save the log CSV file.
"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
if config.mcts.perm:
# Combine logs into a DataFrame
log_data = {
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
"Valid Fraction": valid_fraction_log,
"Binding Affinity": affinity1_log,
"Solubility": sol_log,
"Hemolysis": hemo_log,
"Nonfouling": nf_log,
"Permeability": permeability_log
}
elif config.mcts.dual:
log_data = {
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
"Valid Fraction": valid_fraction_log,
"Binding Affinity 1": affinity1_log,
"Binding Affinity 2": affinity2_log,
"Solubility": sol_log,
"Hemolysis": hemo_log,
"Nonfouling": nf_log,
"Permeability": permeability_log
}
elif config.mcts.single:
log_data = {
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
"Valid Fraction": valid_fraction_log,
"Permeability": permeability_log
}
else:
log_data = {
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
"Valid Fraction": valid_fraction_log,
"Binding Affinity": affinity1_log,
"Solubility": sol_log,
"Hemolysis": hemo_log,
"Nonfouling": nf_log
}
df = pd.DataFrame(log_data)
# Save to CSV
df.to_csv(output_path, index=False)
def plot_data(log1, log2=None,
save_path=None,
label1="Log 1",
label2=None,
title="Fraction of Valid Peptides Over Iterations",
palette=None):
"""
Plots one or two datasets with their mean values over iterations.
Parameters:
log1 (list): The first list of mean values for each iteration.
log2 (list, optional): The second list of mean values for each iteration. Defaults to None.
save_path (str): Path to save the plot. Defaults to None.
label1 (str): Label for the first dataset. Defaults to "Log 1".
label2 (str, optional): Label for the second dataset. Defaults to None.
title (str): Title of the plot. Defaults to "Mean Values Over Iterations".
palette (dict, optional): A dictionary defining custom colors for datasets. Defaults to None.
"""
# Prepare data for log1
data1 = pd.DataFrame({
"Iteration": range(1, len(log1) + 1),
"Fraction of Valid Peptides": log1,
"Dataset": label1
})
# Prepare data for log2 if provided
if log2 is not None:
data2 = pd.DataFrame({
"Iteration": range(1, len(log2) + 1),
"Fraction of Valid Peptides": log2,
"Dataset": label2
})
data = pd.concat([data1, data2], ignore_index=True)
else:
data = data1
palette = {
label1: "#8181ED", # Default color for log1
label2: "#D577FF" # Default color for log2 (if provided)
}
# Set Seaborn theme
sns.set_theme()
sns.set_context("paper")
# Create the plot
sns.lineplot(
data=data,
x="Iteration",
y="Fraction of Valid Peptides",
hue="Dataset",
style="Dataset",
markers=True,
dashes=False,
palette=palette
)
# Titles and labels
plt.title(title)
plt.xlabel("Iteration")
plt.ylabel("Fraction of Valid Peptides")
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Plot saved to {save_path}")
plt.show()
def plot_data_with_distribution_seaborn(log1, log2=None,
save_path=None,
label1=None,
label2=None,
title=None):
"""
Plots one or two datasets with the average values and distributions over iterations using Seaborn.
Parameters:
log1 (list of lists): The first list of scores (each element is a list of scores for an iteration).
log2 (list of lists, optional): The second list of scores (each element is a list of scores for an iteration). Defaults to None.
save_path (str): Path to save the plot. Defaults to None.
label1 (str): Label for the first dataset. Defaults to "Fraction of Valid Peptide SMILES".
label2 (str, optional): Label for the second dataset. Defaults to None.
title (str): Title of the plot. Defaults to "Fraction of Valid Peptides Over Iterations".
"""
# Prepare data for log1
data1 = pd.DataFrame({
"Iteration": np.repeat(range(1, len(log1) + 1), [len(scores) for scores in log1]),
"Fraction of Valid Peptides": [score for scores in log1 for score in scores],
"Dataset": label1,
"Style": "Log1"
})
# Prepare data for log2 if provided
if log2 is not None:
data2 = pd.DataFrame({
"Iteration": np.repeat(range(1, len(log2) + 1), [len(scores) for scores in log2]),
"Fraction of Valid Peptides": [score for scores in log2 for score in scores],
"Dataset": label2,
"Style": "Log2"
})
data = pd.concat([data1, data2], ignore_index=True)
else:
data = data1
palette = {
label1: "#8181ED", # Default color for log1
label2: "#D577FF" # Default color for log2 (if provided)
}
# Set Seaborn theme
sns.set_theme()
sns.set_context("paper")
# Create the plot
sns.relplot(
data=data,
kind="line",
x="Iteration",
y="Fraction of Valid Peptides",
hue="Dataset",
style="Style",
markers=True,
dashes=True,
ci="sd", # Show standard deviation
height=5,
aspect=1.5,
palette=palette
)
# Titles and labels
plt.title(title)
plt.xlabel("Iteration")
plt.ylabel("Fraction of Valid Peptides")
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Plot saved to {save_path}")
plt.show()
@torch.no_grad()
def generate_valid_mcts(config, mdlm, prot1=None, prot2=None, filename=None, prot_name1=None, prot_name2 = None):
tokenizer = mdlm.tokenizer
max_sequence_length = config.sampling.seq_length
# generate array of [MASK] tokens
masked_array = mask_for_de_novo(config, max_sequence_length)
if config.vocab == 'old_smiles':
# use custom encode function
inputs = tokenizer.encode(masked_array)
elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
inputs = tokenizer.encode_for_generation(masked_array)
else:
# custom HELM tokenizer
inputs = tokenizer(masked_array, return_tensors="pt")
inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
# initialize root node
rootNode = Node(config=config, tokens=inputs, timestep=0)
# initalize tree search algorithm
if config.mcts.perm:
score_func_names = ['permeability', 'binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
num_func = [0, 50, 50, 50, 50]
elif config.mcts.dual:
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'binding_affinity2']
elif config.mcts.single:
score_func_names = ['permeability']
else:
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
if not config.mcts.time_dependent:
num_func = [0] * len(score_func_names)
if prot1 and prot2 is not None:
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1, prot2], num_func=num_func)
elif prot1 is not None:
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1], num_func=num_func)
elif config.mcts.single:
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
else:
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
paretoFront = mcts.forward(rootNode)
output_log_path = f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/log_{filename}.csv'
save_logs_to_file(config, mcts.valid_fraction_log, mcts.affinity1_log, mcts.affinity2_log, mcts.sol_log, mcts.hemo_log, mcts.nf_log, mcts.permeability_log, output_log_path)
if config.mcts.single:
plot_data_with_distribution_seaborn(log1=mcts.permeability_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/perm_{filename}.png',
label1="Average Permeability Score",
title="Average Permeability Score Over Iterations")
else:
plot_data(mcts.valid_fraction_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/valid_{filename}.png')
plot_data_with_distribution_seaborn(log1=mcts.affinity1_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/binding1_{filename}.png',
label1="Average Binding Affinity to TfR",
title="Average Binding Affinity to TfR Over Iterations")
if config.mcts.dual:
plot_data_with_distribution_seaborn(log1=mcts.affinity2_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/binding2_{filename}.png',
label1="Average Binding Affinity to SKP2",
title="Average Binding Affinity to SKP2 Over Iterations")
plot_data_with_distribution_seaborn(log1=mcts.sol_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/sol_{filename}.png',
label1="Average Solubility Score",
title="Average Solubility Score Over Iterations")
plot_data_with_distribution_seaborn(log1=mcts.hemo_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/hemo_{filename}.png',
label1="Average Hemolysis Score",
title="Average Hemolysis Score Over Iterations")
plot_data_with_distribution_seaborn(log1=mcts.nf_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/nf_{filename}.png',
label1="Average Nonfouling Score",
title="Average Nonfouling Score Over Iterations")
if config.mcts.perm:
plot_data_with_distribution_seaborn(log1=mcts.permeability_log,
save_path=f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/perm_{filename}.png',
label1="Average Permeability Score",
title="Average Permeability Score Over Iterations")
return paretoFront, inputs
@hydra.main(version_base=None, config_path='/home/st512/peptune/scripts/peptide-mdlm-mcts', config_name='config')
def main(config):
prot_name1 = "time_dependent"
prot_name2 = "skp2"
mode = "2"
model = "mcts"
length = "100"
epoch = "7"
filename = f'{mode}_{model}_length_{length}_epoch_{epoch}'
if config.vocab == 'new_smiles':
tokenizer = APETokenizer()
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_smiles_600_vocab.json')
elif config.vocab == 'old_smiles':
tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
'/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
elif config.vocab == 'selfies':
tokenizer = APETokenizer()
tokenizer.load_vocabulary('/home/st512/peptune/scripts/peptide-mdlm-mcts/new_tokenizer/peptide_selfies_600_vocab.json')
elif config.vocab == 'helm':
tokenizer = HelmTokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/helm_tokenizer/monomer_vocab.txt')
mdlm = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer, strict=False)
mdlm.eval()
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
mdlm.to(device)
print("loaded models...")
analyzer = PeptideAnalyzer()
# proteins
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
paretoFront, input_array = generate_valid_mcts(config, mdlm, gfap, None, filename, prot_name1, None)
generation_results = []
for sequence, v in paretoFront.items():
generated_array = v['token_ids'].to(mdlm.device)
# compute perplexity
perplexity = mdlm.compute_masked_perplexity(generated_array, input_array['input_ids'])
perplexity = round(perplexity, 4)
aa_seq, seq_length = analyzer.analyze_structure(sequence)
scores = v['scores']
if config.mcts.single == False:
binding1 = scores[0]
solubility = scores[1]
hemo = scores[2]
nonfouling = scores[3]
if config.mcts.perm:
permeability = scores[4]
generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling, permeability])
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}")
elif config.mcts.dual:
binding2 = scores[4]
generation_results.append([sequence, perplexity, aa_seq, binding1, binding2, solubility, hemo, nonfouling])
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity 1: {binding1} | Binding Affinity 2: {binding2} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
elif config.mcts.single:
permeability = scores[0]
else:
generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling])
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
sys.stdout.flush()
if config.mcts.perm:
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
elif config.mcts.dual:
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity 1', 'Binding Affinity 2', 'Solubility', 'Hemolysis', 'Nonfouling'])
elif config.mcts.single:
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Permeability'])
else:
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling'])
df.to_csv(f'/home/st512/peptune/scripts/peptide-mdlm-mcts/benchmarks/{prot_name1}/{filename}.csv', index=False)
if __name__ == "__main__":
main()