RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""Language grounding utilities."""
import re
import nltk
import torch
from torch import Tensor
from transformers import BatchEncoding
from vis4d.common.logging import rank_zero_info, rank_zero_warn
def find_noun_phrases(caption: str) -> list:
"""Find noun phrases in a caption using nltk.
Args:
caption (str): The caption to analyze.
Returns:
list: List of noun phrases found in the caption.
Examples:
>>> caption = 'There is two cat and a remote in the picture'
>>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture']
"""
caption = caption.lower()
tokens = nltk.word_tokenize(caption)
pos_tags = nltk.pos_tag(tokens)
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}"
cp = nltk.RegexpParser(grammar)
result = cp.parse(pos_tags)
noun_phrases = []
for subtree in result.subtrees():
if subtree.label() == "NP":
noun_phrases.append(" ".join(t[0] for t in subtree.leaves()))
return noun_phrases
def remove_punctuation(text: str) -> str:
"""Remove punctuation from a text.
Args:
text (str): The input text.
Returns:
str: The text with punctuation removed.
"""
punctuation = [
"|",
":",
";",
"@",
"(",
")",
"[",
"]",
"{",
"}",
"^",
"'",
'"',
"’",
"`",
"?",
"$",
"%",
"#",
"!",
"&",
"*",
"+",
",",
".",
]
for p in punctuation:
text = text.replace(p, "")
return text.strip()
def run_ner(caption: str) -> tuple[list[list[int]], list[str]]:
"""Run NER on a caption and return the tokens and noun phrases.
Args:
caption (str): The input caption.
Returns:
Tuple[List, List]: A tuple containing the tokens and noun phrases.
- tokens_positive (List): A list of token positions.
- noun_phrases (List): A list of noun phrases.
"""
noun_phrases = find_noun_phrases(caption)
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
noun_phrases = [phrase for phrase in noun_phrases if phrase != ""]
rank_zero_info("noun_phrases:", noun_phrases)
relevant_phrases = noun_phrases
labels = noun_phrases
tokens_positive = []
for entity, label in zip(relevant_phrases, labels):
try:
# search all occurrences and mark them as different entities
# TODO: Not Robust
for m in re.finditer(entity, caption.lower()):
tokens_positive.append([[m.start(), m.end()]])
except Exception:
rank_zero_warn("noun entities:", noun_phrases)
rank_zero_warn("entity:", entity)
rank_zero_warn("caption:", caption.lower())
return tokens_positive, noun_phrases
def create_positive_map(
tokenized: BatchEncoding,
tokens_positive: list[list[int]],
max_num_entities: int = 256,
) -> Tensor:
"""construct a map such that positive_map[i,j] = True
if box i is associated to token j
Args:
tokenized: The tokenized input.
tokens_positive (list): A list of token ranges
associated with positive boxes.
max_num_entities (int, optional): The maximum number of entities.
Defaults to 256.
Returns:
torch.Tensor: The positive map.
Raises:
Exception: If an error occurs during token-to-char mapping.
"""
positive_map = torch.zeros(
(len(tokens_positive), max_num_entities), dtype=torch.float
)
for j, tok_list in enumerate(tokens_positive):
for beg, end in tok_list:
try:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
except Exception as e:
print("beg:", beg, "end:", end)
print("token_positive:", tokens_positive)
raise e
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(beg + 2)
except Exception:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(end - 3)
except Exception:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
positive_map[j, beg_pos : end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
def create_positive_map_label_to_token(
positive_map: Tensor, plus: int = 0
) -> dict:
"""Create a dictionary mapping the label to the token.
Args:
positive_map (Tensor): The positive map tensor.
plus (int, optional): Value added to the label for indexing.
Defaults to 0.
Returns:
dict: The dictionary mapping the label to the token.
"""
positive_map_label_to_token = {}
for i in range(len(positive_map)):
positive_map_label_to_token[i + plus] = torch.nonzero(
positive_map[i], as_tuple=True
)[0].tolist()
return positive_map_label_to_token
def clean_label_name(name: str) -> str:
"""Clean label name."""
name = re.sub(r"\(.*\)", "", name)
name = re.sub(r"_", " ", name)
name = re.sub(r" ", " ", name)
return name
def chunks(lst: list, n: int) -> list:
"""Yield successive n-sized chunks from lst."""
all_ = []
for i in range(0, len(lst), n):
data_index = lst[i : i + n]
all_.append(data_index)
counter = 0
for i in all_:
counter += len(i)
assert counter == len(lst)
return all_