Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
e3f7f9a
1
Parent(s):
3d0c02a
add disconnect mechanism
Browse files- app.py +15 -4
- index.html +30 -0
app.py
CHANGED
|
@@ -14,6 +14,9 @@ import websockets
|
|
| 14 |
app = FastAPI()
|
| 15 |
templates = Jinja2Templates(directory=".")
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
# OAuth configuration from HF Spaces environment
|
| 18 |
OAUTH_CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
|
| 19 |
OAUTH_CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET")
|
|
@@ -420,6 +423,10 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
|
|
| 420 |
|
| 421 |
await websocket.accept()
|
| 422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
# Get user from cookie
|
| 424 |
access_token = websocket.cookies.get("access_token")
|
| 425 |
if not access_token:
|
|
@@ -470,7 +477,7 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
|
|
| 470 |
|
| 471 |
# Connect to FAL WebSocket
|
| 472 |
fal_ws_url = f"wss://fal.run/fal-ai/krea-wan-14b/ws?fal_jwt_token={fal_token}"
|
| 473 |
-
|
| 474 |
try:
|
| 475 |
async with websockets.connect(fal_ws_url) as fal_ws:
|
| 476 |
# Relay messages between client and FAL
|
|
@@ -483,7 +490,7 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
|
|
| 483 |
await fal_ws.send(data)
|
| 484 |
except Exception as e:
|
| 485 |
print(f"Client to FAL error: {e}")
|
| 486 |
-
|
| 487 |
async def fal_to_client():
|
| 488 |
try:
|
| 489 |
while True:
|
|
@@ -496,7 +503,7 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
|
|
| 496 |
await websocket.send_bytes(message)
|
| 497 |
except Exception as e:
|
| 498 |
print(f"FAL to client error: {e}")
|
| 499 |
-
|
| 500 |
# Run both directions concurrently
|
| 501 |
import asyncio
|
| 502 |
await asyncio.gather(
|
|
@@ -505,4 +512,8 @@ async def websocket_video_gen(websocket: WebSocket, user_fal_key: Optional[str]
|
|
| 505 |
)
|
| 506 |
except Exception as e:
|
| 507 |
print(f"WebSocket proxy error: {e}")
|
| 508 |
-
await websocket.close(code=1011, reason=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
app = FastAPI()
|
| 15 |
templates = Jinja2Templates(directory=".")
|
| 16 |
|
| 17 |
+
# Track active WebSocket connections
|
| 18 |
+
active_websockets = set()
|
| 19 |
+
|
| 20 |
# OAuth configuration from HF Spaces environment
|
| 21 |
OAUTH_CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
|
| 22 |
OAUTH_CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET")
|
|
|
|
| 423 |
|
| 424 |
await websocket.accept()
|
| 425 |
|
| 426 |
+
# Track this connection
|
| 427 |
+
active_websockets.add(websocket)
|
| 428 |
+
print(f"WebSocket connected. Active connections: {len(active_websockets)}")
|
| 429 |
+
|
| 430 |
# Get user from cookie
|
| 431 |
access_token = websocket.cookies.get("access_token")
|
| 432 |
if not access_token:
|
|
|
|
| 477 |
|
| 478 |
# Connect to FAL WebSocket
|
| 479 |
fal_ws_url = f"wss://fal.run/fal-ai/krea-wan-14b/ws?fal_jwt_token={fal_token}"
|
| 480 |
+
|
| 481 |
try:
|
| 482 |
async with websockets.connect(fal_ws_url) as fal_ws:
|
| 483 |
# Relay messages between client and FAL
|
|
|
|
| 490 |
await fal_ws.send(data)
|
| 491 |
except Exception as e:
|
| 492 |
print(f"Client to FAL error: {e}")
|
| 493 |
+
|
| 494 |
async def fal_to_client():
|
| 495 |
try:
|
| 496 |
while True:
|
|
|
|
| 503 |
await websocket.send_bytes(message)
|
| 504 |
except Exception as e:
|
| 505 |
print(f"FAL to client error: {e}")
|
| 506 |
+
|
| 507 |
# Run both directions concurrently
|
| 508 |
import asyncio
|
| 509 |
await asyncio.gather(
|
|
|
|
| 512 |
)
|
| 513 |
except Exception as e:
|
| 514 |
print(f"WebSocket proxy error: {e}")
|
| 515 |
+
await websocket.close(code=1011, reason=str(e))
|
| 516 |
+
finally:
|
| 517 |
+
# Remove from active connections
|
| 518 |
+
active_websockets.discard(websocket)
|
| 519 |
+
print(f"WebSocket disconnected. Active connections: {len(active_websockets)}")
|
index.html
CHANGED
|
@@ -1210,6 +1210,8 @@
|
|
| 1210 |
promptUpdateTimer: null,
|
| 1211 |
pendingPromptUpdate: null,
|
| 1212 |
generationFinished: false,
|
|
|
|
|
|
|
| 1213 |
|
| 1214 |
init() {
|
| 1215 |
this.setupEventListeners();
|
|
@@ -1449,6 +1451,7 @@
|
|
| 1449 |
|
| 1450 |
this.ws.onopen = () => {
|
| 1451 |
this.showInfo('Connected! Waiting for ready signal...');
|
|
|
|
| 1452 |
};
|
| 1453 |
|
| 1454 |
this.ws.onmessage = async (event) => {
|
|
@@ -1590,6 +1593,9 @@
|
|
| 1590 |
this.frameCount++;
|
| 1591 |
document.getElementById('frameCount').textContent = this.frameCount;
|
| 1592 |
|
|
|
|
|
|
|
|
|
|
| 1593 |
if (this.frameCount === 1) {
|
| 1594 |
document.getElementById('spinner').classList.add('hidden');
|
| 1595 |
this.startPlaybackLoop();
|
|
@@ -1893,7 +1899,31 @@
|
|
| 1893 |
}
|
| 1894 |
},
|
| 1895 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1896 |
disconnect() {
|
|
|
|
|
|
|
| 1897 |
if (this.ws) {
|
| 1898 |
this.ws.close();
|
| 1899 |
this.ws = null;
|
|
|
|
| 1210 |
promptUpdateTimer: null,
|
| 1211 |
pendingPromptUpdate: null,
|
| 1212 |
generationFinished: false,
|
| 1213 |
+
idleTimeout: null,
|
| 1214 |
+
lastFrameTime: null,
|
| 1215 |
|
| 1216 |
init() {
|
| 1217 |
this.setupEventListeners();
|
|
|
|
| 1451 |
|
| 1452 |
this.ws.onopen = () => {
|
| 1453 |
this.showInfo('Connected! Waiting for ready signal...');
|
| 1454 |
+
this.startIdleTimeout();
|
| 1455 |
};
|
| 1456 |
|
| 1457 |
this.ws.onmessage = async (event) => {
|
|
|
|
| 1593 |
this.frameCount++;
|
| 1594 |
document.getElementById('frameCount').textContent = this.frameCount;
|
| 1595 |
|
| 1596 |
+
// Reset idle timeout on every frame
|
| 1597 |
+
this.startIdleTimeout();
|
| 1598 |
+
|
| 1599 |
if (this.frameCount === 1) {
|
| 1600 |
document.getElementById('spinner').classList.add('hidden');
|
| 1601 |
this.startPlaybackLoop();
|
|
|
|
| 1899 |
}
|
| 1900 |
},
|
| 1901 |
|
| 1902 |
+
startIdleTimeout() {
|
| 1903 |
+
// Clear existing timeout
|
| 1904 |
+
if (this.idleTimeout) {
|
| 1905 |
+
clearTimeout(this.idleTimeout);
|
| 1906 |
+
}
|
| 1907 |
+
|
| 1908 |
+
// Set new timeout - disconnect after 15 seconds of no frames
|
| 1909 |
+
this.idleTimeout = setTimeout(() => {
|
| 1910 |
+
if (this.isGenerating) {
|
| 1911 |
+
this.showError('The FAL API server is too busy, try again soon!');
|
| 1912 |
+
this.disconnect();
|
| 1913 |
+
}
|
| 1914 |
+
}, 15000);
|
| 1915 |
+
},
|
| 1916 |
+
|
| 1917 |
+
clearIdleTimeout() {
|
| 1918 |
+
if (this.idleTimeout) {
|
| 1919 |
+
clearTimeout(this.idleTimeout);
|
| 1920 |
+
this.idleTimeout = null;
|
| 1921 |
+
}
|
| 1922 |
+
},
|
| 1923 |
+
|
| 1924 |
disconnect() {
|
| 1925 |
+
this.clearIdleTimeout();
|
| 1926 |
+
|
| 1927 |
if (this.ws) {
|
| 1928 |
this.ws.close();
|
| 1929 |
this.ws = null;
|