Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import py3Dmol | |
| import io | |
| import numpy as np | |
| import os | |
| import traceback | |
| from esm.sdk import client | |
| from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, ESMProteinError | |
| from esm.utils.structure.protein_chain import ProteinChain | |
| from Bio.Data import PDBData | |
| import biotite.structure as bs | |
| from biotite.structure.io import pdb | |
| from esm.utils import residue_constants as RC | |
| import requests | |
| from dotenv import load_dotenv | |
| import torch | |
| import json | |
| import time | |
| from Bio.PDB import PDBParser | |
| import itertools | |
| howtouse = """ | |
| ## How to use | |
| 1. Upload a PDB file using the file uploader. | |
| 2. Adjust the number of prediction runs per frame using the slider. | |
| 3. Set the noise level to add random perturbations to the structure. | |
| 4. Choose the number of MD frames to simulate. | |
| 5. Click the "Run Prediction" button to start the process. | |
| 6. The 3D visualization will show the original structure (grey) and the best predicted structure (green). | |
| 7. The alignment result will display the best cRMSD (lower is better). | |
| 8. Total and Normalized (per atom) steric clashes (lower is better) | |
| """ | |
| about = """ ## Background | |
| - 3D protein structures typically come from crystal structures, which are densely packed and lack flexibility. | |
| - Different proteins require varying levels of noise to achieve overlap in conformational space. | |
| - We've developed an adaptability model that predicts the appropriate noise level for each protein. | |
| ## Our Approach | |
| 1. **Adaptability Model**: Trained on Molecular Dynamics (MD) data, our model predicts flexibility at the atomic level. | |
| 2. **Correlation**: The adaptability predictions correlate well with the RMSD (Root Mean Square Deviation) from ESM3 sampling. | |
| 3. **Noise Application**: We apply noise to simulate protein flexibility, mimicking MD-like behavior. | |
| """ | |
| about1 = """ | |
| ## About | |
| This demo uses the ESM3 model to predict protein structures from PDB files. | |
| It runs multiple predictions with added noise and simulated MD frames, displaying the best result based on the lowest cRMSD. | |
| """ | |
| load_dotenv() | |
| API_URL = "https://forge.evolutionaryscale.ai/api/v1" | |
| MODEL = "esm3-open-2024-03" | |
| API_TOKEN = os.environ.get("ESM_API_TOKEN") | |
| if not API_TOKEN: | |
| raise ValueError("ESM_API_TOKEN environment variable is not set") | |
| model = client( | |
| model=MODEL, | |
| url=API_URL, | |
| token="2x0lifRJCpo8klurAJtRom" | |
| ) | |
| amino3to1 = { | |
| 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', | |
| 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', | |
| 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', | |
| 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y' | |
| } | |
| # Covalent radii dictionary | |
| COVALENT_RADIUS = { | |
| "H": 0.31, "HE": 0.28, "LI": 1.28, "BE": 0.96, "B": 0.84, "C": 0.76, "N": 0.71, "O": 0.66, "F": 0.57, "NE": 0.58, | |
| "NA": 1.66, "MG": 1.41, "AL": 1.21, "SI": 1.11, "P": 1.07, "S": 1.05, "CL": 1.02, "AR": 1.06, "K": 2.03, | |
| "CA": 1.76, "SC": 1.7, "TI": 1.6, "V": 1.53, "CR": 1.39, "MN": 1.5, "FE": 1.42, "CO": 1.38, "NI": 1.24, | |
| "CU": 1.32, "ZN": 1.22, "GA": 1.22, "GE": 1.2, "AS": 1.19, "SE": 1.2, "BR": 1.2, "KR": 1.16, "RB": 2.2, | |
| "SR": 1.95, "Y": 1.9, "ZR": 1.75, "NB": 1.64, "MO": 1.54, "TC": 1.47, "RU": 1.46, "RH": 1.42, "PD": 1.39, | |
| "AG": 1.45, "CD": 1.44, "IN": 1.42, "SN": 1.39, "SB": 1.39, "TE": 1.38, "I": 1.39, "XE": 1.4, "CS": 2.44, | |
| "BA": 2.15, "LA": 2.07, "CE": 2.04, "PR": 2.03, "ND": 2.01, "PM": 1.99, "SM": 1.98, "EU": 1.98, "GD": 1.96, | |
| "TB": 1.94, "DY": 1.92, "HO": 1.92, "ER": 1.89, "TM": 1.9, "YB": 1.87, "LU": 1.87, "HF": 1.75, "TA": 1.7, | |
| "W": 1.62, "RE": 1.51, "OS": 1.44, "IR": 1.41, "PT": 1.36, "AU": 1.36, "HG": 1.32, "TL": 1.45, "PB": 1.46, | |
| "BI": 1.48, "PO": 1.4, "AT": 1.5, "RN": 1.5, "FR": 2.6, "RA": 2.21, "AC": 2.15, "TH": 2.06, "PA": 2.0, | |
| "U": 1.96, "NP": 1.9, "PU": 1.87, "AM": 1.8, "CM": 1.69, "BK": 2.0, "CF": 2.0, "ES": 2.0, "FM": 2.0, | |
| "MD": 2.0, "NO": 2.0, "LR": 2.0, "RF": 2.0, "DB": 2.0, "SG": 2.0, "BH": 2.0, "HS": 2.0, "MT": 2.0, | |
| "DS": 2.0, "RG": 2.0, "CN": 2.0, "UUT": 2.0, "UUQ": 2.0, "UUP": 2.0, "UUH": 2.0, "UUS": 2.0, "UUO": 2.0 | |
| } | |
| # Function to get the covalent radius of an atom | |
| def get_covalent_radius(atom): | |
| element = atom.element.upper() | |
| return COVALENT_RADIUS.get(element, 2.0) # Default to 2.0 Å if element is not in the dictionary | |
| def calculate_clashes_for_pdb(pdb_file): | |
| parser = PDBParser(QUIET=True) | |
| structure = parser.get_structure("protein", pdb_file) | |
| atoms = list(structure.get_atoms()) | |
| steric_clash_count = 0 | |
| num_atoms = len(atoms) | |
| # Check atom pairs for steric clashes | |
| for atom1, atom2 in itertools.combinations(atoms, 2): | |
| covalent_radius_sum = get_covalent_radius(atom1) + get_covalent_radius(atom2) | |
| distance = atom1 - atom2 # Distance between atom1 and atom2 | |
| # Check if the distance is less than the sum of covalent radii | |
| if distance + 0.5 < covalent_radius_sum: | |
| steric_clash_count += 1 | |
| # Normalize steric clashes per number of atoms | |
| norm_ster_clash_count = steric_clash_count / num_atoms | |
| return f"{steric_clash_count}", f"{norm_ster_clash_count}" | |
| def read_pdb_io(pdb_file): | |
| if isinstance(pdb_file, io.StringIO): | |
| pdb_content = pdb_file.getvalue() | |
| elif hasattr(pdb_file, 'name'): | |
| with open(pdb_file.name, 'r') as f: | |
| pdb_content = f.read() | |
| else: | |
| raise ValueError("Unsupported file type") | |
| if not pdb_content.strip(): | |
| raise ValueError("The PDB file is empty.") | |
| pdb_io = io.StringIO(pdb_content) | |
| return pdb_io, pdb_content | |
| def get_protein(pdb_file) -> ESMProtein: | |
| try: | |
| pdb_io, content = read_pdb_io(pdb_file) | |
| if not content.strip(): | |
| raise ValueError("The PDB file is empty") | |
| # Parse the PDB file using biotite | |
| pdb_file = pdb.PDBFile.read(pdb_io) | |
| structure = pdb_file.get_structure() | |
| # Check if the structure contains any atoms | |
| if structure.array_length() == 0: | |
| raise ValueError("The PDB file does not contain any valid atoms") | |
| # Filter for amino acids and create a sequence | |
| valid_residues = [] | |
| for res in bs.residue_iter(structure): | |
| res_name = res.res_name | |
| if isinstance(res_name, np.ndarray): | |
| res_name = res_name[0] # Take the first element if it's an array | |
| if res_name in amino3to1: | |
| valid_residues.append(res) | |
| if not valid_residues: | |
| raise ValueError("No valid amino acid residues found in the PDB file") | |
| sequence = ''.join(amino3to1.get(res.res_name[0] if isinstance(res.res_name, np.ndarray) else res.res_name, 'X') for res in valid_residues) | |
| # Handle res_id as a potential sequence | |
| residue_indices = [] | |
| for res in valid_residues: | |
| if isinstance(res.res_id, (list, tuple, np.ndarray)): | |
| residue_indices.append(res.res_id[0]) # Take the first element if it's a sequence | |
| else: | |
| residue_indices.append(res.res_id) | |
| # Create a ProteinChain object | |
| protein_chain = ProteinChain( | |
| id="test", | |
| sequence=sequence, | |
| chain_id="A", | |
| entity_id=None, | |
| residue_index=np.array(residue_indices, dtype=int), | |
| insertion_code=np.full(len(sequence), "", dtype="<U4"), | |
| atom37_positions=np.full((len(sequence), 37, 3), np.nan), | |
| atom37_mask=np.zeros((len(sequence), 37), dtype=bool), | |
| confidence=np.ones(len(sequence), dtype=np.float32) | |
| ) | |
| # Fill in atom positions and mask | |
| for i, res in enumerate(valid_residues): | |
| for atom in res: | |
| atom_name = atom.atom_name | |
| if isinstance(atom_name, np.ndarray): | |
| atom_name = atom_name[0] # Take the first element if it's an array | |
| if atom_name in RC.atom_order: | |
| idx = RC.atom_order[atom_name] | |
| coord = atom.coord | |
| if coord.ndim > 1: | |
| coord = coord[0] # Take the first coordinate set if multiple are present | |
| protein_chain.atom37_positions[i, idx] = coord | |
| protein_chain.atom37_mask[i, idx] = True | |
| protein = ESMProtein.from_protein_chain(protein_chain) | |
| return protein | |
| except Exception as e: | |
| print(f"Error processing PDB file: {str(e)}") | |
| raise ValueError(f"Unable to process the PDB file: {str(e)}") | |
| def add_noise_to_coordinates(protein: ESMProtein, noise_level: float) -> ESMProtein: | |
| """Add Gaussian noise to the atom positions of the protein.""" | |
| coordinates = protein.coordinates | |
| noise = np.random.randn(*coordinates.shape) * noise_level | |
| noisy_coordinates = coordinates + noise | |
| return ESMProtein(sequence=protein.sequence, coordinates=noisy_coordinates) | |
| def run_structure_prediction(protein: ESMProtein, temperature: float, num_steps: int) -> ESMProtein: | |
| structure_prediction_config = GenerationConfig( | |
| track="structure", | |
| num_steps=num_steps, | |
| temperature=temperature, | |
| ) | |
| try: | |
| response = model.generate(protein, structure_prediction_config) | |
| if isinstance(response, ESMProtein): | |
| return response | |
| elif isinstance(response, ESMProteinError): | |
| print(f"ESMProteinError during structure prediction: {response.error_msg}") | |
| return None | |
| else: | |
| raise ValueError(f"Unexpected response type: {type(response)}") | |
| except Exception as e: | |
| print(f"Error during structure prediction: {str(e)}") | |
| return None | |
| def align_after_prediction(protein: ESMProtein, structure_prediction: ESMProtein) -> tuple[ESMProtein, float]: | |
| if structure_prediction is None: | |
| return None, float('inf') | |
| try: | |
| structure_prediction_chain = structure_prediction.to_protein_chain() | |
| protein_chain = protein.to_protein_chain() | |
| # Ensure both chains have the same length | |
| min_length = min(len(structure_prediction_chain.sequence), len(protein_chain.sequence)) | |
| structure_indices = np.arange(0, min_length) | |
| # Perform alignment | |
| aligned_chain = structure_prediction_chain.align( | |
| protein_chain, | |
| mobile_inds=structure_indices, | |
| target_inds=structure_indices | |
| ) | |
| # Calculate RMSD | |
| crmsd = structure_prediction_chain.rmsd( | |
| protein_chain, | |
| mobile_inds=structure_indices, | |
| target_inds=structure_indices | |
| ) | |
| return ESMProtein.from_protein_chain(aligned_chain), crmsd | |
| except AttributeError as e: | |
| print(f"Error during alignment: {str(e)}") | |
| print(f"Structure prediction type: {type(structure_prediction)}") | |
| print(f"Structure prediction attributes: {dir(structure_prediction)}") | |
| return None, float('inf') | |
| except Exception as e: | |
| print(f"Unexpected error during alignment: {str(e)}") | |
| return None, float('inf') | |
| def visualize_after_pred(protein: ESMProtein, aligned: ESMProtein): | |
| if aligned is None: | |
| return None | |
| view = py3Dmol.view(width=800, height=600) | |
| view.addModel(protein_to_pdb(protein), "pdb") | |
| view.setStyle({"cartoon": {"color": "lightgrey"}}) | |
| view.addModel(protein_to_pdb(aligned), "pdb") | |
| view.setStyle({"model": -1}, {"cartoon": {"color": "lightgreen"}}) | |
| view.zoomTo() | |
| return view | |
| def protein_to_pdb(protein: ESMProtein): | |
| pdb_str = "" | |
| for i, (aa, coords) in enumerate(zip(protein.sequence, protein.coordinates)): | |
| for j, atom in enumerate(RC.atom_types): | |
| if not torch.isnan(coords[j][0]): | |
| x, y, z = coords[j].tolist() | |
| pdb_str += f"ATOM {i*37+j+1:5d} {atom:3s} {aa:3s} A{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f}\n" | |
| return pdb_str | |
| def prediction_visualization(pdb_file, num_runs: int, noise_level: float, num_frames: int, temperature: float, num_steps: int, progress=gr.Progress()): | |
| protein = get_protein(pdb_file) | |
| runs = [] | |
| total_iterations = num_frames * num_runs | |
| progress(0, desc="Starting predictions") | |
| for frame in progress.tqdm(range(num_frames), desc="Processing frames"): | |
| noisy_protein = add_noise_to_coordinates(protein, noise_level) | |
| for i in range(num_runs): | |
| progress((frame * num_runs + i + 1) / total_iterations, desc=f"Frame {frame+1}, Run {i+1}") | |
| structure_prediction = run_structure_prediction(noisy_protein, temperature, num_steps) | |
| if structure_prediction is not None: | |
| aligned, crmsd = align_after_prediction(protein, structure_prediction) | |
| if aligned is not None: | |
| runs.append((crmsd, aligned)) | |
| time.sleep(0.1) # Small delay to allow for UI updates | |
| if not runs: | |
| return None, "No successful predictions" | |
| best_aligned = sorted(runs, key=lambda x: x[0])[0] | |
| view_data = visualize_after_pred(protein, best_aligned[1]) | |
| return view_data, f"Best cRMSD: {best_aligned[0]:.4f}" | |
| def run_prediction(pdb_file, num_runs, noise_level, num_frames, temperature, num_steps, progress=gr.Progress()): | |
| try: | |
| if pdb_file is None: | |
| return "Please upload a PDB file.", "No file uploaded", "", "" | |
| progress(0, desc="Starting prediction") | |
| view, crmsd_text = prediction_visualization(pdb_file, num_runs, noise_level, num_frames, temperature, num_steps, progress) | |
| steric_clash_text, norm_steric_clash_text = calculate_clashes_for_pdb(pdb_file) | |
| if view is None: | |
| return "No successful predictions were made. Try adjusting the parameters or check the PDB file.", crmsd_text, steric_clash_text, norm_steric_clash_text | |
| progress(0.9, desc="Rendering visualization") | |
| # Convert the py3Dmol view to HTML | |
| view_html = view._make_html().replace("'", '"') | |
| html_content = f""" | |
| <iframe style="width: 100%; height: 600px;" name="result" allow="midi; geolocation; microphone; camera; | |
| display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
| allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='<!DOCTYPE html><html>{view_html}</html>'></iframe> | |
| """ | |
| progress(1.0, desc="Completed") | |
| return html_content, crmsd_text, steric_clash_text, norm_steric_clash_text | |
| except Exception as e: | |
| error_message = str(e) | |
| stack_trace = traceback.format_exc() | |
| return f""" | |
| <div style='color: red;'> | |
| <h3>Error:</h3> | |
| <p>{error_message}</p> | |
| <h4>Stack Trace:</h4> | |
| <pre>{stack_trace}</pre> | |
| </div> | |
| """, "Error occurred", "", "" | |
| def create_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Protein Structure Prediction and Visualization with Noise and MD Frames") | |
| with gr.Accordion(label='learn more about MISATO ESM3 conformational sampling', open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(about) | |
| with gr.Column(): | |
| gr.Markdown(howtouse) | |
| with gr.Row(): | |
| gr.Markdown(about1) | |
| with gr.Accordion(label="watch presentation video", open=False): | |
| with gr.Row(): | |
| gr.Video(value="demovideo/demo.mp4", label="MISATO Video Submission") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdb_file = gr.File(label="Upload PDB file") | |
| num_runs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of runs per frame") | |
| noise_level = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Noise level") | |
| num_frames = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of MD frames") | |
| temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.7, label="Temperature") | |
| num_steps = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Number of steps") | |
| run_button = gr.Button("Run Prediction") | |
| with gr.Column(scale=2): | |
| visualization = gr.HTML(label="3D Visualization") | |
| alignment_result = gr.Textbox(label="Alignment Result") | |
| steric_clash_result = gr.Textbox(label="Steric Clash Result") | |
| norm_steric_clash_result = gr.Textbox(label="Normalized Steric Clash Result") | |
| run_button.click( | |
| fn=run_prediction, | |
| inputs=[pdb_file, num_runs, noise_level, num_frames, temperature, num_steps], | |
| outputs=[visualization, alignment_result, steric_clash_result, norm_steric_clash_result] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/1ywi.pdb"], | |
| ["examples/5awl.pdb"], | |
| ["examples/11gs.pdb"], | |
| ], | |
| inputs=[pdb_file], | |
| outputs=[visualization, alignment_result, steric_clash_result, norm_steric_clash_result], | |
| fn=run_prediction, | |
| cache_examples=False, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue() | |
| demo.launch() |