|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import warnings |
|
|
from typing import Optional |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from .auto import MODEL_TYPE_TO_PEFT_MODEL_MAPPING |
|
|
from .config import PeftConfig |
|
|
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING |
|
|
from .mixed_model import PeftMixedModel |
|
|
from .peft_model import PeftModel |
|
|
from .tuners.tuners_utils import BaseTuner, BaseTunerLayer |
|
|
from .utils import _prepare_prompt_learning_config |
|
|
|
|
|
|
|
|
def get_peft_model( |
|
|
model: PreTrainedModel, |
|
|
peft_config: PeftConfig, |
|
|
adapter_name: str = "default", |
|
|
mixed: bool = False, |
|
|
autocast_adapter_dtype: bool = True, |
|
|
revision: Optional[str] = None, |
|
|
low_cpu_mem_usage: bool = False, |
|
|
) -> PeftModel | PeftMixedModel: |
|
|
""" |
|
|
Returns a Peft model object from a model and a config, where the model will be modified in-place. |
|
|
|
|
|
Args: |
|
|
model ([`transformers.PreTrainedModel`]): |
|
|
Model to be wrapped. |
|
|
peft_config ([`PeftConfig`]): |
|
|
Configuration object containing the parameters of the Peft model. |
|
|
adapter_name (`str`, `optional`, defaults to `"default"`): |
|
|
The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). |
|
|
mixed (`bool`, `optional`, defaults to `False`): |
|
|
Whether to allow mixing different (compatible) adapter types. |
|
|
autocast_adapter_dtype (`bool`, *optional*): |
|
|
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights |
|
|
using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect |
|
|
select PEFT tuners. |
|
|
revision (`str`, `optional`, defaults to `main`): |
|
|
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for |
|
|
the base model |
|
|
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): |
|
|
Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as |
|
|
False if you intend on training the model, unless the adapter weights will be replaced by different weights |
|
|
before training starts. |
|
|
""" |
|
|
model_config = BaseTuner.get_model_config(model) |
|
|
old_name = peft_config.base_model_name_or_path |
|
|
new_name = model.__dict__.get("name_or_path", None) |
|
|
peft_config.base_model_name_or_path = new_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any(isinstance(module, BaseTunerLayer) for module in model.modules()): |
|
|
warnings.warn( |
|
|
"You are trying to modify a model with PEFT for a second time. If you want to reload the model with a " |
|
|
"different config, make sure to call `.unload()` before." |
|
|
) |
|
|
|
|
|
if (old_name is not None) and (old_name != new_name): |
|
|
warnings.warn( |
|
|
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " |
|
|
"Please ensure that the correct base model is loaded when loading this checkpoint." |
|
|
) |
|
|
|
|
|
if revision is not None: |
|
|
if peft_config.revision is not None and peft_config.revision != revision: |
|
|
warnings.warn( |
|
|
f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}" |
|
|
) |
|
|
peft_config.revision = revision |
|
|
|
|
|
if ( |
|
|
(isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"])) |
|
|
and (peft_config.init_lora_weights == "eva") |
|
|
and not low_cpu_mem_usage |
|
|
): |
|
|
warnings.warn( |
|
|
"lora with eva initialization used with low_cpu_mem_usage=False. " |
|
|
"Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization." |
|
|
) |
|
|
|
|
|
prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type) |
|
|
if prefix and adapter_name in prefix: |
|
|
warnings.warn( |
|
|
f"Adapter name '{adapter_name}' should not be contained in the prefix '{prefix}'. " |
|
|
"This may lead to reinitialization of the adapter weights during loading." |
|
|
) |
|
|
|
|
|
if mixed: |
|
|
|
|
|
return PeftMixedModel(model, peft_config, adapter_name=adapter_name) |
|
|
|
|
|
|
|
|
|
|
|
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: |
|
|
return PeftModel( |
|
|
model, |
|
|
peft_config, |
|
|
adapter_name=adapter_name, |
|
|
autocast_adapter_dtype=autocast_adapter_dtype, |
|
|
low_cpu_mem_usage=low_cpu_mem_usage, |
|
|
) |
|
|
|
|
|
if peft_config.is_prompt_learning: |
|
|
peft_config = _prepare_prompt_learning_config(peft_config, model_config) |
|
|
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( |
|
|
model, |
|
|
peft_config, |
|
|
adapter_name=adapter_name, |
|
|
autocast_adapter_dtype=autocast_adapter_dtype, |
|
|
low_cpu_mem_usage=low_cpu_mem_usage, |
|
|
) |
|
|
|