Morgan Funtowicz
commited on
Commit
·
8550385
1
Parent(s):
69fb91d
misc(sdk): use endpoint config parser
Browse files- handler.py +43 -31
handler.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
import os
|
| 3 |
import zlib
|
| 4 |
from functools import lru_cache
|
| 5 |
from io import BytesIO
|
|
|
|
| 6 |
from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING
|
| 7 |
|
| 8 |
import numpy as np
|
|
|
|
| 9 |
from hfendpoints.openai import Context, run
|
| 10 |
from hfendpoints.openai.audio import (
|
| 11 |
AutomaticSpeechRecognitionEndpoint,
|
|
@@ -19,22 +20,25 @@ from hfendpoints.openai.audio import (
|
|
| 19 |
)
|
| 20 |
from librosa import load as load_audio, get_duration
|
| 21 |
from loguru import logger
|
|
|
|
| 22 |
from vllm import (
|
| 23 |
AsyncEngineArgs,
|
| 24 |
AsyncLLMEngine,
|
| 25 |
SamplingParams,
|
| 26 |
)
|
| 27 |
|
| 28 |
-
from hfendpoints import Handler
|
| 29 |
|
| 30 |
if TYPE_CHECKING:
|
| 31 |
from transformers import PreTrainedTokenizer
|
| 32 |
from vllm import CompletionOutput, RequestOutput
|
| 33 |
from vllm.sequence import SampleLogprobs
|
| 34 |
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def chunk_audio_with_duration(
|
| 37 |
-
|
| 38 |
) -> Sequence[np.ndarray]:
|
| 39 |
"""
|
| 40 |
Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
|
|
@@ -63,10 +67,10 @@ def compression_ratio(text: str) -> float:
|
|
| 63 |
|
| 64 |
|
| 65 |
def create_prompt(
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
):
|
| 71 |
"""
|
| 72 |
Generate the right prompt with the specific parameters to submit for inference over Whisper
|
|
@@ -93,7 +97,7 @@ def create_prompt(
|
|
| 93 |
|
| 94 |
|
| 95 |
def create_params(
|
| 96 |
-
|
| 97 |
) -> "SamplingParams":
|
| 98 |
"""
|
| 99 |
Create sampling parameters to submit for inference through vLLM `generate`
|
|
@@ -123,12 +127,12 @@ def get_avg_logprob(logprobs: "SampleLogprobs") -> float:
|
|
| 123 |
|
| 124 |
|
| 125 |
def process_chunk(
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
) -> Generator:
|
| 133 |
"""
|
| 134 |
Decode a single transcribed audio chunk and generates all the segments associated
|
|
@@ -198,9 +202,9 @@ def process_chunk(
|
|
| 198 |
|
| 199 |
|
| 200 |
def process_chunks(
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
) -> Tuple[List[Segment], str]:
|
| 205 |
"""
|
| 206 |
Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
|
|
@@ -223,7 +227,7 @@ def process_chunks(
|
|
| 223 |
logprobs = generation.logprobs
|
| 224 |
|
| 225 |
for segment, _is_continuation in process_chunk(
|
| 226 |
-
|
| 227 |
):
|
| 228 |
materialized_segments.append(segment)
|
| 229 |
|
|
@@ -258,17 +262,17 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
| 258 |
enforce_eager=False,
|
| 259 |
enable_prefix_caching=True,
|
| 260 |
max_logprobs=1, # TODO(mfuntowicz) : Set from config?
|
| 261 |
-
disable_log_requests=True
|
| 262 |
)
|
| 263 |
)
|
| 264 |
|
| 265 |
async def transcribe(
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
) -> (List[Segment], str):
|
| 273 |
async def __agenerate__(request_id: str, prompt, params):
|
| 274 |
"""
|
|
@@ -319,14 +323,14 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
| 319 |
return segments, text
|
| 320 |
|
| 321 |
async def __call__(
|
| 322 |
-
|
| 323 |
) -> TranscriptionResponse:
|
| 324 |
with logger.contextualize(request_id=ctx.request_id):
|
| 325 |
with memoryview(request) as audio:
|
| 326 |
|
| 327 |
# Check if we need to enable the verbose path
|
| 328 |
is_verbose = (
|
| 329 |
-
|
| 330 |
)
|
| 331 |
|
| 332 |
# Retrieve the tokenizer and model config asynchronously while we decode audio
|
|
@@ -375,14 +379,22 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
| 375 |
|
| 376 |
|
| 377 |
def entrypoint():
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
|
|
|
| 381 |
endpoint = AutomaticSpeechRecognitionEndpoint(
|
| 382 |
-
WhisperHandler(
|
| 383 |
)
|
| 384 |
|
| 385 |
-
|
|
|
|
| 386 |
|
| 387 |
|
| 388 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
import zlib
|
| 3 |
from functools import lru_cache
|
| 4 |
from io import BytesIO
|
| 5 |
+
from pathlib import Path
|
| 6 |
from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
+
from hfendpoints.errors.config import UnsupportedModelArchitecture
|
| 10 |
from hfendpoints.openai import Context, run
|
| 11 |
from hfendpoints.openai.audio import (
|
| 12 |
AutomaticSpeechRecognitionEndpoint,
|
|
|
|
| 20 |
)
|
| 21 |
from librosa import load as load_audio, get_duration
|
| 22 |
from loguru import logger
|
| 23 |
+
from transformers import AutoConfig
|
| 24 |
from vllm import (
|
| 25 |
AsyncEngineArgs,
|
| 26 |
AsyncLLMEngine,
|
| 27 |
SamplingParams,
|
| 28 |
)
|
| 29 |
|
| 30 |
+
from hfendpoints import EndpointConfig, Handler, ensure_supported_architectures
|
| 31 |
|
| 32 |
if TYPE_CHECKING:
|
| 33 |
from transformers import PreTrainedTokenizer
|
| 34 |
from vllm import CompletionOutput, RequestOutput
|
| 35 |
from vllm.sequence import SampleLogprobs
|
| 36 |
|
| 37 |
+
SUPPORTED_MODEL_ARCHITECTURES = ["WhisperForConditionalGeneration"]
|
| 38 |
+
|
| 39 |
|
| 40 |
def chunk_audio_with_duration(
|
| 41 |
+
audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
|
| 42 |
) -> Sequence[np.ndarray]:
|
| 43 |
"""
|
| 44 |
Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def create_prompt(
|
| 70 |
+
audio: np.ndarray,
|
| 71 |
+
sampling_rate: int,
|
| 72 |
+
language: int,
|
| 73 |
+
timestamp_marker: int,
|
| 74 |
):
|
| 75 |
"""
|
| 76 |
Generate the right prompt with the specific parameters to submit for inference over Whisper
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
def create_params(
|
| 100 |
+
max_tokens: int, temperature: float, is_verbose: bool
|
| 101 |
) -> "SamplingParams":
|
| 102 |
"""
|
| 103 |
Create sampling parameters to submit for inference through vLLM `generate`
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def process_chunk(
|
| 130 |
+
tokenizer: "PreTrainedTokenizer",
|
| 131 |
+
ids: np.ndarray,
|
| 132 |
+
logprobs: "SampleLogprobs",
|
| 133 |
+
request: TranscriptionRequest,
|
| 134 |
+
segment_offset: int,
|
| 135 |
+
timestamp_offset: int,
|
| 136 |
) -> Generator:
|
| 137 |
"""
|
| 138 |
Decode a single transcribed audio chunk and generates all the segments associated
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
def process_chunks(
|
| 205 |
+
tokenizer: "PreTrainedTokenizer",
|
| 206 |
+
chunks: List["RequestOutput"],
|
| 207 |
+
request: TranscriptionRequest,
|
| 208 |
) -> Tuple[List[Segment], str]:
|
| 209 |
"""
|
| 210 |
Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
|
|
|
|
| 227 |
logprobs = generation.logprobs
|
| 228 |
|
| 229 |
for segment, _is_continuation in process_chunk(
|
| 230 |
+
tokenizer, ids, logprobs, request, segment_offset, time_offset
|
| 231 |
):
|
| 232 |
materialized_segments.append(segment)
|
| 233 |
|
|
|
|
| 262 |
enforce_eager=False,
|
| 263 |
enable_prefix_caching=True,
|
| 264 |
max_logprobs=1, # TODO(mfuntowicz) : Set from config?
|
| 265 |
+
disable_log_requests=True,
|
| 266 |
)
|
| 267 |
)
|
| 268 |
|
| 269 |
async def transcribe(
|
| 270 |
+
self,
|
| 271 |
+
ctx: Context,
|
| 272 |
+
request: TranscriptionRequest,
|
| 273 |
+
tokenizer: "PreTrainedTokenizer",
|
| 274 |
+
audio_chunks: Iterable[np.ndarray],
|
| 275 |
+
params: "SamplingParams",
|
| 276 |
) -> (List[Segment], str):
|
| 277 |
async def __agenerate__(request_id: str, prompt, params):
|
| 278 |
"""
|
|
|
|
| 323 |
return segments, text
|
| 324 |
|
| 325 |
async def __call__(
|
| 326 |
+
self, request: TranscriptionRequest, ctx: Context
|
| 327 |
) -> TranscriptionResponse:
|
| 328 |
with logger.contextualize(request_id=ctx.request_id):
|
| 329 |
with memoryview(request) as audio:
|
| 330 |
|
| 331 |
# Check if we need to enable the verbose path
|
| 332 |
is_verbose = (
|
| 333 |
+
request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
|
| 334 |
)
|
| 335 |
|
| 336 |
# Retrieve the tokenizer and model config asynchronously while we decode audio
|
|
|
|
| 379 |
|
| 380 |
|
| 381 |
def entrypoint():
|
| 382 |
+
# Retrieve endpoint configuration
|
| 383 |
+
endpoint_config = EndpointConfig.from_env()
|
| 384 |
+
|
| 385 |
+
# Ensure the model is compatible is pre-downloaded
|
| 386 |
+
if (model_local_path := Path(endpoint_config.model_id)).exists():
|
| 387 |
+
if (config_local_path := (model_local_path / "config.json")).exists():
|
| 388 |
+
config = AutoConfig.from_pretrained(config_local_path)
|
| 389 |
+
ensure_supported_architectures(config, SUPPORTED_MODEL_ARCHITECTURES)
|
| 390 |
|
| 391 |
+
# Initialize the endpoint
|
| 392 |
endpoint = AutomaticSpeechRecognitionEndpoint(
|
| 393 |
+
WhisperHandler(endpoint_config.model_id)
|
| 394 |
)
|
| 395 |
|
| 396 |
+
# Serve the model
|
| 397 |
+
run(endpoint, endpoint_config.interface, endpoint_config.port)
|
| 398 |
|
| 399 |
|
| 400 |
if __name__ == "__main__":
|