Gapeleon commited on
Commit
c471c87
·
verified ·
1 Parent(s): 315f949

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ NVIDIA Nemo Codec Test - Gradio App
4
+ Equivalent to snac_test.py but for the NVIDIA Nemo codec used in Kani TTS based models.
5
+ Allows testing encode/decode cycles with the nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps model.
6
+ """
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torchaudio
11
+ import torchaudio.transforms as T
12
+ import numpy as np
13
+ import traceback
14
+ import time
15
+
16
+ # Attempt to import Nemo
17
+ try:
18
+ from nemo.collections.tts.models import AudioCodecModel
19
+ from nemo.utils.nemo_logging import Logger
20
+
21
+ # Suppress Nemo logging
22
+ nemo_logger = Logger()
23
+ nemo_logger.remove_stream_handlers()
24
+ print("Nemo modules imported successfully.")
25
+ except ImportError as e:
26
+ print(f"Error importing Nemo: {e}")
27
+ raise ImportError("Could not import Nemo. Make sure 'nemo_toolkit[tts]' is installed correctly.") from e
28
+
29
+ # --- Configuration ---
30
+ TARGET_SR = 22050 # Nemo codec operates at 22kHz
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ MODEL_NAME = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
33
+ print(f"Using device: {DEVICE}")
34
+
35
+ # --- Load Model (Load once globally) ---
36
+ nemo_codec = None
37
+ try:
38
+ print(f"Loading Nemo codec model: {MODEL_NAME}...")
39
+ start_time = time.time()
40
+ nemo_codec = AudioCodecModel.from_pretrained(MODEL_NAME)
41
+ nemo_codec = nemo_codec.to(DEVICE)
42
+ nemo_codec.eval() # Set model to evaluation mode
43
+ end_time = time.time()
44
+ print(f"Nemo codec loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.")
45
+ except Exception as e:
46
+ print(f"FATAL: Error loading Nemo codec: {e}")
47
+ print(traceback.format_exc())
48
+
49
+ # --- Main Processing Function ---
50
+ def process_audio(audio_filepath):
51
+ """
52
+ Loads, resamples, encodes, decodes audio using Nemo codec, and returns results.
53
+ """
54
+ if nemo_codec is None:
55
+ return None, None, None, "Error: Nemo codec could not be loaded. Cannot process audio."
56
+
57
+ if audio_filepath is None:
58
+ return None, None, None, "Please upload an audio file."
59
+
60
+ logs = ["--- Starting Audio Processing with Nemo Codec ---"]
61
+ try:
62
+ # 1. Load Audio
63
+ logs.append(f"Loading audio file: {audio_filepath}")
64
+ load_start = time.time()
65
+ original_waveform, original_sr = torchaudio.load(audio_filepath)
66
+ load_end = time.time()
67
+ logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s")
68
+
69
+ # Ensure float32
70
+ original_waveform = original_waveform.to(dtype=torch.float32)
71
+
72
+ # Handle multi-channel audio: Use the first channel
73
+ if original_waveform.shape[0] > 1:
74
+ logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.")
75
+ original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency
76
+
77
+ # --- Prepare Original for Playback ---
78
+ original_audio_playback = (original_sr, original_waveform.squeeze().numpy())
79
+ logs.append("Prepared original audio for playback.")
80
+
81
+ # 2. Resample if necessary
82
+ resample_start = time.time()
83
+ if original_sr != TARGET_SR:
84
+ logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...")
85
+ resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device)
86
+ waveform_to_encode = resampler(original_waveform)
87
+ logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}")
88
+ else:
89
+ logs.append("Waveform is already at the target sample rate (22kHz).")
90
+ waveform_to_encode = original_waveform
91
+ resample_end = time.time()
92
+ logs.append(f"Resampling time: {resample_end - resample_start:.2f}s")
93
+
94
+ # --- Prepare Resampled for Playback ---
95
+ resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy())
96
+ logs.append("Prepared resampled audio for playback.")
97
+
98
+ # 3. Prepare for Nemo Encoding
99
+ # Nemo expects [batch, samples] format
100
+ if waveform_to_encode.dim() == 2 and waveform_to_encode.shape[0] == 1:
101
+ waveform_batch = waveform_to_encode # [1, samples]
102
+ else:
103
+ waveform_batch = waveform_to_encode.unsqueeze(0) # Add batch dimension
104
+
105
+ waveform_batch = waveform_batch.to(DEVICE)
106
+
107
+ # Calculate audio length for Nemo
108
+ audio_len = torch.tensor([waveform_batch.shape[-1]], dtype=torch.int64).to(DEVICE)
109
+ logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Audio length: {audio_len.item()}, Device: {DEVICE}")
110
+
111
+ # 4. Encode Audio using Nemo
112
+ logs.append("Encoding audio with Nemo codec...")
113
+ encode_start = time.time()
114
+ with torch.inference_mode():
115
+ encoded_tokens, tokens_len = nemo_codec.encode(audio=waveform_batch, audio_len=audio_len)
116
+ encode_end = time.time()
117
+
118
+ if encoded_tokens is None:
119
+ log_msg = "Encoding failed: encoded_tokens is None"
120
+ logs.append(log_msg)
121
+ raise ValueError(log_msg)
122
+
123
+ logs.append(f"Encoding complete. Time: {encode_end - encode_start:.2f}s")
124
+ logs.append(f"Encoded tokens shape: {encoded_tokens.shape}, tokens_len: {tokens_len}")
125
+ logs.append(f"Encoded tokens device: {encoded_tokens.device}")
126
+
127
+ # Log some statistics about the tokens
128
+ if encoded_tokens.dim() >= 2:
129
+ logs.append(f"Number of codebooks: {encoded_tokens.shape[1] if encoded_tokens.dim() >= 3 else 'N/A'}")
130
+ logs.append(f"Sequence length: {encoded_tokens.shape[-1]}")
131
+ logs.append(f"Token range: [{encoded_tokens.min().item():.0f}, {encoded_tokens.max().item():.0f}]")
132
+
133
+ # 5. Decode the Tokens using Nemo
134
+ logs.append("Decoding the generated tokens with Nemo codec...")
135
+ decode_start = time.time()
136
+ with torch.inference_mode():
137
+ reconstructed_waveform, _ = nemo_codec.decode(tokens=encoded_tokens, tokens_len=tokens_len)
138
+ decode_end = time.time()
139
+ logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s")
140
+
141
+ # 6. Prepare Reconstructed Audio for Playback
142
+ # Output should be [batch, samples]. Move to CPU, remove batch dim, convert to NumPy.
143
+ reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy()
144
+ logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}")
145
+ reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np)
146
+
147
+ # 7. Calculate quality metrics
148
+ original_for_comparison = waveform_to_encode.squeeze().numpy()
149
+ if len(original_for_comparison) != len(reconstructed_audio_np):
150
+ # Handle length differences (common with codecs)
151
+ min_len = min(len(original_for_comparison), len(reconstructed_audio_np))
152
+ original_trimmed = original_for_comparison[:min_len]
153
+ reconstructed_trimmed = reconstructed_audio_np[:min_len]
154
+
155
+ # Simple MSE calculation
156
+ mse = np.mean((original_trimmed - reconstructed_trimmed) ** 2)
157
+ logs.append(f"Audio length difference: Original {len(original_for_comparison)}, Reconstructed {len(reconstructed_audio_np)}")
158
+ logs.append(f"MSE (first {min_len} samples): {mse:.6f}")
159
+ else:
160
+ mse = np.mean((original_for_comparison - reconstructed_audio_np) ** 2)
161
+ logs.append(f"MSE: {mse:.6f}")
162
+
163
+ logs.append("\n--- Audio Processing Completed Successfully ---")
164
+ logs.append(f"Compression ratio: ~{len(original_for_comparison) / (encoded_tokens.numel() if encoded_tokens.numel() > 0 else 1):.1f}:1")
165
+
166
+ return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs)
167
+
168
+ except Exception as e:
169
+ logs.append("\n--- An Error Occurred ---")
170
+ logs.append(f"Error Type: {type(e).__name__}")
171
+ logs.append(f"Error Details: {e}")
172
+ logs.append("\n--- Traceback ---")
173
+ logs.append(traceback.format_exc())
174
+ return None, None, None, "\n".join(logs)
175
+
176
+ # --- Gradio Interface ---
177
+ DESCRIPTION = """
178
+ This app demonstrates the **NVIDIA Nemo Codec** model (`nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps`) used in Kani TTS.
179
+
180
+ **How it works:**
181
+ 1. Upload an audio file (wav, mp3, flac, etc.).
182
+ 2. The audio will be automatically resampled to 22kHz if needed.
183
+ 3. The 22kHz audio is encoded into discrete tokens by the Nemo codec.
184
+ 4. These tokens are then decoded back into audio by the Nemo codec.
185
+ 5. You can listen to the original, the 22kHz version (if resampled), and the final reconstructed audio.
186
+
187
+ **Technical details:**
188
+ - Sample rate: 22kHz
189
+ - Compression: ~0.6kbps
190
+ - Frame rate: 12.5fps
191
+ - 4 codebook levels per frame
192
+
193
+ **Note:** Processing happens locally. Larger files will take longer. If the input is stereo, only the first channel is processed.
194
+ """
195
+
196
+ iface = gr.Interface(
197
+ fn=process_audio,
198
+ inputs=gr.Audio(type="filepath", label="Upload Audio File"),
199
+ outputs=[
200
+ gr.Audio(label="Original Audio"),
201
+ gr.Audio(label="Resampled Audio (22kHz Input to Nemo)"),
202
+ gr.Audio(label="Reconstructed Audio (Output from Nemo Codec)"),
203
+ gr.Textbox(label="Log Output", lines=20)
204
+ ],
205
+ title="NVIDIA Nemo Codec Demo (22kHz)",
206
+ description=DESCRIPTION,
207
+ examples=[
208
+ # later I might add some samples
209
+ # ["examples/example1.wav"],
210
+ # ["examples/example2.wav"],
211
+ ],
212
+ cache_examples=False
213
+ )
214
+
215
+ if __name__ == "__main__":
216
+ if nemo_codec is None:
217
+ print("Cannot launch Gradio interface because Nemo codec failed to load.")
218
+ else:
219
+ print("Launching Gradio Interface...")
220
+ print(f"Model: {MODEL_NAME}")
221
+ print(f"Target sample rate: {TARGET_SR} Hz")
222
+ print(f"Device: {DEVICE}")
223
+ iface.launch(share=True)