from threading import Thread
from typing import Iterator, List, Tuple
import gradio as gr
from gradio.themes import Soft
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
TEAM_LOGO_URL = "http://nlp.polytechnique.fr/static/images/logo_dascim.png"
PROTEIN_VISUAL_URL = "https://cas-bridge.xethub.hf.co/xet-bridge-us/68e677c594d3f20bbeecf13c/7cff6ae021d7c518ee4e2fcb70490516ad9e4999ec75c6a5dd164cc6ca64ae30?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251023%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251023T094659Z&X-Amz-Expires=3600&X-Amz-Signature=6a7598d77a46df971e88e1f378bc5e06794a3893f31319a6ab3431e4323d755c&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=66448b4fecac3bc79b26304f&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.png%3B+filename%3D%22model.png%22%3B&response-content-type=image%2Fpng&x-id=GetObject&Expires=1761216419&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTIxNjQxOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OGU2NzdjNTk0ZDNmMjBiYmVlY2YxM2MvN2NmZjZhZTAyMWQ3YzUxOGVlNGUyZmNiNzA0OTA1MTZhZDllNDk5OWVjNzVjNmE1ZGQxNjRjYzZjYTY0YWUzMCoifV19&Signature=YjrX1ZF%7EX1qw-m2nWOY8AxdSXwbrsidvlTZ5YWXZx3UPv0my0u68lWcpWIpIxzkGeWTtWPvlCfMcmnpmmwS2wHexorhgq9c7%7E3Ghw20evO0EMPvHBwP4vWYmXW8nHBqqqbw8Qy1pojDm9TvXV19O4-fCFxPi1aQ5FOTC2Kmn9gKxW%7EAN7vkWnfhU8QcCf18139hMbUvh9YoJ%7EesOWXoCFWgAbyz%7Eroajt5e3oM9b-IsU%7E2-UzMZ4%7EMA2MSOFmg487bhZDbr2IMD15-8O0jzWu3qyO3T1H06S-9kTdI%7EC6AYtXUY8YtSWKw%7EBzhARjXK6%7EuZ3c3kE1V7%7EdnLl1YM-2w__&Key-Pair-Id=K2L8F4GPSG1IFC"
PROTEIN_HERO = f"""
"""
DESCRIPTION = f"""\
### Prot2Text-V2 Demo
{PROTEIN_HERO}
Prot2Text-V2 treats a protein sequence as if it were another language and translates it into English. Supply a raw amino acid sequence and the model returns a clear, human-readable paragraph describing what the protein does.
The paper describing Prot2Text-V2 has been accepted to the NeurIPS 2025 main conference and pairs fast experimentation with explainability-minded outputs.
- **Input**: protein sequence using IUPAC single-letter amino acid codes (20 canonical amino acids).
- **Output**: polished descriptions of predicted function, localization cues, and structural hints.
- **Why it matters**: accelerate protein characterization, lab annotations, or downstream hypothesis building.
**Model architecture at a glance**
- Protein language model encoder: facebook/esm2_t36_3B_UR50D.
- Modality adapter: lightweight bridge aligning protein embeddings with the language model.
- Natural language decoder: meta-llama/Llama-3.1-8B-Instruct for articulate descriptions.
**Resources**
- [Paper (NeurIPS 2025)](https://arxiv.org/abs/2505.11194)
- [Code repository](https://github.com/ColinFX/Prot2Text-V2)
- [Training data](https://huggingface.co/datasets/habdine/Prot2Text-Data)
"""
EXAMPLE_SEQUENCES = [
["AEQAERYEEMVEFMEKL"],
[
"MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA"
],
[
"MCYSANGNTFLIVDNTQKRIPEEKKPDFVRENVGDLDGVIFVELVDGKYFMDYYNRDGSMAAFCGNGARAFSQYLIDRGWIKEKEFTFLSRAGEIKVIVDDSIWVRMPGVSEKKEMKVDGYEGYFVVVGVPHFVMEVKGIDELDVEKLGRDLRYKTGANVDFYEVLPDRLKVRTYERGVERETKACGTGVTSVFVVYRDKTGAKEVKIQVPGGTLFLKEENGEIFLRGDVKRCSEE"
],
[
"MTQEERFEQRIAQETAIEPQDWMPDAYRKTLIRQIGQHAHSEIVGMLPEGNWITRAPTLRRKAILLAKVQDEAGHGLYLYSAAETLGCAREDIYQKMLDGRMKYSSIFNYPTLSWADIGVIGWLVDGAAIVNQVALCRTSYGPYARAMVKICKEESFHQRQGFEACMALAQGSEAQKQMLQDAINRFWWPALMMFGPNDDNSPNSARSLTWKIKRFTNDELRQRFVDNTVPQVEMLGMTVPDPDLHFDTESGHYRFGEIDWQEFNEVINGRGICNQERLDAKRKAWEEGTWVREAALAHAQKQHARKVA"
],
[
"MTTRMIILNGGSSAGKSGIVRCLQSVLPEPWLAFGVDSLIEAMPLKMQSAEGGIEFDADGGVSIGPEFRALEGAWAEGVVAMARAGARIIIDDVFLGGAAAQERWRSFVGDLDVLWVGVRCDGAVAEGRETARGDRVAGMAAKQAYVVHEGVEYDVEVDTTHKESIECAWAIAAHVVP"
],
]
MAX_MAX_NEW_TOKENS = 256
DEFAULT_MAX_NEW_TOKENS = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
system_message = (
"You are a scientific assistant specialized in protein function "
"predictions. Given the sequence embeddings and other information "
"of a protein, describe its function clearly and concisely in "
"professional language. "
)
placeholder = '<|reserved_special_token_1|>'
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
llama_tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
pad_token='<|reserved_special_token_0|>'
)
model = AutoModelForCausalLM.from_pretrained('xiao-fei/Prot2Text-V2-11B-Instruct-hf',
trust_remote_code=True,
torch_dtype=torch.bfloat16,).to(device)
model.eval()
@spaces.GPU(duration=90)
def stream_response(
message: str,
max_new_tokens: int = 1024,
do_sample: bool = False,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
streamer = TextIteratorStreamer(llama_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
user_message = "Sequence embeddings: " + placeholder * (len(message)+2)
tokenized_prompt = llama_tokenizer.apply_chat_template(
[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
tokenized_sequence = esm_tokenizer(
message,
return_tensors="pt"
)
model.eval()
generate_kwargs = dict(
inputs=tokenized_prompt["input_ids"].to(model.device),
attention_mask=tokenized_prompt["attention_mask"].to(model.device),
protein_input_ids=tokenized_sequence["input_ids"].to(model.device),
protein_attention_mask=tokenized_sequence["attention_mask"].to(model.device),
eos_token_id=128009,
pad_token_id=128002,
return_dict_in_generate=False,
num_beams=1,
# device=device,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
ChatHistory = List[Tuple[str, str]]
def handle_submit(
message: str,
history: ChatHistory,
max_new_tokens: int,
do_sample: bool,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
):
history = list(history or [])
message = message.strip()
if not message:
return
conversation = history.copy()
conversation.append((message, ""))
for partial_response in stream_response(
message=message,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
):
conversation[-1] = (message, partial_response)
snapshot = conversation.copy()
yield snapshot, snapshot, gr.update(value="")
def clear_conversation():
empty_history: ChatHistory = []
return empty_history, empty_history, gr.update(value="")
theme = Soft(
primary_hue="slate",
secondary_hue="stone",
neutral_hue="gray",
)
with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=5, min_width=320):
gr.HTML(
f"""
"""
)
gr.Markdown(DESCRIPTION)
with gr.Column(scale=7, min_width=400, elem_classes="interaction-column"):
history_state = gr.State([])
chatbot = gr.Chatbot(
label="Generated Function",
height=350,
show_copy_button=True,
)
with gr.Group(elem_classes="input-card"):
sequence_input = gr.Textbox(
placeholder="Paste your amino acid sequence here (e.g. MAVVLPAVVEELLSEMAAAVQESA...)",
label="Protein sequence",
lines=1,
max_lines=1,
autofocus=True,
)
with gr.Row(elem_classes="button-row"):
submit_button = gr.Button("Predict function", variant="primary", elem_classes="primary-btn")
stop_button = gr.Button("Stop generation", variant="stop", elem_classes="stop-btn")
gr.Examples(
examples=EXAMPLE_SEQUENCES,
inputs=sequence_input,
label="Sample sequences",
cache_examples=False,
run_on_click=False,
)
with gr.Accordion("Generation controls", open=False):
max_new_tokens_slider = gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
do_sample_checkbox = gr.Checkbox(label="Enable sampling", value=False)
temperature_slider = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
)
top_p_slider = gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
)
top_k_slider = gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
)
repetition_penalty_slider = gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
)
enter_event = sequence_input.submit(
handle_submit,
inputs=[
sequence_input,
history_state,
max_new_tokens_slider,
do_sample_checkbox,
temperature_slider,
top_p_slider,
top_k_slider,
repetition_penalty_slider,
],
outputs=[chatbot, history_state, sequence_input],
queue=True,
)
submit_event = submit_button.click(
handle_submit,
inputs=[
sequence_input,
history_state,
max_new_tokens_slider,
do_sample_checkbox,
temperature_slider,
top_p_slider,
top_k_slider,
repetition_penalty_slider,
],
outputs=[chatbot, history_state, sequence_input],
queue=True,
)
stop_button.click(
None,
inputs=None,
outputs=None,
cancels=[submit_event, enter_event],
)
with gr.Accordion("Model & usage notes", open=False):
gr.Markdown(
"- **Model stack**: Facebook ESM2 encoder + Llama 3.1 8B instruction-tuned decoder.\n"
"- **Token budget**: the generator truncates after the configured `Max new tokens`.\n"
"- **Attribution**: Outputs are predictions; validate experimentally before publication.\n"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
if __name__ == "__main__":
demo.queue(max_size=20).launch()