Spaces:
Runtime error
Runtime error
| import functools | |
| import torch | |
| from accelerate.logging import get_logger | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| from .constants import FINETRAINERS_LOG_LEVEL | |
| logger = get_logger("finetrainers") # pylint: disable=invalid-name | |
| logger.setLevel(FINETRAINERS_LOG_LEVEL) | |
| def perform_peft_patches() -> None: | |
| _perform_patch_move_adapter_to_device_of_base_layer() | |
| def _perform_patch_move_adapter_to_device_of_base_layer() -> None: | |
| # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights | |
| # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of | |
| # LoRA weights from higher precision dtype. | |
| BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( | |
| BaseTunerLayer._move_adapter_to_device_of_base_layer | |
| ) | |
| def _patched_move_adapter_to_device_of_base_layer(func) -> None: | |
| def wrapper(self, *args, **kwargs): | |
| with DisableTensorToDtype(): | |
| return func(self, *args, **kwargs) | |
| return wrapper | |
| class DisableTensorToDtype: | |
| def __enter__(self): | |
| self.original_to = torch.Tensor.to | |
| def modified_to(tensor, *args, **kwargs): | |
| # remove dtype from args if present | |
| args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] | |
| if "dtype" in kwargs: | |
| kwargs.pop("dtype") | |
| return self.original_to(tensor, *args, **kwargs) | |
| torch.Tensor.to = modified_to | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| torch.Tensor.to = self.original_to | |