Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from train import init, inference_file | |
| import tempfile | |
| # ===== Basic config ===== | |
| USE_CUDA = torch.cuda.is_available() | |
| BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12")) | |
| # Read model repo and filename from environment variables | |
| REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt") | |
| FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt") | |
| # ===== Download & load weights ===== | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) | |
| system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA) | |
| # ===== Inference ===== | |
| def inference(audio_path: str): | |
| temp_dir = tempfile.gettempdir() | |
| output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav') | |
| output_path = os.path.join(temp_dir, output_filename) | |
| inference_file(system, audio_path, output_path, audio_path) | |
| return output_path | |
| # ===== Gradio UI ===== | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # π§ DCCRN Speech Enhancement (Demo) | |
| **How to use:** drag & drop a noisy audio clip (or upload / record) β click **Enhance** β listen & download the result. | |
| **Sample audio:** click a sample below to auto-fill the input, then click **Enhance**. | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Audio( | |
| sources=["upload", "microphone"], # drag & drop supported by default | |
| type="filepath", | |
| label="Input: noisy speech (drag & drop or upload / record)" | |
| ) | |
| out = gr.Audio( | |
| label="Output: enhanced speech (downloadable)", | |
| show_download_button=True | |
| ) | |
| enhance_btn = gr.Button("Enhance") | |
| # On-page sample clips (make sure these files exist in the repo) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/noisy_1.wav"], | |
| ["examples/noisy_2.wav"], | |
| ["examples/noisy_3.wav"], | |
| ], | |
| inputs=inp, | |
| label="Sample audio", | |
| examples_per_page=3, | |
| ) | |
| # Gradio β₯4.44: set concurrency on the event listener | |
| enhance_btn.click(inference, inputs=inp, outputs=out, concurrency_limit=1) | |
| # Queue: keep a small queue to avoid OOM | |
| demo.queue(max_size=16) | |
| demo.launch() |