Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from transformers import AutoTokenizer, AutoModel | |
| from sentence_transformers import SentenceTransformer | |
| import pickle | |
| from input_format import * | |
| from score import * | |
| # load document scoring model | |
| pretrained_model = 'allenai/specter' | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model) | |
| doc_model = AutoModel.from_pretrained(pretrained_model) | |
| # load sentence model | |
| sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base') | |
| def get_similar_paper( | |
| abstract_text_input, | |
| pdf_file_input, | |
| author_id_input, | |
| num_papers_show=10 | |
| ): | |
| input_sentences = sent_tokenize(abstract_text_input) | |
| pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb')) | |
| # TODO handle pdf file input | |
| if pdf_file_input is not None: | |
| name = None | |
| papers = [] | |
| raise ValueError('Use submission abstract instead.') | |
| else: | |
| # Get author papers from id | |
| name, papers = get_text_from_author_id(author_id_input) | |
| # Compute Doc-level affinity scores for the Papers | |
| titles, abstracts, doc_scores = compute_overall_score( | |
| doc_model, | |
| tokenizer, | |
| abstract_text_input, | |
| papers, | |
| batch=30 | |
| ) | |
| tmp = { | |
| 'titles': titles, | |
| 'abstracts': abstracts, | |
| 'doc_scores': doc_scores | |
| } | |
| pickle.dump(tmp, open('tmp_paperinfo.pkl', 'wb')) | |
| # Select top K choices of papers to show | |
| titles = titles[:num_papers_show] | |
| abstracts = abstracts[:num_papers_show] | |
| doc_scores = doc_scores[:num_papers_show] | |
| return titles[0], abstracts[0], doc_scores[0], gr.update(choices=input_sentences, interactive=True), gr.update(visible=True) | |
| def get_highlights( | |
| abstract_text_input, | |
| pdf_file_input, | |
| abstract, | |
| K=2 | |
| ): | |
| # Compute sent-level and phrase-level affinity scores for each papers | |
| sent_ids, sent_scores, info = get_highlight_info( | |
| sent_model, | |
| abstract_text_input, | |
| abstract, | |
| K=K | |
| ) | |
| input_sentences = sent_tokenize(abstract_text_input) | |
| num_sents = len(input_sentences) | |
| word_scores = dict() | |
| # different highlights for each input sentences | |
| for i in range(num_sents): | |
| word_scores[str(i)] = { | |
| "original": abstract, | |
| "interpretation": list(zip(info['all_words'], info[i]['scores'])) | |
| } | |
| tmp = { | |
| 'source_sentences': input_sentences, | |
| 'highlight': word_scores | |
| } | |
| pickle.dump(tmp, open('highlight_info.pkl', 'wb')) | |
| # update the visibility of radio choices | |
| return gr.update(visible=True) | |
| def update_name(author_id_input): | |
| # update the name of the author based on the id input | |
| name, _ = get_text_from_author_id(author_id_input) | |
| return gr.update(value=name) | |
| def change_output_highlight(source_sent_choice): | |
| # change the output highlight based on the sentence selected from the submission | |
| if os.path.exists('highlight_info.pkl'): | |
| tmp = pickle.load(open('highlight_info.pkl', 'rb')) | |
| source_sents = tmp['source_sentences'] | |
| highlights = tmp['highlight'] | |
| for i, s in enumerate(source_sents): | |
| print('changing highlight!') | |
| if source_sent_choice == s: | |
| return highlights[str(i)] | |
| else: | |
| return | |
| with gr.Blocks() as demo: | |
| ### INPUT | |
| with gr.Row() as input_row: | |
| with gr.Column(): | |
| abstract_text_input = gr.Textbox(label='Submission Abstract') | |
| with gr.Column(): | |
| pdf_file_input = gr.File(label='OR upload a submission PDF File') | |
| with gr.Column(): | |
| with gr.Row(): | |
| author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)') | |
| with gr.Row(): | |
| name = gr.Textbox(label='Confirm Reviewer Name', interactive=False) | |
| author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name) | |
| with gr.Row(): | |
| compute_btn = gr.Button('Search Similar Papers from the Reviewer') | |
| # with gr.Row(visible=False) as reviewer_name_info: | |
| # name = gr.Textbox(label='Reveiwer Author Name') | |
| # with gr.Row(): | |
| # with gr.Tabs(): | |
| # for tt in range(num_papers_show): | |
| # with gr.TabItem('Paper %d'%(tt+1)): | |
| # TODO handle multiple papers | |
| ### PAPER INFORMATION | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| paper_title = gr.Textbox(label='Title', interactive=False) | |
| with gr.Column(scale=1): | |
| affinity= gr.Number(label='Affinity', interactive=False, value=0) | |
| with gr.Row(): | |
| paper_abstract = gr.Textbox(label='Abstract', interactive=False) | |
| with gr.Row(visible=False) as explain_button_row: | |
| explain_btn = gr.Button('Show Relevant Parts from Selected Paper') | |
| ### RELEVANT PARTS (HIGHLIGHTS) | |
| with gr.Row(): | |
| with gr.Column(scale=2): # text from submission | |
| source_sentences = gr.Radio( | |
| choices=[], | |
| visible=False, | |
| label='Sentences from Submission Abstract', | |
| ) | |
| with gr.Column(scale=3): # highlighted text from paper | |
| highlight = gr.components.Interpretation(paper_abstract) | |
| compute_btn.click( | |
| fn=get_similar_paper, | |
| inputs=[ | |
| abstract_text_input, | |
| pdf_file_input, | |
| author_id_input | |
| ], | |
| outputs=[ | |
| paper_title, | |
| paper_abstract, | |
| affinity, | |
| source_sentences, | |
| explain_button_row | |
| ] | |
| ) | |
| explain_btn.click( | |
| fn=get_highlights, | |
| inputs=[ | |
| abstract_text_input, | |
| pdf_file_input, | |
| paper_abstract | |
| ], | |
| outputs=source_sentences | |
| ) | |
| source_sentences.change( | |
| fn=change_output_highlight, | |
| inputs=source_sentences, | |
| outputs=highlight | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |