Spaces:
Runtime error
Runtime error
File size: 5,440 Bytes
9f787c6 fe47004 9f787c6 8a7cae1 9f787c6 c2f3ea7 9f787c6 8a7cae1 9f787c6 0228637 8a7cae1 9f787c6 726db2f 9f787c6 8a7cae1 9f787c6 0228637 8a7cae1 0228637 726db2f 0228637 8a7cae1 0228637 fe47004 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 9f787c6 726db2f 0228637 8a7cae1 9f787c6 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 726db2f 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 9f787c6 8a7cae1 0228637 8a7cae1 0228637 9f787c6 8a7cae1 0228637 8a7cae1 0228637 8a7cae1 0228637 726db2f 0228637 74c4f1d 8a7cae1 74c4f1d 9a2a578 74c4f1d 8a7cae1 74c4f1d 0228637 9f787c6 8a7cae1 9f787c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
from train import init, inference_file, inference_file_text
import tempfile
# ===== Basic config =====
USE_CUDA = torch.cuda.is_available()
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12"))
os.environ['CONFIG_ROOT'] = './config'
# Read model repo and filename from environment variables
REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt")
FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt")
# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA)
# ===== Inference functions =====
def inference_audio(audio_path: str, query_audio_path: str):
"""Perform inference using audio as query"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file(system, audio_path, output_path, query_audio_path)
return output_path
def inference_text(audio_path: str, query_text: str):
"""Perform inference using text as query"""
temp_dir = tempfile.gettempdir()
output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav')
output_path = os.path.join(temp_dir, output_filename)
inference_file_text(system, audio_path, output_path, query_text)
return output_path
@spaces.GPU
def inference(audio_path: str, query_input, query_type: str):
"""Unified inference function that selects method based on query type"""
if not audio_path:
return None, "Please upload the source audio file first"
if query_type == "audio":
if not query_input:
return None, "Please upload a query audio file"
return inference_audio(audio_path, query_input), "β
Processing with audio query completed!"
else: # text
if not query_input:
return None, "Please enter query text"
return inference_text(audio_path, query_input), "β
Processing with text query completed!"
# ===== Gradio UI =====
with gr.Blocks() as demo:
gr.Markdown(
"""
# π΅ Language-Audio Banquet (Demo)
**How to use:**
1. Upload the source audio file
2. Select query type (audio or text) and provide query content
3. Click **Inference** β Listen to and download the result
"""
)
with gr.Row():
with gr.Column():
# Source audio input
inp_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Source Audio Input (Required)"
)
# Query type selection
query_type = gr.Radio(
choices=["audio", "text"],
value="audio",
label="Query Type"
)
# Query audio input (visible by default)
inp_query_audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Query Audio Input",
visible=True
)
# Query text input (hidden by default)
inp_query_text = gr.Textbox(
label="Query Text Input",
placeholder="Enter description text...",
visible=False
)
inference_btn = gr.Button("Inference", variant="primary")
with gr.Column():
# Output audio
out_audio = gr.Audio(
label="Separated Audio Output",
show_download_button=True
)
# Status display
status = gr.Textbox(
label="Processing Status",
value="Ready",
interactive=False
)
# Query type toggle interaction
def toggle_query_inputs(query_type):
"""Show/hide appropriate input components based on selected query type"""
if query_type == "audio":
return gr.Audio(visible=True), gr.Textbox(visible=False)
else:
return gr.Audio(visible=False), gr.Textbox(visible=True)
query_type.change(
toggle_query_inputs,
inputs=query_type,
outputs=[inp_query_audio, inp_query_text]
)
# Button click event
def process_inference(audio_path, query_type, query_audio, query_text):
"""Handle inference request"""
# Get appropriate query input based on type
query_input = query_audio if query_type == "audio" else query_text
# Call inference function
result, status_msg = inference(audio_path, query_input, query_type)
return result, status_msg
inference_btn.click(
process_inference,
inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
outputs=[out_audio, status]
)
# Examples section
gr.Examples(
examples=[
["examples/forget.mp3", "audio", "examples/forget_bass.mp3", ""],
["examples/forget.mp3", "text", None, "vocal"],
],
inputs=[inp_audio, query_type, inp_query_audio, inp_query_text],
label="Example Audios",
examples_per_page=2,
)
# Queue: keep a small queue to avoid OOM
demo.queue(max_size=8)
demo.launch() |