Spaces:
Running
Running
| import dacite | |
| import pandas as pd | |
| import torch | |
| import json | |
| import numpy as np | |
| import multiprocessing as mp | |
| from dataclasses import dataclass, fields | |
| from abc import ABC, abstractmethod | |
| from typing import Union, List, Dict, Optional | |
| from ..data_types import ChatMLSample, TextContent, AudioContent | |
| from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN | |
| from loguru import logger | |
| # Whisper processor, 30 sec -> 3000 features | |
| # Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz | |
| WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25 | |
| class ChatMLDatasetSample: | |
| input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens. | |
| label_ids: torch.LongTensor # Shape (seq_len,): The label ids. | |
| audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated. | |
| # Here `audio_seq_len` is the length of the concatenated audio tokens.` | |
| audio_ids_start: ( | |
| torch.LongTensor | |
| ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens. | |
| audio_waveforms_concat: ( | |
| torch.Tensor | |
| ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features. | |
| audio_waveforms_start: ( | |
| torch.LongTensor | |
| ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms. | |
| audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms. | |
| audio_speaker_indices: ( | |
| torch.LongTensor | |
| ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio. | |
| audio_label_ids_concat: Optional[torch.LongTensor] = ( | |
| None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated. | |
| ) | |
| # Here `audio_seq_len` is the length of the concatenated audio tokens.` | |
| reward: Optional[float] = None | |
| def num_audios(self): | |
| return max(len(self.audio_waveforms_start), len(self.audio_ids_start)) | |
| def get_audio_codes(self, idx): | |
| code_start = self.audio_ids_start[idx] | |
| if idx < len(self.audio_ids_start) - 1: | |
| code_end = self.audio_ids_start[idx + 1] | |
| else: | |
| code_end = self.audio_ids_concat.shape[-1] | |
| return self.audio_ids_concat[:, code_start:code_end] | |
| def get_audio_codes_labels(self, idx): | |
| if self.audio_label_ids_concat is None: | |
| return None | |
| code_start = self.audio_ids_start[idx] | |
| if idx < len(self.audio_ids_start) - 1: | |
| code_end = self.audio_ids_start[idx + 1] | |
| else: | |
| code_end = self.audio_ids_concat.shape[-1] | |
| return self.audio_label_ids_concat[:, code_start:code_end] | |
| def get_wv(self, idx): | |
| wv_start = self.audio_waveforms_start[idx] | |
| sr = self.audio_sample_rate[idx] | |
| if idx < len(self.audio_waveforms_start) - 1: | |
| wv_end = self.audio_waveforms_start[idx + 1] | |
| else: | |
| wv_end = self.audio_waveforms_concat.shape[-1] | |
| return self.audio_waveforms_concat[wv_start:wv_end], sr | |
| def cal_num_tokens( | |
| self, | |
| encode_whisper_embed: bool = True, | |
| encode_audio_in_tokens: bool = False, | |
| encode_audio_out_tokens: bool = True, | |
| audio_in_token_id: int = 128015, | |
| audio_out_token_id: int = 128016, | |
| ) -> int: | |
| # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids | |
| # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa) | |
| num_tokens = len(self.input_ids) - len(self.audio_ids_start) | |
| if encode_whisper_embed and len(self.audio_waveforms_concat) > 0: | |
| audio_lengths = torch.diff(self.audio_waveforms_start) | |
| if len(audio_lengths): | |
| # Sum before calling .item() | |
| num_tokens += ( | |
| ( | |
| np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1]) | |
| ).sum() | |
| ).item() | |
| # add the last audio's token estimation | |
| num_tokens += ( | |
| np.ceil( | |
| WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC | |
| * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1]) | |
| / self.audio_sample_rate[-1] | |
| ) | |
| ).item() | |
| if self.audio_ids_concat.size(1) > 0: | |
| audio_io_ids = self.input_ids[ | |
| (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id) | |
| ] | |
| audio_io_id_lengths = torch.concat( | |
| [ | |
| torch.diff(self.audio_ids_start), | |
| torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]), | |
| ] | |
| ) | |
| if encode_audio_in_tokens: | |
| num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item() | |
| if encode_audio_out_tokens: | |
| num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item() | |
| return int(num_tokens) | |
| def merge( | |
| cls, | |
| samples: List["ChatMLDatasetSample"], | |
| eos_token_id: int, | |
| ignore_index: int, | |
| padding_size: Optional[int] = None, | |
| ) -> "ChatMLDatasetSample": | |
| """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start. | |
| Args: | |
| samples (List[ChatMLDatasetSample]): List of samples to merge. | |
| eos_token_id (int): Tokens to be inserted into input_ids between samples. | |
| ignore_index (int): Default label for padding. | |
| padding_size (Optional[int]): If provided, pad the sequence to with this length. | |
| Returns: | |
| ChatMLDatasetSample: Merged and potentially padded sample. | |
| """ | |
| if not samples: | |
| logger.fatal("The samples list is empty and cannot be merged.") | |
| raise ValueError("The samples list is empty and cannot be merged.") | |
| # Initialize empty lists for concatenation | |
| input_ids_list = [] | |
| label_ids_list = [] | |
| audio_ids_concat_list = [] | |
| audio_ids_start_list = [] | |
| audio_waveforms_concat_list = [] | |
| audio_waveforms_start_list = [] | |
| audio_sample_rate_list = [] | |
| audio_speaker_indices_list = [] | |
| # Track offsets | |
| audio_ids_offset = 0 | |
| audio_waveforms_offset = 0 | |
| for sample in samples: | |
| # Add input_ids and label_ids with padding | |
| if input_ids_list: | |
| input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long)) | |
| label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long)) | |
| input_ids_list.append(sample.input_ids) | |
| label_ids_list.append(sample.label_ids) | |
| # Add audio_ids_concat and handle empty audio ids | |
| if sample.audio_ids_concat.size(1) > 0: | |
| audio_ids_concat_list.append(sample.audio_ids_concat) | |
| # Offset and add audio_ids_start | |
| audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset) | |
| audio_ids_offset += sample.audio_ids_concat.size( | |
| 1 | |
| ) # (num_codebooks, seq_len): Update offset by audio_seq_len | |
| # Add audio_waveforms_concat | |
| if sample.audio_waveforms_concat.size(0) > 0: | |
| # Check dimensions of the audio waveform to ensure consistency | |
| if ( | |
| audio_waveforms_concat_list | |
| and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim() | |
| ): | |
| logger.warning( | |
| f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D" | |
| ) | |
| continue | |
| audio_waveforms_concat_list.append(sample.audio_waveforms_concat) | |
| audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset) | |
| audio_waveforms_offset += sample.audio_waveforms_concat.size(0) | |
| # Add audio_sample_rate and audio_speaker_indices | |
| audio_sample_rate_list.append(sample.audio_sample_rate) | |
| audio_speaker_indices_list.append(sample.audio_speaker_indices) | |
| # Concatenate all tensors | |
| input_ids = torch.cat(input_ids_list, dim=0) | |
| label_ids = torch.cat(label_ids_list, dim=0) | |
| # Apply padding if padding_size is specified | |
| if padding_size is not None and padding_size > 0: | |
| input_ids = torch.cat( | |
| [ | |
| input_ids, | |
| torch.full((padding_size,), eos_token_id, dtype=torch.long), | |
| ], | |
| dim=0, | |
| ) | |
| label_ids = torch.cat( | |
| [ | |
| label_ids, | |
| torch.full((padding_size,), ignore_index, dtype=torch.long), | |
| ], | |
| dim=0, | |
| ) | |
| # Safely concatenate audio tensors with proper error handling | |
| try: | |
| audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]]) | |
| audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([]) | |
| # Check for dimensional consistency in audio waveforms | |
| if audio_waveforms_concat_list: | |
| dims = [t.dim() for t in audio_waveforms_concat_list] | |
| if not all(d == dims[0] for d in dims): | |
| # If dimensions don't match, log warning and filter out the problematic tensors | |
| logger.warning( | |
| f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones." | |
| ) | |
| expected_dim = max(set(dims), key=dims.count) # Most common dimension | |
| audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim] | |
| # Recalculate audio_waveforms_start with the filtered list | |
| if audio_waveforms_concat_list: | |
| audio_waveforms_offset = 0 | |
| audio_waveforms_start_list = [] | |
| for waveform in audio_waveforms_concat_list: | |
| audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset])) | |
| audio_waveforms_offset += waveform.size(0) | |
| audio_waveforms_concat = ( | |
| torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([]) | |
| ) | |
| audio_waveforms_start = ( | |
| torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([]) | |
| ) | |
| audio_sample_rate = ( | |
| torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([]) | |
| ) | |
| audio_speaker_indices = ( | |
| torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([]) | |
| ) | |
| except RuntimeError as e: | |
| logger.error(f"Error during tensor concatenation: {str(e)}") | |
| logger.warning("Falling back to empty audio tensors") | |
| # Fall back to empty tensors | |
| audio_ids_concat = torch.tensor([[]]) | |
| audio_ids_start = torch.tensor([]) | |
| audio_waveforms_concat = torch.tensor([]) | |
| audio_waveforms_start = torch.tensor([]) | |
| audio_sample_rate = torch.tensor([]) | |
| audio_speaker_indices = torch.tensor([]) | |
| # Create the merged sample | |
| merged_sample = cls( | |
| input_ids=input_ids, | |
| label_ids=label_ids, | |
| audio_ids_concat=audio_ids_concat, | |
| audio_ids_start=audio_ids_start, | |
| audio_waveforms_concat=audio_waveforms_concat, | |
| audio_waveforms_start=audio_waveforms_start, | |
| audio_sample_rate=audio_sample_rate, | |
| audio_speaker_indices=audio_speaker_indices, | |
| ) | |
| return merged_sample | |
| class RankedChatMLDatasetSampleTuple: | |
| samples: List[ChatMLDatasetSample] | |
| scores: List[float] | |
| def max_score_sample(self) -> ChatMLDatasetSample: | |
| idx = self.scores.index(max(self.scores)) | |
| self.samples[idx].reward = self.scores[idx] | |
| return self.samples[idx] | |
| def min_score_sample(self) -> ChatMLDatasetSample: | |
| idx = self.scores.index(min(self.scores)) | |
| self.samples[idx].reward = self.scores[idx] | |
| return self.samples[idx] | |
| class ChatMLDatasetStorageSample: | |
| input_tokens: torch.LongTensor | |
| label_tokens: torch.LongTensor | |
| audio_bytes_cache_dir_index: int | |
| audio_codes_cache_dir_index: int | |
| audio_bytes_indices: torch.LongTensor | |
| audio_codes_indices: torch.LongTensor | |
| speaker_indices: torch.LongTensor | |
| file_index: int | |
| original_sample_index: int | |
| # TODO(sxjscience): We need to revist the logic about parsing speaker ids. | |
| # Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample. | |
| def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer): | |
| """Preprocess the ChatML sample to get the tokens for the text part. | |
| Args: | |
| sample (ChatMLSample): The ChatML sample to preprocess. | |
| tokenizer: The tokenizer to use for encoding the text. | |
| """ | |
| try: | |
| if not isinstance(sample, ChatMLSample): | |
| # Handle all fields that could be NaN | |
| if "speaker" in sample and pd.isna(sample["speaker"]): | |
| sample["speaker"] = None | |
| if "start_index" in sample and pd.isna(sample["start_index"]): | |
| sample["start_index"] = None | |
| if "content" in sample and pd.isna(sample["content"]): | |
| sample["content"] = "" | |
| # Convert any other potential NaN values in nested structures | |
| def convert_nan_to_none(obj): | |
| import numpy as np | |
| if isinstance(obj, (pd.Series, np.ndarray)): | |
| return obj.tolist() | |
| elif pd.api.types.is_scalar(obj) and pd.isna(obj): | |
| return None | |
| elif isinstance(obj, dict): | |
| return {k: convert_nan_to_none(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple | |
| return [convert_nan_to_none(item) for item in obj] | |
| return obj | |
| # Clean the sample data | |
| clean_sample = convert_nan_to_none(sample) | |
| val_keys = [] | |
| for field in fields(ChatMLSample): | |
| if field.name in clean_sample: | |
| val_keys.append(field.name) | |
| clean_sample = {k: clean_sample[k] for k in val_keys} | |
| try: | |
| sample = dacite.from_dict( | |
| data_class=ChatMLSample, | |
| data=clean_sample, | |
| config=dacite.Config(strict=True, check_types=True), | |
| ) | |
| except Exception as e: | |
| print(f"Failed to convert to ChatMLSample: {e}") | |
| print(f"Clean sample: {json.dumps(clean_sample, indent=2)}") | |
| return None, None, None, None | |
| input_tokens = [] | |
| label_tokens = [] | |
| audio_contents = [] | |
| speaker_id = None | |
| if sample.speaker is not None: | |
| speaker_id = sample.speaker | |
| elif sample.misc is not None: | |
| if "speaker" in sample.misc: | |
| speaker_id = sample.misc["speaker"] | |
| total_m = len(sample.messages) | |
| for turn_id, message in enumerate(sample.messages): | |
| role = message.role | |
| recipient = message.recipient | |
| content = message.content | |
| content_l = [] | |
| if isinstance(content, str): | |
| content_l.append(TextContent(text=content)) | |
| elif isinstance(content, TextContent): | |
| content_l.append(content) | |
| elif isinstance(content, AudioContent): | |
| content_l.append(content) | |
| elif isinstance(content, list): | |
| for ele in content: | |
| if isinstance(ele, str): | |
| content_l.append(TextContent(text=ele)) | |
| else: | |
| content_l.append(ele) | |
| if turn_id == 0: | |
| prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n" | |
| else: | |
| prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n" | |
| eot_postfix = "<|eot_id|>" | |
| eom_postfix = "<|eom_id|>" | |
| prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) | |
| input_tokens.extend(prefix_tokens) | |
| label_tokens.extend([-100 for _ in prefix_tokens]) | |
| if recipient: | |
| assert role == "assistant", "Recipient is only available for assistant role." | |
| recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False) | |
| input_tokens.extend(recipient_tokens) | |
| label_tokens.extend(recipient_tokens) | |
| for content in content_l: | |
| if content.type == "text": | |
| text_tokens = tokenizer.encode(content.text, add_special_tokens=False) | |
| input_tokens.extend(text_tokens) | |
| if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index): | |
| label_tokens.extend(text_tokens) | |
| else: | |
| label_tokens.extend([-100 for _ in text_tokens]) | |
| elif content.type == "audio": | |
| # Generate the text-part of the audio tokens | |
| audio_contents.append(content) | |
| if role == "user" or role == "system": | |
| # Add the text tokens | |
| text_tokens = tokenizer.encode( | |
| f"<|audio_bos|><|AUDIO|><|audio_eos|>", | |
| add_special_tokens=False, | |
| ) | |
| input_tokens.extend(text_tokens) | |
| label_tokens.extend([-100 for _ in text_tokens]) | |
| elif role == "assistant": | |
| # Add the text tokens for audio-out part. | |
| text_tokens = tokenizer.encode( | |
| f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>", | |
| add_special_tokens=False, | |
| ) | |
| input_tokens.extend(text_tokens) | |
| if sample.start_index is None or turn_id >= sample.start_index: | |
| label_tokens.extend(text_tokens) | |
| else: | |
| label_tokens.extend([-100 for _ in text_tokens]) | |
| next_id = turn_id + 1 | |
| if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant": | |
| postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False) | |
| input_tokens.extend(postfix_tokens) | |
| else: | |
| postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False) | |
| input_tokens.extend(postfix_tokens) | |
| if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index): | |
| label_tokens.extend(postfix_tokens) | |
| else: | |
| label_tokens.extend([-100 for _ in postfix_tokens]) | |
| return input_tokens, label_tokens, audio_contents, speaker_id | |
| except Exception as e: | |
| print(f"Error in prepare_chatml_sample: {str(e)}") | |
| print(f"Sample data: {json.dumps(sample, indent=2)}") | |
| return None, None, None, None | |
| def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer): | |
| """Extract the generation prompt and reference answer from the input tokens. | |
| For example: | |
| Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n | |
| What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|> | |
| <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>' | |
| --> | |
| Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n | |
| What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|> | |
| <|start_header_id|>assistant<|end_header_id|>\n\n', | |
| Reference = 'At first they went by quick, too quick to even get.' | |
| Args: | |
| input_tokens: The input tokens. | |
| audio_contents: The audio contents. | |
| tokenizer: The tokenizer to use for decoding the text. | |
| Returns: | |
| prompt_tokens: The tokens for the prompt. | |
| reference_answer: The reference answer. | |
| num_audios_in_reference: The number of audios in the reference answer. | |
| """ | |
| input_text = tokenizer.decode(input_tokens) | |
| generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
| postfix = "<|eot_id|>" | |
| assert generation_prefix in input_text | |
| generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix) | |
| generation_prompt = input_text[:generation_prompt_end_loc] | |
| reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)] | |
| num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN) | |
| return ( | |
| tokenizer.encode(generation_prompt, add_special_tokens=False), | |
| reference_answer, | |
| num_audios_in_reference, | |
| ) | |
| def prepare_chatml_dataframe_single_process(df, tokenizer): | |
| """Prepare the ChatML DataFrame.""" | |
| ret = [] | |
| for _, row in df.iterrows(): | |
| input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer) | |
| ret.append((input_tokens, label_tokens, audio_contents, speaker_id)) | |
| return ret | |
| def prepare_chatml_dataframe(df, tokenizer, num_process=16): | |
| if num_process is None: | |
| return prepare_chatml_dataframe_single_process(df, tokenizer) | |
| else: | |
| num_process = max(min(len(df) // 1000, num_process), 1) | |
| workloads = np.array_split(df, num_process) | |
| with mp.Pool(num_process) as pool: | |
| ret = pool.starmap( | |
| prepare_chatml_dataframe_single_process, | |
| [(workload, tokenizer) for workload in workloads], | |
| ) | |
| return sum(ret, []) | |
| class DatasetInterface(ABC): | |
| def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]: | |
| """Retrieve a dataset sample by index.""" | |
| raise NotImplementedError | |
| class IterableDatasetInterface(ABC): | |
| def __iter__( | |
| self, | |
| ) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]: | |
| """Retrieve a sample by iterating through the dataset.""" | |
| raise NotImplementedError | |
| class DatasetInfo: | |
| dataset_type: str | |
| group_type: Optional[str] = None | |
| mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples. | |