Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| from train import init, inference_file, inference_file_text | |
| import tempfile | |
| # ===== Basic config ===== | |
| USE_CUDA = torch.cuda.is_available() | |
| BATCH_SIZE = int(os.getenv("BATCH_SIZE", "12")) | |
| os.environ['CONFIG_ROOT'] = './config' | |
| # Read model repo and filename from environment variables | |
| REPO_ID = os.getenv("MODEL_REPO_ID", "chenxie95/Language-Audio-Banquet-ckpt") | |
| FILENAME = os.getenv("MODEL_FILENAME", "ev-pre-aug.ckpt") | |
| # ===== Download & load weights ===== | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) | |
| system = init(ckpt_path, batch_size=BATCH_SIZE, use_cuda=USE_CUDA) | |
| # ===== Inference functions ===== | |
| def inference_audio(audio_path: str, query_audio_path: str): | |
| """Perform inference using audio as query""" | |
| temp_dir = tempfile.gettempdir() | |
| output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav') | |
| output_path = os.path.join(temp_dir, output_filename) | |
| inference_file(system, audio_path, output_path, query_audio_path) | |
| return output_path | |
| def inference_text(audio_path: str, query_text: str): | |
| """Perform inference using text as query""" | |
| temp_dir = tempfile.gettempdir() | |
| output_filename = os.path.basename(audio_path).replace('.wav', '_inference.wav') | |
| output_path = os.path.join(temp_dir, output_filename) | |
| inference_file_text(system, audio_path, output_path, query_text) | |
| return output_path | |
| def inference(audio_path: str, query_input, query_type: str): | |
| """Unified inference function that selects method based on query type""" | |
| if not audio_path: | |
| return None, "Please upload the source audio file first" | |
| if query_type == "audio": | |
| if not query_input: | |
| return None, "Please upload a query audio file" | |
| return inference_audio(audio_path, query_input), "β Processing with audio query completed!" | |
| else: # text | |
| if not query_input: | |
| return None, "Please enter query text" | |
| return inference_text(audio_path, query_input), "β Processing with text query completed!" | |
| # ===== Gradio UI ===== | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # π΅ Language-Audio Banquet (Demo) | |
| **How to use:** | |
| 1. Upload the source audio file | |
| 2. Select query type (audio or text) and provide query content | |
| 3. Click **Inference** β Listen to and download the result | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Source audio input | |
| inp_audio = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Source Audio Input (Required)" | |
| ) | |
| # Query type selection | |
| query_type = gr.Radio( | |
| choices=["audio", "text"], | |
| value="audio", | |
| label="Query Type" | |
| ) | |
| # Query audio input (visible by default) | |
| inp_query_audio = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Query Audio Input", | |
| visible=True | |
| ) | |
| # Query text input (hidden by default) | |
| inp_query_text = gr.Textbox( | |
| label="Query Text Input", | |
| placeholder="Enter description text...", | |
| visible=False | |
| ) | |
| inference_btn = gr.Button("Inference", variant="primary") | |
| with gr.Column(): | |
| # Output audio | |
| out_audio = gr.Audio( | |
| label="Separated Audio Output", | |
| show_download_button=True | |
| ) | |
| # Status display | |
| status = gr.Textbox( | |
| label="Processing Status", | |
| value="Ready", | |
| interactive=False | |
| ) | |
| # Query type toggle interaction | |
| def toggle_query_inputs(query_type): | |
| """Show/hide appropriate input components based on selected query type""" | |
| if query_type == "audio": | |
| return gr.Audio(visible=True), gr.Textbox(visible=False) | |
| else: | |
| return gr.Audio(visible=False), gr.Textbox(visible=True) | |
| query_type.change( | |
| toggle_query_inputs, | |
| inputs=query_type, | |
| outputs=[inp_query_audio, inp_query_text] | |
| ) | |
| # Button click event | |
| def process_inference(audio_path, query_type, query_audio, query_text): | |
| """Handle inference request""" | |
| # Get appropriate query input based on type | |
| query_input = query_audio if query_type == "audio" else query_text | |
| # Call inference function | |
| result, status_msg = inference(audio_path, query_input, query_type) | |
| return result, status_msg | |
| inference_btn.click( | |
| process_inference, | |
| inputs=[inp_audio, query_type, inp_query_audio, inp_query_text], | |
| outputs=[out_audio, status] | |
| ) | |
| # Examples section | |
| gr.Examples( | |
| examples=[ | |
| ["examples/forget.mp3", "audio", "examples/forget_bass.mp3", ""], | |
| ["examples/forget.mp3", "text", None, "vocal"], | |
| ], | |
| inputs=[inp_audio, query_type, inp_query_audio, inp_query_text], | |
| label="Example Audios", | |
| examples_per_page=2, | |
| ) | |
| # Queue: keep a small queue to avoid OOM | |
| demo.queue(max_size=8) | |
| demo.launch() |