Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| import base64 | |
| from typing import List | |
| os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_cache") | |
| os.environ.setdefault("HF_HOME", "/tmp/hf_home") | |
| os.environ.setdefault("HF_HUB_ENABLE_XET", "0") # <-- disable xet backend | |
| os.environ.setdefault("MPLCONFIGDIR", "/tmp/mplconfig") | |
| os.makedirs("/tmp/hf_cache", exist_ok=True) | |
| os.makedirs("/tmp/hf_home", exist_ok=True) | |
| os.makedirs("/tmp/mplconfig", exist_ok=True) | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import mmcv | |
| from model.run_inference import infer_images | |
| logging.basicConfig(level=logging.INFO) | |
| ALLOWED_MODELS = { | |
| "regnetx4.0gf+detr3d", | |
| "regnetx4.0gf+petr", | |
| } | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://yaghi27-imagetobev-lightweight.hf.space"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| async def root(): | |
| with open("static/index.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(f.read()) | |
| async def run_inference( | |
| model: str = Form(...), | |
| images: List[UploadFile] = File(...), | |
| ): | |
| model = model.strip().lower() | |
| if model not in ALLOWED_MODELS: | |
| raise HTTPException(status_code=400, detail=f"Invalid model '{model}'. Allowed: {sorted(ALLOWED_MODELS)}") | |
| if len(images) != 6: | |
| raise HTTPException(status_code=400, detail=f"Expected 6 images, received {len(images)}") | |
| tmpdir = tempfile.mkdtemp(prefix="bev_infer_") | |
| img_paths = [] | |
| try: | |
| for idx, upload in enumerate(images): | |
| data = await upload.read() | |
| bgr = mmcv.imfrombytes(data, flag="color") | |
| if bgr is None: | |
| raise HTTPException(status_code=400, detail=f"File '{upload.filename}' is not a valid image.") | |
| out_path = os.path.join(tmpdir, f"cam_{idx}.png") | |
| mmcv.imwrite(bgr, out_path) | |
| img_paths.append(out_path) | |
| logging.info("Starting inference with model=%s on %d images", model, len(img_paths)) | |
| bev_paths = infer_images(img_paths, model=model) | |
| output = [] | |
| for p in bev_paths: | |
| with open(p, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("utf-8") | |
| output.append({"bev_image": b64}) | |
| return JSONResponse(content=output) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.exception("inference failed") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| try: | |
| shutil.rmtree(tmpdir) | |
| except Exception: | |
| logging.warning("Failed to clean tmpdir %s", tmpdir) | |