import os import torch import gradio as gr import spaces from huggingface_hub import hf_hub_download from train import init, inference_file, inference_file_text import tempfile # ===== Basic config ===== USE_CUDA = torch.cuda.is_available() BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12")) os.environ['CONFIG_ROOT'] = './config' # Read model repo and filename from environment variables REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt") FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt") # ===== Download & load weights ===== ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA) # ===== Inference functions ===== def inference_audio(audio_path: str, query_audio_path: str): """Perform inference using audio as query""" temp_dir = tempfile.gettempdir() output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav') output_path = os.path.join(temp_dir, output_filename) inference_file(system, audio_path, output_path, query_audio_path) return output_path def inference_text(audio_path: str, query_text: str): """Perform inference using text as query""" temp_dir = tempfile.gettempdir() output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav') output_path = os.path.join(temp_dir, output_filename) inference_file_text(system, audio_path, output_path, query_text) return output_path @spaces.GPU def inference(audio_path: str, query_input, query_type: str): """Unified inference function that selects method based on query type""" if not audio_path: return None, "Please upload the source audio file first" if query_type == "audio": if not query_input: return None, "Please upload a query audio file" return inference_audio(audio_path, query_input), "✅ Processing with audio query completed!" else: # text if not query_input: return None, "Please enter query text" return inference_text(audio_path, query_input), "✅ Processing with text query completed!" # ===== Gradio UI ===== with gr.Blocks() as demo: gr.Markdown( """ # 🎵 Language-Audio Banquet (Demo) **How to use:** 1. Upload the source audio file 2. Select query type (audio or text) and provide query content 3. Click **Inference** → Listen to and download the result """ ) with gr.Row(): with gr.Column(): # Source audio input inp_audio = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Source Audio Input (Required)" ) # Query type selection query_type = gr.Radio( choices=["audio", "text"], value="audio", label="Query Type" ) # Query audio input (visible by default) inp_query_audio = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Query Audio Input", visible=True ) # Query text input (hidden by default) inp_query_text = gr.Textbox( label="Query Text Input", placeholder="Enter description text...", info="Examples: 'bass', 'vocal', 'drum kit'. A more detailed description e.g. 'Lead Female Vox' is acceptable but the result may not be accurate and varies greatly depending on the choice of words.", visible=False ) inference_btn = gr.Button("Inference", variant="primary") with gr.Column(): # Output audio out_audio = gr.Audio( label="Separated Audio Output", show_download_button=True ) # Status display status = gr.Textbox( label="Processing Status", value="Ready", interactive=False ) # Query type toggle interaction def toggle_query_inputs(query_type): """Show/hide appropriate input components based on selected query type""" if query_type == "audio": return gr.Audio(visible=True), gr.Textbox(visible=False) else: return gr.Audio(visible=False), gr.Textbox(visible=True) query_type.change( toggle_query_inputs, inputs=query_type, outputs=[inp_query_audio, inp_query_text] ) # Button click event def process_inference(audio_path, query_type, query_audio, query_text): """Handle inference request""" # Get appropriate query input based on type query_input = query_audio if query_type == "audio" else query_text # Call inference function result, status_msg = inference(audio_path, query_input, query_type) return result, status_msg inference_btn.click( process_inference, inputs=[inp_audio, query_type, inp_query_audio, inp_query_text], outputs=[out_audio, status] ) # Examples section gr.Examples( examples=[ ["examples/forget.mp3", "audio", "examples/forget_bass.mp3", ""], ["examples/forget.mp3", "text", None, "vocal"], ["examples/sonata.mp3", "audio", "examples/sonata_violin.mp3", ""], ["examples/sonata.mp3", "text", None, "grand piano"], ], inputs=[inp_audio, query_type, inp_query_audio, inp_query_text], label="Example Audios", examples_per_page=4, ) # Queue: keep a small queue to avoid OOM demo.queue(max_size=8) demo.launch()