Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Generate audio using JAM model | |
| Reads from filtered test set and generates audio using CFM+DiT model. | |
| """ | |
| import os | |
| import glob | |
| import time | |
| import json | |
| import random | |
| import sys | |
| from huggingface_hub import snapshot_download | |
| import torch | |
| import torchaudio | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm.auto import tqdm | |
| import accelerate | |
| import pyloudnorm as pyln | |
| from safetensors.torch import load_file | |
| from muq import MuQMuLan | |
| import numpy as np | |
| from accelerate import Accelerator | |
| from jam.dataset import enhance_webdataset_config, DiffusionWebDataset | |
| from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE | |
| # DiffRhythm imports for CFM+DiT model | |
| from jam.model import CFM, DiT | |
| def get_negative_style_prompt(device, file_path): | |
| vocal_stlye = np.load(file_path) | |
| vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512] | |
| vocal_stlye = vocal_stlye.half() | |
| return vocal_stlye | |
| def normalize_audio(audio, normalize_lufs=True): | |
| audio = audio - audio.mean(-1, keepdim=True) | |
| audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) | |
| if normalize_lufs: | |
| meter = pyln.Meter(rate=44100) | |
| target_lufs = -14.0 | |
| loudness = meter.integrated_loudness(audio.transpose(0, 1).numpy()) | |
| normalised = pyln.normalize.loudness(audio.transpose(0, 1).numpy(), loudness, target_lufs) | |
| normalised = torch.from_numpy(normalised).transpose(0, 1) | |
| else: | |
| normalised = audio | |
| return normalised | |
| class FilteredTestSetDataset(Dataset): | |
| """Custom dataset for loading from filtered test set JSON""" | |
| def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False): | |
| with open(test_set_path, 'r') as f: | |
| self.test_samples = json.load(f) | |
| if num_samples is not None: | |
| self.test_samples = self.test_samples[:num_samples] | |
| self.diffusion_dataset = diffusion_dataset | |
| self.muq_model = muq_model | |
| self.random_crop_style = random_crop_style | |
| self.num_style_secs = num_style_secs | |
| self.use_prompt_style = use_prompt_style | |
| if self.use_prompt_style: | |
| print("Using prompt style instead of audio style.") | |
| def __len__(self): | |
| return len(self.test_samples) | |
| def __getitem__(self, idx): | |
| test_sample = self.test_samples[idx] | |
| sample_id = test_sample["id"] | |
| # Load LRC data | |
| lrc_path = test_sample["lrc_path"] | |
| with open(lrc_path, 'r') as f: | |
| lrc_data = json.load(f) | |
| if 'word' not in lrc_data: | |
| data = {'word': lrc_data} | |
| lrc_data = data | |
| # Generate style embedding from original audio on-the-fly | |
| audio_path = test_sample["audio_path"] | |
| if self.use_prompt_style: | |
| prompt_path = test_sample["prompt_path"] | |
| prompt = open(prompt_path, 'r').read() | |
| if len(prompt) > 300: | |
| print(f"Sample {sample_id} has prompt length {len(prompt)}") | |
| prompt = prompt[:300] | |
| print(prompt) | |
| style_embedding = self.muq_model(texts=[prompt]).squeeze(0) | |
| else: | |
| style_embedding = self.generate_style_embedding(audio_path) | |
| duration = test_sample["duration"] | |
| # Create fake latent with correct length | |
| # Assuming frame_rate from config (typically 21.5 fps for 44.1kHz) | |
| frame_rate = 21.5 | |
| num_frames = int(duration * frame_rate) | |
| fake_latent = torch.randn(128, num_frames) # 128 is latent dim | |
| # Create sample tuple matching DiffusionWebDataset format | |
| fake_sample = ( | |
| sample_id, | |
| fake_latent, # latent with correct duration | |
| style_embedding, # style from actual audio | |
| lrc_data # actual LRC data | |
| ) | |
| # Process through DiffusionWebDataset's process_sample_safely | |
| processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample) | |
| # Add metadata | |
| if processed_sample is not None: | |
| processed_sample['test_metadata'] = { | |
| 'sample_id': sample_id, | |
| 'audio_path': audio_path, | |
| 'lrc_path': lrc_path, | |
| 'duration': duration, | |
| 'num_frames': num_frames | |
| } | |
| return processed_sample | |
| def generate_style_embedding(self, audio_path): | |
| """Generate style embedding using MuQ model on the whole music""" | |
| # Load audio | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Resample to 24kHz if needed (MuQ expects 24kHz) | |
| if sample_rate != 24000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 24000) | |
| waveform = resampler(waveform) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| # Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono | |
| waveform = waveform.squeeze(0) # Now shape is (time,) | |
| # Move to same device as model | |
| waveform = waveform.to(self.muq_model.device) | |
| # Generate embedding using MuQ model | |
| with torch.inference_mode(): | |
| # MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim) | |
| if self.random_crop_style: | |
| # Randomly crop 30 seconds from the waveform | |
| total_samples = waveform.shape[0] | |
| target_samples = 24000 * self.num_style_secs # 30 seconds at 24kHz | |
| start_idx = random.randint(0, total_samples - target_samples) | |
| style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples]) | |
| else: | |
| style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs]) | |
| # Keep shape as (embedding_dim,) not scalar | |
| return style_embedding[0] | |
| def custom_collate_fn_with_metadata(batch, base_collate_fn): | |
| """Custom collate function that preserves test_metadata""" | |
| # Filter out None samples | |
| batch = [item for item in batch if item is not None] | |
| if not batch: | |
| return None | |
| # Extract test_metadata before collating | |
| test_metadata = [item.pop('test_metadata') for item in batch] | |
| # Use base collate function for the rest | |
| collated = base_collate_fn(batch) | |
| # Add test_metadata back | |
| if collated is not None: | |
| collated['test_metadata'] = test_metadata | |
| return collated | |
| def load_model(model_config, checkpoint_path, device): | |
| """ | |
| Load JAM CFM model from checkpoint (follows infer.py pattern) | |
| """ | |
| # Build CFM model from config | |
| dit_config = model_config["dit"].copy() | |
| # Add text_num_embeds if not specified - should be at least 64 for phoneme tokens | |
| if "text_num_embeds" not in dit_config: | |
| dit_config["text_num_embeds"] = 256 # Default value from DiT | |
| cfm = CFM( | |
| transformer=DiT(**dit_config), | |
| **model_config["cfm"] | |
| ) | |
| cfm = cfm.to(device) | |
| # Load checkpoint - use the path from config | |
| checkpoint = load_file(checkpoint_path) | |
| cfm.load_state_dict(checkpoint, strict=False) | |
| return cfm.eval() | |
| def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'): | |
| """ | |
| Generate latent from batch data (follows infer.py pattern) | |
| """ | |
| with torch.inference_mode(): | |
| batch_size = len(batch["lrc"]) | |
| text = batch["lrc"].to(device) | |
| style_prompt = batch["prompt"].to(device) | |
| start_time = batch["start_time"].to(device) | |
| duration_abs = batch["duration_abs"].to(device) | |
| duration_rel = batch["duration_rel"].to(device) | |
| # Create zero conditioning latent | |
| # Handle case where model might be wrapped by accelerator | |
| max_frames = model.max_frames | |
| cond = torch.zeros(batch_size, max_frames, 64).to(text.device) | |
| pred_frames = [(0, max_frames)] | |
| default_sample_kwargs = { | |
| "cfg_strength": 4, | |
| "steps": 50, | |
| "batch_infer_num": 1 | |
| } | |
| sample_kwargs = {**default_sample_kwargs, **sample_kwargs} | |
| if negative_style_prompt_path is None: | |
| negative_style_prompt_path = 'public_checkpoints/vocal.npy' | |
| negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) | |
| elif negative_style_prompt_path == 'zeros': | |
| negative_style_prompt = torch.zeros(1, 512).to(text.device) | |
| else: | |
| negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) | |
| negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) | |
| latents, _ = model.sample( | |
| cond=cond, | |
| text=text, | |
| style_prompt=negative_style_prompt if ignore_style else style_prompt, | |
| duration_abs=duration_abs, | |
| duration_rel=duration_rel, | |
| negative_style_prompt=negative_style_prompt, | |
| start_time=start_time, | |
| latent_pred_segments=pred_frames, | |
| **sample_kwargs | |
| ) | |
| return latents | |
| class Jamify: | |
| def __init__(self): | |
| os.makedirs('outputs', exist_ok=True) | |
| device = 'cuda' | |
| config_path = 'jam_infer.yaml' | |
| self.config = OmegaConf.load(config_path) | |
| OmegaConf.resolve(self.config) | |
| # Override output directory for evaluation | |
| print("Downloading main model checkpoint...") | |
| model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") | |
| self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") | |
| # Load VAE based on configuration | |
| vae_type = self.config.evaluation.get('vae_type', 'stable_audio') | |
| if vae_type == 'diffrhythm': | |
| vae = DiffRhythmVAE(device=device).to(device) | |
| else: | |
| vae = StableAudioOpenVAE().to(device) | |
| self.vae = vae | |
| self.vae_type = vae_type | |
| self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, device) | |
| self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(device).eval() | |
| dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) | |
| enhance_webdataset_config(dataset_cfg) | |
| # Override multiple_styles to False since we're generating single style embeddings | |
| dataset_cfg.multiple_styles = False | |
| self.base_dataset = DiffusionWebDataset(**dataset_cfg) | |
| def cleanup_old_files(self, sample_id): | |
| # Clean up old generated files (keep only last 5 files) | |
| old_mp3_files = sorted(glob.glob("outputs/*.mp3")) | |
| if len(old_mp3_files) >= 10: | |
| for old_file in old_mp3_files[:-9]: # Keep last 4, delete older ones | |
| try: | |
| os.remove(old_file) | |
| print(f"Cleaned up old file: {old_file}") | |
| except OSError: | |
| pass | |
| os.unlink(f"outputs/{sample_id}.json") | |
| def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration): | |
| sample_id = str(int(time.time() * 1000000)) # microsecond timestamp for uniqueness | |
| test_set = [{ | |
| "id": sample_id, | |
| "audio_path": reference_audio_path, | |
| "lrc_path": lyrics_json_path, | |
| "duration": duration, | |
| "prompt_path": style_prompt | |
| }] | |
| json.dump(test_set, open(f"outputs/{sample_id}.json", "w")) | |
| # Create filtered test set dataset | |
| test_dataset = FilteredTestSetDataset( | |
| test_set_path=f"outputs/{sample_id}.json", | |
| diffusion_dataset=self.base_dataset, | |
| muq_model=self.muq_model, | |
| num_samples=1, | |
| random_crop_style=self.config.evaluation.random_crop_style, | |
| num_style_secs=self.config.evaluation.num_style_secs, | |
| use_prompt_style=self.config.evaluation.use_prompt_style | |
| ) | |
| # Create dataloader with custom collate function | |
| dataloader = DataLoader( | |
| test_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn) | |
| ) | |
| batch = next(iter(dataloader)) | |
| sample_kwargs = self.config.evaluation.sample_kwargs | |
| latent = generate_latent(self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style)[0][0] | |
| test_metadata = batch['test_metadata'][0] | |
| sample_id = test_metadata['sample_id'] | |
| original_duration = test_metadata['duration'] | |
| # Decode audio | |
| latent_for_vae = latent.transpose(0, 1).unsqueeze(0) | |
| # Use chunked decoding if configured (only for DiffRhythm VAE) | |
| use_chunked = self.config.evaluation.get('use_chunked_decoding', True) | |
| if self.vae_type == 'diffrhythm' and use_chunked: | |
| pred_audio = self.vae.decode( | |
| latent_for_vae, | |
| chunked=True, | |
| overlap=self.config.evaluation.get('chunked_overlap', 32), | |
| chunk_size=self.config.evaluation.get('chunked_size', 128) | |
| ).sample.squeeze(0).detach().cpu() | |
| else: | |
| pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() | |
| pred_audio = normalize_audio(pred_audio) | |
| sample_rate = 44100 | |
| trim_samples = int(original_duration * sample_rate) | |
| if pred_audio.shape[1] > trim_samples: | |
| pred_audio_trimmed = pred_audio[:, :trim_samples] | |
| else: | |
| pred_audio_trimmed = pred_audio | |
| output_path = f'outputs/{sample_id}.mp3' | |
| torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3") | |
| self.cleanup_old_files(sample_id) | |
| return output_path | |