Spaces:
Runtime error
Runtime error
| import heapq | |
| import itertools | |
| from abc import ABC, abstractmethod | |
| from collections import defaultdict | |
| from operator import itemgetter | |
| from typing import List, Dict, Tuple | |
| from typing import Sequence | |
| import numpy as np | |
| import torch | |
| from bert_score import BERTScorer | |
| from nltk import PorterStemmer | |
| from spacy.tokens import Doc, Span | |
| from toolz import itertoolz | |
| from transformers import AutoTokenizer | |
| from transformers.tokenization_utils_base import PaddingStrategy | |
| class EmbeddingModel(ABC): | |
| def embed( | |
| self, | |
| sents: List[Span] | |
| ): | |
| pass | |
| class ContextualEmbedding(EmbeddingModel): | |
| def __init__(self, model, tokenizer_name, max_length): | |
| self.model = model | |
| self.tokenizer = SpacyHuggingfaceTokenizer(tokenizer_name, max_length) | |
| self._device = model.device | |
| def embed( | |
| self, | |
| sents: List[Span] | |
| ): | |
| encoded_input, special_tokens_masks, token_alignments = self.tokenizer.batch_encode(sents) | |
| encoded_input = {k: v.to(self._device) for k, v in encoded_input.items()} | |
| with torch.no_grad(): | |
| model_output = self.model(**encoded_input) | |
| embeddings = model_output[0].cpu() | |
| spacy_embs_list = [] | |
| for embs, mask, token_alignment \ | |
| in zip(embeddings, special_tokens_masks, token_alignments): | |
| mask = torch.tensor(mask) | |
| embs = embs[mask == 0] # Filter embeddings at special token positions | |
| spacy_embs = [] | |
| for hf_idxs in token_alignment: | |
| if hf_idxs is None: | |
| pooled_embs = torch.zeros_like(embs[0]) | |
| else: | |
| pooled_embs = embs[hf_idxs].mean(dim=0) # Pool embeddings that map to the same spacy token | |
| spacy_embs.append(pooled_embs.numpy()) | |
| spacy_embs = np.stack(spacy_embs) | |
| spacy_embs = spacy_embs / np.linalg.norm(spacy_embs, axis=-1, keepdims=True) # Normalize | |
| spacy_embs_list.append(spacy_embs) | |
| for embs, sent in zip(spacy_embs_list, sents): | |
| assert len(embs) == len(sent) | |
| return spacy_embs_list | |
| class StaticEmbedding(EmbeddingModel): | |
| def embed( | |
| self, | |
| sents: List[Span] | |
| ): | |
| return [ | |
| np.stack([t.vector / (t.vector_norm or 1) for t in sent]) | |
| for sent in sents | |
| ] | |
| class EmbeddingAligner(): | |
| def __init__( | |
| self, | |
| embedding: EmbeddingModel, | |
| threshold: float, | |
| top_k: int, | |
| baseline_val=0 | |
| ): | |
| self.threshold = threshold | |
| self.top_k = top_k | |
| self.embedding = embedding | |
| self.baseline_val = baseline_val | |
| def align( | |
| self, | |
| source: Doc, | |
| targets: Sequence[Doc] | |
| ) -> List[Dict]: | |
| """Compute alignment from summary tokens to doc tokens with greatest semantic similarity | |
| Args: | |
| source: Source spaCy document | |
| targets: Target spaCy documents | |
| Returns: List of alignments, one for each target document | |
| """ | |
| if len(source) == 0: | |
| return [{} for _ in targets] | |
| all_sents = list(source.sents) + list(itertools.chain.from_iterable(target.sents for target in targets)) | |
| chunk_sizes = [_iter_len(source.sents)] + \ | |
| [_iter_len(target.sents) for target in targets] | |
| all_sents_token_embeddings = self.embedding.embed(all_sents) | |
| chunked_sents_token_embeddings = _split(all_sents_token_embeddings, chunk_sizes) | |
| source_sent_token_embeddings = chunked_sents_token_embeddings[0] | |
| source_token_embeddings = np.concatenate(source_sent_token_embeddings) | |
| for token_idx, token in enumerate(source): | |
| if token.is_stop or token.is_punct: | |
| source_token_embeddings[token_idx] = 0 | |
| alignments = [] | |
| for i, target in enumerate(targets): | |
| target_sent_token_embeddings = chunked_sents_token_embeddings[i + 1] | |
| target_token_embeddings = np.concatenate(target_sent_token_embeddings) | |
| for token_idx, token in enumerate(target): | |
| if token.is_stop or token.is_punct: | |
| target_token_embeddings[token_idx] = 0 | |
| alignment = defaultdict(list) | |
| for score, target_idx, source_idx in self._emb_sim_sparse( | |
| target_token_embeddings, | |
| source_token_embeddings, | |
| ): | |
| alignment[target_idx].append((source_idx, score)) | |
| # TODO used argpartition to get nlargest | |
| for j in list(alignment): | |
| alignment[j] = heapq.nlargest(self.top_k, alignment[j], itemgetter(1)) | |
| alignments.append(alignment) | |
| return alignments | |
| def _emb_sim_sparse(self, embs_1, embs_2): | |
| sim = embs_1 @ embs_2.T | |
| sim = (sim - self.baseline_val) / (1 - self.baseline_val) | |
| keep = sim > self.threshold | |
| keep_idxs_1, keep_idxs_2 = np.where(keep) | |
| keep_scores = sim[keep] | |
| return list(zip(keep_scores, keep_idxs_1, keep_idxs_2)) | |
| class BertscoreAligner(EmbeddingAligner): | |
| def __init__( | |
| self, | |
| threshold, | |
| top_k | |
| ): | |
| scorer = BERTScorer(lang="en", rescale_with_baseline=True) | |
| model = scorer._model | |
| embedding = ContextualEmbedding(model, "roberta-large", 510) | |
| baseline_val = scorer.baseline_vals[2].item() | |
| super(BertscoreAligner, self).__init__( | |
| embedding, threshold, top_k, baseline_val | |
| ) | |
| class StaticEmbeddingAligner(EmbeddingAligner): | |
| def __init__( | |
| self, | |
| threshold, | |
| top_k | |
| ): | |
| embedding = StaticEmbedding() | |
| super(StaticEmbeddingAligner, self).__init__( | |
| embedding, threshold, top_k | |
| ) | |
| class NGramAligner(): | |
| def __init__(self): | |
| self.stemmer = PorterStemmer() | |
| def align( | |
| self, | |
| source: Doc, | |
| targets: List[Doc], | |
| ) -> List[Dict]: | |
| alignments = [] | |
| source_ngram_spans = self._get_ngram_spans(source) | |
| for target in targets: | |
| target_ngram_spans = self._get_ngram_spans(target) | |
| alignments.append( | |
| self._align_ngrams(target_ngram_spans, source_ngram_spans) | |
| ) | |
| return alignments | |
| def _get_ngram_spans( | |
| self, | |
| doc: Doc, | |
| ): | |
| ngrams = [] | |
| for sent in doc.sents: | |
| for n in range(1, len(list(sent))): | |
| tokens = [t for t in sent if not (t.is_stop or t.is_punct)] | |
| ngrams.extend(_ngrams(tokens, n)) | |
| def ngram_key(ngram): | |
| return tuple(self.stemmer.stem(token.text).lower() for token in ngram) | |
| key_to_ngrams = itertoolz.groupby(ngram_key, ngrams) | |
| key_to_spans = {} | |
| for k, grouped_ngrams in key_to_ngrams.items(): | |
| key_to_spans[k] = [ | |
| (ngram[0].i, ngram[-1].i + 1) | |
| for ngram in grouped_ngrams | |
| ] | |
| return key_to_spans | |
| def _align_ngrams( | |
| self, | |
| ngram_spans_1: Dict[Tuple[str], List[Tuple[int, int]]], | |
| ngram_spans_2: Dict[Tuple[str], List[Tuple[int, int]]] | |
| ) -> Dict[Tuple[int, int], List[Tuple[int, int]]]: | |
| """Align ngram spans between two documents | |
| Args: | |
| ngram_spans_1: Map from (normalized_token1, normalized_token2, ...) n-gram tuple to a list of token spans | |
| of format (start_pos, end_pos) | |
| ngram_spans_2: Same format as above, but for second text | |
| Returns: map from each (start, end) span in text 1 to list of aligned (start, end) spans in text 2 | |
| """ | |
| if not ngram_spans_1 or not ngram_spans_2: | |
| return {} | |
| max_span_end_1 = max(span[1] for span in itertools.chain.from_iterable(ngram_spans_1.values())) | |
| token_is_available_1 = [True] * max_span_end_1 # | |
| matched_keys = list(set(ngram_spans_1.keys()) & set(ngram_spans_2.keys())) # Matched normalized ngrams betwee | |
| matched_keys.sort(key=len, reverse=True) # Process n-grams from longest to shortest | |
| alignment = defaultdict(list) # Map from each matched span in text 1 to list of aligned spans in text 2 | |
| for key in matched_keys: | |
| spans_1 = ngram_spans_1[key] | |
| spans_2 = ngram_spans_2[key] | |
| available_spans_1 = [span for span in spans_1 if all(token_is_available_1[slice(*span)])] | |
| matched_spans_1 = [] | |
| if available_spans_1 and spans_2: | |
| # if ngram can be matched to available spans in both sequences | |
| for span in available_spans_1: | |
| # It's possible that these newly matched spans may be overlapping with one another, so | |
| # check that token positions still available (only one span allowed ber token in text 1): | |
| if all(token_is_available_1[slice(*span)]): | |
| matched_spans_1.append(span) | |
| token_is_available_1[slice(*span)] = [False] * (span[1] - span[0]) | |
| for span1 in matched_spans_1: | |
| alignment[span1] = spans_2 | |
| return alignment | |
| class SpacyHuggingfaceTokenizer: | |
| def __init__( | |
| self, | |
| model_name, | |
| max_length | |
| ): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| self.max_length = max_length | |
| def batch_encode( | |
| self, | |
| sents: List[Span] | |
| ): | |
| token_alignments = [] | |
| token_ids_list = [] | |
| # Tokenize each sentence and special tokens. | |
| for sent in sents: | |
| hf_tokens, token_alignment = self.tokenize(sent) | |
| token_alignments.append(token_alignment) | |
| token_ids = self.tokenizer.convert_tokens_to_ids(hf_tokens) | |
| encoding = self.tokenizer.prepare_for_model( | |
| token_ids, | |
| add_special_tokens=True, | |
| padding=False, | |
| ) | |
| token_ids_list.append(encoding['input_ids']) | |
| # Add padding | |
| max_length = max(map(len, token_ids_list)) | |
| attention_mask = [] | |
| input_ids = [] | |
| special_tokens_masks = [] | |
| for token_ids in token_ids_list: | |
| encoding = self.tokenizer.prepare_for_model( | |
| token_ids, | |
| padding=PaddingStrategy.MAX_LENGTH, | |
| max_length=max_length, | |
| add_special_tokens=False | |
| ) | |
| input_ids.append(encoding['input_ids']) | |
| attention_mask.append(encoding['attention_mask']) | |
| special_tokens_masks.append( | |
| self.tokenizer.get_special_tokens_mask( | |
| encoding['input_ids'], | |
| already_has_special_tokens=True | |
| ) | |
| ) | |
| encoded = { | |
| 'input_ids': torch.tensor(input_ids), | |
| 'attention_mask': torch.tensor(attention_mask) | |
| } | |
| return encoded, special_tokens_masks, token_alignments | |
| def tokenize( | |
| self, | |
| sent | |
| ): | |
| """Convert spacy sentence to huggingface tokens and compute the alignment""" | |
| hf_tokens = [] | |
| token_alignment = [] | |
| for i, token in enumerate(sent): | |
| # "Tokenize" each word individually, so as to track the alignment between spaCy/HF tokens | |
| # Prefix all tokens with a space except the first one in the sentence | |
| if i == 0: | |
| token_text = token.text | |
| else: | |
| token_text = ' ' + token.text | |
| start_hf_idx = len(hf_tokens) | |
| word_tokens = self.tokenizer.tokenize(token_text) | |
| end_hf_idx = len(hf_tokens) + len(word_tokens) | |
| if end_hf_idx < self.max_length: | |
| hf_tokens.extend(word_tokens) | |
| hf_idxs = list(range(start_hf_idx, end_hf_idx)) | |
| else: | |
| hf_idxs = None | |
| token_alignment.append(hf_idxs) | |
| return hf_tokens, token_alignment | |
| def _split(data, sizes): | |
| it = iter(data) | |
| return [[next(it) for _ in range(size)] for size in sizes] | |
| def _iter_len(it): | |
| return sum(1 for _ in it) | |
| # TODO set up batching | |
| # To get top K axis and value per row: https://stackoverflow.com/questions/42832711/using-np-argpartition-to-index-values-in-a-multidimensional-array | |
| def _ngrams(tokens, n): | |
| for i in range(len(tokens) - n + 1): | |
| yield tokens[i:i + n] | |