|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import operator |
|
|
from contextlib import contextmanager |
|
|
from functools import partial |
|
|
|
|
|
from torch import nn |
|
|
|
|
|
from peft.import_utils import is_bnb_4bit_available, is_bnb_available |
|
|
from peft.tuners.road.config import RoadConfig |
|
|
from peft.tuners.tuners_utils import ( |
|
|
BaseTuner, |
|
|
) |
|
|
from peft.utils import TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING |
|
|
|
|
|
from .layer import RoadLayer, dispatch_default |
|
|
|
|
|
|
|
|
def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names): |
|
|
|
|
|
kwargs["adapter_names"] = adapter_names |
|
|
return args, kwargs |
|
|
|
|
|
|
|
|
class RoadModel(BaseTuner): |
|
|
""" """ |
|
|
|
|
|
prefix: str = "road_" |
|
|
tuner_layer_cls = RoadLayer |
|
|
target_module_mapping = TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING |
|
|
|
|
|
def _create_and_replace( |
|
|
self, |
|
|
road_config: RoadConfig, |
|
|
adapter_name: str, |
|
|
target: nn.Module, |
|
|
target_name: str, |
|
|
parent: nn.Module, |
|
|
current_key, |
|
|
) -> None: |
|
|
if current_key is None: |
|
|
raise ValueError("Current Key shouldn't be `None`") |
|
|
|
|
|
|
|
|
variant = road_config.variant |
|
|
group_size = road_config.group_size |
|
|
|
|
|
kwargs = { |
|
|
"variant": variant, |
|
|
"group_size": group_size, |
|
|
"init_weights": road_config.init_weights, |
|
|
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), |
|
|
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), |
|
|
} |
|
|
|
|
|
try: |
|
|
kwargs["get_apply_tensor_subclass"] = operator.attrgetter( |
|
|
"hf_quantizer.quantization_config.get_apply_tensor_subclass" |
|
|
)(self.model) |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
if isinstance(target, RoadLayer): |
|
|
target.update_layer( |
|
|
adapter_name, |
|
|
variant, |
|
|
group_size, |
|
|
init_weights=road_config.init_weights, |
|
|
) |
|
|
else: |
|
|
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None |
|
|
new_module = self._create_new_module(road_config, adapter_name, target, device_map=device_map, **kwargs) |
|
|
if adapter_name not in self.active_adapters: |
|
|
|
|
|
new_module.requires_grad_(False) |
|
|
self._replace_module(parent, target_name, new_module, target) |
|
|
|
|
|
@staticmethod |
|
|
def _create_new_module(road_config: RoadConfig, adapter_name, target, **kwargs): |
|
|
dispatchers = [] |
|
|
|
|
|
|
|
|
if is_bnb_available(): |
|
|
from .bnb import dispatch_bnb_8bit |
|
|
|
|
|
dispatchers.append(dispatch_bnb_8bit) |
|
|
|
|
|
if is_bnb_4bit_available(): |
|
|
from .bnb import dispatch_bnb_4bit |
|
|
|
|
|
dispatchers.append(dispatch_bnb_4bit) |
|
|
|
|
|
dispatchers.extend( |
|
|
[ |
|
|
dispatch_default, |
|
|
] |
|
|
) |
|
|
|
|
|
new_module = None |
|
|
for dispatcher in dispatchers: |
|
|
new_module = dispatcher(target, adapter_name, road_config=road_config, **kwargs) |
|
|
if new_module is not None: |
|
|
break |
|
|
|
|
|
if new_module is None: |
|
|
|
|
|
raise ValueError( |
|
|
f"Target module {target} is not supported. Currently, only the following modules are supported: " |
|
|
"`torch.nn.Linear`." |
|
|
) |
|
|
|
|
|
return new_module |
|
|
|
|
|
@contextmanager |
|
|
def _enable_peft_forward_hooks(self, *args, **kwargs): |
|
|
|
|
|
adapter_names = kwargs.pop("adapter_names", None) |
|
|
if adapter_names is None: |
|
|
|
|
|
yield |
|
|
return |
|
|
|
|
|
if self.training: |
|
|
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expected_adapters = set() |
|
|
for layer in self.modules(): |
|
|
if isinstance(layer, RoadLayer): |
|
|
expected_adapters |= layer.road_theta.keys() |
|
|
unique_adapters = {name for name in adapter_names if name != "__base__"} |
|
|
unexpected_adapters = unique_adapters - expected_adapters |
|
|
if unexpected_adapters: |
|
|
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}") |
|
|
|
|
|
hook_handles = [] |
|
|
for module in self.modules(): |
|
|
if isinstance(module, RoadLayer): |
|
|
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) |
|
|
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) |
|
|
hook_handles.append(handle) |
|
|
|
|
|
|
|
|
|
|
|
yield |
|
|
|
|
|
for handle in hook_handles: |
|
|
handle.remove() |
|
|
|