import httpx from fastapi import FastAPI, Request, HTTPException from starlette.responses import StreamingResponse, JSONResponse from starlette.background import BackgroundTask import os import random import logging import time from contextlib import asynccontextmanager # --- Production-Ready Configuration --- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() logging.basicConfig( level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s' ) TARGET_URL = os.getenv("TARGET_URL", "https://api.gmi-serving.com") MAX_RETRIES = int(os.getenv("MAX_RETRIES", "15")) DEFAULT_RETRY_CODES = "429,500,502,503,504" RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES) try: RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')} logging.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}") except ValueError: logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}") RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')} # --- Helper Function --- def generate_random_ip(): """Generates a random, valid-looking IPv4 address.""" return ".".join(str(random.randint(1, 254)) for _ in range(4)) # --- HTTPX Client Lifecycle Management --- @asynccontextmanager async def lifespan(app: FastAPI): """Manages the lifecycle of the HTTPX client.""" async with httpx.AsyncClient(base_url=TARGET_URL, timeout=None) as client: app.state.http_client = client yield # Initialize the FastAPI app with the lifespan manager and disabled docs app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan) # --- API Endpoints --- # 1. Health Check Route (Defined FIRST) # This specific route will be matched before the catch-all proxy route. @app.get("/") async def health_check(): """Provides a basic health check endpoint.""" return JSONResponse({"status": "ok", "target": TARGET_URL}) # 2. Catch-All Reverse Proxy Route (Defined SECOND) # This will capture ALL other requests (e.g., /completions, /v1/models, etc.) # and forward them. This eliminates any redirect issues. @app.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]) async def reverse_proxy_handler(request: Request): """ A catch-all reverse proxy that forwards requests to the target URL with enhanced retry logic and latency logging. """ start_time = time.monotonic() client: httpx.AsyncClient = request.app.state.http_client url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8")) request_headers = dict(request.headers) request_headers.pop("host", None) random_ip = generate_random_ip() logging.info(f"Client '{request.client.host}' proxied with spoofed IP: {random_ip} for path: {url.path}") specific_headers = { "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36", "x-forwarded-for": random_ip, "x-real-ip": random_ip, } request_headers.update(specific_headers) if "authorization" in request.headers: request_headers["authorization"] = request.headers["authorization"] body = await request.body() last_exception = None for attempt in range(MAX_RETRIES): try: rp_req = client.build_request( method=request.method, url=url, headers=request_headers, content=body ) rp_resp = await client.send(rp_req, stream=True) if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1: duration_ms = (time.monotonic() - start_time) * 1000 log_func = logging.info if rp_resp.is_success else logging.warning log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms") return StreamingResponse( rp_resp.aiter_raw(), status_code=rp_resp.status_code, headers=rp_resp.headers, background=BackgroundTask(rp_resp.aclose), ) logging.warning( f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with status {rp_resp.status_code}. Retrying..." ) await rp_resp.aclose() except httpx.ConnectError as e: last_exception = e logging.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with connection error: {e}") duration_ms = (time.monotonic() - start_time) * 1000 logging.critical(f"Request failed, cannot connect to target: {request.method} {request.url.path} status_code=502 latency={duration_ms:.2f}ms") raise HTTPException( status_code=502, detail=f"Bad Gateway: Cannot connect to target service after {MAX_RETRIES} attempts. {last_exception}" )