Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import itertools | |
| import textwrap | |
| import unittest | |
| from time import strftime | |
| from datasets import Dataset, DatasetDict | |
| from parameterized import parameterized | |
| from transformers import AutoProcessor, AutoTokenizer | |
| from trl.data_utils import ( | |
| apply_chat_template, | |
| extract_prompt, | |
| is_conversational, | |
| is_conversational_from_value, | |
| maybe_apply_chat_template, | |
| maybe_convert_to_chatml, | |
| maybe_extract_prompt, | |
| maybe_unpair_preference_dataset, | |
| pack_dataset, | |
| truncate_dataset, | |
| unpair_preference_dataset, | |
| ) | |
| class IsConversationalTester(unittest.TestCase): | |
| conversational_examples = [ | |
| { # Language modeling | |
| "messages": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ], | |
| }, | |
| { # Prompt-only | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| }, | |
| { # Prompt-completion | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| "completion": [{"role": "assistant", "content": "It is blue."}], | |
| }, | |
| { # Preference | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| "chosen": [{"role": "assistant", "content": "It is blue."}], | |
| "rejected": [{"role": "assistant", "content": "It is green."}], | |
| }, | |
| { # Preference with implicit prompt | |
| "chosen": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is green."}, | |
| ], | |
| }, | |
| { # Unpaired preference | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| "completion": [{"role": "assistant", "content": "It is blue."}], | |
| "label": True, | |
| }, | |
| { # Language modeling with harmony | |
| "messages": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| }, | |
| { # Prompt-only with harmony | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| }, | |
| { # Prompt-completion with harmony | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| "completion": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| }, | |
| { # Preference with harmony | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| "chosen": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, | |
| ], | |
| }, | |
| { # Preference with implicit prompt and harmony | |
| "chosen": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, | |
| ], | |
| }, | |
| { # Unpaired preference with harmony | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| "completion": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| "label": True, | |
| }, | |
| ] | |
| non_conversational_examples = [ | |
| {"prompt": "The sky is", "completion": " blue."}, | |
| {"text": "The sky is blue."}, | |
| {"prompt": "The sky is"}, | |
| {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, | |
| {"prompt": "The sky is", "completion": " blue.", "label": True}, | |
| ] | |
| def test_conversational(self, example): | |
| self.assertTrue(is_conversational(example)) | |
| def test_non_conversational(self, example): | |
| self.assertFalse(is_conversational(example)) | |
| class IsConversationalFromValueTester(unittest.TestCase): | |
| def test_positive_1(self): | |
| example = { | |
| "conversations": [ | |
| {"from": "user", "value": "What color is the sky?"}, | |
| {"from": "assistant", "value": "It is blue."}, | |
| ], | |
| } | |
| self.assertTrue(is_conversational_from_value(example)) | |
| def test_negative_1(self): | |
| example = { | |
| "messages": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ], | |
| } | |
| self.assertFalse(is_conversational_from_value(example)) | |
| def test_negative_2(self): | |
| example = {"text": "The sky is blue."} | |
| self.assertFalse(is_conversational_from_value(example)) | |
| class ApplyChatTemplateTester(unittest.TestCase): | |
| tokenizers = [ | |
| "trl-internal-testing/tiny-CohereForCausalLM", | |
| "trl-internal-testing/tiny-DbrxForCausalLM", | |
| "trl-internal-testing/tiny-DeepseekV3ForCausalLM", | |
| "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", | |
| "trl-internal-testing/tiny-FalconMambaForCausalLM", | |
| "trl-internal-testing/tiny-Gemma2ForCausalLM", | |
| "trl-internal-testing/tiny-GemmaForCausalLM", | |
| "trl-internal-testing/tiny-GptOssForCausalLM", | |
| "trl-internal-testing/tiny-LlamaForCausalLM-3.1", | |
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2", | |
| "trl-internal-testing/tiny-LlamaForCausalLM-3", | |
| "trl-internal-testing/tiny-MistralForCausalLM-0.1", | |
| "trl-internal-testing/tiny-MistralForCausalLM-0.2", | |
| "trl-internal-testing/tiny-Phi3ForCausalLM", | |
| "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| "trl-internal-testing/tiny-Qwen3ForCausalLM", | |
| ] | |
| conversational_examples = [ | |
| { # Language modeling | |
| "messages": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ], | |
| }, | |
| { # Prompt-only | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| }, | |
| { # Prompt-completion | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| "completion": [{"role": "assistant", "content": "It is blue."}], | |
| }, | |
| { # Preference | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| "chosen": [{"role": "assistant", "content": "It is blue."}], | |
| "rejected": [{"role": "assistant", "content": "It is green."}], | |
| }, | |
| { # Preference with implicit prompt | |
| "chosen": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is green."}, | |
| ], | |
| }, | |
| { # Unpaired preference | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| "completion": [{"role": "assistant", "content": "It is blue."}], | |
| "label": True, | |
| }, | |
| ] | |
| non_conversational_examples = [ | |
| {"text": "The sky is blue."}, # Language modeling | |
| {"prompt": "The sky is"}, # Prompt-only | |
| {"prompt": "The sky is", "completion": " blue."}, # Prompt-completion | |
| {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, # Preference | |
| {"chosen": "The sky is blue.", "rejected": "The sky is green."}, # Preference with implicit prompt | |
| {"prompt": "The sky is", "completion": " blue.", "label": True}, # Unpaired preference | |
| ] | |
| def test_apply_chat_template(self, tokenizer_id, example): | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
| result = apply_chat_template(example, tokenizer) | |
| # Checking if the result is a dictionary | |
| self.assertIsInstance(result, dict) | |
| # The chat template should be applied to the following keys | |
| for key in ["prompt", "chosen", "rejected", "completion"]: | |
| if key in example: | |
| self.assertIn(key, result) | |
| self.assertIsInstance(result[key], str) | |
| # Exception for messages, the key is "text" once the chat template is applied | |
| if "messages" in example: | |
| self.assertIn("text", result) | |
| self.assertIsInstance(result["text"], str) | |
| # The label should be kept | |
| if "label" in example: | |
| self.assertIn("label", result) | |
| self.assertIsInstance(result["label"], bool) | |
| self.assertEqual(result["label"], example["label"]) | |
| # both conversational and non-conversational examples | |
| def test_maybe_apply_chat_template(self, tokenizer_id, example): | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
| result = maybe_apply_chat_template(example, tokenizer) | |
| # Checking if the result is a dictionary | |
| self.assertIsInstance(result, dict) | |
| # The chat template should be applied to the following keys | |
| for key in ["prompt", "chosen", "rejected", "completion"]: | |
| if key in example: | |
| self.assertIn(key, result) | |
| self.assertIsInstance(result[key], str) | |
| # Exception for messages, the key is "text" once the chat template is applied | |
| if "messages" in example: | |
| self.assertIn("text", result) | |
| self.assertIsInstance(result["text"], str) | |
| # The label should be kept | |
| if "label" in example: | |
| self.assertIn("label", result) | |
| self.assertIsInstance(result["label"], bool) | |
| self.assertEqual(result["label"], example["label"]) | |
| def test_apply_chat_template_with_tools(self): | |
| tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") | |
| # Define dummy test tools | |
| def get_current_temperature(location: str): | |
| """ | |
| Gets the temperature at a given location. | |
| Args: | |
| location: The location to get the temperature for | |
| """ | |
| return 22.0 | |
| # Define test case | |
| test_case = { | |
| "prompt": [ | |
| {"content": "Whats the temperature in London?", "role": "user"}, | |
| ] | |
| } | |
| # Test with tools | |
| result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature]) | |
| # Verify tools are included in the output | |
| self.assertIn("get_current_temperature", result_with_tools["prompt"]) | |
| # Test without tools | |
| result_without_tools = apply_chat_template(test_case, tokenizer, tools=None) | |
| # Verify tools are not included in the output | |
| self.assertNotIn("get_current_temperature", result_without_tools["prompt"]) | |
| class ApplyChatTemplateHarmonyTester(unittest.TestCase): | |
| def test_language_modeling(self): | |
| messages = { | |
| "messages": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| } | |
| output = apply_chat_template( | |
| messages, | |
| tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), | |
| reasoning_effort="low", | |
| model_identity="You are HuggingGPT.", | |
| ) | |
| expected = textwrap.dedent(f"""\ | |
| <|start|>system<|message|>You are HuggingGPT. | |
| Knowledge cutoff: 2024-06 | |
| Current date: {strftime("%Y-%m-%d")} | |
| Reasoning: low | |
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions | |
| Respond in a friendly manner.<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") | |
| self.assertEqual(output["text"], expected) | |
| def test_prompt_only(self): | |
| messages = { | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| } | |
| output = apply_chat_template( | |
| messages, | |
| tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), | |
| reasoning_effort="low", | |
| model_identity="You are HuggingGPT.", | |
| ) | |
| expected = textwrap.dedent(f"""\ | |
| <|start|>system<|message|>You are HuggingGPT. | |
| Knowledge cutoff: 2024-06 | |
| Current date: {strftime("%Y-%m-%d")} | |
| Reasoning: low | |
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions | |
| Respond in a friendly manner.<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") | |
| self.assertEqual(output["prompt"], expected) | |
| def test_prompt_completion(self): | |
| messages = { | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| "completion": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| } | |
| output = apply_chat_template( | |
| messages, | |
| tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), | |
| reasoning_effort="low", | |
| model_identity="You are HuggingGPT.", | |
| ) | |
| expected_prompt = textwrap.dedent(f"""\ | |
| <|start|>system<|message|>You are HuggingGPT. | |
| Knowledge cutoff: 2024-06 | |
| Current date: {strftime("%Y-%m-%d")} | |
| Reasoning: low | |
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions | |
| Respond in a friendly manner.<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") | |
| expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" | |
| self.assertEqual(output["prompt"], expected_prompt) | |
| self.assertEqual(output["completion"], expected_completion) | |
| def test_preference(self): | |
| messages = { | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| "chosen": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, | |
| ], | |
| } | |
| output = apply_chat_template( | |
| messages, | |
| tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), | |
| reasoning_effort="low", | |
| model_identity="You are HuggingGPT.", | |
| ) | |
| expected_prompt = textwrap.dedent(f"""\ | |
| <|start|>system<|message|>You are HuggingGPT. | |
| Knowledge cutoff: 2024-06 | |
| Current date: {strftime("%Y-%m-%d")} | |
| Reasoning: low | |
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions | |
| Respond in a friendly manner.<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") | |
| expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" | |
| expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>" | |
| self.assertEqual(output["prompt"], expected_prompt) | |
| self.assertEqual(output["chosen"], expected_chosen) | |
| self.assertEqual(output["rejected"], expected_rejected) | |
| def test_preference_with_implicit_prompt(self): | |
| messages = { | |
| "chosen": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, | |
| ], | |
| } | |
| output = apply_chat_template( | |
| messages, | |
| tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), | |
| reasoning_effort="low", | |
| model_identity="You are HuggingGPT.", | |
| ) | |
| expected_chosen = textwrap.dedent(f"""\ | |
| <|start|>system<|message|>You are HuggingGPT. | |
| Knowledge cutoff: 2024-06 | |
| Current date: {strftime("%Y-%m-%d")} | |
| Reasoning: low | |
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions | |
| Respond in a friendly manner.<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") | |
| expected_rejected = textwrap.dedent(f"""\ | |
| <|start|>system<|message|>You are HuggingGPT. | |
| Knowledge cutoff: 2024-06 | |
| Current date: {strftime("%Y-%m-%d")} | |
| Reasoning: low | |
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions | |
| Respond in a friendly manner.<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""") | |
| self.assertEqual(output["chosen"], expected_chosen) | |
| self.assertEqual(output["rejected"], expected_rejected) | |
| def test_unpaired_preference(self): | |
| messages = { | |
| "prompt": [ | |
| {"role": "system", "content": "Respond in a friendly manner."}, | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| "completion": [ | |
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, | |
| ], | |
| "label": True, | |
| } | |
| output = apply_chat_template( | |
| messages, | |
| tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), | |
| reasoning_effort="low", | |
| model_identity="You are HuggingGPT.", | |
| ) | |
| expected_prompt = textwrap.dedent(f"""\ | |
| <|start|>system<|message|>You are HuggingGPT. | |
| Knowledge cutoff: 2024-06 | |
| Current date: {strftime("%Y-%m-%d")} | |
| Reasoning: low | |
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions | |
| Respond in a friendly manner.<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") | |
| expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" | |
| self.assertEqual(output["prompt"], expected_prompt) | |
| self.assertEqual(output["completion"], expected_completion) | |
| self.assertTrue(output["label"]) | |
| class UnpairPreferenceDatasetTester(unittest.TestCase): | |
| paired_dataset = Dataset.from_dict( | |
| { | |
| "prompt": ["The sky is", "The sun is"], | |
| "chosen": [" blue.", " in the sky."], | |
| "rejected": [" green.", " in the sea."], | |
| } | |
| ) | |
| unpaired_dataset = Dataset.from_dict( | |
| { | |
| "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], | |
| "completion": [" blue.", " in the sky.", " green.", " in the sea."], | |
| "label": [True, True, False, False], | |
| } | |
| ) | |
| def test_unpair_preference_dataset(self): | |
| # Test that a paired dataset is correctly converted to unpaired | |
| unpaired_dataset = unpair_preference_dataset(self.paired_dataset) | |
| self.assertEqual( | |
| unpaired_dataset.to_dict(), | |
| self.unpaired_dataset.to_dict(), | |
| "The paired dataset should be converted to unpaired.", | |
| ) | |
| def test_unpair_preference_dataset_dict(self): | |
| # Test that a paired dataset dict is correctly converted to unpaired | |
| paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) | |
| unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) | |
| self.assertEqual( | |
| unpaired_dataset_dict["abc"].to_dict(), | |
| self.unpaired_dataset.to_dict(), | |
| "The paired dataset should be converted to unpaired.", | |
| ) | |
| def test_maybe_unpair_preference_dataset(self): | |
| # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset | |
| unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) | |
| self.assertEqual( | |
| unpaired_dataset.to_dict(), | |
| self.unpaired_dataset.to_dict(), | |
| "The paired dataset should be converted to unpaired.", | |
| ) | |
| def test_maybe_unpair_preference_dataset_dict(self): | |
| # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset | |
| paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) | |
| unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) | |
| self.assertEqual( | |
| unpaired_dataset_dict["abc"].to_dict(), | |
| self.unpaired_dataset.to_dict(), | |
| "The paired dataset should be converted to unpaired.", | |
| ) | |
| def test_maybe_unpair_preference_dataset_already_paired(self): | |
| # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset | |
| unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) | |
| self.assertEqual( | |
| unpaired_dataset.to_dict(), | |
| self.unpaired_dataset.to_dict(), | |
| "The unpaired dataset should remain unchanged.", | |
| ) | |
| def test_maybe_unpair_preference_dataset_dict_already_paired(self): | |
| # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset | |
| unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) | |
| self.assertEqual( | |
| unpaired_dataset_dict["abc"].to_dict(), | |
| self.unpaired_dataset.to_dict(), | |
| "The unpaired dataset should remain unchanged.", | |
| ) | |
| class ExtractPromptTester(unittest.TestCase): | |
| example_implicit_prompt_conversational = { | |
| "chosen": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is green."}, | |
| ], | |
| } | |
| example_explicit_prompt_conversational = { | |
| "prompt": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| ], | |
| "chosen": [ | |
| {"role": "assistant", "content": "It is blue."}, | |
| ], | |
| "rejected": [ | |
| {"role": "assistant", "content": "It is green."}, | |
| ], | |
| } | |
| example_implicit_prompt_standard = { | |
| "chosen": "The sky is blue.", | |
| "rejected": "The sky is green.", | |
| } | |
| example_explicit_prompt_standard = { | |
| "prompt": "The sky is", | |
| "chosen": " blue.", | |
| "rejected": " green.", | |
| } | |
| def test_extract_prompt_conversational(self): | |
| # Test that the prompt is correctly extracted from the dataset | |
| example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational) | |
| self.assertEqual( | |
| example_extracted_prompt, | |
| self.example_explicit_prompt_conversational, | |
| "The prompt is not correctly extracted from the dataset.", | |
| ) | |
| def test_maybe_extract_prompt_conversational(self): | |
| # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt | |
| example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational) | |
| self.assertEqual( | |
| example_extracted_prompt, | |
| self.example_explicit_prompt_conversational, | |
| "The prompt is not correctly extracted from the dataset.", | |
| ) | |
| def test_maybe_extract_prompt_conversational_already_explicit(self): | |
| # Test that the prompt remains unchanged with maybe_extract_prompt | |
| example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational) | |
| self.assertEqual( | |
| example_extracted_prompt, | |
| self.example_explicit_prompt_conversational, | |
| "The prompt should remain unchanged.", | |
| ) | |
| def test_extract_prompt_standard(self): | |
| # Test that the prompt is correctly extracted from the dataset | |
| example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard) | |
| self.assertEqual( | |
| example_extracted_prompt, | |
| self.example_explicit_prompt_standard, | |
| "The prompt is not correctly extracted from the dataset.", | |
| ) | |
| def test_maybe_extract_prompt_standard(self): | |
| # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt | |
| example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard) | |
| self.assertEqual( | |
| example_extracted_prompt, | |
| self.example_explicit_prompt_standard, | |
| "The prompt is not correctly extracted from the dataset.", | |
| ) | |
| def test_maybe_extract_prompt_standard_already_explicit(self): | |
| # Test that the prompt remains unchanged with maybe_extract_prompt | |
| example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard) | |
| self.assertEqual( | |
| example_extracted_prompt, | |
| self.example_explicit_prompt_standard, | |
| "The prompt should remain unchanged.", | |
| ) | |
| class TestPackDatasetWrapped(unittest.TestCase): | |
| def test_with_dataset(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | |
| } | |
| dataset = Dataset.from_dict(examples) | |
| seq_length = 3 | |
| expected_output = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], | |
| } | |
| dataset = pack_dataset(dataset, seq_length, strategy="wrapped") | |
| self.assertEqual(dataset.to_dict(), expected_output) | |
| def test_with_iterable_dataset(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | |
| } | |
| dataset = Dataset.from_dict(examples).to_iterable_dataset() | |
| seq_length = 3 | |
| expected_output = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], | |
| } | |
| dataset = pack_dataset(dataset, seq_length, strategy="wrapped") | |
| num_examples = len(examples[next(iter(examples))]) | |
| self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) | |
| class TestPackDatasetBfd(unittest.TestCase): | |
| def test_simple(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | |
| } | |
| dataset = Dataset.from_dict(examples) | |
| seq_length = 4 | |
| expected_output = { | |
| "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], | |
| "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], | |
| "seq_lengths": [[4], [3, 1]], | |
| } | |
| dataset = pack_dataset(dataset, seq_length, strategy="bfd") | |
| self.assertEqual(dataset.to_dict(), expected_output) | |
| def test_with_iterable_dataset(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | |
| } | |
| dataset = Dataset.from_dict(examples).to_iterable_dataset() | |
| seq_length = 4 | |
| expected_output = { | |
| "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], | |
| "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], | |
| "seq_lengths": [[4], [3, 1]], | |
| } | |
| dataset = pack_dataset(dataset, seq_length, strategy="bfd") | |
| num_examples = len(examples[next(iter(examples))]) | |
| self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) | |
| def test_with_truncation(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]], | |
| "attention_mask": [[1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1], [1]], | |
| } | |
| dataset = Dataset.from_dict(examples) | |
| seq_length = 4 | |
| expected_output = { | |
| "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]], | |
| "attention_mask": [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]], | |
| "seq_lengths": [[4], [4], [2, 1]], | |
| } | |
| dataset = pack_dataset(dataset, seq_length, strategy="bfd") | |
| self.assertEqual(dataset.to_dict(), expected_output) | |
| class TestTruncateExamples(unittest.TestCase): | |
| def test_with_dataset(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | |
| } | |
| dataset = Dataset.from_dict(examples) | |
| max_length = 2 | |
| expected_output = { | |
| "input_ids": [[1, 2], [4, 5], [8]], | |
| "attention_mask": [[0, 1], [0, 0], [1]], | |
| } | |
| dataset = truncate_dataset(dataset, max_length) | |
| self.assertEqual(dataset.to_dict(), expected_output) | |
| def test_with_iterable_dataset(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | |
| } | |
| dataset = Dataset.from_dict(examples).to_iterable_dataset() | |
| max_length = 2 | |
| expected_output = { | |
| "input_ids": [[1, 2], [4, 5], [8]], | |
| "attention_mask": [[0, 1], [0, 0], [1]], | |
| } | |
| dataset = truncate_dataset(dataset, max_length) | |
| num_examples = len(examples[next(iter(examples))]) | |
| self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) | |
| def test_with_extra_column(self): | |
| examples = { | |
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | |
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | |
| "my_column": ["a", "b", "c"], | |
| } | |
| dataset = Dataset.from_dict(examples) | |
| max_length = 2 | |
| expected_output = { | |
| "input_ids": [[1, 2], [4, 5], [8]], | |
| "attention_mask": [[0, 1], [0, 0], [1]], | |
| "my_column": ["a", "b", "c"], | |
| } | |
| dataset = truncate_dataset(dataset, max_length) | |
| self.assertEqual(dataset.to_dict(), expected_output) | |
| class TestMaybeConvertToChatML(unittest.TestCase): | |
| def test_with_conversations_key(self): | |
| # Particular case where the key is "conversations": we rename it to "messages" | |
| example = { | |
| "conversations": [ | |
| {"from": "user", "value": "What color is the sky?"}, | |
| {"from": "assistant", "value": "It is blue."}, | |
| ] | |
| } | |
| expected_output = { | |
| "messages": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ] | |
| } | |
| self.assertEqual(maybe_convert_to_chatml(example), expected_output) | |
| def test_without_conversations_key(self): | |
| # Same as before, but we don't rename the keys | |
| example = { | |
| "prompt": [{"from": "user", "value": "What color is the sky?"}], | |
| "completion": [{"from": "assistant", "value": "It is blue."}], | |
| } | |
| expected_output = { | |
| "prompt": [{"role": "user", "content": "What color is the sky?"}], | |
| "completion": [{"role": "assistant", "content": "It is blue."}], | |
| } | |
| self.assertEqual(maybe_convert_to_chatml(example), expected_output) | |
| def test_not_conversional(self): | |
| # When not needed, the example should remain unchanged | |
| example = {"text": "The sky is blue."} | |
| self.assertEqual(maybe_convert_to_chatml(example), example) | |
| def test_already_chatml(self): | |
| # When the example is already in ChatML format, it should remain unchanged | |
| example = { | |
| "messages": [ | |
| {"role": "user", "content": "What color is the sky?"}, | |
| {"role": "assistant", "content": "It is blue."}, | |
| ] | |
| } | |
| self.assertEqual(maybe_convert_to_chatml(example), example) | |
| # Run the tests | |
| if __name__ == "__main__": | |
| unittest.main() | |