Spaces:
Running
Running
| import argparse | |
| import gc | |
| import glob | |
| import os | |
| import sys | |
| import time | |
| import warnings | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from datasets.utils.logging import disable_progress_bar | |
| from sklearn.metrics import mean_squared_error, r2_score | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoTokenizer, get_linear_schedule_with_warmup | |
| # Append the utils module path | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from generation_utils import prepare_input | |
| from models import ReactionT5Yield | |
| from rdkit import RDLogger | |
| from utils import ( | |
| AverageMeter, | |
| add_new_tokens, | |
| canonicalize, | |
| filter_out, | |
| get_logger, | |
| get_optimizer_params, | |
| seed_everything, | |
| space_clean, | |
| timeSince, | |
| ) | |
| # Suppress warnings and logging | |
| warnings.filterwarnings("ignore") | |
| RDLogger.DisableLog("rdApp.*") | |
| disable_progress_bar() | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def parse_args(): | |
| """ | |
| Parse command line arguments. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="Training script for ReactionT5Yield model." | |
| ) | |
| parser.add_argument( | |
| "--train_data_path", | |
| type=str, | |
| required=True, | |
| help="Path to training data CSV file.", | |
| ) | |
| parser.add_argument( | |
| "--valid_data_path", | |
| type=str, | |
| required=True, | |
| help="Path to validation data CSV file.", | |
| ) | |
| parser.add_argument( | |
| "--test_data_path", | |
| type=str, | |
| help="Path to testing data CSV file.", | |
| ) | |
| parser.add_argument( | |
| "--CN_test_data_path", | |
| type=str, | |
| help="Path to CN testing data CSV file.", | |
| ) | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default="sagawa/CompoundT5", | |
| help="Pretrained model name or path.", | |
| ) | |
| parser.add_argument( | |
| "--model_name_or_path", | |
| type=str, | |
| help="The model's name or path used for fine-tuning.", | |
| ) | |
| parser.add_argument("--debug", action="store_true", help="Enable debug mode.") | |
| parser.add_argument( | |
| "--epochs", type=int, default=5, help="Number of training epochs." | |
| ) | |
| parser.add_argument( | |
| "--patience", type=int, default=10, help="Early stopping patience." | |
| ) | |
| parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate.") | |
| parser.add_argument("--batch_size", type=int, default=5, help="Batch size.") | |
| parser.add_argument( | |
| "--input_max_length", type=int, default=400, help="Maximum input token length." | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=4, help="Number of data loading workers." | |
| ) | |
| parser.add_argument( | |
| "--fc_dropout", | |
| type=float, | |
| default=0.0, | |
| help="Dropout rate after fully connected layers.", | |
| ) | |
| parser.add_argument( | |
| "--eps", type=float, default=1e-6, help="Epsilon for Adam optimizer." | |
| ) | |
| parser.add_argument( | |
| "--weight_decay", type=float, default=0.05, help="Weight decay for optimizer." | |
| ) | |
| parser.add_argument( | |
| "--max_grad_norm", | |
| type=int, | |
| default=1000, | |
| help="Maximum gradient norm for clipping.", | |
| ) | |
| parser.add_argument( | |
| "--gradient_accumulation_steps", | |
| type=int, | |
| default=1, | |
| help="Gradient accumulation steps.", | |
| ) | |
| parser.add_argument( | |
| "--num_warmup_steps", type=int, default=0, help="Number of warmup steps." | |
| ) | |
| parser.add_argument( | |
| "--batch_scheduler", action="store_true", help="Use batch scheduler." | |
| ) | |
| parser.add_argument( | |
| "--print_freq", type=int, default=100, help="Logging frequency." | |
| ) | |
| parser.add_argument( | |
| "--use_amp", | |
| action="store_true", | |
| help="Use automatic mixed precision for training.", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./", | |
| help="Directory to save the trained model.", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=42, help="Random seed for reproducibility." | |
| ) | |
| parser.add_argument( | |
| "--sampling_num", | |
| type=int, | |
| default=-1, | |
| help="Number of samples used for training. If you want to use all samples, set -1.", | |
| ) | |
| parser.add_argument( | |
| "--sampling_frac", | |
| type=float, | |
| default=-1.0, | |
| help="Ratio of samples used for training. If you want to use all samples, set -1.0.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| help="Path to the checkpoint file for resuming training.", | |
| ) | |
| return parser.parse_args() | |
| def preprocess_df(df, cfg, drop_duplicates=True): | |
| """ | |
| Preprocess the input DataFrame for training. | |
| Args: | |
| df (pd.DataFrame): Input DataFrame. | |
| cfg (argparse.Namespace): Configuration object. | |
| Returns: | |
| pd.DataFrame: Preprocessed DataFrame. | |
| """ | |
| if "YIELD" in df.columns: | |
| # if max yield is 100, then normalize to [0, 1] | |
| if df["YIELD"].max() >= 100: | |
| df["YIELD"] = df["YIELD"].clip(0, 100) / 100 | |
| else: | |
| df["YIELD"] = None | |
| for col in ["REACTANT", "PRODUCT", "CATALYST", "REAGENT", "SOLVENT"]: | |
| if col not in df.columns: | |
| df[col] = None | |
| df[col] = df[col].fillna(" ") | |
| df["REAGENT"] = df["CATALYST"] + "." + df["REAGENT"] | |
| for col in ["REAGENT", "REACTANT", "PRODUCT"]: | |
| df[col] = df[col].apply(lambda x: space_clean(x)) | |
| df[col] = df[col].apply(lambda x: canonicalize(x) if x != " " else " ") | |
| df = df[~df[col].isna()].reset_index(drop=True) | |
| df[col] = df[col].apply(lambda x: ".".join(sorted(x.split(".")))) | |
| df["input"] = ( | |
| "REACTANT:" | |
| + df["REACTANT"] | |
| + "REAGENT:" | |
| + df["REAGENT"] | |
| + "PRODUCT:" | |
| + df["PRODUCT"] | |
| ) | |
| if drop_duplicates: | |
| df = df.loc[df[["input", "YIELD"]].drop_duplicates().index].reset_index( | |
| drop=True | |
| ) | |
| if cfg.debug: | |
| df = df.head(1000) | |
| return df | |
| def preprocess_CN(df): | |
| """ | |
| Preprocess the CN test DataFrame. | |
| Args: | |
| df (pd.DataFrame): Input DataFrame. | |
| Returns: | |
| pd.DataFrame: Preprocessed DataFrame. | |
| """ | |
| df["REACTANT"] = df["REACTANT"].apply(lambda x: ".".join(sorted(x.split(".")))) | |
| df["REAGENT"] = df["REAGENT"].apply(lambda x: ".".join(sorted(x.split(".")))) | |
| df["PRODUCT"] = df["PRODUCT"].apply(lambda x: ".".join(sorted(x.split(".")))) | |
| df["input"] = ( | |
| "REACTANT:" | |
| + df["REACTANT"] | |
| + "REAGENT:" | |
| + df["REAGENT"] | |
| + "PRODUCT:" | |
| + df["PRODUCT"] | |
| ) | |
| df["pair"] = df["input"] | |
| return df | |
| class TrainDataset(Dataset): | |
| """ | |
| Dataset class for training. | |
| """ | |
| def __init__(self, cfg, df): | |
| self.cfg = cfg | |
| self.inputs = df["input"].values | |
| self.labels = df["YIELD"].values | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, item): | |
| inputs = prepare_input(self.cfg, self.inputs[item]) | |
| label = torch.tensor(self.labels[item], dtype=torch.float) | |
| return inputs, label | |
| def save_checkpoint(state, filename="checkpoint.pth.tar"): | |
| """ | |
| Save model checkpoint. | |
| Args: | |
| state (dict): Checkpoint state. | |
| filename (str): Filename to save the checkpoint. | |
| """ | |
| torch.save(state, filename) | |
| def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, cfg): | |
| """ | |
| Training function for one epoch. | |
| Args: | |
| train_loader (DataLoader): DataLoader for training data. | |
| model (nn.Module): Model to be trained. | |
| criterion (nn.Module): Loss function. | |
| optimizer (Optimizer): Optimizer. | |
| epoch (int): Current epoch. | |
| scheduler (Scheduler): Learning rate scheduler. | |
| cfg (argparse.Namespace): Configuration object. | |
| Returns: | |
| float: Average training loss. | |
| """ | |
| model.train() | |
| scaler = torch.amp.GradScaler(enabled=cfg.use_amp) | |
| losses = AverageMeter() | |
| start = time.time() | |
| for step, (inputs, labels) in enumerate(train_loader): | |
| inputs = {k: v.to(cfg.device) for k, v in inputs.items()} | |
| labels = labels.to(cfg.device) | |
| batch_size = labels.size(0) | |
| with torch.autocast(cfg.device, enabled=cfg.use_amp): | |
| y_preds = model(inputs) | |
| loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1)) | |
| if cfg.gradient_accumulation_steps > 1: | |
| loss /= cfg.gradient_accumulation_steps | |
| losses.update(loss.item(), batch_size) | |
| scaler.scale(loss).backward() | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), cfg.max_grad_norm | |
| ) | |
| if (step + 1) % cfg.gradient_accumulation_steps == 0: | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad() | |
| if cfg.batch_scheduler: | |
| scheduler.step() | |
| if step % cfg.print_freq == 0 or step == (len(train_loader) - 1): | |
| print( | |
| f"Epoch: [{epoch + 1}][{step}/{len(train_loader)}] " | |
| f"Elapsed {timeSince(start, float(step + 1) / len(train_loader))} " | |
| f"Loss: {losses.val:.4f}({losses.avg:.4f}) " | |
| f"Grad: {grad_norm:.4f} " | |
| f"LR: {scheduler.get_lr()[0]:.8f}" | |
| ) | |
| return losses.avg | |
| def valid_fn(valid_loader, model, cfg): | |
| """ | |
| Validation function. | |
| Args: | |
| valid_loader (DataLoader): DataLoader for validation data. | |
| model (nn.Module): Model to be validated. | |
| cfg (argparse.Namespace): Configuration object. | |
| Returns: | |
| tuple: Validation loss and R^2 score. | |
| """ | |
| model.eval() | |
| start = time.time() | |
| label_list = [] | |
| pred_list = [] | |
| for step, (inputs, labels) in enumerate(valid_loader): | |
| inputs = {k: v.to(cfg.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| y_preds = model(inputs) | |
| label_list.extend(labels.tolist()) | |
| pred_list.extend(y_preds.tolist()) | |
| if step % cfg.print_freq == 0 or step == (len(valid_loader) - 1): | |
| print( | |
| f"EVAL: [{step}/{len(valid_loader)}] " | |
| f"Elapsed {timeSince(start, float(step + 1) / len(valid_loader))} " | |
| f"RMSE Loss: {np.sqrt(mean_squared_error(label_list, pred_list)):.4f} " | |
| f"R^2 Score: {r2_score(label_list, pred_list):.4f}" | |
| ) | |
| return mean_squared_error(label_list, pred_list), r2_score(label_list, pred_list) | |
| def train_loop(train_ds, valid_ds, cfg): | |
| """ | |
| Training loop. | |
| Args: | |
| train_ds (pd.DataFrame): Training data. | |
| valid_ds (pd.DataFrame): Validation data. | |
| """ | |
| train_dataset = TrainDataset(cfg, train_ds) | |
| valid_dataset = TrainDataset(cfg, valid_ds) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=cfg.batch_size, | |
| shuffle=True, | |
| num_workers=cfg.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| valid_loader = DataLoader( | |
| valid_dataset, | |
| batch_size=cfg.batch_size, | |
| shuffle=False, | |
| num_workers=cfg.num_workers, | |
| pin_memory=True, | |
| drop_last=False, | |
| ) | |
| if not cfg.model_name_or_path: | |
| model = ReactionT5Yield(cfg, config_path=None, pretrained=True) | |
| torch.save(model.config, os.path.join(cfg.output_dir, "config.pth")) | |
| else: | |
| model = ReactionT5Yield( | |
| cfg, | |
| config_path=os.path.join(cfg.model_name_or_path, "config.pth"), | |
| pretrained=False, | |
| ) | |
| torch.save(model.config, os.path.join(cfg.output_dir, "config.pth")) | |
| pth_files = glob.glob(os.path.join(cfg.model_name_or_path, "*.pth")) | |
| for pth_file in pth_files: | |
| state = torch.load( | |
| pth_file, map_location=torch.device("cpu"), weights_only=False | |
| ) | |
| try: | |
| model.load_state_dict(state) | |
| break | |
| except: | |
| pass | |
| model.to(cfg.device) | |
| optimizer_parameters = get_optimizer_params( | |
| model, encoder_lr=cfg.lr, decoder_lr=cfg.lr, weight_decay=cfg.weight_decay | |
| ) | |
| optimizer = AdamW(optimizer_parameters, lr=cfg.lr, eps=cfg.eps, betas=(0.9, 0.999)) | |
| num_train_steps = int(len(train_ds) / cfg.batch_size * cfg.epochs) | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=cfg.num_warmup_steps, | |
| num_training_steps=num_train_steps, | |
| ) | |
| criterion = nn.MSELoss(reduction="mean") | |
| best_loss = float("inf") | |
| start_epoch = 0 | |
| es_count = 0 | |
| if cfg.checkpoint: | |
| checkpoint = torch.load(cfg.checkpoint) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| scheduler.load_state_dict(checkpoint["scheduler"]) | |
| best_loss = checkpoint["loss"] | |
| start_epoch = checkpoint["epoch"] + 1 | |
| es_count = checkpoint["es_count"] | |
| del checkpoint | |
| for epoch in range(start_epoch, cfg.epochs): | |
| start_time = time.time() | |
| avg_loss = train_fn( | |
| train_loader, model, criterion, optimizer, epoch, scheduler, cfg | |
| ) | |
| val_loss, val_r2_score = valid_fn(valid_loader, model, cfg) | |
| elapsed = time.time() - start_time | |
| cfg.logger.info( | |
| f"Epoch {epoch + 1} - avg_train_loss: {avg_loss:.4f} val_rmse_loss: {val_loss:.4f} val_r2_score: {val_r2_score:.4f} time: {elapsed:.0f}s" | |
| ) | |
| if val_loss < best_loss: | |
| es_count = 0 | |
| best_loss = val_loss | |
| cfg.logger.info( | |
| f"Epoch {epoch + 1} - Save Lowest Loss: {best_loss:.4f} Model" | |
| ) | |
| torch.save( | |
| model.state_dict(), | |
| os.path.join( | |
| cfg.output_dir, | |
| f"{cfg.pretrained_model_name_or_path.split('/')[-1]}_best.pth", | |
| ), | |
| ) | |
| else: | |
| es_count += 1 | |
| if es_count >= cfg.patience: | |
| print("Early stopping") | |
| break | |
| save_checkpoint( | |
| { | |
| "epoch": epoch, | |
| "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict(), | |
| "loss": best_loss, | |
| "es_count": es_count, | |
| }, | |
| filename=os.path.join(cfg.output_dir, "checkpoint.pth.tar"), | |
| ) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if __name__ == "__main__": | |
| CFG = parse_args() | |
| CFG.batch_scheduler = True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CFG.device = device | |
| if not os.path.exists(CFG.output_dir): | |
| os.makedirs(CFG.output_dir) | |
| seed_everything(seed=CFG.seed) | |
| train = preprocess_df( | |
| filter_out(pd.read_csv(CFG.train_data_path), ["YIELD", "REACTANT", "PRODUCT"]), | |
| CFG, | |
| ) | |
| valid = preprocess_df( | |
| filter_out(pd.read_csv(CFG.valid_data_path), ["YIELD", "REACTANT", "PRODUCT"]), | |
| CFG, | |
| ) | |
| if CFG.CN_test_data_path: | |
| train_copy = preprocess_CN(train.copy()) | |
| CN_test = preprocess_CN(pd.read_csv(CFG.CN_test_data_path)) | |
| print(len(train)) | |
| train = train[~train_copy["pair"].isin(CN_test["pair"])].reset_index(drop=True) | |
| print(len(train)) | |
| train["pair"] = train["input"] + " - " + train["YIELD"].astype(str) | |
| valid["pair"] = valid["input"] + " - " + valid["YIELD"].astype(str) | |
| valid = valid[~valid["pair"].isin(train["pair"])].reset_index(drop=True) | |
| if CFG.sampling_num > 0: | |
| train = train.sample(n=CFG.sampling_num, random_state=CFG.seed).reset_index( | |
| drop=True | |
| ) | |
| elif CFG.sampling_frac > 0: | |
| train = train.sample(frac=CFG.sampling_frac, random_state=CFG.seed).reset_index( | |
| drop=True | |
| ) | |
| train.to_csv("train.csv", index=False) | |
| valid.to_csv("valid.csv", index=False) | |
| if CFG.test_data_path: | |
| test = filter_out( | |
| pd.read_csv(CFG.test_data_path), ["YIELD", "REACTANT", "PRODUCT"] | |
| ) | |
| test = preprocess_df(test, CFG) | |
| test["pair"] = test["input"] + " - " + test["YIELD"].astype(str) | |
| test = test[~test["pair"].isin(train["pair"])].reset_index(drop=True) | |
| test = test.drop_duplicates(subset=["pair"]).reset_index(drop=True) | |
| test.to_csv("test.csv", index=False) | |
| LOGGER = get_logger(os.path.join(CFG.output_dir, "train")) | |
| CFG.logger = LOGGER | |
| # load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.abspath(CFG.model_name_or_path) | |
| if os.path.exists(CFG.model_name_or_path) | |
| else CFG.model_name_or_path, | |
| return_tensors="pt", | |
| ) | |
| tokenizer = add_new_tokens( | |
| tokenizer, | |
| Path(__file__).resolve().parent.parent / "data" / "additional_tokens.txt", | |
| ) | |
| tokenizer.add_special_tokens( | |
| { | |
| "additional_special_tokens": tokenizer.additional_special_tokens | |
| + ["REACTANT:", "PRODUCT:", "REAGENT:"] | |
| } | |
| ) | |
| tokenizer.save_pretrained(CFG.output_dir) | |
| CFG.tokenizer = tokenizer | |
| train_loop(train, valid, CFG) | |