Spaces:
Running
on
Zero
Running
on
Zero
| import uuid | |
| import base64 | |
| import re | |
| import regex | |
| from typing import AsyncGenerator, Union | |
| import io | |
| from pydub import AudioSegment | |
| import torch | |
| import numpy as np | |
| from functools import lru_cache | |
| from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer | |
| def random_uuid() -> str: | |
| return str(uuid.uuid4().hex) | |
| async def async_generator_wrap(first_element, gen: AsyncGenerator): | |
| """Wrap an async generator with the first element.""" | |
| yield first_element | |
| async for item in gen: | |
| yield item | |
| def encode_base64_content_from_file(file_path: str) -> str: | |
| """Encode a content from a local file to base64 format.""" | |
| # Read the MP3 file as binary and encode it directly to Base64 | |
| with open(file_path, "rb") as audio_file: | |
| audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") | |
| return audio_base64 | |
| def pcm16_to_target_format( | |
| np_audio: np.ndarray, | |
| sample_rate: int, | |
| bit_depth: int, | |
| channels: int, | |
| format: str, | |
| target_rate: int, | |
| ): | |
| wav_audio = AudioSegment( | |
| np_audio.tobytes(), | |
| frame_rate=sample_rate, | |
| sample_width=bit_depth // 8, | |
| channels=channels, | |
| ) | |
| if target_rate is not None and target_rate != sample_rate: | |
| wav_audio = wav_audio.set_frame_rate(target_rate) | |
| # Convert WAV to MP3 | |
| target_io = io.BytesIO() | |
| wav_audio.export(target_io, format=format) | |
| target_io.seek(0) | |
| return target_io | |
| chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+") | |
| def contains_chinese(text: str): | |
| return bool(chinese_char_pattern.search(text)) | |
| # remove blank between chinese character | |
| def replace_blank(text: str): | |
| out_str = [] | |
| for i, c in enumerate(text): | |
| if c == " ": | |
| if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "): | |
| out_str.append(c) | |
| else: | |
| out_str.append(c) | |
| return "".join(out_str) | |
| def replace_corner_mark(text: str): | |
| text = text.replace("²", "平方") | |
| text = text.replace("³", "立方") | |
| return text | |
| # remove meaningless symbol | |
| def remove_bracket(text: str): | |
| text = text.replace("(", "").replace(")", "") | |
| text = text.replace("【", "").replace("】", "") | |
| text = text.replace("`", "").replace("`", "") | |
| text = text.replace("——", " ") | |
| return text | |
| # split paragrah logic: | |
| # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len | |
| # 2. cal sentence len according to lang | |
| # 3. split sentence according to puncatation | |
| def split_paragraph( | |
| text: str, | |
| tokenize, | |
| lang="zh", | |
| token_max_n=80, | |
| token_min_n=60, | |
| merge_len=20, | |
| comma_split=False, | |
| ): | |
| def calc_utt_length(_text: str): | |
| if lang == "zh": | |
| return len(_text) | |
| else: | |
| return len(tokenize(_text)) | |
| def should_merge(_text: str): | |
| if lang == "zh": | |
| return len(_text) < merge_len | |
| else: | |
| return len(tokenize(_text)) < merge_len | |
| if lang == "zh": | |
| pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"] | |
| else: | |
| pounc = [".", "?", "!", ";", ":"] | |
| if comma_split: | |
| pounc.extend([",", ","]) | |
| if text[-1] not in pounc: | |
| if lang == "zh": | |
| text += "。" | |
| else: | |
| text += "." | |
| st = 0 | |
| utts = [] | |
| for i, c in enumerate(text): | |
| if c in pounc: | |
| if len(text[st:i]) > 0: | |
| utts.append(text[st:i] + c) | |
| if i + 1 < len(text) and text[i + 1] in ['"', "”"]: | |
| tmp = utts.pop(-1) | |
| utts.append(tmp + text[i + 1]) | |
| st = i + 2 | |
| else: | |
| st = i + 1 | |
| final_utts = [] | |
| cur_utt = "" | |
| for utt in utts: | |
| if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: | |
| final_utts.append(cur_utt) | |
| cur_utt = "" | |
| cur_utt = cur_utt + utt | |
| if len(cur_utt) > 0: | |
| if should_merge(cur_utt) and len(final_utts) != 0: | |
| final_utts[-1] = final_utts[-1] + cur_utt | |
| else: | |
| final_utts.append(cur_utt) | |
| return final_utts | |
| def is_only_punctuation(text: str): | |
| # Regular expression: Match strings that consist only of punctuation marks or are empty. | |
| punctuation_pattern = r"^[\p{P}\p{S}]*$" | |
| return bool(regex.fullmatch(punctuation_pattern, text)) | |
| # spell Arabic numerals | |
| def spell_out_number(text: str, inflect_parser): | |
| new_text = [] | |
| st = None | |
| for i, c in enumerate(text): | |
| if not c.isdigit(): | |
| if st is not None: | |
| num_str = inflect_parser.number_to_words(text[st:i]) | |
| new_text.append(num_str) | |
| st = None | |
| new_text.append(c) | |
| else: | |
| if st is None: | |
| st = i | |
| if st is not None and st < len(text): | |
| num_str = inflect_parser.number_to_words(text[st:]) | |
| new_text.append(num_str) | |
| return "".join(new_text) | |
| def remove_emoji(text: str): | |
| # Pattern to match emojis and their modifiers | |
| # - Standard emoji range | |
| # - Zero-width joiners (U+200D) | |
| # - Variation selectors (U+FE0F, U+FE0E) | |
| # - Skin tone modifiers (U+1F3FB to U+1F3FF) | |
| emoji_pattern = re.compile( | |
| r"[" | |
| r"\U00010000-\U0010FFFF" # Standard emoji range | |
| r"\u200D" # Zero-width joiner | |
| r"\uFE0F\uFE0E" # Variation selectors | |
| r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers | |
| r"]+", | |
| flags=re.UNICODE, | |
| ) | |
| return emoji_pattern.sub(r"", text) | |
| def remove_repeated_punctuations(text, punctuations): | |
| if len(punctuations) == 0: | |
| return text | |
| pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations | |
| return re.sub(rf"({pattern})\1+", r"\1", text) | |
| def full_to_half_width(text: str) -> str: | |
| """Convert full-width punctuation to half-width in a given string.""" | |
| full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~" | |
| half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | |
| trans_table = str.maketrans(full_width, half_width) | |
| return text.translate(trans_table) | |
| def split_interleaved_delayed_audios( | |
| audio_data: Union[list[list[int]], torch.Tensor], | |
| audio_tokenizer: HiggsAudioTokenizer, | |
| audio_stream_eos_id: int, | |
| ) -> list[tuple[list[list[int]], torch.Tensor]]: | |
| separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks | |
| # Convert separator to numpy array if audio_data is numpy array | |
| if isinstance(audio_data, torch.Tensor): | |
| audio_data = audio_data.transpose(1, 0) | |
| separator = torch.tensor(separator) | |
| # Find the indices where the rows equal the separator | |
| split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0] | |
| start = 0 | |
| groups = [] | |
| for idx in split_indices: | |
| groups.append(audio_data[start:idx].transpose(1, 0)) | |
| start = idx + 1 | |
| if start < len(audio_data): | |
| groups.append(audio_data[start:].transpose(1, 0)) | |
| else: | |
| groups = [] | |
| current = [] | |
| for row in audio_data: | |
| current.append(row) | |
| if row == separator: | |
| groups.append(current) | |
| current = [] | |
| # Don't forget the last group if there's no trailing separator | |
| if current: | |
| groups.append(current) | |
| return groups | |