import gc import os import sys import warnings from types import SimpleNamespace import pandas as pd import streamlit as st import torch from torch.utils.data import DataLoader from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Local imports sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward")) ) from generation_utils import ( ReactionT5Dataset, decode_output, save_multiple_predictions, ) from train import preprocess_df from utils import seed_everything warnings.filterwarnings("ignore") # ------------------------------ # Page setup # ------------------------------ st.set_page_config( page_title="ReactionT5 — Product Prediction", page_icon=None, layout="wide", ) st.title("ReactionT5 — Product Prediction") st.caption( "Predict reaction products from your inputs using a pretrained ReactionT5 model." ) with st.expander("How to format your CSV", expanded=False): st.markdown( """ - Include a required `REACTANT` column. - Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`. - If a field lists multiple compounds, separate them with a dot (`.`). - For details, download **demo_reaction_data.csv** and check its contents. - Output contains predicted product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is most probable). """ ) # ------------------------------ # Demo data download # ------------------------------ import io @st.cache_data(show_spinner=False) def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame: # If your files are always UTF-8, this is fine: return pd.read_csv(io.BytesIO(file_bytes)) # If you prefer explicit text decoding: # return pd.read_csv(io.StringIO(file_bytes.decode("utf-8"))) @st.cache_data(show_spinner=False) def load_demo_csv_as_bytes() -> bytes: demo_df = pd.read_csv("data/demo_reaction_data.csv") return demo_df.to_csv(index=False).encode("utf-8") st.download_button( label="Download demo_reaction_data.csv", data=load_demo_csv_as_bytes(), file_name="demo_reaction_data.csv", mime="text/csv", use_container_width=True, ) st.divider() # ------------------------------ # Sidebar: configuration # ------------------------------ with st.sidebar: st.header("Configuration") model_name_or_path = st.text_input( "Model", value="sagawa/ReactionT5v2-forward", help="Hugging Face model repo or a local path.", ) num_beams = st.slider( "Beam size", min_value=1, max_value=10, value=5, step=1, help="Number of beams for beam search.", ) seed = st.number_input( "Random seed", min_value=0, max_value=2**32 - 1, value=42, step=1, help="Seed for reproducibility.", ) with st.expander("Advanced generation", expanded=False): input_max_length = st.number_input( "Input max length", min_value=8, max_value=1024, value=400, step=8 ) output_max_length = st.number_input( "Output max length", min_value=8, max_value=1024, value=300, step=8 ) output_min_length = st.number_input( "Output min length", min_value=-1, max_value=1024, value=-1, step=1, help="Use -1 to let the model decide.", ) batch_size = st.number_input( "Batch size", min_value=1, max_value=16, value=1, step=1 ) num_workers = st.number_input( "DataLoader workers", min_value=0, max_value=8, value=4, step=1, help="Set to 0 if multiprocessing is restricted in your environment.", ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.caption(f"Detected device: **{device.type.upper()}**") # ------------------------------ # Cached loaders # ------------------------------ @st.cache_resource(show_spinner=False) def load_tokenizer(model_ref: str): resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref return AutoTokenizer.from_pretrained(resolved, return_tensors="pt") @st.cache_resource(show_spinner=True) def load_model(model_ref: str, device_str: str): resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref model = AutoModelForSeq2SeqLM.from_pretrained(resolved) model.to(torch.device(device_str)) model.eval() return model @st.cache_data(show_spinner=False) def df_to_csv_bytes(df: pd.DataFrame) -> bytes: return df.to_csv(index=False).encode("utf-8") # ------------------------------ # Main interaction # ------------------------------ left, right = st.columns([1.4, 1.0], vertical_alignment="top") with left: with st.form("predict_form", clear_on_submit=False): uploaded = st.file_uploader( "Upload a CSV file with reactions", type=["csv"], accept_multiple_files=False, help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.", ) run = st.form_submit_button("Predict", use_container_width=True) if uploaded is not None: try: file_bytes = uploaded.getvalue() raw_df = parse_csv_from_bytes(file_bytes) # raw_df = pd.read_csv(uploaded) st.subheader("Input preview") st.dataframe(raw_df.head(20), use_container_width=True) except Exception as e: st.error(f"Failed to read CSV: {e}") with right: st.subheader("Notes") st.markdown( f""" - Beam size: **{num_beams}** - Approximate time: about **15 seconds per reaction** when `beam size = 5` (varies by hardware). - Results include the **sum of log-likelihoods** per prediction and are **sorted** by that value. """ ) st.info( "If you encounter CUDA OOM issues, reduce max lengths or beam size, or switch to CPU." ) # ------------------------------ # Inference # ------------------------------ if "results_df" not in st.session_state: st.session_state["results_df"] = None if "last_error" not in st.session_state: st.session_state["last_error"] = None if run: if uploaded is None: st.warning("Please upload a CSV file before running prediction.") else: # Build config object expected by your dataset/utils CFG = SimpleNamespace( num_beams=int(num_beams), num_return_sequences=int(num_beams), # tie to beams by default model_name_or_path=model_name_or_path, input_column="input", input_max_length=int(input_max_length), output_max_length=int(output_max_length), output_min_length=int(output_min_length), model="t5", seed=int(seed), batch_size=int(batch_size), ) seed_everything(seed=CFG.seed) # Load model & tokenizer with st.status("Loading model and tokenizer...", expanded=False) as status: try: tokenizer = load_tokenizer(CFG.model_name_or_path) model = load_model(CFG.model_name_or_path, device.type) status.update(label="Model ready.", state="complete") except Exception as e: st.session_state["last_error"] = f"Failed to load model: {e}" status.update(label="Model load failed.", state="error") st.stop() # Prepare data file_bytes = uploaded.getvalue() input_df = parse_csv_from_bytes(file_bytes) # input_df = pd.read_csv(uploaded) input_df = preprocess_df(input_df, drop_duplicates=False) # Dataset & loader dataset = ReactionT5Dataset(CFG, input_df) dataloader = DataLoader( dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=int(num_workers), pin_memory=(device.type == "cuda"), drop_last=False, ) # Generation loop with progress all_sequences, all_scores = [], [] total = len(dataloader) progress = st.progress(0, text="Generating predictions...") info_placeholder = st.empty() for i, inputs in enumerate(dataloader, start=1): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): output = model.generate( **inputs, min_length=CFG.output_min_length, max_length=CFG.output_max_length, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True, ) sequences, scores = decode_output(output, CFG) all_sequences.extend(sequences) if scores: all_scores.extend(scores) del output if device.type == "cuda": torch.cuda.empty_cache() gc.collect() progress.progress(i / total, text=f"Generating predictions... {i}/{total}") info_placeholder.caption(f"Processed batch {i} of {total}") progress.empty() info_placeholder.empty() # Save predictions try: output_df = save_multiple_predictions( input_df, all_sequences, all_scores, CFG ) st.session_state["results_df"] = output_df st.success("Prediction complete.") except Exception as e: st.session_state["last_error"] = f"Failed to assemble output: {e}" st.error(st.session_state["last_error"]) st.stop() # ------------------------------ # Results # ------------------------------ if st.session_state.get("results_df") is not None: st.subheader("Results preview") st.dataframe(st.session_state["results_df"].head(50), use_container_width=True) st.download_button( label="Download predictions as CSV", data=df_to_csv_bytes(st.session_state["results_df"]), file_name="output.csv", mime="text/csv", use_container_width=True, ) if st.session_state.get("last_error"): st.error(st.session_state["last_error"])