Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from train import init, inference_file, inference_file_text # 导入新的文本推理函数 | |
| import tempfile | |
| # ===== Basic config ===== | |
| USE_CUDA = torch.cuda.is_available() | |
| BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12")) | |
| os.environ['CONFIG_ROOT'] = './config' | |
| # Read model repo and filename from environment variables | |
| REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt") | |
| FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt") | |
| # ===== Download & load weights ===== | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) | |
| system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA) | |
| # ===== Inference functions ===== | |
| def inference_audio(audio_path: str, query_audio_path: str): | |
| """使用音频作为query进行推理""" | |
| temp_dir = tempfile.gettempdir() | |
| output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav') | |
| output_path = os.path.join(temp_dir, output_filename) | |
| inference_file(system, audio_path, output_path, query_audio_path) # 第二个参数改为query_path | |
| return output_path | |
| def inference_text(audio_path: str, query_text: str): | |
| """使用文本作为query进行推理""" | |
| temp_dir = tempfile.gettempdir() | |
| output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav') | |
| output_path = os.path.join(temp_dir, output_filename) | |
| inference_file_text(system, audio_path, output_path, query_text) # 使用文本推理函数 | |
| return output_path | |
| def inference(audio_path: str, query_input, query_type: str): | |
| """统一推理函数,根据query类型选择不同的推理方式""" | |
| if not audio_path: | |
| return None, "请先上传原音频文件" | |
| if query_type == "audio": | |
| if not query_input: | |
| return None, "请上传query音频文件" | |
| return inference_audio(audio_path, query_input), "✅ 使用音频query处理完成!" | |
| else: # text | |
| if not query_input: | |
| return None, "请输入query文本" | |
| return inference_text(audio_path, query_input), "✅ 使用文本query处理完成!" | |
| # ===== Gradio UI ===== | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎧 DCCRN Speech Enhancement (Demo) | |
| **How to use:** | |
| 1. 上传原音频文件 | |
| 2. 选择query类型(音频或文字)并提供query内容 | |
| 3. 点击 **Enhance** → 收听并下载结果 | |
| **Sample audio:** 点击下面的样例自动填充输入 | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 原音频输入 | |
| inp_audio = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="原音频输入 (必选)" | |
| ) | |
| # Query类型选择 | |
| query_type = gr.Radio( | |
| choices=["audio", "text"], | |
| value="audio", | |
| label="Query类型" | |
| ) | |
| # Query音频输入(默认显示) | |
| inp_query_audio = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Query音频输入", | |
| visible=True | |
| ) | |
| # Query文本输入(默认隐藏) | |
| inp_query_text = gr.Textbox( | |
| label="Query文本输入", | |
| placeholder="请输入描述文本...", | |
| visible=False | |
| ) | |
| enhance_btn = gr.Button("Enhance", variant="primary") | |
| with gr.Column(): | |
| # 输出音频 | |
| out_audio = gr.Audio( | |
| label="增强后的音频输出", | |
| show_download_button=True | |
| ) | |
| # 状态显示 | |
| status = gr.Textbox( | |
| label="处理状态", | |
| value="就绪", | |
| interactive=False | |
| ) | |
| # Query类型切换的交互逻辑 | |
| def toggle_query_inputs(query_type): | |
| """根据选择的query类型显示/隐藏相应的输入组件""" | |
| if query_type == "audio": | |
| return gr.Audio(visible=True), gr.Textbox(visible=False) | |
| else: | |
| return gr.Audio(visible=False), gr.Textbox(visible=True) | |
| query_type.change( | |
| toggle_query_inputs, | |
| inputs=query_type, | |
| outputs=[inp_query_audio, inp_query_text] | |
| ) | |
| # 增强按钮点击事件 | |
| def process_inference(audio_path, query_type, query_audio, query_text): | |
| """处理推理请求""" | |
| # 根据query类型获取相应的query输入 | |
| query_input = query_audio if query_type == "audio" else query_text | |
| # 调用推理函数 | |
| result, status_msg = inference(audio_path, query_input, query_type) | |
| return result, status_msg | |
| enhance_btn.click( | |
| process_inference, | |
| inputs=[inp_audio, query_type, inp_query_audio, inp_query_text], | |
| outputs=[out_audio, status] | |
| ) | |
| # Queue: keep a small queue to avoid OOM | |
| demo.queue(max_size=8) # 减少队列大小,因为现在需要更多资源 | |
| demo.launch() |