|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import warnings |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from peft.tuners._buffer_dict import BufferDict |
|
|
from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge |
|
|
from peft.utils.integrations import check_deepspeed_zero3_enabled, gather_params_ctx |
|
|
|
|
|
|
|
|
class TrainableTokensLayer(nn.Module, BaseTunerLayer): |
|
|
|
|
|
adapter_layer_names = ("trainable_tokens_delta",) |
|
|
|
|
|
|
|
|
other_param_names = ("token_indices", "trainable_tokens_original") |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_layer: nn.Module, |
|
|
adapter_name: str, |
|
|
token_indices: list[int], |
|
|
tied_adapter: Optional[TrainableTokensLayer] = None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.base_layer = base_layer |
|
|
self._active_adapter = adapter_name |
|
|
self.kwargs = kwargs |
|
|
|
|
|
|
|
|
|
|
|
self._tied_adapter = [tied_adapter] if tied_adapter else [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.tied_adapter: |
|
|
self.trainable_tokens_delta = nn.ParameterDict({}) |
|
|
self.trainable_tokens_original = BufferDict({}) |
|
|
self.token_indices = {} |
|
|
else: |
|
|
self.trainable_tokens_delta = self.tied_adapter.trainable_tokens_delta |
|
|
self.trainable_tokens_original = self.tied_adapter.trainable_tokens_original |
|
|
self.token_indices = self.tied_adapter.token_indices |
|
|
|
|
|
|
|
|
self.merged_adapters = [] |
|
|
|
|
|
in_features, out_features = _get_in_out_features(self.get_base_layer()) |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
|
|
|
@property |
|
|
def tied_adapter(self): |
|
|
if self._tied_adapter: |
|
|
return self._tied_adapter[0] |
|
|
return None |
|
|
|
|
|
def _collect_token_weights(self, weight: torch.Tensor, rows: torch.Tensor, embed_dim: int) -> torch.Tensor: |
|
|
"""DeepSpeed zero3 specific code to initialize trainable tokens. |
|
|
|
|
|
Ensures that only the necessary weights are collected to a single rank, initialized, and then shared with all |
|
|
ranks. |
|
|
""" |
|
|
src_rank = 0 |
|
|
|
|
|
device = torch.device("cuda", torch.cuda.current_device()) |
|
|
|
|
|
with gather_params_ctx([weight], modifier_rank=None): |
|
|
if dist.get_rank() == src_rank: |
|
|
token_weights = weight[rows].clone() |
|
|
else: |
|
|
|
|
|
token_weights = torch.empty( |
|
|
(len(rows), embed_dim), |
|
|
dtype=weight.dtype, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
dist.broadcast(token_weights, src=src_rank) |
|
|
return token_weights |
|
|
|
|
|
def update_layer(self, adapter_name, **kwargs): |
|
|
if kwargs.get("tied_adapter", None): |
|
|
|
|
|
return |
|
|
|
|
|
self.token_indices[adapter_name] = kwargs["token_indices"] |
|
|
init_weights = kwargs.get("init_weights", True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weight = self.get_base_layer().weight |
|
|
embed_dim = self.get_base_layer().embedding_dim |
|
|
|
|
|
if init_weights: |
|
|
if check_deepspeed_zero3_enabled(): |
|
|
values = self._collect_token_weights(weight, self.token_indices[adapter_name], embed_dim) |
|
|
else: |
|
|
values = self.weight[self.token_indices[adapter_name]] |
|
|
else: |
|
|
|
|
|
values = torch.randn( |
|
|
(len(self.token_indices[adapter_name]), embed_dim), |
|
|
dtype=weight.dtype, |
|
|
device=weight.device, |
|
|
) |
|
|
|
|
|
self.trainable_tokens_delta[adapter_name] = nn.Parameter(values.clone(), requires_grad=True) |
|
|
self.trainable_tokens_original[adapter_name] = values.clone() |
|
|
|
|
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
|
|
|
|
def _check_overlapping_tokens(self, adapter_names): |
|
|
"""Raises an error if the token indices of the given adapter names are overlapping. |
|
|
This is currently not supported and can lead to undefined behavior of the model if no specific merging between |
|
|
the overlapping indices' values is applied. |
|
|
""" |
|
|
if len(adapter_names) <= 1: |
|
|
return |
|
|
|
|
|
indices = set() |
|
|
|
|
|
|
|
|
for adapter_name in set(adapter_names + self.merged_adapters): |
|
|
index_set = set(self.token_indices[adapter_name]) |
|
|
if len(indices.intersection(index_set)): |
|
|
raise ValueError( |
|
|
f"Token indices of adapter {adapter_name} are already defined and would result in " |
|
|
"undefined merging behavior. Only disjunct token indices are currently supported." |
|
|
) |
|
|
indices.update(index_set) |
|
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: |
|
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
|
|
|
|
if not adapter_names: |
|
|
|
|
|
return |
|
|
|
|
|
self._check_overlapping_tokens(adapter_names) |
|
|
|
|
|
merged = self.base_layer.weight.data |
|
|
|
|
|
for adapter_name in adapter_names: |
|
|
index = torch.tensor(self.token_indices[adapter_name]).to(merged.device) |
|
|
deltas = self.trainable_tokens_delta[adapter_name].to(merged) |
|
|
merged = merged.index_copy(dim=0, index=index, source=deltas) |
|
|
|
|
|
if safe_merge and not torch.isfinite(merged).all(): |
|
|
raise ValueError(f"NaNs detected in the merged weights. The adapter {adapter_name} seems to be broken") |
|
|
|
|
|
self.base_layer.weight.data = merged |
|
|
self.merged_adapters.extend(adapter_names) |
|
|
|
|
|
def unmerge(self) -> None: |
|
|
if not self.merged: |
|
|
warnings.warn("Already unmerged. Nothing to do.") |
|
|
return |
|
|
|
|
|
while len(self.merged_adapters) > 0: |
|
|
adapter_name = self.merged_adapters.pop() |
|
|
|
|
|
index = torch.tensor(self.token_indices[adapter_name]).to(self.base_layer.weight.device) |
|
|
originals = self.trainable_tokens_original[adapter_name].to(self.base_layer.weight) |
|
|
self.base_layer.weight.data.index_copy_(dim=0, index=index, source=originals) |
|
|
|
|
|
def get_merged_weights(self, active_adapters): |
|
|
W = self.base_layer.weight |
|
|
|
|
|
for adapter_name in active_adapters: |
|
|
index = torch.tensor(self.token_indices[adapter_name]).to(W.device) |
|
|
deltas = self.trainable_tokens_delta[adapter_name].to(W) |
|
|
W = W.index_copy(dim=0, index=index, source=deltas) |
|
|
|
|
|
return W |
|
|
|
|
|
def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> torch.Tensor: |
|
|
if self.disable_adapters or not active_adapters: |
|
|
if self.merged: |
|
|
self.unmerge() |
|
|
result = self.base_layer(x, *args, **kwargs) |
|
|
elif self.merged: |
|
|
result = self.base_layer(x, *args, **kwargs) |
|
|
else: |
|
|
self._check_overlapping_tokens(active_adapters) |
|
|
|
|
|
W = self.get_merged_weights(active_adapters) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(self.base_layer, torch.nn.Embedding): |
|
|
result = F.embedding( |
|
|
input=x, |
|
|
weight=W, |
|
|
padding_idx=self.base_layer.padding_idx, |
|
|
max_norm=self.base_layer.max_norm, |
|
|
norm_type=self.base_layer.norm_type, |
|
|
scale_grad_by_freq=self.base_layer.scale_grad_by_freq, |
|
|
sparse=self.base_layer.sparse, |
|
|
) |
|
|
elif isinstance(self.base_layer, torch.nn.Linear): |
|
|
|
|
|
result = F.linear( |
|
|
input=x, |
|
|
weight=W, |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
"TrainableTokensLayer wraps an unknown layer type, maybe you are targeting the wrong layer?" |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
|
return self.forward_adapters(x, self.active_adapters, *args, **kwargs) |
|
|
|