voiceclear / app.py
Diggz10's picture
Update app.py
b67ceda verified
raw
history blame
7.14 kB
import io
import os
import tempfile
from typing import Tuple, Optional
# ---- tame noisy deprecation warnings (optional but nice) ----
import warnings
warnings.filterwarnings(
"ignore",
message=".*torchaudio._backend.list_audio_backends has been deprecated.*",
)
warnings.filterwarnings(
"ignore",
module=r"speechbrain\..*",
category=UserWarning,
)
import gradio as gr
import numpy as np
import soundfile as sf
import torch
import torchaudio
from fastapi import FastAPI, File, UploadFile, Query
from fastapi.responses import StreamingResponse
# ---- SpeechBrain import: prefer new API, fall back if older version ----
try:
# SpeechBrain >= 1.0
from speechbrain.inference import SpectralMaskEnhancement
except Exception: # pragma: no cover
# Older SpeechBrain (<1.0)
from speechbrain.pretrained import SpectralMaskEnhancement # type: ignore
# -----------------------------
# Model: SpeechBrain MetricGAN+
# -----------------------------
_ENHANCER: Optional[SpectralMaskEnhancement] = None
_DEVICE = "cpu"
def _get_enhancer() -> SpectralMaskEnhancement:
global _ENHANCER
if _ENHANCER is None:
_ENHANCER = SpectralMaskEnhancement.from_hparams(
source="speechbrain/metricgan-plus-voicebank",
savedir="pretrained/metricgan_plus_voicebank",
run_opts={"device": _DEVICE},
)
return _ENHANCER
# -----------------------------
# Audio helpers
# -----------------------------
def _to_mono(wav: np.ndarray) -> np.ndarray:
"""Ensure mono shape [T] float32."""
if wav.ndim == 1:
return wav.astype(np.float32)
# [T, C] or [C, T]
if wav.shape[0] < wav.shape[1]:
return wav.mean(axis=1).astype(np.float32)
return wav.mean(axis=0).astype(np.float32)
def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor:
if sr_in == sr_out:
return wav
return torchaudio.functional.resample(wav, sr_in, sr_out)
def _highpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor:
if cutoff_hz is None or cutoff_hz <= 0:
return wav
return torchaudio.functional.highpass_biquad(wav, sr, cutoff_hz)
def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor:
"""Simple presence EQ around ~4.5 kHz."""
if abs(gain_db) < 1e-6:
return wav
center = 4500.0
q = 0.707
return torchaudio.functional.equalizer_biquad(wav, sr, center, q, gain_db)
def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor:
target_amp = 10.0 ** (target_dbfs / 20.0)
peak = torch.max(torch.abs(wav)).item()
if peak > 0:
scale = min(1.0, target_amp / peak)
wav = wav * scale
return torch.clamp(wav, -1.0, 1.0)
def _enhance_numpy_audio(
audio: Tuple[int, np.ndarray],
presence_db: float = 3.0,
lowcut_hz: float = 75.0,
out_sr: Optional[int] = None,
) -> Tuple[int, np.ndarray]:
"""
Core pipeline used by both Gradio UI and raw FastAPI route.
Input: (sr, np.float32 [T] or [T,C])
Returns: (sr_out, np.float32 [T])
"""
sr_in, wav_np = audio
wav_mono = _to_mono(wav_np)
wav_t = torch.from_numpy(wav_mono).unsqueeze(0) # [1, T]
# MetricGAN+ expects 16 kHz mono
enh = _get_enhancer()
wav_16k = _resample_torch(wav_t, sr_in, 16000)
# Enhance via file path API for broad compatibility
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16")
tmp_in.flush()
clean = enh.enhance_file(tmp_in.name) # torch.Tensor [1, T]
try:
os.remove(tmp_in.name)
except Exception:
pass
# Optional polish: high-pass & presence EQ + peak limit
clean = _highpass(clean, 16000, lowcut_hz)
clean = _presence_boost(clean, 16000, presence_db)
clean = _limit_peak(clean, target_dbfs=-1.0)
# Resample to requested output rate (or original)
sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr)
clean_out = (
_resample_torch(clean, 16000, sr_out).squeeze(0).numpy().astype(np.float32)
)
return sr_out, clean_out
def _wav_bytes(sr: int, mono_f32: np.ndarray) -> bytes:
"""Encode mono float32 array as 16-bit PCM WAV bytes."""
buf = io.BytesIO()
sf.write(buf, mono_f32, sr, subtype="PCM_16", format="WAV")
buf.seek(0)
return buf.read()
# -----------------------------
# FastAPI app with raw endpoint
# -----------------------------
app = FastAPI(title="Voice Clarity Booster (MetricGAN+)", version="1.0.1")
@app.post("/enhance")
async def enhance_endpoint(
file: UploadFile = File(..., description="Audio file (wav/mp3/ogg etc.)"),
presence_db: float = Query(3.0, ge=-12.0, le=12.0, description="Presence EQ gain in dB"),
lowcut_hz: float = Query(75.0, ge=0.0, le=200.0, description="High-pass cutoff in Hz"),
output_sr: int = Query(0, ge=0, description="0=keep original, or set to e.g. 44100/48000"),
):
"""Raw REST endpoint. Returns enhanced audio as audio/wav bytes."""
data = await file.read()
wav_np, sr_in = sf.read(io.BytesIO(data), always_2d=False, dtype="float32")
sr_out, enhanced = _enhance_numpy_audio(
(sr_in, wav_np),
presence_db=presence_db,
lowcut_hz=lowcut_hz,
out_sr=output_sr if output_sr > 0 else None,
)
wav_bytes = _wav_bytes(sr_out, enhanced)
headers = {
"Content-Disposition": f'attachment; filename="{os.path.splitext(file.filename or "audio")[0]}_enhanced.wav"'
}
return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav", headers=headers)
# -----------------------------
# Gradio UI (for quick testing)
# -----------------------------
def gradio_enhance(
audio: Tuple[int, np.ndarray],
presence_db: float,
lowcut_hz: float,
output_sr: str,
):
if audio is None:
return None
out_sr = None
if output_sr in {"44100", "48000"}:
out_sr = int(output_sr)
sr_out, enhanced = _enhance_numpy_audio(
audio, presence_db=float(presence_db), lowcut_hz=float(lowcut_hz), out_sr=out_sr
)
return (sr_out, enhanced)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Voice Clarity Booster (MetricGAN+)")
with gr.Row():
with gr.Column():
in_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input")
presence = gr.Slider(-12, 12, value=3, step=0.5, label="Presence Boost (dB)")
lowcut = gr.Slider(0, 200, value=75, step=5, label="Low-Cut (Hz)")
out_sr = gr.Radio(
choices=["Original", "44100", "48000"],
value="Original",
label="Output Sample Rate",
)
btn = gr.Button("Enhance")
with gr.Column():
out_audio = gr.Audio(type="numpy", label="Enhanced")
btn.click(gradio_enhance, inputs=[in_audio, presence, lowcut, out_sr], outputs=[out_audio])
# Mount Gradio at root path and keep FastAPI for /enhance
app = gr.mount_gradio_app(app, demo, path="/")