File size: 4,840 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import torch
from torch import nn
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_lorafa_optimizer
from .testing_utils import torch_device
class SimpleNet(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.embedding = nn.Embedding(100, 20)
self.layer_norm = nn.LayerNorm(20)
self.lin0 = nn.Linear(20, 20, bias=bias)
self.relu = nn.ReLU()
self.lin1 = nn.Linear(20, 16, bias=bias)
def forward(self, X):
X = self.lin0(self.layer_norm(self.embedding(X)))
X = self.relu(X)
X = self.lin1(X)
return X
def test_lorafa_init_default():
"""
Test if the optimizer is correctly created
"""
lora_rank = 16
lora_alpha = 32
lr = 7e-5
model = SimpleNet()
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["lin0", "lin1"],
bias="none",
)
model = get_peft_model(model, config)
optimizer = create_lorafa_optimizer(model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr)
assert math.isclose(optimizer.param_groups[0]["scaling_factor"], lora_alpha / lora_rank, rel_tol=1e-9, abs_tol=0.0)
all_A_fixed = True
all_B_trainable = True
assert optimizer is not None
for name, param in model.named_parameters():
if "lora_A" in name:
all_A_fixed &= not param.requires_grad
elif "lora_B" in name:
all_B_trainable &= param.requires_grad
assert all_A_fixed and all_B_trainable
def test_lorafa_init_rslora():
"""
Test if the optimizer is correctly created when use_rslora = True
"""
lora_rank = 16
lora_alpha = 32
lr = 7e-5
model = SimpleNet()
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["lin0", "lin1"],
bias="none",
)
model = get_peft_model(model, config)
optimizer = create_lorafa_optimizer(model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr, use_rslora=True)
assert math.isclose(
optimizer.param_groups[0]["scaling_factor"], lora_alpha / math.sqrt(lora_rank), rel_tol=1e-9, abs_tol=0.0
)
def test_LoraFAOptimizer_step():
"""
Test if the optimizer's step function runs without any exception and checks specific conditions on lora_A and
lora_B weights.
"""
lora_rank = 16
lora_alpha = 32
lr = 7e-5
num_steps = 5
model = SimpleNet()
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["lin0", "lin1"],
bias="none",
)
model = get_peft_model(model, config).to(torch_device)
optimizer = create_lorafa_optimizer(model=model, r=16, lora_alpha=32, lr=7e-5)
loss = torch.nn.CrossEntropyLoss()
# Save initial weights of lora_A
initial_lora_A_weights = {name: param.clone() for name, param in model.named_parameters() if "lora_A" in name}
# Ensure lora_B is initialized to zero
for name, param in model.named_parameters():
if "lora_B" in name:
assert torch.all(param == 0), f"lora_B weights not initialized to zero for {name}"
for _ in range(num_steps): # Run the optimizer step multiple times
# Generate random input and label for each step
x = torch.randint(100, (2, 4, 10)).to(torch_device)
output = model(x).permute(0, 3, 1, 2)
label = torch.randint(16, (2, 4, 10)).to(torch_device)
# Calculate loss and perform backward pass
loss_value = loss(output, label)
loss_value.backward()
# Perform optimizer step
optimizer.step()
# Zero the gradients after each step to prevent accumulation
optimizer.zero_grad()
# Check if lora_A weights have not changed
for name, param in model.named_parameters():
if "lora_A" in name:
assert torch.equal(param, initial_lora_A_weights[name]), f"lora_A weights changed for {name}"
# Check if lora_B weights are non-zero
for name, param in model.named_parameters():
if "lora_B" in name:
assert torch.any(param != 0), f"lora_B weights are still zero for {name}"
|