Spaces:
Runtime error
Runtime error
File size: 5,181 Bytes
9f787c6 0228637 9f787c6 c2f3ea7 9f787c6 0228637 9f787c6 0228637 9f787c6 0228637 9f787c6 0228637 9f787c6 0228637 9f787c6 0228637 9f787c6 0228637 9f787c6 0228637 9f787c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from train import init, inference_file, inference_file_text # 导入新的文本推理函数
import tempfile
# ===== Basic config =====
USE_CUDA = torch.cuda.is_available()
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12"))
os.environ['CONFIG_ROOT'] = './config'
# Read model repo and filename from environment variables
REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt")
FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt")
# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA)
# ===== Inference functions =====
def inference_audio(audio_path: str, query_audio_path: str):
"""使用音频作为query进行推理"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file(system, audio_path, output_path, query_audio_path) # 第二个参数改为query_path
return output_path
def inference_text(audio_path: str, query_text: str):
"""使用文本作为query进行推理"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file_text(system, audio_path, output_path, query_text) # 使用文本推理函数
return output_path
def inference(audio_path: str, query_input, query_type: str):
"""统一推理函数,根据query类型选择不同的推理方式"""
if not audio_path:
return None, "请先上传原音频文件"
if query_type == "audio":
if not query_input:
return None, "请上传query音频文件"
return inference_audio(audio_path, query_input), "✅ 使用音频query处理完成!"
else: # text
if not query_input:
return None, "请输入query文本"
return inference_text(audio_path, query_input), "✅ 使用文本query处理完成!"
# ===== Gradio UI =====
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🎧 DCCRN Speech Enhancement (Demo)
**How to use:**
1. 上传原音频文件
2. 选择query类型(音频或文字)并提供query内容
3. 点击 **Enhance** → 收听并下载结果
**Sample audio:** 点击下面的样例自动填充输入
"""
)
with gr.Row():
with gr.Column():
# 原音频输入
inp_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="原音频输入 (必选)"
)
# Query类型选择
query_type = gr.Radio(
choices=["audio", "text"],
value="audio",
label="Query类型"
)
# Query音频输入(默认显示)
inp_query_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Query音频输入",
visible=True
)
# Query文本输入(默认隐藏)
inp_query_text = gr.Textbox(
label="Query文本输入",
placeholder="请输入描述文本...",
visible=False
)
enhance_btn = gr.Button("Enhance", variant="primary")
with gr.Column():
# 输出音频
out_audio = gr.Audio(
label="增强后的音频输出",
show_download_button=True
)
# 状态显示
status = gr.Textbox(
label="处理状态",
value="就绪",
interactive=False
)
# Query类型切换的交互逻辑
def toggle_query_inputs(query_type):
"""根据选择的query类型显示/隐藏相应的输入组件"""
if query_type == "audio":
return gr.Audio(visible=True), gr.Textbox(visible=False)
else:
return gr.Audio(visible=False), gr.Textbox(visible=True)
query_type.change(
toggle_query_inputs,
inputs=query_type,
outputs=[inp_query_audio, inp_query_text]
)
# 增强按钮点击事件
def process_inference(audio_path, query_type, query_audio, query_text):
"""处理推理请求"""
# 根据query类型获取相应的query输入
query_input = query_audio if query_type == "audio" else query_text
# 调用推理函数
result, status_msg = inference(audio_path, query_input, query_type)
return result, status_msg
enhance_btn.click(
process_inference,
inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
outputs=[out_audio, status]
)
# Queue: keep a small queue to avoid OOM
demo.queue(max_size=8) # 减少队列大小,因为现在需要更多资源
demo.launch() |