| import torch | |
| from torch import nn | |
| def replace_linear_with_lora( | |
| module: nn.Module, | |
| max_rank: int, | |
| scale: float = 1.0, | |
| ) -> None: | |
| for name, child in module.named_children(): | |
| if isinstance(child, nn.Linear): | |
| new_lora = LinearLora( | |
| in_features=child.in_features, | |
| out_features=child.out_features, | |
| bias=child.bias, | |
| rank=max_rank, | |
| scale=scale, | |
| dtype=child.weight.dtype, | |
| device=child.weight.device, | |
| ) | |
| new_lora.weight = child.weight | |
| new_lora.bias = child.bias if child.bias is not None else None | |
| setattr(module, name, new_lora) | |
| else: | |
| replace_linear_with_lora( | |
| module=child, | |
| max_rank=max_rank, | |
| scale=scale, | |
| ) | |
| class LinearLora(nn.Linear): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| bias: bool, | |
| rank: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| lora_bias: bool = True, | |
| scale: float = 1.0, | |
| *args, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__( | |
| in_features=in_features, | |
| out_features=out_features, | |
| bias=bias is not None, | |
| device=device, | |
| dtype=dtype, | |
| *args, | |
| **kwargs, | |
| ) | |
| assert isinstance(scale, float), "scale must be a float" | |
| self.scale = scale | |
| self.rank = rank | |
| self.lora_bias = lora_bias | |
| self.dtype = dtype | |
| self.device = device | |
| if rank > (new_rank := min(self.out_features, self.in_features)): | |
| self.rank = new_rank | |
| self.lora_A = nn.Linear( | |
| in_features=in_features, | |
| out_features=self.rank, | |
| bias=False, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| self.lora_B = nn.Linear( | |
| in_features=self.rank, | |
| out_features=out_features, | |
| bias=self.lora_bias, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def set_scale(self, scale: float) -> None: | |
| assert isinstance(scale, float), "scalar value must be a float" | |
| self.scale = scale | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| base_out = super().forward(input) | |
| _lora_out_B = self.lora_B(self.lora_A(input)) | |
| lora_update = _lora_out_B * self.scale | |
| return base_out + lora_update | |