# app/middleware.py from __future__ import annotations import time import logging import json import asyncio from typing import Callable, Optional from anyio import EndOfStream from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from starlette.responses import Response, JSONResponse from starlette.middleware.gzip import GZipMiddleware from starlette.exceptions import ClientDisconnect # Optional: python-json-logger for structured logs; fallback to a minimal JSON formatter. try: from pythonjsonlogger import jsonlogger # type: ignore _HAS_PY_JSON_LOGGER = True except Exception: _HAS_PY_JSON_LOGGER = False from .deps import get_settings from .core.rate_limit import RateLimiter from .core.logging import add_trace_id class _SimpleJsonFormatter(logging.Formatter): def format(self, record: logging.LogRecord) -> str: payload = { "asctime": self.formatTime(record, "%Y-%m-%d %H:%M:%S"), "name": record.name, "levelname": record.levelname, "message": record.getMessage(), "trace_id": getattr(record, "trace_id", None), } try: return json.dumps(payload, ensure_ascii=False) except Exception: return ( f'{payload["asctime"]} {payload["name"]} {payload["levelname"]} ' f'{payload["message"]} trace_id={payload["trace_id"]}' ) _logger = logging.getLogger("matrix-ai") if not _logger.handlers: _logger.setLevel(logging.INFO) _handler = logging.StreamHandler() if _HAS_PY_JSON_LOGGER: _formatter = jsonlogger.JsonFormatter( "%(asctime)s %(name)s %(levelname)s %(message)s %(trace_id)s" ) else: _formatter = _SimpleJsonFormatter() logging.getLogger("uvicorn.error").warning( "python-json-logger not found; using a minimal JSON formatter." ) _handler.setFormatter(_formatter) _logger.addHandler(_handler) _rate_limiter = RateLimiter() _SSE_PATH_SUFFIXES = ("/chat/stream", "/v1/chat/stream") _HEALTH_PATHS = ("/health", "/livez", "/readyz") def _client_ip(request: Request) -> str: xff = request.headers.get("x-forwarded-for") if xff: return xff.split(",")[0].strip() return request.client.host if request.client else "unknown" def _is_sse(request: Request, response: Optional[Response] = None) -> bool: path = request.url.path if path.endswith(_SSE_PATH_SUFFIXES): return True if response is not None: ctype = response.headers.get("content-type", "") if ctype.startswith("text/event-stream"): return True accept = request.headers.get("accept", "") return "text/event-stream" in accept def attach_middlewares(app: FastAPI) -> None: app.add_middleware(GZipMiddleware, minimum_size=512) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["X-Trace-Id", "X-Process-Time-Ms", "Server-Timing"], ) @app.middleware("http") async def rate_limit_and_log_middleware(request: Request, call_next: Callable): add_trace_id(request) trace_id = getattr(request.state, "trace_id", "N/A") path = request.url.path method = request.method ua = request.headers.get("user-agent", "-") ip = _client_ip(request) if path in _HEALTH_PATHS: try: response = await call_next(request) except Exception: return JSONResponse({"status": "unhealthy"}, status_code=500) response.headers.setdefault("X-Trace-Id", str(trace_id)) return response settings = get_settings() if not _rate_limiter.allow(ip, path, settings.limits.rate_per_min): _logger.warning( "429 Too Many Requests from %s on %s", ip, path, extra={"trace_id": trace_id}, ) return JSONResponse({"detail": "Too Many Requests"}, status_code=429, headers={"X-Trace-Id": str(trace_id)}) t0 = time.time() try: response = await call_next(request) # --- NEW: treat disconnects as benign (return 204) --- except (EndOfStream, ClientDisconnect, asyncio.CancelledError): _logger.info( "Client disconnected from stream. Path: %s, IP: %s", path, ip, extra={"trace_id": trace_id}, ) resp = Response(status_code=204) resp.headers.setdefault("X-Trace-Id", str(trace_id)) return resp except RuntimeError as e: # Starlette sometimes wraps EndOfStream as this RuntimeError if str(e) == "No response returned.": _logger.info( "Downstream produced no response (likely streaming disconnect). " "Path: %s, IP: %s", path, ip, extra={"trace_id": trace_id}, ) resp = Response(status_code=204) resp.headers.setdefault("X-Trace-Id", str(trace_id)) return resp # not a disconnect case → re-raise to be handled below raise except Exception as e: _logger.exception( "Unhandled error while processing %s %s: %s", method, path, e, extra={"trace_id": trace_id}, ) dur_ms = (time.time() - t0) * 1000.0 return JSONResponse( {"detail": "Internal Server Error"}, status_code=500, headers={ "X-Trace-Id": str(trace_id), "X-Process-Time-Ms": f"{dur_ms:.2f}", "Server-Timing": f"app;dur={dur_ms:.2f}", }, ) if not isinstance(response, Response): _logger.error("Downstream returned no Response object for %s", path, extra={"trace_id": trace_id}) return JSONResponse({"detail": "Internal Server Error"}, status_code=500, headers={"X-Trace-Id": str(trace_id)}) sse = _is_sse(request, response) dur_ms = (time.time() - t0) * 1000.0 response.headers.setdefault("X-Trace-Id", str(trace_id)) response.headers.setdefault("X-Process-Time-Ms", f"{dur_ms:.2f}") response.headers.setdefault("Server-Timing", f"app;dur={dur_ms:.2f}") if sse: response.headers.setdefault("Cache-Control", "no-cache") _logger.info( '"%s %s" %s (SSE) ip=%s ua="%s" %.2fms', method, path, response.status_code, ip, ua, dur_ms, extra={"trace_id": trace_id}, ) return response _logger.info( '"%s %s" %s ip=%s ua="%s" %.2fms', method, path, response.status_code, ip, ua, dur_ms, extra={"trace_id": trace_id}, ) return response