Spaces:
Runtime error
Runtime error
File size: 6,029 Bytes
ac046a6 37ea0ab 52e1cb2 378efa2 37ea0ab 52e1cb2 37ea0ab 52e1cb2 378efa2 37ea0ab 52e1cb2 ac046a6 52e1cb2 378efa2 4739b3b 378efa2 52e1cb2 4739b3b 378efa2 52e1cb2 378efa2 52e1cb2 1230385 e04f1e3 378efa2 1230385 378efa2 e04f1e3 1230385 378efa2 52e1cb2 ac046a6 378efa2 ac046a6 378efa2 ac046a6 378efa2 52e1cb2 378efa2 52e1cb2 378efa2 52e1cb2 378efa2 52e1cb2 378efa2 ac046a6 1230385 52e1cb2 ac046a6 378efa2 ac046a6 52e1cb2 378efa2 52e1cb2 378efa2 37ea0ab ac046a6 378efa2 37ea0ab ac046a6 37ea0ab 52e1cb2 ac046a6 37ea0ab 378efa2 52e1cb2 ac046a6 378efa2 ac046a6 52e1cb2 ac046a6 52e1cb2 37ea0ab ac046a6 52e1cb2 ac046a6 37ea0ab 378efa2 37ea0ab 378efa2 37ea0ab 378efa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import os
import re
import cv2
import torch
import numpy as np
import torch.nn as nn
import mmengine
from mmengine.logging.history_buffer import HistoryBuffer
import numpy.core.multiarray as cam
import numpy.dtypes as dtypes
_orig_torch_load = torch.load
torch.load = lambda f, **kwargs: _orig_torch_load(f, weights_only=False, **kwargs)
torch.serialization.add_safe_globals([
cam.scalar, HistoryBuffer, cam._reconstruct, np.ndarray, np.dtype,
dtypes.Float64DType, dtypes.Int64DType, getattr,
])
from functools import lru_cache
from typing import List, Tuple, Dict, Union
from huggingface_hub import hf_hub_download
from mmengine.config import Config
from mmdet3d.apis import init_model, inference_mono_3d_detector
from model.utils import plot_bev_detections
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_REGISTRY: Dict[str, Tuple[str, Tuple[str, str]]] = {
"regnetx4.0gf+detr3d": (
"model/DETR3D/detr3d_r101_gridmask.py",
("yaghi27/RegnetX4.0GF_DETR3D", "epoch_30.pth"),
),
"regnetx4.0gf+petr": (
"model/PETR/petr_vovnet_gridmask_p4_800x320.py",
("yaghi27/RegnetX4.0GF_PETR", "epoch_24.pth"),
),
}
CAM_ORDER = [
"CAM_FRONT",
"CAM_FRONT_LEFT",
"CAM_FRONT_RIGHT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_BACK_RIGHT",
]
DUMMY_4x4 = np.eye(4, dtype=np.float32)
def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = None) -> str:
# default to a writable cache location in Spaces/containers
cache_dir = cache_dir or os.environ.get("HF_HUB_CACHE") or "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
ckpt_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
revision="main",
cache_dir=cache_dir,
token=os.environ.get("HF_TOKEN"),
)
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}")
return ckpt_path
@lru_cache(maxsize=2)
def get_model(model_key: str):
"""Load and cache a detector for the selected model."""
if model_key not in MODEL_REGISTRY:
raise ValueError(f"Unknown model '{model_key}'. Available: {list(MODEL_REGISTRY)}")
config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key]
# Ensure local config exists
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
# Download checkpoint from the Hub
ckpt_path = _download_ckpt_from_hf(repo_id, hf_file)
# Load config and avoid auto-pretraining downloads
cfg = Config.fromfile(config_path)
if hasattr(cfg, "model") and isinstance(cfg.model, dict):
cfg.model.setdefault("pretrained", None)
if "backbone" in cfg.model and isinstance(cfg.model["backbone"], dict):
cfg.model["backbone"].setdefault("init_cfg", None)
# Build model
model = init_model(config_path, ckpt_path, device=DEVICE)
model.eval()
# Some backbones expect a batch dimension; enforce it
if hasattr(model, "img_backbone") and hasattr(model.img_backbone, "forward"):
original_forward = model.img_backbone.forward
def _ensure_batch(x):
if x.dim() == 3:
x = x.unsqueeze(0)
return original_forward(x)
model.img_backbone.forward = _ensure_batch
return model
def _index_from_name(path: str) -> int:
"""
Extract 0..5 from filenames like `cam_0.png` written by the backend.
Falls back to natural sort order if no index found.
"""
m = re.search(r'(\d+)(?=\.[^.]+$)', os.path.basename(path))
return int(m.group(1)) if m else 9999
def infer_single(model, img_path: str, cam_key: str):
"""Run single-frame mono-3D inference with a provided camera key."""
data_info = {
'images': {
cam_key: {
'img_path': img_path,
'cam2img': DUMMY_4x4,
'lidar2cam': DUMMY_4x4
}
},
'img_shape': (800, 450, 3), # adjust if your inputs differ
'scale_factor': 1.0
}
ann_file = img_path + '.pkl'
mmengine.dump({'data_list': [data_info]}, ann_file)
try:
result = inference_mono_3d_detector(model, img_path, ann_file, cam_type=cam_key)
finally:
try:
os.remove(ann_file)
except OSError:
pass
return result
def infer_images(img_paths: List[str], model: str) -> List[str]:
"""
img_paths: list of 6 images (front, front_left, front_right, back, back_left, back_right)
model: 'regnetx4.0gf+detr3d' or 'regnetx4.0gf+petr'
Returns: list with path(s) to rendered BEV image(s)
"""
model_inst = get_model(model)
ordered = sorted(img_paths, key=_index_from_name)
if len(ordered) != 6:
raise ValueError(f"Expected 6 images, got {len(ordered)}")
all_boxes, all_scores, all_labels = [], [], []
for idx, img_path in enumerate(ordered):
cam_key = CAM_ORDER[idx]
res = infer_single(model_inst, img_path, cam_key)
# Collect detections
all_boxes.append(res.pred_instances_3d.bboxes_3d.tensor)
all_scores.append(res.pred_instances_3d.scores_3d)
all_labels.append(res.pred_instances_3d.labels_3d)
boxes = torch.cat(all_boxes, dim=0) if len(all_boxes) else torch.empty((0, 7))
scores = torch.cat(all_scores, dim=0) if len(all_scores) else torch.empty((0,))
labels = torch.cat(all_labels, dim=0) if len(all_labels) else torch.empty((0,), dtype=torch.long)
combined_path = os.path.join("/tmp", "combined_bev.png")
plot_bev_detections(boxes, scores, labels, save_path=combined_path)
return [combined_path]
if __name__ == "__main__":
example_imgs = [f"/path/to/cam_{i}.png" for i in range(6)]
if all(os.path.isfile(p) for p in example_imgs):
out = infer_images(example_imgs, model="regnetx4.0gf+petr")
print("Saved:", out)
else:
print("Provide six images named cam_0..cam_5.png to run the demo.")
|