Jihuai's picture
to English
8a7cae1
raw
history blame
5.44 kB
import os
import torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
from train import init, inference_file, inference_file_text
import tempfile
# ===== Basic config =====
USE_CUDA = torch.cuda.is_available()
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12"))
os.environ['CONFIG_ROOT'] = './config'
# Read model repo and filename from environment variables
REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt")
FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt")
# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA)
# ===== Inference functions =====
def inference_audio(audio_path: str, query_audio_path: str):
"""Perform inference using audio as query"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file(system, audio_path, output_path, query_audio_path)
return output_path
def inference_text(audio_path: str, query_text: str):
"""Perform inference using text as query"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file_text(system, audio_path, output_path, query_text)
return output_path
@spaces.GPU
def inference(audio_path: str, query_input, query_type: str):
"""Unified inference function that selects method based on query type"""
if not audio_path:
return None, "Please upload the source audio file first"
if query_type == "audio":
if not query_input:
return None, "Please upload a query audio file"
return inference_audio(audio_path, query_input), "βœ… Processing with audio query completed!"
else: # text
if not query_input:
return None, "Please enter query text"
return inference_text(audio_path, query_input), "βœ… Processing with text query completed!"
# ===== Gradio UI =====
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🎡 Language-Audio Banquet (Demo)
**How to use:**
1. Upload the source audio file
2. Select query type (audio or text) and provide query content
3. Click **Inference** β†’ Listen to and download the result
"""
)
with gr.Row():
with gr.Column():
# Source audio input
inp_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Source Audio Input (Required)"
)
# Query type selection
query_type = gr.Radio(
choices=["audio", "text"],
value="audio",
label="Query Type"
)
# Query audio input (visible by default)
inp_query_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Query Audio Input",
visible=True
)
# Query text input (hidden by default)
inp_query_text = gr.Textbox(
label="Query Text Input",
placeholder="Enter description text...",
visible=False
)
inference_btn = gr.Button("Inference", variant="primary")
with gr.Column():
# Output audio
out_audio = gr.Audio(
label="Separated Audio Output",
show_download_button=True
)
# Status display
status = gr.Textbox(
label="Processing Status",
value="Ready",
interactive=False
)
# Query type toggle interaction
def toggle_query_inputs(query_type):
"""Show/hide appropriate input components based on selected query type"""
if query_type == "audio":
return gr.Audio(visible=True), gr.Textbox(visible=False)
else:
return gr.Audio(visible=False), gr.Textbox(visible=True)
query_type.change(
toggle_query_inputs,
inputs=query_type,
outputs=[inp_query_audio, inp_query_text]
)
# Button click event
def process_inference(audio_path, query_type, query_audio, query_text):
"""Handle inference request"""
# Get appropriate query input based on type
query_input = query_audio if query_type == "audio" else query_text
# Call inference function
result, status_msg = inference(audio_path, query_input, query_type)
return result, status_msg
inference_btn.click(
process_inference,
inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
outputs=[out_audio, status]
)
# Examples section
gr.Examples(
examples=[
["examples/forget.mp3", "audio", "examples/forget_bass.mp3", ""],
["examples/forget.mp3", "text", None, "vocal"],
],
inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
label="Example Audios",
examples_per_page=2,
)
# Queue: keep a small queue to avoid OOM
demo.queue(max_size=8)
demo.launch()