"""Language grounding utilities.""" import re import nltk import torch from torch import Tensor from transformers import BatchEncoding from vis4d.common.logging import rank_zero_info, rank_zero_warn def find_noun_phrases(caption: str) -> list: """Find noun phrases in a caption using nltk. Args: caption (str): The caption to analyze. Returns: list: List of noun phrases found in the caption. Examples: >>> caption = 'There is two cat and a remote in the picture' >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] """ caption = caption.lower() tokens = nltk.word_tokenize(caption) pos_tags = nltk.pos_tag(tokens) grammar = "NP: {
?*+}" cp = nltk.RegexpParser(grammar) result = cp.parse(pos_tags) noun_phrases = [] for subtree in result.subtrees(): if subtree.label() == "NP": noun_phrases.append(" ".join(t[0] for t in subtree.leaves())) return noun_phrases def remove_punctuation(text: str) -> str: """Remove punctuation from a text. Args: text (str): The input text. Returns: str: The text with punctuation removed. """ punctuation = [ "|", ":", ";", "@", "(", ")", "[", "]", "{", "}", "^", "'", '"', "’", "`", "?", "$", "%", "#", "!", "&", "*", "+", ",", ".", ] for p in punctuation: text = text.replace(p, "") return text.strip() def run_ner(caption: str) -> tuple[list[list[int]], list[str]]: """Run NER on a caption and return the tokens and noun phrases. Args: caption (str): The input caption. Returns: Tuple[List, List]: A tuple containing the tokens and noun phrases. - tokens_positive (List): A list of token positions. - noun_phrases (List): A list of noun phrases. """ noun_phrases = find_noun_phrases(caption) noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] noun_phrases = [phrase for phrase in noun_phrases if phrase != ""] rank_zero_info("noun_phrases:", noun_phrases) relevant_phrases = noun_phrases labels = noun_phrases tokens_positive = [] for entity, label in zip(relevant_phrases, labels): try: # search all occurrences and mark them as different entities # TODO: Not Robust for m in re.finditer(entity, caption.lower()): tokens_positive.append([[m.start(), m.end()]]) except Exception: rank_zero_warn("noun entities:", noun_phrases) rank_zero_warn("entity:", entity) rank_zero_warn("caption:", caption.lower()) return tokens_positive, noun_phrases def create_positive_map( tokenized: BatchEncoding, tokens_positive: list[list[int]], max_num_entities: int = 256, ) -> Tensor: """construct a map such that positive_map[i,j] = True if box i is associated to token j Args: tokenized: The tokenized input. tokens_positive (list): A list of token ranges associated with positive boxes. max_num_entities (int, optional): The maximum number of entities. Defaults to 256. Returns: torch.Tensor: The positive map. Raises: Exception: If an error occurs during token-to-char mapping. """ positive_map = torch.zeros( (len(tokens_positive), max_num_entities), dtype=torch.float ) for j, tok_list in enumerate(tokens_positive): for beg, end in tok_list: try: beg_pos = tokenized.char_to_token(beg) end_pos = tokenized.char_to_token(end - 1) except Exception as e: print("beg:", beg, "end:", end) print("token_positive:", tokens_positive) raise e if beg_pos is None: try: beg_pos = tokenized.char_to_token(beg + 1) if beg_pos is None: beg_pos = tokenized.char_to_token(beg + 2) except Exception: beg_pos = None if end_pos is None: try: end_pos = tokenized.char_to_token(end - 2) if end_pos is None: end_pos = tokenized.char_to_token(end - 3) except Exception: end_pos = None if beg_pos is None or end_pos is None: continue assert beg_pos is not None and end_pos is not None positive_map[j, beg_pos : end_pos + 1].fill_(1) return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) def create_positive_map_label_to_token( positive_map: Tensor, plus: int = 0 ) -> dict: """Create a dictionary mapping the label to the token. Args: positive_map (Tensor): The positive map tensor. plus (int, optional): Value added to the label for indexing. Defaults to 0. Returns: dict: The dictionary mapping the label to the token. """ positive_map_label_to_token = {} for i in range(len(positive_map)): positive_map_label_to_token[i + plus] = torch.nonzero( positive_map[i], as_tuple=True )[0].tolist() return positive_map_label_to_token def clean_label_name(name: str) -> str: """Clean label name.""" name = re.sub(r"\(.*\)", "", name) name = re.sub(r"_", " ", name) name = re.sub(r" ", " ", name) return name def chunks(lst: list, n: int) -> list: """Yield successive n-sized chunks from lst.""" all_ = [] for i in range(0, len(lst), n): data_index = lst[i : i + n] all_.append(data_index) counter = 0 for i in all_: counter += len(i) assert counter == len(lst) return all_