Spaces:
Running
Running
| import argparse | |
| import os | |
| import sys | |
| import warnings | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from generation_utils import ReactionT5Dataset | |
| from train import preprocess_df, preprocess_USPTO | |
| from utils import filter_out, seed_everything | |
| warnings.filterwarnings("ignore") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--input_data", | |
| type=str, | |
| required=True, | |
| help="Path to the input data.", | |
| ) | |
| parser.add_argument( | |
| "--test_data", | |
| type=str, | |
| required=False, | |
| help="Path to the test data. If provided, the duplicates will be removed from the input data.", | |
| ) | |
| parser.add_argument( | |
| "--input_max_length", | |
| type=int, | |
| default=400, | |
| help="Maximum token length of input.", | |
| ) | |
| parser.add_argument( | |
| "--model_name_or_path", | |
| type=str, | |
| default="sagawa/ReactionT5v2-forward", | |
| help="Name or path of the finetuned model for prediction. Can be a local model or one from Hugging Face.", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=5, help="Batch size for prediction." | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./", | |
| help="Directory where predictions are saved.", | |
| ) | |
| parser.add_argument( | |
| "--debug", action="store_true", default=False, help="Use debug mode." | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=42, help="Seed for reproducibility." | |
| ) | |
| return parser.parse_args() | |
| def create_embedding(dataloader, model, device): | |
| outputs_mean = [] | |
| model.eval() | |
| model.to(device) | |
| for inputs in dataloader: | |
| inputs = {k: v.to(CFG.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model(**inputs) | |
| last_hidden_states = output[0] | |
| input_mask_expanded = ( | |
| inputs["attention_mask"] | |
| .unsqueeze(-1) | |
| .expand(last_hidden_states.size()) | |
| .float() | |
| ) | |
| sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1) | |
| sum_mask = input_mask_expanded.sum(1) | |
| sum_mask = torch.clamp(sum_mask, min=1e-6) | |
| mean_embeddings = sum_embeddings / sum_mask | |
| outputs_mean.append(mean_embeddings.detach().cpu().numpy()) | |
| return outputs_mean | |
| if __name__ == "__main__": | |
| CFG = parse_args() | |
| CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if not os.path.exists(CFG.output_dir): | |
| os.makedirs(CFG.output_dir) | |
| seed_everything(seed=CFG.seed) | |
| CFG.tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.abspath(CFG.model_name_or_path) | |
| if os.path.exists(CFG.model_name_or_path) | |
| else CFG.model_name_or_path, | |
| return_tensors="pt", | |
| ) | |
| model = T5EncoderModel.from_pretrained(CFG.model_name_or_path).to(CFG.device) | |
| model.eval() | |
| input_data = filter_out(pd.read_csv(CFG.input_data), ["REACTANT", "PRODUCT"]) | |
| input_data = preprocess_df(input_data, drop_duplicates=False) | |
| if CFG.test_data: | |
| input_data_copy = preprocess_USPTO(input_data.copy()) | |
| test_data = filter_out(pd.read_csv(CFG.test_data), ["REACTANT", "PRODUCT"]) | |
| USPTO_test = preprocess_USPTO(test_data) | |
| input_data = input_data[ | |
| ~input_data_copy["pair"].isin(USPTO_test["pair"]) | |
| ].reset_index(drop=True) | |
| input_data.to_csv(os.path.join(CFG.output_dir, "input_data.csv"), index=False) | |
| dataset = ReactionT5Dataset(CFG, input_data) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=CFG.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True, | |
| drop_last=False, | |
| ) | |
| outputs = create_embedding(dataloader, model, CFG.device) | |
| outputs = np.concatenate(outputs, axis=0) | |
| np.save(os.path.join(CFG.output_dir, "embedding_mean.npy"), outputs) | |