Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from transformers import AutoTokenizer, AutoModel | |
| from sentence_transformers import SentenceTransformer | |
| import pickle | |
| import nltk | |
| nltk.download('punkt') # tokenizer | |
| nltk.download('averaged_perceptron_tagger') # postagger | |
| import time | |
| from input_format import * | |
| from score import * | |
| # load document scoring model | |
| torch.cuda.is_available = lambda : False | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| pretrained_model = 'allenai/specter' | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model) | |
| doc_model = AutoModel.from_pretrained(pretrained_model) | |
| doc_model.to(device) | |
| # load sentence model | |
| sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base') | |
| sent_model.to(device) | |
| def get_similar_paper( | |
| abstract_text_input, | |
| pdf_file_input, | |
| author_id_input, | |
| num_papers_show=10 | |
| ): | |
| print('retrieving similar papers...') | |
| start = time.time() | |
| input_sentences = sent_tokenize(abstract_text_input) | |
| # TODO handle pdf file input | |
| if pdf_file_input is not None: | |
| name = None | |
| papers = [] | |
| raise ValueError('Use submission abstract instead.') | |
| else: | |
| # Get author papers from id | |
| name, papers = get_text_from_author_id(author_id_input) | |
| # Compute Doc-level affinity scores for the Papers | |
| print('computing scores...') | |
| titles, abstracts, doc_scores = compute_document_score( | |
| doc_model, | |
| tokenizer, | |
| abstract_text_input, | |
| papers, | |
| batch=50 | |
| ) | |
| tmp = { | |
| 'titles': titles, | |
| 'abstracts': abstracts, | |
| 'doc_scores': doc_scores | |
| } | |
| pickle.dump(tmp, open('paper_info.pkl', 'wb')) | |
| # Select top K choices of papers to show | |
| titles = titles[:num_papers_show] | |
| abstracts = abstracts[:num_papers_show] | |
| doc_scores = doc_scores[:num_papers_show] | |
| display_title = ['[ %0.3f ] %s'%(s, t) for t, s in zip(titles, doc_scores)] | |
| end = time.time() | |
| print('retrieval complete in [%0.2f] seconds'%(end - start)) | |
| return ( | |
| gr.update(choices=display_title, interactive=True, visible=True), # set of papers | |
| gr.update(choices=input_sentences, interactive=True), # submission sentences | |
| gr.update(visible=True), # title row | |
| gr.update(visible=True), # abstract row | |
| gr.update(visible=True) # button | |
| ) | |
| def get_highlights( | |
| abstract_text_input, | |
| pdf_file_input, | |
| abstract, | |
| K=2 | |
| ): | |
| print('obtaining highlights..') | |
| start = time.time() | |
| # Compute sent-level and phrase-level affinity scores for each papers | |
| sent_ids, sent_scores, info = get_highlight_info( | |
| sent_model, | |
| abstract_text_input, | |
| abstract, | |
| K=K | |
| ) | |
| input_sentences = sent_tokenize(abstract_text_input) | |
| num_sents = len(input_sentences) | |
| word_scores = dict() | |
| # different highlights for each input sentence | |
| for i in range(num_sents): | |
| word_scores[str(i)] = { | |
| "original": abstract, | |
| "interpretation": list(zip(info['all_words'], info[i]['scores'])) | |
| } # format to feed to for Gradio Interpretation component | |
| tmp = { | |
| 'source_sentences': input_sentences, | |
| 'highlight': word_scores | |
| } | |
| pickle.dump(tmp, open('highlight_info.pkl', 'wb')) | |
| end = time.time() | |
| print('done in [%0.2f] seconds'%(end - start)) | |
| # update the visibility of radio choices | |
| return gr.update(visible=True) | |
| def update_name(author_id_input): | |
| # update the name of the author based on the id input | |
| name, _ = get_text_from_author_id(author_id_input) | |
| return gr.update(value=name) | |
| def change_output_highlight(source_sent_choice): | |
| # change the output highlight based on the sentence selected from the submission | |
| fname = 'highlight_info.pkl' | |
| if os.path.exists(fname): | |
| tmp = pickle.load(open(fname, 'rb')) | |
| source_sents = tmp['source_sentences'] | |
| highlights = tmp['highlight'] | |
| for i, s in enumerate(source_sents): | |
| #print('changing highlight') | |
| if source_sent_choice == s: | |
| return highlights[str(i)] | |
| else: | |
| return | |
| def change_paper(selected_papers_radio): | |
| # change the paper to show based on the paper selected | |
| fname = 'paper_info.pkl' | |
| if os.path.exists(fname): | |
| tmp = pickle.load(open(fname, 'rb')) | |
| for title, abstract, aff_score in zip(tmp['titles'], tmp['abstracts'], tmp['doc_scores']): | |
| display_title = '[ %0.3f ] %s'%(aff_score, title) | |
| if display_title == selected_papers_radio: | |
| #print('changing paper') | |
| return title, abstract, aff_score # update title, abstract, and affinity score fields | |
| else: | |
| return | |
| with gr.Blocks() as demo: | |
| ### INPUT | |
| with gr.Row() as input_row: | |
| with gr.Column(): | |
| abstract_text_input = gr.Textbox(label='Submission Abstract') | |
| with gr.Column(): | |
| pdf_file_input = gr.File(label='OR upload a submission PDF File') | |
| with gr.Column(): | |
| with gr.Row(): | |
| author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)') | |
| with gr.Row(): | |
| name = gr.Textbox(label='Confirm Reviewer Name', interactive=False) | |
| author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name) | |
| with gr.Row(): | |
| compute_btn = gr.Button('Search Similar Papers from the Reviewer') | |
| ### PAPER INFORMATION | |
| # show multiple papers in radio check box to select from | |
| with gr.Row(): | |
| selected_papers_radio = gr.Radio( | |
| choices=[], # will be udpated with the button click | |
| visible=False, # also will be updated with the button click | |
| label='Selected Top Papers from the Reviewer' | |
| ) | |
| # selected paper information | |
| with gr.Row(visible=False) as title_row: | |
| with gr.Column(scale=3): | |
| paper_title = gr.Textbox(label='Title', interactive=False) | |
| with gr.Column(scale=1): | |
| affinity= gr.Number(label='Affinity', interactive=False, value=0) | |
| with gr.Row(visibe=False) as abstract_row: | |
| paper_abstract = gr.Textbox(label='Abstract', interactive=False, visible=False) | |
| with gr.Row(visible=False) as explain_button_row: | |
| explain_btn = gr.Button('Show Relevant Parts from Selected Paper') | |
| ### RELEVANT PARTS (HIGHLIGHTS) | |
| with gr.Row(): | |
| with gr.Column(scale=2): # text from submission | |
| source_sentences = gr.Radio( | |
| choices=[], | |
| visible=False, | |
| label='Sentences from Submission Abstract', | |
| ) | |
| with gr.Column(scale=3): # highlighted text from paper | |
| highlight = gr.components.Interpretation(paper_abstract) | |
| ### EVENT LISTENERS | |
| # retrieve similar papers | |
| compute_btn.click( | |
| fn=get_similar_paper, | |
| inputs=[ | |
| abstract_text_input, | |
| pdf_file_input, | |
| author_id_input | |
| ], | |
| outputs=[ | |
| selected_papers_radio, | |
| source_sentences, | |
| title_row, | |
| paper_abstract, | |
| explain_button_row, | |
| ] | |
| ) | |
| # get highlights | |
| explain_btn.click( | |
| fn=get_highlights, | |
| inputs=[ | |
| abstract_text_input, | |
| pdf_file_input, | |
| paper_abstract | |
| ], | |
| outputs=source_sentences | |
| ) | |
| # change highlight based on selected sentences from submission | |
| source_sentences.change( | |
| fn=change_output_highlight, | |
| inputs=source_sentences, | |
| outputs=highlight | |
| ) | |
| # change paper to show based on selected papers | |
| selected_papers_radio.change( | |
| fn=change_paper, | |
| inputs=selected_papers_radio, | |
| outputs= [ | |
| paper_title, | |
| paper_abstract, | |
| affinity | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |