File size: 2,720 Bytes
6cf18f3 104f7bd ff49b3c 566619b d5886b3 b9f09f7 69c754e 104f7bd d5886b3 b9f09f7 104f7bd 6cf18f3 5f66658 d920423 fd90ec3 d920423 5f66658 0cf8b89 ff49b3c 5f66658 6cf18f3 5fdb08e 6cf18f3 5fdb08e 104f7bd 6cf18f3 0cf8b89 6cf18f3 5fdb08e b9f09f7 6cccf9e 566619b 1cea20a 104f7bd b9f09f7 104f7bd d5886b3 ff49b3c 566619b 104f7bd 566619b 104f7bd d5886b3 104f7bd 566619b 69c754e b0d49ce 5fdb08e 104f7bd 5fdb08e df1de84 d5886b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
import asyncio
import logging
from parse_args import parse_args
from audio_processor import AudioProcessor
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
args = parse_args()
@asynccontextmanager
async def lifespan(app: FastAPI):
global asr, tokenizer, diarization
if args.transcription:
asr, tokenizer = backend_factory(args)
warmup_asr(asr, args.warmup_file)
else:
asr, tokenizer = None, None
if args.diarization:
from diarization.diarization_online import DiartDiarization
diarization = DiartDiarization()
else :
diarization = None
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load demo HTML for the root endpoint
with open("web/live_transcription.html", "r", encoding="utf-8") as f:
html = f.read()
@app.get("/")
async def get():
return HTMLResponse(html)
async def handle_websocket_results(websocket, results_generator):
"""Consumes results from the audio processor and sends them via WebSocket."""
try:
async for response in results_generator:
await websocket.send_json(response)
except Exception as e:
logger.warning(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
audio_processor = AudioProcessor(args, asr, tokenizer)
await websocket.accept()
logger.info("WebSocket connection opened.")
results_generator = await audio_processor.create_tasks(diarization)
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
try:
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
except WebSocketDisconnect:
logger.warning("WebSocket disconnected.")
finally:
websocket_task.cancel()
await audio_processor.cleanup()
logger.info("WebSocket endpoint cleaned up.")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
log_level="info"
) |