Spaces:
Runtime error
Runtime error
| from sentence_transformers import util, SentenceTransformer | |
| from transformers import BertModel | |
| from nltk.tokenize import sent_tokenize | |
| from nltk import word_tokenize, pos_tag | |
| import torch | |
| import numpy as np | |
| import tqdm | |
| def compute_sentencewise_scores(model, query_sents, candidate_sents, tokenizer=None): | |
| if isinstance(model, SentenceTransformer): | |
| # if the model is using SentenceTrasformer style | |
| q_v, c_v = get_embedding(model, query_sents, candidate_sents) | |
| elif isinstance(model, BertModel): | |
| # if the model is BERT-style model using transformers library | |
| inputs = tokenizer( | |
| query_sents + candidate_sents, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512 | |
| ) | |
| inputs.to(model.device) | |
| result = model(**inputs) | |
| embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy() | |
| q_v = embeddings[:len(query_sents)] | |
| c_v = embeddings[len(query_sents):] | |
| else: | |
| raise ValueError('model not supported at the time') | |
| assert(q_v.shape[1] == c_v.shape[1]) | |
| assert(q_v.shape[0] == len(query_sents)) | |
| assert(c_v.shape[0] == len(candidate_sents)) | |
| return util.cos_sim(q_v, c_v) | |
| def get_embedding(model, query_sents, candidate_sents): | |
| q_v = model.encode(query_sents) | |
| c_v = model.encode(candidate_sents) | |
| return q_v, c_v | |
| def get_top_k(score_mat, K=3): | |
| """ | |
| Pick top K sentences to show | |
| """ | |
| picked_scores, picked_sent = torch.sort(-score_mat, axis=1) | |
| picked_sent = picked_sent[:,:K] | |
| picked_scores = -picked_scores[:,:K] | |
| return picked_sent, picked_scores | |
| def get_words(sent): | |
| """ | |
| Input: list of sentences | |
| Output: list of list of words per sentence, all words in, index of starting words for each sentence | |
| """ | |
| words = [] | |
| sent_start_id = [] # keep track of the word index where the new sentence starts | |
| counter = 0 | |
| for x in sent: | |
| w = word_tokenize(x) | |
| nw = len(w) | |
| counter += nw | |
| words.append(w) | |
| sent_start_id.append(counter) | |
| words = [word_tokenize(x) for x in sent] | |
| all_words = [item for sublist in words for item in sublist] | |
| sent_start_id.pop() | |
| sent_start_id = [0] + sent_start_id | |
| assert(len(sent_start_id) == len(sent)) | |
| return words, all_words, sent_start_id | |
| def get_match_phrase(w1, w2, method='pos'): | |
| """ | |
| Input: list of words for query and candidate text | |
| Output: word list and binary mask of matching phrases between the inputs | |
| """ | |
| mask1 = np.zeros(len(w1)) | |
| mask2 = np.zeros(len(w2)) | |
| if method == 'pos': | |
| # POS tags that should be considered for matching phrase | |
| include = [ | |
| 'NN', | |
| 'NNS', | |
| 'NNP', | |
| 'NNPS', | |
| 'LS', | |
| 'SYM', | |
| 'FW' | |
| ] | |
| pos1 = pos_tag(w1) | |
| pos2 = pos_tag(w2) | |
| for i, (w, p) in enumerate(pos2): | |
| for j, (w_, p_) in enumerate(pos1): | |
| if w.lower() == w_.lower() and p in include: | |
| mask2[i] = 1 | |
| mask1[j] = 1 | |
| return mask1, mask2 | |
| def remove_spaces(words, attrs): | |
| # make the output more readable by removing unnecessary spacings from the tokenizer | |
| # e.g. | |
| # 1. spacing for parenthesis | |
| # 2. spacing for single/double quotations | |
| # 3. spacing for commas and periods | |
| # 4. spacing for possessive quotations | |
| assert(len(words) == len(attrs)) | |
| word_out, attr_out = [], [] | |
| idx, single_q, double_q = 0, 0, 0 | |
| while idx < len(words): | |
| # stick to the word that appears right before | |
| if words[idx] in [',', '.', '%', ')', ':', '?', ';', "'s", '”', "''"]: | |
| ww = word_out.pop() | |
| aa = attr_out.pop() | |
| word_out.append(ww + words[idx]) | |
| attr_out.append(aa) | |
| idx += 1 | |
| # stick to the word that appears right after | |
| elif words[idx] in ["(", '“']: | |
| word_out.append(words[idx] + words[idx+1]) | |
| attr_out.append(attrs[idx+1]) | |
| idx += 2 | |
| # quotes | |
| elif words[idx] == '"': | |
| double_q += 1 | |
| if double_q == 2: | |
| # this is closing quote: stick to word before | |
| ww = word_out.pop() | |
| aa = attr_out.pop() | |
| word_out.append(ww + words[idx]) | |
| attr_out.append(aa) | |
| idx += 1 | |
| double_q = 0 | |
| else: | |
| # this is opening quote: stick to the word after | |
| word_out.append(words[idx] + words[idx+1]) | |
| attr_out.append(attrs[idx+1]) | |
| idx += 2 | |
| elif words[idx] == "'": | |
| single_q += 1 | |
| if single_q == 2: | |
| # this is closing quote: stick to word before | |
| ww = word_out.pop() | |
| aa = attr_out.pop() | |
| word_out.append(ww + words[idx]) | |
| attr_out.append(aa) | |
| idx += 1 | |
| single_q = 0 | |
| else: | |
| if words[idx-1][-1] == 's': #possessive quote | |
| # stick to the word before, reset counter | |
| ww = word_out.pop() | |
| aa = attr_out.pop() | |
| word_out.append(ww + words[idx]) | |
| attr_out.append(aa) | |
| idx += 1 | |
| single_q = 0 | |
| else: | |
| # this is opening quote: stick to the word after | |
| word_out.append(words[idx] + words[idx+1]) | |
| attr_out.append(attrs[idx+1]) | |
| idx += 2 | |
| elif words[idx] == '``': | |
| # this is opening quote: stick to the word after, but change to real double quote | |
| word_out.append('"' + words[idx+1]) | |
| attr_out.append(attrs[idx+1]) | |
| idx += 2 | |
| elif words[idx] == "''": | |
| # this is closing quote: stick to word before, but change to real double quote | |
| ww = word_out.pop() | |
| aa = attr_out.pop() | |
| word_out.append(ww + '"') | |
| attr_out.append(aa) | |
| idx += 1 | |
| else: | |
| word_out.append(words[idx]) | |
| attr_out.append(attrs[idx]) | |
| idx += 1 | |
| assert(len(word_out) == len(attr_out)) | |
| return word_out, attr_out | |
| def scale_scores(arr, vmin=0.1, vmax=1): | |
| # rescale positive and negative attributions to be between vmin and vmax. | |
| # while keeping 0 at 0. | |
| pos_max, pos_min = np.max(arr[arr > 0]), np.min(arr[arr > 0]) | |
| out = (arr - pos_min) / (pos_max - pos_min) * (vmax - vmin) + vmin | |
| idx = np.where(arr == 0.0)[0] | |
| out[idx] = 0.0 | |
| return out | |
| def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores): | |
| """ | |
| Mark the words that are highlighted, both by in terms of sentence and phrase | |
| """ | |
| num_query_sent = sent_ids.shape[0] | |
| num_cand_sent = sent_ids.shape[1] | |
| num_words = len(all_words) | |
| output = dict() | |
| output['all_words'] = all_words | |
| output['words_by_sentence'] = words | |
| # for each query sentence, mark the highlight information | |
| for i in range(num_query_sent): | |
| output[i] = dict() | |
| for j in range(1, num_cand_sent+1): # for each number of selected sentences from candidate | |
| query_words = word_tokenize(query_sents[i]) | |
| is_selected_sent = np.zeros(num_words) | |
| is_selected_phrase = np.zeros(num_words) | |
| word_scores = np.zeros(num_words) | |
| # for each selected sentences from the candidate, compile information | |
| for sid, sscore in zip(sent_ids[i][:j], sent_scores[i][:j]): | |
| #print(len(sent_start_id), sid, sid+1) | |
| if sid+1 < len(sent_start_id): | |
| sent_range = (sent_start_id[sid], sent_start_id[sid+1]) | |
| is_selected_sent[sent_range[0]:sent_range[1]] = 1 | |
| word_scores[sent_range[0]:sent_range[1]] = sscore | |
| _, is_selected_phrase[sent_range[0]:sent_range[1]] = \ | |
| get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]]) | |
| else: | |
| is_selected_sent[sent_start_id[sid]:] = 1 | |
| word_scores[sent_start_id[sid]:] = sscore | |
| _, is_selected_phrase[sent_start_id[sid]:] = \ | |
| get_match_phrase(query_words, all_words[sent_start_id[sid]:]) | |
| # scale the word_scores: maximum value gets the darkest, minimum value gets the lightest color | |
| if j > 1: | |
| word_scores = scale_scores(word_scores) | |
| # update selected phrase scores (-1 meaning a different color in gradio) | |
| word_scores[is_selected_sent+is_selected_phrase==2] = -0.5 | |
| output[i][j] = { | |
| 'is_selected_sent': is_selected_sent, | |
| 'is_selected_phrase': is_selected_phrase, | |
| 'scores': word_scores | |
| } | |
| return output | |
| def get_highlight_info(model, tokenizer, text1, text2, K=None, top_pair_num=5): | |
| """ | |
| Get highlight information from two texts | |
| """ | |
| sent1 = sent_tokenize(text1) # query | |
| sent2 = sent_tokenize(text2) # candidate | |
| score_mat = compute_sentencewise_scores(model, sent1, sent2, tokenizer=tokenizer) | |
| if K is None: # if K is not set, get all information | |
| K = score_mat.shape[1] | |
| sent_ids, sent_scores = get_top_k(score_mat, K=K) | |
| words2, all_words2, sent_start_id2 = get_words(sent2) | |
| info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores) | |
| # get top sentence pairs from the query and candidate (score, index_pair) to show upfront | |
| top_pairs = [] | |
| ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape) | |
| for i, j in zip(ii[0][::-1], ii[1][::-1]): | |
| score = sent_scores[i,j].item() | |
| index_pair = (i, sent_ids[i,j].item()) | |
| top_pairs.append((score, index_pair)) # list of (score, (sent_id_query, sent_id_candidate)) | |
| # convert top_pairs to corresponding highlights format for GRadio Interpretation component | |
| top_pairs_info = dict() | |
| count = 0 | |
| for s, (sidq, sidc) in top_pairs: | |
| q_sent = sent1[sidq] | |
| c_sent = sent2[sidc] | |
| q_words = word_tokenize(q_sent) | |
| c_words = word_tokenize(c_sent) | |
| mask1, mask2 = get_match_phrase(q_words, c_words) | |
| sc = 0.5 | |
| mask1 *= -sc # mark matching phrases as blue (-1: darkest) | |
| mask2 *= -sc # mark matching phrases as blue | |
| assert(len(mask1) == len(q_words) and len(mask2) == len(c_words)) | |
| # spacing | |
| q_words, mask1 = remove_spaces(q_words, mask1) | |
| c_words, mask2 = remove_spaces(c_words, mask2) | |
| top_pairs_info[count] = { | |
| 'query': { | |
| 'original': q_sent, | |
| 'interpretation': list(zip(q_words, mask1)) | |
| }, | |
| 'candidate': { | |
| 'original': c_sent, | |
| 'interpretation': list(zip(c_words, mask2)) | |
| }, | |
| 'score': s, | |
| 'sent_idx': (sidq, sidc) | |
| } | |
| count += 1 | |
| return sent_ids, sent_scores, info, top_pairs_info | |
| ### Document-level operations | |
| def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20): | |
| # compute document scores for each papers | |
| # concatenate title and abstract | |
| title_abs = [] | |
| for t, a in zip(titles, abstracts): | |
| if t is not None and a is not None: | |
| title_abs.append(t + ' [SEP] ' + a) # title + abstract | |
| num_docs = len(title_abs) | |
| no_iter = int(np.ceil(num_docs / batch)) | |
| scores = [] | |
| with torch.no_grad(): | |
| # batch | |
| for i in tqdm.tqdm(range(no_iter)): | |
| # preprocess the input | |
| inputs = tokenizer( | |
| [query] + title_abs[i*batch:(i+1)*batch], | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512 | |
| ) | |
| inputs.to(doc_model.device) | |
| result = doc_model(**inputs) | |
| # take the first token in the batch as the embedding | |
| embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy() | |
| # compute cosine similarity | |
| q_emb = embeddings[0,:] | |
| p_emb = embeddings[1:,:] | |
| nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1) | |
| scores += list(np.dot(p_emb, q_emb) / nn) | |
| assert(len(scores) == num_docs) | |
| return scores | |
| def compute_document_score(doc_model, tokenizer, query_title, query_abs, papers, batch=5): | |
| scores = [] | |
| titles = [] | |
| abstracts = [] | |
| urls = [] | |
| years = [] | |
| citations = [] | |
| for p in papers: | |
| if p['title'] is not None and p['abstract'] is not None: | |
| titles.append(p['title']) | |
| abstracts.append(p['abstract']) | |
| urls.append(p['url']) | |
| years.append(p['year']) | |
| citations.append(p['citationCount']) | |
| if query_title == '': | |
| query = query_abs | |
| else: | |
| query = query_title + ' [SEP] ' + query_abs # feed in submission title and abstract | |
| scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch) | |
| assert(len(scores) == len(abstracts)) | |
| idx_sorted = np.argsort(scores)[::-1] | |
| titles_sorted = [titles[x] for x in idx_sorted] | |
| abstracts_sorted = [abstracts[x] for x in idx_sorted] | |
| scores_sorted = [scores[x] for x in idx_sorted] | |
| urls_sorted = [urls[x] for x in idx_sorted] | |
| years_sorted = [years[x] for x in idx_sorted] | |
| citations_sorted = [citations[x] for x in idx_sorted] | |
| return titles_sorted, abstracts_sorted, urls_sorted, scores_sorted, years_sorted, citations_sorted |