import logging import os import shutil import tempfile import base64 from typing import List os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_cache") os.environ.setdefault("HF_HOME", "/tmp/hf_home") os.environ.setdefault("HF_HUB_ENABLE_XET", "0") # <-- disable xet backend os.environ.setdefault("MPLCONFIGDIR", "/tmp/mplconfig") os.makedirs("/tmp/hf_cache", exist_ok=True) os.makedirs("/tmp/hf_home", exist_ok=True) os.makedirs("/tmp/mplconfig", exist_ok=True) from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware import mmcv from model.run_inference import infer_images logging.basicConfig(level=logging.INFO) ALLOWED_MODELS = { "regnetx4.0gf+detr3d", "regnetx4.0gf+petr", } app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["https://yaghi27-imagetobev-lightweight.hf.space"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/", response_class=HTMLResponse) async def root(): with open("static/index.html", "r", encoding="utf-8") as f: return HTMLResponse(f.read()) @app.post("/infer") async def run_inference( model: str = Form(...), images: List[UploadFile] = File(...), ): model = model.strip().lower() if model not in ALLOWED_MODELS: raise HTTPException(status_code=400, detail=f"Invalid model '{model}'. Allowed: {sorted(ALLOWED_MODELS)}") if len(images) != 6: raise HTTPException(status_code=400, detail=f"Expected 6 images, received {len(images)}") tmpdir = tempfile.mkdtemp(prefix="bev_infer_") img_paths = [] try: for idx, upload in enumerate(images): data = await upload.read() bgr = mmcv.imfrombytes(data, flag="color") if bgr is None: raise HTTPException(status_code=400, detail=f"File '{upload.filename}' is not a valid image.") out_path = os.path.join(tmpdir, f"cam_{idx}.png") mmcv.imwrite(bgr, out_path) img_paths.append(out_path) logging.info("Starting inference with model=%s on %d images", model, len(img_paths)) bev_paths = infer_images(img_paths, model=model) output = [] for p in bev_paths: with open(p, "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") output.append({"bev_image": b64}) return JSONResponse(content=output) except HTTPException: raise except Exception as e: logging.exception("inference failed") return JSONResponse(status_code=500, content={"error": str(e)}) finally: try: shutil.rmtree(tmpdir) except Exception: logging.warning("Failed to clean tmpdir %s", tmpdir)