ImageToBEV-lightweight / model /run_inference.py
yaghi27's picture
Update model/run_inference.py
e04f1e3 verified
raw
history blame
6.03 kB
import os
import re
import cv2
import torch
import numpy as np
import torch.nn as nn
import mmengine
from mmengine.logging.history_buffer import HistoryBuffer
import numpy.core.multiarray as cam
import numpy.dtypes as dtypes
_orig_torch_load = torch.load
torch.load = lambda f, **kwargs: _orig_torch_load(f, weights_only=False, **kwargs)
torch.serialization.add_safe_globals([
cam.scalar, HistoryBuffer, cam._reconstruct, np.ndarray, np.dtype,
dtypes.Float64DType, dtypes.Int64DType, getattr,
])
from functools import lru_cache
from typing import List, Tuple, Dict, Union
from huggingface_hub import hf_hub_download
from mmengine.config import Config
from mmdet3d.apis import init_model, inference_mono_3d_detector
from model.utils import plot_bev_detections
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_REGISTRY: Dict[str, Tuple[str, Tuple[str, str]]] = {
"regnetx4.0gf+detr3d": (
"model/DETR3D/detr3d_r101_gridmask.py",
("yaghi27/RegnetX4.0GF_DETR3D", "epoch_30.pth"),
),
"regnetx4.0gf+petr": (
"model/PETR/petr_vovnet_gridmask_p4_800x320.py",
("yaghi27/RegnetX4.0GF_PETR", "epoch_24.pth"),
),
}
CAM_ORDER = [
"CAM_FRONT",
"CAM_FRONT_LEFT",
"CAM_FRONT_RIGHT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_BACK_RIGHT",
]
DUMMY_4x4 = np.eye(4, dtype=np.float32)
def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = None) -> str:
# default to a writable cache location in Spaces/containers
cache_dir = cache_dir or os.environ.get("HF_HUB_CACHE") or "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
ckpt_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
revision="main",
cache_dir=cache_dir,
token=os.environ.get("HF_TOKEN"),
)
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}")
return ckpt_path
@lru_cache(maxsize=2)
def get_model(model_key: str):
"""Load and cache a detector for the selected model."""
if model_key not in MODEL_REGISTRY:
raise ValueError(f"Unknown model '{model_key}'. Available: {list(MODEL_REGISTRY)}")
config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key]
# Ensure local config exists
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
# Download checkpoint from the Hub
ckpt_path = _download_ckpt_from_hf(repo_id, hf_file)
# Load config and avoid auto-pretraining downloads
cfg = Config.fromfile(config_path)
if hasattr(cfg, "model") and isinstance(cfg.model, dict):
cfg.model.setdefault("pretrained", None)
if "backbone" in cfg.model and isinstance(cfg.model["backbone"], dict):
cfg.model["backbone"].setdefault("init_cfg", None)
# Build model
model = init_model(config_path, ckpt_path, device=DEVICE)
model.eval()
# Some backbones expect a batch dimension; enforce it
if hasattr(model, "img_backbone") and hasattr(model.img_backbone, "forward"):
original_forward = model.img_backbone.forward
def _ensure_batch(x):
if x.dim() == 3:
x = x.unsqueeze(0)
return original_forward(x)
model.img_backbone.forward = _ensure_batch
return model
def _index_from_name(path: str) -> int:
"""
Extract 0..5 from filenames like `cam_0.png` written by the backend.
Falls back to natural sort order if no index found.
"""
m = re.search(r'(\d+)(?=\.[^.]+$)', os.path.basename(path))
return int(m.group(1)) if m else 9999
def infer_single(model, img_path: str, cam_key: str):
"""Run single-frame mono-3D inference with a provided camera key."""
data_info = {
'images': {
cam_key: {
'img_path': img_path,
'cam2img': DUMMY_4x4,
'lidar2cam': DUMMY_4x4
}
},
'img_shape': (800, 450, 3), # adjust if your inputs differ
'scale_factor': 1.0
}
ann_file = img_path + '.pkl'
mmengine.dump({'data_list': [data_info]}, ann_file)
try:
result = inference_mono_3d_detector(model, img_path, ann_file, cam_type=cam_key)
finally:
try:
os.remove(ann_file)
except OSError:
pass
return result
def infer_images(img_paths: List[str], model: str) -> List[str]:
"""
img_paths: list of 6 images (front, front_left, front_right, back, back_left, back_right)
model: 'regnetx4.0gf+detr3d' or 'regnetx4.0gf+petr'
Returns: list with path(s) to rendered BEV image(s)
"""
model_inst = get_model(model)
ordered = sorted(img_paths, key=_index_from_name)
if len(ordered) != 6:
raise ValueError(f"Expected 6 images, got {len(ordered)}")
all_boxes, all_scores, all_labels = [], [], []
for idx, img_path in enumerate(ordered):
cam_key = CAM_ORDER[idx]
res = infer_single(model_inst, img_path, cam_key)
# Collect detections
all_boxes.append(res.pred_instances_3d.bboxes_3d.tensor)
all_scores.append(res.pred_instances_3d.scores_3d)
all_labels.append(res.pred_instances_3d.labels_3d)
boxes = torch.cat(all_boxes, dim=0) if len(all_boxes) else torch.empty((0, 7))
scores = torch.cat(all_scores, dim=0) if len(all_scores) else torch.empty((0,))
labels = torch.cat(all_labels, dim=0) if len(all_labels) else torch.empty((0,), dtype=torch.long)
combined_path = os.path.join("/tmp", "combined_bev.png")
plot_bev_detections(boxes, scores, labels, save_path=combined_path)
return [combined_path]
if __name__ == "__main__":
example_imgs = [f"/path/to/cam_{i}.png" for i in range(6)]
if all(os.path.isfile(p) for p in example_imgs):
out = infer_images(example_imgs, model="regnetx4.0gf+petr")
print("Saved:", out)
else:
print("Provide six images named cam_0..cam_5.png to run the demo.")