|
|
import functools |
|
|
import math |
|
|
from typing import Any, Callable, Dict, List, Optional, Type, Union |
|
|
|
|
|
import torch |
|
|
from torch.distributed.checkpoint.state_dict import ( |
|
|
StateDictOptions, |
|
|
get_optimizer_state_dict, |
|
|
set_optimizer_state_dict, |
|
|
) |
|
|
from torch.distributed.checkpoint.stateful import Stateful |
|
|
|
|
|
from .parallel import ParallelBackendEnum |
|
|
from .utils.import_utils import is_bitsandbytes_available |
|
|
|
|
|
|
|
|
class OptimizerWrapper(Stateful): |
|
|
r""" |
|
|
Optimizer wrapper that: |
|
|
- allows step/zero_grad on multiple optimizers needed for virtual pipeline stages |
|
|
- saves/loading optimizer state_dict at checkpoint |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_parts: List[torch.nn.Module], |
|
|
optimizer_cls: Type[torch.optim.Optimizer], |
|
|
optimizer_kwargs: Dict[str, Any], |
|
|
) -> None: |
|
|
self.optimizer_cls = optimizer_cls |
|
|
self.optimizer_kwargs = optimizer_kwargs |
|
|
|
|
|
self.optimizers = [] |
|
|
self.model_parts = model_parts |
|
|
|
|
|
for model in self.model_parts: |
|
|
optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) |
|
|
self.optimizers.append(optimizer) |
|
|
|
|
|
def step(self) -> None: |
|
|
for optimizer in self.optimizers: |
|
|
optimizer.step() |
|
|
|
|
|
def zero_grad(self) -> None: |
|
|
for optimizer in self.optimizers: |
|
|
optimizer.zero_grad() |
|
|
|
|
|
def state_dict(self) -> Dict[str, Any]: |
|
|
func = functools.partial( |
|
|
get_optimizer_state_dict, |
|
|
options=StateDictOptions(flatten_optimizer_state_dict=True), |
|
|
) |
|
|
return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()} |
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
|
|
func = functools.partial( |
|
|
set_optimizer_state_dict, |
|
|
optim_state_dict=state_dict, |
|
|
options=StateDictOptions(flatten_optimizer_state_dict=True), |
|
|
) |
|
|
list(map(func, self.model_parts, self.optimizers)) |
|
|
|
|
|
|
|
|
class SchedulerWrapper: |
|
|
def __init__( |
|
|
self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int |
|
|
) -> None: |
|
|
self.schedulers = [] |
|
|
for optimizer in optimizers: |
|
|
self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)) |
|
|
|
|
|
def step(self) -> None: |
|
|
for scheduler in self.schedulers: |
|
|
scheduler.step() |
|
|
|
|
|
def get_last_lr(self) -> List[float]: |
|
|
|
|
|
return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)} |
|
|
|
|
|
def get_lr_scheduler_state(self) -> Dict[str, Any]: |
|
|
state_dict = {} |
|
|
if len(self.schedulers) == 1: |
|
|
state_dict["lr_scheduler"] = self.schedulers[0] |
|
|
else: |
|
|
|
|
|
|
|
|
for idx, lr_scheduler in enumerate(self.schedulers): |
|
|
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler |
|
|
return state_dict |
|
|
|
|
|
|
|
|
def get_optimizer( |
|
|
parallel_backend: ParallelBackendEnum, |
|
|
name: str, |
|
|
model_parts: List[torch.nn.Module], |
|
|
learning_rate: float = 1e-3, |
|
|
beta1: float = 0.9, |
|
|
beta2: float = 0.95, |
|
|
beta3: float = 0.999, |
|
|
epsilon: float = 1e-8, |
|
|
weight_decay: float = 1e-4, |
|
|
fused: bool = False, |
|
|
) -> Union[torch.optim.Optimizer, OptimizerWrapper]: |
|
|
name = name.lower() |
|
|
|
|
|
_raise_errors_if_packages_not_available(name) |
|
|
|
|
|
if name == "adam": |
|
|
optimizer_cls = torch.optim.Adam |
|
|
optimizer_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
"fused": fused, |
|
|
} |
|
|
elif name == "adamw": |
|
|
optimizer_cls = torch.optim.AdamW |
|
|
optimizer_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
"fused": fused, |
|
|
} |
|
|
elif name == "adam-bnb": |
|
|
from bitsandbytes.optim import Adam |
|
|
|
|
|
optimizer_cls = Adam |
|
|
optimizer_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
elif name == "adamw-bnb": |
|
|
from bitsandbytes.optim import AdamW |
|
|
|
|
|
optimizer_cls = AdamW |
|
|
optimizer_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
elif name == "adam-bnb-8bit": |
|
|
from bitsandbytes.optim import Adam8bit |
|
|
|
|
|
optimizer_cls = Adam8bit |
|
|
optimizer_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
elif name == "adamw-bnb-8bit": |
|
|
from bitsandbytes.optim import AdamW8bit |
|
|
|
|
|
optimizer_cls = AdamW8bit |
|
|
optimizer_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
|
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported optimizer: {name}") |
|
|
|
|
|
if parallel_backend == ParallelBackendEnum.ACCELERATE: |
|
|
return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs) |
|
|
elif parallel_backend == ParallelBackendEnum.PTD: |
|
|
return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs) |
|
|
|
|
|
|
|
|
def get_optimizer_accelerate( |
|
|
model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] |
|
|
) -> torch.optim.Optimizer: |
|
|
params = [param for model in model_parts for param in model.parameters() if param.requires_grad] |
|
|
optimizer = optimizer_cls(params, **optimizer_kwargs) |
|
|
return optimizer |
|
|
|
|
|
|
|
|
def get_optimizer_ptd( |
|
|
model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] |
|
|
) -> OptimizerWrapper: |
|
|
return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs) |
|
|
|
|
|
|
|
|
def get_lr_scheduler( |
|
|
parallel_backend: ParallelBackendEnum, |
|
|
name: str, |
|
|
optimizer: Union[torch.optim.Optimizer, OptimizerWrapper], |
|
|
step_rules: Optional[str] = None, |
|
|
num_warmup_steps: Optional[int] = None, |
|
|
num_training_steps: Optional[int] = None, |
|
|
num_cycles: int = 1, |
|
|
power: float = 1.0, |
|
|
lr_init: float = 1e-3, |
|
|
lr_end: float = 1e-7, |
|
|
last_epoch: int = -1, |
|
|
) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]: |
|
|
name = name.lower() |
|
|
if name == "constant": |
|
|
scheduler_lambda_fn = get_constant_schedule() |
|
|
elif name == "constant_with_warmup": |
|
|
scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps) |
|
|
elif name == "piecewise_constant": |
|
|
scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules) |
|
|
elif name == "linear": |
|
|
scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps) |
|
|
elif name == "cosine": |
|
|
scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles) |
|
|
elif name == "cosine_with_restarts": |
|
|
scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup( |
|
|
num_warmup_steps, num_training_steps, num_cycles |
|
|
) |
|
|
elif name == "polynomial": |
|
|
scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup( |
|
|
num_warmup_steps, num_training_steps, lr_init, lr_end, power |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unsupported scheduler: {name}") |
|
|
|
|
|
if parallel_backend == ParallelBackendEnum.ACCELERATE: |
|
|
return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch) |
|
|
elif parallel_backend == ParallelBackendEnum.PTD: |
|
|
return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch) |
|
|
|
|
|
|
|
|
def get_lr_scheduler_accelerate( |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], |
|
|
last_epoch: int = -1, |
|
|
) -> torch.optim.lr_scheduler.LambdaLR: |
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch) |
|
|
return scheduler |
|
|
|
|
|
|
|
|
def get_lr_scheduler_ptd( |
|
|
optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1 |
|
|
) -> SchedulerWrapper: |
|
|
return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_constant_schedule() -> Callable[[int], float]: |
|
|
r""" |
|
|
Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
|
|
""" |
|
|
|
|
|
def lr_lambda(current_step: int): |
|
|
return 1.0 |
|
|
|
|
|
return lr_lambda |
|
|
|
|
|
|
|
|
def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]: |
|
|
r""" |
|
|
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate |
|
|
increases linearly between 0 and the initial lr set in the optimizer. |
|
|
|
|
|
Args: |
|
|
num_warmup_steps (`int`): |
|
|
The number of steps for the warmup phase. |
|
|
""" |
|
|
|
|
|
def lr_lambda(current_step: int): |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1.0, num_warmup_steps)) |
|
|
return 1.0 |
|
|
|
|
|
return lr_lambda |
|
|
|
|
|
|
|
|
def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]: |
|
|
r""" |
|
|
Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
|
|
|
|
|
Args: |
|
|
step_rules (`string`): |
|
|
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate |
|
|
if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 |
|
|
steps and multiple 0.005 for the other steps. |
|
|
""" |
|
|
|
|
|
rules_dict = {} |
|
|
rule_list = step_rules.split(",") |
|
|
for rule_str in rule_list[:-1]: |
|
|
value_str, steps_str = rule_str.split(":") |
|
|
steps = int(steps_str) |
|
|
value = float(value_str) |
|
|
rules_dict[steps] = value |
|
|
last_lr_multiple = float(rule_list[-1]) |
|
|
|
|
|
def create_rules_function(rules_dict, last_lr_multiple): |
|
|
def rule_func(steps: int) -> float: |
|
|
sorted_steps = sorted(rules_dict.keys()) |
|
|
for i, sorted_step in enumerate(sorted_steps): |
|
|
if steps < sorted_step: |
|
|
return rules_dict[sorted_steps[i]] |
|
|
return last_lr_multiple |
|
|
|
|
|
return rule_func |
|
|
|
|
|
rules_func = create_rules_function(rules_dict, last_lr_multiple) |
|
|
return rules_func |
|
|
|
|
|
|
|
|
def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]: |
|
|
r""" |
|
|
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
|
|
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
|
|
|
|
|
Args: |
|
|
num_warmup_steps (`int`): |
|
|
The number of steps for the warmup phase. |
|
|
num_training_steps (`int`): |
|
|
The total number of training steps. |
|
|
""" |
|
|
|
|
|
def lr_lambda(current_step: int): |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
return max( |
|
|
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) |
|
|
) |
|
|
|
|
|
return lr_lambda |
|
|
|
|
|
|
|
|
def get_cosine_schedule_with_warmup( |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
num_cycles: float = 0.5, |
|
|
) -> Callable[[int], float]: |
|
|
r""" |
|
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the |
|
|
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
|
|
initial lr set in the optimizer. |
|
|
|
|
|
Args: |
|
|
num_warmup_steps (`int`): |
|
|
The number of steps for the warmup phase. |
|
|
num_training_steps (`int`): |
|
|
The total number of training steps. |
|
|
num_periods (`float`, *optional*, defaults to 0.5): |
|
|
The number of periods of the cosine function in a schedule (the default is to just decrease from the max |
|
|
value to 0 following a half-cosine). |
|
|
""" |
|
|
|
|
|
def lr_lambda(current_step): |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
|
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
|
|
|
|
|
return lr_lambda |
|
|
|
|
|
|
|
|
def get_cosine_with_hard_restarts_schedule_with_warmup( |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
num_cycles: int = 1, |
|
|
) -> Callable[[int], float]: |
|
|
r""" |
|
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the |
|
|
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases |
|
|
linearly between 0 and the initial lr set in the optimizer. |
|
|
|
|
|
Args: |
|
|
num_warmup_steps (`int`): |
|
|
The number of steps for the warmup phase. |
|
|
num_training_steps (`int`): |
|
|
The total number of training steps. |
|
|
num_cycles (`int`, *optional*, defaults to 1): |
|
|
The number of hard restarts to use. |
|
|
""" |
|
|
|
|
|
def lr_lambda(current_step): |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
|
|
if progress >= 1.0: |
|
|
return 0.0 |
|
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) |
|
|
|
|
|
return lr_lambda |
|
|
|
|
|
|
|
|
def get_polynomial_decay_schedule_with_warmup( |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
lr_init: float, |
|
|
lr_end: float = 1e-7, |
|
|
power: float = 1.0, |
|
|
) -> Callable[[int], float]: |
|
|
r""" |
|
|
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the |
|
|
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the |
|
|
initial lr set in the optimizer. |
|
|
|
|
|
Args: |
|
|
num_warmup_steps (`int`): |
|
|
The number of steps for the warmup phase. |
|
|
num_training_steps (`int`): |
|
|
The total number of training steps. |
|
|
lr_end (`float`, *optional*, defaults to 1e-7): |
|
|
The end LR. |
|
|
power (`float`, *optional*, defaults to 1.0): |
|
|
Power factor. |
|
|
|
|
|
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at |
|
|
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 |
|
|
""" |
|
|
|
|
|
if not (lr_init > lr_end): |
|
|
raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") |
|
|
|
|
|
def lr_lambda(current_step: int): |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
elif current_step > num_training_steps: |
|
|
return lr_end / lr_init |
|
|
else: |
|
|
lr_range = lr_init - lr_end |
|
|
decay_steps = num_training_steps - num_warmup_steps |
|
|
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps |
|
|
decay = lr_range * pct_remaining**power + lr_end |
|
|
return decay / lr_init |
|
|
|
|
|
return lr_lambda |
|
|
|
|
|
|
|
|
def _raise_errors_if_packages_not_available(name: str) -> None: |
|
|
name_split = name.split("-") |
|
|
if len(name_split) < 2: |
|
|
return |
|
|
package_name = name_split[1] |
|
|
if package_name == "bnb": |
|
|
if not is_bitsandbytes_available(): |
|
|
raise ImportError( |
|
|
f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer." |
|
|
) |
|
|
|