Spaces:
Running
Running
| import os | |
| import gc | |
| import random | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import tokenizers | |
| import transformers | |
| from transformers import AutoTokenizer, EncoderDecoderModel, AutoModelForSeq2SeqLM | |
| import sentencepiece | |
| from rdkit import Chem | |
| import rdkit | |
| import streamlit as st | |
| class CFG(): | |
| input_data = st.text_area('enter chemical reaction (e.g. REACTANT:NCCO.O=C1COCC(=O)O1CATALYST: REAGENT: SOLVENT:c1ccncc1)') | |
| model_name_or_path = 'sagawa/ZINC-t5' | |
| model = 't5' | |
| num_beams = 5 | |
| num_return_sequences = 5 | |
| seed = 42 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def seed_everything(seed=42): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| seed_everything(seed=CFG.seed) | |
| input_compound = CFG.input_data | |
| min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0) | |
| inp = tokenizer(input_compound, return_tensors='pt').to(device) | |
| output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True) | |
| scores = output['sequences_scores'].tolist() | |
| output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']] | |
| for ith, out in enumerate(output): | |
| mol = Chem.MolFromSmiles(out.rstrip('.')) | |
| if type(mol) == rdkit.Chem.rdchem.Mol: | |
| output.append(out.rstrip('.')) | |
| scores.append(scores[ith]) | |
| break | |
| if type(mol) == None: | |
| output.append(None) | |
| scores.append(None) | |
| output += scores | |
| output = [input_compound] + output | |
| output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score']) | |
| print(output_df) |