|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import platform |
|
|
import re |
|
|
import warnings |
|
|
from typing import Optional |
|
|
|
|
|
import huggingface_hub |
|
|
import torch |
|
|
from huggingface_hub import file_exists, hf_hub_download |
|
|
from huggingface_hub.errors import EntryNotFoundError, LocalEntryNotFoundError |
|
|
from safetensors.torch import load_file as safe_load_file |
|
|
from transformers.utils import http_user_agent |
|
|
|
|
|
from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING |
|
|
|
|
|
from .constants import INCLUDE_LINEAR_LAYERS_SHORTHAND |
|
|
from .other import ( |
|
|
EMBEDDING_LAYER_NAMES, |
|
|
SAFETENSORS_WEIGHTS_NAME, |
|
|
WEIGHTS_NAME, |
|
|
AuxiliaryTrainingWrapper, |
|
|
check_file_exists_on_hf_hub, |
|
|
infer_device, |
|
|
match_target_against_key, |
|
|
) |
|
|
from .peft_types import PeftType |
|
|
|
|
|
|
|
|
def has_valid_embedding_base_layer(layer): |
|
|
"""Check if the layer has an embedding base layer""" |
|
|
return hasattr(layer, "base_layer") and isinstance(layer.base_layer, (torch.nn.Linear, torch.nn.Embedding)) |
|
|
|
|
|
|
|
|
def get_embedding_layer_name(model, layer, is_embedding_in_target_modules): |
|
|
"""Get the name of the embedding module for a given layer.""" |
|
|
for name, module in model.named_modules(): |
|
|
if (not is_embedding_in_target_modules and module == layer) or module == getattr(layer, "base_layer", None): |
|
|
return name |
|
|
return None |
|
|
|
|
|
|
|
|
def get_peft_model_state_dict( |
|
|
model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto" |
|
|
): |
|
|
""" |
|
|
Get the state dict of the given adapter of the PEFT model. |
|
|
|
|
|
This only includes the PEFT parameters, not the parameters of the base model. Thus the returned `state_dict` is |
|
|
generally small compared to the full model size. To retrieve the full `state_dict`, just call `model.state_dict()`. |
|
|
|
|
|
Note that the adapter name is removed from the `state_dict`, as this is just an arbitrary name that can be changed |
|
|
when loading the adapter. So e.g. if the adapter name is `'default'` and the original key is |
|
|
`'model.q_proj.lora_A.default.weight'`, the returned key will be `'model.q_proj.lora_A.weight'`. Use this function |
|
|
in conjunction with [`set_peft_model_state_dict`] to take care of the adapter name when loading weights. |
|
|
|
|
|
Args: |
|
|
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, |
|
|
the model should be the underlying model/unwrapped model (i.e. model.module). |
|
|
state_dict (`dict`, *optional*, defaults to `None`): |
|
|
The state dict of the model. If not provided, the state dict of the passed model will be used. |
|
|
adapter_name (`str`, *optional*, defaults to `"default"`): |
|
|
The name of the adapter whose state dict should be returned. |
|
|
unwrap_compiled (`bool`, *optional*, defaults to `False`): |
|
|
Whether to unwrap the model if torch.compile was used. |
|
|
save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`): |
|
|
If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common embedding |
|
|
layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. Based on it |
|
|
sets the boolean flag. This only works for 🤗 transformers models. |
|
|
|
|
|
""" |
|
|
if unwrap_compiled: |
|
|
model = getattr(model, "_orig_mod", model) |
|
|
|
|
|
config = model.peft_config[adapter_name] |
|
|
if state_dict is None: |
|
|
state_dict = model.state_dict() |
|
|
|
|
|
|
|
|
if config.peft_type in (PeftType.LORA, PeftType.ADALORA): |
|
|
|
|
|
|
|
|
|
|
|
bias = config.bias |
|
|
if bias == "none": |
|
|
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} |
|
|
elif bias == "all": |
|
|
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} |
|
|
elif bias == "lora_only": |
|
|
to_return = {} |
|
|
for k in state_dict: |
|
|
if "lora_" in k: |
|
|
to_return[k] = state_dict[k] |
|
|
bias_name = k.split("lora_")[0] + "bias" |
|
|
if bias_name in state_dict: |
|
|
to_return[bias_name] = state_dict[bias_name] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} |
|
|
if config.peft_type == PeftType.ADALORA: |
|
|
rank_pattern = config.rank_pattern |
|
|
if rank_pattern is not None: |
|
|
rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()} |
|
|
config.rank_pattern = rank_pattern |
|
|
to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) |
|
|
|
|
|
if config.use_dora: |
|
|
|
|
|
|
|
|
|
|
|
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight" |
|
|
|
|
|
def renamed_dora_weights(k): |
|
|
if k.endswith(new_dora_suffix): |
|
|
k = k[:-7] |
|
|
return k |
|
|
|
|
|
to_return = {renamed_dora_weights(k): v for k, v in to_return.items()} |
|
|
|
|
|
elif config.peft_type == PeftType.BOFT: |
|
|
bias = config.bias |
|
|
if bias == "none": |
|
|
to_return = {k: state_dict[k] for k in state_dict if "boft_" in k} |
|
|
elif bias == "all": |
|
|
to_return = {k: state_dict[k] for k in state_dict if "boft_" in k or "bias" in k} |
|
|
elif bias == "boft_only": |
|
|
to_return = {} |
|
|
for k in state_dict: |
|
|
if "boft_" in k: |
|
|
to_return[k] = state_dict[k] |
|
|
bias_name = k.split("boft_")[0] + "bias" |
|
|
if bias_name in state_dict: |
|
|
to_return[bias_name] = state_dict[bias_name] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
elif config.peft_type == PeftType.ADAPTION_PROMPT: |
|
|
to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} |
|
|
|
|
|
elif config.is_prompt_learning: |
|
|
to_return = {} |
|
|
if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: |
|
|
to_return["prefix_task_cols"] = model.prompt_encoder[adapter_name].prefix_task_cols |
|
|
to_return["prefix_task_rows"] = model.prompt_encoder[adapter_name].prefix_task_rows |
|
|
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight |
|
|
else: |
|
|
if config.inference_mode: |
|
|
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight |
|
|
else: |
|
|
prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) |
|
|
to_return["prompt_embeddings"] = prompt_embeddings |
|
|
|
|
|
elif config.peft_type == PeftType.SHIRA: |
|
|
shira_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] |
|
|
to_return = {k: state_dict[k] for k in state_dict if shira_prefix in k} |
|
|
if platform.system() == "Windows": |
|
|
warnings.warn( |
|
|
"Windows has issues saving integers into safetensors. Hence, we convert shira_indices to float32 " |
|
|
"before saving on Windows OS. The shira_indices will always be converted to integers when loading." |
|
|
) |
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, "shira_indices"): |
|
|
for k, v in module.shira_indices.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
to_return[f"{name}.shira_indices.{k}"] = ( |
|
|
v.to(torch.float32) if platform.system() == "Windows" else v |
|
|
) |
|
|
|
|
|
elif config.peft_type == PeftType.VERA: |
|
|
vera_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] |
|
|
to_return = {k: state_dict[k] for k in state_dict if vera_prefix in k} |
|
|
if config.save_projection: |
|
|
|
|
|
|
|
|
if f"base_model.vera_A.{adapter_name}" not in state_dict: |
|
|
raise ValueError( |
|
|
"Model was initialised to not save vera_A and vera_B but config now specifies to save projection!" |
|
|
" Set `config.save_projection` to `False`." |
|
|
) |
|
|
to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name] |
|
|
to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name] |
|
|
elif config.peft_type == PeftType.XLORA: |
|
|
to_return = {k: state_dict[k] for k in state_dict if "internal_xlora_classifier" in k} |
|
|
elif config.peft_type == PeftType.VBLORA: |
|
|
to_return = {} |
|
|
|
|
|
if config.num_vectors < 2**8: |
|
|
indices_dtype = torch.uint8 |
|
|
elif config.num_vectors < 2**15: |
|
|
indices_dtype = torch.int16 |
|
|
elif config.num_vectors < 2**31: |
|
|
indices_dtype = torch.int32 |
|
|
else: |
|
|
indices_dtype = torch.int64 |
|
|
if config.save_only_topk_weights: |
|
|
|
|
|
for k in state_dict: |
|
|
if "vblora_logits" in k: |
|
|
logits, indices = state_dict[k].topk(config.topk) |
|
|
to_return.update({k + "_topk_indices": indices.to(dtype=indices_dtype)}) |
|
|
to_return.update({k + "_topk_weights": torch.softmax(logits, dim=-1)[:, :, :-1].contiguous()}) |
|
|
else: |
|
|
to_return = {k: state_dict[k] for k in state_dict if "vblora_logits" in k} |
|
|
to_return["base_model.vblora_vector_bank." + adapter_name] = state_dict[ |
|
|
"base_model.vblora_vector_bank." + adapter_name |
|
|
] |
|
|
elif config.peft_type in list(PeftType): |
|
|
prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] |
|
|
to_return = {k: state_dict[k] for k in state_dict if prefix in k} |
|
|
else: |
|
|
raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, AuxiliaryTrainingWrapper): |
|
|
if name.startswith("_fsdp_wrapped_module."): |
|
|
|
|
|
|
|
|
name = name.removeprefix("_fsdp_wrapped_module.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
module_state_dict = { |
|
|
k.removeprefix(f"{name}."): v for k, v in state_dict.items() if k.startswith(f"{name}.") |
|
|
} |
|
|
to_return.update( |
|
|
{f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding_is_targeted = False |
|
|
if hasattr(config, "target_modules"): |
|
|
if isinstance(config.target_modules, str) and (config.target_modules != INCLUDE_LINEAR_LAYERS_SHORTHAND): |
|
|
|
|
|
|
|
|
_model = model.get_base_model() if hasattr(model, "get_base_model") else model |
|
|
embedding_is_targeted = any( |
|
|
match_target_against_key(config.target_modules, k) |
|
|
for k, _ in _model.named_modules() |
|
|
if any(re.match(rf"(.*\.)?{e}$", k) for e in EMBEDDING_LAYER_NAMES) |
|
|
) |
|
|
elif config.target_modules: |
|
|
embedding_is_targeted = any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES) |
|
|
|
|
|
using_trainable_tokens = ( |
|
|
config.peft_type == PeftType.TRAINABLE_TOKENS or getattr(config, "trainable_token_indices", None) is not None |
|
|
) |
|
|
|
|
|
if save_embedding_layers == "auto" and embedding_is_targeted and not using_trainable_tokens: |
|
|
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.") |
|
|
save_embedding_layers = True |
|
|
elif save_embedding_layers == "auto": |
|
|
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None) |
|
|
model_id = getattr(config, "base_model_name_or_path", None) |
|
|
|
|
|
|
|
|
|
|
|
has_base_config = False |
|
|
|
|
|
|
|
|
if model_id is not None: |
|
|
local_config_exists = os.path.exists(os.path.join(model_id, "config.json")) |
|
|
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json") |
|
|
if exists is None: |
|
|
|
|
|
warnings.warn( |
|
|
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified." |
|
|
) |
|
|
has_base_config = False |
|
|
else: |
|
|
has_base_config = exists |
|
|
|
|
|
|
|
|
if ( |
|
|
vocab_size |
|
|
and model_id |
|
|
and has_base_config |
|
|
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size) |
|
|
): |
|
|
warnings.warn( |
|
|
"Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning." |
|
|
) |
|
|
save_embedding_layers = True |
|
|
else: |
|
|
save_embedding_layers = False |
|
|
|
|
|
if save_embedding_layers and hasattr(model, "get_input_embeddings"): |
|
|
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]: |
|
|
|
|
|
|
|
|
if not embedding_is_targeted or has_valid_embedding_base_layer(layer): |
|
|
embedding_module_name = get_embedding_layer_name(model, layer, embedding_is_targeted) |
|
|
if embedding_module_name: |
|
|
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k}) |
|
|
elif save_embedding_layers: |
|
|
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.") |
|
|
|
|
|
|
|
|
|
|
|
pattern = re.compile(re.escape(f".{adapter_name}") + r"$") |
|
|
|
|
|
def remove_adapter_name(key): |
|
|
if "." not in key: |
|
|
|
|
|
return key |
|
|
|
|
|
if key.endswith(f".{adapter_name}"): |
|
|
|
|
|
return key.removesuffix(f".{adapter_name}") |
|
|
|
|
|
|
|
|
key, _, suffix = key.rpartition(".") |
|
|
|
|
|
if (config.peft_type == PeftType.VBLORA) and suffix.startswith(f"{adapter_name}_"): |
|
|
|
|
|
|
|
|
|
|
|
return key + "_" + suffix.removeprefix(f"{adapter_name}_") |
|
|
|
|
|
key = pattern.sub("", key) |
|
|
return f"{key}.{suffix}" |
|
|
|
|
|
to_return = {remove_adapter_name(k): v for k, v in to_return.items()} |
|
|
return to_return |
|
|
|
|
|
|
|
|
def _find_mismatched_keys( |
|
|
model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False |
|
|
) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]: |
|
|
if not ignore_mismatched_sizes: |
|
|
return peft_model_state_dict, [] |
|
|
|
|
|
mismatched = [] |
|
|
state_dict = model.state_dict() |
|
|
for key, tensor in peft_model_state_dict.items(): |
|
|
if key not in state_dict: |
|
|
continue |
|
|
|
|
|
|
|
|
if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()): |
|
|
|
|
|
|
|
|
|
|
|
continue |
|
|
|
|
|
if state_dict[key].shape != tensor.shape: |
|
|
mismatched.append((key, tensor.shape, state_dict[key].shape)) |
|
|
|
|
|
for key, _, _ in mismatched: |
|
|
del peft_model_state_dict[key] |
|
|
|
|
|
return peft_model_state_dict, mismatched |
|
|
|
|
|
|
|
|
def _insert_adapter_name_into_state_dict( |
|
|
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str |
|
|
) -> dict[str, torch.Tensor]: |
|
|
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name.""" |
|
|
peft_model_state_dict = {} |
|
|
for key, val in state_dict.items(): |
|
|
if parameter_prefix in key: |
|
|
_, _, suffix = key.rpartition(parameter_prefix) |
|
|
if "." in suffix: |
|
|
suffix_to_replace = ".".join(suffix.split(".")[1:]) |
|
|
|
|
|
|
|
|
key = re.sub(re.escape(suffix_to_replace) + r"$", f"{adapter_name}.{suffix_to_replace}", key) |
|
|
else: |
|
|
key = f"{key}.{adapter_name}" |
|
|
peft_model_state_dict[key] = val |
|
|
else: |
|
|
peft_model_state_dict[key] = val |
|
|
return peft_model_state_dict |
|
|
|
|
|
|
|
|
def set_peft_model_state_dict( |
|
|
model, |
|
|
peft_model_state_dict, |
|
|
adapter_name="default", |
|
|
ignore_mismatched_sizes: bool = False, |
|
|
low_cpu_mem_usage: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Set the state dict of the PEFT model. |
|
|
|
|
|
Given a PEFT `state_dict` (as returned by [`get_peft_model_state_dict`]), insert the weights into the model. The |
|
|
model needs to have the PEFT adapters already in place (e.g. via [`inject_adapter_in_model`]). |
|
|
|
|
|
Setting the adapter weights also takes care of re-inserting the adapter name. This name may be a different name |
|
|
than the one originally used to train the adapter. |
|
|
|
|
|
Args: |
|
|
model ([`PeftModel`]): |
|
|
The Peft model. |
|
|
peft_model_state_dict (`dict`): |
|
|
The state dict of the Peft model. |
|
|
adapter_name (`str`, *optional*, defaults to `"default"`): |
|
|
The name of the adapter whose state dict should be set. |
|
|
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): |
|
|
Whether to ignore mismatched in the state dict. |
|
|
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): |
|
|
This argument must be `True` if the `model` was loaded with adapter weights on the meta device, e.g. after |
|
|
calling `inject_adapter_in_model` with `low_cpu_mem_usage=True`. Otherwise, leave it as `False`. |
|
|
|
|
|
""" |
|
|
config = model.peft_config[adapter_name] |
|
|
state_dict = peft_model_state_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, AuxiliaryTrainingWrapper): |
|
|
|
|
|
|
|
|
|
|
|
key_map = module.adapter_state_dict_load_map(adapter_name) |
|
|
if name.startswith("_fsdp_wrapped_module."): |
|
|
|
|
|
|
|
|
name = name.removeprefix("_fsdp_wrapped_module.") |
|
|
for k in key_map: |
|
|
lookup_key = f"{name}.{k}" |
|
|
store_key = f"{name}.{key_map[k]}" |
|
|
|
|
|
state_dict[store_key] = peft_model_state_dict[lookup_key] |
|
|
|
|
|
|
|
|
del state_dict[lookup_key] |
|
|
|
|
|
if config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: |
|
|
peft_model_state_dict = state_dict |
|
|
elif config.peft_type == PeftType.XLORA: |
|
|
peft_model_state_dict = state_dict |
|
|
elif config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING: |
|
|
peft_model_state_dict = {} |
|
|
parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] |
|
|
if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights: |
|
|
num_vectors, _ = model.vblora_vector_bank[adapter_name].shape |
|
|
state_dict_keys = list(state_dict.keys()) |
|
|
for k in state_dict_keys: |
|
|
|
|
|
|
|
|
|
|
|
if "_topk_indices" in k: |
|
|
v = state_dict[k].to(torch.long) |
|
|
original_key = k.replace("_topk_indices", "") |
|
|
|
|
|
topk_weights = state_dict[k.replace("_topk_indices", "_topk_weights")] |
|
|
|
|
|
topk_weights = torch.cat([topk_weights, 1 - topk_weights.sum(-1, keepdim=True)], dim=-1) |
|
|
|
|
|
topk_logits = torch.log(topk_weights) |
|
|
matrix = ( |
|
|
torch.zeros([*(topk_logits.shape[:-1]), num_vectors]) |
|
|
.fill_(float("-inf")) |
|
|
.to(topk_logits.device) |
|
|
.scatter(-1, v, topk_logits) |
|
|
) |
|
|
|
|
|
state_dict[original_key] = matrix |
|
|
|
|
|
del state_dict[k] |
|
|
del state_dict[k.replace("_topk_indices", "_topk_weights")] |
|
|
|
|
|
peft_model_state_dict = _insert_adapter_name_into_state_dict( |
|
|
state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix |
|
|
) |
|
|
|
|
|
if config.peft_type == PeftType.ADALORA: |
|
|
rank_pattern = config.rank_pattern |
|
|
if rank_pattern is not None: |
|
|
model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) |
|
|
elif config.peft_type == PeftType.SHIRA: |
|
|
if platform.system() == "Windows": |
|
|
warnings.warn( |
|
|
"Windows has issues saving integers into safetensors. Hence, we had converted shira_indices " |
|
|
"to float32 before saving on Windows OS. The shira_indices will always be converted to integers " |
|
|
"when loading." |
|
|
) |
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, "shira_indices"): |
|
|
|
|
|
if f"{name}.shira_indices.{adapter_name}" in peft_model_state_dict: |
|
|
shira_indices_values = peft_model_state_dict.pop(f"{name}.shira_indices.{adapter_name}") |
|
|
|
|
|
|
|
|
|
|
|
module.shira_indices[adapter_name] = shira_indices_values.to(torch.int) |
|
|
elif config.peft_type == PeftType.VERA: |
|
|
if config.save_projection and "base_model.vera_A" not in peft_model_state_dict: |
|
|
raise ValueError( |
|
|
"Specified to load vera_A and vera_B from state dictionary however they were not present!" |
|
|
) |
|
|
elif not config.save_projection and "base_model.vera_A" in peft_model_state_dict: |
|
|
warnings.warn( |
|
|
"Specified to not load vera_A and vera_B from state dictionary however they are present in state" |
|
|
" dictionary! Consider using them to ensure checkpoint loading is correct on all platforms using" |
|
|
" `peft_config.save_projection = True`" |
|
|
) |
|
|
elif not config.save_projection: |
|
|
warnings.warn( |
|
|
"Specified to not load vera_A and vera_B from state dictionary. This means we will be relying on" |
|
|
" PRNG initialisation to restore these projections using `config.projection_prng_key`, which may" |
|
|
" not be accurate on all system configurations." |
|
|
) |
|
|
elif config.peft_type == PeftType.LORA: |
|
|
|
|
|
|
|
|
old_dora_suffix = f"lora_magnitude_vector.{adapter_name}" |
|
|
|
|
|
def renamed_dora_weights(k): |
|
|
if k.endswith(old_dora_suffix): |
|
|
k = k + ".weight" |
|
|
return k |
|
|
|
|
|
peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()} |
|
|
elif config.peft_type == PeftType.OFT: |
|
|
if any(".oft_r." in key for key in peft_model_state_dict): |
|
|
raise ValueError( |
|
|
"Trying to load old OFT checkpoint, which is no longer supported. Please install PEFT <= v0.15.2 to load it or train a new OFT adapter." |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
peft_model_state_dict, mismatched_keys = _find_mismatched_keys( |
|
|
model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes |
|
|
) |
|
|
if low_cpu_mem_usage: |
|
|
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True) |
|
|
|
|
|
for module in model.modules(): |
|
|
if hasattr(module, "_move_adapter_to_device_of_base_layer"): |
|
|
module._move_adapter_to_device_of_base_layer(adapter_name) |
|
|
else: |
|
|
load_result = model.load_state_dict(peft_model_state_dict, strict=False) |
|
|
|
|
|
if config.is_prompt_learning: |
|
|
model.prompt_encoder[adapter_name].embedding.load_state_dict( |
|
|
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True |
|
|
) |
|
|
|
|
|
if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: |
|
|
model.prompt_encoder[adapter_name].load_state_dict(peft_model_state_dict, strict=False) |
|
|
|
|
|
if mismatched_keys: |
|
|
|
|
|
mismatched_warning = "\n".join( |
|
|
[ |
|
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" |
|
|
for key, shape1, shape2 in mismatched_keys |
|
|
] |
|
|
) |
|
|
msg = ( |
|
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint " |
|
|
f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}." |
|
|
) |
|
|
warnings.warn(msg) |
|
|
return load_result |
|
|
|
|
|
|
|
|
|
|
|
def torch_load(*args, weights_only=True, **kwargs): |
|
|
"""Call torch.load and handle weights_only. |
|
|
|
|
|
Defaults to weights_only=True to anticipate upcoming switch on the PyTorch side. |
|
|
|
|
|
""" |
|
|
return torch.load(*args, weights_only=weights_only, **kwargs) |
|
|
|
|
|
|
|
|
def load_peft_weights( |
|
|
model_id: str, device: Optional[str] = None, key_mapping: Optional[dict[str, str]] = None, **hf_hub_download_kwargs |
|
|
) -> dict: |
|
|
r""" |
|
|
A helper method to load the PEFT weights from the HuggingFace Hub or locally |
|
|
|
|
|
Args: |
|
|
model_id (`str`): |
|
|
The local path to the adapter weights or the name of the adapter to load from the HuggingFace Hub. |
|
|
device (`str`): |
|
|
The device to load the weights onto. |
|
|
key_mapping (dict, *optional*, defaults to None) |
|
|
Extra mapping of PEFT `state_dict` keys applied before loading the `state_dict`. When this mapping is |
|
|
applied, the PEFT-specific `"base_model.model"` prefix is removed beforehand and the adapter name (e.g. |
|
|
`"default"`) is not inserted yet. Only pass this argument if you know what you're doing. |
|
|
hf_hub_download_kwargs (`dict`): |
|
|
Additional arguments to pass to the `hf_hub_download` method when loading from the HuggingFace Hub. |
|
|
""" |
|
|
path = ( |
|
|
os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) |
|
|
if hf_hub_download_kwargs.get("subfolder", None) is not None |
|
|
else model_id |
|
|
) |
|
|
|
|
|
if device is None: |
|
|
device = infer_device() |
|
|
|
|
|
def get_hub_filename(use_safetensors=True): |
|
|
weights_name = SAFETENSORS_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME |
|
|
return ( |
|
|
os.path.join(hf_hub_download_kwargs["subfolder"], weights_name) |
|
|
if hf_hub_download_kwargs.get("subfolder", None) is not None |
|
|
else weights_name |
|
|
) |
|
|
|
|
|
if "user_agent" not in hf_hub_download_kwargs: |
|
|
hf_hub_download_kwargs["user_agent"] = http_user_agent() |
|
|
|
|
|
if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): |
|
|
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) |
|
|
use_safetensors = True |
|
|
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): |
|
|
filename = os.path.join(path, WEIGHTS_NAME) |
|
|
use_safetensors = False |
|
|
elif huggingface_hub.constants.HF_HUB_OFFLINE: |
|
|
|
|
|
hub_filename = get_hub_filename(use_safetensors=True) |
|
|
hf_hub_download_kwargs.pop("local_files_only", None) |
|
|
try: |
|
|
filename = hf_hub_download(model_id, hub_filename, local_files_only=True, **hf_hub_download_kwargs) |
|
|
use_safetensors = True |
|
|
except LocalEntryNotFoundError: |
|
|
|
|
|
|
|
|
hub_filename = get_hub_filename(use_safetensors=False) |
|
|
filename = hf_hub_download(model_id, hub_filename, local_files_only=True, **hf_hub_download_kwargs) |
|
|
use_safetensors = False |
|
|
else: |
|
|
token = hf_hub_download_kwargs.get("token", None) |
|
|
if token is None: |
|
|
token = hf_hub_download_kwargs.get("use_auth_token", None) |
|
|
|
|
|
hub_filename = get_hub_filename(use_safetensors=True) |
|
|
has_remote_safetensors_file = file_exists( |
|
|
repo_id=model_id, |
|
|
filename=hub_filename, |
|
|
revision=hf_hub_download_kwargs.get("revision", None), |
|
|
repo_type=hf_hub_download_kwargs.get("repo_type", None), |
|
|
token=token, |
|
|
) |
|
|
use_safetensors = has_remote_safetensors_file |
|
|
|
|
|
if has_remote_safetensors_file: |
|
|
|
|
|
filename = hf_hub_download( |
|
|
model_id, |
|
|
SAFETENSORS_WEIGHTS_NAME, |
|
|
**hf_hub_download_kwargs, |
|
|
) |
|
|
else: |
|
|
try: |
|
|
filename = hf_hub_download(model_id, WEIGHTS_NAME, **hf_hub_download_kwargs) |
|
|
except EntryNotFoundError: |
|
|
raise ValueError( |
|
|
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " |
|
|
f"Please check that the file {WEIGHTS_NAME} or {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}." |
|
|
) |
|
|
|
|
|
if use_safetensors: |
|
|
if hasattr(torch.backends, "mps") and (device == torch.device("mps")): |
|
|
adapters_weights = safe_load_file(filename, device="cpu") |
|
|
else: |
|
|
adapters_weights = safe_load_file(filename, device=device) |
|
|
else: |
|
|
adapters_weights = torch_load(filename, map_location=torch.device(device)) |
|
|
|
|
|
if not key_mapping: |
|
|
remapped_adapters_weights = adapters_weights |
|
|
else: |
|
|
|
|
|
|
|
|
remapped_adapters_weights = {} |
|
|
for key, val in adapters_weights.items(): |
|
|
if key.startswith("base_model.model."): |
|
|
prefix = "base_model.model." |
|
|
elif key.startswith("base_model."): |
|
|
prefix = "base_model." |
|
|
else: |
|
|
raise ValueError( |
|
|
"An error occurred while trying to load a PEFT state_dict with key_mapping. This should not " |
|
|
"happen. Please open an issue on https://github.com/huggingface/peft/issues and report the error." |
|
|
) |
|
|
|
|
|
key = key.removeprefix(prefix) |
|
|
for pattern, replacement in key_mapping.items(): |
|
|
key_new, n_replace = re.subn(pattern, replacement, key) |
|
|
|
|
|
if n_replace > 0: |
|
|
key = key_new |
|
|
break |
|
|
key_with_prefix = f"{prefix}{key}" |
|
|
remapped_adapters_weights[key_with_prefix] = val |
|
|
|
|
|
return remapped_adapters_weights |
|
|
|