Diggz10 commited on
Commit
ffb4e02
·
verified ·
1 Parent(s): 319d1d1

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +204 -0
main.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import tempfile
4
+ from typing import Tuple, Optional
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import torch
10
+ import torchaudio
11
+ from fastapi import FastAPI, File, UploadFile, Query, Response
12
+ from fastapi.responses import StreamingResponse
13
+ from speechbrain.pretrained import SpectralMaskEnhancement
14
+
15
+ # -----------------------------
16
+ # Model: SpeechBrain MetricGAN+
17
+ # -----------------------------
18
+ _ENHANCER: Optional[SpectralMaskEnhancement] = None
19
+ _DEVICE = "cpu"
20
+
21
+
22
+ def _get_enhancer() -> SpectralMaskEnhancement:
23
+ global _ENHANCER
24
+ if _ENHANCER is None:
25
+ # Downloads once and caches in the Space
26
+ _ENHANCER = SpectralMaskEnhancement.from_hparams(
27
+ source="speechbrain/metricgan-plus-voicebank",
28
+ savedir="pretrained/metricgan_plus_voicebank",
29
+ run_opts={"device": _DEVICE},
30
+ )
31
+ return _ENHANCER
32
+
33
+
34
+ # -----------------------------
35
+ # Audio helpers
36
+ # -----------------------------
37
+ def _to_mono(wav: np.ndarray) -> np.ndarray:
38
+ """Ensure mono shape [T]."""
39
+ if wav.ndim == 1:
40
+ return wav.astype(np.float32)
41
+ # shape [T, C] or [C, T]
42
+ if wav.shape[0] < wav.shape[1]:
43
+ # likely [T, C]
44
+ return wav.mean(axis=1).astype(np.float32)
45
+ else:
46
+ # likely [C, T]
47
+ return wav.mean(axis=0).astype(np.float32)
48
+
49
+
50
+ def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor:
51
+ if sr_in == sr_out:
52
+ return wav
53
+ return torchaudio.functional.resample(wav, sr_in, sr_out)
54
+
55
+
56
+ def _highpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor:
57
+ if cutoff_hz is None or cutoff_hz <= 0:
58
+ return wav
59
+ # 2nd-order Butterworth-ish highpass via biquad
60
+ return torchaudio.functional.highpass_biquad(wav, sr, cutoff_hz)
61
+
62
+
63
+ def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor:
64
+ """Simple presence (peaking) EQ around 4.5 kHz."""
65
+ if abs(gain_db) < 1e-6:
66
+ return wav
67
+ center = 4500.0 # presence band
68
+ q = 0.707 # wide-ish
69
+ return torchaudio.functional.equalizer_biquad(wav, sr, center, q, gain_db)
70
+
71
+
72
+ def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor:
73
+ """Peak-normalize to target dBFS (default -1 dB)."""
74
+ target_amp = 10.0 ** (target_dbfs / 20.0)
75
+ peak = torch.max(torch.abs(wav)).item()
76
+ if peak > 0:
77
+ scale = min(1.0, target_amp / peak)
78
+ wav = wav * scale
79
+ return torch.clamp(wav, -1.0, 1.0)
80
+
81
+
82
+ def _enhance_numpy_audio(
83
+ audio: Tuple[int, np.ndarray],
84
+ presence_db: float = 3.0,
85
+ lowcut_hz: float = 75.0,
86
+ out_sr: Optional[int] = None,
87
+ ) -> Tuple[int, np.ndarray]:
88
+ """
89
+ Core pipeline used by both Gradio UI and raw FastAPI route.
90
+ Input: (sr, np.float32 [T] or [T,C])
91
+ Returns: (sr_out, np.float32 [T])
92
+ """
93
+ sr_in, wav_np = audio
94
+ wav_mono = _to_mono(wav_np)
95
+ wav_t = torch.from_numpy(wav_mono).unsqueeze(0) # [1, T]
96
+
97
+ # MetricGAN+ expects 16 kHz mono
98
+ enh = _get_enhancer()
99
+ wav_16k = _resample_torch(wav_t, sr_in, 16000)
100
+
101
+ # Enhance via file path API for maximum compatibility
102
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
103
+ sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16")
104
+ tmp_in.flush()
105
+ # Enhance; returns torch.Tensor [1, T]
106
+ clean = enh.enhance_file(tmp_in.name)
107
+ try:
108
+ os.remove(tmp_in.name)
109
+ except Exception:
110
+ pass
111
+
112
+ # Optional polish: high-pass & presence EQ
113
+ clean = _highpass(clean, 16000, lowcut_hz)
114
+ clean = _presence_boost(clean, 16000, presence_db)
115
+
116
+ # Peak limiting to avoid inter-sample clip
117
+ clean = _limit_peak(clean, target_dbfs=-1.0)
118
+
119
+ # Resample back
120
+ sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr)
121
+ clean_out = _resample_torch(clean, 16000, sr_out).squeeze(0).numpy().astype(
122
+ np.float32
123
+ )
124
+
125
+ return sr_out, clean_out
126
+
127
+
128
+ def _wav_bytes(sr: int, mono_f32: np.ndarray) -> bytes:
129
+ """Encode a mono float32 array as 16-bit PCM WAV into bytes."""
130
+ buf = io.BytesIO()
131
+ sf.write(buf, mono_f32, sr, subtype="PCM_16", format="WAV")
132
+ buf.seek(0)
133
+ return buf.read()
134
+
135
+
136
+ # -----------------------------
137
+ # FastAPI app with raw endpoint
138
+ # -----------------------------
139
+ app = FastAPI(title="Voice Clarity Booster (MetricGAN+)", version="1.0.0")
140
+
141
+
142
+ @app.post("/enhance")
143
+ async def enhance_endpoint(
144
+ file: UploadFile = File(..., description="Audio file (wav/mp3/ogg etc.)"),
145
+ presence_db: float = Query(3.0, ge=-12.0, le=12.0, description="Presence EQ gain in dB"),
146
+ lowcut_hz: float = Query(75.0, ge=0.0, le=200.0, description="High-pass cutoff in Hz"),
147
+ output_sr: int = Query(0, ge=0, description="0=keep original, or set to e.g. 44100/48000"),
148
+ ):
149
+ """Raw REST endpoint. Returns enhanced audio as audio/wav bytes."""
150
+ data = await file.read()
151
+ # Decode with soundfile
152
+ wav_np, sr_in = sf.read(io.BytesIO(data), always_2d=False, dtype="float32")
153
+ sr_out, enhanced = _enhance_numpy_audio(
154
+ (sr_in, wav_np),
155
+ presence_db=presence_db,
156
+ lowcut_hz=lowcut_hz,
157
+ out_sr=output_sr if output_sr > 0 else None,
158
+ )
159
+ wav_bytes = _wav_bytes(sr_out, enhanced)
160
+ headers = {"Content-Disposition": f'attachment; filename="{os.path.splitext(file.filename or "audio")[0]}_enhanced.wav"'}
161
+ return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav", headers=headers)
162
+
163
+
164
+ # -----------------------------
165
+ # Gradio UI (for quick testing)
166
+ # -----------------------------
167
+ def gradio_enhance(
168
+ audio: Tuple[int, np.ndarray],
169
+ presence_db: float,
170
+ lowcut_hz: float,
171
+ output_sr: str,
172
+ ):
173
+ if audio is None:
174
+ return None
175
+ out_sr = None
176
+ if output_sr in {"44100", "48000"}:
177
+ out_sr = int(output_sr)
178
+ # "Original" -> None
179
+ sr_out, enhanced = _enhance_numpy_audio(
180
+ audio, presence_db=float(presence_db), lowcut_hz=float(lowcut_hz), out_sr=out_sr
181
+ )
182
+ return (sr_out, enhanced)
183
+
184
+
185
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
186
+ gr.Markdown("## Voice Clarity Booster (MetricGAN+)")
187
+ with gr.Row():
188
+ with gr.Column():
189
+ in_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input")
190
+ presence = gr.Slider(-12, 12, value=3, step=0.5, label="Presence Boost (dB)")
191
+ lowcut = gr.Slider(0, 200, value=75, step=5, label="Low-Cut (Hz)")
192
+ out_sr = gr.Radio(
193
+ choices=["Original", "44100", "48000"],
194
+ value="Original",
195
+ label="Output Sample Rate",
196
+ )
197
+ btn = gr.Button("Enhance")
198
+ with gr.Column():
199
+ out_audio = gr.Audio(type="numpy", label="Enhanced")
200
+
201
+ btn.click(gradio_enhance, inputs=[in_audio, presence, lowcut, out_sr], outputs=[out_audio])
202
+
203
+ # Mount Gradio at root path and keep FastAPI for /enhance
204
+ app = gr.mount_gradio_app(app, demo, path="/")