Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,031 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
"""Language grounding utilities."""
import re
import nltk
import torch
from torch import Tensor
from transformers import BatchEncoding
from vis4d.common.logging import rank_zero_info, rank_zero_warn
def find_noun_phrases(caption: str) -> list:
"""Find noun phrases in a caption using nltk.
Args:
caption (str): The caption to analyze.
Returns:
list: List of noun phrases found in the caption.
Examples:
>>> caption = 'There is two cat and a remote in the picture'
>>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture']
"""
caption = caption.lower()
tokens = nltk.word_tokenize(caption)
pos_tags = nltk.pos_tag(tokens)
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}"
cp = nltk.RegexpParser(grammar)
result = cp.parse(pos_tags)
noun_phrases = []
for subtree in result.subtrees():
if subtree.label() == "NP":
noun_phrases.append(" ".join(t[0] for t in subtree.leaves()))
return noun_phrases
def remove_punctuation(text: str) -> str:
"""Remove punctuation from a text.
Args:
text (str): The input text.
Returns:
str: The text with punctuation removed.
"""
punctuation = [
"|",
":",
";",
"@",
"(",
")",
"[",
"]",
"{",
"}",
"^",
"'",
'"',
"’",
"`",
"?",
"$",
"%",
"#",
"!",
"&",
"*",
"+",
",",
".",
]
for p in punctuation:
text = text.replace(p, "")
return text.strip()
def run_ner(caption: str) -> tuple[list[list[int]], list[str]]:
"""Run NER on a caption and return the tokens and noun phrases.
Args:
caption (str): The input caption.
Returns:
Tuple[List, List]: A tuple containing the tokens and noun phrases.
- tokens_positive (List): A list of token positions.
- noun_phrases (List): A list of noun phrases.
"""
noun_phrases = find_noun_phrases(caption)
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
noun_phrases = [phrase for phrase in noun_phrases if phrase != ""]
rank_zero_info("noun_phrases:", noun_phrases)
relevant_phrases = noun_phrases
labels = noun_phrases
tokens_positive = []
for entity, label in zip(relevant_phrases, labels):
try:
# search all occurrences and mark them as different entities
# TODO: Not Robust
for m in re.finditer(entity, caption.lower()):
tokens_positive.append([[m.start(), m.end()]])
except Exception:
rank_zero_warn("noun entities:", noun_phrases)
rank_zero_warn("entity:", entity)
rank_zero_warn("caption:", caption.lower())
return tokens_positive, noun_phrases
def create_positive_map(
tokenized: BatchEncoding,
tokens_positive: list[list[int]],
max_num_entities: int = 256,
) -> Tensor:
"""construct a map such that positive_map[i,j] = True
if box i is associated to token j
Args:
tokenized: The tokenized input.
tokens_positive (list): A list of token ranges
associated with positive boxes.
max_num_entities (int, optional): The maximum number of entities.
Defaults to 256.
Returns:
torch.Tensor: The positive map.
Raises:
Exception: If an error occurs during token-to-char mapping.
"""
positive_map = torch.zeros(
(len(tokens_positive), max_num_entities), dtype=torch.float
)
for j, tok_list in enumerate(tokens_positive):
for beg, end in tok_list:
try:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
except Exception as e:
print("beg:", beg, "end:", end)
print("token_positive:", tokens_positive)
raise e
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(beg + 2)
except Exception:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(end - 3)
except Exception:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
positive_map[j, beg_pos : end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
def create_positive_map_label_to_token(
positive_map: Tensor, plus: int = 0
) -> dict:
"""Create a dictionary mapping the label to the token.
Args:
positive_map (Tensor): The positive map tensor.
plus (int, optional): Value added to the label for indexing.
Defaults to 0.
Returns:
dict: The dictionary mapping the label to the token.
"""
positive_map_label_to_token = {}
for i in range(len(positive_map)):
positive_map_label_to_token[i + plus] = torch.nonzero(
positive_map[i], as_tuple=True
)[0].tolist()
return positive_map_label_to_token
def clean_label_name(name: str) -> str:
"""Clean label name."""
name = re.sub(r"\(.*\)", "", name)
name = re.sub(r"_", " ", name)
name = re.sub(r" ", " ", name)
return name
def chunks(lst: list, n: int) -> list:
"""Yield successive n-sized chunks from lst."""
all_ = []
for i in range(0, len(lst), n):
data_index = lst[i : i + n]
all_.append(data_index)
counter = 0
for i in all_:
counter += len(i)
assert counter == len(lst)
return all_
|