|
|
""" |
|
|
2025.9.14 |
|
|
2025.9.11 |
|
|
4.56.2 |
|
|
0.23.0 |
|
|
__UNSLOTH_VERSIONING__ |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import Tensor |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable |
|
|
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, Path, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, logging, os, set_seed, textwrap, torch, wandb, warnings) |
|
|
|
|
|
|
|
|
import os |
|
|
from typing import * |
|
|
from dataclasses import dataclass, field |
|
|
from packaging.version import Version |
|
|
import torch |
|
|
import numpy as np |
|
|
from contextlib import nullcontext |
|
|
from torch.nn import functional as F |
|
|
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling |
|
|
from transformers.training_args import ParallelMode |
|
|
|
|
|
|
|
|
import functools |
|
|
from types import MethodType |
|
|
def prepare_for_training_mode(f): |
|
|
@functools.wraps(f) |
|
|
def wrapper(self, *args, **kwargs): |
|
|
|
|
|
if hasattr(self, 'model') and hasattr(self.model, "for_training"): |
|
|
self.model.for_training() |
|
|
output = f(self, *args, **kwargs) |
|
|
|
|
|
if hasattr(self, 'model') and hasattr(self.model, "for_inference"): |
|
|
self.model.for_inference() |
|
|
return output |
|
|
return wrapper |
|
|
pass |
|
|
|
|
|
torch_compile_options = { |
|
|
"epilogue_fusion" : True, |
|
|
"max_autotune" : False, |
|
|
"shape_padding" : True, |
|
|
"trace.enabled" : False, |
|
|
"triton.cudagraphs" : False, |
|
|
} |
|
|
|
|
|
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) |
|
|
def chunked_selective_log_softmax(logits, index): |
|
|
|
|
|
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) |
|
|
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) |
|
|
all_per_token_logps = [] |
|
|
|
|
|
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): |
|
|
chunk_logits = chunk_logits.to(torch.float32) |
|
|
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) |
|
|
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) |
|
|
per_token_logps = selected_logits - logsumexp_values |
|
|
all_per_token_logps.append(per_token_logps) |
|
|
pass |
|
|
all_per_token_logps = torch.concat(all_per_token_logps) |
|
|
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) |
|
|
return all_per_token_logps |
|
|
|
|
|
def calculate_pad_tokens_in_prompt( |
|
|
input_ids: torch.Tensor, |
|
|
logits_to_keep: int, |
|
|
pad_token_id: int |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens |
|
|
""" |
|
|
if logits_to_keep >= input_ids.shape[1]: |
|
|
raise ValueError("logits_to_keep must be smaller than the sequence length.") |
|
|
|
|
|
prompt_section = input_ids[:, :-logits_to_keep] |
|
|
|
|
|
padding_mask = (prompt_section == pad_token_id) |
|
|
|
|
|
pad_token_counts = padding_mask.sum(dim=1) |
|
|
|
|
|
return pad_token_counts |
|
|
|
|
|
def create_completion_attention_mask( |
|
|
completion_input_ids: torch.Tensor, |
|
|
left_pad_tokens_per_prompt: torch.Tensor, |
|
|
max_left_pad: int, |
|
|
pad_token_id: int |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] |
|
|
|
|
|
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens |
|
|
and pad are pad tokens, this function would make a completion mask that would 0 out the pad |
|
|
and p tokens. so in this example [0,0,0,1,1,1,0,0,0] |
|
|
""" |
|
|
batch_size, completion_len = completion_input_ids.shape |
|
|
device = completion_input_ids.device |
|
|
|
|
|
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt |
|
|
|
|
|
indices = torch.arange(completion_len, device=device).unsqueeze(0) |
|
|
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) |
|
|
|
|
|
non_padding_mask = (completion_input_ids != pad_token_id) |
|
|
|
|
|
final_mask = shift_mask & non_padding_mask |
|
|
|
|
|
return final_mask |
|
|
|
|
|
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: |
|
|
""" |
|
|
Moves all padding tokens in each sequence of a batch to the right. |
|
|
""" |
|
|
mask = (tensor != pad_id) |
|
|
|
|
|
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) |
|
|
packed_tensor = torch.gather(tensor, 1, sorted_indices) |
|
|
return packed_tensor |
|
|
@dataclass |
|
|
class UnslothAlignPropConfig(AlignPropConfig): |
|
|
""" |
|
|
|
|
|
Configuration class for the [`AlignPropTrainer`]. |
|
|
|
|
|
Using [`~transformers.HfArgumentParser`] we can turn this class into |
|
|
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the |
|
|
command line. |
|
|
|
|
|
Parameters: |
|
|
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`): |
|
|
Name of this experiment (defaults to the file name without the extension). |
|
|
run_name (`str`, *optional*, defaults to `""`): |
|
|
Name of this run. |
|
|
seed (`int`, *optional*, defaults to `0`): |
|
|
Random seed for reproducibility. |
|
|
log_with (`str` or `None`, *optional*, defaults to `None`): |
|
|
Log with either `"wandb"` or `"tensorboard"`. Check |
|
|
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details. |
|
|
log_image_freq (`int`, *optional*, defaults to `1`): |
|
|
Frequency for logging images. |
|
|
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`): |
|
|
Keyword arguments for the tracker (e.g., `wandb_project`). |
|
|
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`): |
|
|
Keyword arguments for the accelerator. |
|
|
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`): |
|
|
Keyword arguments for the accelerator project config (e.g., `logging_dir`). |
|
|
tracker_project_name (`str`, *optional*, defaults to `"trl"`): |
|
|
Name of project to use for tracking. |
|
|
logdir (`str`, *optional*, defaults to `"logs"`): |
|
|
Top-level logging directory for checkpoint saving. |
|
|
num_epochs (`int`, *optional*, defaults to `100`): |
|
|
Number of epochs to train. |
|
|
save_freq (`int`, *optional*, defaults to `1`): |
|
|
Number of epochs between saving model checkpoints. |
|
|
num_checkpoint_limit (`int`, *optional*, defaults to `5`): |
|
|
Number of checkpoints to keep before overwriting old ones. |
|
|
mixed_precision (`str`, *optional*, defaults to `"fp16"`): |
|
|
Mixed precision training. |
|
|
allow_tf32 (`bool`, *optional*, defaults to `True`): |
|
|
Allow `tf32` on Ampere GPUs. |
|
|
resume_from (`str`, *optional*, defaults to `""`): |
|
|
Path to resume training from a checkpoint. |
|
|
sample_num_steps (`int`, *optional*, defaults to `50`): |
|
|
Number of sampler inference steps. |
|
|
sample_eta (`float`, *optional*, defaults to `1.0`): |
|
|
Eta parameter for the DDIM sampler. |
|
|
sample_guidance_scale (`float`, *optional*, defaults to `5.0`): |
|
|
Classifier-free guidance weight. |
|
|
train_batch_size (`int`, *optional*, defaults to `1`): |
|
|
Batch size for training. |
|
|
train_use_8bit_adam (`bool`, *optional*, defaults to `False`): |
|
|
Whether to use the 8bit Adam optimizer from `bitsandbytes`. |
|
|
train_learning_rate (`float`, *optional*, defaults to `1e-3`): |
|
|
Learning rate. |
|
|
train_adam_beta1 (`float`, *optional*, defaults to `0.9`): |
|
|
Beta1 for Adam optimizer. |
|
|
train_adam_beta2 (`float`, *optional*, defaults to `0.999`): |
|
|
Beta2 for Adam optimizer. |
|
|
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`): |
|
|
Weight decay for Adam optimizer. |
|
|
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`): |
|
|
Epsilon value for Adam optimizer. |
|
|
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`): |
|
|
Number of gradient accumulation steps. |
|
|
train_max_grad_norm (`float`, *optional*, defaults to `1.0`): |
|
|
Maximum gradient norm for gradient clipping. |
|
|
negative_prompts (`str` or `None`, *optional*, defaults to `None`): |
|
|
Comma-separated list of prompts to use as negative examples. |
|
|
truncated_backprop_rand (`bool`, *optional*, defaults to `True`): |
|
|
If `True`, randomized truncation to different diffusion timesteps is used. |
|
|
truncated_backprop_timestep (`int`, *optional*, defaults to `49`): |
|
|
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`. |
|
|
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`): |
|
|
Range of diffusion timesteps for randomized truncated backpropagation. |
|
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
|
Whether to push the final model to the Hub. |
|
|
|
|
|
""" |
|
|
vllm_sampling_params: Optional[Any] = field( |
|
|
default = None, |
|
|
metadata = {'help': 'vLLM SamplingParams'}, |
|
|
) |
|
|
unsloth_num_chunks : Optional[int] = field( |
|
|
default = -1, |
|
|
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, |
|
|
) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
exp_name = 'colab_kernel_launcher', |
|
|
run_name = '', |
|
|
seed = 3407, |
|
|
log_with = None, |
|
|
log_image_freq = 1, |
|
|
tracker_project_name = 'trl', |
|
|
logdir = 'logs', |
|
|
num_epochs = 100, |
|
|
save_freq = 1, |
|
|
num_checkpoint_limit = 5, |
|
|
mixed_precision = 'fp16', |
|
|
allow_tf32 = True, |
|
|
resume_from = '', |
|
|
sample_num_steps = 50, |
|
|
sample_eta = 1.0, |
|
|
sample_guidance_scale = 5.0, |
|
|
train_batch_size = 1, |
|
|
train_use_8bit_adam = False, |
|
|
train_learning_rate = 5e-05, |
|
|
train_adam_beta1 = 0.9, |
|
|
train_adam_beta2 = 0.999, |
|
|
train_adam_weight_decay = 0.01, |
|
|
train_adam_epsilon = 1e-08, |
|
|
train_gradient_accumulation_steps = 2, |
|
|
train_max_grad_norm = 1.0, |
|
|
negative_prompts = None, |
|
|
truncated_backprop_rand = True, |
|
|
truncated_backprop_timestep = 49, |
|
|
push_to_hub = False, |
|
|
vllm_sampling_params = None, |
|
|
unsloth_num_chunks = -1, |
|
|
|
|
|
**kwargs, |
|
|
): |
|
|
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') |
|
|
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') |
|
|
|
|
|
super().__init__( |
|
|
exp_name = exp_name, |
|
|
run_name = run_name, |
|
|
seed = seed, |
|
|
log_with = log_with, |
|
|
log_image_freq = log_image_freq, |
|
|
tracker_project_name = tracker_project_name, |
|
|
logdir = logdir, |
|
|
num_epochs = num_epochs, |
|
|
save_freq = save_freq, |
|
|
num_checkpoint_limit = num_checkpoint_limit, |
|
|
mixed_precision = mixed_precision, |
|
|
allow_tf32 = allow_tf32, |
|
|
resume_from = resume_from, |
|
|
sample_num_steps = sample_num_steps, |
|
|
sample_eta = sample_eta, |
|
|
sample_guidance_scale = sample_guidance_scale, |
|
|
train_batch_size = train_batch_size, |
|
|
train_use_8bit_adam = train_use_8bit_adam, |
|
|
train_learning_rate = train_learning_rate, |
|
|
train_adam_beta1 = train_adam_beta1, |
|
|
train_adam_beta2 = train_adam_beta2, |
|
|
train_adam_weight_decay = train_adam_weight_decay, |
|
|
train_adam_epsilon = train_adam_epsilon, |
|
|
train_gradient_accumulation_steps = train_gradient_accumulation_steps, |
|
|
train_max_grad_norm = train_max_grad_norm, |
|
|
negative_prompts = negative_prompts, |
|
|
truncated_backprop_rand = truncated_backprop_rand, |
|
|
truncated_backprop_timestep = truncated_backprop_timestep, |
|
|
push_to_hub = push_to_hub,**kwargs) |
|
|
self.vllm_sampling_params = vllm_sampling_params |
|
|
self.unsloth_num_chunks = unsloth_num_chunks |
|
|
|
|
|
pass |
|
|
|
|
|
class _UnslothAlignPropTrainer(PyTorchModelHubMixin): |
|
|
"""""" |
|
|
|
|
|
_tag_names = ["trl", "alignprop"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: AlignPropConfig, |
|
|
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor], |
|
|
prompt_function: Callable[[], tuple[str, Any]], |
|
|
sd_pipeline: DDPOStableDiffusionPipeline, |
|
|
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, |
|
|
): |
|
|
warnings.warn( |
|
|
"AlignPropTrainer is deprecated and will be removed in version 0.23.0.", |
|
|
DeprecationWarning, |
|
|
) |
|
|
if image_samples_hook is None: |
|
|
logger.warning("No image_samples_hook provided; no images will be logged") |
|
|
|
|
|
self.prompt_fn = prompt_function |
|
|
self.reward_fn = reward_function |
|
|
self.config = config |
|
|
self.image_samples_callback = image_samples_hook |
|
|
|
|
|
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) |
|
|
|
|
|
if self.config.resume_from: |
|
|
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) |
|
|
if "checkpoint_" not in os.path.basename(self.config.resume_from): |
|
|
|
|
|
checkpoints = list( |
|
|
filter( |
|
|
lambda x: "checkpoint_" in x, |
|
|
os.listdir(self.config.resume_from), |
|
|
) |
|
|
) |
|
|
if len(checkpoints) == 0: |
|
|
raise ValueError(f"No checkpoints found in {self.config.resume_from}") |
|
|
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) |
|
|
self.config.resume_from = os.path.join( |
|
|
self.config.resume_from, |
|
|
f"checkpoint_{checkpoint_numbers[-1]}", |
|
|
) |
|
|
|
|
|
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 |
|
|
|
|
|
self.accelerator = Accelerator( |
|
|
log_with=self.config.log_with, |
|
|
mixed_precision=self.config.mixed_precision, |
|
|
project_config=accelerator_project_config, |
|
|
|
|
|
|
|
|
|
|
|
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps, |
|
|
**self.config.accelerator_kwargs, |
|
|
) |
|
|
|
|
|
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" |
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
self.accelerator.init_trackers( |
|
|
self.config.tracker_project_name, |
|
|
config=dict(alignprop_trainer_config=config.to_dict()) |
|
|
if not is_using_tensorboard |
|
|
else config.to_dict(), |
|
|
init_kwargs=self.config.tracker_kwargs, |
|
|
) |
|
|
|
|
|
logger.info(f"\n{config}") |
|
|
|
|
|
set_seed(self.config.seed, device_specific=True) |
|
|
|
|
|
self.sd_pipeline = sd_pipeline |
|
|
|
|
|
self.sd_pipeline.set_progress_bar_config( |
|
|
position=1, |
|
|
disable=not self.accelerator.is_local_main_process, |
|
|
leave=False, |
|
|
desc="Timestep", |
|
|
dynamic_ncols=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.accelerator.mixed_precision == "fp16": |
|
|
inference_dtype = torch.float16 |
|
|
elif self.accelerator.mixed_precision == "bf16": |
|
|
inference_dtype = torch.bfloat16 |
|
|
else: |
|
|
inference_dtype = torch.float32 |
|
|
|
|
|
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) |
|
|
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) |
|
|
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) |
|
|
|
|
|
trainable_layers = self.sd_pipeline.get_trainable_layers() |
|
|
|
|
|
self.accelerator.register_save_state_pre_hook(self._save_model_hook) |
|
|
self.accelerator.register_load_state_pre_hook(self._load_model_hook) |
|
|
|
|
|
|
|
|
|
|
|
if self.config.allow_tf32 and torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
self.optimizer = self._setup_optimizer( |
|
|
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers |
|
|
) |
|
|
|
|
|
self.neg_prompt_embed = self.sd_pipeline.text_encoder( |
|
|
self.sd_pipeline.tokenizer( |
|
|
[""] if self.config.negative_prompts is None else self.config.negative_prompts, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=self.sd_pipeline.tokenizer.model_max_length, |
|
|
).input_ids.to(self.accelerator.device) |
|
|
)[0] |
|
|
|
|
|
|
|
|
|
|
|
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast |
|
|
|
|
|
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: |
|
|
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) |
|
|
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) |
|
|
else: |
|
|
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) |
|
|
|
|
|
if config.resume_from: |
|
|
logger.info(f"Resuming from {config.resume_from}") |
|
|
self.accelerator.load_state(config.resume_from) |
|
|
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 |
|
|
else: |
|
|
self.first_epoch = 0 |
|
|
|
|
|
def compute_rewards(self, prompt_image_pairs): |
|
|
reward, reward_metadata = self.reward_fn( |
|
|
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"] |
|
|
) |
|
|
return reward |
|
|
|
|
|
def step(self, epoch: int, global_step: int): |
|
|
""" |
|
|
Perform a single step of training. |
|
|
|
|
|
Args: |
|
|
epoch (int): The current epoch. |
|
|
global_step (int): The current global step. |
|
|
|
|
|
Side Effects: |
|
|
- Model weights are updated |
|
|
- Logs the statistics to the accelerator trackers. |
|
|
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, |
|
|
and the accelerator tracker. |
|
|
|
|
|
Returns: |
|
|
global_step (int): The updated global step. |
|
|
""" |
|
|
info = defaultdict(list) |
|
|
|
|
|
self.sd_pipeline.unet.train() |
|
|
|
|
|
for _ in range(self.config.train_gradient_accumulation_steps): |
|
|
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad(): |
|
|
prompt_image_pairs = self._generate_samples( |
|
|
batch_size=self.config.train_batch_size, |
|
|
) |
|
|
|
|
|
rewards = self.compute_rewards(prompt_image_pairs) |
|
|
|
|
|
prompt_image_pairs["rewards"] = rewards |
|
|
|
|
|
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy() |
|
|
|
|
|
loss = self.calculate_loss(rewards) |
|
|
|
|
|
self.accelerator.backward(loss) |
|
|
|
|
|
if self.accelerator.sync_gradients: |
|
|
self.accelerator.clip_grad_norm_( |
|
|
self.trainable_layers.parameters() |
|
|
if not isinstance(self.trainable_layers, list) |
|
|
else self.trainable_layers, |
|
|
self.config.train_max_grad_norm, |
|
|
) |
|
|
|
|
|
self.optimizer.step() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
info["reward_mean"].append(rewards_vis.mean()) |
|
|
info["reward_std"].append(rewards_vis.std()) |
|
|
info["loss"].append(loss.item()) |
|
|
|
|
|
|
|
|
if self.accelerator.sync_gradients: |
|
|
|
|
|
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()} |
|
|
info = self.accelerator.reduce(info, reduction="mean") |
|
|
info.update({"epoch": epoch}) |
|
|
self.accelerator.log(info, step=global_step) |
|
|
global_step += 1 |
|
|
info = defaultdict(list) |
|
|
else: |
|
|
raise ValueError( |
|
|
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." |
|
|
) |
|
|
|
|
|
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0: |
|
|
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0]) |
|
|
|
|
|
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: |
|
|
self.accelerator.save_state() |
|
|
|
|
|
return global_step |
|
|
|
|
|
def calculate_loss(self, rewards): |
|
|
""" |
|
|
Calculate the loss for a batch of an unpacked sample |
|
|
|
|
|
Args: |
|
|
rewards (torch.Tensor): |
|
|
Differentiable reward scalars for each generated image, shape: [batch_size] |
|
|
|
|
|
Returns: |
|
|
loss (torch.Tensor) (all of these are of shape (1,)) |
|
|
""" |
|
|
|
|
|
loss = 10.0 - (rewards).mean() |
|
|
return loss |
|
|
|
|
|
def loss( |
|
|
self, |
|
|
advantages: torch.Tensor, |
|
|
clip_range: float, |
|
|
ratio: torch.Tensor, |
|
|
): |
|
|
unclipped_loss = -advantages * ratio |
|
|
clipped_loss = -advantages * torch.clamp( |
|
|
ratio, |
|
|
1.0 - clip_range, |
|
|
1.0 + clip_range, |
|
|
) |
|
|
return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) |
|
|
|
|
|
def _setup_optimizer(self, trainable_layers_parameters): |
|
|
if self.config.train_use_8bit_adam: |
|
|
import bitsandbytes |
|
|
|
|
|
optimizer_cls = bitsandbytes.optim.AdamW8bit |
|
|
else: |
|
|
optimizer_cls = torch.optim.AdamW |
|
|
|
|
|
return optimizer_cls( |
|
|
trainable_layers_parameters, |
|
|
lr=self.config.train_learning_rate, |
|
|
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), |
|
|
weight_decay=self.config.train_adam_weight_decay, |
|
|
eps=self.config.train_adam_epsilon, |
|
|
) |
|
|
|
|
|
def _save_model_hook(self, models, weights, output_dir): |
|
|
self.sd_pipeline.save_checkpoint(models, weights, output_dir) |
|
|
weights.pop() |
|
|
|
|
|
def _load_model_hook(self, models, input_dir): |
|
|
self.sd_pipeline.load_checkpoint(models, input_dir) |
|
|
models.pop() |
|
|
|
|
|
def _generate_samples(self, batch_size, with_grad=True, prompts=None): |
|
|
""" |
|
|
Generate samples from the model |
|
|
|
|
|
Args: |
|
|
batch_size (int): Batch size to use for sampling |
|
|
with_grad (bool): Whether the generated RGBs should have gradients attached to it. |
|
|
prompts (list[str], *optional*): If provided, use these prompts instead of generating new ones. |
|
|
|
|
|
Returns: |
|
|
prompt_image_pairs (dict[Any]) |
|
|
""" |
|
|
prompt_image_pairs = {} |
|
|
|
|
|
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) |
|
|
|
|
|
if prompts is None: |
|
|
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) |
|
|
else: |
|
|
prompt_metadata = [{} for _ in range(batch_size)] |
|
|
|
|
|
prompt_ids = self.sd_pipeline.tokenizer( |
|
|
prompts, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=self.sd_pipeline.tokenizer.model_max_length, |
|
|
).input_ids.to(self.accelerator.device) |
|
|
|
|
|
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] |
|
|
|
|
|
if with_grad: |
|
|
sd_output = self.sd_pipeline.rgb_with_grad( |
|
|
prompt_embeds=prompt_embeds, |
|
|
negative_prompt_embeds=sample_neg_prompt_embeds, |
|
|
num_inference_steps=self.config.sample_num_steps, |
|
|
guidance_scale=self.config.sample_guidance_scale, |
|
|
eta=self.config.sample_eta, |
|
|
truncated_backprop_rand=self.config.truncated_backprop_rand, |
|
|
truncated_backprop_timestep=self.config.truncated_backprop_timestep, |
|
|
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax, |
|
|
output_type="pt", |
|
|
) |
|
|
else: |
|
|
sd_output = self.sd_pipeline( |
|
|
prompt_embeds=prompt_embeds, |
|
|
negative_prompt_embeds=sample_neg_prompt_embeds, |
|
|
num_inference_steps=self.config.sample_num_steps, |
|
|
guidance_scale=self.config.sample_guidance_scale, |
|
|
eta=self.config.sample_eta, |
|
|
output_type="pt", |
|
|
) |
|
|
|
|
|
images = sd_output.images |
|
|
|
|
|
prompt_image_pairs["images"] = images |
|
|
prompt_image_pairs["prompts"] = prompts |
|
|
prompt_image_pairs["prompt_metadata"] = prompt_metadata |
|
|
|
|
|
return prompt_image_pairs |
|
|
|
|
|
def train(self, epochs: Optional[int] = None): |
|
|
""" |
|
|
Train the model for a given number of epochs |
|
|
""" |
|
|
global_step = 0 |
|
|
if epochs is None: |
|
|
epochs = self.config.num_epochs |
|
|
for epoch in range(self.first_epoch, epochs): |
|
|
global_step = self.step(epoch, global_step) |
|
|
|
|
|
def _save_pretrained(self, save_directory): |
|
|
self.sd_pipeline.save_pretrained(save_directory) |
|
|
self.create_model_card() |
|
|
|
|
|
|
|
|
def _save_checkpoint(self, model, trial): |
|
|
if self.args.hub_model_id is None: |
|
|
model_name = Path(self.args.output_dir).name |
|
|
else: |
|
|
model_name = self.args.hub_model_id.split("/")[-1] |
|
|
self.create_model_card(model_name=model_name) |
|
|
super()._save_checkpoint(model, trial) |
|
|
|
|
|
def create_model_card( |
|
|
self, |
|
|
model_name: Optional[str] = None, |
|
|
dataset_name: Optional[str] = None, |
|
|
tags: Union[str, list[str], None] = None, |
|
|
): |
|
|
""" |
|
|
Creates a draft of a model card using the information available to the `Trainer`. |
|
|
|
|
|
Args: |
|
|
model_name (`str` or `None`, *optional*, defaults to `None`): |
|
|
Name of the model. |
|
|
dataset_name (`str` or `None`, *optional*, defaults to `None`): |
|
|
Name of the dataset used for training. |
|
|
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): |
|
|
Tags to be associated with the model card. |
|
|
""" |
|
|
if not self.is_world_process_zero(): |
|
|
return |
|
|
|
|
|
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): |
|
|
base_model = self.model.config._name_or_path |
|
|
else: |
|
|
base_model = None |
|
|
|
|
|
|
|
|
if tags is None: |
|
|
tags = set() |
|
|
elif isinstance(tags, str): |
|
|
tags = {tags} |
|
|
else: |
|
|
tags = set(tags) |
|
|
|
|
|
if hasattr(self.model.config, "unsloth_version"): |
|
|
tags.add("unsloth") |
|
|
|
|
|
if "JOB_ID" in os.environ: |
|
|
tags.add("hf_jobs") |
|
|
|
|
|
tags.update(self._tag_names) |
|
|
|
|
|
|
|
|
citation = textwrap.dedent("""\ |
|
|
@article{prabhudesai2024aligning, |
|
|
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}}, |
|
|
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki}, |
|
|
year = 2024, |
|
|
eprint = {arXiv:2310.03739} |
|
|
}""") |
|
|
|
|
|
model_card = generate_model_card( |
|
|
base_model=base_model, |
|
|
model_name=model_name, |
|
|
hub_model_id=self.hub_model_id, |
|
|
dataset_name=dataset_name, |
|
|
tags=tags, |
|
|
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, |
|
|
comet_url=get_comet_experiment_url(), |
|
|
trainer_name="AlignProp", |
|
|
trainer_citation=citation, |
|
|
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation", |
|
|
paper_id="2310.03739", |
|
|
) |
|
|
|
|
|
model_card.save(os.path.join(self.args.output_dir, "README.md")) |
|
|
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer): |
|
|
""" |
|
|
|
|
|
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is |
|
|
heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ As of now only Stable Diffusion based |
|
|
pipelines are supported |
|
|
|
|
|
Args: |
|
|
config (`AlignPropConfig`): |
|
|
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details. |
|
|
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`): |
|
|
Reward function to be used |
|
|
prompt_function (`Callable[[], tuple[str, Any]]`): |
|
|
Function to generate prompts to guide model |
|
|
sd_pipeline (`DDPOStableDiffusionPipeline`): |
|
|
Stable Diffusion pipeline to be used for training. |
|
|
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`): |
|
|
Hook to be called to log images |
|
|
|
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
config, |
|
|
reward_function, |
|
|
prompt_function, |
|
|
sd_pipeline, |
|
|
image_samples_hook = None, |
|
|
**kwargs |
|
|
): |
|
|
if args is None: args = UnslothAlignPropConfig() |
|
|
other_metrics = [] |
|
|
|
|
|
from unsloth_zoo.logging_utils import PatchRLStatistics |
|
|
PatchRLStatistics('alignprop_trainer', other_metrics) |
|
|
|
|
|
|
|
|
|
|
|
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: |
|
|
if getattr(args, "_n_gpu", 1) != 1: |
|
|
args._n_gpu = 1 |
|
|
if "model" in locals() and hasattr(model, "for_training"): |
|
|
model.for_training() |
|
|
super().__init__( |
|
|
config = config, |
|
|
reward_function = reward_function, |
|
|
prompt_function = prompt_function, |
|
|
sd_pipeline = sd_pipeline, |
|
|
image_samples_hook = image_samples_hook,**kwargs) |
|
|
if "model" in locals() and hasattr(model, "for_inference"): |
|
|
model.for_inference() |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if hasattr(logger, "addFilter"): |
|
|
import logging |
|
|
class HideLoggingMessage(logging.Filter): |
|
|
def __init__(self, text): self.text = text |
|
|
def filter(self, x): return not (self.text in x.getMessage()) |
|
|
pass |
|
|
logger.addFilter(HideLoggingMessage("`use_cache=True`")) |
|
|
|
|
|
|