Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import os | |
| import shutil | |
| def normalize_ratios(ratios): | |
| total = sum(ratios) | |
| return [r / total for r in ratios] | |
| from torch.nn.utils.rnn import pad_sequence | |
| def collate_fn_transformer(batch): | |
| """ | |
| Custom collate function to handle variable-length raw waveform inputs. | |
| Args: | |
| batch: List of tuples (tensor, label), where tensor has shape [sequence_length]. | |
| Returns: | |
| padded_waveforms: Padded tensor of shape [batch_size, max_seq_len]. | |
| attention_mask: Attention mask for padded sequences. | |
| labels: Tensor of shape [batch_size]. | |
| """ | |
| # Separate waveforms and labels | |
| waveforms, labels = zip(*batch) | |
| # Ensure waveforms are 1D tensors | |
| waveforms = [torch.tensor(waveform).squeeze() for waveform in waveforms] | |
| # Pad sequences to the same length | |
| padded_waveforms = pad_sequence(waveforms, batch_first=True) # [batch_size, max_seq_len] | |
| # Create attention mask | |
| attention_mask = (padded_waveforms != 0).long() # Mask for non-padded values | |
| # In the training loop or DataLoader debug | |
| # Convert labels to a tensor | |
| labels = torch.tensor(labels, dtype=torch.long) | |
| return padded_waveforms, attention_mask, labels | |
| def collate_fn(batch): | |
| inputs, targets, input_lengths, target_lengths = zip(*batch) | |
| inputs = torch.stack(inputs) # Convert list of tensors to a batch tensor | |
| targets = torch.cat(targets) # Flatten target sequences | |
| input_lengths = torch.tensor(input_lengths, dtype=torch.long) | |
| target_lengths = torch.tensor(target_lengths, dtype=torch.long) | |
| return inputs, targets, input_lengths, target_lengths | |
| def save_test_data(test_dataset, dataset, save_dir): | |
| if os.path.exists(save_dir): | |
| shutil.rmtree(save_dir) # Delete the existing directory and its contents | |
| print(f"Existing test data directory '{save_dir}' removed.") | |
| os.makedirs(save_dir, exist_ok=True) | |
| for idx in test_dataset.indices: | |
| audio_file_path = dataset.audio_files[idx] # Assuming dataset has `audio_files` attribute | |
| label = dataset.labels[idx] # Assuming dataset has `labels` attribute | |
| # Create a directory for the label if it doesn't exist | |
| label_dir = os.path.join(save_dir, str(label)) | |
| os.makedirs(label_dir, exist_ok=True) | |
| # Copy the audio file to the label directory | |
| shutil.copy(audio_file_path, label_dir) | |
| print(f"Test data saved in {save_dir}") |