Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
| 1 |
import os
|
| 2 |
import sqlite3
|
| 3 |
import base64
|
| 4 |
-
from datetime import datetime
|
| 5 |
from typing import Optional
|
| 6 |
-
from fastapi import FastAPI, Request,
|
| 7 |
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
| 8 |
from fastapi.templating import Jinja2Templates
|
| 9 |
-
from fastapi.staticfiles import StaticFiles
|
| 10 |
from huggingface_hub import whoami
|
| 11 |
-
import secrets
|
| 12 |
import httpx
|
|
|
|
| 13 |
|
| 14 |
app = FastAPI()
|
| 15 |
templates = Jinja2Templates(directory=".")
|
|
@@ -21,6 +20,9 @@ OAUTH_SCOPES = os.getenv("OAUTH_SCOPES", "openid profile")
|
|
| 21 |
SPACE_HOST = os.getenv("SPACE_HOST", "multimodalart-krea-realtime-video.hf.space")
|
| 22 |
OPENID_PROVIDER_URL = os.getenv("OPENID_PROVIDER_URL", "https://huggingface.co")
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
# Database setup
|
| 25 |
DB_PATH = "/data/usage.db"
|
| 26 |
|
|
@@ -70,7 +72,6 @@ def record_session(username: str, is_pro: bool):
|
|
| 70 |
)
|
| 71 |
conn.commit()
|
| 72 |
except sqlite3.IntegrityError:
|
| 73 |
-
# Already recorded today
|
| 74 |
pass
|
| 75 |
finally:
|
| 76 |
conn.close()
|
|
@@ -85,7 +86,6 @@ async def exchange_code_for_token(code: str, redirect_uri: str) -> dict:
|
|
| 85 |
"""Exchange OAuth code for access token"""
|
| 86 |
token_url = f"{OPENID_PROVIDER_URL}/oauth/token"
|
| 87 |
|
| 88 |
-
# Prepare Basic Auth header
|
| 89 |
credentials = f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}"
|
| 90 |
b64_credentials = base64.b64encode(credentials.encode()).decode()
|
| 91 |
|
|
@@ -126,7 +126,6 @@ async def home(request: Request, access_token: Optional[str] = Cookie(None)):
|
|
| 126 |
"""Home page - check auth and show app or login"""
|
| 127 |
|
| 128 |
if not access_token:
|
| 129 |
-
# Not logged in - show login button
|
| 130 |
return templates.TemplateResponse("index.html", {
|
| 131 |
"request": request,
|
| 132 |
"authenticated": False,
|
|
@@ -135,11 +134,9 @@ async def home(request: Request, access_token: Optional[str] = Cookie(None)):
|
|
| 135 |
"space_host": SPACE_HOST
|
| 136 |
})
|
| 137 |
|
| 138 |
-
# Verify token and get user info
|
| 139 |
try:
|
| 140 |
user_info = await get_user_info(access_token)
|
| 141 |
except:
|
| 142 |
-
# Invalid token, clear it
|
| 143 |
response = templates.TemplateResponse("index.html", {
|
| 144 |
"request": request,
|
| 145 |
"authenticated": False,
|
|
@@ -151,7 +148,6 @@ async def home(request: Request, access_token: Optional[str] = Cookie(None)):
|
|
| 151 |
response.delete_cookie("access_token")
|
| 152 |
return response
|
| 153 |
|
| 154 |
-
# Check session limits
|
| 155 |
can_start, used, limit = can_start_session(user_info["username"], user_info["is_pro"])
|
| 156 |
|
| 157 |
return templates.TemplateResponse("index.html", {
|
|
@@ -172,14 +168,12 @@ async def oauth_callback(code: str, state: Optional[str] = None):
|
|
| 172 |
redirect_uri = f"https://{SPACE_HOST}/oauth/callback"
|
| 173 |
|
| 174 |
try:
|
| 175 |
-
# Exchange code for token
|
| 176 |
token_data = await exchange_code_for_token(code, redirect_uri)
|
| 177 |
access_token = token_data.get("access_token")
|
| 178 |
|
| 179 |
if not access_token:
|
| 180 |
raise HTTPException(status_code=400, detail="No access token received")
|
| 181 |
|
| 182 |
-
# Redirect to home with token as cookie
|
| 183 |
response = RedirectResponse(url="/", status_code=302)
|
| 184 |
response.set_cookie(
|
| 185 |
key="access_token",
|
|
@@ -187,7 +181,7 @@ async def oauth_callback(code: str, state: Optional[str] = None):
|
|
| 187 |
httponly=True,
|
| 188 |
secure=True,
|
| 189 |
samesite="lax",
|
| 190 |
-
max_age=30 * 24 * 60 * 60
|
| 191 |
)
|
| 192 |
|
| 193 |
return response
|
|
@@ -215,7 +209,6 @@ async def start_session(access_token: Optional[str] = Cookie(None)):
|
|
| 215 |
detail=f"Daily limit reached. You've used {used}/{limit} sessions today."
|
| 216 |
)
|
| 217 |
|
| 218 |
-
# Record the session
|
| 219 |
record_session(user_info["username"], user_info["is_pro"])
|
| 220 |
|
| 221 |
return {
|
|
@@ -254,4 +247,98 @@ async def logout():
|
|
| 254 |
@app.get("/health")
|
| 255 |
async def health():
|
| 256 |
"""Health check endpoint"""
|
| 257 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sqlite3
|
| 3 |
import base64
|
| 4 |
+
from datetime import datetime
|
| 5 |
from typing import Optional
|
| 6 |
+
from fastapi import FastAPI, Request, Cookie, HTTPException, WebSocket
|
| 7 |
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
| 8 |
from fastapi.templating import Jinja2Templates
|
|
|
|
| 9 |
from huggingface_hub import whoami
|
|
|
|
| 10 |
import httpx
|
| 11 |
+
import websockets
|
| 12 |
|
| 13 |
app = FastAPI()
|
| 14 |
templates = Jinja2Templates(directory=".")
|
|
|
|
| 20 |
SPACE_HOST = os.getenv("SPACE_HOST", "multimodalart-krea-realtime-video.hf.space")
|
| 21 |
OPENID_PROVIDER_URL = os.getenv("OPENID_PROVIDER_URL", "https://huggingface.co")
|
| 22 |
|
| 23 |
+
# FAL API Key from environment
|
| 24 |
+
FAL_API_KEY = os.getenv("FAL_API_KEY", "")
|
| 25 |
+
|
| 26 |
# Database setup
|
| 27 |
DB_PATH = "/data/usage.db"
|
| 28 |
|
|
|
|
| 72 |
)
|
| 73 |
conn.commit()
|
| 74 |
except sqlite3.IntegrityError:
|
|
|
|
| 75 |
pass
|
| 76 |
finally:
|
| 77 |
conn.close()
|
|
|
|
| 86 |
"""Exchange OAuth code for access token"""
|
| 87 |
token_url = f"{OPENID_PROVIDER_URL}/oauth/token"
|
| 88 |
|
|
|
|
| 89 |
credentials = f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}"
|
| 90 |
b64_credentials = base64.b64encode(credentials.encode()).decode()
|
| 91 |
|
|
|
|
| 126 |
"""Home page - check auth and show app or login"""
|
| 127 |
|
| 128 |
if not access_token:
|
|
|
|
| 129 |
return templates.TemplateResponse("index.html", {
|
| 130 |
"request": request,
|
| 131 |
"authenticated": False,
|
|
|
|
| 134 |
"space_host": SPACE_HOST
|
| 135 |
})
|
| 136 |
|
|
|
|
| 137 |
try:
|
| 138 |
user_info = await get_user_info(access_token)
|
| 139 |
except:
|
|
|
|
| 140 |
response = templates.TemplateResponse("index.html", {
|
| 141 |
"request": request,
|
| 142 |
"authenticated": False,
|
|
|
|
| 148 |
response.delete_cookie("access_token")
|
| 149 |
return response
|
| 150 |
|
|
|
|
| 151 |
can_start, used, limit = can_start_session(user_info["username"], user_info["is_pro"])
|
| 152 |
|
| 153 |
return templates.TemplateResponse("index.html", {
|
|
|
|
| 168 |
redirect_uri = f"https://{SPACE_HOST}/oauth/callback"
|
| 169 |
|
| 170 |
try:
|
|
|
|
| 171 |
token_data = await exchange_code_for_token(code, redirect_uri)
|
| 172 |
access_token = token_data.get("access_token")
|
| 173 |
|
| 174 |
if not access_token:
|
| 175 |
raise HTTPException(status_code=400, detail="No access token received")
|
| 176 |
|
|
|
|
| 177 |
response = RedirectResponse(url="/", status_code=302)
|
| 178 |
response.set_cookie(
|
| 179 |
key="access_token",
|
|
|
|
| 181 |
httponly=True,
|
| 182 |
secure=True,
|
| 183 |
samesite="lax",
|
| 184 |
+
max_age=30 * 24 * 60 * 60
|
| 185 |
)
|
| 186 |
|
| 187 |
return response
|
|
|
|
| 209 |
detail=f"Daily limit reached. You've used {used}/{limit} sessions today."
|
| 210 |
)
|
| 211 |
|
|
|
|
| 212 |
record_session(user_info["username"], user_info["is_pro"])
|
| 213 |
|
| 214 |
return {
|
|
|
|
| 247 |
@app.get("/health")
|
| 248 |
async def health():
|
| 249 |
"""Health check endpoint"""
|
| 250 |
+
return {
|
| 251 |
+
"status": "ok",
|
| 252 |
+
"oauth_enabled": bool(OAUTH_CLIENT_ID),
|
| 253 |
+
"fal_api_key_configured": bool(FAL_API_KEY)
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
@app.websocket("/ws/video-gen")
|
| 257 |
+
async def websocket_video_gen(websocket: WebSocket):
|
| 258 |
+
"""WebSocket proxy to FAL API - keeps API key secret"""
|
| 259 |
+
from fastapi import WebSocket
|
| 260 |
+
import websockets
|
| 261 |
+
import json
|
| 262 |
+
|
| 263 |
+
await websocket.accept()
|
| 264 |
+
|
| 265 |
+
# Get user from cookie
|
| 266 |
+
access_token = websocket.cookies.get("access_token")
|
| 267 |
+
if not access_token:
|
| 268 |
+
await websocket.close(code=1008, reason="Not authenticated")
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
user_info = await get_user_info(access_token)
|
| 273 |
+
except:
|
| 274 |
+
await websocket.close(code=1008, reason="Invalid session")
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
# Check if user can start session
|
| 278 |
+
can_start, used, limit = can_start_session(user_info["username"], user_info["is_pro"])
|
| 279 |
+
if not can_start:
|
| 280 |
+
await websocket.close(code=1008, reason=f"Daily limit reached ({used}/{limit})")
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
if not FAL_API_KEY:
|
| 284 |
+
await websocket.close(code=1011, reason="FAL API key not configured")
|
| 285 |
+
return
|
| 286 |
+
|
| 287 |
+
# Fetch temporary FAL token
|
| 288 |
+
try:
|
| 289 |
+
async with httpx.AsyncClient() as client:
|
| 290 |
+
response = await client.post(
|
| 291 |
+
"https://rest.alpha.fal.ai/tokens/",
|
| 292 |
+
headers={
|
| 293 |
+
"Content-Type": "application/json",
|
| 294 |
+
"Authorization": f"Key {FAL_API_KEY}"
|
| 295 |
+
},
|
| 296 |
+
json={
|
| 297 |
+
"allowed_apps": ["krea-wan-14b"],
|
| 298 |
+
"token_expiration": 5000
|
| 299 |
+
}
|
| 300 |
+
)
|
| 301 |
+
response.raise_for_status()
|
| 302 |
+
fal_token = response.json()
|
| 303 |
+
except Exception as e:
|
| 304 |
+
await websocket.close(code=1011, reason=f"Failed to get FAL token: {str(e)}")
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
# Connect to FAL WebSocket
|
| 308 |
+
fal_ws_url = f"wss://fal.run/fal-ai/krea-wan-14b/ws?fal_jwt_token={fal_token}"
|
| 309 |
+
|
| 310 |
+
try:
|
| 311 |
+
async with websockets.connect(fal_ws_url) as fal_ws:
|
| 312 |
+
# Relay messages between client and FAL
|
| 313 |
+
async def client_to_fal():
|
| 314 |
+
try:
|
| 315 |
+
while True:
|
| 316 |
+
# Receive from client
|
| 317 |
+
data = await websocket.receive_bytes()
|
| 318 |
+
# Forward to FAL
|
| 319 |
+
await fal_ws.send(data)
|
| 320 |
+
except Exception as e:
|
| 321 |
+
print(f"Client to FAL error: {e}")
|
| 322 |
+
|
| 323 |
+
async def fal_to_client():
|
| 324 |
+
try:
|
| 325 |
+
while True:
|
| 326 |
+
# Receive from FAL
|
| 327 |
+
message = await fal_ws.recv()
|
| 328 |
+
# Forward to client
|
| 329 |
+
if isinstance(message, str):
|
| 330 |
+
await websocket.send_text(message)
|
| 331 |
+
else:
|
| 332 |
+
await websocket.send_bytes(message)
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"FAL to client error: {e}")
|
| 335 |
+
|
| 336 |
+
# Run both directions concurrently
|
| 337 |
+
import asyncio
|
| 338 |
+
await asyncio.gather(
|
| 339 |
+
client_to_fal(),
|
| 340 |
+
fal_to_client()
|
| 341 |
+
)
|
| 342 |
+
except Exception as e:
|
| 343 |
+
print(f"WebSocket proxy error: {e}")
|
| 344 |
+
await websocket.close(code=1011, reason=str(e))
|