Spaces:
Running
Running
| import argparse | |
| import os | |
| import subprocess | |
| import sys | |
| import warnings | |
| import pandas as pd | |
| import torch | |
| from datasets.utils.logging import disable_progress_bar | |
| from transformers import AutoTokenizer | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from train import preprocess_df, train_loop | |
| from utils import get_logger, seed_everything | |
| # Suppress warnings and logging | |
| warnings.filterwarnings("ignore") | |
| disable_progress_bar() | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def parse_args(): | |
| """ | |
| Parse command line arguments. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="Training script for ReactionT5Yield model." | |
| ) | |
| parser.add_argument( | |
| "--train_data_path", | |
| type=str, | |
| required=True, | |
| help="Path to training data CSV file.", | |
| ) | |
| parser.add_argument( | |
| "--valid_data_path", | |
| type=str, | |
| required=True, | |
| help="Path to validation data CSV file.", | |
| ) | |
| parser.add_argument( | |
| "--similar_reaction_data_path", | |
| type=str, | |
| required=False, | |
| help="Path to similar data CSV.", | |
| ) | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default="sagawa/CompoundT5", | |
| help="Pretrained model name or path.", | |
| ) | |
| parser.add_argument( | |
| "--model_name_or_path", | |
| type=str, | |
| help="The model's name or path used for fine-tuning.", | |
| ) | |
| parser.add_argument( | |
| "--download_pretrained_model", | |
| action="store_true", | |
| default=False, | |
| required=False, | |
| help="Download pretrained model from hugging face hub and use it for fine-tuning.", | |
| ) | |
| parser.add_argument("--debug", action="store_true", help="Enable debug mode.") | |
| parser.add_argument( | |
| "--epochs", type=int, default=200, help="Number of training epochs." | |
| ) | |
| parser.add_argument( | |
| "--patience", type=int, default=10, help="Early stopping patience." | |
| ) | |
| parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") | |
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") | |
| parser.add_argument( | |
| "--input_max_length", type=int, default=300, help="Maximum input token length." | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=4, help="Number of data loading workers." | |
| ) | |
| parser.add_argument( | |
| "--fc_dropout", | |
| type=float, | |
| default=0.0, | |
| help="Dropout rate after fully connected layers.", | |
| ) | |
| parser.add_argument( | |
| "--eps", type=float, default=1e-6, help="Epsilon for Adam optimizer." | |
| ) | |
| parser.add_argument( | |
| "--weight_decay", type=float, default=0.05, help="Weight decay for optimizer." | |
| ) | |
| parser.add_argument( | |
| "--max_grad_norm", | |
| type=int, | |
| default=1000, | |
| help="Maximum gradient norm for clipping.", | |
| ) | |
| parser.add_argument( | |
| "--gradient_accumulation_steps", | |
| type=int, | |
| default=1, | |
| help="Gradient accumulation steps.", | |
| ) | |
| parser.add_argument( | |
| "--num_warmup_steps", type=int, default=0, help="Number of warmup steps." | |
| ) | |
| parser.add_argument( | |
| "--batch_scheduler", action="store_true", help="Use batch scheduler." | |
| ) | |
| parser.add_argument( | |
| "--print_freq", type=int, default=100, help="Logging frequency." | |
| ) | |
| parser.add_argument( | |
| "--use_amp", | |
| action="store_true", | |
| help="Use automatic mixed precision for training.", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./", | |
| help="Directory to save the trained model.", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=42, help="Random seed for reproducibility." | |
| ) | |
| parser.add_argument( | |
| "--sampling_num", | |
| type=int, | |
| default=-1, | |
| help="Number of samples used for training. If you want to use all samples, set -1.", | |
| ) | |
| parser.add_argument( | |
| "--sampling_frac", | |
| type=float, | |
| default=-1.0, | |
| help="Ratio of samples used for training. If you want to use all samples, set -1.0.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| help="Path to the checkpoint file for resuming training.", | |
| ) | |
| return parser.parse_args() | |
| def download_pretrained_model(): | |
| """ | |
| Download the pretrained model from Hugging Face. | |
| """ | |
| subprocess.run( | |
| "wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/CompoundT5_best.pth", | |
| shell=True, | |
| ) | |
| subprocess.run( | |
| "wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/config.pth", | |
| shell=True, | |
| ) | |
| subprocess.run( | |
| "wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/special_tokens_map.json", | |
| shell=True, | |
| ) | |
| subprocess.run( | |
| "wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/tokenizer.json", | |
| shell=True, | |
| ) | |
| subprocess.run( | |
| "wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/tokenizer_config.json", | |
| shell=True, | |
| ) | |
| if __name__ == "__main__": | |
| CFG = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CFG.device = device | |
| if not os.path.exists(CFG.output_dir): | |
| os.makedirs(CFG.output_dir) | |
| seed_everything(seed=CFG.seed) | |
| if CFG.download_pretrained_model: | |
| download_pretrained_model() | |
| CFG.model_name_or_path = "." | |
| train = pd.read_csv(CFG.train_data_path).drop_duplicates().reset_index(drop=True) | |
| valid = pd.read_csv(CFG.valid_data_path).drop_duplicates().reset_index(drop=True) | |
| train = preprocess_df(train, CFG) | |
| valid = preprocess_df(valid, CFG) | |
| if CFG.sampling_num > 0: | |
| train = train.sample(n=CFG.sampling_num, random_state=CFG.seed).reset_index( | |
| drop=True | |
| ) | |
| elif CFG.sampling_frac > 0 and CFG.sampling_frac < 1: | |
| train = train.sample(frac=CFG.sampling_frac, random_state=CFG.seed).reset_index( | |
| drop=True | |
| ) | |
| if CFG.similar_reaction_data_path: | |
| similar = preprocess_df(pd.read_csv(CFG.similar_reaction_data_path), CFG) | |
| print(len(train)) | |
| train = pd.concat([train, similar], ignore_index=True) | |
| print(len(train)) | |
| LOGGER = get_logger(os.path.join(CFG.output_dir, "train")) | |
| CFG.logger = LOGGER | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.abspath(CFG.model_name_or_path) | |
| if os.path.exists(CFG.model_name_or_path) | |
| else CFG.model_name_or_path, | |
| return_tensors="pt", | |
| ) | |
| tokenizer.save_pretrained(CFG.output_dir) | |
| CFG.tokenizer = tokenizer | |
| train_loop(train, valid, CFG) | |