yaghi27's picture
Update main.py
7f2005b verified
import logging
import os
import shutil
import tempfile
import base64
from typing import List
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_cache")
os.environ.setdefault("HF_HOME", "/tmp/hf_home")
os.environ.setdefault("HF_HUB_ENABLE_XET", "0") # <-- disable xet backend
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
os.makedirs("/tmp/hf_cache", exist_ok=True)
os.makedirs("/tmp/hf_home", exist_ok=True)
os.makedirs("/tmp/mplconfig", exist_ok=True)
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import mmcv
from model.run_inference import infer_images
logging.basicConfig(level=logging.INFO)
ALLOWED_MODELS = {
"regnetx4.0gf+detr3d",
"regnetx4.0gf+petr",
}
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://yaghi27-imagetobev-lightweight.hf.space"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
async def root():
with open("static/index.html", "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
@app.post("/infer")
async def run_inference(
model: str = Form(...),
images: List[UploadFile] = File(...),
):
model = model.strip().lower()
if model not in ALLOWED_MODELS:
raise HTTPException(status_code=400, detail=f"Invalid model '{model}'. Allowed: {sorted(ALLOWED_MODELS)}")
if len(images) != 6:
raise HTTPException(status_code=400, detail=f"Expected 6 images, received {len(images)}")
tmpdir = tempfile.mkdtemp(prefix="bev_infer_")
img_paths = []
try:
for idx, upload in enumerate(images):
data = await upload.read()
bgr = mmcv.imfrombytes(data, flag="color")
if bgr is None:
raise HTTPException(status_code=400, detail=f"File '{upload.filename}' is not a valid image.")
out_path = os.path.join(tmpdir, f"cam_{idx}.png")
mmcv.imwrite(bgr, out_path)
img_paths.append(out_path)
logging.info("Starting inference with model=%s on %d images", model, len(img_paths))
bev_paths = infer_images(img_paths, model=model)
output = []
for p in bev_paths:
with open(p, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
output.append({"bev_image": b64})
return JSONResponse(content=output)
except HTTPException:
raise
except Exception as e:
logging.exception("inference failed")
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
try:
shutil.rmtree(tmpdir)
except Exception:
logging.warning("Failed to clean tmpdir %s", tmpdir)