|  | import math | 
					
						
						|  | from abc import ABC, abstractmethod | 
					
						
						|  | from logging import getLogger | 
					
						
						|  | from typing import Literal | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  |  | 
					
						
						|  | from .parametrized_layer import Parametrization | 
					
						
						|  | from .utils import use_init_empty_weights | 
					
						
						|  |  | 
					
						
						|  | logger = getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CompressionCriterion(ABC): | 
					
						
						|  | """ | 
					
						
						|  | Abstract class for compression criterion of a (target) parameter of a parametrized module. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def __call__(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | x: A tensor of any shape | 
					
						
						|  |  | 
					
						
						|  | Returns: A boolean mask of the same shape as `x` where `False` indicates that the entry can be removed. | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ThresholdCriterion(CompressionCriterion): | 
					
						
						|  | """ | 
					
						
						|  | Compression criterion based on a threshold. All entries below `self.threshold` can be removed. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, threshold: float = 0.0): | 
					
						
						|  | self.threshold = threshold | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return x > self.threshold | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ProjectedLinearParametrization(Parametrization, ABC): | 
					
						
						|  | """ | 
					
						
						|  | Implementation of a linear layer parametrization, factorizing the weight matrix as | 
					
						
						|  | `weight = ortho.weight @ torch.diag(mask) @ base.weight`. | 
					
						
						|  | Here, `ortho` is a linear layer with orthogonal columns, `mask` represents a (binary) diagonal matrix | 
					
						
						|  | that can be pruned, and `base` is a linear layer (determined by the choice of `ortho`). | 
					
						
						|  | Any child class needs to implement `_ortho_init` which creates `ortho`. Based on this, `mask` and `base` are | 
					
						
						|  | initialized such that the original weight matrix is obtained at initialization. | 
					
						
						|  |  | 
					
						
						|  | `mask` corresponds to the only target parameter of this parametrization. Pruning it will result in | 
					
						
						|  | a low-rank matrix representation of the parametrized linear module. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | base_class = nn.Linear | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | mask_func: Literal["ste", "relu", "none"] = "ste", | 
					
						
						|  | mask_scaling_factor: float | str = "norm", | 
					
						
						|  | compression_criterion: CompressionCriterion = ThresholdCriterion(), | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | mask_func: A function applied to the mask parameter in each forward pass implementing | 
					
						
						|  | custom functionalities. Available options: ["ste", "relu", "none"]. | 
					
						
						|  | "ste" means using a straight-through estimator, i.e., in the forward pass, `mask` is binarized, which | 
					
						
						|  | is ignored in the backward pass. Before `mask` passed through a ReLU activation. | 
					
						
						|  | "relu" means that `mask` is passed through a ReLU activation. | 
					
						
						|  | "none" means that `mask` is not modified. | 
					
						
						|  | mask_scaling_factor: Conceptually, `mask` is initialized with ones, but rescaling to a smaller value | 
					
						
						|  | can vastly improve the training speed. `mask_scaling_factor` specifies this rescaling factor. | 
					
						
						|  | The rescaling should be compensated by scaling `ortho` accordingly in `self._ortho_init`. | 
					
						
						|  | If `mask_scaling_factor='norm'`, the scaling factor is chosen such that `mask` has unit L2 norm | 
					
						
						|  | (note that this can lead to a different behavior in model tuning than for a fixed factor | 
					
						
						|  | when some target parameters have different number of elements). | 
					
						
						|  | compression_criterion: `CompressionCriterion` to be used in `self.reset_target_params(mode="compress")`. | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.mask_func = { | 
					
						
						|  | "ste": mask_func_ste, | 
					
						
						|  | "relu": mask_func_relu, | 
					
						
						|  | "none": mask_func_none, | 
					
						
						|  | }[mask_func] | 
					
						
						|  | self._mask_scaling_factor = mask_scaling_factor | 
					
						
						|  | self.compression_criterion = compression_criterion | 
					
						
						|  |  | 
					
						
						|  | def _forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | x = self.base(x) | 
					
						
						|  | x = self.mask_func(self.mask, self.mask_scaling_factor) * x | 
					
						
						|  | x = self.ortho(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def _weight(self) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | mask = self.mask_func(self.mask, self.mask_scaling_factor) | 
					
						
						|  | return self.ortho.weight @ torch.diag(mask) @ self.base.weight | 
					
						
						|  |  | 
					
						
						|  | def _bias(self) -> torch.Tensor | None: | 
					
						
						|  | return self.ortho.bias | 
					
						
						|  |  | 
					
						
						|  | def _initialize(self, base_module: base_class) -> None: | 
					
						
						|  | factory_kwargs = {"device": base_module.weight.device, "dtype": base_module.weight.dtype} | 
					
						
						|  | in_dim, out_dim = base_module.in_features, base_module.out_features | 
					
						
						|  | proj_dim = min(in_dim, out_dim) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.add_module( | 
					
						
						|  | "ortho", | 
					
						
						|  | nn.Linear(in_features=proj_dim, out_features=out_dim, bias=base_module.bias is not None, **factory_kwargs), | 
					
						
						|  | ) | 
					
						
						|  | self._ortho_init(base_module.weight) | 
					
						
						|  | if base_module.bias is not None: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.ortho.bias.data.copy_(base_module.bias.data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | base = base_module.__class__(in_features=in_dim, out_features=proj_dim, bias=False, **factory_kwargs) | 
					
						
						|  | base.weight.data.copy_(self.ortho.weight.data.T @ base_module.weight.data) | 
					
						
						|  | self.add_module("base", base) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.register_parameter("mask", torch.nn.Parameter(torch.ones(proj_dim, **factory_kwargs))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.reset_target_params() | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def _ortho_init(self, weight: torch.Tensor) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Initialize ortho layer. Must be implemented by child class. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | weight: Weight matrix of the original linear layer module. | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def get_target_params(self) -> dict[str, torch.nn.Parameter]: | 
					
						
						|  | return {"mask": self.mask} | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def mask_scaling_factor(self) -> float: | 
					
						
						|  | if self._mask_scaling_factor == "norm": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._mask_scaling_factor = 1 / math.sqrt(self.mask.numel()) | 
					
						
						|  | return self._mask_scaling_factor | 
					
						
						|  | elif isinstance(self._mask_scaling_factor, float): | 
					
						
						|  | return self._mask_scaling_factor | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Invalid mask_scaling_factor: {self._mask_scaling_factor}") | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def in_features(self) -> int: | 
					
						
						|  | return self.base.in_features | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def out_features(self) -> int: | 
					
						
						|  | return self.ortho.out_features | 
					
						
						|  |  | 
					
						
						|  | def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None: | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | if mode == "full": | 
					
						
						|  |  | 
					
						
						|  | self.mask.data = torch.ones_like(self.mask.data) * self.mask_scaling_factor | 
					
						
						|  | elif mode == "nonzero": | 
					
						
						|  |  | 
					
						
						|  | self.mask.data[self.mask.data > 0] = 1.0 * self.mask_scaling_factor | 
					
						
						|  | self.mask.data[self.mask.data < 0] = 0.0 | 
					
						
						|  | elif mode == "compress": | 
					
						
						|  | if self.compression_criterion is None: | 
					
						
						|  | logger.warning("Compression criterion is not set. No op...") | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | dim_select = self.compression_criterion(self.mask) | 
					
						
						|  |  | 
					
						
						|  | new_base = new_linear_from_mask(self.base, dim_select, column_select=False) | 
					
						
						|  | new_ortho = new_linear_from_mask(self.ortho, dim_select, column_select=True) | 
					
						
						|  | new_mask = self.mask[dim_select].clone().detach() | 
					
						
						|  | del self.mask, self.base, self.ortho | 
					
						
						|  | self.register_module("base", new_base) | 
					
						
						|  | self.register_module("ortho", new_ortho) | 
					
						
						|  | self.register_parameter("mask", nn.Parameter(new_mask)) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Invalid mode: {mode}") | 
					
						
						|  |  | 
					
						
						|  | def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int: | 
					
						
						|  | if not compressed: | 
					
						
						|  |  | 
					
						
						|  | num_params = self.in_features * self.out_features | 
					
						
						|  | if self.bias is not None: | 
					
						
						|  | num_params += self.out_features | 
					
						
						|  | return num_params | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | if target_params is not None: | 
					
						
						|  | sparsity = mask_sparsity(target_params["mask"] != 0.0, threshold=0.0) | 
					
						
						|  | else: | 
					
						
						|  | sparsity = mask_sparsity(self.mask) | 
					
						
						|  |  | 
					
						
						|  | num_params = self.in_features * sparsity + sparsity * self.out_features | 
					
						
						|  | if self.bias is not None: | 
					
						
						|  | num_params += self.out_features | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_params = min(self.get_num_params(compressed=False), num_params) | 
					
						
						|  | return num_params | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SVDLinearParametrization(ProjectedLinearParametrization): | 
					
						
						|  | """ | 
					
						
						|  | Implementation of a linear layer parametrization using SVD decomposition. | 
					
						
						|  | If the SVD of weight is U * S * V^T, then `ortho.weight = U` and `base.weight = S * V^T`. | 
					
						
						|  | As base is computed automatically by `_initialize`, `_ortho_init` only needs to compute U and | 
					
						
						|  | scale it properly with `mask_scaling_factor`. The singular values S are buffered just in case they are needed | 
					
						
						|  | in the tuning process. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def _ortho_init(self, weight: torch.Tensor) -> None: | 
					
						
						|  | k = min(weight.shape[0], weight.shape[1]) | 
					
						
						|  | if use_init_empty_weights.get(): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger.debug("Parametrizing with empty weights.") | 
					
						
						|  | U = torch.empty(weight.shape[0], k) | 
					
						
						|  | S = torch.empty(k, 1) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | U, S, _ = torch.linalg.svd(weight.detach().float(), full_matrices=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self._mask_scaling_factor == "norm": | 
					
						
						|  | U = math.pow(k, 1 / 4) * U | 
					
						
						|  | else: | 
					
						
						|  | U = math.sqrt(1 / self._mask_scaling_factor) * U | 
					
						
						|  | factory_kwargs = {"device": weight.device, "dtype": weight.dtype} | 
					
						
						|  | self.ortho.weight.data.copy_(U.detach().to(**factory_kwargs)) | 
					
						
						|  | self.register_buffer("S", S.detach().flatten().to(**factory_kwargs)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mask_func_ste(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | mask = F.relu(mask) | 
					
						
						|  | return (mask > 0).to(mask.dtype).detach() * mask_scaling_factor + mask - mask.detach() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mask_func_relu(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | return F.relu(mask) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mask_func_none(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | return mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mask_sparsity(mask: torch.Tensor, threshold: float = 0.0) -> int: | 
					
						
						|  | """Simple util function to compute the number of non-zero elements of a mask, where an element is considered | 
					
						
						|  | non-zero if its value is strictly greater than `threshold`.""" | 
					
						
						|  | return torch.count_nonzero(mask > threshold).item() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def new_linear_from_mask(module: nn.Linear, dim_select: torch.Tensor, column_select=True) -> nn.Linear: | 
					
						
						|  | """ | 
					
						
						|  | Creates a new linear layer from an existing one based on a mask indicating which columns/rows to keep. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | module: Module to be pruned. | 
					
						
						|  | dim_select: Boolean tensor mask indicating which columns/rows to keep. | 
					
						
						|  | column_select: Whether to prune columns (True) or rows (False) according to `dim_select`. | 
					
						
						|  |  | 
					
						
						|  | Returns: Pruned module. | 
					
						
						|  | """ | 
					
						
						|  | assert dim_select.dtype == torch.bool, "dim_select must be boolean" | 
					
						
						|  |  | 
					
						
						|  | in_features, out_features = module.in_features, module.out_features | 
					
						
						|  | sparsity = dim_select.sum().item() | 
					
						
						|  | if column_select: | 
					
						
						|  | in_features = sparsity | 
					
						
						|  | else: | 
					
						
						|  | out_features = sparsity | 
					
						
						|  | new_module = module.__class__( | 
					
						
						|  | in_features=in_features, | 
					
						
						|  | out_features=out_features, | 
					
						
						|  | bias=module.bias is not None, | 
					
						
						|  | device=module.weight.device, | 
					
						
						|  | dtype=module.weight.dtype, | 
					
						
						|  | ) | 
					
						
						|  | weight = module.weight.data | 
					
						
						|  | if column_select: | 
					
						
						|  | weight = weight[:, dim_select] | 
					
						
						|  | else: | 
					
						
						|  | weight = weight[dim_select, :] | 
					
						
						|  | new_module.weight.data.copy_(weight.detach()) | 
					
						
						|  |  | 
					
						
						|  | if new_module.bias is not None: | 
					
						
						|  | if column_select: | 
					
						
						|  | new_module.bias.data.copy_(module.bias.detach()) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | new_module.bias.data.copy_(module.bias[dim_select].detach()) | 
					
						
						|  |  | 
					
						
						|  | return new_module | 
					
						
						|  |  |