Spaces:
Paused
Paused
| import sys | |
| import numpy as np | |
| import logging | |
| from typing import List, Tuple, Optional | |
| from timed_objects import ASRToken, Sentence, Transcript | |
| logger = logging.getLogger(__name__) | |
| class HypothesisBuffer: | |
| """ | |
| Buffer to store and process ASR hypothesis tokens. | |
| It holds: | |
| - committed_in_buffer: tokens that have been confirmed (committed) | |
| - buffer: the last hypothesis that is not yet committed | |
| - new: new tokens coming from the recognizer | |
| """ | |
| def __init__(self, logfile=sys.stderr, confidence_validation=False): | |
| self.confidence_validation = confidence_validation | |
| self.committed_in_buffer: List[ASRToken] = [] | |
| self.buffer: List[ASRToken] = [] | |
| self.new: List[ASRToken] = [] | |
| self.last_committed_time = 0.0 | |
| self.last_committed_word: Optional[str] = None | |
| self.logfile = logfile | |
| def insert(self, new_tokens: List[ASRToken], offset: float): | |
| """ | |
| Insert new tokens (after applying a time offset) and compare them with the | |
| already committed tokens. Only tokens that extend the committed hypothesis | |
| are added. | |
| """ | |
| # Apply the offset to each token. | |
| new_tokens = [token.with_offset(offset) for token in new_tokens] | |
| # Only keep tokens that are roughly "new" | |
| self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1] | |
| if self.new: | |
| first_token = self.new[0] | |
| if abs(first_token.start - self.last_committed_time) < 1: | |
| if self.committed_in_buffer: | |
| committed_len = len(self.committed_in_buffer) | |
| new_len = len(self.new) | |
| # Try to match 1 to 5 consecutive tokens | |
| max_ngram = min(min(committed_len, new_len), 5) | |
| for i in range(1, max_ngram + 1): | |
| committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:]) | |
| new_ngram = " ".join(token.text for token in self.new[:i]) | |
| if committed_ngram == new_ngram: | |
| removed = [] | |
| for _ in range(i): | |
| removed_token = self.new.pop(0) | |
| removed.append(repr(removed_token)) | |
| logger.debug(f"Removing last {i} words: {' '.join(removed)}") | |
| break | |
| def flush(self) -> List[ASRToken]: | |
| """ | |
| Returns the committed chunk, defined as the longest common prefix | |
| between the previous hypothesis and the new tokens. | |
| """ | |
| committed: List[ASRToken] = [] | |
| while self.new: | |
| current_new = self.new[0] | |
| if self.confidence_validation and current_new.probability and current_new.probability > 0.95: | |
| committed.append(current_new) | |
| self.last_committed_word = current_new.text | |
| self.last_committed_time = current_new.end | |
| self.new.pop(0) | |
| self.buffer.pop(0) if self.buffer else None | |
| elif not self.buffer: | |
| break | |
| elif current_new.text == self.buffer[0].text: | |
| committed.append(current_new) | |
| self.last_committed_word = current_new.text | |
| self.last_committed_time = current_new.end | |
| self.buffer.pop(0) | |
| self.new.pop(0) | |
| else: | |
| break | |
| self.buffer = self.new | |
| self.new = [] | |
| self.committed_in_buffer.extend(committed) | |
| return committed | |
| def pop_committed(self, time: float): | |
| """ | |
| Remove tokens (from the beginning) that have ended before `time`. | |
| """ | |
| while self.committed_in_buffer and self.committed_in_buffer[0].end <= time: | |
| self.committed_in_buffer.pop(0) | |
| class OnlineASRProcessor: | |
| """ | |
| Processes incoming audio in a streaming fashion, calling the ASR system | |
| periodically, and uses a hypothesis buffer to commit and trim recognized text. | |
| The processor supports two types of buffer trimming: | |
| - "sentence": trims at sentence boundaries (using a sentence tokenizer) | |
| - "segment": trims at fixed segment durations. | |
| """ | |
| SAMPLING_RATE = 16000 | |
| def __init__( | |
| self, | |
| asr, | |
| tokenize_method: Optional[callable] = None, | |
| buffer_trimming: Tuple[str, float] = ("segment", 15), | |
| confidence_validation = False, | |
| logfile=sys.stderr, | |
| ): | |
| """ | |
| asr: An ASR system object (for example, a WhisperASR instance) that | |
| provides a `transcribe` method, a `ts_words` method (to extract tokens), | |
| a `segments_end_ts` method, and a separator attribute `sep`. | |
| tokenize_method: A function that receives text and returns a list of sentence strings. | |
| buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment". | |
| """ | |
| self.asr = asr | |
| self.tokenize = tokenize_method | |
| self.logfile = logfile | |
| self.confidence_validation = confidence_validation | |
| self.init() | |
| self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming | |
| if self.buffer_trimming_way not in ["sentence", "segment"]: | |
| raise ValueError("buffer_trimming must be either 'sentence' or 'segment'") | |
| if self.buffer_trimming_sec <= 0: | |
| raise ValueError("buffer_trimming_sec must be positive") | |
| elif self.buffer_trimming_sec > 30: | |
| logger.warning( | |
| f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM." | |
| ) | |
| def init(self, offset: Optional[float] = None): | |
| """Initialize or reset the processing buffers.""" | |
| self.audio_buffer = np.array([], dtype=np.float32) | |
| self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation) | |
| self.buffer_time_offset = offset if offset is not None else 0.0 | |
| self.transcript_buffer.last_committed_time = self.buffer_time_offset | |
| self.committed: List[ASRToken] = [] | |
| def insert_audio_chunk(self, audio: np.ndarray): | |
| """Append an audio chunk (a numpy array) to the current audio buffer.""" | |
| self.audio_buffer = np.append(self.audio_buffer, audio) | |
| def prompt(self) -> Tuple[str, str]: | |
| """ | |
| Returns a tuple: (prompt, context), where: | |
| - prompt is a 200-character suffix of committed text that falls | |
| outside the current audio buffer. | |
| - context is the committed text within the current audio buffer. | |
| """ | |
| k = len(self.committed) | |
| while k > 0 and self.committed[k - 1].end > self.buffer_time_offset: | |
| k -= 1 | |
| prompt_tokens = self.committed[:k] | |
| prompt_words = [token.text for token in prompt_tokens] | |
| prompt_list = [] | |
| length_count = 0 | |
| # Use the last words until reaching 200 characters. | |
| while prompt_words and length_count < 200: | |
| word = prompt_words.pop(-1) | |
| length_count += len(word) + 1 | |
| prompt_list.append(word) | |
| non_prompt_tokens = self.committed[k:] | |
| context_text = self.asr.sep.join(token.text for token in non_prompt_tokens) | |
| return self.asr.sep.join(prompt_list[::-1]), context_text | |
| def get_buffer(self): | |
| """ | |
| Get the unvalidated buffer in string format. | |
| """ | |
| return self.concatenate_tokens(self.transcript_buffer.buffer) | |
| def process_iter(self) -> Transcript: | |
| """ | |
| Processes the current audio buffer. | |
| Returns a Transcript object representing the committed transcript. | |
| """ | |
| prompt_text, _ = self.prompt() | |
| logger.debug( | |
| f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}" | |
| ) | |
| res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text) | |
| tokens = self.asr.ts_words(res) # Expecting List[ASRToken] | |
| self.transcript_buffer.insert(tokens, self.buffer_time_offset) | |
| committed_tokens = self.transcript_buffer.flush() | |
| self.committed.extend(committed_tokens) | |
| completed = self.concatenate_tokens(committed_tokens) | |
| logger.debug(f">>>> COMPLETE NOW: {completed.text}") | |
| incomp = self.concatenate_tokens(self.transcript_buffer.buffer) | |
| logger.debug(f"INCOMPLETE: {incomp.text}") | |
| if committed_tokens and self.buffer_trimming_way == "sentence": | |
| if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec: | |
| self.chunk_completed_sentence() | |
| s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30 | |
| if len(self.audio_buffer) / self.SAMPLING_RATE > s: | |
| self.chunk_completed_segment(res) | |
| logger.debug("Chunking segment") | |
| logger.debug( | |
| f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds" | |
| ) | |
| return committed_tokens | |
| def chunk_completed_sentence(self): | |
| """ | |
| If the committed tokens form at least two sentences, chunk the audio | |
| buffer at the end time of the penultimate sentence. | |
| """ | |
| if not self.committed: | |
| return | |
| logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed)) | |
| sentences = self.words_to_sentences(self.committed) | |
| for sentence in sentences: | |
| logger.debug(f"\tSentence: {sentence.text}") | |
| if len(sentences) < 2: | |
| return | |
| # Keep the last two sentences. | |
| while len(sentences) > 2: | |
| sentences.pop(0) | |
| chunk_time = sentences[-2].end | |
| logger.debug(f"--- Sentence chunked at {chunk_time:.2f}") | |
| self.chunk_at(chunk_time) | |
| def chunk_completed_segment(self, res): | |
| """ | |
| Chunk the audio buffer based on segment-end timestamps reported by the ASR. | |
| """ | |
| if not self.committed: | |
| return | |
| ends = self.asr.segments_end_ts(res) | |
| last_committed_time = self.committed[-1].end | |
| if len(ends) > 1: | |
| e = ends[-2] + self.buffer_time_offset | |
| while len(ends) > 2 and e > last_committed_time: | |
| ends.pop(-1) | |
| e = ends[-2] + self.buffer_time_offset | |
| if e <= last_committed_time: | |
| logger.debug(f"--- Segment chunked at {e:.2f}") | |
| self.chunk_at(e) | |
| else: | |
| logger.debug("--- Last segment not within committed area") | |
| else: | |
| logger.debug("--- Not enough segments to chunk") | |
| def chunk_at(self, time: float): | |
| """ | |
| Trim both the hypothesis and audio buffer at the given time. | |
| """ | |
| logger.debug(f"Chunking at {time:.2f}s") | |
| logger.debug( | |
| f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s" | |
| ) | |
| self.transcript_buffer.pop_committed(time) | |
| cut_seconds = time - self.buffer_time_offset | |
| self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):] | |
| self.buffer_time_offset = time | |
| logger.debug( | |
| f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s" | |
| ) | |
| def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]: | |
| """ | |
| Converts a list of tokens to a list of Sentence objects using the provided | |
| sentence tokenizer. | |
| """ | |
| if not tokens: | |
| return [] | |
| full_text = " ".join(token.text for token in tokens) | |
| if self.tokenize: | |
| try: | |
| sentence_texts = self.tokenize(full_text) | |
| except Exception as e: | |
| # Some tokenizers (e.g., MosesSentenceSplitter) expect a list input. | |
| try: | |
| sentence_texts = self.tokenize([full_text]) | |
| except Exception as e2: | |
| raise ValueError("Tokenization failed") from e2 | |
| else: | |
| sentence_texts = [full_text] | |
| sentences: List[Sentence] = [] | |
| token_index = 0 | |
| for sent_text in sentence_texts: | |
| sent_text = sent_text.strip() | |
| if not sent_text: | |
| continue | |
| sent_tokens = [] | |
| accumulated = "" | |
| # Accumulate tokens until roughly matching the length of the sentence text. | |
| while token_index < len(tokens) and len(accumulated) < len(sent_text): | |
| token = tokens[token_index] | |
| accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text | |
| sent_tokens.append(token) | |
| token_index += 1 | |
| if sent_tokens: | |
| sentence = Sentence( | |
| start=sent_tokens[0].start, | |
| end=sent_tokens[-1].end, | |
| text=" ".join(t.text for t in sent_tokens), | |
| ) | |
| sentences.append(sentence) | |
| return sentences | |
| def finish(self) -> Transcript: | |
| """ | |
| Flush the remaining transcript when processing ends. | |
| """ | |
| remaining_tokens = self.transcript_buffer.buffer | |
| final_transcript = self.concatenate_tokens(remaining_tokens) | |
| logger.debug(f"Final non-committed transcript: {final_transcript}") | |
| self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE | |
| return final_transcript | |
| def concatenate_tokens( | |
| self, | |
| tokens: List[ASRToken], | |
| sep: Optional[str] = None, | |
| offset: float = 0 | |
| ) -> Transcript: | |
| sep = sep if sep is not None else self.asr.sep | |
| text = sep.join(token.text for token in tokens) | |
| probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None | |
| if tokens: | |
| start = offset + tokens[0].start | |
| end = offset + tokens[-1].end | |
| else: | |
| start = None | |
| end = None | |
| return Transcript(start, end, text, probability=probability) | |
| class VACOnlineASRProcessor: | |
| """ | |
| Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC). | |
| It receives small chunks of audio, applies VAD (e.g. with Silero), | |
| and when the system detects a pause in speech (or end of an utterance) | |
| it finalizes the utterance immediately. | |
| """ | |
| SAMPLING_RATE = 16000 | |
| def __init__(self, online_chunk_size: float, *args, **kwargs): | |
| self.online_chunk_size = online_chunk_size | |
| self.online = OnlineASRProcessor(*args, **kwargs) | |
| # Load a VAD model (e.g. Silero VAD) | |
| import torch | |
| model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") | |
| from silero_vad_iterator import FixedVADIterator | |
| self.vac = FixedVADIterator(model) | |
| self.logfile = self.online.logfile | |
| self.init() | |
| def init(self): | |
| self.online.init() | |
| self.vac.reset_states() | |
| self.current_online_chunk_buffer_size = 0 | |
| self.is_currently_final = False | |
| self.status: Optional[str] = None # "voice" or "nonvoice" | |
| self.audio_buffer = np.array([], dtype=np.float32) | |
| self.buffer_offset = 0 # in frames | |
| def clear_buffer(self): | |
| self.buffer_offset += len(self.audio_buffer) | |
| self.audio_buffer = np.array([], dtype=np.float32) | |
| def insert_audio_chunk(self, audio: np.ndarray): | |
| """ | |
| Process an incoming small audio chunk: | |
| - run VAD on the chunk, | |
| - decide whether to send the audio to the online ASR processor immediately, | |
| - and/or to mark the current utterance as finished. | |
| """ | |
| res = self.vac(audio) | |
| self.audio_buffer = np.append(self.audio_buffer, audio) | |
| if res is not None: | |
| # VAD returned a result; adjust the frame number | |
| frame = list(res.values())[0] - self.buffer_offset | |
| if "start" in res and "end" not in res: | |
| self.status = "voice" | |
| send_audio = self.audio_buffer[frame:] | |
| self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE) | |
| self.online.insert_audio_chunk(send_audio) | |
| self.current_online_chunk_buffer_size += len(send_audio) | |
| self.clear_buffer() | |
| elif "end" in res and "start" not in res: | |
| self.status = "nonvoice" | |
| send_audio = self.audio_buffer[:frame] | |
| self.online.insert_audio_chunk(send_audio) | |
| self.current_online_chunk_buffer_size += len(send_audio) | |
| self.is_currently_final = True | |
| self.clear_buffer() | |
| else: | |
| beg = res["start"] - self.buffer_offset | |
| end = res["end"] - self.buffer_offset | |
| self.status = "nonvoice" | |
| send_audio = self.audio_buffer[beg:end] | |
| self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE) | |
| self.online.insert_audio_chunk(send_audio) | |
| self.current_online_chunk_buffer_size += len(send_audio) | |
| self.is_currently_final = True | |
| self.clear_buffer() | |
| else: | |
| if self.status == "voice": | |
| self.online.insert_audio_chunk(self.audio_buffer) | |
| self.current_online_chunk_buffer_size += len(self.audio_buffer) | |
| self.clear_buffer() | |
| else: | |
| # Keep 1 second worth of audio in case VAD later detects voice, | |
| # but trim to avoid unbounded memory usage. | |
| self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE) | |
| self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:] | |
| def process_iter(self) -> Transcript: | |
| """ | |
| Depending on the VAD status and the amount of accumulated audio, | |
| process the current audio chunk. | |
| """ | |
| if self.is_currently_final: | |
| return self.finish() | |
| elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size: | |
| self.current_online_chunk_buffer_size = 0 | |
| return self.online.process_iter() | |
| else: | |
| logger.debug("No online update, only VAD") | |
| return Transcript(None, None, "") | |
| def finish(self) -> Transcript: | |
| """Finish processing by flushing any remaining text.""" | |
| result = self.online.finish() | |
| self.current_online_chunk_buffer_size = 0 | |
| self.is_currently_final = False | |
| return result | |
| def get_buffer(self): | |
| """ | |
| Get the unvalidated buffer in string format. | |
| """ | |
| return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text | |