Spaces:
Running
on
Zero
Running
on
Zero
| from threading import Thread | |
| from typing import Iterator, List, Tuple | |
| import gradio as gr | |
| from gradio.themes import Soft | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| TEAM_LOGO_URL = "http://nlp.polytechnique.fr/static/images/logo_dascim.png" | |
| PROTEIN_VISUAL_URL = "https://cas-bridge.xethub.hf.co/xet-bridge-us/68e677c594d3f20bbeecf13c/7cff6ae021d7c518ee4e2fcb70490516ad9e4999ec75c6a5dd164cc6ca64ae30?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251023%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251023T094659Z&X-Amz-Expires=3600&X-Amz-Signature=6a7598d77a46df971e88e1f378bc5e06794a3893f31319a6ab3431e4323d755c&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=66448b4fecac3bc79b26304f&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.png%3B+filename%3D%22model.png%22%3B&response-content-type=image%2Fpng&x-id=GetObject&Expires=1761216419&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTIxNjQxOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OGU2NzdjNTk0ZDNmMjBiYmVlY2YxM2MvN2NmZjZhZTAyMWQ3YzUxOGVlNGUyZmNiNzA0OTA1MTZhZDllNDk5OWVjNzVjNmE1ZGQxNjRjYzZjYTY0YWUzMCoifV19&Signature=YjrX1ZF%7EX1qw-m2nWOY8AxdSXwbrsidvlTZ5YWXZx3UPv0my0u68lWcpWIpIxzkGeWTtWPvlCfMcmnpmmwS2wHexorhgq9c7%7E3Ghw20evO0EMPvHBwP4vWYmXW8nHBqqqbw8Qy1pojDm9TvXV19O4-fCFxPi1aQ5FOTC2Kmn9gKxW%7EAN7vkWnfhU8QcCf18139hMbUvh9YoJ%7EesOWXoCFWgAbyz%7Eroajt5e3oM9b-IsU%7E2-UzMZ4%7EMA2MSOFmg487bhZDbr2IMD15-8O0jzWu3qyO3T1H06S-9kTdI%7EC6AYtXUY8YtSWKw%7EBzhARjXK6%7EuZ3c3kE1V7%7EdnLl1YM-2w__&Key-Pair-Id=K2L8F4GPSG1IFC" | |
| PROTEIN_HERO = f""" | |
| <div class="visual-card hero-card"> | |
| <img src="{PROTEIN_VISUAL_URL}" alt="Protein rendering" class="protein-visual"> | |
| </div> | |
| """ | |
| DESCRIPTION = f"""\ | |
| ### Prot2Text-V2 Demo | |
| {PROTEIN_HERO} | |
| Prot2Text-V2 treats a protein sequence as if it were another language and translates it into English. Supply a raw amino acid sequence and the model returns a clear, human-readable paragraph describing what the protein does. | |
| The paper describing Prot2Text-V2 has been accepted to the NeurIPS 2025 main conference and pairs fast experimentation with explainability-minded outputs. | |
| - **Input**: protein sequence using IUPAC single-letter amino acid codes (20 canonical amino acids). | |
| - **Output**: polished descriptions of predicted function, localization cues, and structural hints. | |
| - **Why it matters**: accelerate protein characterization, lab annotations, or downstream hypothesis building. | |
| **Model architecture at a glance** | |
| - Protein language model encoder: facebook/esm2_t36_3B_UR50D. | |
| - Modality adapter: lightweight bridge aligning protein embeddings with the language model. | |
| - Natural language decoder: meta-llama/Llama-3.1-8B-Instruct for articulate descriptions. | |
| **Resources** | |
| - [Paper (NeurIPS 2025)](https://arxiv.org/abs/2505.11194) | |
| - [Code repository](https://github.com/ColinFX/Prot2Text-V2) | |
| - [Training data](https://huggingface.co/datasets/habdine/Prot2Text-Data) | |
| """ | |
| EXAMPLE_SEQUENCES = [ | |
| ["AEQAERYEEMVEFMEKL"], | |
| [ | |
| "MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA" | |
| ], | |
| [ | |
| "MCYSANGNTFLIVDNTQKRIPEEKKPDFVRENVGDLDGVIFVELVDGKYFMDYYNRDGSMAAFCGNGARAFSQYLIDRGWIKEKEFTFLSRAGEIKVIVDDSIWVRMPGVSEKKEMKVDGYEGYFVVVGVPHFVMEVKGIDELDVEKLGRDLRYKTGANVDFYEVLPDRLKVRTYERGVERETKACGTGVTSVFVVYRDKTGAKEVKIQVPGGTLFLKEENGEIFLRGDVKRCSEE" | |
| ], | |
| [ | |
| "MTQEERFEQRIAQETAIEPQDWMPDAYRKTLIRQIGQHAHSEIVGMLPEGNWITRAPTLRRKAILLAKVQDEAGHGLYLYSAAETLGCAREDIYQKMLDGRMKYSSIFNYPTLSWADIGVIGWLVDGAAIVNQVALCRTSYGPYARAMVKICKEESFHQRQGFEACMALAQGSEAQKQMLQDAINRFWWPALMMFGPNDDNSPNSARSLTWKIKRFTNDELRQRFVDNTVPQVEMLGMTVPDPDLHFDTESGHYRFGEIDWQEFNEVINGRGICNQERLDAKRKAWEEGTWVREAALAHAQKQHARKVA" | |
| ], | |
| [ | |
| "MTTRMIILNGGSSAGKSGIVRCLQSVLPEPWLAFGVDSLIEAMPLKMQSAEGGIEFDADGGVSIGPEFRALEGAWAEGVVAMARAGARIIIDDVFLGGAAAQERWRSFVGDLDVLWVGVRCDGAVAEGRETARGDRVAGMAAKQAYVVHEGVEYDVEVDTTHKESIECAWAIAAHVVP" | |
| ], | |
| ] | |
| MAX_MAX_NEW_TOKENS = 256 | |
| DEFAULT_MAX_NEW_TOKENS = 100 | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| system_message = ( | |
| "You are a scientific assistant specialized in protein function " | |
| "predictions. Given the sequence embeddings and other information " | |
| "of a protein, describe its function clearly and concisely in " | |
| "professional language. " | |
| ) | |
| placeholder = '<|reserved_special_token_1|>' | |
| esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D") | |
| llama_tokenizer = AutoTokenizer.from_pretrained( | |
| pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct", | |
| pad_token='<|reserved_special_token_0|>' | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained('xiao-fei/Prot2Text-V2-11B-Instruct-hf', | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16,).to(device) | |
| model.eval() | |
| def stream_response( | |
| message: str, | |
| max_new_tokens: int = 1024, | |
| do_sample: bool = False, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| streamer = TextIteratorStreamer(llama_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
| user_message = "Sequence embeddings: " + placeholder * (len(message)+2) | |
| tokenized_prompt = llama_tokenizer.apply_chat_template( | |
| [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ], | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ) | |
| tokenized_sequence = esm_tokenizer( | |
| message, | |
| return_tensors="pt" | |
| ) | |
| model.eval() | |
| generate_kwargs = dict( | |
| inputs=tokenized_prompt["input_ids"].to(model.device), | |
| attention_mask=tokenized_prompt["attention_mask"].to(model.device), | |
| protein_input_ids=tokenized_sequence["input_ids"].to(model.device), | |
| protein_attention_mask=tokenized_sequence["attention_mask"].to(model.device), | |
| eos_token_id=128009, | |
| pad_token_id=128002, | |
| return_dict_in_generate=False, | |
| num_beams=1, | |
| # device=device, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| ChatHistory = List[Tuple[str, str]] | |
| def handle_submit( | |
| message: str, | |
| history: ChatHistory, | |
| max_new_tokens: int, | |
| do_sample: bool, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| ): | |
| history = list(history or []) | |
| message = message.strip() | |
| if not message: | |
| return | |
| conversation = history.copy() | |
| conversation.append((message, "")) | |
| for partial_response in stream_response( | |
| message=message, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| ): | |
| conversation[-1] = (message, partial_response) | |
| snapshot = conversation.copy() | |
| yield snapshot, snapshot, gr.update(value="") | |
| def clear_conversation(): | |
| empty_history: ChatHistory = [] | |
| return empty_history, empty_history, gr.update(value="") | |
| theme = Soft( | |
| primary_hue="slate", | |
| secondary_hue="stone", | |
| neutral_hue="gray", | |
| ) | |
| with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo: | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=5, min_width=320): | |
| gr.HTML( | |
| f""" | |
| <div class="brand-header center"> | |
| <a href="https://www.lix.polytechnique.fr/dascim/" target="_blank" rel="noopener"> | |
| <img src="{TEAM_LOGO_URL}" alt="DASCIM team logo" class="team-logo"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Column(scale=7, min_width=400, elem_classes="interaction-column"): | |
| history_state = gr.State([]) | |
| chatbot = gr.Chatbot( | |
| label="Generated Function", | |
| height=350, | |
| show_copy_button=True, | |
| ) | |
| with gr.Group(elem_classes="input-card"): | |
| sequence_input = gr.Textbox( | |
| placeholder="Paste your amino acid sequence here (e.g. MAVVLPAVVEELLSEMAAAVQESA...)", | |
| label="Protein sequence", | |
| lines=1, | |
| max_lines=1, | |
| autofocus=True, | |
| ) | |
| with gr.Row(elem_classes="button-row"): | |
| submit_button = gr.Button("Predict function", variant="primary", elem_classes="primary-btn") | |
| stop_button = gr.Button("Stop generation", variant="stop", elem_classes="stop-btn") | |
| gr.Examples( | |
| examples=EXAMPLE_SEQUENCES, | |
| inputs=sequence_input, | |
| label="Sample sequences", | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| with gr.Accordion("Generation controls", open=False): | |
| max_new_tokens_slider = gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ) | |
| do_sample_checkbox = gr.Checkbox(label="Enable sampling", value=False) | |
| temperature_slider = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=4.0, | |
| step=0.1, | |
| value=0.6, | |
| ) | |
| top_p_slider = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.9, | |
| ) | |
| top_k_slider = gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=1000, | |
| step=1, | |
| value=50, | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| label="Repetition penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0, | |
| ) | |
| enter_event = sequence_input.submit( | |
| handle_submit, | |
| inputs=[ | |
| sequence_input, | |
| history_state, | |
| max_new_tokens_slider, | |
| do_sample_checkbox, | |
| temperature_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| repetition_penalty_slider, | |
| ], | |
| outputs=[chatbot, history_state, sequence_input], | |
| queue=True, | |
| ) | |
| submit_event = submit_button.click( | |
| handle_submit, | |
| inputs=[ | |
| sequence_input, | |
| history_state, | |
| max_new_tokens_slider, | |
| do_sample_checkbox, | |
| temperature_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| repetition_penalty_slider, | |
| ], | |
| outputs=[chatbot, history_state, sequence_input], | |
| queue=True, | |
| ) | |
| stop_button.click( | |
| None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[submit_event, enter_event], | |
| ) | |
| with gr.Accordion("Model & usage notes", open=False): | |
| gr.Markdown( | |
| "- **Model stack**: Facebook ESM2 encoder + Llama 3.1 8B instruction-tuned decoder.\n" | |
| "- **Token budget**: the generator truncates after the configured `Max new tokens`.\n" | |
| "- **Attribution**: Outputs are predictions; validate experimentally before publication.\n" | |
| ) | |
| gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |