#!/usr/bin/env python3 """ NVIDIA Nemo Codec Test - Gradio App Equivalent to snac_test.py but for the NVIDIA Nemo codec used in Kani TTS based models. Allows testing encode/decode cycles with the nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps model. """ import gradio as gr import torch import torchaudio import torchaudio.transforms as T import numpy as np import traceback import time # Attempt to import Nemo try: from nemo.collections.tts.models import AudioCodecModel from nemo.utils.nemo_logging import Logger # Suppress Nemo logging nemo_logger = Logger() nemo_logger.remove_stream_handlers() print("Nemo modules imported successfully.") except ImportError as e: print(f"Error importing Nemo: {e}") raise ImportError("Could not import Nemo. Make sure 'nemo_toolkit[tts]' is installed correctly.") from e # --- Configuration --- TARGET_SR = 22050 # Nemo codec operates at 22kHz DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps" print(f"Using device: {DEVICE}") # --- Load Model (Load once globally) --- nemo_codec = None try: print(f"Loading Nemo codec model: {MODEL_NAME}...") start_time = time.time() nemo_codec = AudioCodecModel.from_pretrained(MODEL_NAME) nemo_codec = nemo_codec.to(DEVICE) nemo_codec.eval() # Set model to evaluation mode end_time = time.time() print(f"Nemo codec loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.") except Exception as e: print(f"FATAL: Error loading Nemo codec: {e}") print(traceback.format_exc()) # --- Main Processing Function --- def process_audio(audio_filepath): """ Loads, resamples, encodes, decodes audio using Nemo codec, and returns results. """ if nemo_codec is None: return None, None, None, "Error: Nemo codec could not be loaded. Cannot process audio." if audio_filepath is None: return None, None, None, "Please upload an audio file." logs = ["--- Starting Audio Processing with Nemo Codec ---"] try: # 1. Load Audio logs.append(f"Loading audio file: {audio_filepath}") load_start = time.time() original_waveform, original_sr = torchaudio.load(audio_filepath) load_end = time.time() logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s") # Ensure float32 original_waveform = original_waveform.to(dtype=torch.float32) # Handle multi-channel audio: Use the first channel if original_waveform.shape[0] > 1: logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.") original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency # --- Prepare Original for Playback --- original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) logs.append("Prepared original audio for playback.") # 2. Resample if necessary resample_start = time.time() if original_sr != TARGET_SR: logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...") resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) waveform_to_encode = resampler(original_waveform) logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}") else: logs.append("Waveform is already at the target sample rate (22kHz).") waveform_to_encode = original_waveform resample_end = time.time() logs.append(f"Resampling time: {resample_end - resample_start:.2f}s") # --- Prepare Resampled for Playback --- resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy()) logs.append("Prepared resampled audio for playback.") # 3. Prepare for Nemo Encoding # Nemo expects [batch, samples] format if waveform_to_encode.dim() == 2 and waveform_to_encode.shape[0] == 1: waveform_batch = waveform_to_encode # [1, samples] else: waveform_batch = waveform_to_encode.unsqueeze(0) # Add batch dimension waveform_batch = waveform_batch.to(DEVICE) # Calculate audio length for Nemo audio_len = torch.tensor([waveform_batch.shape[-1]], dtype=torch.int64).to(DEVICE) logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Audio length: {audio_len.item()}, Device: {DEVICE}") # 4. Encode Audio using Nemo logs.append("Encoding audio with Nemo codec...") encode_start = time.time() with torch.inference_mode(): encoded_tokens, tokens_len = nemo_codec.encode(audio=waveform_batch, audio_len=audio_len) encode_end = time.time() if encoded_tokens is None: log_msg = "Encoding failed: encoded_tokens is None" logs.append(log_msg) raise ValueError(log_msg) logs.append(f"Encoding complete. Time: {encode_end - encode_start:.2f}s") logs.append(f"Encoded tokens shape: {encoded_tokens.shape}, tokens_len: {tokens_len}") logs.append(f"Encoded tokens device: {encoded_tokens.device}") # Log some statistics about the tokens if encoded_tokens.dim() >= 2: logs.append(f"Number of codebooks: {encoded_tokens.shape[1] if encoded_tokens.dim() >= 3 else 'N/A'}") logs.append(f"Sequence length: {encoded_tokens.shape[-1]}") logs.append(f"Token range: [{encoded_tokens.min().item():.0f}, {encoded_tokens.max().item():.0f}]") # 5. Decode the Tokens using Nemo logs.append("Decoding the generated tokens with Nemo codec...") decode_start = time.time() with torch.inference_mode(): reconstructed_waveform, _ = nemo_codec.decode(tokens=encoded_tokens, tokens_len=tokens_len) decode_end = time.time() logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s") # 6. Prepare Reconstructed Audio for Playback # Output should be [batch, samples]. Move to CPU, remove batch dim, convert to NumPy. reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}") reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np) # 7. Calculate quality metrics original_for_comparison = waveform_to_encode.squeeze().numpy() if len(original_for_comparison) != len(reconstructed_audio_np): # Handle length differences (common with codecs) min_len = min(len(original_for_comparison), len(reconstructed_audio_np)) original_trimmed = original_for_comparison[:min_len] reconstructed_trimmed = reconstructed_audio_np[:min_len] # Simple MSE calculation mse = np.mean((original_trimmed - reconstructed_trimmed) ** 2) logs.append(f"Audio length difference: Original {len(original_for_comparison)}, Reconstructed {len(reconstructed_audio_np)}") logs.append(f"MSE (first {min_len} samples): {mse:.6f}") else: mse = np.mean((original_for_comparison - reconstructed_audio_np) ** 2) logs.append(f"MSE: {mse:.6f}") logs.append("\n--- Audio Processing Completed Successfully ---") logs.append(f"Compression ratio: ~{len(original_for_comparison) / (encoded_tokens.numel() if encoded_tokens.numel() > 0 else 1):.1f}:1") return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs) except Exception as e: logs.append("\n--- An Error Occurred ---") logs.append(f"Error Type: {type(e).__name__}") logs.append(f"Error Details: {e}") logs.append("\n--- Traceback ---") logs.append(traceback.format_exc()) return None, None, None, "\n".join(logs) # --- Gradio Interface --- DESCRIPTION = """ This app demonstrates the **NVIDIA Nemo Codec** model (`nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps`) used in Kani TTS. **How it works:** 1. Upload an audio file (wav, mp3, flac, etc.). 2. The audio will be automatically resampled to 22kHz if needed. 3. The 22kHz audio is encoded into discrete tokens by the Nemo codec. 4. These tokens are then decoded back into audio by the Nemo codec. 5. You can listen to the original, the 22kHz version (if resampled), and the final reconstructed audio. **Technical details:** - Sample rate: 22kHz - Compression: ~0.6kbps - Frame rate: 12.5fps - 4 codebook levels per frame **Note:** Processing happens locally. Larger files will take longer. If the input is stereo, only the first channel is processed. """ iface = gr.Interface( fn=process_audio, inputs=gr.Audio(type="filepath", label="Upload Audio File"), outputs=[ gr.Audio(label="Original Audio"), gr.Audio(label="Resampled Audio (22kHz Input to Nemo)"), gr.Audio(label="Reconstructed Audio (Output from Nemo Codec)"), gr.Textbox(label="Log Output", lines=20) ], title="NVIDIA Nemo Codec Demo (22kHz)", description=DESCRIPTION, examples=[ # later I might add some samples # ["examples/example1.wav"], # ["examples/example2.wav"], ], cache_examples=False ) if __name__ == "__main__": if nemo_codec is None: print("Cannot launch Gradio interface because Nemo codec failed to load.") else: print("Launching Gradio Interface...") print(f"Model: {MODEL_NAME}") print(f"Target sample rate: {TARGET_SR} Hz") print(f"Device: {DEVICE}") iface.launch(share=True)