from threading import Thread from typing import Iterator, List, Tuple import gradio as gr from gradio.themes import Soft import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer TEAM_LOGO_URL = "http://nlp.polytechnique.fr/static/images/logo_dascim.png" PROTEIN_VISUAL_URL = "https://cas-bridge.xethub.hf.co/xet-bridge-us/68e677c594d3f20bbeecf13c/7cff6ae021d7c518ee4e2fcb70490516ad9e4999ec75c6a5dd164cc6ca64ae30?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251023%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251023T094659Z&X-Amz-Expires=3600&X-Amz-Signature=6a7598d77a46df971e88e1f378bc5e06794a3893f31319a6ab3431e4323d755c&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=66448b4fecac3bc79b26304f&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.png%3B+filename%3D%22model.png%22%3B&response-content-type=image%2Fpng&x-id=GetObject&Expires=1761216419&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTIxNjQxOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OGU2NzdjNTk0ZDNmMjBiYmVlY2YxM2MvN2NmZjZhZTAyMWQ3YzUxOGVlNGUyZmNiNzA0OTA1MTZhZDllNDk5OWVjNzVjNmE1ZGQxNjRjYzZjYTY0YWUzMCoifV19&Signature=YjrX1ZF%7EX1qw-m2nWOY8AxdSXwbrsidvlTZ5YWXZx3UPv0my0u68lWcpWIpIxzkGeWTtWPvlCfMcmnpmmwS2wHexorhgq9c7%7E3Ghw20evO0EMPvHBwP4vWYmXW8nHBqqqbw8Qy1pojDm9TvXV19O4-fCFxPi1aQ5FOTC2Kmn9gKxW%7EAN7vkWnfhU8QcCf18139hMbUvh9YoJ%7EesOWXoCFWgAbyz%7Eroajt5e3oM9b-IsU%7E2-UzMZ4%7EMA2MSOFmg487bhZDbr2IMD15-8O0jzWu3qyO3T1H06S-9kTdI%7EC6AYtXUY8YtSWKw%7EBzhARjXK6%7EuZ3c3kE1V7%7EdnLl1YM-2w__&Key-Pair-Id=K2L8F4GPSG1IFC" PROTEIN_HERO = f"""
Protein rendering
""" DESCRIPTION = f"""\ ### Prot2Text-V2 Demo {PROTEIN_HERO} Prot2Text-V2 treats a protein sequence as if it were another language and translates it into English. Supply a raw amino acid sequence and the model returns a clear, human-readable paragraph describing what the protein does. The paper describing Prot2Text-V2 has been accepted to the NeurIPS 2025 main conference and pairs fast experimentation with explainability-minded outputs. - **Input**: protein sequence using IUPAC single-letter amino acid codes (20 canonical amino acids). - **Output**: polished descriptions of predicted function, localization cues, and structural hints. - **Why it matters**: accelerate protein characterization, lab annotations, or downstream hypothesis building. **Model architecture at a glance** - Protein language model encoder: facebook/esm2_t36_3B_UR50D. - Modality adapter: lightweight bridge aligning protein embeddings with the language model. - Natural language decoder: meta-llama/Llama-3.1-8B-Instruct for articulate descriptions. **Resources** - [Paper (NeurIPS 2025)](https://arxiv.org/abs/2505.11194) - [Code repository](https://github.com/ColinFX/Prot2Text-V2) - [Training data](https://huggingface.co/datasets/habdine/Prot2Text-Data) """ EXAMPLE_SEQUENCES = [ ["AEQAERYEEMVEFMEKL"], [ "MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA" ], [ "MCYSANGNTFLIVDNTQKRIPEEKKPDFVRENVGDLDGVIFVELVDGKYFMDYYNRDGSMAAFCGNGARAFSQYLIDRGWIKEKEFTFLSRAGEIKVIVDDSIWVRMPGVSEKKEMKVDGYEGYFVVVGVPHFVMEVKGIDELDVEKLGRDLRYKTGANVDFYEVLPDRLKVRTYERGVERETKACGTGVTSVFVVYRDKTGAKEVKIQVPGGTLFLKEENGEIFLRGDVKRCSEE" ], [ "MTQEERFEQRIAQETAIEPQDWMPDAYRKTLIRQIGQHAHSEIVGMLPEGNWITRAPTLRRKAILLAKVQDEAGHGLYLYSAAETLGCAREDIYQKMLDGRMKYSSIFNYPTLSWADIGVIGWLVDGAAIVNQVALCRTSYGPYARAMVKICKEESFHQRQGFEACMALAQGSEAQKQMLQDAINRFWWPALMMFGPNDDNSPNSARSLTWKIKRFTNDELRQRFVDNTVPQVEMLGMTVPDPDLHFDTESGHYRFGEIDWQEFNEVINGRGICNQERLDAKRKAWEEGTWVREAALAHAQKQHARKVA" ], [ "MTTRMIILNGGSSAGKSGIVRCLQSVLPEPWLAFGVDSLIEAMPLKMQSAEGGIEFDADGGVSIGPEFRALEGAWAEGVVAMARAGARIIIDDVFLGGAAAQERWRSFVGDLDVLWVGVRCDGAVAEGRETARGDRVAGMAAKQAYVVHEGVEYDVEVDTTHKESIECAWAIAAHVVP" ], ] MAX_MAX_NEW_TOKENS = 256 DEFAULT_MAX_NEW_TOKENS = 100 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") system_message = ( "You are a scientific assistant specialized in protein function " "predictions. Given the sequence embeddings and other information " "of a protein, describe its function clearly and concisely in " "professional language. " ) placeholder = '<|reserved_special_token_1|>' esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D") llama_tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct", pad_token='<|reserved_special_token_0|>' ) model = AutoModelForCausalLM.from_pretrained('xiao-fei/Prot2Text-V2-11B-Instruct-hf', trust_remote_code=True, torch_dtype=torch.bfloat16,).to(device) model.eval() @spaces.GPU(duration=90) def stream_response( message: str, max_new_tokens: int = 1024, do_sample: bool = False, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: streamer = TextIteratorStreamer(llama_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) user_message = "Sequence embeddings: " + placeholder * (len(message)+2) tokenized_prompt = llama_tokenizer.apply_chat_template( [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message} ], add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ) tokenized_sequence = esm_tokenizer( message, return_tensors="pt" ) model.eval() generate_kwargs = dict( inputs=tokenized_prompt["input_ids"].to(model.device), attention_mask=tokenized_prompt["attention_mask"].to(model.device), protein_input_ids=tokenized_sequence["input_ids"].to(model.device), protein_attention_mask=tokenized_sequence["attention_mask"].to(model.device), eos_token_id=128009, pad_token_id=128002, return_dict_in_generate=False, num_beams=1, # device=device, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=do_sample, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) ChatHistory = List[Tuple[str, str]] def handle_submit( message: str, history: ChatHistory, max_new_tokens: int, do_sample: bool, temperature: float, top_p: float, top_k: int, repetition_penalty: float, ): history = list(history or []) message = message.strip() if not message: return conversation = history.copy() conversation.append((message, "")) for partial_response in stream_response( message=message, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, ): conversation[-1] = (message, partial_response) snapshot = conversation.copy() yield snapshot, snapshot, gr.update(value="") def clear_conversation(): empty_history: ChatHistory = [] return empty_history, empty_history, gr.update(value="") theme = Soft( primary_hue="slate", secondary_hue="stone", neutral_hue="gray", ) with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo: with gr.Row(equal_height=True): with gr.Column(scale=5, min_width=320): gr.HTML( f"""
""" ) gr.Markdown(DESCRIPTION) with gr.Column(scale=7, min_width=400, elem_classes="interaction-column"): history_state = gr.State([]) chatbot = gr.Chatbot( label="Generated Function", height=350, show_copy_button=True, ) with gr.Group(elem_classes="input-card"): sequence_input = gr.Textbox( placeholder="Paste your amino acid sequence here (e.g. MAVVLPAVVEELLSEMAAAVQESA...)", label="Protein sequence", lines=1, max_lines=1, autofocus=True, ) with gr.Row(elem_classes="button-row"): submit_button = gr.Button("Predict function", variant="primary", elem_classes="primary-btn") stop_button = gr.Button("Stop generation", variant="stop", elem_classes="stop-btn") gr.Examples( examples=EXAMPLE_SEQUENCES, inputs=sequence_input, label="Sample sequences", cache_examples=False, run_on_click=False, ) with gr.Accordion("Generation controls", open=False): max_new_tokens_slider = gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ) do_sample_checkbox = gr.Checkbox(label="Enable sampling", value=False) temperature_slider = gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ) top_p_slider = gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ) top_k_slider = gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ) repetition_penalty_slider = gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0, ) enter_event = sequence_input.submit( handle_submit, inputs=[ sequence_input, history_state, max_new_tokens_slider, do_sample_checkbox, temperature_slider, top_p_slider, top_k_slider, repetition_penalty_slider, ], outputs=[chatbot, history_state, sequence_input], queue=True, ) submit_event = submit_button.click( handle_submit, inputs=[ sequence_input, history_state, max_new_tokens_slider, do_sample_checkbox, temperature_slider, top_p_slider, top_k_slider, repetition_penalty_slider, ], outputs=[chatbot, history_state, sequence_input], queue=True, ) stop_button.click( None, inputs=None, outputs=None, cancels=[submit_event, enter_event], ) with gr.Accordion("Model & usage notes", open=False): gr.Markdown( "- **Model stack**: Facebook ESM2 encoder + Llama 3.1 8B instruction-tuned decoder.\n" "- **Token budget**: the generator truncates after the configured `Max new tokens`.\n" "- **Attribution**: Outputs are predictions; validate experimentally before publication.\n" ) gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") if __name__ == "__main__": demo.queue(max_size=20).launch()