Jihuai's picture
download querier from hf directly; remove examples
7c3f87e
raw
history blame
5.18 kB
import os
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from train import init, inference_file, inference_file_text # 导入新的文本推理函数
import tempfile
# ===== Basic config =====
USE_CUDA = torch.cuda.is_available()
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12"))
os.environ['CONFIG_ROOT'] = './config'
# Read model repo and filename from environment variables
REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt")
FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt")
# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA)
# ===== Inference functions =====
def inference_audio(audio_path: str, query_audio_path: str):
"""使用音频作为query进行推理"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file(system, audio_path, output_path, query_audio_path) # 第二个参数改为query_path
return output_path
def inference_text(audio_path: str, query_text: str):
"""使用文本作为query进行推理"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file_text(system, audio_path, output_path, query_text) # 使用文本推理函数
return output_path
def inference(audio_path: str, query_input, query_type: str):
"""统一推理函数,根据query类型选择不同的推理方式"""
if not audio_path:
return None, "请先上传原音频文件"
if query_type == "audio":
if not query_input:
return None, "请上传query音频文件"
return inference_audio(audio_path, query_input), "✅ 使用音频query处理完成!"
else: # text
if not query_input:
return None, "请输入query文本"
return inference_text(audio_path, query_input), "✅ 使用文本query处理完成!"
# ===== Gradio UI =====
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🎧 DCCRN Speech Enhancement (Demo)
**How to use:**
1. 上传原音频文件
2. 选择query类型(音频或文字)并提供query内容
3. 点击 **Enhance** → 收听并下载结果
**Sample audio:** 点击下面的样例自动填充输入
"""
)
with gr.Row():
with gr.Column():
# 原音频输入
inp_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="原音频输入 (必选)"
)
# Query类型选择
query_type = gr.Radio(
choices=["audio", "text"],
value="audio",
label="Query类型"
)
# Query音频输入(默认显示)
inp_query_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Query音频输入",
visible=True
)
# Query文本输入(默认隐藏)
inp_query_text = gr.Textbox(
label="Query文本输入",
placeholder="请输入描述文本...",
visible=False
)
enhance_btn = gr.Button("Enhance", variant="primary")
with gr.Column():
# 输出音频
out_audio = gr.Audio(
label="增强后的音频输出",
show_download_button=True
)
# 状态显示
status = gr.Textbox(
label="处理状态",
value="就绪",
interactive=False
)
# Query类型切换的交互逻辑
def toggle_query_inputs(query_type):
"""根据选择的query类型显示/隐藏相应的输入组件"""
if query_type == "audio":
return gr.Audio(visible=True), gr.Textbox(visible=False)
else:
return gr.Audio(visible=False), gr.Textbox(visible=True)
query_type.change(
toggle_query_inputs,
inputs=query_type,
outputs=[inp_query_audio, inp_query_text]
)
# 增强按钮点击事件
def process_inference(audio_path, query_type, query_audio, query_text):
"""处理推理请求"""
# 根据query类型获取相应的query输入
query_input = query_audio if query_type == "audio" else query_text
# 调用推理函数
result, status_msg = inference(audio_path, query_input, query_type)
return result, status_msg
enhance_btn.click(
process_inference,
inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
outputs=[out_audio, status]
)
# Queue: keep a small queue to avoid OOM
demo.queue(max_size=8) # 减少队列大小,因为现在需要更多资源
demo.launch()