yaghi27's picture
Rename server.py to main.py
510e9d0
raw
history blame
1.9 kB
import logging
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import base64
import mmcv
from model.run_inference import infer_images
# ─── Configure logging ─────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO)
app = FastAPI()
# ─── CORS: allow only your Space’s hf.space domain ─────────────────────────────
app.add_middleware(
CORSMiddleware,
allow_origins=["https://yaghi27-regnet-detr3d.hf.space"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
async def root():
with open("static/index.html", "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
@app.post("/infer")
async def run_inference(images: list[UploadFile] = File(...)):
img_paths = []
for upload in images:
data = await upload.read()
# Drop any alpha channel, force 3-channel BGR
bgr = mmcv.imfrombytes(data, flag="color")
tmp = f"/tmp/{upload.filename}"
mmcv.imwrite(bgr, tmp)
img_paths.append(tmp)
try:
bev_paths = infer_images(img_paths)
except Exception as e:
logging.exception("❌ inference failed")
return JSONResponse(status_code=500, content={"error": str(e)})
output = []
for p in bev_paths:
with open(p, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
output.append({"bev_image": b64})
return JSONResponse(content=output)