Spaces:
Running
on
Zero
Running
on
Zero
| """Language grounding utilities.""" | |
| import re | |
| import nltk | |
| import torch | |
| from torch import Tensor | |
| from transformers import BatchEncoding | |
| from vis4d.common.logging import rank_zero_info, rank_zero_warn | |
| def find_noun_phrases(caption: str) -> list: | |
| """Find noun phrases in a caption using nltk. | |
| Args: | |
| caption (str): The caption to analyze. | |
| Returns: | |
| list: List of noun phrases found in the caption. | |
| Examples: | |
| >>> caption = 'There is two cat and a remote in the picture' | |
| >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] | |
| """ | |
| caption = caption.lower() | |
| tokens = nltk.word_tokenize(caption) | |
| pos_tags = nltk.pos_tag(tokens) | |
| grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}" | |
| cp = nltk.RegexpParser(grammar) | |
| result = cp.parse(pos_tags) | |
| noun_phrases = [] | |
| for subtree in result.subtrees(): | |
| if subtree.label() == "NP": | |
| noun_phrases.append(" ".join(t[0] for t in subtree.leaves())) | |
| return noun_phrases | |
| def remove_punctuation(text: str) -> str: | |
| """Remove punctuation from a text. | |
| Args: | |
| text (str): The input text. | |
| Returns: | |
| str: The text with punctuation removed. | |
| """ | |
| punctuation = [ | |
| "|", | |
| ":", | |
| ";", | |
| "@", | |
| "(", | |
| ")", | |
| "[", | |
| "]", | |
| "{", | |
| "}", | |
| "^", | |
| "'", | |
| '"', | |
| "’", | |
| "`", | |
| "?", | |
| "$", | |
| "%", | |
| "#", | |
| "!", | |
| "&", | |
| "*", | |
| "+", | |
| ",", | |
| ".", | |
| ] | |
| for p in punctuation: | |
| text = text.replace(p, "") | |
| return text.strip() | |
| def run_ner(caption: str) -> tuple[list[list[int]], list[str]]: | |
| """Run NER on a caption and return the tokens and noun phrases. | |
| Args: | |
| caption (str): The input caption. | |
| Returns: | |
| Tuple[List, List]: A tuple containing the tokens and noun phrases. | |
| - tokens_positive (List): A list of token positions. | |
| - noun_phrases (List): A list of noun phrases. | |
| """ | |
| noun_phrases = find_noun_phrases(caption) | |
| noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] | |
| noun_phrases = [phrase for phrase in noun_phrases if phrase != ""] | |
| rank_zero_info("noun_phrases:", noun_phrases) | |
| relevant_phrases = noun_phrases | |
| labels = noun_phrases | |
| tokens_positive = [] | |
| for entity, label in zip(relevant_phrases, labels): | |
| try: | |
| # search all occurrences and mark them as different entities | |
| # TODO: Not Robust | |
| for m in re.finditer(entity, caption.lower()): | |
| tokens_positive.append([[m.start(), m.end()]]) | |
| except Exception: | |
| rank_zero_warn("noun entities:", noun_phrases) | |
| rank_zero_warn("entity:", entity) | |
| rank_zero_warn("caption:", caption.lower()) | |
| return tokens_positive, noun_phrases | |
| def create_positive_map( | |
| tokenized: BatchEncoding, | |
| tokens_positive: list[list[int]], | |
| max_num_entities: int = 256, | |
| ) -> Tensor: | |
| """construct a map such that positive_map[i,j] = True | |
| if box i is associated to token j | |
| Args: | |
| tokenized: The tokenized input. | |
| tokens_positive (list): A list of token ranges | |
| associated with positive boxes. | |
| max_num_entities (int, optional): The maximum number of entities. | |
| Defaults to 256. | |
| Returns: | |
| torch.Tensor: The positive map. | |
| Raises: | |
| Exception: If an error occurs during token-to-char mapping. | |
| """ | |
| positive_map = torch.zeros( | |
| (len(tokens_positive), max_num_entities), dtype=torch.float | |
| ) | |
| for j, tok_list in enumerate(tokens_positive): | |
| for beg, end in tok_list: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg) | |
| end_pos = tokenized.char_to_token(end - 1) | |
| except Exception as e: | |
| print("beg:", beg, "end:", end) | |
| print("token_positive:", tokens_positive) | |
| raise e | |
| if beg_pos is None: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg + 1) | |
| if beg_pos is None: | |
| beg_pos = tokenized.char_to_token(beg + 2) | |
| except Exception: | |
| beg_pos = None | |
| if end_pos is None: | |
| try: | |
| end_pos = tokenized.char_to_token(end - 2) | |
| if end_pos is None: | |
| end_pos = tokenized.char_to_token(end - 3) | |
| except Exception: | |
| end_pos = None | |
| if beg_pos is None or end_pos is None: | |
| continue | |
| assert beg_pos is not None and end_pos is not None | |
| positive_map[j, beg_pos : end_pos + 1].fill_(1) | |
| return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
| def create_positive_map_label_to_token( | |
| positive_map: Tensor, plus: int = 0 | |
| ) -> dict: | |
| """Create a dictionary mapping the label to the token. | |
| Args: | |
| positive_map (Tensor): The positive map tensor. | |
| plus (int, optional): Value added to the label for indexing. | |
| Defaults to 0. | |
| Returns: | |
| dict: The dictionary mapping the label to the token. | |
| """ | |
| positive_map_label_to_token = {} | |
| for i in range(len(positive_map)): | |
| positive_map_label_to_token[i + plus] = torch.nonzero( | |
| positive_map[i], as_tuple=True | |
| )[0].tolist() | |
| return positive_map_label_to_token | |
| def clean_label_name(name: str) -> str: | |
| """Clean label name.""" | |
| name = re.sub(r"\(.*\)", "", name) | |
| name = re.sub(r"_", " ", name) | |
| name = re.sub(r" ", " ", name) | |
| return name | |
| def chunks(lst: list, n: int) -> list: | |
| """Yield successive n-sized chunks from lst.""" | |
| all_ = [] | |
| for i in range(0, len(lst), n): | |
| data_index = lst[i : i + n] | |
| all_.append(data_index) | |
| counter = 0 | |
| for i in all_: | |
| counter += len(i) | |
| assert counter == len(lst) | |
| return all_ | |