Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import soundfile as sf | |
| import numpy as np | |
| import yaml | |
| from inference import MasteringStyleTransfer | |
| from utils import download_youtube_audio | |
| from config import args | |
| import pyloudnorm as pyln | |
| import tempfile | |
| import os | |
| import pandas as pd | |
| mastering_transfer = MasteringStyleTransfer(args) | |
| def denormalize_audio(audio, dtype=np.int16): | |
| """ | |
| Denormalize the audio from the range [-1, 1] to the full range of the specified dtype. | |
| """ | |
| if dtype == np.int16: | |
| audio = np.clip(audio, -1, 1) # Ensure the input is in the range [-1, 1] | |
| return (audio * 32767).astype(np.int16) | |
| elif dtype == np.float32: | |
| return audio.astype(np.float32) | |
| else: | |
| raise ValueError("Unsupported dtype. Use np.int16 or np.float32.") | |
| def loudness_normalize(audio, sample_rate, target_loudness=-12.0): | |
| # Ensure audio is float32 | |
| if audio.dtype != np.float32: | |
| audio = audio.astype(np.float32) | |
| # If audio is mono, reshape to (samples, 1) | |
| if audio.ndim == 1: | |
| audio = audio.reshape(-1, 1) | |
| meter = pyln.Meter(sample_rate) # create BS.1770 meter | |
| loudness = meter.integrated_loudness(audio) | |
| loudness_normalized_audio = pyln.normalize.loudness(audio, loudness, target_loudness) | |
| return loudness_normalized_audio | |
| def process_youtube_url(url): | |
| try: | |
| audio, sr = download_youtube_audio(url) | |
| return (sr, audio), None | |
| except Exception as e: | |
| return None, f"Error processing YouTube URL: {str(e)}" | |
| def download_youtube_audios(input_youtube_url, reference_youtube_url): | |
| input_audio, input_error = process_youtube_url(input_youtube_url) if input_youtube_url else (None, None) | |
| reference_audio, reference_error = process_youtube_url(reference_youtube_url) if reference_youtube_url else (None, None) | |
| return input_audio, reference_audio, input_error, reference_error | |
| def process_audio_with_youtube(input_audio, input_youtube_url, reference_audio, reference_youtube_url): | |
| if input_youtube_url: | |
| input_audio, error = process_youtube_url(input_youtube_url) | |
| if error: | |
| return None, None, error | |
| if reference_youtube_url: | |
| reference_audio, error = process_youtube_url(reference_youtube_url) | |
| if error: | |
| return None, None, error | |
| if input_audio is None or reference_audio is None: | |
| return None, None, "Both input and reference audio are required." | |
| return process_audio(input_audio, reference_audio) | |
| def to_numpy_audio(audio): | |
| # Convert output_audio to numpy array if it's a tensor | |
| if isinstance(audio, torch.Tensor): | |
| audio = audio.cpu().numpy() | |
| # check dimension | |
| if audio.ndim == 1: | |
| audio = audio.reshape(-1, 1) | |
| elif audio.ndim > 2: | |
| audio = audio.squeeze() | |
| # Ensure the audio is in the correct shape (samples, channels) | |
| if audio.shape[1] > audio.shape[0]: | |
| audio = audio.transpose(1,0) | |
| return audio | |
| def process_audio(input_audio, reference_audio): | |
| output_audio, predicted_params, sr, normalized_input = mastering_transfer.process_audio( | |
| input_audio, reference_audio | |
| ) | |
| param_output = mastering_transfer.get_param_output_string(predicted_params) | |
| # Convert to numpy audio | |
| output_audio = to_numpy_audio(output_audio) | |
| normalized_input = to_numpy_audio(normalized_input) | |
| # Normalize output audio | |
| output_audio = loudness_normalize(output_audio, sr) | |
| # Denormalize the audio to int16 | |
| output_audio = denormalize_audio(output_audio, dtype=np.int16) | |
| return (sr, output_audio), param_output, (sr, normalized_input) | |
| def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt, clap_distance_fn): | |
| if ito_reference_audio is None: | |
| ito_reference_audio = reference_audio | |
| af_weights = [float(w.strip()) for w in af_weights.split(',')] | |
| ito_config = { | |
| 'optimizer': optimizer, | |
| 'learning_rate': learning_rate, | |
| 'num_steps': num_steps, | |
| 'af_weights': af_weights, | |
| 'sample_rate': args.sample_rate, | |
| 'loss_function': loss_function, | |
| 'clap_target_type': clap_target_type, | |
| 'clap_text_prompt': clap_text_prompt, | |
| 'clap_distance_fn': clap_distance_fn | |
| } | |
| input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate) | |
| reference_tensor = mastering_transfer.preprocess_audio(reference_audio, args.sample_rate) | |
| ito_reference_tensor = mastering_transfer.preprocess_audio(ito_reference_audio, args.sample_rate) | |
| initial_reference_feature = mastering_transfer.get_reference_embedding(reference_tensor) | |
| all_results, min_loss_step = mastering_transfer.inference_time_optimization( | |
| input_tensor, ito_reference_tensor, ito_config, initial_reference_feature | |
| ) | |
| ito_log = "" | |
| loss_values = [] | |
| for result in all_results: | |
| ito_log += result['log'] | |
| loss_values.append({"step": result['step'], "loss": result['loss']}) | |
| # Return the results of the last step | |
| last_result = all_results[-1] | |
| current_output = last_result['audio'] | |
| ito_param_output = mastering_transfer.get_param_output_string(last_result['params']) | |
| # Convert to numpy audio | |
| current_output = to_numpy_audio(current_output) | |
| # Loudness normalize output audio | |
| current_output = loudness_normalize(current_output, args.sample_rate) | |
| # Denormalize the audio to int16 | |
| current_output = denormalize_audio(current_output, dtype=np.int16) | |
| return (args.sample_rate, current_output), ito_param_output, num_steps, ito_log, pd.DataFrame(loss_values), all_results | |
| def update_ito_output(all_results, selected_step): | |
| selected_result = all_results[selected_step - 1] | |
| current_output = selected_result['audio'] | |
| ito_param_output = mastering_transfer.get_param_output_string(selected_result['params']) | |
| # Convert to numpy audio | |
| current_output = to_numpy_audio(current_output) | |
| # Loudness normalize output audio | |
| current_output = loudness_normalize(current_output, args.sample_rate) | |
| # Denormalize the audio to int16 | |
| current_output = denormalize_audio(current_output, dtype=np.int16) | |
| return (args.sample_rate, current_output), ito_param_output, selected_result['log'] | |
| """ APP display """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ITO-Master: Inference Time Optimization for Audio Effects Modeling of Music Mastering Processors") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.HTML(""" | |
| <!-- Load MathJax --> | |
| <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> | |
| <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> | |
| <p> | |
| Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. | |
| The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. | |
| Perform mastering style transfer with an input source audio and a reference mastering style audio. | |
| On top of this result, you can perform ITO to optimize the reference embedding $z_{ref}$ to further gain control over the output mastering style. | |
| </p> | |
| <p> | |
| <strong>🔗 GitHub Source Code:</strong> | |
| <a href="https://github.com/SonyResearch/ITO-Master" target="_blank">SonyResearch/ITO-Master</a> | |
| </p> | |
| <p> | |
| <strong>📜 Full Paper:</strong> | |
| <a href="https://arxiv.org/abs/2506.16889" target="_blank">Arxiv</a> | |
| </p> | |
| """) | |
| with gr.Column(scale=1): | |
| gr.Image("ito_new.png", width=500, height=300, label="ITO pipeline") | |
| gr.Markdown("## Step 1: Mastering Style Transfer") | |
| with gr.Tab("Upload Audio"): | |
| with gr.Row(): | |
| input_audio = gr.Audio(label="Source Audio $x_{in}$") | |
| reference_audio = gr.Audio(label="Reference Style Audio $x_{ref}$") | |
| process_button = gr.Button("Process Mastering Style Transfer") | |
| gr.Markdown('<span style="color: lightgray; font-style: italic;">all output samples are normalized to -12dB LUFS</span>') | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_audio = gr.Audio(label="Output Audio y'", type='numpy') | |
| normalized_input = gr.Audio(label="Normalized Source Audio", type='numpy') | |
| param_output = gr.Textbox(label="Predicted Parameters", lines=5) | |
| process_button.click( | |
| process_audio, | |
| inputs=[input_audio, reference_audio], | |
| outputs=[output_audio, param_output, normalized_input] | |
| ) | |
| with gr.Tab("YouTube Audio"): | |
| gr.Markdown("Seems like it's currently unavailable to download YouTube clips from HuggingFace... But you could try out yourself in your environment with the available source code.") | |
| with gr.Row(): | |
| input_youtube_url = gr.Textbox(label="Input YouTube URL") | |
| reference_youtube_url = gr.Textbox(label="Reference YouTube URL") | |
| download_button = gr.Button("Download YouTube Audios") | |
| error_message_yt = gr.Textbox(label="Error Message", visible=False) | |
| with gr.Row(): | |
| input_audio_yt = gr.Audio(label="Source Audio (Do not put when using YouTube URL)") | |
| reference_audio_yt = gr.Audio(label="Reference Style Audio (Do not put when using YouTube URL)") | |
| process_button_yt = gr.Button("Process Mastering Style Transfer") | |
| gr.Markdown('<span style="color: lightgray; font-style: italic;">all output samples are normalized to -12dB LUFS</span>') | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_audio_yt = gr.Audio(label="Output Audio y'", type='numpy') | |
| normalized_input_yt = gr.Audio(label="Normalized Source Audio", type='numpy') | |
| param_output_yt = gr.Textbox(label="Predicted Parameters", lines=5) | |
| def handle_download_youtube_audios(input_youtube_url, reference_youtube_url): | |
| input_audio, reference_audio, input_error, reference_error = download_youtube_audios(input_youtube_url, reference_youtube_url) | |
| if input_error or reference_error: | |
| return None, None, gr.update(visible=True, value=input_error or reference_error) | |
| return input_audio, reference_audio, gr.update(visible=False, value="") | |
| download_button.click( | |
| handle_download_youtube_audios, | |
| inputs=[input_youtube_url, reference_youtube_url], | |
| outputs=[input_audio_yt, reference_audio_yt, error_message_yt] | |
| ) | |
| process_button_yt.click( | |
| process_audio, | |
| inputs=[input_audio_yt, reference_audio_yt], | |
| outputs=[output_audio_yt, param_output_yt, normalized_input_yt] | |
| ) | |
| # def process_and_handle_errors(input_audio, input_youtube_url, reference_audio, reference_youtube_url): | |
| # result = process_audio_with_youtube(input_audio, input_youtube_url, reference_audio, reference_youtube_url) | |
| # if len(result) == 3 and isinstance(result[2], str): # Error occurred check | |
| # return None, None, None, gr.update(visible=True, value=result[2]) | |
| # return result[0], result[1], result[2], gr.update(visible=False, value="") | |
| # process_button_yt.click( | |
| # process_and_handle_errors, | |
| # inputs=[input_audio_yt, input_youtube_url, reference_audio_yt, reference_youtube_url], | |
| # outputs=[output_audio_yt, param_output_yt, normalized_input_yt, error_message_yt] | |
| # ) | |
| gr.Markdown("## Step 2: Inference Time Optimization (ITO)") | |
| with gr.Row(): | |
| ito_reference_audio = gr.Audio(label="ITO Reference Style Audio $x'_{ref}$ (optional)") | |
| with gr.Column(): | |
| num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps for additional optimization") | |
| optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer") | |
| learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate") | |
| loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss") | |
| # Audio Feature Loss weights | |
| with gr.Column(visible=True) as audio_feature_weights: | |
| af_weights = gr.Textbox( | |
| label="AudioFeatureLoss Weights (comma-separated)", | |
| value="0.1,0.001,1.0,1.0,0.1", | |
| info="RMS, Crest Factor, Stereo Width, Stereo Imbalance, Bark Spectrum" | |
| ) | |
| # CLAP Loss options | |
| with gr.Column(visible=False) as clap_options: | |
| clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio") | |
| clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False) | |
| clap_distance_fn = gr.Dropdown(["cosine", "mse", "l1"], label="CLAP Distance Function", value="cosine") | |
| def update_clap_options(loss_function): | |
| if loss_function == "CLAPFeatureLoss": | |
| return gr.update(visible=False), gr.update(visible=True) | |
| else: | |
| return gr.update(visible=True), gr.update(visible=False) | |
| loss_function.change( | |
| update_clap_options, | |
| inputs=[loss_function], | |
| outputs=[audio_feature_weights, clap_options] | |
| ) | |
| def update_clap_text_prompt(clap_target_type): | |
| return gr.update(visible=clap_target_type == "Text") | |
| clap_target_type.change( | |
| update_clap_text_prompt, | |
| inputs=[clap_target_type], | |
| outputs=[clap_text_prompt] | |
| ) | |
| ito_button = gr.Button("Perform ITO") | |
| gr.Markdown('<span style="color: lightgray; font-style: italic;">all output samples are normalized to -12dB LUFS</span>') | |
| with gr.Row(): | |
| with gr.Column(): | |
| ito_output_audio = gr.Audio(label="ITO Output Audio") | |
| ito_step_slider = gr.Slider(minimum=1, maximum=100, step=1, label="ITO Step", interactive=True) | |
| ito_param_output = gr.Textbox(label="ITO Predicted Parameters", lines=15) | |
| with gr.Column(): | |
| ito_loss_plot = gr.LinePlot( | |
| x="step", | |
| y="loss", | |
| title="ITO Loss Curve", | |
| x_title="Step", | |
| y_title="Loss", | |
| height=300, | |
| width=600, | |
| ) | |
| ito_log = gr.Textbox(label="ITO Log", lines=10) | |
| all_results = gr.State([]) | |
| ito_button.click( | |
| perform_ito, | |
| inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt, clap_distance_fn], | |
| outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results] | |
| ).then( | |
| update_ito_output, | |
| inputs=[all_results, ito_step_slider], | |
| outputs=[ito_output_audio, ito_param_output, ito_log] | |
| ) | |
| ito_step_slider.change( | |
| update_ito_output, | |
| inputs=[all_results, ito_step_slider], | |
| outputs=[ito_output_audio, ito_param_output, ito_log] | |
| ) | |
| # demo.launch() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |