Spaces:
Running
Running
| import argparse | |
| import glob | |
| import logging | |
| import os | |
| import sys | |
| import warnings | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from datasets.utils.logging import disable_progress_bar | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| # Suppress warnings and logging | |
| warnings.filterwarnings("ignore") | |
| logging.disable(logging.WARNING) | |
| disable_progress_bar() | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Append the utils module path | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from finetune import download_pretrained_model | |
| from generation_utils import ReactionT5Dataset | |
| from models import ReactionT5Yield | |
| from train import preprocess_df | |
| from utils import seed_everything | |
| def parse_args(): | |
| """ | |
| Parse command line arguments. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="Prediction script for ReactionT5Yield model." | |
| ) | |
| parser.add_argument( | |
| "--input_data", | |
| type=str, | |
| required=True, | |
| help="Data as a CSV file that contains an 'input' column. The format of the contents of the column are like 'REACTANT:{reactants of the reaction}PRODUCT:{products of the reaction}'. If there are multiple reactants, concatenate them with '.'.", | |
| ) | |
| parser.add_argument( | |
| "--model_name_or_path", | |
| type=str, | |
| help="Name or path of the finetuned model for prediction. Can be a local model or one from Hugging Face.", | |
| ) | |
| parser.add_argument( | |
| "--download_pretrained_model", | |
| action="store_true", | |
| help="Download finetuned model from hugging face hub and use it for prediction.", | |
| ) | |
| parser.add_argument("--debug", action="store_true", help="Use debug mode.") | |
| parser.add_argument( | |
| "--input_max_length", | |
| type=int, | |
| default=300, | |
| help="Maximum token length of input.", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=5, required=False, help="Batch size." | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=4, help="Number of data loading workers." | |
| ) | |
| parser.add_argument( | |
| "--fc_dropout", | |
| type=float, | |
| default=0.0, | |
| help="Dropout rate after fully connected layers.", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./", | |
| help="Directory where predictions are saved.", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=42, help="Random seed for reproducibility." | |
| ) | |
| return parser.parse_args() | |
| def inference_fn(test_loader, model, cfg): | |
| """ | |
| Inference function. | |
| Args: | |
| test_loader (DataLoader): DataLoader for test data. | |
| model (nn.Module): Model for inference. | |
| cfg (argparse.Namespace): Configuration object. | |
| Returns: | |
| np.ndarray: Predictions. | |
| """ | |
| model.eval() | |
| model.to(cfg.device) | |
| preds = [] | |
| for inputs in tqdm(test_loader, total=len(test_loader)): | |
| inputs = {k: v.to(cfg.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| y_preds = model(inputs) | |
| preds.append(y_preds.to("cpu").numpy()) | |
| return np.concatenate(preds) | |
| if __name__ == "__main__": | |
| CFG = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CFG.device = device | |
| if not os.path.exists(CFG.output_dir): | |
| os.makedirs(CFG.output_dir) | |
| seed_everything(seed=CFG.seed) | |
| if CFG.model_name_or_path is None: | |
| CFG.download_pretrained_model = True | |
| if CFG.download_pretrained_model: | |
| download_pretrained_model() | |
| CFG.model_name_or_path = "." | |
| CFG.tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.abspath(CFG.model_name_or_path) | |
| if os.path.exists(CFG.model_name_or_path) | |
| else CFG.model_name_or_path, | |
| return_tensors="pt", | |
| ) | |
| model = ReactionT5Yield( | |
| CFG, | |
| config_path=os.path.join(CFG.model_name_or_path, "config.pth"), | |
| pretrained=False, | |
| ) | |
| pth_files = glob.glob(os.path.join(CFG.model_name_or_path, "*.pth")) | |
| for pth_file in pth_files: | |
| state = torch.load( | |
| pth_file, | |
| map_location=torch.device("cpu"), | |
| ) | |
| try: | |
| model.load_state_dict(state) | |
| break | |
| except: | |
| pass | |
| test_ds = pd.read_csv(CFG.input_data) | |
| test_ds = preprocess_df(test_ds, CFG, drop_duplicates=False) | |
| test_dataset = ReactionT5Dataset(CFG, test_ds) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=CFG.batch_size, | |
| shuffle=False, | |
| num_workers=CFG.num_workers, | |
| pin_memory=True, | |
| drop_last=False, | |
| ) | |
| prediction = inference_fn(test_loader, model, CFG) | |
| test_ds["prediction"] = prediction * 100 | |
| test_ds["prediction"] = test_ds["prediction"].clip(0, 100) | |
| test_ds.to_csv( | |
| os.path.join(CFG.output_dir, "yield_prediction_output.csv"), index=False | |
| ) | |