Spaces:
Running
Running
| import gc | |
| import os | |
| import warnings | |
| from types import SimpleNamespace | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| # Local imports | |
| from generation_utils import ( | |
| ReactionT5Dataset, | |
| decode_output, | |
| save_multiple_predictions, | |
| ) | |
| from models import ReactionT5Yield2 | |
| from torch.utils.data import DataLoader | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from utils import seed_everything | |
| warnings.filterwarnings("ignore") | |
| # ------------------------------ | |
| # Page setup | |
| # ------------------------------ | |
| st.set_page_config( | |
| page_title="ReactionT5", | |
| page_icon=None, | |
| layout="wide", | |
| ) | |
| st.title("ReactionT5") | |
| st.caption( | |
| "Predict reaction products, reactants, or yields from your inputs using a pretrained ReactionT5 model." | |
| ) | |
| with st.expander("How to format your CSV", expanded=False): | |
| st.markdown( | |
| """ | |
| - Include a required `REACTANT` column. | |
| - Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`. | |
| - If a field lists multiple compounds, separate them with a dot (`.`). | |
| - For details, download **demo_reaction_data.csv** and check its contents. | |
| - Output contains predicted product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is most probable). | |
| """ | |
| ) | |
| # ------------------------------ | |
| # Demo data download | |
| # ------------------------------ | |
| import io | |
| def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame: | |
| # If your files are always UTF-8, this is fine: | |
| return pd.read_csv(io.BytesIO(file_bytes)) | |
| # If you prefer explicit text decoding: | |
| # return pd.read_csv(io.StringIO(file_bytes.decode("utf-8"))) | |
| def load_demo_csv_as_bytes() -> bytes: | |
| demo_df = pd.read_csv("data/demo_reaction_data.csv") | |
| return demo_df.to_csv(index=False).encode("utf-8") | |
| st.download_button( | |
| label="Download demo_reaction_data.csv", | |
| data=load_demo_csv_as_bytes(), | |
| file_name="demo_reaction_data.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| ) | |
| st.divider() | |
| # ------------------------------ | |
| # Sidebar: configuration | |
| # ------------------------------ | |
| with st.sidebar: | |
| st.header("Configuration") | |
| task = st.selectbox( | |
| "Task", | |
| options=["product prediction", "retrosynthesis prediction", "yield prediction"], | |
| index=0, | |
| help="Choose the task to run.", | |
| ) | |
| # Model options tied to task | |
| if task == "product prediction": | |
| model_options = [ | |
| "sagawa/ReactionT5v2-forward", | |
| "sagawa/ReactionT5v2-forward-USPTO_MIT", | |
| ] | |
| model_help = "Recommended models for product prediction." | |
| input_max_length_default = 400 | |
| output_max_length_default = 300 | |
| from task_forward.train import preprocess_df | |
| elif task == "retrosynthesis prediction": | |
| model_options = [ | |
| "sagawa/ReactionT5v2-retrosynthesis", | |
| "sagawa/ReactionT5v2-retrosynthesis-USPTO_50k", | |
| ] | |
| model_help = "Recommended models for retrosynthesis prediction." | |
| input_max_length_default = 100 | |
| output_max_length_default = 400 | |
| from task_retrosynthesis.train import preprocess_df | |
| else: # yield prediction | |
| model_options = ["sagawa/ReactionT5v2-yield"] # default as requested | |
| model_help = "Default model for yield prediction." | |
| input_max_length_default = 400 | |
| from task_yield.train import preprocess_df | |
| from task_yield.prediction import inference_fn | |
| model_name_or_path = st.selectbox( | |
| "Model", | |
| options=model_options, | |
| index=0, | |
| help=model_help, | |
| ) | |
| if task != "yield prediction": | |
| num_beams = st.slider( | |
| "Beam size", | |
| min_value=1, | |
| max_value=10, | |
| value=5, | |
| step=1, | |
| help="Number of beams for beam search.", | |
| ) | |
| seed = st.number_input( | |
| "Random seed", | |
| min_value=0, | |
| max_value=2**32 - 1, | |
| value=42, | |
| step=1, | |
| help="Seed for reproducibility.", | |
| ) | |
| with st.expander("Advanced generation", expanded=False): | |
| input_max_length = st.number_input( | |
| "Input max length", | |
| min_value=8, | |
| max_value=1024, | |
| value=input_max_length_default, | |
| step=8, | |
| ) | |
| if task != "yield prediction": | |
| output_max_length = st.number_input( | |
| "Output max length", | |
| min_value=8, | |
| max_value=1024, | |
| value=output_max_length_default, | |
| step=8, | |
| ) | |
| output_min_length = st.number_input( | |
| "Output min length", | |
| min_value=-1, | |
| max_value=1024, | |
| value=-1, | |
| step=1, | |
| help="Use -1 to let the model decide.", | |
| ) | |
| batch_size = st.number_input( | |
| "Batch size", min_value=1, max_value=16, value=1, step=1 | |
| ) | |
| num_workers = st.number_input( | |
| "DataLoader workers", | |
| min_value=0, | |
| max_value=8, | |
| value=4, | |
| step=1, | |
| help="Set to 0 if multiprocessing is restricted in your environment.", | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| st.caption(f"Detected device: **{device.type.upper()}**") | |
| # ------------------------------ | |
| # Cached loaders | |
| # ------------------------------ | |
| def load_tokenizer(model_ref: str): | |
| resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref | |
| return AutoTokenizer.from_pretrained(resolved, return_tensors="pt") | |
| def load_model(model_ref: str, device_str: str, task: str): | |
| resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref | |
| if task != "yield prediction": | |
| model = AutoModelForSeq2SeqLM.from_pretrained(resolved) | |
| else: | |
| model = ReactionT5Yield2.from_pretrained(resolved) | |
| model.to(torch.device(device_str)) | |
| model.eval() | |
| return model | |
| def df_to_csv_bytes(df: pd.DataFrame) -> bytes: | |
| return df.to_csv(index=False).encode("utf-8") | |
| # ------------------------------ | |
| # Main interaction | |
| # ------------------------------ | |
| left, right = st.columns([1.4, 1.0], vertical_alignment="top") | |
| with left: | |
| with st.form("predict_form", clear_on_submit=False): | |
| uploaded = st.file_uploader( | |
| "Upload a CSV file with reactions", | |
| type=["csv"], | |
| accept_multiple_files=False, | |
| help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.", | |
| ) | |
| run = st.form_submit_button("Predict", use_container_width=True) | |
| if uploaded is not None: | |
| try: | |
| file_bytes = uploaded.getvalue() | |
| raw_df = parse_csv_from_bytes(file_bytes) | |
| # raw_df = pd.read_csv(uploaded) | |
| st.subheader("Input preview") | |
| st.dataframe(raw_df.head(20), use_container_width=True) | |
| except Exception as e: | |
| st.error(f"Failed to read CSV: {e}") | |
| with right: | |
| st.subheader("Notes") | |
| st.markdown( | |
| f""" | |
| - Beam size: **{num_beams}** | |
| - Approximate time: about **15 seconds per reaction** when `beam size = 5` (varies by hardware). | |
| - Results include the **sum of log-likelihoods** per prediction and are **sorted** by that value. | |
| """ | |
| ) | |
| st.info( | |
| "If you encounter CUDA OOM issues, reduce max lengths or beam size, or switch to CPU." | |
| ) | |
| # ------------------------------ | |
| # Inference | |
| # ------------------------------ | |
| if "results_df" not in st.session_state: | |
| st.session_state["results_df"] = None | |
| if "last_error" not in st.session_state: | |
| st.session_state["last_error"] = None | |
| if run: | |
| if uploaded is None: | |
| st.warning("Please upload a CSV file before running prediction.") | |
| else: | |
| # Build config object expected by your dataset/utils | |
| CFG = SimpleNamespace( | |
| task=task, | |
| num_beams=int(num_beams) if task != "yield prediction" else None, | |
| num_return_sequences=int(num_beams) | |
| if task != "yield prediction" | |
| else None, # tie to beams by default | |
| model_name_or_path=model_name_or_path, | |
| input_column="input", | |
| input_max_length=int(input_max_length) | |
| if task != "yield prediction" | |
| else None, | |
| output_max_length=int(output_max_length) | |
| if task != "yield prediction" | |
| else None, | |
| output_min_length=int(output_min_length) | |
| if task != "yield prediction" | |
| else None, | |
| seed=int(seed), | |
| batch_size=int(batch_size), | |
| ) | |
| seed_everything(seed=CFG.seed) | |
| # Load model & tokenizer | |
| with st.status("Loading model and tokenizer...", expanded=False) as status: | |
| try: | |
| tokenizer = load_tokenizer(CFG.model_name_or_path) | |
| CFG.tokenizer = tokenizer | |
| model = load_model(CFG.model_name_or_path, device.type, task) | |
| status.update(label="Model ready.", state="complete") | |
| except Exception as e: | |
| st.session_state["last_error"] = f"Failed to load model: {e}" | |
| status.update(label="Model load failed.", state="error") | |
| st.stop() | |
| # Prepare data | |
| file_bytes = uploaded.getvalue() | |
| input_df = parse_csv_from_bytes(file_bytes) | |
| # input_df = pd.read_csv(uploaded) | |
| input_df = preprocess_df(input_df, drop_duplicates=False) | |
| # Dataset & loader | |
| dataset = ReactionT5Dataset(CFG, input_df) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=CFG.batch_size, | |
| shuffle=False, | |
| num_workers=int(num_workers), | |
| pin_memory=(device.type == "cuda"), | |
| drop_last=False, | |
| ) | |
| if task == "yield prediction": | |
| # Use custom inference function for yield prediction | |
| prediction = inference_fn(dataloader, model, CFG) | |
| output_df = input_df.copy() | |
| output_df["prediction"] = prediction | |
| output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0) | |
| st.session_state["results_df"] = output_df | |
| st.success("Prediction complete.") | |
| else: | |
| # Generation loop with progress | |
| all_sequences, all_scores = [], [] | |
| total = len(dataloader) | |
| progress = st.progress(0, text="Generating predictions...") | |
| info_placeholder = st.empty() | |
| for i, inputs in enumerate(dataloader, start=1): | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| min_length=CFG.output_min_length, | |
| max_length=CFG.output_max_length, | |
| num_beams=CFG.num_beams, | |
| num_return_sequences=CFG.num_return_sequences, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| sequences, scores = decode_output(output, CFG) | |
| all_sequences.extend(sequences) | |
| if scores: | |
| all_scores.extend(scores) | |
| del output | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| progress.progress(i / total, text=f"Generating predictions... {i}/{total}") | |
| info_placeholder.caption(f"Processed batch {i} of {total}") | |
| progress.empty() | |
| info_placeholder.empty() | |
| # Save predictions | |
| try: | |
| output_df = save_multiple_predictions( | |
| input_df, all_sequences, all_scores, CFG | |
| ) | |
| st.session_state["results_df"] = output_df | |
| st.success("Prediction complete.") | |
| except Exception as e: | |
| st.session_state["last_error"] = f"Failed to assemble output: {e}" | |
| st.error(st.session_state["last_error"]) | |
| st.stop() | |
| # ------------------------------ | |
| # Results | |
| # ------------------------------ | |
| if st.session_state.get("results_df") is not None: | |
| st.subheader("Results preview") | |
| st.dataframe(st.session_state["results_df"].head(50), use_container_width=True) | |
| st.download_button( | |
| label="Download predictions as CSV", | |
| data=df_to_csv_bytes(st.session_state["results_df"]), | |
| file_name="output.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| ) | |
| if st.session_state.get("last_error"): | |
| st.error(st.session_state["last_error"]) | |