Spaces:
Runtime error
Runtime error
File size: 2,240 Bytes
9f787c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from train import init, inference_file
import tempfile
# ===== Basic config =====
USE_CUDA = torch.cuda.is_available()
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12"))
# Read model repo and filename from environment variables
REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt")
FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt")
# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA)
# ===== Inference =====
def inference(audio_path: str):
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file(system, audio_path, output_path, audio_path)
return output_path
# ===== Gradio UI =====
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🎧 DCCRN Speech Enhancement (Demo)
**How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result.
**Sample audio:** click a sample below to auto-fill the input, then click **Enhance**.
"""
)
with gr.Row():
inp = gr.Audio(
sources=["upload", "microphone"], # drag & drop supported by default
type="filepath",
label="Input: noisy speech (drag & drop or upload / record)"
)
out = gr.Audio(
label="Output: enhanced speech (downloadable)",
show_download_button=True
)
enhance_btn = gr.Button("Enhance")
# On-page sample clips (make sure these files exist in the repo)
gr.Examples(
examples=[
["examples/noisy_1.wav"],
["examples/noisy_2.wav"],
["examples/noisy_3.wav"],
],
inputs=inp,
label="Sample audio",
examples_per_page=3,
)
# Gradio ≥4.44: set concurrency on the event listener
enhance_btn.click(inference, inputs=inp, outputs=out, concurrency_limit=1)
# Queue: keep a small queue to avoid OOM
demo.queue(max_size=16)
demo.launch() |