|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import warnings |
|
|
|
|
|
import torch |
|
|
from transformers.pytorch_utils import Conv1D |
|
|
|
|
|
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists |
|
|
from peft.utils import ( |
|
|
TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING, |
|
|
) |
|
|
from peft.utils.other import get_pattern_key |
|
|
|
|
|
from .layer import WaveFTLayer, WaveFTLinear |
|
|
|
|
|
|
|
|
class WaveFTModel(BaseTuner): |
|
|
prefix: str = "waveft_" |
|
|
tuner_layer_cls: type[BaseTunerLayer] = WaveFTLayer |
|
|
target_module_mapping = TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING |
|
|
|
|
|
def _calculate_proportional_parameters(self, model: torch.nn.Module, waveft_config): |
|
|
"""Calculate proportional parameter allocation for all target modules.""" |
|
|
target_modules_info = [] |
|
|
for name, module in model.named_modules(): |
|
|
if check_target_module_exists(waveft_config, name): |
|
|
|
|
|
if isinstance(module, WaveFTLayer): |
|
|
|
|
|
base_module = module.base_layer |
|
|
if isinstance(base_module, torch.nn.Linear): |
|
|
input_dim, output_dim = base_module.in_features, base_module.out_features |
|
|
elif isinstance(base_module, Conv1D): |
|
|
input_dim, output_dim = base_module.weight.shape[1], base_module.weight.shape[0] |
|
|
else: |
|
|
continue |
|
|
elif isinstance(module, torch.nn.Linear): |
|
|
input_dim, output_dim = module.in_features, module.out_features |
|
|
elif isinstance(module, Conv1D): |
|
|
input_dim, output_dim = module.weight.shape[1], module.weight.shape[0] |
|
|
else: |
|
|
continue |
|
|
target_modules_info.append((name, input_dim, output_dim)) |
|
|
|
|
|
if not target_modules_info: |
|
|
raise ValueError("No target modules found for proportional parameter allocation.") |
|
|
|
|
|
total_sum = sum(input_dim * output_dim for (_, input_dim, output_dim) in target_modules_info) |
|
|
num_layers = len(target_modules_info) |
|
|
total_budget = waveft_config.n_frequency * num_layers |
|
|
|
|
|
n_frequency_dict = {} |
|
|
for name, input_dim, output_dim in target_modules_info: |
|
|
layer_ratio = (input_dim * output_dim) / total_sum |
|
|
n_freq = round(layer_ratio * total_budget) |
|
|
n_frequency_dict[name] = n_freq |
|
|
|
|
|
return n_frequency_dict |
|
|
|
|
|
def _create_and_replace( |
|
|
self, |
|
|
waveft_config, |
|
|
adapter_name, |
|
|
target, |
|
|
target_name, |
|
|
parent, |
|
|
current_key, |
|
|
**optional_kwargs, |
|
|
): |
|
|
if current_key is None: |
|
|
raise ValueError("Current Key shouldn't be `None`") |
|
|
|
|
|
|
|
|
if waveft_config.proportional_parameters: |
|
|
if not hasattr(self, "_proportional_params_cache"): |
|
|
self._proportional_params_cache = {} |
|
|
if adapter_name not in self._proportional_params_cache: |
|
|
n_frequency_dict = self._calculate_proportional_parameters(self.model, waveft_config) |
|
|
self._proportional_params_cache[adapter_name] = n_frequency_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_frequency = None |
|
|
if ( |
|
|
waveft_config.proportional_parameters |
|
|
and hasattr(self, "_proportional_params_cache") |
|
|
and adapter_name in self._proportional_params_cache |
|
|
): |
|
|
n_frequency = self._proportional_params_cache[adapter_name].get(current_key) |
|
|
|
|
|
if n_frequency is None and "n_frequency" in optional_kwargs: |
|
|
n_frequency = optional_kwargs["n_frequency"] |
|
|
|
|
|
if n_frequency is None: |
|
|
pattern_keys = list(waveft_config.n_frequency_pattern.keys()) |
|
|
target_name_key = get_pattern_key(pattern_keys, current_key) |
|
|
n_frequency = waveft_config.n_frequency_pattern.get(target_name_key, waveft_config.n_frequency) |
|
|
|
|
|
|
|
|
wavelet_family = None |
|
|
if "wavelet_family" in optional_kwargs: |
|
|
wavelet_family = optional_kwargs["wavelet_family"] |
|
|
if wavelet_family is None: |
|
|
wavelet_family = waveft_config.wavelet_family |
|
|
|
|
|
scaling = waveft_config.scaling |
|
|
random_loc_seed = waveft_config.random_loc_seed |
|
|
bias = hasattr(target, "bias") and target.bias is not None |
|
|
|
|
|
kwargs = { |
|
|
"n_frequency": n_frequency, |
|
|
"scaling": scaling, |
|
|
"fan_in_fan_out": waveft_config.fan_in_fan_out, |
|
|
"init_weights": waveft_config.init_weights, |
|
|
"random_loc_seed": waveft_config.random_loc_seed, |
|
|
"wavelet_family": wavelet_family, |
|
|
} |
|
|
kwargs["bias"] = bias |
|
|
|
|
|
if isinstance(target, WaveFTLayer): |
|
|
target.update_layer( |
|
|
adapter_name, |
|
|
n_frequency, |
|
|
scaling, |
|
|
waveft_config.init_weights, |
|
|
random_loc_seed, |
|
|
wavelet_family=wavelet_family, |
|
|
use_idwt=waveft_config.use_idwt, |
|
|
) |
|
|
else: |
|
|
new_module = self._create_new_module(waveft_config, adapter_name, target, **kwargs) |
|
|
if adapter_name != self.active_adapter: |
|
|
new_module.requires_grad_(False) |
|
|
self._replace_module(parent, target_name, new_module, target) |
|
|
|
|
|
@staticmethod |
|
|
def _create_new_module(waveft_config, adapter_name, target, **kwargs): |
|
|
if isinstance(target, BaseTunerLayer): |
|
|
target_base_layer = target.get_base_layer() |
|
|
else: |
|
|
target_base_layer = target |
|
|
|
|
|
if isinstance(target_base_layer, torch.nn.Linear): |
|
|
if kwargs["fan_in_fan_out"]: |
|
|
warnings.warn( |
|
|
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " |
|
|
"Setting fan_in_fan_out to False." |
|
|
) |
|
|
kwargs["fan_in_fan_out"] = waveft_config.fan_in_fan_out = False |
|
|
elif isinstance(target_base_layer, Conv1D): |
|
|
kwargs["is_target_conv_1d_layer"] = True |
|
|
if not kwargs["fan_in_fan_out"]: |
|
|
warnings.warn( |
|
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." |
|
|
) |
|
|
kwargs["fan_in_fan_out"] = waveft_config.fan_in_fan_out = True |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Target module {target} is not supported. Currently, only the following modules are supported: " |
|
|
"`torch.nn.Linear`." |
|
|
) |
|
|
|
|
|
kwargs["wavelet_family"] = waveft_config.wavelet_family |
|
|
kwargs["use_idwt"] = waveft_config.use_idwt |
|
|
new_module = WaveFTLinear(target, adapter_name, **kwargs) |
|
|
|
|
|
return new_module |
|
|
|
|
|
def delete_adapter(self, adapter_name: str) -> None: |
|
|
""" |
|
|
Deletes an existing adapter. |
|
|
|
|
|
Args: |
|
|
adapter_name (str): Name of the adapter to be deleted. |
|
|
""" |
|
|
super().delete_adapter(adapter_name) |
|
|
|
|
|
if hasattr(self, "_proportional_params_cache") and adapter_name in self._proportional_params_cache: |
|
|
del self._proportional_params_cache[adapter_name] |
|
|
|