Spaces:
Build error
Build error
| import asyncio | |
| import os | |
| from collections import defaultdict | |
| from datetime import datetime, timedelta | |
| from urllib.parse import urlparse | |
| from fastapi import Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | |
| from starlette.requests import Request as StarletteRequest | |
| from starlette.responses import Response | |
| from starlette.types import ASGIApp | |
| class LocalhostCORSMiddleware(CORSMiddleware): | |
| """ | |
| Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, | |
| while using standard CORS rules for other origins. | |
| """ | |
| def __init__(self, app: ASGIApp) -> None: | |
| allow_origins_str = os.getenv('PERMITTED_CORS_ORIGINS') | |
| if allow_origins_str: | |
| allow_origins = tuple( | |
| origin.strip() for origin in allow_origins_str.split(',') | |
| ) | |
| else: | |
| allow_origins = () | |
| super().__init__( | |
| app, | |
| allow_origins=allow_origins, | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*'], | |
| ) | |
| def is_allowed_origin(self, origin: str) -> bool: | |
| if origin and not self.allow_origins and not self.allow_origin_regex: | |
| parsed = urlparse(origin) | |
| hostname = parsed.hostname or '' | |
| # Allow any localhost/127.0.0.1 origin regardless of port | |
| if hostname in ['localhost', '127.0.0.1']: | |
| return True | |
| # For missing origin or other origins, use the parent class's logic | |
| result: bool = super().is_allowed_origin(origin) | |
| return result | |
| class CacheControlMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Middleware to disable caching for all routes by adding appropriate headers | |
| """ | |
| async def dispatch( | |
| self, request: Request, call_next: RequestResponseEndpoint | |
| ) -> Response: | |
| response = await call_next(request) | |
| if request.url.path.startswith('/assets'): | |
| # The content of the assets directory has fingerprinted file names so we cache aggressively | |
| response.headers['Cache-Control'] = 'public, max-age=2592000, immutable' | |
| else: | |
| response.headers['Cache-Control'] = ( | |
| 'no-cache, no-store, must-revalidate, max-age=0' | |
| ) | |
| response.headers['Pragma'] = 'no-cache' | |
| response.headers['Expires'] = '0' | |
| return response | |
| class InMemoryRateLimiter: | |
| history: dict[str, list[datetime]] | |
| requests: int | |
| seconds: int | |
| sleep_seconds: int | |
| def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1): | |
| self.requests = requests | |
| self.seconds = seconds | |
| self.sleep_seconds = sleep_seconds | |
| self.history = defaultdict(list) | |
| self.sleep_seconds = sleep_seconds | |
| def _clean_old_requests(self, key: str) -> None: | |
| now = datetime.now() | |
| cutoff = now - timedelta(seconds=self.seconds) | |
| self.history[key] = [ts for ts in self.history[key] if ts > cutoff] | |
| async def __call__(self, request: Request) -> bool: | |
| key = request.client.host | |
| now = datetime.now() | |
| self._clean_old_requests(key) | |
| self.history[key].append(now) | |
| if len(self.history[key]) > self.requests * 2: | |
| return False | |
| elif len(self.history[key]) > self.requests: | |
| if self.sleep_seconds > 0: | |
| await asyncio.sleep(self.sleep_seconds) | |
| return True | |
| else: | |
| return False | |
| return True | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter): | |
| super().__init__(app) | |
| self.rate_limiter = rate_limiter | |
| async def dispatch( | |
| self, request: Request, call_next: RequestResponseEndpoint | |
| ) -> Response: | |
| if not self.is_rate_limited_request(request): | |
| return await call_next(request) | |
| ok = await self.rate_limiter(request) | |
| if not ok: | |
| return JSONResponse( | |
| status_code=429, | |
| content={'message': 'Too many requests'}, | |
| headers={'Retry-After': '1'}, | |
| ) | |
| return await call_next(request) | |
| def is_rate_limited_request(self, request: StarletteRequest) -> bool: | |
| if request.url.path.startswith('/assets'): | |
| return False | |
| # Put Other non rate limited checks here | |
| return True | |