Spaces:
Runtime error
Runtime error
| # Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang) | |
| # | |
| # See LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import subprocess | |
| from dataclasses import dataclass | |
| from datetime import timedelta | |
| from typing import Optional | |
| from transformers import pipeline, MarianMTModel, MarianTokenizer | |
| import numpy as np | |
| import sherpa_onnx | |
| from model import sample_rate | |
| class Segment: | |
| start: float | |
| duration: float | |
| text: str = "" | |
| cn_text: str = "" | |
| def end(self): | |
| return self.start + self.duration | |
| def __str__(self): | |
| s = f"0{timedelta(seconds=self.start)}"[:-3] | |
| s += " --> " | |
| s += f"0{timedelta(seconds=self.end)}"[:-3] | |
| s = s.replace(".", ",") | |
| s += "\n" | |
| s += self.text | |
| s += "\n" | |
| s += self.cn_text | |
| return s | |
| def decode( | |
| recognizer: sherpa_onnx.OfflineRecognizer, | |
| vad: sherpa_onnx.VoiceActivityDetector, | |
| punct: Optional[sherpa_onnx.OfflinePunctuation], | |
| filename: str, | |
| ) -> str: | |
| ffmpeg_cmd = [ | |
| "ffmpeg", | |
| "-i", | |
| filename, | |
| "-f", | |
| "s16le", | |
| "-acodec", | |
| "pcm_s16le", | |
| "-ac", | |
| "1", | |
| "-ar", | |
| str(sample_rate), | |
| "-", | |
| ] | |
| process = subprocess.Popen( | |
| ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL | |
| ) | |
| frames_per_read = int(sample_rate * 100) # 100 second | |
| window_size = 512 | |
| buffer = [] | |
| segment_list = [] | |
| logging.info("Started!") | |
| all_text = [] | |
| is_last = False | |
| while True: | |
| # *2 because int16_t has two bytes | |
| data = process.stdout.read(frames_per_read * 2) | |
| if not data: | |
| if is_last: | |
| break | |
| is_last = True | |
| data = np.zeros(sample_rate, dtype=np.int16) | |
| samples = np.frombuffer(data, dtype=np.int16) | |
| samples = samples.astype(np.float32) / 32768 | |
| buffer = np.concatenate([buffer, samples]) | |
| while len(buffer) > window_size: | |
| vad.accept_waveform(buffer[:window_size]) | |
| buffer = buffer[window_size:] | |
| streams = [] | |
| segments = [] | |
| while not vad.empty(): | |
| segment = Segment( | |
| start=vad.front.start / sample_rate, | |
| duration=len(vad.front.samples) / sample_rate, | |
| ) | |
| segments.append(segment) | |
| stream = recognizer.create_stream() | |
| stream.accept_waveform(sample_rate, vad.front.samples) | |
| streams.append(stream) | |
| vad.pop() | |
| for s in streams: | |
| recognizer.decode_stream(s) | |
| for seg, stream in zip(segments, streams): | |
| en_text = stream.result.text.strip() | |
| seg.text = en_text | |
| if len(seg.text) == 0: | |
| logging.info("Skip empty segment") | |
| continue | |
| seg.cn_text = _llm_translator.translate(en_text) | |
| if len(all_text) == 0: | |
| all_text.append(seg.text) | |
| elif len(all_text[-1][0].encode()) == 1 and len(seg.text[0].encode()) == 1: | |
| all_text.append(" ") | |
| all_text.append(seg.text) | |
| else: | |
| all_text.append(seg.text) | |
| if punct is not None: | |
| seg.text = punct.add_punctuation(seg.text) | |
| segment_list.append(seg) | |
| all_text = "".join(all_text) | |
| if punct is not None: | |
| all_text = punct.add_punctuation(all_text) | |
| return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text | |
| def translate_en_to_cn(src_text: str, ) -> str: | |
| model_name = "Helsinki-NLP/opus-mt-en-zh" | |
| tokenizer = MarianTokenizer.from_pretrained(model_name) | |
| model = MarianMTModel.from_pretrained(model_name) | |
| translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True)) | |
| res = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] | |
| return res | |
| class LLMTranslator: | |
| _tokenizer: MarianTokenizer | |
| _model: MarianMTModel | |
| def __init__(self): | |
| model_name = "Helsinki-NLP/opus-mt-en-zh" | |
| self._tokenizer = MarianTokenizer.from_pretrained(model_name) | |
| self._model = MarianMTModel.from_pretrained(model_name) | |
| def translate(self, src_text: str) -> str: | |
| translated = self._model.generate(**self._tokenizer(src_text, return_tensors="pt", padding=True)) | |
| res = [self._tokenizer.decode(t, skip_special_tokens=True) for t in translated] | |
| return "".join(str(itemText) for itemText in res) | |
| _llm_translator = LLMTranslator() |