File size: 5,181 Bytes
9f787c6
 
 
 
0228637
9f787c6
 
 
 
 
 
c2f3ea7
9f787c6
 
 
 
 
 
 
 
0228637
 
 
9f787c6
 
 
0228637
9f787c6
 
0228637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f787c6
 
 
 
 
0228637
 
 
 
 
 
9f787c6
 
 
 
0228637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f787c6
0228637
 
 
 
 
 
 
 
 
 
 
 
 
9f787c6
0228637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f787c6
0228637
9f787c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from train import init, inference_file, inference_file_text  # 导入新的文本推理函数
import tempfile

# ===== Basic config =====
USE_CUDA = torch.cuda.is_available()
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12"))

os.environ['CONFIG_ROOT'] = './config'
# Read model repo and filename from environment variables
REPO_ID  = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt")
FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt")

# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA)

# ===== Inference functions =====
def inference_audio(audio_path: str, query_audio_path: str):
    """使用音频作为query进行推理"""
    temp_dir = tempfile.gettempdir()
    output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav')
    output_path = os.path.join(temp_dir, output_filename)
    inference_file(system, audio_path, output_path, query_audio_path)  # 第二个参数改为query_path
    return output_path

def inference_text(audio_path: str, query_text: str):
    """使用文本作为query进行推理"""
    temp_dir = tempfile.gettempdir()
    output_filename = os.path.basename(audio_path).replace('.wav', '_enhanced.wav')
    output_path = os.path.join(temp_dir, output_filename)
    inference_file_text(system, audio_path, output_path, query_text)  # 使用文本推理函数
    return output_path

def inference(audio_path: str, query_input, query_type: str):
    """统一推理函数,根据query类型选择不同的推理方式"""
    if not audio_path:
        return None, "请先上传原音频文件"
    
    if query_type == "audio":
        if not query_input:
            return None, "请上传query音频文件"
        return inference_audio(audio_path, query_input), "✅ 使用音频query处理完成!"
    else:  # text
        if not query_input:
            return None, "请输入query文本"
        return inference_text(audio_path, query_input), "✅ 使用文本query处理完成!"

# ===== Gradio UI =====
with gr.Blocks() as demo:
    gr.Markdown(
        """
# 🎧 DCCRN Speech Enhancement (Demo)
**How to use:** 
1. 上传原音频文件
2. 选择query类型(音频或文字)并提供query内容
3. 点击 **Enhance** → 收听并下载结果

**Sample audio:** 点击下面的样例自动填充输入
        """
    )

    with gr.Row():
        with gr.Column():
            # 原音频输入
            inp_audio = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="原音频输入 (必选)"
            )
            
            # Query类型选择
            query_type = gr.Radio(
                choices=["audio", "text"],
                value="audio",
                label="Query类型"
            )
            
            # Query音频输入(默认显示)
            inp_query_audio = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Query音频输入",
                visible=True
            )
            
            # Query文本输入(默认隐藏)
            inp_query_text = gr.Textbox(
                label="Query文本输入",
                placeholder="请输入描述文本...",
                visible=False
            )
            
            enhance_btn = gr.Button("Enhance", variant="primary")
        
        with gr.Column():
            # 输出音频
            out_audio = gr.Audio(
                label="增强后的音频输出",
                show_download_button=True
            )
            
            # 状态显示
            status = gr.Textbox(
                label="处理状态",
                value="就绪",
                interactive=False
            )

    # Query类型切换的交互逻辑
    def toggle_query_inputs(query_type):
        """根据选择的query类型显示/隐藏相应的输入组件"""
        if query_type == "audio":
            return gr.Audio(visible=True), gr.Textbox(visible=False)
        else:
            return gr.Audio(visible=False), gr.Textbox(visible=True)
    
    query_type.change(
        toggle_query_inputs,
        inputs=query_type,
        outputs=[inp_query_audio, inp_query_text]
    )

    # 增强按钮点击事件
    def process_inference(audio_path, query_type, query_audio, query_text):
        """处理推理请求"""
        # 根据query类型获取相应的query输入
        query_input = query_audio if query_type == "audio" else query_text
        
        # 调用推理函数
        result, status_msg = inference(audio_path, query_input, query_type)
        return result, status_msg
    
    enhance_btn.click(
        process_inference,
        inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
        outputs=[out_audio, status]
    )

# Queue: keep a small queue to avoid OOM
demo.queue(max_size=8)  # 减少队列大小,因为现在需要更多资源
demo.launch()