multimodalart HF Staff commited on
Commit
bf5b392
·
verified ·
1 Parent(s): 06ec34e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -15
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import os
2
  import sqlite3
3
  import base64
4
- from datetime import datetime, timedelta
5
  from typing import Optional
6
- from fastapi import FastAPI, Request, Response, Cookie, HTTPException
7
  from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
8
  from fastapi.templating import Jinja2Templates
9
- from fastapi.staticfiles import StaticFiles
10
  from huggingface_hub import whoami
11
- import secrets
12
  import httpx
 
13
 
14
  app = FastAPI()
15
  templates = Jinja2Templates(directory=".")
@@ -21,6 +20,9 @@ OAUTH_SCOPES = os.getenv("OAUTH_SCOPES", "openid profile")
21
  SPACE_HOST = os.getenv("SPACE_HOST", "multimodalart-krea-realtime-video.hf.space")
22
  OPENID_PROVIDER_URL = os.getenv("OPENID_PROVIDER_URL", "https://huggingface.co")
23
 
 
 
 
24
  # Database setup
25
  DB_PATH = "/data/usage.db"
26
 
@@ -70,7 +72,6 @@ def record_session(username: str, is_pro: bool):
70
  )
71
  conn.commit()
72
  except sqlite3.IntegrityError:
73
- # Already recorded today
74
  pass
75
  finally:
76
  conn.close()
@@ -85,7 +86,6 @@ async def exchange_code_for_token(code: str, redirect_uri: str) -> dict:
85
  """Exchange OAuth code for access token"""
86
  token_url = f"{OPENID_PROVIDER_URL}/oauth/token"
87
 
88
- # Prepare Basic Auth header
89
  credentials = f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}"
90
  b64_credentials = base64.b64encode(credentials.encode()).decode()
91
 
@@ -126,7 +126,6 @@ async def home(request: Request, access_token: Optional[str] = Cookie(None)):
126
  """Home page - check auth and show app or login"""
127
 
128
  if not access_token:
129
- # Not logged in - show login button
130
  return templates.TemplateResponse("index.html", {
131
  "request": request,
132
  "authenticated": False,
@@ -135,11 +134,9 @@ async def home(request: Request, access_token: Optional[str] = Cookie(None)):
135
  "space_host": SPACE_HOST
136
  })
137
 
138
- # Verify token and get user info
139
  try:
140
  user_info = await get_user_info(access_token)
141
  except:
142
- # Invalid token, clear it
143
  response = templates.TemplateResponse("index.html", {
144
  "request": request,
145
  "authenticated": False,
@@ -151,7 +148,6 @@ async def home(request: Request, access_token: Optional[str] = Cookie(None)):
151
  response.delete_cookie("access_token")
152
  return response
153
 
154
- # Check session limits
155
  can_start, used, limit = can_start_session(user_info["username"], user_info["is_pro"])
156
 
157
  return templates.TemplateResponse("index.html", {
@@ -172,14 +168,12 @@ async def oauth_callback(code: str, state: Optional[str] = None):
172
  redirect_uri = f"https://{SPACE_HOST}/oauth/callback"
173
 
174
  try:
175
- # Exchange code for token
176
  token_data = await exchange_code_for_token(code, redirect_uri)
177
  access_token = token_data.get("access_token")
178
 
179
  if not access_token:
180
  raise HTTPException(status_code=400, detail="No access token received")
181
 
182
- # Redirect to home with token as cookie
183
  response = RedirectResponse(url="/", status_code=302)
184
  response.set_cookie(
185
  key="access_token",
@@ -187,7 +181,7 @@ async def oauth_callback(code: str, state: Optional[str] = None):
187
  httponly=True,
188
  secure=True,
189
  samesite="lax",
190
- max_age=30 * 24 * 60 * 60 # 30 days
191
  )
192
 
193
  return response
@@ -215,7 +209,6 @@ async def start_session(access_token: Optional[str] = Cookie(None)):
215
  detail=f"Daily limit reached. You've used {used}/{limit} sessions today."
216
  )
217
 
218
- # Record the session
219
  record_session(user_info["username"], user_info["is_pro"])
220
 
221
  return {
@@ -254,4 +247,98 @@ async def logout():
254
  @app.get("/health")
255
  async def health():
256
  """Health check endpoint"""
257
- return {"status": "ok", "oauth_enabled": bool(OAUTH_CLIENT_ID)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sqlite3
3
  import base64
4
+ from datetime import datetime
5
  from typing import Optional
6
+ from fastapi import FastAPI, Request, Cookie, HTTPException, WebSocket
7
  from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
8
  from fastapi.templating import Jinja2Templates
 
9
  from huggingface_hub import whoami
 
10
  import httpx
11
+ import websockets
12
 
13
  app = FastAPI()
14
  templates = Jinja2Templates(directory=".")
 
20
  SPACE_HOST = os.getenv("SPACE_HOST", "multimodalart-krea-realtime-video.hf.space")
21
  OPENID_PROVIDER_URL = os.getenv("OPENID_PROVIDER_URL", "https://huggingface.co")
22
 
23
+ # FAL API Key from environment
24
+ FAL_API_KEY = os.getenv("FAL_API_KEY", "")
25
+
26
  # Database setup
27
  DB_PATH = "/data/usage.db"
28
 
 
72
  )
73
  conn.commit()
74
  except sqlite3.IntegrityError:
 
75
  pass
76
  finally:
77
  conn.close()
 
86
  """Exchange OAuth code for access token"""
87
  token_url = f"{OPENID_PROVIDER_URL}/oauth/token"
88
 
 
89
  credentials = f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}"
90
  b64_credentials = base64.b64encode(credentials.encode()).decode()
91
 
 
126
  """Home page - check auth and show app or login"""
127
 
128
  if not access_token:
 
129
  return templates.TemplateResponse("index.html", {
130
  "request": request,
131
  "authenticated": False,
 
134
  "space_host": SPACE_HOST
135
  })
136
 
 
137
  try:
138
  user_info = await get_user_info(access_token)
139
  except:
 
140
  response = templates.TemplateResponse("index.html", {
141
  "request": request,
142
  "authenticated": False,
 
148
  response.delete_cookie("access_token")
149
  return response
150
 
 
151
  can_start, used, limit = can_start_session(user_info["username"], user_info["is_pro"])
152
 
153
  return templates.TemplateResponse("index.html", {
 
168
  redirect_uri = f"https://{SPACE_HOST}/oauth/callback"
169
 
170
  try:
 
171
  token_data = await exchange_code_for_token(code, redirect_uri)
172
  access_token = token_data.get("access_token")
173
 
174
  if not access_token:
175
  raise HTTPException(status_code=400, detail="No access token received")
176
 
 
177
  response = RedirectResponse(url="/", status_code=302)
178
  response.set_cookie(
179
  key="access_token",
 
181
  httponly=True,
182
  secure=True,
183
  samesite="lax",
184
+ max_age=30 * 24 * 60 * 60
185
  )
186
 
187
  return response
 
209
  detail=f"Daily limit reached. You've used {used}/{limit} sessions today."
210
  )
211
 
 
212
  record_session(user_info["username"], user_info["is_pro"])
213
 
214
  return {
 
247
  @app.get("/health")
248
  async def health():
249
  """Health check endpoint"""
250
+ return {
251
+ "status": "ok",
252
+ "oauth_enabled": bool(OAUTH_CLIENT_ID),
253
+ "fal_api_key_configured": bool(FAL_API_KEY)
254
+ }
255
+
256
+ @app.websocket("/ws/video-gen")
257
+ async def websocket_video_gen(websocket: WebSocket):
258
+ """WebSocket proxy to FAL API - keeps API key secret"""
259
+ from fastapi import WebSocket
260
+ import websockets
261
+ import json
262
+
263
+ await websocket.accept()
264
+
265
+ # Get user from cookie
266
+ access_token = websocket.cookies.get("access_token")
267
+ if not access_token:
268
+ await websocket.close(code=1008, reason="Not authenticated")
269
+ return
270
+
271
+ try:
272
+ user_info = await get_user_info(access_token)
273
+ except:
274
+ await websocket.close(code=1008, reason="Invalid session")
275
+ return
276
+
277
+ # Check if user can start session
278
+ can_start, used, limit = can_start_session(user_info["username"], user_info["is_pro"])
279
+ if not can_start:
280
+ await websocket.close(code=1008, reason=f"Daily limit reached ({used}/{limit})")
281
+ return
282
+
283
+ if not FAL_API_KEY:
284
+ await websocket.close(code=1011, reason="FAL API key not configured")
285
+ return
286
+
287
+ # Fetch temporary FAL token
288
+ try:
289
+ async with httpx.AsyncClient() as client:
290
+ response = await client.post(
291
+ "https://rest.alpha.fal.ai/tokens/",
292
+ headers={
293
+ "Content-Type": "application/json",
294
+ "Authorization": f"Key {FAL_API_KEY}"
295
+ },
296
+ json={
297
+ "allowed_apps": ["krea-wan-14b"],
298
+ "token_expiration": 5000
299
+ }
300
+ )
301
+ response.raise_for_status()
302
+ fal_token = response.json()
303
+ except Exception as e:
304
+ await websocket.close(code=1011, reason=f"Failed to get FAL token: {str(e)}")
305
+ return
306
+
307
+ # Connect to FAL WebSocket
308
+ fal_ws_url = f"wss://fal.run/fal-ai/krea-wan-14b/ws?fal_jwt_token={fal_token}"
309
+
310
+ try:
311
+ async with websockets.connect(fal_ws_url) as fal_ws:
312
+ # Relay messages between client and FAL
313
+ async def client_to_fal():
314
+ try:
315
+ while True:
316
+ # Receive from client
317
+ data = await websocket.receive_bytes()
318
+ # Forward to FAL
319
+ await fal_ws.send(data)
320
+ except Exception as e:
321
+ print(f"Client to FAL error: {e}")
322
+
323
+ async def fal_to_client():
324
+ try:
325
+ while True:
326
+ # Receive from FAL
327
+ message = await fal_ws.recv()
328
+ # Forward to client
329
+ if isinstance(message, str):
330
+ await websocket.send_text(message)
331
+ else:
332
+ await websocket.send_bytes(message)
333
+ except Exception as e:
334
+ print(f"FAL to client error: {e}")
335
+
336
+ # Run both directions concurrently
337
+ import asyncio
338
+ await asyncio.gather(
339
+ client_to_fal(),
340
+ fal_to_client()
341
+ )
342
+ except Exception as e:
343
+ print(f"WebSocket proxy error: {e}")
344
+ await websocket.close(code=1011, reason=str(e))