Spaces:
Running
Running
| import math | |
| import os | |
| import pickle | |
| import random | |
| import time | |
| import numpy as np | |
| import torch | |
| from rdkit import Chem | |
| def seed_everything(seed=42): | |
| random.seed(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| def space_clean(row): | |
| row = row.replace(". ", "").replace(" .", "").replace(" ", " ") | |
| return row | |
| def canonicalize(smiles): | |
| try: | |
| new_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True) | |
| except: | |
| new_smiles = None | |
| return new_smiles | |
| def canonicalize_str(smiles): | |
| """Try to canonicalize the molecule, return empty string if fails.""" | |
| if "%" in smiles: | |
| return smiles | |
| else: | |
| try: | |
| return canonicalize(smiles) | |
| except: | |
| return "" | |
| def uncanonicalize(smiles): | |
| try: | |
| new_smiles = [] | |
| for smiles_i in smiles.split("."): | |
| mol = Chem.MolFromSmiles(smiles_i) | |
| atom_indices = list(range(mol.GetNumAtoms())) | |
| random.shuffle(atom_indices) | |
| new_smiles_i = Chem.MolToSmiles( | |
| mol, rootedAtAtom=atom_indices[0], canonical=False | |
| ) | |
| new_smiles.append(new_smiles_i) | |
| smiles = ".".join(new_smiles) | |
| except: | |
| smiles = None | |
| return smiles | |
| def remove_atom_mapping(smi): | |
| mol = Chem.MolFromSmiles(smi) | |
| [a.SetAtomMapNum(0) for a in mol.GetAtoms()] | |
| smi = Chem.MolToSmiles(mol, canonical=True) | |
| return canonicalize(smi) | |
| def get_logger(filename="train"): | |
| from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger | |
| logger = getLogger(__name__) | |
| logger.setLevel(INFO) | |
| handler1 = StreamHandler() | |
| handler1.setFormatter(Formatter("%(message)s")) | |
| handler2 = FileHandler(filename=f"{filename}.log") | |
| handler2.setFormatter(Formatter("%(message)s")) | |
| logger.addHandler(handler1) | |
| logger.addHandler(handler2) | |
| return logger | |
| class AverageMeter(object): | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def asMinutes(s): | |
| m = math.floor(s / 60) | |
| s -= m * 60 | |
| return "%dm %ds" % (m, s) | |
| def timeSince(since, percent): | |
| now = time.time() | |
| s = now - since | |
| es = s / (percent) | |
| rs = es - s | |
| return "%s (remain %s)" % (asMinutes(s), asMinutes(rs)) | |
| def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0): | |
| no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] | |
| optimizer_parameters = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.model.named_parameters() | |
| if not any(nd in n for nd in no_decay) | |
| ], | |
| "lr": encoder_lr, | |
| "weight_decay": weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.model.named_parameters() | |
| if any(nd in n for nd in no_decay) | |
| ], | |
| "lr": encoder_lr, | |
| "weight_decay": 0.0, | |
| }, | |
| { | |
| "params": [p for n, p in model.named_parameters() if "model" not in n], | |
| "lr": decoder_lr, | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| return optimizer_parameters | |
| def to_cpu(obj): | |
| if torch.is_tensor(obj): | |
| return obj.to("cpu") | |
| elif isinstance(obj, dict): | |
| return {k: to_cpu(v) for k, v in obj.items()} | |
| elif ( | |
| isinstance(obj, list) | |
| or isinstance(obj, tuple) | |
| or isinstance(obj, set) | |
| or isinstance(obj, torch.Tensor) | |
| ): | |
| return [to_cpu(v) for v in obj] | |
| else: | |
| return obj | |
| def get_accuracy_score(eval_preds, cfg): | |
| preds, labels = eval_preds | |
| if isinstance(preds, tuple): | |
| preds = preds[0] | |
| decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=True) | |
| labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id) | |
| decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=True) | |
| decoded_preds = [ | |
| canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds | |
| ] | |
| decoded_labels = [ | |
| [canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels | |
| ] | |
| score = 0 | |
| for i in range(len(decoded_preds)): | |
| if decoded_preds[i] == decoded_labels[i][0]: | |
| score += 1 | |
| score /= len(decoded_preds) | |
| return {"accuracy": score} | |
| def get_accuracy_score_multitask(eval_preds, cfg): | |
| preds, labels = eval_preds | |
| if isinstance(preds, tuple): | |
| preds = preds[0] | |
| special_tokens = cfg.tokenizer.special_tokens_map | |
| special_tokens = [ | |
| special_tokens["eos_token"], | |
| special_tokens["pad_token"], | |
| special_tokens["unk_token"], | |
| ] + list( | |
| set(special_tokens["additional_special_tokens"]) | |
| - set( | |
| [ | |
| "0%", | |
| "10%", | |
| "20%", | |
| "30%", | |
| "40%", | |
| "50%", | |
| "60%", | |
| "70%", | |
| "80%", | |
| "90%", | |
| "100%", | |
| ] | |
| ) | |
| ) | |
| decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=False) | |
| for special_token in special_tokens: | |
| decoded_preds = [pred.replace(special_token, "") for pred in decoded_preds] | |
| labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id) | |
| decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=False) | |
| for special_token in special_tokens: | |
| decoded_labels = [pred.replace(special_token, "") for pred in decoded_labels] | |
| decoded_preds = [ | |
| canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds | |
| ] | |
| decoded_labels = [ | |
| [canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels | |
| ] | |
| score = 0 | |
| for i in range(len(decoded_preds)): | |
| if decoded_preds[i] == decoded_labels[i][0]: | |
| score += 1 | |
| score /= len(decoded_preds) | |
| return {"accuracy": score} | |
| def preprocess_dataset(examples, cfg): | |
| inputs = examples["input"] | |
| targets = examples[cfg.target_column] | |
| model_inputs = cfg.tokenizer( | |
| inputs, max_length=cfg.input_max_length, truncation=True | |
| ) | |
| labels = cfg.tokenizer(targets, max_length=cfg.target_max_length, truncation=True) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| def filter_out(df, col_names): | |
| for col_name in col_names: | |
| df = df[~df[col_name].isna()].reset_index(drop=True) | |
| return df | |
| def save_pickle(path: str, contents): | |
| """Saves contents to a pickle file.""" | |
| with open(path, "wb") as f: | |
| pickle.dump(contents, f) | |
| def load_pickle(path: str): | |
| """Loads contents from a pickle file.""" | |
| with open(path, "rb") as f: | |
| return pickle.load(f) | |
| def add_new_tokens(tokenizer, file_path): | |
| """ | |
| Adds new tokens to the tokenizer from a file. | |
| The file should contain one token per line. | |
| """ | |
| with open(file_path, "r") as f: | |
| new_tokens = [line.strip() for line in f if line.strip()] | |
| tokenizer.add_tokens(new_tokens) | |
| return tokenizer | |