File size: 5,440 Bytes
9f787c6
 
 
fe47004
9f787c6
8a7cae1
9f787c6
 
 
 
 
 
c2f3ea7
9f787c6
8a7cae1
9f787c6
 
 
 
 
 
0228637
 
8a7cae1
9f787c6
726db2f
9f787c6
8a7cae1
9f787c6
 
0228637
8a7cae1
0228637
726db2f
0228637
8a7cae1
0228637
 
fe47004
0228637
8a7cae1
0228637
8a7cae1
0228637
 
 
8a7cae1
 
0228637
 
8a7cae1
 
0228637
9f787c6
 
 
 
726db2f
0228637
8a7cae1
 
 
9f787c6
 
 
 
0228637
8a7cae1
0228637
 
 
8a7cae1
0228637
 
8a7cae1
0228637
 
 
8a7cae1
0228637
 
8a7cae1
0228637
 
 
8a7cae1
0228637
 
 
8a7cae1
0228637
8a7cae1
 
0228637
 
 
726db2f
0228637
 
8a7cae1
0228637
8a7cae1
0228637
 
 
8a7cae1
0228637
8a7cae1
 
0228637
 
9f787c6
8a7cae1
0228637
8a7cae1
0228637
 
 
 
 
 
 
 
 
 
9f787c6
8a7cae1
0228637
8a7cae1
 
0228637
 
8a7cae1
0228637
 
 
726db2f
0228637
 
 
 
74c4f1d
8a7cae1
74c4f1d
 
 
9a2a578
74c4f1d
 
8a7cae1
74c4f1d
 
0228637
9f787c6
8a7cae1
9f787c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
from train import init, inference_file, inference_file_text
import tempfile

# ===== Basic config =====
USE_CUDA = torch.cuda.is_available()
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12"))

os.environ['CONFIG_ROOT'] = './config'
# Read model repo and filename from environment variables
REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt")
FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt")

# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA)

# ===== Inference functions =====
def inference_audio(audio_path: str, query_audio_path: str):
    """Perform inference using audio as query"""
    temp_dir = tempfile.gettempdir()
    output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav')
    output_path = os.path.join(temp_dir, output_filename)
    inference_file(system, audio_path, output_path, query_audio_path)
    return output_path

def inference_text(audio_path: str, query_text: str):
    """Perform inference using text as query"""
    temp_dir = tempfile.gettempdir()
    output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav')
    output_path = os.path.join(temp_dir, output_filename)
    inference_file_text(system, audio_path, output_path, query_text)
    return output_path

@spaces.GPU
def inference(audio_path: str, query_input, query_type: str):
    """Unified inference function that selects method based on query type"""
    if not audio_path:
        return None, "Please upload the source audio file first"
    
    if query_type == "audio":
        if not query_input:
            return None, "Please upload a query audio file"
        return inference_audio(audio_path, query_input), "βœ… Processing with audio query completed!"
    else:  # text
        if not query_input:
            return None, "Please enter query text"
        return inference_text(audio_path, query_input), "βœ… Processing with text query completed!"

# ===== Gradio UI =====
with gr.Blocks() as demo:
    gr.Markdown(
        """
# 🎡 Language-Audio Banquet (Demo)
**How to use:** 
1. Upload the source audio file
2. Select query type (audio or text) and provide query content
3. Click **Inference** β†’ Listen to and download the result
        """
    )

    with gr.Row():
        with gr.Column():
            # Source audio input
            inp_audio = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Source Audio Input (Required)"
            )
            
            # Query type selection
            query_type = gr.Radio(
                choices=["audio", "text"],
                value="audio",
                label="Query Type"
            )
            
            # Query audio input (visible by default)
            inp_query_audio = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Query Audio Input",
                visible=True
            )
            
            # Query text input (hidden by default)
            inp_query_text = gr.Textbox(
                label="Query Text Input",
                placeholder="Enter description text...",
                visible=False
            )
            
            inference_btn = gr.Button("Inference", variant="primary")
        
        with gr.Column():
            # Output audio
            out_audio = gr.Audio(
                label="Separated Audio Output",
                show_download_button=True
            )
            
            # Status display
            status = gr.Textbox(
                label="Processing Status",
                value="Ready",
                interactive=False
            )

    # Query type toggle interaction
    def toggle_query_inputs(query_type):
        """Show/hide appropriate input components based on selected query type"""
        if query_type == "audio":
            return gr.Audio(visible=True), gr.Textbox(visible=False)
        else:
            return gr.Audio(visible=False), gr.Textbox(visible=True)
    
    query_type.change(
        toggle_query_inputs,
        inputs=query_type,
        outputs=[inp_query_audio, inp_query_text]
    )

    # Button click event
    def process_inference(audio_path, query_type, query_audio, query_text):
        """Handle inference request"""
        # Get appropriate query input based on type
        query_input = query_audio if query_type == "audio" else query_text
        
        # Call inference function
        result, status_msg = inference(audio_path, query_input, query_type)
        return result, status_msg
    
    inference_btn.click(
        process_inference,
        inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
        outputs=[out_audio, status]
    )
    
    # Examples section
    gr.Examples(
        examples=[
            ["examples/forget.mp3", "audio", "examples/forget_bass.mp3", ""],
            ["examples/forget.mp3", "text", None, "vocal"],
        ],
        inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
        label="Example Audios",
        examples_per_page=2,
    )

# Queue: keep a small queue to avoid OOM
demo.queue(max_size=8)
demo.launch()