Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| import mmengine | |
| from mmengine.logging.history_buffer import HistoryBuffer | |
| import numpy.core.multiarray as cam | |
| import numpy.dtypes as dtypes | |
| _orig_torch_load = torch.load | |
| torch.load = lambda f, **kwargs: _orig_torch_load(f, weights_only=False, **kwargs) | |
| torch.serialization.add_safe_globals([ | |
| cam.scalar, HistoryBuffer, cam._reconstruct, np.ndarray, np.dtype, | |
| dtypes.Float64DType, dtypes.Int64DType, getattr, | |
| ]) | |
| from functools import lru_cache | |
| from typing import List, Tuple, Dict, Union | |
| from huggingface_hub import hf_hub_download | |
| from mmengine.config import Config | |
| from mmdet3d.apis import init_model, inference_mono_3d_detector | |
| from model.utils import plot_bev_detections | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| MODEL_REGISTRY: Dict[str, Tuple[str, Tuple[str, str]]] = { | |
| "regnetx4.0gf+detr3d": ( | |
| "model/DETR3D/detr3d_r101_gridmask.py", | |
| ("yaghi27/RegnetX4.0GF_DETR3D", "epoch_30.pth"), | |
| ), | |
| "regnetx4.0gf+petr": ( | |
| "model/PETR/petr_vovnet_gridmask_p4_800x320.py", | |
| ("yaghi27/RegnetX4.0GF_PETR", "epoch_24.pth"), | |
| ), | |
| } | |
| CAM_ORDER = [ | |
| "CAM_FRONT", | |
| "CAM_FRONT_LEFT", | |
| "CAM_FRONT_RIGHT", | |
| "CAM_BACK", | |
| "CAM_BACK_LEFT", | |
| "CAM_BACK_RIGHT", | |
| ] | |
| DUMMY_4x4 = np.eye(4, dtype=np.float32) | |
| def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = None) -> str: | |
| # default to a writable cache location in Spaces/containers | |
| cache_dir = cache_dir or os.environ.get("HF_HUB_CACHE") or "/tmp/hf_cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| ckpt_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| revision="main", | |
| cache_dir=cache_dir, | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| if not os.path.isfile(ckpt_path): | |
| raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}") | |
| return ckpt_path | |
| def get_model(model_key: str): | |
| """Load and cache a detector for the selected model.""" | |
| if model_key not in MODEL_REGISTRY: | |
| raise ValueError(f"Unknown model '{model_key}'. Available: {list(MODEL_REGISTRY)}") | |
| config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key] | |
| # Ensure local config exists | |
| if not os.path.isfile(config_path): | |
| raise FileNotFoundError(f"Config not found: {config_path}") | |
| # Download checkpoint from the Hub | |
| ckpt_path = _download_ckpt_from_hf(repo_id, hf_file) | |
| # Load config and avoid auto-pretraining downloads | |
| cfg = Config.fromfile(config_path) | |
| if hasattr(cfg, "model") and isinstance(cfg.model, dict): | |
| cfg.model.setdefault("pretrained", None) | |
| if "backbone" in cfg.model and isinstance(cfg.model["backbone"], dict): | |
| cfg.model["backbone"].setdefault("init_cfg", None) | |
| # Build model | |
| model = init_model(config_path, ckpt_path, device=DEVICE) | |
| model.eval() | |
| # Some backbones expect a batch dimension; enforce it | |
| if hasattr(model, "img_backbone") and hasattr(model.img_backbone, "forward"): | |
| original_forward = model.img_backbone.forward | |
| def _ensure_batch(x): | |
| if x.dim() == 3: | |
| x = x.unsqueeze(0) | |
| return original_forward(x) | |
| model.img_backbone.forward = _ensure_batch | |
| return model | |
| def _index_from_name(path: str) -> int: | |
| """ | |
| Extract 0..5 from filenames like `cam_0.png` written by the backend. | |
| Falls back to natural sort order if no index found. | |
| """ | |
| m = re.search(r'(\d+)(?=\.[^.]+$)', os.path.basename(path)) | |
| return int(m.group(1)) if m else 9999 | |
| def infer_single(model, img_path: str, cam_key: str): | |
| """Run single-frame mono-3D inference with a provided camera key.""" | |
| data_info = { | |
| 'images': { | |
| cam_key: { | |
| 'img_path': img_path, | |
| 'cam2img': DUMMY_4x4, | |
| 'lidar2cam': DUMMY_4x4 | |
| } | |
| }, | |
| 'img_shape': (800, 450, 3), # adjust if your inputs differ | |
| 'scale_factor': 1.0 | |
| } | |
| ann_file = img_path + '.pkl' | |
| mmengine.dump({'data_list': [data_info]}, ann_file) | |
| try: | |
| result = inference_mono_3d_detector(model, img_path, ann_file, cam_type=cam_key) | |
| finally: | |
| try: | |
| os.remove(ann_file) | |
| except OSError: | |
| pass | |
| return result | |
| def infer_images(img_paths: List[str], model: str) -> List[str]: | |
| """ | |
| img_paths: list of 6 images (front, front_left, front_right, back, back_left, back_right) | |
| model: 'regnetx4.0gf+detr3d' or 'regnetx4.0gf+petr' | |
| Returns: list with path(s) to rendered BEV image(s) | |
| """ | |
| model_inst = get_model(model) | |
| ordered = sorted(img_paths, key=_index_from_name) | |
| if len(ordered) != 6: | |
| raise ValueError(f"Expected 6 images, got {len(ordered)}") | |
| all_boxes, all_scores, all_labels = [], [], [] | |
| for idx, img_path in enumerate(ordered): | |
| cam_key = CAM_ORDER[idx] | |
| res = infer_single(model_inst, img_path, cam_key) | |
| # Collect detections | |
| all_boxes.append(res.pred_instances_3d.bboxes_3d.tensor) | |
| all_scores.append(res.pred_instances_3d.scores_3d) | |
| all_labels.append(res.pred_instances_3d.labels_3d) | |
| boxes = torch.cat(all_boxes, dim=0) if len(all_boxes) else torch.empty((0, 7)) | |
| scores = torch.cat(all_scores, dim=0) if len(all_scores) else torch.empty((0,)) | |
| labels = torch.cat(all_labels, dim=0) if len(all_labels) else torch.empty((0,), dtype=torch.long) | |
| combined_path = os.path.join("/tmp", "combined_bev.png") | |
| plot_bev_detections(boxes, scores, labels, save_path=combined_path) | |
| return [combined_path] | |
| if __name__ == "__main__": | |
| example_imgs = [f"/path/to/cam_{i}.png" for i in range(6)] | |
| if all(os.path.isfile(p) for p in example_imgs): | |
| out = infer_images(example_imgs, model="regnetx4.0gf+petr") | |
| print("Saved:", out) | |
| else: | |
| print("Provide six images named cam_0..cam_5.png to run the demo.") | |