Spaces:
Running
Running
| """Script to replicate results from the DGEB paper.""" | |
| import torch | |
| import dgeb | |
| from functools import partial | |
| ALL_DEVICES = list(range(torch.cuda.device_count())) | |
| DEFAULT_BATCH_SIZE = 64 | |
| DEFAULT_SEQ_LEN = 1024 | |
| get_model = partial( | |
| dgeb.get_model, | |
| devices=ALL_DEVICES, | |
| batch_size=DEFAULT_BATCH_SIZE, | |
| max_seq_length=DEFAULT_SEQ_LEN, | |
| ) | |
| def main(): | |
| ######################### Protein Models ######################### | |
| protein_tasks = dgeb.get_tasks_by_modality(dgeb.Modality.PROTEIN) | |
| protein_evaluation = dgeb.DGEB(tasks=protein_tasks) | |
| # ESM models. | |
| protein_evaluation.run(get_model("facebook/esm2_t6_8M_UR50D")) | |
| protein_evaluation.run(get_model("facebook/esm2_t12_35M_UR50D")) | |
| protein_evaluation.run(get_model("facebook/esm2_t30_150M_UR50D")) | |
| protein_evaluation.run(get_model("facebook/esm2_t33_650M_UR50D", batch_size=32)) | |
| protein_evaluation.run(get_model("facebook/esm2_t36_3B_UR50D", batch_size=1)) | |
| # ESM3 models. | |
| protein_evaluation.run(get_model("esm3_sm_open_v1", batch_size=1, devices=[0])) | |
| # ProtT5 models. | |
| protein_evaluation.run(get_model("Rostlab/prot_t5_xl_uniref50", batch_size=32)) | |
| protein_evaluation.run(get_model("Rostlab/prot_t5_xl_bfd", batch_size=32)) | |
| # ProGen2 models. | |
| protein_evaluation.run(get_model("hugohrban/progen2-small")) | |
| protein_evaluation.run(get_model("hugohrban/progen2-medium", batch_size=32)) | |
| protein_evaluation.run(get_model("hugohrban/progen2-large", batch_size=1)) | |
| protein_evaluation.run(get_model("hugohrban/progen2-xlarge", batch_size=1)) | |
| ######################### DNA Models ######################### | |
| dna_tasks = dgeb.get_tasks_by_modality(dgeb.Modality.DNA) | |
| dna_evaluation = dgeb.DGEB(tasks=dna_tasks) | |
| # Evo models | |
| dna_evaluation.run( | |
| get_model( | |
| "togethercomputer/evo-1-8k-base", batch_size=1, seq_len=8192, devices=[0] | |
| ) | |
| ) | |
| # 131k will OOM so we use half this length. | |
| evo_131k_max_seq_len = int(131072 / 2) | |
| dna_evaluation.run( | |
| get_model( | |
| "togethercomputer/evo-1-131k-base", | |
| batch_size=1, | |
| seq_len=evo_131k_max_seq_len, | |
| devices=[0], | |
| ) | |
| ) | |
| # Nucleotide Transformer models. | |
| dna_evaluation.run( | |
| get_model("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species") | |
| ) | |
| dna_evaluation.run( | |
| get_model("InstaDeepAI/nucleotide-transformer-v2-100m-multi-species") | |
| ) | |
| dna_evaluation.run( | |
| get_model("InstaDeepAI/nucleotide-transformer-v2-250m-multi-species") | |
| ) | |
| dna_evaluation.run( | |
| get_model("InstaDeepAI/nucleotide-transformer-v2-500m-multi-species") | |
| ) | |
| dna_evaluation.run( | |
| get_model("InstaDeepAI/nucleotide-transformer-2.5b-multi-species", batch_size=1) | |
| ) | |
| if __name__ == "__main__": | |
| main() | |