Spaces:
Running
Running
Commit
·
a5b4e6b
1
Parent(s):
6bd742d
Upload 106 files
Browse files- app.py +2 -3
- demo_all_text_embedding_cache.pth +3 -0
- fcclip/fcclip.py +60 -21
app.py
CHANGED
|
@@ -27,8 +27,6 @@ from detectron2.data import MetadataCatalog
|
|
| 27 |
from detectron2.projects.deeplab import add_deeplab_config
|
| 28 |
|
| 29 |
|
| 30 |
-
coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
| 31 |
-
|
| 32 |
# import FCCLIP project
|
| 33 |
from fcclip import add_maskformer2_config, add_fcclip_config
|
| 34 |
from demo.predictor import DefaultPredictor, OpenVocabVisualizer
|
|
@@ -46,6 +44,7 @@ add_maskformer2_config(cfg)
|
|
| 46 |
add_fcclip_config(cfg)
|
| 47 |
cfg.merge_from_file("configs/coco/panoptic-segmentation/fcclip/fcclip_convnext_large_eval_ade20k.yaml")
|
| 48 |
os.system("gdown 1-91PIns86vyNaL3CzMmDD39zKGnPMtvj")
|
|
|
|
| 49 |
cfg.MODEL.WEIGHTS = './fcclip_cocopan.pth'
|
| 50 |
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = False
|
| 51 |
cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
|
|
@@ -160,7 +159,7 @@ def inference(image_path, vocab, label_list):
|
|
| 160 |
|
| 161 |
im = cv2.imread(image_path)
|
| 162 |
outputs = predictor(im)
|
| 163 |
-
v = OpenVocabVisualizer(im[:, :, ::-1], demo_metadata, scale=1.
|
| 164 |
panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
|
| 165 |
return Image.fromarray(np.uint8(panoptic_result)).convert('RGB')
|
| 166 |
|
|
|
|
| 27 |
from detectron2.projects.deeplab import add_deeplab_config
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
| 30 |
# import FCCLIP project
|
| 31 |
from fcclip import add_maskformer2_config, add_fcclip_config
|
| 32 |
from demo.predictor import DefaultPredictor, OpenVocabVisualizer
|
|
|
|
| 44 |
add_fcclip_config(cfg)
|
| 45 |
cfg.merge_from_file("configs/coco/panoptic-segmentation/fcclip/fcclip_convnext_large_eval_ade20k.yaml")
|
| 46 |
os.system("gdown 1-91PIns86vyNaL3CzMmDD39zKGnPMtvj")
|
| 47 |
+
os.system("gdown 1-91PIns86vyNaL3CzMmDD39zKGnPMtvj")
|
| 48 |
cfg.MODEL.WEIGHTS = './fcclip_cocopan.pth'
|
| 49 |
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = False
|
| 50 |
cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
|
|
|
|
| 159 |
|
| 160 |
im = cv2.imread(image_path)
|
| 161 |
outputs = predictor(im)
|
| 162 |
+
v = OpenVocabVisualizer(im[:, :, ::-1], demo_metadata, scale=1.0, instance_mode=ColorMode.IMAGE)
|
| 163 |
panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
|
| 164 |
return Image.fromarray(np.uint8(panoptic_result)).convert('RGB')
|
| 165 |
|
demo_all_text_embedding_cache.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ee4c83884a03f41e1078a5b0916f6a26606258c0031e4e22e74c93c6672e9c9
|
| 3 |
+
size 7848107
|
fcclip/fcclip.py
CHANGED
|
@@ -18,6 +18,7 @@ from .modeling.matcher import HungarianMatcher
|
|
| 18 |
|
| 19 |
|
| 20 |
from .modeling.transformer_decoder.fcclip_transformer_decoder import MaskPooling, get_classification_logits
|
|
|
|
| 21 |
VILD_PROMPT = [
|
| 22 |
"a photo of a {}.",
|
| 23 |
"This is a photo of a {}",
|
|
@@ -35,6 +36,20 @@ VILD_PROMPT = [
|
|
| 35 |
"There is a large {} in the scene.",
|
| 36 |
]
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
@META_ARCH_REGISTRY.register()
|
| 40 |
class FCCLIP(nn.Module):
|
|
@@ -129,14 +144,15 @@ class FCCLIP(nn.Module):
|
|
| 129 |
_, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
|
| 130 |
self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
def prepare_class_names_from_metadata(self, metadata, train_metadata):
|
| 133 |
-
def split_labels(x):
|
| 134 |
-
res = []
|
| 135 |
-
for x_ in x:
|
| 136 |
-
x_ = x_.replace(', ', ',')
|
| 137 |
-
x_ = x_.split(',') # there can be multiple synonyms for single class
|
| 138 |
-
res.append(x_)
|
| 139 |
-
return res
|
| 140 |
# get text classifier
|
| 141 |
try:
|
| 142 |
class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
|
|
@@ -152,13 +168,6 @@ class FCCLIP(nn.Module):
|
|
| 152 |
category_overlapping_list.append(is_overlapping)
|
| 153 |
category_overlapping_mask = torch.tensor(
|
| 154 |
category_overlapping_list, dtype=torch.long)
|
| 155 |
-
|
| 156 |
-
def fill_all_templates_ensemble(x_=''):
|
| 157 |
-
res = []
|
| 158 |
-
for x in x_:
|
| 159 |
-
for template in VILD_PROMPT:
|
| 160 |
-
res.append(template.format(x))
|
| 161 |
-
return res, len(res) // len(VILD_PROMPT)
|
| 162 |
|
| 163 |
num_templates = []
|
| 164 |
templated_class_names = []
|
|
@@ -195,17 +204,47 @@ class FCCLIP(nn.Module):
|
|
| 195 |
return self.train_text_classifier, self.train_num_templates
|
| 196 |
else:
|
| 197 |
if self.test_text_classifier is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
text_classifier = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
# this is needed to avoid oom, which may happen when num of class is large
|
| 200 |
bs = 128
|
| 201 |
-
for idx in range(0, len(
|
| 202 |
-
text_classifier.append(self.backbone.get_text_classifier(
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
| 209 |
self.test_text_classifier = text_classifier
|
| 210 |
return self.test_text_classifier, self.test_num_templates
|
| 211 |
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
from .modeling.transformer_decoder.fcclip_transformer_decoder import MaskPooling, get_classification_logits
|
| 21 |
+
import os
|
| 22 |
VILD_PROMPT = [
|
| 23 |
"a photo of a {}.",
|
| 24 |
"This is a photo of a {}",
|
|
|
|
| 36 |
"There is a large {} in the scene.",
|
| 37 |
]
|
| 38 |
|
| 39 |
+
def split_labels(x):
|
| 40 |
+
res = []
|
| 41 |
+
for x_ in x:
|
| 42 |
+
x_ = x_.replace(', ', ',')
|
| 43 |
+
x_ = x_.split(',') # there can be multiple synonyms for single class
|
| 44 |
+
res.append(x_)
|
| 45 |
+
return res
|
| 46 |
+
|
| 47 |
+
def fill_all_templates_ensemble(x_=''):
|
| 48 |
+
res = []
|
| 49 |
+
for x in x_:
|
| 50 |
+
for template in VILD_PROMPT:
|
| 51 |
+
res.append(template.format(x))
|
| 52 |
+
return res, len(res) // len(VILD_PROMPT)
|
| 53 |
|
| 54 |
@META_ARCH_REGISTRY.register()
|
| 55 |
class FCCLIP(nn.Module):
|
|
|
|
| 144 |
_, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
|
| 145 |
self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
|
| 146 |
|
| 147 |
+
self.demo_all_text_embedding_cache = {}
|
| 148 |
+
# This consists of COCO, ADE20K, LVIS
|
| 149 |
+
if os.path.exists("demo_all_text_embedding_cache.pth"):
|
| 150 |
+
# key: str of class name, value: tensor in shape of C
|
| 151 |
+
self.demo_all_text_embedding_cache = torch.load("demo_all_text_embedding_cache.pth", map_location=self.device)
|
| 152 |
+
self.demo_all_text_embedding_cache = {k:v.to(self.device) for k,v in self.demo_all_text_embedding_cache.items()}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
def prepare_class_names_from_metadata(self, metadata, train_metadata):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
# get text classifier
|
| 157 |
try:
|
| 158 |
class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
|
|
|
|
| 168 |
category_overlapping_list.append(is_overlapping)
|
| 169 |
category_overlapping_mask = torch.tensor(
|
| 170 |
category_overlapping_list, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
num_templates = []
|
| 173 |
templated_class_names = []
|
|
|
|
| 204 |
return self.train_text_classifier, self.train_num_templates
|
| 205 |
else:
|
| 206 |
if self.test_text_classifier is None:
|
| 207 |
+
try:
|
| 208 |
+
nontemplated_class_names = split_labels(self.test_metadata.stuff_classes) # it includes both thing and stuff
|
| 209 |
+
except:
|
| 210 |
+
# this could be for insseg, where only thing_classes are available
|
| 211 |
+
nontemplated_class_names = split_labels(self.test_metadata.thing_classes)
|
| 212 |
+
|
| 213 |
+
text2classifier = {}
|
| 214 |
+
test_class_names = []
|
| 215 |
+
uncached_class_name = []
|
| 216 |
text_classifier = []
|
| 217 |
+
# exclude those already in cache
|
| 218 |
+
for class_names in nontemplated_class_names:
|
| 219 |
+
for class_name in class_names:
|
| 220 |
+
if class_name in self.demo_all_text_embedding_cache:
|
| 221 |
+
text2classifier[class_name] = self.demo_all_text_embedding_cache[class_name].to(self.device)
|
| 222 |
+
else:
|
| 223 |
+
test_class_names += fill_all_templates_ensemble([class_name])[0]
|
| 224 |
+
uncached_class_name.append(class_name)
|
| 225 |
+
print("Uncached texts:", len(uncached_class_name), uncached_class_name, test_class_names)
|
| 226 |
# this is needed to avoid oom, which may happen when num of class is large
|
| 227 |
bs = 128
|
| 228 |
+
for idx in range(0, len(test_class_names), bs):
|
| 229 |
+
text_classifier.append(self.backbone.get_text_classifier(test_class_names[idx:idx+bs], self.device).detach())
|
| 230 |
+
|
| 231 |
+
if len(text_classifier) > 0:
|
| 232 |
+
text_classifier = torch.cat(text_classifier, dim=0)
|
| 233 |
+
# average across templates and normalization.
|
| 234 |
+
text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
|
| 235 |
+
text_classifier = text_classifier.reshape(text_classifier.shape[0]//len(VILD_PROMPT), len(VILD_PROMPT), text_classifier.shape[-1]).mean(1)
|
| 236 |
+
text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
|
| 237 |
+
assert text_classifier.shape[0] == len(uncached_class_name)
|
| 238 |
+
for idx in range(len(uncached_class_name)):
|
| 239 |
+
self.demo_all_text_embedding_cache[uncached_class_name[idx]] = text_classifier[idx]
|
| 240 |
+
text2classifier[uncached_class_name[idx]] = text_classifier[idx]
|
| 241 |
+
#torch.save({k:v for k, v in self.demo_all_text_embedding_cache.items()}, "demo_all_text_embedding_cache.pth")
|
| 242 |
|
| 243 |
+
text_classifier = []
|
| 244 |
+
for class_names in nontemplated_class_names:
|
| 245 |
+
for text in class_names:
|
| 246 |
+
text_classifier.append(text2classifier[text].to(self.device))
|
| 247 |
+
text_classifier = torch.stack(text_classifier, dim=0).to(self.device)
|
| 248 |
self.test_text_classifier = text_classifier
|
| 249 |
return self.test_text_classifier, self.test_num_templates
|
| 250 |
|