TRL documentation
BEMA for Reference Model
BEMA for Reference Model
This feature implements the BEMA algorithm to update the reference model during DPO training.
Usage
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
bema_callback = BEMACallback(update_ref_model=True)
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
train_dataset=pref_dataset,
processing_class=tokenizer,
callbacks=[bema_callback],
)
trainer.train()