Spaces:
Running
Running
| import io | |
| import os | |
| import tempfile | |
| from typing import Tuple, Optional | |
| # ---- tame noisy deprecation warnings (optional but nice) ---- | |
| import warnings | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=".*torchaudio._backend.list_audio_backends has been deprecated.*", | |
| ) | |
| warnings.filterwarnings( | |
| "ignore", | |
| module=r"speechbrain\..*", | |
| category=UserWarning, | |
| ) | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| from fastapi import FastAPI, File, UploadFile, Query | |
| from fastapi.responses import StreamingResponse | |
| # ---- SpeechBrain import: prefer new API, fall back if older version ---- | |
| try: | |
| # SpeechBrain >= 1.0 | |
| from speechbrain.inference import SpectralMaskEnhancement | |
| except Exception: # pragma: no cover | |
| # Older SpeechBrain (<1.0) | |
| from speechbrain.pretrained import SpectralMaskEnhancement # type: ignore | |
| # ----------------------------- | |
| # Model: SpeechBrain MetricGAN+ | |
| # ----------------------------- | |
| _ENHANCER: Optional[SpectralMaskEnhancement] = None | |
| _DEVICE = "cpu" | |
| def _get_enhancer() -> SpectralMaskEnhancement: | |
| global _ENHANCER | |
| if _ENHANCER is None: | |
| _ENHANCER = SpectralMaskEnhancement.from_hparams( | |
| source="speechbrain/metricgan-plus-voicebank", | |
| savedir="pretrained/metricgan_plus_voicebank", | |
| run_opts={"device": _DEVICE}, | |
| ) | |
| return _ENHANCER | |
| # ----------------------------- | |
| # Audio helpers | |
| # ----------------------------- | |
| def _to_mono(wav: np.ndarray) -> np.ndarray: | |
| """Ensure mono shape [T] float32.""" | |
| if wav.ndim == 1: | |
| return wav.astype(np.float32) | |
| # [T, C] or [C, T] | |
| if wav.shape[0] < wav.shape[1]: | |
| return wav.mean(axis=1).astype(np.float32) | |
| return wav.mean(axis=0).astype(np.float32) | |
| def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor: | |
| if sr_in == sr_out: | |
| return wav | |
| return torchaudio.functional.resample(wav, sr_in, sr_out) | |
| def _highpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor: | |
| if cutoff_hz is None or cutoff_hz <= 0: | |
| return wav | |
| return torchaudio.functional.highpass_biquad(wav, sr, cutoff_hz) | |
| def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor: | |
| """Simple presence EQ around ~4.5 kHz.""" | |
| if abs(gain_db) < 1e-6: | |
| return wav | |
| center = 4500.0 | |
| q = 0.707 | |
| return torchaudio.functional.equalizer_biquad(wav, sr, center, q, gain_db) | |
| def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor: | |
| target_amp = 10.0 ** (target_dbfs / 20.0) | |
| peak = torch.max(torch.abs(wav)).item() | |
| if peak > 0: | |
| scale = min(1.0, target_amp / peak) | |
| wav = wav * scale | |
| return torch.clamp(wav, -1.0, 1.0) | |
| def _enhance_numpy_audio( | |
| audio: Tuple[int, np.ndarray], | |
| presence_db: float = 3.0, | |
| lowcut_hz: float = 75.0, | |
| out_sr: Optional[int] = None, | |
| ) -> Tuple[int, np.ndarray]: | |
| """ | |
| Core pipeline used by both Gradio UI and raw FastAPI route. | |
| Input: (sr, np.float32 [T] or [T,C]) | |
| Returns: (sr_out, np.float32 [T]) | |
| """ | |
| sr_in, wav_np = audio | |
| wav_mono = _to_mono(wav_np) | |
| wav_t = torch.from_numpy(wav_mono).unsqueeze(0) # [1, T] | |
| # MetricGAN+ expects 16 kHz mono | |
| enh = _get_enhancer() | |
| wav_16k = _resample_torch(wav_t, sr_in, 16000) | |
| # Enhance via file path API for broad compatibility | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in: | |
| sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16") | |
| tmp_in.flush() | |
| clean = enh.enhance_file(tmp_in.name) # torch.Tensor [1, T] | |
| try: | |
| os.remove(tmp_in.name) | |
| except Exception: | |
| pass | |
| # Optional polish: high-pass & presence EQ + peak limit | |
| clean = _highpass(clean, 16000, lowcut_hz) | |
| clean = _presence_boost(clean, 16000, presence_db) | |
| clean = _limit_peak(clean, target_dbfs=-1.0) | |
| # Resample to requested output rate (or original) | |
| sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr) | |
| clean_out = ( | |
| _resample_torch(clean, 16000, sr_out).squeeze(0).numpy().astype(np.float32) | |
| ) | |
| return sr_out, clean_out | |
| def _wav_bytes(sr: int, mono_f32: np.ndarray) -> bytes: | |
| """Encode mono float32 array as 16-bit PCM WAV bytes.""" | |
| buf = io.BytesIO() | |
| sf.write(buf, mono_f32, sr, subtype="PCM_16", format="WAV") | |
| buf.seek(0) | |
| return buf.read() | |
| # ----------------------------- | |
| # FastAPI app with raw endpoint | |
| # ----------------------------- | |
| app = FastAPI(title="Voice Clarity Booster (MetricGAN+)", version="1.0.1") | |
| async def enhance_endpoint( | |
| file: UploadFile = File(..., description="Audio file (wav/mp3/ogg etc.)"), | |
| presence_db: float = Query(3.0, ge=-12.0, le=12.0, description="Presence EQ gain in dB"), | |
| lowcut_hz: float = Query(75.0, ge=0.0, le=200.0, description="High-pass cutoff in Hz"), | |
| output_sr: int = Query(0, ge=0, description="0=keep original, or set to e.g. 44100/48000"), | |
| ): | |
| """Raw REST endpoint. Returns enhanced audio as audio/wav bytes.""" | |
| data = await file.read() | |
| wav_np, sr_in = sf.read(io.BytesIO(data), always_2d=False, dtype="float32") | |
| sr_out, enhanced = _enhance_numpy_audio( | |
| (sr_in, wav_np), | |
| presence_db=presence_db, | |
| lowcut_hz=lowcut_hz, | |
| out_sr=output_sr if output_sr > 0 else None, | |
| ) | |
| wav_bytes = _wav_bytes(sr_out, enhanced) | |
| headers = { | |
| "Content-Disposition": f'attachment; filename="{os.path.splitext(file.filename or "audio")[0]}_enhanced.wav"' | |
| } | |
| return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav", headers=headers) | |
| # ----------------------------- | |
| # Gradio UI (for quick testing) | |
| # ----------------------------- | |
| def gradio_enhance( | |
| audio: Tuple[int, np.ndarray], | |
| presence_db: float, | |
| lowcut_hz: float, | |
| output_sr: str, | |
| ): | |
| if audio is None: | |
| return None | |
| out_sr = None | |
| if output_sr in {"44100", "48000"}: | |
| out_sr = int(output_sr) | |
| sr_out, enhanced = _enhance_numpy_audio( | |
| audio, presence_db=float(presence_db), lowcut_hz=float(lowcut_hz), out_sr=out_sr | |
| ) | |
| return (sr_out, enhanced) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## Voice Clarity Booster (MetricGAN+)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input") | |
| presence = gr.Slider(-12, 12, value=3, step=0.5, label="Presence Boost (dB)") | |
| lowcut = gr.Slider(0, 200, value=75, step=5, label="Low-Cut (Hz)") | |
| out_sr = gr.Radio( | |
| choices=["Original", "44100", "48000"], | |
| value="Original", | |
| label="Output Sample Rate", | |
| ) | |
| btn = gr.Button("Enhance") | |
| with gr.Column(): | |
| out_audio = gr.Audio(type="numpy", label="Enhanced") | |
| btn.click(gradio_enhance, inputs=[in_audio, presence, lowcut, out_sr], outputs=[out_audio]) | |
| # Mount Gradio at root path and keep FastAPI for /enhance | |
| app = gr.mount_gradio_app(app, demo, path="/") | |