Diggz10 commited on
Commit
b67ceda
·
verified ·
1 Parent(s): 4497e6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -28
app.py CHANGED
@@ -3,14 +3,34 @@ import os
3
  import tempfile
4
  from typing import Tuple, Optional
5
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gradio as gr
7
  import numpy as np
8
  import soundfile as sf
9
  import torch
10
  import torchaudio
11
- from fastapi import FastAPI, File, UploadFile, Query, Response
12
  from fastapi.responses import StreamingResponse
13
- from speechbrain.pretrained import SpectralMaskEnhancement
 
 
 
 
 
 
 
 
14
 
15
  # -----------------------------
16
  # Model: SpeechBrain MetricGAN+
@@ -22,7 +42,6 @@ _DEVICE = "cpu"
22
  def _get_enhancer() -> SpectralMaskEnhancement:
23
  global _ENHANCER
24
  if _ENHANCER is None:
25
- # Downloads once and caches in the Space
26
  _ENHANCER = SpectralMaskEnhancement.from_hparams(
27
  source="speechbrain/metricgan-plus-voicebank",
28
  savedir="pretrained/metricgan_plus_voicebank",
@@ -35,16 +54,13 @@ def _get_enhancer() -> SpectralMaskEnhancement:
35
  # Audio helpers
36
  # -----------------------------
37
  def _to_mono(wav: np.ndarray) -> np.ndarray:
38
- """Ensure mono shape [T]."""
39
  if wav.ndim == 1:
40
  return wav.astype(np.float32)
41
- # shape [T, C] or [C, T]
42
  if wav.shape[0] < wav.shape[1]:
43
- # likely [T, C]
44
  return wav.mean(axis=1).astype(np.float32)
45
- else:
46
- # likely [C, T]
47
- return wav.mean(axis=0).astype(np.float32)
48
 
49
 
50
  def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor:
@@ -56,21 +72,19 @@ def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor:
56
  def _highpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor:
57
  if cutoff_hz is None or cutoff_hz <= 0:
58
  return wav
59
- # 2nd-order Butterworth-ish highpass via biquad
60
  return torchaudio.functional.highpass_biquad(wav, sr, cutoff_hz)
61
 
62
 
63
  def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor:
64
- """Simple presence (peaking) EQ around 4.5 kHz."""
65
  if abs(gain_db) < 1e-6:
66
  return wav
67
- center = 4500.0 # presence band
68
- q = 0.707 # wide-ish
69
  return torchaudio.functional.equalizer_biquad(wav, sr, center, q, gain_db)
70
 
71
 
72
  def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor:
73
- """Peak-normalize to target dBFS (default -1 dB)."""
74
  target_amp = 10.0 ** (target_dbfs / 20.0)
75
  peak = torch.max(torch.abs(wav)).item()
76
  if peak > 0:
@@ -98,35 +112,33 @@ def _enhance_numpy_audio(
98
  enh = _get_enhancer()
99
  wav_16k = _resample_torch(wav_t, sr_in, 16000)
100
 
101
- # Enhance via file path API for maximum compatibility
102
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
103
  sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16")
104
  tmp_in.flush()
105
- # Enhance; returns torch.Tensor [1, T]
106
- clean = enh.enhance_file(tmp_in.name)
107
  try:
108
  os.remove(tmp_in.name)
109
  except Exception:
110
  pass
111
 
112
- # Optional polish: high-pass & presence EQ
113
  clean = _highpass(clean, 16000, lowcut_hz)
114
  clean = _presence_boost(clean, 16000, presence_db)
115
-
116
- # Peak limiting to avoid inter-sample clip
117
  clean = _limit_peak(clean, target_dbfs=-1.0)
118
 
119
- # Resample back
120
  sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr)
121
- clean_out = _resample_torch(clean, 16000, sr_out).squeeze(0).numpy().astype(
122
- np.float32
123
  )
124
 
125
  return sr_out, clean_out
126
 
127
 
128
  def _wav_bytes(sr: int, mono_f32: np.ndarray) -> bytes:
129
- """Encode a mono float32 array as 16-bit PCM WAV into bytes."""
130
  buf = io.BytesIO()
131
  sf.write(buf, mono_f32, sr, subtype="PCM_16", format="WAV")
132
  buf.seek(0)
@@ -136,7 +148,7 @@ def _wav_bytes(sr: int, mono_f32: np.ndarray) -> bytes:
136
  # -----------------------------
137
  # FastAPI app with raw endpoint
138
  # -----------------------------
139
- app = FastAPI(title="Voice Clarity Booster (MetricGAN+)", version="1.0.0")
140
 
141
 
142
  @app.post("/enhance")
@@ -148,7 +160,6 @@ async def enhance_endpoint(
148
  ):
149
  """Raw REST endpoint. Returns enhanced audio as audio/wav bytes."""
150
  data = await file.read()
151
- # Decode with soundfile
152
  wav_np, sr_in = sf.read(io.BytesIO(data), always_2d=False, dtype="float32")
153
  sr_out, enhanced = _enhance_numpy_audio(
154
  (sr_in, wav_np),
@@ -157,7 +168,9 @@ async def enhance_endpoint(
157
  out_sr=output_sr if output_sr > 0 else None,
158
  )
159
  wav_bytes = _wav_bytes(sr_out, enhanced)
160
- headers = {"Content-Disposition": f'attachment; filename="{os.path.splitext(file.filename or "audio")[0]}_enhanced.wav"'}
 
 
161
  return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav", headers=headers)
162
 
163
 
@@ -175,7 +188,6 @@ def gradio_enhance(
175
  out_sr = None
176
  if output_sr in {"44100", "48000"}:
177
  out_sr = int(output_sr)
178
- # "Original" -> None
179
  sr_out, enhanced = _enhance_numpy_audio(
180
  audio, presence_db=float(presence_db), lowcut_hz=float(lowcut_hz), out_sr=out_sr
181
  )
 
3
  import tempfile
4
  from typing import Tuple, Optional
5
 
6
+ # ---- tame noisy deprecation warnings (optional but nice) ----
7
+ import warnings
8
+ warnings.filterwarnings(
9
+ "ignore",
10
+ message=".*torchaudio._backend.list_audio_backends has been deprecated.*",
11
+ )
12
+ warnings.filterwarnings(
13
+ "ignore",
14
+ module=r"speechbrain\..*",
15
+ category=UserWarning,
16
+ )
17
+
18
  import gradio as gr
19
  import numpy as np
20
  import soundfile as sf
21
  import torch
22
  import torchaudio
23
+ from fastapi import FastAPI, File, UploadFile, Query
24
  from fastapi.responses import StreamingResponse
25
+
26
+ # ---- SpeechBrain import: prefer new API, fall back if older version ----
27
+ try:
28
+ # SpeechBrain >= 1.0
29
+ from speechbrain.inference import SpectralMaskEnhancement
30
+ except Exception: # pragma: no cover
31
+ # Older SpeechBrain (<1.0)
32
+ from speechbrain.pretrained import SpectralMaskEnhancement # type: ignore
33
+
34
 
35
  # -----------------------------
36
  # Model: SpeechBrain MetricGAN+
 
42
  def _get_enhancer() -> SpectralMaskEnhancement:
43
  global _ENHANCER
44
  if _ENHANCER is None:
 
45
  _ENHANCER = SpectralMaskEnhancement.from_hparams(
46
  source="speechbrain/metricgan-plus-voicebank",
47
  savedir="pretrained/metricgan_plus_voicebank",
 
54
  # Audio helpers
55
  # -----------------------------
56
  def _to_mono(wav: np.ndarray) -> np.ndarray:
57
+ """Ensure mono shape [T] float32."""
58
  if wav.ndim == 1:
59
  return wav.astype(np.float32)
60
+ # [T, C] or [C, T]
61
  if wav.shape[0] < wav.shape[1]:
 
62
  return wav.mean(axis=1).astype(np.float32)
63
+ return wav.mean(axis=0).astype(np.float32)
 
 
64
 
65
 
66
  def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor:
 
72
  def _highpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor:
73
  if cutoff_hz is None or cutoff_hz <= 0:
74
  return wav
 
75
  return torchaudio.functional.highpass_biquad(wav, sr, cutoff_hz)
76
 
77
 
78
  def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor:
79
+ """Simple presence EQ around ~4.5 kHz."""
80
  if abs(gain_db) < 1e-6:
81
  return wav
82
+ center = 4500.0
83
+ q = 0.707
84
  return torchaudio.functional.equalizer_biquad(wav, sr, center, q, gain_db)
85
 
86
 
87
  def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor:
 
88
  target_amp = 10.0 ** (target_dbfs / 20.0)
89
  peak = torch.max(torch.abs(wav)).item()
90
  if peak > 0:
 
112
  enh = _get_enhancer()
113
  wav_16k = _resample_torch(wav_t, sr_in, 16000)
114
 
115
+ # Enhance via file path API for broad compatibility
116
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
117
  sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16")
118
  tmp_in.flush()
119
+ clean = enh.enhance_file(tmp_in.name) # torch.Tensor [1, T]
120
+
121
  try:
122
  os.remove(tmp_in.name)
123
  except Exception:
124
  pass
125
 
126
+ # Optional polish: high-pass & presence EQ + peak limit
127
  clean = _highpass(clean, 16000, lowcut_hz)
128
  clean = _presence_boost(clean, 16000, presence_db)
 
 
129
  clean = _limit_peak(clean, target_dbfs=-1.0)
130
 
131
+ # Resample to requested output rate (or original)
132
  sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr)
133
+ clean_out = (
134
+ _resample_torch(clean, 16000, sr_out).squeeze(0).numpy().astype(np.float32)
135
  )
136
 
137
  return sr_out, clean_out
138
 
139
 
140
  def _wav_bytes(sr: int, mono_f32: np.ndarray) -> bytes:
141
+ """Encode mono float32 array as 16-bit PCM WAV bytes."""
142
  buf = io.BytesIO()
143
  sf.write(buf, mono_f32, sr, subtype="PCM_16", format="WAV")
144
  buf.seek(0)
 
148
  # -----------------------------
149
  # FastAPI app with raw endpoint
150
  # -----------------------------
151
+ app = FastAPI(title="Voice Clarity Booster (MetricGAN+)", version="1.0.1")
152
 
153
 
154
  @app.post("/enhance")
 
160
  ):
161
  """Raw REST endpoint. Returns enhanced audio as audio/wav bytes."""
162
  data = await file.read()
 
163
  wav_np, sr_in = sf.read(io.BytesIO(data), always_2d=False, dtype="float32")
164
  sr_out, enhanced = _enhance_numpy_audio(
165
  (sr_in, wav_np),
 
168
  out_sr=output_sr if output_sr > 0 else None,
169
  )
170
  wav_bytes = _wav_bytes(sr_out, enhanced)
171
+ headers = {
172
+ "Content-Disposition": f'attachment; filename="{os.path.splitext(file.filename or "audio")[0]}_enhanced.wav"'
173
+ }
174
  return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav", headers=headers)
175
 
176
 
 
188
  out_sr = None
189
  if output_sr in {"44100", "48000"}:
190
  out_sr = int(output_sr)
 
191
  sr_out, enhanced = _enhance_numpy_audio(
192
  audio, presence_db=float(presence_db), lowcut_hz=float(lowcut_hz), out_sr=out_sr
193
  )