Spaces:
Runtime error
Runtime error
| import os | |
| import zipfile | |
| import requests | |
| from tqdm import tqdm | |
| from typing import List, Tuple | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| import librosa | |
| import torch | |
| SAMPLE_RATE = 22050 | |
| DURATION = 1.4 # second | |
| class EmodbDataset(Dataset): | |
| __url__ = "http://www.emodb.bilderbar.info/download/download.zip" | |
| __labels__ = ("angry", "happy", "neutral", "sad") | |
| __suffixes__ = { | |
| "angry": ["Wa", "Wb", "Wc", "Wd"], | |
| "happy": ["Fa", "Fb", "Fc", "Fd"], | |
| "neutral": ["Na", "Nb", "Nc", "Nd"], | |
| "sad": ["Ta", "Tb", "Tc", "Td"] | |
| } | |
| def __init__(self, root_path: str = './data/emodb', transform=None): | |
| super().__init__() | |
| self.root_path = root_path | |
| self.audio_root_path = os.path.join(root_path, "wav") | |
| # Ensure the dataset is downloaded | |
| self._ensure_dataset() | |
| ids = [] | |
| targets = [] | |
| for audio_file in os.listdir(self.audio_root_path): | |
| f_name, ext = os.path.splitext(audio_file) | |
| if ext != ".wav": | |
| continue | |
| suffix = f_name[-2:] | |
| for label, suffixes in self.__suffixes__.items(): | |
| if suffix in suffixes: | |
| ids.append(os.path.join(self.audio_root_path, audio_file)) | |
| targets.append(self.label2id(label)) | |
| break | |
| self.ids = ids | |
| self.targets = np.array(targets, dtype=np.int64) | |
| self.transform = transform | |
| def _ensure_dataset(self): | |
| """ | |
| Ensures the dataset is downloaded and extracted. | |
| """ | |
| if not os.path.isdir(self.audio_root_path): | |
| print(f"Dataset not found at {self.audio_root_path}. Downloading...") | |
| self._download_and_extract() | |
| def _download_and_extract(self): | |
| """ | |
| Downloads and extracts the dataset zip file. | |
| """ | |
| # Ensure the root path exists | |
| os.makedirs(self.root_path, exist_ok=True) | |
| # Download the dataset | |
| zip_path = os.path.join(self.root_path, "emodb.zip") | |
| with requests.get(self.__url__, stream=True) as r: | |
| r.raise_for_status() | |
| total_size = int(r.headers.get("content-length", 0)) | |
| with open(zip_path, "wb") as f, tqdm( | |
| desc="Downloading EMO-DB dataset", | |
| total=total_size, | |
| unit="B", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as bar: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| bar.update(len(chunk)) | |
| # Extract the dataset | |
| print("Extracting dataset...") | |
| with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
| zip_ref.extractall(self.root_path) | |
| # Clean up the zip file | |
| os.remove(zip_path) | |
| def __len__(self): | |
| return len(self.ids) | |
| def __getitem__(self, idx: int) -> Tuple: | |
| target = self.targets[idx] | |
| audio = self.load_audio(self.ids[idx]) # Should return a numpy array | |
| if self.transform: | |
| audio = self.transform(audio) # Apply transform | |
| return audio, target | |
| def id2label(idx: int) -> str: | |
| return EmodbDataset.__labels__[idx] | |
| def label2id(label: str) -> int: | |
| if label not in EmodbDataset.__labels__: | |
| raise ValueError(f"Unknown label: {label}") | |
| return EmodbDataset.__labels__.index(label) | |
| def load_audio(audio_file_path: str) -> np.ndarray: | |
| audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION) | |
| assert SAMPLE_RATE == sr, "broken audio file" | |
| # Convert numpy array to PyTorch tensor | |
| return torch.tensor(audio, dtype=torch.float32) | |
| def get_labels() -> List[str]: | |
| return list(EmodbDataset.__labels__) | |