Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -32,8 +32,6 @@ def load_lora_state(lora_model_name):
|
|
| 32 |
with open(config_path, 'r') as f:
|
| 33 |
lora_config = json.load(f)
|
| 34 |
|
| 35 |
-
scale = lora_config['lora_alpha'] / lora_config['r']
|
| 36 |
-
|
| 37 |
# Download adapter weights
|
| 38 |
try:
|
| 39 |
adapter_path = hf_hub_download(
|
|
@@ -52,18 +50,18 @@ def load_lora_state(lora_model_name):
|
|
| 52 |
)
|
| 53 |
lora_state = torch.load(adapter_path, map_location='cpu')
|
| 54 |
|
| 55 |
-
return lora_state,
|
| 56 |
|
| 57 |
def find_lora_weights(lora_state, key):
|
| 58 |
"""Find corresponding LoRA A and B weights for a given key"""
|
| 59 |
lora_A = None
|
| 60 |
lora_B = None
|
| 61 |
|
| 62 |
-
# Remove .weight suffix
|
| 63 |
-
clean_key = key.
|
| 64 |
|
| 65 |
for lora_key, lora_weight in lora_state.items():
|
| 66 |
-
if clean_key in lora_key
|
| 67 |
if 'lora_A' in lora_key:
|
| 68 |
lora_A = lora_weight
|
| 69 |
elif 'lora_B' in lora_key:
|
|
@@ -118,17 +116,27 @@ def download_and_upload_non_model_files(base_model_name, output_repo_name):
|
|
| 118 |
shutil.rmtree(temp_config_dir, ignore_errors=True)
|
| 119 |
|
| 120 |
def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
|
| 121 |
-
multiplicative_lora, progress=gr.Progress()):
|
| 122 |
temp_lora_dir = None
|
| 123 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
login(hf_token)
|
| 125 |
|
| 126 |
progress(0.1, desc="Loading LoRA adapter...")
|
| 127 |
info_fn("Loading LoRA adapter...")
|
| 128 |
|
| 129 |
# Load LoRA state (this downloads the adapter)
|
| 130 |
-
lora_state,
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
progress(0.2, desc="Creating output repository...")
|
| 134 |
|
|
@@ -157,6 +165,18 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
|
|
| 157 |
|
| 158 |
info_fn(f"Found {len(shard_files)} model shards to process")
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
merged_tensors = 0
|
| 161 |
total_shards = len(shard_files)
|
| 162 |
|
|
@@ -194,29 +214,47 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
|
|
| 194 |
lora_A, lora_B = find_lora_weights(lora_state, key)
|
| 195 |
|
| 196 |
if lora_A is not None and lora_B is not None:
|
| 197 |
-
|
| 198 |
-
info_fn(f"Merging {lora_type} LoRA weights for {key}")
|
| 199 |
shard_merged_count += 1
|
| 200 |
merged_tensors += 1
|
| 201 |
|
| 202 |
# Convert to float32 for computation
|
| 203 |
original_dtype = tensor.dtype
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
lora_B_f32 = lora_B.to(torch.float32)
|
| 207 |
|
| 208 |
if multiplicative_lora:
|
| 209 |
-
#
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
else:
|
| 212 |
-
#
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
# Convert back to original dtype
|
| 216 |
-
tensor =
|
| 217 |
|
| 218 |
# Clean up intermediate tensors
|
| 219 |
-
del
|
| 220 |
if torch.cuda.is_available():
|
| 221 |
torch.cuda.empty_cache()
|
| 222 |
|
|
@@ -246,7 +284,7 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
|
|
| 246 |
|
| 247 |
progress(1.0, desc="Upload completed!")
|
| 248 |
|
| 249 |
-
success_msg = f"β Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights"
|
| 250 |
info_fn("Merge completed successfully!")
|
| 251 |
|
| 252 |
return success_msg
|
|
@@ -272,15 +310,23 @@ This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a me
|
|
| 272 |
- **Streaming Processing**: Downloads β Processes β Uploads β Deletes each shard sequentially
|
| 273 |
- **Automatic Cleanup**: Temporary files are automatically removed after processing
|
| 274 |
- **Progress Tracking**: Real-time status updates throughout the merge process
|
| 275 |
-
- **Advanced Options**: Multiplicative LoRA
|
| 276 |
"""
|
| 277 |
|
| 278 |
DETAILS_TEXT = """
|
| 279 |
### How It Works
|
| 280 |
-
LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
### Memory Efficiency
|
| 286 |
- **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
|
|
@@ -328,10 +374,23 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
|
|
| 328 |
)
|
| 329 |
|
| 330 |
gr.Markdown("### Advanced Options")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
multiplicative_lora = gr.Checkbox(
|
| 332 |
label="Multiplicative LoRA",
|
| 333 |
value=False,
|
| 334 |
-
info="Apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
)
|
| 336 |
|
| 337 |
with gr.Column(scale=1):
|
|
@@ -348,7 +407,8 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
|
|
| 348 |
|
| 349 |
submit_btn.click(
|
| 350 |
fn=merge_lora_efficient,
|
| 351 |
-
inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
|
|
|
|
| 352 |
outputs=output_text
|
| 353 |
)
|
| 354 |
|
|
|
|
| 32 |
with open(config_path, 'r') as f:
|
| 33 |
lora_config = json.load(f)
|
| 34 |
|
|
|
|
|
|
|
| 35 |
# Download adapter weights
|
| 36 |
try:
|
| 37 |
adapter_path = hf_hub_download(
|
|
|
|
| 50 |
)
|
| 51 |
lora_state = torch.load(adapter_path, map_location='cpu')
|
| 52 |
|
| 53 |
+
return lora_state, lora_config, temp_lora_dir
|
| 54 |
|
| 55 |
def find_lora_weights(lora_state, key):
|
| 56 |
"""Find corresponding LoRA A and B weights for a given key"""
|
| 57 |
lora_A = None
|
| 58 |
lora_B = None
|
| 59 |
|
| 60 |
+
# Remove .weight suffix for matching
|
| 61 |
+
clean_key = key.strip('.weight')
|
| 62 |
|
| 63 |
for lora_key, lora_weight in lora_state.items():
|
| 64 |
+
if clean_key in lora_key:
|
| 65 |
if 'lora_A' in lora_key:
|
| 66 |
lora_A = lora_weight
|
| 67 |
elif 'lora_B' in lora_key:
|
|
|
|
| 116 |
shutil.rmtree(temp_config_dir, ignore_errors=True)
|
| 117 |
|
| 118 |
def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
|
| 119 |
+
scale_factor, multiplicative_lora, inverse_lora, progress=gr.Progress()):
|
| 120 |
temp_lora_dir = None
|
| 121 |
try:
|
| 122 |
+
# Validate scale factor
|
| 123 |
+
if not (0 < scale_factor < 2):
|
| 124 |
+
error_msg = "Scale factor must be in the range (0, 2)"
|
| 125 |
+
warning_fn(error_msg)
|
| 126 |
+
return f"β Error: {error_msg}"
|
| 127 |
+
|
| 128 |
login(hf_token)
|
| 129 |
|
| 130 |
progress(0.1, desc="Loading LoRA adapter...")
|
| 131 |
info_fn("Loading LoRA adapter...")
|
| 132 |
|
| 133 |
# Load LoRA state (this downloads the adapter)
|
| 134 |
+
lora_state, lora_config, temp_lora_dir = load_lora_state(lora_model_name)
|
| 135 |
+
|
| 136 |
+
# Calculate scale with user factor
|
| 137 |
+
base_scale = lora_config['lora_alpha'] / lora_config['r']
|
| 138 |
+
scale = base_scale * scale_factor
|
| 139 |
+
info_fn(f"Using LoRA scale: {scale} (base: {base_scale:.3f} Γ factor: {scale_factor})")
|
| 140 |
|
| 141 |
progress(0.2, desc="Creating output repository...")
|
| 142 |
|
|
|
|
| 165 |
|
| 166 |
info_fn(f"Found {len(shard_files)} model shards to process")
|
| 167 |
|
| 168 |
+
# Determine merge mode
|
| 169 |
+
if multiplicative_lora and inverse_lora:
|
| 170 |
+
merge_mode = "Multiplicative Inverse"
|
| 171 |
+
elif multiplicative_lora:
|
| 172 |
+
merge_mode = "Multiplicative"
|
| 173 |
+
elif inverse_lora:
|
| 174 |
+
merge_mode = "Additive Inverse"
|
| 175 |
+
else:
|
| 176 |
+
merge_mode = "Additive"
|
| 177 |
+
|
| 178 |
+
info_fn(f"Merge mode: {merge_mode}")
|
| 179 |
+
|
| 180 |
merged_tensors = 0
|
| 181 |
total_shards = len(shard_files)
|
| 182 |
|
|
|
|
| 214 |
lora_A, lora_B = find_lora_weights(lora_state, key)
|
| 215 |
|
| 216 |
if lora_A is not None and lora_B is not None:
|
| 217 |
+
info_fn(f"Merging {merge_mode} LoRA weights for {key}")
|
|
|
|
| 218 |
shard_merged_count += 1
|
| 219 |
merged_tensors += 1
|
| 220 |
|
| 221 |
# Convert to float32 for computation
|
| 222 |
original_dtype = tensor.dtype
|
| 223 |
+
tensor = tensor.to(torch.float32)
|
| 224 |
+
lora_delta = scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32)
|
|
|
|
| 225 |
|
| 226 |
if multiplicative_lora:
|
| 227 |
+
# Validate dimensions for multiplicative LoRA
|
| 228 |
+
if lora_delta.shape[0] != lora_delta.shape[1]:
|
| 229 |
+
raise ValueError(f"Multiplicative LoRA requires square delta matrix for {key}: got shape {lora_delta.shape}")
|
| 230 |
+
if lora_delta.shape[-1] != tensor.shape[-2]:
|
| 231 |
+
raise ValueError(f"Multiplicative LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
|
| 232 |
+
|
| 233 |
+
if inverse_lora:
|
| 234 |
+
# Inverse multiplicative: tensor = (I + lora_delta)^(-1) @ tensor
|
| 235 |
+
identity = torch.eye(lora_delta.shape[0], device=lora_delta.device, dtype=torch.float32)
|
| 236 |
+
inverse_matrix = torch.linalg.inv(identity + lora_delta)
|
| 237 |
+
tensor = inverse_matrix @ tensor
|
| 238 |
+
else:
|
| 239 |
+
# Forward multiplicative: tensor = (I + lora_delta) @ tensor
|
| 240 |
+
tensor += lora_delta @ tensor
|
| 241 |
else:
|
| 242 |
+
# Validate dimensions for additive LoRA
|
| 243 |
+
if lora_delta.shape != tensor.shape:
|
| 244 |
+
raise ValueError(f"Additive LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
|
| 245 |
+
|
| 246 |
+
if inverse_lora:
|
| 247 |
+
# Inverse additive: tensor = tensor - lora_delta
|
| 248 |
+
tensor -= lora_delta
|
| 249 |
+
else:
|
| 250 |
+
# Forward additive: tensor = tensor + lora_delta
|
| 251 |
+
tensor += lora_delta
|
| 252 |
|
| 253 |
# Convert back to original dtype
|
| 254 |
+
tensor = tensor.to(original_dtype)
|
| 255 |
|
| 256 |
# Clean up intermediate tensors
|
| 257 |
+
del lora_delta
|
| 258 |
if torch.cuda.is_available():
|
| 259 |
torch.cuda.empty_cache()
|
| 260 |
|
|
|
|
| 284 |
|
| 285 |
progress(1.0, desc="Upload completed!")
|
| 286 |
|
| 287 |
+
success_msg = f"β Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nMerge mode: {merge_mode}\nScale factor: {scale_factor}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights"
|
| 288 |
info_fn("Merge completed successfully!")
|
| 289 |
|
| 290 |
return success_msg
|
|
|
|
| 310 |
- **Streaming Processing**: Downloads β Processes β Uploads β Deletes each shard sequentially
|
| 311 |
- **Automatic Cleanup**: Temporary files are automatically removed after processing
|
| 312 |
- **Progress Tracking**: Real-time status updates throughout the merge process
|
| 313 |
+
- **Advanced Options**: Multiplicative LoRA, inverse merging, and custom scale factors
|
| 314 |
"""
|
| 315 |
|
| 316 |
DETAILS_TEXT = """
|
| 317 |
### How It Works
|
| 318 |
+
LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool supports four merge modes:
|
| 319 |
+
|
| 320 |
+
- **Additive LoRA**: `W_new = W + scale Γ B @ A`
|
| 321 |
+
- **Additive Inverse**: `W_new = W - scale Γ B @ A` (removes LoRA effect)
|
| 322 |
+
- **Multiplicative LoRA**: `W_new = W + scale Γ B @ A @ W`
|
| 323 |
+
- **Multiplicative Inverse**: `W_new = (I + scale Γ B @ A)^(-1) @ W`
|
| 324 |
|
| 325 |
+
### Scale Factor
|
| 326 |
+
The scale factor (0 < scale < 2) controls the strength of the LoRA merge:
|
| 327 |
+
- **1.0**: Full strength (default)
|
| 328 |
+
- **0.5**: Half strength
|
| 329 |
+
- **1.5**: 150% strength
|
| 330 |
|
| 331 |
### Memory Efficiency
|
| 332 |
- **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
|
|
|
|
| 374 |
)
|
| 375 |
|
| 376 |
gr.Markdown("### Advanced Options")
|
| 377 |
+
scale_factor = gr.Slider(
|
| 378 |
+
minimum=0.01,
|
| 379 |
+
maximum=1.99,
|
| 380 |
+
value=1.0,
|
| 381 |
+
step=0.01,
|
| 382 |
+
label="Scale Factor",
|
| 383 |
+
info="Strength of LoRA merge (0 < scale < 2)"
|
| 384 |
+
)
|
| 385 |
multiplicative_lora = gr.Checkbox(
|
| 386 |
label="Multiplicative LoRA",
|
| 387 |
value=False,
|
| 388 |
+
info="Apply multiplicative LoRA instead of additive LoRA"
|
| 389 |
+
)
|
| 390 |
+
inverse_lora = gr.Checkbox(
|
| 391 |
+
label="Inverse Merge",
|
| 392 |
+
value=False,
|
| 393 |
+
info="Apply inverse operation (subtract/invert the LoRA effect)"
|
| 394 |
)
|
| 395 |
|
| 396 |
with gr.Column(scale=1):
|
|
|
|
| 407 |
|
| 408 |
submit_btn.click(
|
| 409 |
fn=merge_lora_efficient,
|
| 410 |
+
inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
|
| 411 |
+
scale_factor, multiplicative_lora, inverse_lora],
|
| 412 |
outputs=output_text
|
| 413 |
)
|
| 414 |
|