Spaces:
Running
Running
| import argparse | |
| import os | |
| import sys | |
| import warnings | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from utils import seed_everything | |
| warnings.filterwarnings("ignore") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Search for similar reactions.") | |
| parser.add_argument( | |
| "--input_data", | |
| type=str, | |
| required=True, | |
| help="Path to the input data.", | |
| ) | |
| parser.add_argument( | |
| "--target_embedding", | |
| type=str, | |
| required=True, | |
| help="Path to the target embedding.", | |
| ) | |
| parser.add_argument( | |
| "--query_embedding", | |
| type=str, | |
| required=True, | |
| help="Path to the target embedding.", | |
| ) | |
| parser.add_argument("--batch_size", type=int, default=64, help="Batch size.") | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./", | |
| help="Directory where results are saved.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| config = parse_args() | |
| seed_everything(42) | |
| target_embedding = np.load(config.target_embedding) | |
| query_embedding = np.load(config.query_embedding) | |
| target_embedding = torch.tensor(target_embedding, dtype=torch.float32).cuda() | |
| query_embedding = torch.tensor(query_embedding, dtype=torch.float32).cuda() | |
| target_embedding = torch.nn.functional.normalize(target_embedding, p=2, dim=1) | |
| query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1) | |
| batch_size = config.batch_size | |
| distances = [] | |
| for i in range(0, query_embedding.shape[0], batch_size): | |
| print(f"Processing batch {i // batch_size}...") | |
| batch = query_embedding[i : i + batch_size] | |
| similarity = torch.matmul(batch, target_embedding.T) | |
| distance, _ = torch.max(similarity, dim=1) | |
| distances.append(distance.cpu().tolist()) | |
| distances = np.concatenate(distances) | |
| df = pd.read_csv(config.input_data) | |
| df["distance"] = distances | |
| df.to_csv(os.path.join(config.output_dir, "distance.csv"), index=False) | |