Spaces:
Running
Running
| import argparse | |
| import os | |
| import sys | |
| import warnings | |
| import pandas as pd | |
| import rdkit | |
| from rdkit import Chem | |
| from transformers import AutoTokenizer | |
| rdkit.RDLogger.DisableLog("rdApp.*") | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from utils import canonicalize, seed_everything | |
| warnings.filterwarnings("ignore") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Script for reaction retrosynthesis prediction." | |
| ) | |
| parser.add_argument( | |
| "--input_data", | |
| type=str, | |
| required=True, | |
| help="Path to the input data.", | |
| ) | |
| parser.add_argument( | |
| "--target_data", | |
| type=str, | |
| required=True, | |
| help="Path to the target data.", | |
| ) | |
| parser.add_argument( | |
| "--target_col", | |
| type=str, | |
| required=True, | |
| help="Name of target column.", | |
| ) | |
| parser.add_argument( | |
| "--model_name_or_path", | |
| type=str, | |
| default="sagawa/ReactionT5v2-retrosynthesis", | |
| help="Name or path of the finetuned model for prediction. Can be a local model or one from Hugging Face.", | |
| ) | |
| parser.add_argument( | |
| "--num_beams", type=int, default=5, help="Number of beams used for beam search." | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=42, help="Seed for reproducibility." | |
| ) | |
| return parser.parse_args() | |
| def remove_space(row): | |
| for i in range(5): | |
| row[f"{i}th"] = row[f"{i}th"].replace(" ", "") | |
| return row | |
| if __name__ == "__main__": | |
| CFG = parse_args() | |
| seed_everything(seed=CFG.seed) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.abspath(CFG.model_name_or_path) | |
| if os.path.exists(CFG.model_name_or_path) | |
| else CFG.model_name_or_path, | |
| return_tensors="pt", | |
| ) | |
| df = pd.read_csv(CFG.input_data) | |
| df[[f"{i}th" for i in range(CFG.num_beams)]] = df[ | |
| [f"{i}th" for i in range(CFG.num_beams)] | |
| ].fillna(" ") | |
| df["target"] = pd.read_csv(CFG.target_data)[CFG.target_col].values | |
| df = df.apply(remove_space, axis=1) | |
| top_k_invalidity = CFG.num_beams | |
| top1, top2, top3, top5 = [], [], [], [] | |
| invalidity = [] | |
| for idx, row in df.iterrows(): | |
| target = canonicalize(row["target"]) | |
| if canonicalize(row["0th"]) == target: | |
| top1.append(1) | |
| top2.append(1) | |
| top3.append(1) | |
| top5.append(1) | |
| elif canonicalize(row["1th"]) == target: | |
| top1.append(0) | |
| top2.append(1) | |
| top3.append(1) | |
| top5.append(1) | |
| elif canonicalize(row["2th"]) == target: | |
| top1.append(0) | |
| top2.append(0) | |
| top3.append(1) | |
| top5.append(1) | |
| elif canonicalize(row["3th"]) == target: | |
| top1.append(0) | |
| top2.append(0) | |
| top3.append(0) | |
| top5.append(1) | |
| elif canonicalize(row["4th"]) == target: | |
| top1.append(0) | |
| top2.append(0) | |
| top3.append(0) | |
| top5.append(1) | |
| else: | |
| top1.append(0) | |
| top2.append(0) | |
| top3.append(0) | |
| top5.append(0) | |
| input_compound = row["input"] | |
| output = [row[f"{i}th"] for i in range(top_k_invalidity)] | |
| inval_score = 0 | |
| for ith, out in enumerate(output): | |
| mol = Chem.MolFromSmiles(out.rstrip(".")) | |
| if not isinstance(mol, Chem.rdchem.Mol): | |
| inval_score += 1 | |
| invalidity.append(inval_score) | |
| print(CFG.input_data) | |
| print(f"Top 1 accuracy: {sum(top1) / len(top1)}") | |
| print(f"Top 2 accuracy: {sum(top2) / len(top2)}") | |
| print(f"Top 3 accuracy: {sum(top3) / len(top3)}") | |
| print(f"Top 5 accuracy: {sum(top5) / len(top5)}") | |
| print( | |
| f"Top {top_k_invalidity} Invalidity: {sum(invalidity) / (len(invalidity) * top_k_invalidity) * 100}" | |
| ) | |