Spaces:
Running
Running
| import torch | |
| import soundfile as sf | |
| import numpy as np | |
| import argparse | |
| import os | |
| import yaml | |
| import julius | |
| import sys | |
| currentdir = os.path.dirname(os.path.realpath(__file__)) | |
| sys.path.append(os.path.dirname(currentdir)) | |
| from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder | |
| from modules.loss import AudioFeatureLoss, Loss, CLAPFeatureLoss | |
| from modules.data_normalization import Audio_Effects_Normalizer | |
| def convert_audio(wav: torch.Tensor, from_rate: float, | |
| to_rate: float, to_channels: int) -> torch.Tensor: | |
| """Convert audio to new sample rate and number of audio channels. | |
| """ | |
| wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) | |
| wav = convert_audio_channels(wav, to_channels) | |
| return wav | |
| class MasteringStyleTransfer: | |
| def __init__(self, args): | |
| self.args = args | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load models | |
| self.effects_encoder = self.load_effects_encoder() | |
| self.mastering_converter = self.load_mastering_converter() | |
| self.fx_normalizer = Audio_Effects_Normalizer(precomputed_feature_path=args.fx_norm_feature_path, \ | |
| STEMS=['mixture'], \ | |
| EFFECTS=['eq', 'imager', 'loudness']) | |
| # Loss functions | |
| self.clap_loss = CLAPFeatureLoss() | |
| def load_effects_encoder(self): | |
| effects_encoder = Effects_Encoder(self.args.cfg_enc) | |
| reload_weights(effects_encoder, self.args.encoder_path, self.device) | |
| effects_encoder.to(self.device) | |
| effects_encoder.eval() | |
| return effects_encoder | |
| def load_mastering_converter(self): | |
| mastering_converter = Dasp_Mastering_Style_Transfer(num_features=2048, | |
| sample_rate=self.args.sample_rate, | |
| tgt_fx_names=['eq', 'distortion', 'multiband_comp', 'gain', 'imager', 'limiter'], | |
| model_type='tcn', | |
| config=self.args.cfg_converter, | |
| batch_size=1) | |
| reload_weights(mastering_converter, self.args.model_path, self.device) | |
| mastering_converter.to(self.device) | |
| mastering_converter.eval() | |
| return mastering_converter | |
| def get_reference_embedding(self, reference_tensor): | |
| with torch.no_grad(): | |
| reference_feature = self.effects_encoder(reference_tensor) | |
| return reference_feature | |
| def mastering_style_transfer(self, input_tensor, reference_feature): | |
| with torch.no_grad(): | |
| output_audio = self.mastering_converter(input_tensor, reference_feature) | |
| predicted_params = self.mastering_converter.get_last_predicted_params() | |
| return output_audio, predicted_params | |
| def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature): | |
| fit_embedding = torch.nn.Parameter(initial_reference_feature, requires_grad=True) | |
| optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate']) | |
| min_loss = float('inf') | |
| min_loss_step = 0 | |
| all_results = [] | |
| af_loss = AudioFeatureLoss( | |
| weights=ito_config['af_weights'], | |
| sample_rate=ito_config['sample_rate'], | |
| stem_separation=False, | |
| use_clap=False | |
| ) | |
| for step in range(ito_config['num_steps']): | |
| optimizer.zero_grad() | |
| output_audio = self.mastering_converter(input_tensor, fit_embedding) | |
| current_params = self.mastering_converter.get_last_predicted_params() | |
| # Compute loss | |
| if ito_config['loss_function'] == 'AudioFeatureLoss': | |
| losses = af_loss(output_audio, reference_tensor) | |
| total_loss = sum(losses.values()) | |
| elif ito_config['loss_function'] == 'CLAPFeatureLoss': | |
| if ito_config['clap_target_type'] == 'Audio': | |
| target = reference_tensor | |
| else: | |
| target = ito_config['clap_text_prompt'] | |
| total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn']) | |
| if total_loss < min_loss: | |
| min_loss = total_loss.item() | |
| min_loss_step = step | |
| # Log top 5 parameter differences | |
| if step == 0: | |
| initial_params = current_params | |
| top_5_diff = self.get_top_n_diff_string(initial_params, current_params, top_n=5) | |
| log_entry = f"Step {step + 1}\n Loss: {total_loss.item():.4f}\n{top_5_diff}\n" | |
| all_results.append({ | |
| 'step': step + 1, | |
| 'loss': total_loss.item(), | |
| 'audio': output_audio.detach().cpu().numpy(), | |
| 'params': current_params, | |
| 'log': log_entry | |
| }) | |
| total_loss.backward() | |
| optimizer.step() | |
| return all_results, min_loss_step | |
| def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False): | |
| sample_rate, data = audio | |
| # Normalize audio to -1 to 1 range | |
| if data.dtype == np.int16: | |
| data = data.astype(np.float32) / 32768.0 | |
| elif data.dtype == np.float32: | |
| data = np.clip(data, -1.0, 1.0) | |
| else: | |
| raise ValueError(f"Unsupported audio data type: {data.dtype}") | |
| # Ensure stereo channels | |
| if data.ndim == 1: | |
| data = np.stack([data, data]) | |
| elif data.ndim == 2: | |
| if data.shape[0] == 2: | |
| pass # Already in correct shape | |
| elif data.shape[1] == 2: | |
| data = data.T | |
| else: | |
| data = np.stack([data[:, 0], data[:, 0]]) # Duplicate mono channel | |
| else: | |
| raise ValueError(f"Unsupported audio shape: {data.shape}") | |
| # Resample if necessary | |
| if sample_rate != target_sample_rate: | |
| data = julius.resample_frac(torch.from_numpy(data), sample_rate, target_sample_rate).numpy() | |
| # Apply fx normalization for input audio during mastering style transfer | |
| if normalize: | |
| data = self.fx_normalizer.normalize_audio(data.T, 'mixture').T | |
| # Convert to torch tensor | |
| data_tensor = torch.FloatTensor(data).unsqueeze(0) | |
| return data_tensor.to(self.device) | |
| def process_audio(self, input_audio, reference_audio): | |
| input_tensor = self.preprocess_audio(input_audio, self.args.sample_rate, normalize=True) | |
| reference_tensor = self.preprocess_audio(reference_audio, self.args.sample_rate) | |
| reference_feature = self.get_reference_embedding(reference_tensor) | |
| output_audio, predicted_params = self.mastering_style_transfer(input_tensor, reference_feature) | |
| return output_audio, predicted_params, self.args.sample_rate, input_tensor | |
| def get_param_output_string(self, params): | |
| if params is None: | |
| return "No parameters available" | |
| param_mapper = { | |
| 'eq': { | |
| 'low_shelf_gain_db': ('Low Shelf Gain', 'dB', -20, 20), | |
| 'low_shelf_cutoff_freq': ('Low Shelf Cutoff', 'Hz', 20, 2000), | |
| 'low_shelf_q_factor': ('Low Shelf Q', '', 0.1, 5.0), | |
| 'band0_gain_db': ('Low-Mid Band Gain', 'dB', -20, 20), | |
| 'band0_cutoff_freq': ('Low-Mid Band Frequency', 'Hz', 80, 2000), | |
| 'band0_q_factor': ('Low-Mid Band Q', '', 0.1, 5.0), | |
| 'band1_gain_db': ('Mid Band Gain', 'dB', -20, 20), | |
| 'band1_cutoff_freq': ('Mid Band Frequency', 'Hz', 2000, 8000), | |
| 'band1_q_factor': ('Mid Band Q', '', 0.1, 5.0), | |
| 'band2_gain_db': ('High-Mid Band Gain', 'dB', -20, 20), | |
| 'band2_cutoff_freq': ('High-Mid Band Frequency', 'Hz', 8000, 12000), | |
| 'band2_q_factor': ('High-Mid Band Q', '', 0.1, 5.0), | |
| 'band3_gain_db': ('High Band Gain', 'dB', -20, 20), | |
| 'band3_cutoff_freq': ('High Band Frequency', 'Hz', 12000, 20000), | |
| 'band3_q_factor': ('High Band Q', '', 0.1, 5.0), | |
| 'high_shelf_gain_db': ('High Shelf Gain', 'dB', -20, 20), | |
| 'high_shelf_cutoff_freq': ('High Shelf Cutoff', 'Hz', 4000, 20000), | |
| 'high_shelf_q_factor': ('High Shelf Q', '', 0.1, 5.0), | |
| }, | |
| 'distortion': { | |
| 'drive_db': ('Drive', 'dB', 0, 8), | |
| 'parallel_weight_factor': ('Dry/Wet Mix', '%', 0, 100), | |
| }, | |
| 'multiband_comp': { | |
| 'low_cutoff': ('Low/Mid Crossover', 'Hz', 20, 1000), | |
| 'high_cutoff': ('Mid/High Crossover', 'Hz', 1000, 20000), | |
| 'parallel_weight_factor': ('Dry/Wet Mix', '%', 0, 100), | |
| 'low_shelf_comp_thresh': ('Low Band Comp Threshold', 'dB', -60, 0), | |
| 'low_shelf_comp_ratio': ('Low Band Comp Ratio', ': 1', 1, 20), | |
| 'low_shelf_exp_thresh': ('Low Band Exp Threshold', 'dB', -60, 0), | |
| 'low_shelf_exp_ratio': ('Low Band Exp Ratio', ': 1', 1, 20), | |
| 'low_shelf_at': ('Low Band Attack Time', 'ms', 5, 100), | |
| 'low_shelf_rt': ('Low Band Release Time', 'ms', 5, 100), | |
| 'mid_band_comp_thresh': ('Mid Band Comp Threshold', 'dB', -60, 0), | |
| 'mid_band_comp_ratio': ('Mid Band Comp Ratio', ': 1', 1, 20), | |
| 'mid_band_exp_thresh': ('Mid Band Exp Threshold', 'dB', -60, 0), | |
| 'mid_band_exp_ratio': ('Mid Band Exp Ratio', ': 1', 0, 1), | |
| 'mid_band_at': ('Mid Band Attack Time', 'ms', 5, 100), | |
| 'mid_band_rt': ('Mid Band Release Time', 'ms', 5, 100), | |
| 'high_shelf_comp_thresh': ('High Band Comp Threshold', 'dB', -60, 0), | |
| 'high_shelf_comp_ratio': ('High Band Comp Ratio', ': 1', 1, 20), | |
| 'high_shelf_exp_thresh': ('High Band Exp Threshold', 'dB', -60, 0), | |
| 'high_shelf_exp_ratio': ('High Band Exp Ratio', ': 1', 1, 20), | |
| 'high_shelf_at': ('High Band Attack Time', 'ms', 5, 100), | |
| 'high_shelf_rt': ('High Band Release Time', 'ms', 5, 100), | |
| }, | |
| 'gain': { | |
| 'gain_db': ('Output Gain', 'dB', -24, 24), | |
| }, | |
| 'imager': { | |
| 'width': ('Stereo Width', '', 0, 1), | |
| }, | |
| 'limiter': { | |
| 'threshold': ('Threshold', 'dB', -60, 0), | |
| 'at': ('Attack Time', 'ms', 5, 100), | |
| 'rt': ('Release Time', 'ms', 5, 100), | |
| }, | |
| } | |
| output = [] | |
| for fx_name, fx_params in params.items(): | |
| output.append(f"{fx_name.upper()}:") | |
| if isinstance(fx_params, dict): | |
| for param_name, param_value in fx_params.items(): | |
| if isinstance(param_value, torch.Tensor): | |
| param_value = param_value.item() | |
| if fx_name in param_mapper and param_name in param_mapper[fx_name]: | |
| friendly_name, unit, min_val, max_val = param_mapper[fx_name][param_name] | |
| if unit=='%': | |
| param_value = param_value * 100 | |
| current_content = f" {friendly_name}: {param_value:.2f} {unit}" | |
| if param_name=='mid_band_exp_ratio': | |
| current_content += f" (Range: {min_val}-{max_val})" | |
| output.append(current_content) | |
| else: | |
| output.append(f" {param_name}: {param_value:.2f}") | |
| else: | |
| # stereo imager | |
| width_percentage = fx_params.item() * 200 | |
| output.append(f" Stereo Width: {width_percentage:.2f}% (Range: 0-200%)") | |
| return "\n".join(output) | |
| def get_top_n_diff_string(self, initial_params, ito_params, top_n=5): | |
| if initial_params is None or ito_params is None: | |
| return "Cannot compare parameters" | |
| all_diffs = [] | |
| for fx_name in initial_params.keys(): | |
| if isinstance(initial_params[fx_name], dict): | |
| for param_name in initial_params[fx_name].keys(): | |
| initial_value = initial_params[fx_name][param_name] | |
| ito_value = ito_params[fx_name][param_name] | |
| param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name] | |
| normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0])) | |
| all_diffs.append((fx_name, param_name, initial_value.item(), ito_value.item(), normalized_diff.item())) | |
| else: | |
| initial_value = initial_params[fx_name] | |
| ito_value = ito_params[fx_name] | |
| normalized_diff = abs(ito_value - initial_value) | |
| all_diffs.append((fx_name, 'width', initial_value.item(), ito_value.item(), normalized_diff.item())) | |
| top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:top_n] | |
| output = [f" Top {top_n} parameter differences (initial / ITO / normalized diff):"] | |
| for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs: | |
| output.append(f" {fx_name.upper()} - {param_name}: {initial_value:.2f} / {ito_value:.2f} / {normalized_diff:.2f}") | |
| return "\n".join(output) | |
| def reload_weights(model, ckpt_path, device): | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, v in checkpoint["model"].items(): | |
| name = k[7:] # remove `module.` | |
| new_state_dict[name] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |