Spaces:
Running
Running
| import gc | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download, HfApi, login, list_repo_files | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file, load_file | |
| import os | |
| import shutil | |
| import json | |
| api = HfApi() | |
| def info_fn(text): | |
| gr.Info(text) | |
| def warning_fn(text): | |
| gr.Warning(text) | |
| def load_lora_state(lora_model_name): | |
| """Download and load LoRA adapter weights""" | |
| temp_lora_dir = "/tmp/lora_adapter" | |
| os.makedirs(temp_lora_dir, exist_ok=True) | |
| # Download adapter config | |
| config_path = hf_hub_download( | |
| repo_id=lora_model_name, | |
| filename="adapter_config.json", | |
| local_dir=temp_lora_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| with open(config_path, 'r') as f: | |
| lora_config = json.load(f) | |
| # Download adapter weights | |
| try: | |
| adapter_path = hf_hub_download( | |
| repo_id=lora_model_name, | |
| filename="adapter_model.safetensors", | |
| local_dir=temp_lora_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| lora_state = load_file(adapter_path, device='cpu') | |
| except: | |
| adapter_path = hf_hub_download( | |
| repo_id=lora_model_name, | |
| filename="adapter_model.bin", | |
| local_dir=temp_lora_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| lora_state = torch.load(adapter_path, map_location='cpu') | |
| return lora_state, lora_config, temp_lora_dir | |
| def find_lora_weights(lora_state, key): | |
| """Find corresponding LoRA A and B weights for a given key""" | |
| lora_A = None | |
| lora_B = None | |
| # Remove .weight suffix for matching | |
| clean_key = key.strip('.weight') | |
| for lora_key, lora_weight in lora_state.items(): | |
| if clean_key in lora_key: | |
| if 'lora_A' in lora_key: | |
| lora_A = lora_weight | |
| elif 'lora_B' in lora_key: | |
| lora_B = lora_weight | |
| # Both should be None or both should have values | |
| if (lora_A is None) != (lora_B is None): | |
| return None, None | |
| return lora_A, lora_B | |
| def download_and_upload_non_model_files(base_model_name, output_repo_name): | |
| """Download and upload non-model files (config, tokenizer, etc.)""" | |
| temp_config_dir = "/tmp/config_files" | |
| os.makedirs(temp_config_dir, exist_ok=True) | |
| try: | |
| # List all files in the repository | |
| files = list_repo_files(repo_id=base_model_name) | |
| # Filter non-model files | |
| non_model_files = [ | |
| f for f in files | |
| if not (f.startswith('model') and f.endswith('.safetensors')) | |
| ] | |
| # Download and upload each non-model file | |
| for filename in non_model_files: | |
| if filename.endswith(('.gguf', '.bin')) and 'model' in filename: | |
| continue # Skip other model formats | |
| try: | |
| file_path = hf_hub_download( | |
| repo_id=base_model_name, | |
| filename=filename, | |
| local_dir=temp_config_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| # Upload to output repo | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=filename, | |
| repo_id=output_repo_name, | |
| repo_type="model" | |
| ) | |
| except Exception as e: | |
| info_fn(f"Skipping {filename}: {e}") | |
| finally: | |
| shutil.rmtree(temp_config_dir, ignore_errors=True) | |
| def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name, | |
| scale_factor, progress=gr.Progress()): | |
| temp_lora_dir = None | |
| try: | |
| # Validate scale factor | |
| if not (-2 <= scale_factor <= 2): | |
| error_msg = "Scale factor must be in the range [-2, 2]" | |
| warning_fn(error_msg) | |
| return f"β Error: {error_msg}" | |
| login(hf_token) | |
| progress(0.1, desc="Loading LoRA adapter...") | |
| info_fn("Loading LoRA adapter...") | |
| # Load LoRA state (this downloads the adapter) | |
| lora_state, lora_config, temp_lora_dir = load_lora_state(lora_model_name) | |
| # Calculate scale with user factor | |
| base_scale = lora_config['lora_alpha'] / lora_config['r'] | |
| scale = base_scale * scale_factor | |
| info_fn(f"Using LoRA scale: {scale} (base: {base_scale:.3f} Γ factor: {scale_factor})") | |
| progress(0.2, desc="Creating output repository...") | |
| # Create repository | |
| try: | |
| repo_url = api.create_repo(repo_id=output_repo_name, exist_ok=True) | |
| info_fn(f"Repository created/updated: {repo_url}") | |
| except Exception as e: | |
| warning_fn(f"Repository might already exist: {e}") | |
| progress(0.3, desc="Uploading configuration files...") | |
| info_fn("Uploading configuration files...") | |
| # Download and upload non-model files | |
| download_and_upload_non_model_files(base_model_name, output_repo_name) | |
| progress(0.4, desc="Finding model shards...") | |
| info_fn("Finding model shards...") | |
| # Get list of all safetensors files | |
| all_files = list_repo_files(repo_id=base_model_name) | |
| shard_files = [f for f in all_files if f.startswith('model') and f.endswith('.safetensors')] | |
| if not shard_files: | |
| raise FileNotFoundError("No model safetensors files found in the repository") | |
| info_fn(f"Found {len(shard_files)} model shards to process") | |
| merged_tensors = 0 | |
| total_shards = len(shard_files) | |
| # Process each shard individually | |
| for i, shard_filename in enumerate(shard_files): | |
| progress(0.4 + (i / total_shards) * 0.5, | |
| desc=f"Processing {shard_filename} ({i+1}/{total_shards})") | |
| info_fn(f"Processing shard {i+1}/{total_shards}: {shard_filename}") | |
| # Create temporary directory for this shard only | |
| temp_shard_dir = f"/tmp/shard_{i}" | |
| os.makedirs(temp_shard_dir, exist_ok=True) | |
| try: | |
| # Download the current shard | |
| shard_path = hf_hub_download( | |
| repo_id=base_model_name, | |
| filename=shard_filename, | |
| local_dir=temp_shard_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| # Process the shard | |
| tensors = {} | |
| shard_merged_count = 0 | |
| with safe_open(shard_path, framework='pt', device='cpu') as f: | |
| # Get metadata if available | |
| metadata = f.metadata() if hasattr(f, 'metadata') else {} | |
| for key in f.keys(): | |
| tensor = f.get_tensor(key) | |
| # Try to find corresponding LoRA weights | |
| lora_A, lora_B = find_lora_weights(lora_state, key) | |
| if lora_A is not None and lora_B is not None: | |
| info_fn(f"Merging LoRA weights for {key}") | |
| shard_merged_count += 1 | |
| merged_tensors += 1 | |
| # Convert to float32 for computation | |
| original_dtype = tensor.dtype | |
| tensor = tensor.to(torch.float32) | |
| lora_delta = scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32) | |
| # Validate dimensions for additive LoRA | |
| if lora_delta.shape != tensor.shape: | |
| raise ValueError(f"Additive LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}") | |
| tensor += lora_delta | |
| # Convert back to original dtype | |
| tensor = tensor.to(original_dtype) | |
| # Clean up intermediate tensors | |
| del lora_delta | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| tensors[key] = tensor | |
| # Save processed shard to temporary file | |
| output_shard_path = os.path.join(temp_shard_dir, f"processed_{shard_filename}") | |
| save_file(tensors, output_shard_path, metadata=metadata) | |
| info_fn(f"Shard {shard_filename}: Merged {shard_merged_count} tensors") | |
| # Upload the processed shard | |
| api.upload_file( | |
| path_or_fileobj=output_shard_path, | |
| path_in_repo=shard_filename, | |
| repo_id=output_repo_name, | |
| repo_type="model" | |
| ) | |
| # Clean up this shard's data | |
| del tensors | |
| gc.collect() | |
| finally: | |
| # Always clean up the temporary shard directory | |
| shutil.rmtree(temp_shard_dir, ignore_errors=True) | |
| progress(1.0, desc="Upload completed!") | |
| success_msg = f"β Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nScale factor: {scale_factor}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights" | |
| info_fn("Merge completed successfully!") | |
| return success_msg | |
| except Exception as e: | |
| error_msg = f"β Error during merge: {str(e)}" | |
| warning_fn(error_msg) | |
| return error_msg | |
| finally: | |
| # Cleanup LoRA directory | |
| if temp_lora_dir and os.path.exists(temp_lora_dir): | |
| shutil.rmtree(temp_lora_dir, ignore_errors=True) | |
| gc.collect() | |
| INTRODUCTION_TEXT = """ | |
| ## Memory-Efficient LoRA Merge | |
| This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a memory-efficient approach that processes model files individually, significantly reducing memory requirements compared to traditional methods. | |
| ### Key Features | |
| - **Minimal Memory Usage**: Processes one model shard at a time instead of loading the entire model | |
| - **Streaming Processing**: Downloads β Processes β Uploads β Deletes each shard sequentially | |
| - **Automatic Cleanup**: Temporary files are automatically removed after processing | |
| - **Progress Tracking**: Real-time status updates throughout the merge process | |
| - **Advanced Options**: Custom scale factors (including negative values) | |
| """ | |
| DETAILS_TEXT = """ | |
| ### How It Works | |
| LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model: | |
| `W_new = W + scale Γ (Ξ±/r) Γ B @ A` | |
| ### Scale Factor | |
| The scale factor (-2 β€ scale β€ 2) controls the strength of the LoRA merge: | |
| - **1.0**: Full strength (default) | |
| - **0.5**: Half strength | |
| - **-1.0**: Reverse effect (removes LoRA impact) | |
| ### Memory Efficiency | |
| - **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models) | |
| - **This approach**: Peak usage determined by largest shard size, not total model size | |
| - **Result**: Enables merging of much larger models on limited hardware | |
| ### Example Usage | |
| - **Base Model:** `microsoft/DialoGPT-medium` | |
| - **LoRA Adapter:** `username/my-trained-lora` | |
| - **Output Name:** `username/dialogpt-merged` | |
| ### Attribution | |
| This tool builds upon excellent work from the community: | |
| - **Base implementation:** [Weyaxi/merge-lora](https://huggingface.co/spaces/Weyaxi/merge-lora) | |
| - **Memory-efficient method:** [qlora-pipe](https://github.com/tdrussell/qlora-pipe/blob/main/tools/merge_lora.py) by tdrussell | |
| """ | |
| with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(INTRODUCTION_TEXT) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Configuration") | |
| hf_token = gr.Textbox( | |
| label="Hugging Face Token", | |
| placeholder="hf_...", | |
| type="password", | |
| info="Token with write access to create repositories" | |
| ) | |
| base_model_name = gr.Textbox( | |
| label="Base Model Repository", | |
| placeholder="microsoft/DialoGPT-medium", | |
| info="The original model to merge LoRA into" | |
| ) | |
| lora_model_name = gr.Textbox( | |
| label="LoRA Adapter Repository", | |
| placeholder="username/my-lora-adapter", | |
| info="Repository containing adapter_model.safetensors" | |
| ) | |
| output_repo_name = gr.Textbox( | |
| label="Output Repository Name", | |
| placeholder="username/my-merged-model", | |
| info="Name for the new merged model repository" | |
| ) | |
| gr.Markdown("### Advanced Options") | |
| scale_factor = gr.Slider( | |
| minimum=-2.0, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.01, | |
| label="Scale Factor", | |
| info="Strength of LoRA merge (-2 β€ scale β€ 2)" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Status") | |
| output_text = gr.Textbox( | |
| label="Merge Progress & Results", | |
| lines=20, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Start LoRA Merge", variant="primary", size="lg") | |
| submit_btn.click( | |
| fn=merge_lora_efficient, | |
| inputs=[hf_token, base_model_name, lora_model_name, output_repo_name, | |
| scale_factor], | |
| outputs=output_text | |
| ) | |
| gr.Markdown(DETAILS_TEXT) | |
| demo.queue() | |
| demo.launch(show_error=True) |