apolinario commited on
Commit
e3f7f9a
·
1 Parent(s): 3d0c02a

add disconnect mechanism

Browse files
Files changed (2) hide show
  1. app.py +15 -4
  2. index.html +30 -0
app.py CHANGED
@@ -14,6 +14,9 @@ import websockets
14
  app = FastAPI()
15
  templates = Jinja2Templates(directory=".")
16
 
 
 
 
17
  # OAuth configuration from HF Spaces environment
18
  OAUTH_CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
19
  OAUTH_CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET")
@@ -420,6 +423,10 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
420
 
421
  await websocket.accept()
422
 
 
 
 
 
423
  # Get user from cookie
424
  access_token = websocket.cookies.get("access_token")
425
  if not access_token:
@@ -470,7 +477,7 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
470
 
471
  # Connect to FAL WebSocket
472
  fal_ws_url = f"wss://fal.run/fal-ai/krea-wan-14b/ws?fal_jwt_token={fal_token}"
473
-
474
  try:
475
  async with websockets.connect(fal_ws_url) as fal_ws:
476
  # Relay messages between client and FAL
@@ -483,7 +490,7 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
483
  await fal_ws.send(data)
484
  except Exception as e:
485
  print(f"Client to FAL error: {e}")
486
-
487
  async def fal_to_client():
488
  try:
489
  while True:
@@ -496,7 +503,7 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
496
  await websocket.send_bytes(message)
497
  except Exception as e:
498
  print(f"FAL to client error: {e}")
499
-
500
  # Run both directions concurrently
501
  import asyncio
502
  await asyncio.gather(
@@ -505,4 +512,8 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
505
  )
506
  except Exception as e:
507
  print(f"WebSocket proxy error: {e}")
508
- await websocket.close(code=1011, reason=str(e))
 
 
 
 
 
14
  app = FastAPI()
15
  templates = Jinja2Templates(directory=".")
16
 
17
+ # Track active WebSocket connections
18
+ active_websockets = set()
19
+
20
  # OAuth configuration from HF Spaces environment
21
  OAUTH_CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
22
  OAUTH_CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET")
 
423
 
424
  await websocket.accept()
425
 
426
+ # Track this connection
427
+ active_websockets.add(websocket)
428
+ print(f"WebSocket connected. Active connections: {len(active_websockets)}")
429
+
430
  # Get user from cookie
431
  access_token = websocket.cookies.get("access_token")
432
  if not access_token:
 
477
 
478
  # Connect to FAL WebSocket
479
  fal_ws_url = f"wss://fal.run/fal-ai/krea-wan-14b/ws?fal_jwt_token={fal_token}"
480
+
481
  try:
482
  async with websockets.connect(fal_ws_url) as fal_ws:
483
  # Relay messages between client and FAL
 
490
  await fal_ws.send(data)
491
  except Exception as e:
492
  print(f"Client to FAL error: {e}")
493
+
494
  async def fal_to_client():
495
  try:
496
  while True:
 
503
  await websocket.send_bytes(message)
504
  except Exception as e:
505
  print(f"FAL to client error: {e}")
506
+
507
  # Run both directions concurrently
508
  import asyncio
509
  await asyncio.gather(
 
512
  )
513
  except Exception as e:
514
  print(f"WebSocket proxy error: {e}")
515
+ await websocket.close(code=1011, reason=str(e))
516
+ finally:
517
+ # Remove from active connections
518
+ active_websockets.discard(websocket)
519
+ print(f"WebSocket disconnected. Active connections: {len(active_websockets)}")
index.html CHANGED
@@ -1210,6 +1210,8 @@
1210
  promptUpdateTimer: null,
1211
  pendingPromptUpdate: null,
1212
  generationFinished: false,
 
 
1213
 
1214
  init() {
1215
  this.setupEventListeners();
@@ -1449,6 +1451,7 @@
1449
 
1450
  this.ws.onopen = () => {
1451
  this.showInfo('Connected! Waiting for ready signal...');
 
1452
  };
1453
 
1454
  this.ws.onmessage = async (event) => {
@@ -1590,6 +1593,9 @@
1590
  this.frameCount++;
1591
  document.getElementById('frameCount').textContent = this.frameCount;
1592
 
 
 
 
1593
  if (this.frameCount === 1) {
1594
  document.getElementById('spinner').classList.add('hidden');
1595
  this.startPlaybackLoop();
@@ -1893,7 +1899,31 @@
1893
  }
1894
  },
1895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1896
  disconnect() {
 
 
1897
  if (this.ws) {
1898
  this.ws.close();
1899
  this.ws = null;
 
1210
  promptUpdateTimer: null,
1211
  pendingPromptUpdate: null,
1212
  generationFinished: false,
1213
+ idleTimeout: null,
1214
+ lastFrameTime: null,
1215
 
1216
  init() {
1217
  this.setupEventListeners();
 
1451
 
1452
  this.ws.onopen = () => {
1453
  this.showInfo('Connected! Waiting for ready signal...');
1454
+ this.startIdleTimeout();
1455
  };
1456
 
1457
  this.ws.onmessage = async (event) => {
 
1593
  this.frameCount++;
1594
  document.getElementById('frameCount').textContent = this.frameCount;
1595
 
1596
+ // Reset idle timeout on every frame
1597
+ this.startIdleTimeout();
1598
+
1599
  if (this.frameCount === 1) {
1600
  document.getElementById('spinner').classList.add('hidden');
1601
  this.startPlaybackLoop();
 
1899
  }
1900
  },
1901
 
1902
+ startIdleTimeout() {
1903
+ // Clear existing timeout
1904
+ if (this.idleTimeout) {
1905
+ clearTimeout(this.idleTimeout);
1906
+ }
1907
+
1908
+ // Set new timeout - disconnect after 15 seconds of no frames
1909
+ this.idleTimeout = setTimeout(() => {
1910
+ if (this.isGenerating) {
1911
+ this.showError('The FAL API server is too busy, try again soon!');
1912
+ this.disconnect();
1913
+ }
1914
+ }, 15000);
1915
+ },
1916
+
1917
+ clearIdleTimeout() {
1918
+ if (this.idleTimeout) {
1919
+ clearTimeout(this.idleTimeout);
1920
+ this.idleTimeout = null;
1921
+ }
1922
+ },
1923
+
1924
  disconnect() {
1925
+ this.clearIdleTimeout();
1926
+
1927
  if (this.ws) {
1928
  this.ws.close();
1929
  this.ws = null;