Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import sqlite3 | |
| import base64 | |
| import json | |
| from datetime import datetime | |
| from typing import Optional | |
| from fastapi import FastAPI, Request, Cookie, HTTPException, WebSocket, Header | |
| from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| from huggingface_hub import whoami | |
| import httpx | |
| import websockets | |
| app = FastAPI() | |
| templates = Jinja2Templates(directory=".") | |
| # Track active WebSocket connections | |
| active_websockets = set() | |
| # OAuth configuration from HF Spaces environment | |
| OAUTH_CLIENT_ID = os.getenv("OAUTH_CLIENT_ID") | |
| OAUTH_CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET") | |
| OAUTH_SCOPES = os.getenv("OAUTH_SCOPES", "openid profile") | |
| SPACE_HOST = os.getenv("SPACE_HOST", "localhost:7860") | |
| OPENID_PROVIDER_URL = os.getenv("OPENID_PROVIDER_URL", "https://huggingface.co") | |
| # FAL API Key from environment | |
| FAL_API_KEY = os.getenv("FAL_API_KEY", "") | |
| # Database setup | |
| DB_PATH = "/data/usage.db" | |
| def init_db(): | |
| """Initialize SQLite database""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS generations ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT NOT NULL, | |
| is_pro BOOLEAN NOT NULL, | |
| generation_date DATE NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # Create index for faster queries | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_username_date | |
| ON generations(username, generation_date) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| init_db() | |
| def get_daily_usage(username: str, date: str = None) -> int: | |
| """Get number of generations used today by user""" | |
| if date is None: | |
| date = datetime.now().date().isoformat() | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT COUNT(*) FROM generations WHERE username = ? AND generation_date = ?", | |
| (username, date) | |
| ) | |
| count = cursor.fetchone()[0] | |
| conn.close() | |
| return count | |
| def record_generation(username: str, is_pro: bool): | |
| """Record a new generation - called every time user clicks 'Start Generation'""" | |
| date = datetime.now().date().isoformat() | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT INTO generations (username, is_pro, generation_date) VALUES (?, ?, ?)", | |
| (username, is_pro, date) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def can_start_generation(username: str, is_pro: bool) -> tuple[bool, int, int]: | |
| """Check if user can start a new generation. Returns (can_start, used, limit)""" | |
| used = get_daily_usage(username) | |
| limit = 15 if is_pro else 2 | |
| return used < limit, used, limit | |
| def get_origin_from_request(request: Request) -> str: | |
| """Get the origin (scheme + host) from the request, detecting HTTPS from proxy headers""" | |
| # Check proxy headers for original protocol (common when behind reverse proxy) | |
| proto = request.headers.get("X-Forwarded-Proto", "") | |
| ssl = request.headers.get("X-Forwarded-Ssl", "") | |
| # Get host from headers (handles both direct access and proxy) | |
| host = request.headers.get("X-Forwarded-Host") or request.headers.get("Host") or "" | |
| # Determine scheme | |
| if proto == "https" or ssl == "on": | |
| scheme = "https" | |
| elif ".hf.space" in host or "huggingface.co" in host: | |
| # Force HTTPS for Hugging Face domains (they always serve over HTTPS) | |
| scheme = "https" | |
| else: | |
| scheme = request.url.scheme or "https" | |
| # Build origin URL | |
| if host: | |
| return f"{scheme}://{host}" | |
| # Fallback to SPACE_HOST environment variable with HTTPS | |
| return f"https://{SPACE_HOST}" | |
| def get_token_from_request(cookie_token: Optional[str], auth_header: Optional[str]) -> Optional[str]: | |
| """Extract access token from either cookie or Authorization header""" | |
| # Try Authorization header first (Bearer token) | |
| if auth_header: | |
| parts = auth_header.split() | |
| if len(parts) == 2 and parts[0].lower() == "bearer": | |
| return parts[1] | |
| # Fall back to cookie | |
| return cookie_token | |
| async def exchange_code_for_token(code: str, redirect_uri: str) -> dict: | |
| """Exchange OAuth code for access token""" | |
| token_url = f"{OPENID_PROVIDER_URL}/oauth/token" | |
| credentials = f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}" | |
| b64_credentials = base64.b64encode(credentials.encode()).decode() | |
| headers = { | |
| "Authorization": f"Basic {b64_credentials}", | |
| "Content-Type": "application/x-www-form-urlencoded" | |
| } | |
| data = { | |
| "grant_type": "authorization_code", | |
| "code": code, | |
| "redirect_uri": redirect_uri, | |
| "client_id": OAUTH_CLIENT_ID | |
| } | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post(token_url, data=data, headers=headers) | |
| response.raise_for_status() | |
| return response.json() | |
| async def get_user_info(access_token: str) -> dict: | |
| """Get user info from access token using whoami""" | |
| try: | |
| user_data = whoami(token=access_token) | |
| return { | |
| "username": user_data.get("name"), | |
| "is_pro": user_data.get("isPro", False), | |
| "avatar": user_data.get("avatarUrl"), | |
| "fullname": user_data.get("fullname", user_data.get("name")), | |
| "email": user_data.get("email") | |
| } | |
| except Exception as e: | |
| print(f"Failed to get user info: {e}") | |
| raise HTTPException(status_code=401, detail="Failed to get user information") | |
| async def home(request: Request): | |
| """Home page - client-side auth with popup OAuth""" | |
| # Return template - authentication will be handled client-side | |
| return templates.TemplateResponse("index.html", { | |
| "request": request, | |
| "oauth_client_id": OAUTH_CLIENT_ID | |
| }) | |
| async def auth_login(request: Request, state: Optional[str] = None): | |
| """OAuth login - stores state in cookie and redirects to HF OAuth""" | |
| # Dynamically detect origin from request | |
| origin = get_origin_from_request(request) | |
| redirect_uri = f"{origin}/oauth/callback" | |
| # Generate or use provided state | |
| oauth_state = state or os.urandom(16).hex() | |
| # Build OAuth authorize URL | |
| auth_url = f"https://huggingface.co/oauth/authorize" | |
| auth_url += f"?response_type=code" | |
| auth_url += f"&client_id={OAUTH_CLIENT_ID}" | |
| auth_url += f"&redirect_uri={redirect_uri}" | |
| auth_url += f"&scope=openid profile" | |
| auth_url += f"&state={oauth_state}" | |
| # Create response that redirects to HF OAuth | |
| response = RedirectResponse(url=auth_url, status_code=302) | |
| # Store state in cookie for validation in callback | |
| # Note: samesite="none" is required for iframe/cross-site contexts | |
| if not state: # Only set cookie if state wasn't provided | |
| response.set_cookie( | |
| key="hf_oauth_state", | |
| value=oauth_state, | |
| httponly=True, | |
| samesite="none", # Required for iframe/third-party context | |
| secure=True, | |
| max_age=300, # 5 minutes | |
| path="/" | |
| ) | |
| return response | |
| async def auth_exchange(request: Request, code: str, state: str, hf_oauth_state: Optional[str] = Cookie(None)): | |
| """Exchange OAuth code for access token - called from callback page""" | |
| # Validate state from cookie | |
| if not hf_oauth_state or state != hf_oauth_state: | |
| raise HTTPException(status_code=400, detail="Invalid or expired OAuth state") | |
| origin = get_origin_from_request(request) | |
| redirect_uri = f"{origin}/oauth/callback" | |
| try: | |
| token_data = await exchange_code_for_token(code, redirect_uri) | |
| access_token = token_data.get("access_token") | |
| if not access_token: | |
| raise HTTPException(status_code=400, detail="No access token received") | |
| # Get user info | |
| user_info = await get_user_info(access_token) | |
| # Return token and user info | |
| response = JSONResponse({ | |
| "token": access_token, | |
| "namespace": user_info["username"] | |
| }) | |
| response.delete_cookie("hf_oauth_state") | |
| # Also set access_token cookie for WebSocket authentication | |
| response.set_cookie( | |
| key="access_token", | |
| value=access_token, | |
| httponly=True, | |
| samesite="none", | |
| secure=True, | |
| max_age=30 * 24 * 60 * 60, | |
| path="/" | |
| ) | |
| return response | |
| except Exception as e: | |
| response = JSONResponse( | |
| {"error": str(e)}, | |
| status_code=400 | |
| ) | |
| response.delete_cookie("hf_oauth_state") | |
| raise HTTPException(status_code=400, detail=f"Token exchange failed: {str(e)}") | |
| async def oauth_callback(request: Request): | |
| """OAuth callback page - exchanges code for token client-side""" | |
| callback_html = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head><title>Authenticating...</title></head> | |
| <body style="font-family: sans-serif; padding: 40px; text-align: center;"> | |
| <h2>Authenticating...</h2> | |
| <p>Please wait while we complete your login.</p> | |
| <script> | |
| (async function() { | |
| const params = new URLSearchParams(window.location.search); | |
| const code = params.get('code'); | |
| const state = params.get('state'); | |
| const error = params.get('error'); | |
| if (error) { | |
| document.body.innerHTML = '<h2>Authentication failed</h2><p>' + error + '</p>'; | |
| setTimeout(() => window.location.href = '/', 3000); | |
| return; | |
| } | |
| if (!code || !state) { | |
| document.body.innerHTML = '<h2>Authentication failed</h2><p>Missing authorization code</p>'; | |
| setTimeout(() => window.location.href = '/', 3000); | |
| return; | |
| } | |
| try { | |
| // Exchange code for token | |
| const response = await fetch('/api/auth/exchange?code=' + code + '&state=' + state, { | |
| method: 'POST', | |
| credentials: 'same-origin' | |
| }); | |
| if (!response.ok) { | |
| const data = await response.json().catch(() => ({ detail: 'Unknown error' })); | |
| throw new Error(data.detail || data.error || 'Failed to exchange code for token'); | |
| } | |
| const data = await response.json(); | |
| // Store in localStorage | |
| const authState = { | |
| token: data.token, | |
| user: { username: data.namespace } | |
| }; | |
| localStorage.setItem('HF_AUTH_STATE', JSON.stringify(authState)); | |
| // Redirect back to home | |
| window.location.href = '/'; | |
| } catch (err) { | |
| document.body.innerHTML = '<h2>Authentication failed</h2><p>' + err.message + '</p><p><a href="/">Return to app</a></p>'; | |
| } | |
| })(); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=callback_html) | |
| async def whoami_endpoint(authorization: Optional[str] = Header(None)): | |
| """Validate token and return user info""" | |
| if not authorization: | |
| raise HTTPException(status_code=401, detail="No authorization header") | |
| # Extract Bearer token | |
| parts = authorization.split() | |
| if len(parts) != 2 or parts[0].lower() != "bearer": | |
| raise HTTPException(status_code=401, detail="Invalid authorization header format") | |
| access_token = parts[1] | |
| try: | |
| user_info = await get_user_info(access_token) | |
| # Get usage info | |
| can_start, used, limit = can_start_generation(user_info["username"], user_info["is_pro"]) | |
| return { | |
| "username": user_info["username"], | |
| "fullname": user_info["fullname"], | |
| "is_pro": user_info["is_pro"], | |
| "avatar": user_info.get("avatar"), | |
| "can_start": can_start, | |
| "sessions_used": used, | |
| "sessions_limit": limit | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") | |
| async def start_session(access_token: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): | |
| """Start a new generation - counts towards daily limit""" | |
| token = get_token_from_request(access_token, authorization) | |
| if not token: | |
| raise HTTPException(status_code=401, detail="Not authenticated") | |
| try: | |
| user_info = await get_user_info(token) | |
| except: | |
| raise HTTPException(status_code=401, detail="Invalid session") | |
| can_start, used, limit = can_start_generation(user_info["username"], user_info["is_pro"]) | |
| if not can_start: | |
| raise HTTPException( | |
| status_code=429, | |
| detail=f"Daily limit reached. You've used {used}/{limit} generations today." | |
| ) | |
| # Record this generation | |
| record_generation(user_info["username"], user_info["is_pro"]) | |
| # Get updated count | |
| new_count = get_daily_usage(user_info["username"]) | |
| return { | |
| "success": True, | |
| "sessions_used": new_count, | |
| "sessions_limit": limit | |
| } | |
| async def check_limits(access_token: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): | |
| """Check current usage limits""" | |
| token = get_token_from_request(access_token, authorization) | |
| if not token: | |
| raise HTTPException(status_code=401, detail="Not authenticated") | |
| try: | |
| user_info = await get_user_info(token) | |
| except: | |
| raise HTTPException(status_code=401, detail="Invalid session") | |
| can_start, used, limit = can_start_generation(user_info["username"], user_info["is_pro"]) | |
| return { | |
| "can_start": can_start, | |
| "sessions_used": used, | |
| "sessions_limit": limit, | |
| "is_pro": user_info["is_pro"] | |
| } | |
| async def logout(): | |
| """Logout user""" | |
| response = JSONResponse({"success": True}) | |
| response.delete_cookie("access_token") | |
| return response | |
| async def health(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "ok", | |
| "oauth_enabled": bool(OAUTH_CLIENT_ID), | |
| "fal_api_key_configured": bool(FAL_API_KEY) | |
| } | |
| async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str] = None): | |
| """WebSocket proxy to FAL API - keeps API key secret""" | |
| from fastapi import WebSocket | |
| import websockets | |
| import json | |
| await websocket.accept() | |
| # Track this connection | |
| active_websockets.add(websocket) | |
| print(f"WebSocket connected. Active connections: {len(active_websockets)}") | |
| try: | |
| # Get user from cookie | |
| access_token = websocket.cookies.get("access_token") | |
| if not access_token: | |
| await websocket.close(code=1008, reason="Not authenticated") | |
| return | |
| try: | |
| user_info = await get_user_info(access_token) | |
| except: | |
| await websocket.close(code=1008, reason="Invalid session") | |
| return | |
| # If user provided their own FAL key, use it (bypass limits) | |
| if user_fal_key: | |
| fal_key_to_use = user_fal_key | |
| else: | |
| # Check if user can start session with server FAL key | |
| can_start, used, limit = can_start_generation(user_info["username"], user_info["is_pro"]) | |
| if not can_start: | |
| await websocket.close(code=1008, reason=f"Daily limit reached ({used}/{limit})") | |
| return | |
| if not FAL_API_KEY: | |
| await websocket.close(code=1011, reason="FAL API key not configured") | |
| return | |
| fal_key_to_use = FAL_API_KEY | |
| # Fetch temporary FAL token | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| "https://rest.alpha.fal.ai/tokens/", | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Authorization": f"Key {fal_key_to_use}" | |
| }, | |
| json={ | |
| "allowed_apps": ["krea-wan-14b"], | |
| "token_expiration": 5000 | |
| } | |
| ) | |
| response.raise_for_status() | |
| fal_token = response.json() | |
| except Exception as e: | |
| await websocket.close(code=1011, reason=f"Failed to get FAL token: {str(e)}") | |
| return | |
| # Connect to FAL WebSocket | |
| fal_ws_url = f"wss://fal.run/fal-ai/krea-wan-14b/ws?fal_jwt_token={fal_token}" | |
| async with websockets.connect(fal_ws_url) as fal_ws: | |
| # Relay messages between client and FAL | |
| async def client_to_fal(): | |
| try: | |
| while True: | |
| # Receive from client | |
| data = await websocket.receive_bytes() | |
| # Forward to FAL | |
| await fal_ws.send(data) | |
| except Exception as e: | |
| print(f"Client to FAL error: {e}") | |
| raise # Re-raise to stop both coroutines | |
| async def fal_to_client(): | |
| try: | |
| while True: | |
| # Receive from FAL | |
| message = await fal_ws.recv() | |
| # Forward to client | |
| if isinstance(message, str): | |
| await websocket.send_text(message) | |
| else: | |
| await websocket.send_bytes(message) | |
| except Exception as e: | |
| print(f"FAL to client error: {e}") | |
| raise # Re-raise to stop both coroutines | |
| # Run both directions concurrently - if either fails, both stop | |
| import asyncio | |
| try: | |
| await asyncio.gather( | |
| client_to_fal(), | |
| fal_to_client() | |
| ) | |
| except Exception: | |
| # One direction failed, close everything | |
| pass | |
| except Exception as e: | |
| print(f"WebSocket proxy error: {e}") | |
| await websocket.close(code=1011, reason=str(e)) | |
| finally: | |
| # Remove from active connections - ALWAYS executes | |
| active_websockets.discard(websocket) | |
| print(f"WebSocket disconnected. Active connections: {len(active_websockets)}") |