File size: 2,962 Bytes
d612c78
431cc3e
 
 
 
 
 
7a55677
 
 
 
 
 
 
 
 
431cc3e
e1e0b8e
22e9be5
431cc3e
7f2005b
d612c78
7a55677
b7db3a9
d612c78
 
431cc3e
0c39889
431cc3e
 
 
b7db3a9
 
 
74f32c9
b7db3a9
 
 
 
e1e0b8e
b7db3a9
431cc3e
b7db3a9
 
abbb4f7
e1e0b8e
b7db3a9
d612c78
e1e0b8e
431cc3e
 
 
 
 
 
 
 
 
 
 
 
e1e0b8e
 
a163efa
431cc3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a163efa
54180a3
d612c78
431cc3e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import logging
import os
import shutil
import tempfile
import base64
from typing import List

os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_cache")
os.environ.setdefault("HF_HOME", "/tmp/hf_home")
os.environ.setdefault("HF_HUB_ENABLE_XET", "0")  # <-- disable xet backend
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")

os.makedirs("/tmp/hf_cache", exist_ok=True)
os.makedirs("/tmp/hf_home", exist_ok=True)
os.makedirs("/tmp/mplconfig", exist_ok=True)

from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles

from fastapi.middleware.cors import CORSMiddleware
import mmcv
from model.run_inference import infer_images  

logging.basicConfig(level=logging.INFO)

ALLOWED_MODELS = {
    "regnetx4.0gf+detr3d",
    "regnetx4.0gf+petr",
}

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://yaghi27-imagetobev-lightweight.hf.space"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")


@app.get("/", response_class=HTMLResponse)
async def root():
    with open("static/index.html", "r", encoding="utf-8") as f:
        return HTMLResponse(f.read())


@app.post("/infer")
async def run_inference(
    model: str = Form(...),
    images: List[UploadFile] = File(...),
):
    model = model.strip().lower()
    if model not in ALLOWED_MODELS:
        raise HTTPException(status_code=400, detail=f"Invalid model '{model}'. Allowed: {sorted(ALLOWED_MODELS)}")

    if len(images) != 6:
        raise HTTPException(status_code=400, detail=f"Expected 6 images, received {len(images)}")

    tmpdir = tempfile.mkdtemp(prefix="bev_infer_")
    img_paths = []

    try:
        for idx, upload in enumerate(images):
            data = await upload.read()
            bgr = mmcv.imfrombytes(data, flag="color")
            if bgr is None:
                raise HTTPException(status_code=400, detail=f"File '{upload.filename}' is not a valid image.")

            out_path = os.path.join(tmpdir, f"cam_{idx}.png")
            mmcv.imwrite(bgr, out_path)
            img_paths.append(out_path)

        logging.info("Starting inference with model=%s on %d images", model, len(img_paths))

        bev_paths = infer_images(img_paths, model=model)

        output = []
        for p in bev_paths:
            with open(p, "rb") as f:
                b64 = base64.b64encode(f.read()).decode("utf-8")
            output.append({"bev_image": b64})

        return JSONResponse(content=output)

    except HTTPException:

        raise
    except Exception as e:
        logging.exception("inference failed")
        return JSONResponse(status_code=500, content={"error": str(e)})
    finally:
        try:
            shutil.rmtree(tmpdir)
        except Exception:
            logging.warning("Failed to clean tmpdir %s", tmpdir)