Spaces:
Running
on
Zero
Running
on
Zero
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from dia.model import Dia | |
| # Load Nari model and config | |
| print("Loading Nari model...") | |
| try: | |
| # Use the function from inference.py | |
| model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") | |
| except Exception as e: | |
| print(f"Error loading Nari model: {e}") | |
| raise | |
| def run_inference( | |
| text_input: str, | |
| audio_prompt_input: Optional[Tuple[int, np.ndarray]], | |
| max_new_tokens: int, | |
| cfg_scale: float, | |
| temperature: float, | |
| top_p: float, | |
| cfg_filter_top_k: int, | |
| speed_factor: float, | |
| ): | |
| """ | |
| Runs Nari inference using the globally loaded model and provided inputs. | |
| Uses temporary files for text and audio prompt compatibility with inference.generate. | |
| """ | |
| # global model, device # Access global model, config, device | |
| if not text_input or text_input.isspace(): | |
| raise gr.Error("Text input cannot be empty.") | |
| temp_txt_file_path = None | |
| temp_audio_prompt_path = None | |
| output_audio = (44100, np.zeros(1, dtype=np.float32)) | |
| try: | |
| prompt_path_for_generate = None | |
| if audio_prompt_input is not None: | |
| sr, audio_data = audio_prompt_input | |
| # Check if audio_data is valid | |
| if ( | |
| audio_data is None or audio_data.size == 0 or audio_data.max() == 0 | |
| ): # Check for silence/empty | |
| gr.Warning("Audio prompt seems empty or silent, ignoring prompt.") | |
| else: | |
| # Save prompt audio to a temporary WAV file | |
| with tempfile.NamedTemporaryFile( | |
| mode="wb", suffix=".wav", delete=False | |
| ) as f_audio: | |
| temp_audio_prompt_path = f_audio.name # Store path for cleanup | |
| # Basic audio preprocessing for consistency | |
| # Convert to float32 in [-1, 1] range if integer type | |
| if np.issubdtype(audio_data.dtype, np.integer): | |
| max_val = np.iinfo(audio_data.dtype).max | |
| audio_data = audio_data.astype(np.float32) / max_val | |
| elif not np.issubdtype(audio_data.dtype, np.floating): | |
| gr.Warning( | |
| f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion." | |
| ) | |
| # Attempt conversion, might fail for complex types | |
| try: | |
| audio_data = audio_data.astype(np.float32) | |
| except Exception as conv_e: | |
| raise gr.Error( | |
| f"Failed to convert audio prompt to float32: {conv_e}" | |
| ) | |
| # Ensure mono (average channels if stereo) | |
| if audio_data.ndim > 1: | |
| if audio_data.shape[0] == 2: # Assume (2, N) | |
| audio_data = np.mean(audio_data, axis=0) | |
| elif audio_data.shape[1] == 2: # Assume (N, 2) | |
| audio_data = np.mean(audio_data, axis=1) | |
| else: | |
| gr.Warning( | |
| f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis." | |
| ) | |
| audio_data = ( | |
| audio_data[0] | |
| if audio_data.shape[0] < audio_data.shape[1] | |
| else audio_data[:, 0] | |
| ) | |
| audio_data = np.ascontiguousarray( | |
| audio_data | |
| ) # Ensure contiguous after slicing/mean | |
| # Write using soundfile | |
| try: | |
| sf.write( | |
| temp_audio_prompt_path, audio_data, sr, subtype="FLOAT" | |
| ) # Explicitly use FLOAT subtype | |
| prompt_path_for_generate = temp_audio_prompt_path | |
| print( | |
| f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})" | |
| ) | |
| except Exception as write_e: | |
| print(f"Error writing temporary audio file: {write_e}") | |
| raise gr.Error(f"Failed to save audio prompt: {write_e}") | |
| # 3. Run Generation | |
| start_time = time.time() | |
| # Use torch.inference_mode() context manager for the generation call | |
| with torch.inference_mode(): | |
| output_audio_np = model.generate( | |
| text_input, | |
| max_tokens=max_new_tokens, | |
| cfg_scale=cfg_scale, | |
| temperature=temperature, | |
| top_p=top_p, | |
| cfg_filter_top_k=cfg_filter_top_k, # Pass the value here | |
| use_torch_compile=False, # Keep False for Gradio stability | |
| audio_prompt=prompt_path_for_generate, | |
| ) | |
| end_time = time.time() | |
| print(f"Generation finished in {end_time - start_time:.2f} seconds.") | |
| # 4. Convert Codes to Audio | |
| if output_audio_np is not None: | |
| # Get sample rate from the loaded DAC model | |
| output_sr = 44100 | |
| # --- Slow down audio --- | |
| original_len = len(output_audio_np) | |
| # Ensure speed_factor is positive and not excessively small/large to avoid issues | |
| speed_factor = max(0.1, min(speed_factor, 5.0)) | |
| target_len = int( | |
| original_len / speed_factor | |
| ) # Target length based on speed_factor | |
| if ( | |
| target_len != original_len and target_len > 0 | |
| ): # Only interpolate if length changes and is valid | |
| x_original = np.arange(original_len) | |
| x_resampled = np.linspace(0, original_len - 1, target_len) | |
| resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np) | |
| output_audio = ( | |
| output_sr, | |
| resampled_audio_np.astype(np.float32), | |
| ) # Use resampled audio | |
| print( | |
| f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed." | |
| ) | |
| else: | |
| output_audio = ( | |
| output_sr, | |
| output_audio_np, | |
| ) # Keep original if calculation fails or no change | |
| print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).") | |
| # --- End slowdown --- | |
| print( | |
| f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}" | |
| ) | |
| # Explicitly convert to int16 to prevent Gradio warning | |
| if ( | |
| output_audio[1].dtype == np.float32 | |
| or output_audio[1].dtype == np.float64 | |
| ): | |
| audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0) | |
| audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16) | |
| output_audio = (output_sr, audio_for_gradio) | |
| print("Converted audio to int16 for Gradio output.") | |
| else: | |
| print("\nGeneration finished, but no valid tokens were produced.") | |
| # Return default silence | |
| gr.Warning("Generation produced no output.") | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Re-raise as Gradio error to display nicely in the UI | |
| raise gr.Error(f"Inference failed: {e}") | |
| finally: | |
| # 5. Cleanup Temporary Files defensively | |
| if temp_txt_file_path and Path(temp_txt_file_path).exists(): | |
| try: | |
| Path(temp_txt_file_path).unlink() | |
| print(f"Deleted temporary text file: {temp_txt_file_path}") | |
| except OSError as e: | |
| print( | |
| f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}" | |
| ) | |
| if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists(): | |
| try: | |
| Path(temp_audio_prompt_path).unlink() | |
| print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}") | |
| except OSError as e: | |
| print( | |
| f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}" | |
| ) | |
| return output_audio | |
| # --- Create Gradio Interface --- | |
| css = """ | |
| #col-container {max-width: 90%; margin-left: auto; margin-right: auto;} | |
| """ | |
| # Attempt to load default text from example.txt | |
| default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face." | |
| example_txt_path = Path("./example.txt") | |
| if example_txt_path.exists(): | |
| try: | |
| default_text = example_txt_path.read_text(encoding="utf-8").strip() | |
| if not default_text: # Handle empty example file | |
| default_text = "Example text file was empty." | |
| except Exception as e: | |
| print(f"Warning: Could not read example.txt: {e}") | |
| # Build Gradio UI | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# Nari Text-to-Speech Synthesis") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter text here...", | |
| value=default_text, | |
| lines=5, # Increased lines | |
| ) | |
| audio_prompt_input = gr.Audio( | |
| label="Audio Prompt (Optional)", | |
| show_label=True, | |
| sources=["upload", "microphone"], | |
| type="numpy", | |
| ) | |
| with gr.Accordion("Generation Parameters", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens (Audio Length)", | |
| minimum=860, | |
| maximum=3072, | |
| value=model.config.data.audio_length, # Use config default if available, else fallback | |
| step=50, | |
| info="Controls the maximum length of the generated audio (more tokens = longer audio).", | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale (Guidance Strength)", | |
| minimum=1.0, | |
| maximum=5.0, | |
| value=3.0, # Default from inference.py | |
| step=0.1, | |
| info="Higher values increase adherence to the text prompt.", | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature (Randomness)", | |
| minimum=1.0, | |
| maximum=1.5, | |
| value=1.3, # Default from inference.py | |
| step=0.05, | |
| info="Lower values make the output more deterministic, higher values increase randomness.", | |
| ) | |
| top_p = gr.Slider( | |
| label="Top P (Nucleus Sampling)", | |
| minimum=0.80, | |
| maximum=1.0, | |
| value=0.95, # Default from inference.py | |
| step=0.01, | |
| info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.", | |
| ) | |
| cfg_filter_top_k = gr.Slider( | |
| label="CFG Filter Top K", | |
| minimum=15, | |
| maximum=50, | |
| value=30, | |
| step=1, | |
| info="Top k filter for CFG guidance.", | |
| ) | |
| speed_factor_slider = gr.Slider( | |
| label="Speed Factor", | |
| minimum=0.8, | |
| maximum=1.0, | |
| value=0.94, | |
| step=0.02, | |
| info="Adjusts the speed of the generated audio (1.0 = original speed).", | |
| ) | |
| run_button = gr.Button("Generate Audio", variant="primary") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio( | |
| label="Generated Audio", | |
| type="numpy", | |
| autoplay=False, | |
| ) | |
| # Link button click to function | |
| run_button.click( | |
| fn=run_inference, | |
| inputs=[ | |
| text_input, | |
| audio_prompt_input, | |
| max_new_tokens, | |
| cfg_scale, | |
| temperature, | |
| top_p, | |
| cfg_filter_top_k, | |
| speed_factor_slider, | |
| ], | |
| outputs=[audio_output], # Add status_output here if using it | |
| api_name="generate_audio", | |
| ) | |
| # Add examples (ensure the prompt path is correct or remove it if example file doesn't exist) | |
| example_prompt_path = "./example_prompt.mp3" # Adjust if needed | |
| examples_list = [ | |
| [ | |
| "[S1] Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct! \n[S2] Oh my god! Okay.. it's happening. Everybody stay calm! \n[S1] What's the procedure... \n[S2] Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway! ", | |
| None, | |
| 3072, | |
| 3.0, | |
| 1.3, | |
| 0.95, | |
| 35, | |
| 0.94, | |
| ], | |
| [ | |
| "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.", | |
| example_prompt_path if Path(example_prompt_path).exists() else None, | |
| 3072, | |
| 3.0, | |
| 1.3, | |
| 0.95, | |
| 35, | |
| 0.94, | |
| ], | |
| ] | |
| if examples_list: | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[ | |
| text_input, | |
| audio_prompt_input, | |
| max_new_tokens, | |
| cfg_scale, | |
| temperature, | |
| top_p, | |
| cfg_filter_top_k, | |
| speed_factor_slider, | |
| ], | |
| outputs=[audio_output], | |
| fn=run_inference, | |
| cache_examples=False, | |
| label="Examples (Click to Run)", | |
| ) | |
| else: | |
| gr.Markdown("_(No examples configured or example prompt file missing)_") | |
| # --- Launch the App --- | |
| if __name__ == "__main__": | |
| print("Launching Gradio interface...") | |
| # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values | |
| # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker | |
| demo.launch() |