Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import tempfile | |
| import unittest | |
| from functools import partial | |
| import torch | |
| from datasets import Dataset | |
| from parameterized import parameterized | |
| from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments | |
| from trl import IterativeSFTTrainer | |
| class IterativeTrainerTester(unittest.TestCase): | |
| def setUp(self): | |
| self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # get t5 as seq2seq example: | |
| model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" | |
| self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| def _init_tensor_dummy_dataset(self): | |
| dummy_dataset_dict = { | |
| "input_ids": [ | |
| torch.tensor([5303, 3621, 3666, 1438, 318]), | |
| torch.tensor([3666, 1438, 318, 3666, 1438, 318]), | |
| torch.tensor([5303, 3621, 3666, 1438, 318]), | |
| ], | |
| "attention_mask": [ | |
| torch.tensor([1, 1, 1, 1, 1]), | |
| torch.tensor([1, 1, 1, 1, 1, 1]), | |
| torch.tensor([1, 1, 1, 1, 1]), | |
| ], | |
| "labels": [ | |
| torch.tensor([5303, 3621, 3666, 1438, 318]), | |
| torch.tensor([3666, 1438, 318, 3666, 1438, 318]), | |
| torch.tensor([5303, 3621, 3666, 1438, 318]), | |
| ], | |
| } | |
| dummy_dataset = Dataset.from_dict(dummy_dataset_dict) | |
| dummy_dataset.set_format("torch") | |
| return dummy_dataset | |
| def _init_textual_dummy_dataset(self): | |
| dummy_dataset_dict = { | |
| "texts": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], | |
| "texts_labels": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], | |
| } | |
| dummy_dataset = Dataset.from_dict(dummy_dataset_dict) | |
| dummy_dataset.set_format("torch") | |
| return dummy_dataset | |
| def test_iterative_step_from_tensor(self, model_name, input_name): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # initialize dataset | |
| if input_name == "tensor": | |
| dummy_dataset = self._init_tensor_dummy_dataset() | |
| inputs = { | |
| "input_ids": dummy_dataset["input_ids"], | |
| "attention_mask": dummy_dataset["attention_mask"], | |
| "labels": dummy_dataset["labels"], | |
| } | |
| else: | |
| dummy_dataset = self._init_textual_dummy_dataset() | |
| inputs = { | |
| "texts": dummy_dataset["texts"], | |
| "texts_labels": dummy_dataset["texts_labels"], | |
| } | |
| if model_name == "qwen": | |
| model = self.model | |
| tokenizer = self.tokenizer | |
| else: | |
| model = self.t5_model | |
| tokenizer = self.t5_tokenizer | |
| training_args = TrainingArguments( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_steps=2, | |
| learning_rate=1e-3, | |
| report_to="none", | |
| ) | |
| iterative_trainer = IterativeSFTTrainer(model=model, args=training_args, processing_class=tokenizer) | |
| iterative_trainer.optimizer.zero_grad = partial(iterative_trainer.optimizer.zero_grad, set_to_none=False) | |
| iterative_trainer.step(**inputs) | |
| for param in iterative_trainer.model.parameters(): | |
| self.assertIsNotNone(param.grad) | |