Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import re | |
| import logging | |
| import tempfile | |
| from datetime import datetime | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from snac import SNAC | |
| import gradio as gr | |
| import numpy as np | |
| # ============================= | |
| # Logging | |
| # ============================= | |
| logging.basicConfig( | |
| filename="tts_app.log", | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| # global flags | |
| # ============================= | |
| # Enable TF32 where available (Ampere+ GPUs) for faster matmuls with minimal quality loss | |
| try: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| except Exception: | |
| pass | |
| # Prefer high-precision matmul kernels on CPU when needed | |
| try: | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| pass | |
| # ============================= | |
| # Device & dtype selection | |
| # ============================= | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 # safer on CPU | |
| # Load models once at startup | |
| # ============================= | |
| # Model names | |
| # ============================= | |
| voice_model_name = "webbigdata/VoiceCore" | |
| snac_model_name = "hubertsiuzdak/snac_24khz" | |
| # ============================= | |
| # Load models (once) | |
| # ============================= | |
| logging.info("Loading models…") | |
| voice_model = AutoModelForCausalLM.from_pretrained( | |
| voice_model_name, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| use_cache=True, | |
| ) | |
| voice_tokenizer = AutoTokenizer.from_pretrained(voice_model_name) | |
| # compile for extra speed on PyTorch 2.0+ | |
| try: | |
| voice_model = torch.compile(voice_model) | |
| logging.info("voice_model compiled with torch.compile") | |
| except Exception as e: | |
| logging.info(f"torch.compile unavailable or failed: {e}") | |
| snac_model = SNAC.from_pretrained(snac_model_name) | |
| # Move SNAC to same device. Keep default dtype for safety. | |
| snac_model.to(device) | |
| # ============================= | |
| # Helpers | |
| # ============================= | |
| # Security: sanitize and limit input text | |
| SANITIZE_RX = re.compile(r"[\x00-\x1F\x7F]") | |
| # Security: sanitize and limit input text | |
| def sanitize_text(text, max_length=500): | |
| # Remove any non-printable or control characters | |
| clean_text = SANITIZE_RX.sub("", text or "") | |
| # Limit text length | |
| if len(clean_text) > max_length: | |
| clean_text = clean_text[:max_length] | |
| return clean_text.strip() | |
| # ============================= | |
| # Core generation | |
| # ============================= | |
| def generate_voice(voice_type: str, text: str, max_new_tokens: int = 2048, temperature: float = 0.6, top_p: float = 0.9): | |
| # Log request | |
| logging.info( | |
| f"Request received - Voice: {voice_type}, Text length: {0 if text is None else len(text)}" | |
| ) | |
| # Sanitize input | |
| text = sanitize_text(text) | |
| chosen_voice = f"{voice_type}[neutral]" | |
| prompt = f"{chosen_voice}: {text}" | |
| # Tokenization directly to device | |
| input_ids = voice_tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| # Prepend/append special tokens on-device | |
| start_token = torch.tensor([[128259]], dtype=torch.long, device=device) | |
| end_tokens = torch.tensor([[128009, 128260, 128261]], dtype=torch.long, device=device) | |
| input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) | |
| # Attention mask on-device | |
| attention_mask = torch.ones_like(input_ids, device=device) | |
| # Faster decoding settings | |
| try: | |
| generated_ids = voice_model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=1.1, | |
| eos_token_id=128258, | |
| use_cache=True, | |
| ) | |
| except Exception as e: | |
| logging.error(f"Generation error: {e}") | |
| raise RuntimeError("Error during voice generation") | |
| # Post-process tokens | |
| token_to_find = 128257 | |
| token_to_remove = 128258 | |
| token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) | |
| if len(token_indices[1]) > 0: | |
| last_occurrence_idx = token_indices[1][-1].item() | |
| cropped_tensor = generated_ids[:, last_occurrence_idx + 1 :] | |
| else: | |
| cropped_tensor = generated_ids | |
| processed_row = cropped_tensor[0][cropped_tensor[0] != token_to_remove] | |
| code_list = processed_row.tolist() | |
| new_length = (len(code_list) // 7) * 7 | |
| code_list = [t - 128266 for t in code_list[:new_length]] | |
| layer_1, layer_2, layer_3 = [], [], [] | |
| for i in range(len(code_list) // 7): | |
| layer_1.append(code_list[7 * i]) | |
| layer_2.append(code_list[7 * i + 1] - 4096) | |
| layer_3.append(code_list[7 * i + 2] - 8192) | |
| layer_3.append(code_list[7 * i + 3] - 12288) | |
| layer_2.append(code_list[7 * i + 4] - 16384) | |
| layer_3.append(code_list[7 * i + 5] - 20480) | |
| layer_3.append(code_list[7 * i + 6] - 24576) | |
| codes = [ | |
| torch.tensor(layer_1, device=device).unsqueeze(0), | |
| torch.tensor(layer_2, device=device).unsqueeze(0), | |
| torch.tensor(layer_3, device=device).unsqueeze(0), | |
| ] | |
| # SNAC decode on the same device | |
| audio = snac_model.decode(codes) | |
| # Ensure float32 on CPU for Gradio numpy output | |
| audio_np = audio.detach().squeeze().float().cpu().numpy() | |
| # Return numpy audio directly (avoids disk I/O) | |
| sample_rate = 24000 | |
| return sample_rate, audio_np | |
| # ============================= | |
| # Gradio UI | |
| # ============================= | |
| voices = [ | |
| "amitaro_female", | |
| "matsukaze_male", | |
| "naraku_female", | |
| "shiguu_male", | |
| "sayoko_female", | |
| "dahara1_male", | |
| ] | |
| with gr.Blocks(title="VoiceCore TTS — Fast") as iface: | |
| gr.Markdown("# VoiceCore TTS — Fast Mode\nGenerate speech from text using VoiceCore + SNAC (optimized).") | |
| with gr.Row(): | |
| voice_dd = gr.Dropdown(label="Voice Type", choices=voices, value="matsukaze_male") | |
| max_new = gr.Slider(64, 8192, value=2048, step=64, label="Max New Tokens (lower = faster)") | |
| with gr.Row(): | |
| temp = gr.Slider(0.1, 1.2, value=0.6, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| text_in = gr.Textbox(label="Text", lines=4, placeholder="Type what you want the voice to say…") | |
| audio_out = gr.Audio(type="numpy", label="Generated Audio", streaming=False) | |
| def _wrap(voice, text, mx, t, p): | |
| return generate_voice(voice, text, int(mx), float(t), float(p)) | |
| gen_btn = gr.Button("Generate") | |
| gen_btn.click(_wrap, inputs=[voice_dd, text_in, max_new, temp, top_p], outputs=[audio_out]) | |
| def _warmup(): | |
| try: | |
| _ = generate_voice("matsukaze_male", "hello world", max_new_tokens=128) | |
| logging.info("Warm-up generation completed") | |
| except Exception as e: | |
| logging.info(f"Warm-up skipped/failed: {e}") | |
| if __name__ == "__main__": | |
| logging.info("Starting VoiceCore TTS app (Fast Mode)") | |
| # Optional: warm up kernels so first request is snappy | |
| _warmup() | |
| iface.queue(max_size=32).launch() | |