File size: 6,029 Bytes
ac046a6
 
 
 
 
 
 
 
37ea0ab
 
 
 
52e1cb2
378efa2
37ea0ab
52e1cb2
 
37ea0ab
 
52e1cb2
378efa2
 
 
37ea0ab
 
52e1cb2
ac046a6
52e1cb2
 
378efa2
 
4739b3b
378efa2
52e1cb2
 
4739b3b
378efa2
52e1cb2
 
 
378efa2
 
 
 
 
 
 
 
 
 
52e1cb2
1230385
e04f1e3
 
 
 
 
378efa2
1230385
378efa2
e04f1e3
 
 
1230385
378efa2
 
 
52e1cb2
ac046a6
378efa2
 
 
 
 
 
 
 
 
ac046a6
378efa2
 
 
 
ac046a6
378efa2
52e1cb2
 
 
 
 
 
378efa2
52e1cb2
 
 
378efa2
 
 
52e1cb2
378efa2
 
 
 
52e1cb2
378efa2
 
 
ac046a6
1230385
52e1cb2
ac046a6
378efa2
 
ac046a6
52e1cb2
 
 
378efa2
52e1cb2
378efa2
37ea0ab
ac046a6
 
 
 
 
 
 
378efa2
37ea0ab
 
ac046a6
37ea0ab
 
52e1cb2
 
 
ac046a6
 
 
 
37ea0ab
 
378efa2
52e1cb2
ac046a6
 
378efa2
ac046a6
 
52e1cb2
ac046a6
52e1cb2
 
 
 
37ea0ab
ac046a6
52e1cb2
 
 
ac046a6
37ea0ab
 
 
 
378efa2
 
 
37ea0ab
 
 
378efa2
37ea0ab
378efa2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import re
import cv2
import torch
import numpy as np
import torch.nn as nn
import mmengine

from mmengine.logging.history_buffer import HistoryBuffer
import numpy.core.multiarray as cam
import numpy.dtypes as dtypes

_orig_torch_load = torch.load
torch.load = lambda f, **kwargs: _orig_torch_load(f, weights_only=False, **kwargs)  
torch.serialization.add_safe_globals([
    cam.scalar, HistoryBuffer, cam._reconstruct, np.ndarray, np.dtype,
    dtypes.Float64DType, dtypes.Int64DType, getattr,
])

from functools import lru_cache
from typing import List, Tuple, Dict, Union

from huggingface_hub import hf_hub_download
from mmengine.config import Config
from mmdet3d.apis import init_model, inference_mono_3d_detector
from model.utils import plot_bev_detections

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

MODEL_REGISTRY: Dict[str, Tuple[str, Tuple[str, str]]] = {
    "regnetx4.0gf+detr3d": (
        "model/DETR3D/detr3d_r101_gridmask.py",
        ("yaghi27/RegnetX4.0GF_DETR3D", "epoch_30.pth"),
    ),
    "regnetx4.0gf+petr": (
        "model/PETR/petr_vovnet_gridmask_p4_800x320.py",
        ("yaghi27/RegnetX4.0GF_PETR", "epoch_24.pth"),
    ),
}

CAM_ORDER = [
    "CAM_FRONT",
    "CAM_FRONT_LEFT",
    "CAM_FRONT_RIGHT",
    "CAM_BACK",
    "CAM_BACK_LEFT",
    "CAM_BACK_RIGHT",
]

DUMMY_4x4 = np.eye(4, dtype=np.float32)


def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = None) -> str:
    # default to a writable cache location in Spaces/containers
    cache_dir = cache_dir or os.environ.get("HF_HUB_CACHE") or "/tmp/hf_cache"
    os.makedirs(cache_dir, exist_ok=True)

    ckpt_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        revision="main",
        cache_dir=cache_dir,
        token=os.environ.get("HF_TOKEN"),
    )
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}")
    return ckpt_path


@lru_cache(maxsize=2)
def get_model(model_key: str):
    """Load and cache a detector for the selected model."""
    if model_key not in MODEL_REGISTRY:
        raise ValueError(f"Unknown model '{model_key}'. Available: {list(MODEL_REGISTRY)}")

    config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key]

    # Ensure local config exists
    if not os.path.isfile(config_path):
        raise FileNotFoundError(f"Config not found: {config_path}")

    # Download checkpoint from the Hub
    ckpt_path = _download_ckpt_from_hf(repo_id, hf_file)

    # Load config and avoid auto-pretraining downloads
    cfg = Config.fromfile(config_path)
    if hasattr(cfg, "model") and isinstance(cfg.model, dict):
        cfg.model.setdefault("pretrained", None)
        if "backbone" in cfg.model and isinstance(cfg.model["backbone"], dict):
            cfg.model["backbone"].setdefault("init_cfg", None)

    # Build model
    model = init_model(config_path, ckpt_path, device=DEVICE)
    model.eval()

    # Some backbones expect a batch dimension; enforce it
    if hasattr(model, "img_backbone") and hasattr(model.img_backbone, "forward"):
        original_forward = model.img_backbone.forward

        def _ensure_batch(x):
            if x.dim() == 3:
                x = x.unsqueeze(0)
            return original_forward(x)

        model.img_backbone.forward = _ensure_batch

    return model


def _index_from_name(path: str) -> int:
    """
    Extract 0..5 from filenames like `cam_0.png` written by the backend.
    Falls back to natural sort order if no index found.
    """
    m = re.search(r'(\d+)(?=\.[^.]+$)', os.path.basename(path))
    return int(m.group(1)) if m else 9999


def infer_single(model, img_path: str, cam_key: str):
    """Run single-frame mono-3D inference with a provided camera key."""
    data_info = {
        'images': {
            cam_key: {
                'img_path': img_path,
                'cam2img': DUMMY_4x4,
                'lidar2cam': DUMMY_4x4
            }
        },
        'img_shape': (800, 450, 3),  # adjust if your inputs differ
        'scale_factor': 1.0
    }

    ann_file = img_path + '.pkl'
    mmengine.dump({'data_list': [data_info]}, ann_file)
    try:
        result = inference_mono_3d_detector(model, img_path, ann_file, cam_type=cam_key)
    finally:
        try:
            os.remove(ann_file)
        except OSError:
            pass
    return result


def infer_images(img_paths: List[str], model: str) -> List[str]:
    """
    img_paths: list of 6 images (front, front_left, front_right, back, back_left, back_right)
    model: 'regnetx4.0gf+detr3d' or 'regnetx4.0gf+petr'
    Returns: list with path(s) to rendered BEV image(s)
    """
    model_inst = get_model(model)

    ordered = sorted(img_paths, key=_index_from_name)
    if len(ordered) != 6:
        raise ValueError(f"Expected 6 images, got {len(ordered)}")

    all_boxes, all_scores, all_labels = [], [], []

    for idx, img_path in enumerate(ordered):
        cam_key = CAM_ORDER[idx]
        res = infer_single(model_inst, img_path, cam_key)
        # Collect detections
        all_boxes.append(res.pred_instances_3d.bboxes_3d.tensor)
        all_scores.append(res.pred_instances_3d.scores_3d)
        all_labels.append(res.pred_instances_3d.labels_3d)

    boxes  = torch.cat(all_boxes,  dim=0) if len(all_boxes)  else torch.empty((0, 7))
    scores = torch.cat(all_scores, dim=0) if len(all_scores) else torch.empty((0,))
    labels = torch.cat(all_labels, dim=0) if len(all_labels) else torch.empty((0,), dtype=torch.long)

    combined_path = os.path.join("/tmp", "combined_bev.png")
    plot_bev_detections(boxes, scores, labels, save_path=combined_path)

    return [combined_path]


if __name__ == "__main__":
    example_imgs = [f"/path/to/cam_{i}.png" for i in range(6)]
    if all(os.path.isfile(p) for p in example_imgs):
        out = infer_images(example_imgs, model="regnetx4.0gf+petr")
        print("Saved:", out)
    else:
        print("Provide six images named cam_0..cam_5.png to run the demo.")