File size: 6,118 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import operator
from contextlib import contextmanager
from functools import partial
from torch import nn
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.road.config import RoadConfig
from peft.tuners.tuners_utils import (
BaseTuner,
)
from peft.utils import TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING
from .layer import RoadLayer, dispatch_default
def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names):
# pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference
kwargs["adapter_names"] = adapter_names
return args, kwargs
class RoadModel(BaseTuner):
""" """
prefix: str = "road_"
tuner_layer_cls = RoadLayer
target_module_mapping = TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING
def _create_and_replace(
self,
road_config: RoadConfig,
adapter_name: str,
target: nn.Module,
target_name: str,
parent: nn.Module,
current_key,
) -> None:
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
# Regexp matching - Find key which matches current target_name in patterns provided
variant = road_config.variant
group_size = road_config.group_size
kwargs = {
"variant": variant,
"group_size": group_size,
"init_weights": road_config.init_weights,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}
# for torchao merging, we need the get_apply_tensor_subclass from the quantization config
try:
kwargs["get_apply_tensor_subclass"] = operator.attrgetter(
"hf_quantizer.quantization_config.get_apply_tensor_subclass"
)(self.model)
except AttributeError:
pass
if isinstance(target, RoadLayer):
target.update_layer(
adapter_name,
variant,
group_size,
init_weights=road_config.init_weights,
)
else:
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
new_module = self._create_new_module(road_config, adapter_name, target, device_map=device_map, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
@staticmethod
def _create_new_module(road_config: RoadConfig, adapter_name, target, **kwargs):
dispatchers = []
# avoid eager bnb import
if is_bnb_available():
from .bnb import dispatch_bnb_8bit
dispatchers.append(dispatch_bnb_8bit)
if is_bnb_4bit_available():
from .bnb import dispatch_bnb_4bit
dispatchers.append(dispatch_bnb_4bit)
dispatchers.extend(
[
dispatch_default,
]
)
new_module = None
for dispatcher in dispatchers:
new_module = dispatcher(target, adapter_name, road_config=road_config, **kwargs)
if new_module is not None: # first match wins
break
if new_module is None:
# no module could be matched
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`."
)
return new_module
@contextmanager
def _enable_peft_forward_hooks(self, *args, **kwargs):
# If adapter_names is passed as an argument, we inject it into the forward arguments.
adapter_names = kwargs.pop("adapter_names", None)
if adapter_names is None:
# nothing to do
yield
return
if self.training:
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
# Check that users only passed actually existing adapters.
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
# to check that there is at least one layer with the given name, or else something like typos can easily slip.
expected_adapters = set()
for layer in self.modules():
if isinstance(layer, RoadLayer):
expected_adapters |= layer.road_theta.keys()
unique_adapters = {name for name in adapter_names if name != "__base__"}
unexpected_adapters = unique_adapters - expected_adapters
if unexpected_adapters:
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")
hook_handles = []
for module in self.modules():
if isinstance(module, RoadLayer):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
# TODO LoRA also has hooks for beam search, ignore this for now
yield
for handle in hook_handles:
handle.remove()
|