Fix ORPO multi gpu (#1433)
Browse files* don't drop attention_mask for orpo
* handle multi-gpu cases better for orpo
* revert change to not drop the attention_mask from inputs for orpo
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -30,6 +30,7 @@ from transformers import (
|
|
| 30 |
from transformers.trainer_utils import seed_worker
|
| 31 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 32 |
from trl import DPOTrainer
|
|
|
|
| 33 |
|
| 34 |
from axolotl.loraplus import create_loraplus_optimizer
|
| 35 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
|
@@ -472,6 +473,58 @@ class AxolotlTrainer(Trainer):
|
|
| 472 |
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
| 473 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
def orpo_compute_custom_loss(self, logits, labels):
|
| 476 |
logits = logits.contiguous()
|
| 477 |
loss = 0.0
|
|
@@ -512,45 +565,46 @@ class AxolotlTrainer(Trainer):
|
|
| 512 |
dim=2,
|
| 513 |
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
| 514 |
).squeeze(2)
|
| 515 |
-
return torch.mul(per_token_logps, mask.
|
| 516 |
-
dtype=torch.float64
|
| 517 |
-
) / mask.sum(dim=1).to(dtype=torch.float64)
|
| 518 |
|
| 519 |
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
},
|
| 526 |
-
output_hidden_states=True,
|
| 527 |
)
|
| 528 |
-
|
|
|
|
|
|
|
| 529 |
**{
|
| 530 |
-
"input_ids":
|
| 531 |
-
"attention_mask":
|
| 532 |
-
"labels":
|
| 533 |
},
|
| 534 |
output_hidden_states=True,
|
| 535 |
)
|
| 536 |
|
|
|
|
|
|
|
|
|
|
| 537 |
# Calculate NLL loss
|
| 538 |
pos_loss = self.orpo_compute_custom_loss(
|
| 539 |
-
logits=outputs_pos
|
| 540 |
)
|
| 541 |
|
| 542 |
# Calculate Log Probability
|
| 543 |
pos_prob = self.orpo_compute_logps(
|
| 544 |
-
prompt_attention_mask=
|
| 545 |
-
chosen_inputs=
|
| 546 |
-
chosen_attention_mask=
|
| 547 |
-
logits=outputs_pos
|
| 548 |
)
|
| 549 |
neg_prob = self.orpo_compute_logps(
|
| 550 |
-
prompt_attention_mask=
|
| 551 |
-
chosen_inputs=
|
| 552 |
-
chosen_attention_mask=
|
| 553 |
-
logits=outputs_neg
|
| 554 |
)
|
| 555 |
|
| 556 |
# Calculate log odds
|
|
@@ -1247,6 +1301,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 1247 |
train_dataset=self.train_dataset,
|
| 1248 |
eval_dataset=self.eval_dataset,
|
| 1249 |
args=training_args,
|
|
|
|
| 1250 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
| 1251 |
eval_data_collator=self.build_collator(
|
| 1252 |
training_args, is_eval=True, **data_collator_kwargs
|
|
|
|
| 30 |
from transformers.trainer_utils import seed_worker
|
| 31 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 32 |
from trl import DPOTrainer
|
| 33 |
+
from trl.trainer.utils import pad_to_length
|
| 34 |
|
| 35 |
from axolotl.loraplus import create_loraplus_optimizer
|
| 36 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
|
|
|
| 473 |
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
| 474 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 475 |
|
| 476 |
+
@staticmethod
|
| 477 |
+
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
| 478 |
+
concatenated_batch = {}
|
| 479 |
+
|
| 480 |
+
max_length = max(
|
| 481 |
+
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
|
| 482 |
+
)
|
| 483 |
+
# Concatenate positive and negative inputs
|
| 484 |
+
concatenated_batch["input_ids"] = pad_to_length(
|
| 485 |
+
inputs["input_ids"], max_length, pad_token
|
| 486 |
+
)
|
| 487 |
+
concatenated_batch["rejected_input_ids"] = pad_to_length(
|
| 488 |
+
inputs["rejected_input_ids"], max_length, pad_token
|
| 489 |
+
)
|
| 490 |
+
concatenated_batch["labels"] = pad_to_length(
|
| 491 |
+
inputs["labels"], max_length, label_pad_token
|
| 492 |
+
)
|
| 493 |
+
concatenated_batch["rejected_labels"] = pad_to_length(
|
| 494 |
+
inputs["rejected_labels"], max_length, label_pad_token
|
| 495 |
+
)
|
| 496 |
+
concatenated_batch["attention_mask"] = pad_to_length(
|
| 497 |
+
inputs["attention_mask"], max_length, 0
|
| 498 |
+
)
|
| 499 |
+
concatenated_batch["rejected_attention_mask"] = pad_to_length(
|
| 500 |
+
inputs["rejected_attention_mask"], max_length, 0
|
| 501 |
+
)
|
| 502 |
+
concatenated_batch["prompt_attention_mask"] = pad_to_length(
|
| 503 |
+
inputs["prompt_attention_mask"], max_length, 0
|
| 504 |
+
).to(device=device)
|
| 505 |
+
|
| 506 |
+
input_ids = torch.cat(
|
| 507 |
+
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
|
| 508 |
+
dim=0,
|
| 509 |
+
).to(device=device)
|
| 510 |
+
attention_mask = torch.cat(
|
| 511 |
+
[
|
| 512 |
+
concatenated_batch["attention_mask"],
|
| 513 |
+
concatenated_batch["rejected_attention_mask"],
|
| 514 |
+
],
|
| 515 |
+
dim=0,
|
| 516 |
+
).to(device=device)
|
| 517 |
+
labels = torch.cat(
|
| 518 |
+
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
|
| 519 |
+
).to(device=device)
|
| 520 |
+
|
| 521 |
+
return {
|
| 522 |
+
"input_ids": input_ids,
|
| 523 |
+
"labels": labels,
|
| 524 |
+
"attention_mask": attention_mask,
|
| 525 |
+
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
def orpo_compute_custom_loss(self, logits, labels):
|
| 529 |
logits = logits.contiguous()
|
| 530 |
loss = 0.0
|
|
|
|
| 565 |
dim=2,
|
| 566 |
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
| 567 |
).squeeze(2)
|
| 568 |
+
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
|
|
|
|
|
|
| 569 |
|
| 570 |
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
| 571 |
+
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
| 572 |
+
inputs,
|
| 573 |
+
label_pad_token=-100,
|
| 574 |
+
pad_token=self.tokenizer.pad_token_id,
|
| 575 |
+
device=self.accelerator.device,
|
|
|
|
|
|
|
| 576 |
)
|
| 577 |
+
|
| 578 |
+
# Perform a single forward pass
|
| 579 |
+
outputs = model(
|
| 580 |
**{
|
| 581 |
+
"input_ids": concat_inputs["input_ids"],
|
| 582 |
+
"attention_mask": concat_inputs["attention_mask"],
|
| 583 |
+
"labels": concat_inputs["labels"],
|
| 584 |
},
|
| 585 |
output_hidden_states=True,
|
| 586 |
)
|
| 587 |
|
| 588 |
+
# Split the outputs for positive and negative examples
|
| 589 |
+
outputs_pos, outputs_neg = outputs.logits.chunk(2)
|
| 590 |
+
|
| 591 |
# Calculate NLL loss
|
| 592 |
pos_loss = self.orpo_compute_custom_loss(
|
| 593 |
+
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
|
| 594 |
)
|
| 595 |
|
| 596 |
# Calculate Log Probability
|
| 597 |
pos_prob = self.orpo_compute_logps(
|
| 598 |
+
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
| 599 |
+
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
|
| 600 |
+
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
|
| 601 |
+
logits=outputs_pos,
|
| 602 |
)
|
| 603 |
neg_prob = self.orpo_compute_logps(
|
| 604 |
+
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
| 605 |
+
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
|
| 606 |
+
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
|
| 607 |
+
logits=outputs_neg,
|
| 608 |
)
|
| 609 |
|
| 610 |
# Calculate log odds
|
|
|
|
| 1301 |
train_dataset=self.train_dataset,
|
| 1302 |
eval_dataset=self.eval_dataset,
|
| 1303 |
args=training_args,
|
| 1304 |
+
tokenizer=self.tokenizer,
|
| 1305 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
| 1306 |
eval_data_collator=self.build_collator(
|
| 1307 |
training_args, is_eval=True, **data_collator_kwargs
|