Spaces:
Runtime error
Runtime error
| import logging | |
| import torch | |
| import base64 | |
| import os | |
| from indexify_extractor_sdk import Content, Extractor, Feature | |
| from pyannote.audio import Pipeline | |
| from transformers import pipeline, AutoModelForCausalLM | |
| from .diarization_utils import diarize | |
| from huggingface_hub import HfApi | |
| from starlette.exceptions import HTTPException | |
| from pydantic import BaseModel | |
| from pydantic_settings import BaseSettings | |
| from typing import Optional, Literal, List, Union | |
| logger = logging.getLogger(__name__) | |
| token = os.getenv('HF_TOKEN') | |
| class ModelSettings(BaseSettings): | |
| asr_model: str = "openai/whisper-large-v3" | |
| assistant_model: Optional[str] = "distil-whisper/distil-large-v3" | |
| diarization_model: Optional[str] = "pyannote/speaker-diarization-3.1" | |
| hf_token: Optional[str] = token | |
| model_settings = ModelSettings() | |
| class ASRExtractorConfig(BaseModel): | |
| task: Literal["transcribe", "translate"] = "transcribe" | |
| batch_size: int = 24 | |
| assisted: bool = False | |
| chunk_length_s: int = 30 | |
| sampling_rate: int = 16000 | |
| language: Optional[str] = None | |
| num_speakers: Optional[int] = None | |
| min_speakers: Optional[int] = None | |
| max_speakers: Optional[int] = None | |
| class ASRExtractor(Extractor): | |
| name = "tensorlake/asrdiarization" | |
| description = "Powerful ASR + diarization + speculative decoding." | |
| system_dependencies = ["ffmpeg"] | |
| input_mime_types = ["audio", "audio/mpeg"] | |
| def __init__(self): | |
| super(ASRExtractor, self).__init__() | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| logger.info(f"Using device: {device.type}") | |
| torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 | |
| self.assistant_model = AutoModelForCausalLM.from_pretrained( | |
| model_settings.assistant_model, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True | |
| ) if model_settings.assistant_model else None | |
| if self.assistant_model: | |
| self.assistant_model.to(device) | |
| self.asr_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_settings.asr_model, | |
| torch_dtype=torch_dtype, | |
| device=device | |
| ) | |
| if model_settings.diarization_model: | |
| # diarization pipeline doesn't raise if there is no token | |
| HfApi().whoami(model_settings.hf_token) | |
| self.diarization_pipeline = Pipeline.from_pretrained( | |
| checkpoint_path=model_settings.diarization_model, | |
| use_auth_token=model_settings.hf_token, | |
| ) | |
| self.diarization_pipeline.to(device) | |
| else: | |
| self.diarization_pipeline = None | |
| def extract(self, content: Content, params: ASRExtractorConfig) -> List[Union[Feature, Content]]: | |
| file = base64.b64decode(content.data) | |
| logger.info(f"inference params: {params}") | |
| generate_kwargs = { | |
| "task": params.task, | |
| "language": params.language, | |
| "assistant_model": self.assistant_model if params.assisted else None | |
| } | |
| try: | |
| asr_outputs = self.asr_pipeline( | |
| file, | |
| chunk_length_s=params.chunk_length_s, | |
| batch_size=params.batch_size, | |
| generate_kwargs=generate_kwargs, | |
| return_timestamps=True, | |
| ) | |
| except RuntimeError as e: | |
| logger.error(f"ASR inference error: {str(e)}") | |
| raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Unknown error diring ASR inference: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}") | |
| if self.diarization_pipeline: | |
| try: | |
| transcript = diarize(self.diarization_pipeline, file, params, asr_outputs) | |
| except RuntimeError as e: | |
| logger.error(f"Diarization inference error: {str(e)}") | |
| raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Unknown error during diarization: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") | |
| else: | |
| transcript = [] | |
| feature = Feature.metadata(value={"chunks": asr_outputs["chunks"], "text": asr_outputs["text"]}) | |
| return [Content.from_text(str(transcript), features=[feature])] | |
| def sample_input(self) -> Content: | |
| filepath = "sample.mp3" | |
| with open(filepath, 'rb') as f: | |
| audio_encoded = base64.b64encode(f.read()).decode("utf-8") | |
| return Content(content_type="audio/mpeg", data=audio_encoded) | |
| if __name__ == "__main__": | |
| filepath = "sample.mp3" | |
| with open(filepath, 'rb') as f: | |
| audio_encoded = base64.b64encode(f.read()).decode("utf-8") | |
| data = Content(content_type="audio/mpeg", data=audio_encoded) | |
| params = ASRExtractorConfig(batch_size=24) | |
| extractor = ASRExtractor() | |
| results = extractor.extract(data, params=params) | |
| print(results) |