|  | import io | 
					
						
						|  | import re | 
					
						
						|  | import struct | 
					
						
						|  | from enum import IntEnum | 
					
						
						|  | from math import floor | 
					
						
						|  |  | 
					
						
						|  | import requests | 
					
						
						|  |  | 
					
						
						|  | import gradio as gr | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GGUFValueType(IntEnum): | 
					
						
						|  | UINT8 = 0 | 
					
						
						|  | INT8 = 1 | 
					
						
						|  | UINT16 = 2 | 
					
						
						|  | INT16 = 3 | 
					
						
						|  | UINT32 = 4 | 
					
						
						|  | INT32 = 5 | 
					
						
						|  | FLOAT32 = 6 | 
					
						
						|  | BOOL = 7 | 
					
						
						|  | STRING = 8 | 
					
						
						|  | ARRAY = 9 | 
					
						
						|  | UINT64 = 10 | 
					
						
						|  | INT64 = 11 | 
					
						
						|  | FLOAT64 = 12 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | _simple_value_packing = { | 
					
						
						|  | GGUFValueType.UINT8: "<B", | 
					
						
						|  | GGUFValueType.INT8: "<b", | 
					
						
						|  | GGUFValueType.UINT16: "<H", | 
					
						
						|  | GGUFValueType.INT16: "<h", | 
					
						
						|  | GGUFValueType.UINT32: "<I", | 
					
						
						|  | GGUFValueType.INT32: "<i", | 
					
						
						|  | GGUFValueType.FLOAT32: "<f", | 
					
						
						|  | GGUFValueType.UINT64: "<Q", | 
					
						
						|  | GGUFValueType.INT64: "<q", | 
					
						
						|  | GGUFValueType.FLOAT64: "<d", | 
					
						
						|  | GGUFValueType.BOOL: "?", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | value_type_info = { | 
					
						
						|  | GGUFValueType.UINT8: 1, | 
					
						
						|  | GGUFValueType.INT8: 1, | 
					
						
						|  | GGUFValueType.UINT16: 2, | 
					
						
						|  | GGUFValueType.INT16: 2, | 
					
						
						|  | GGUFValueType.UINT32: 4, | 
					
						
						|  | GGUFValueType.INT32: 4, | 
					
						
						|  | GGUFValueType.FLOAT32: 4, | 
					
						
						|  | GGUFValueType.UINT64: 8, | 
					
						
						|  | GGUFValueType.INT64: 8, | 
					
						
						|  | GGUFValueType.FLOAT64: 8, | 
					
						
						|  | GGUFValueType.BOOL: 1, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_single(value_type, file): | 
					
						
						|  | if value_type == GGUFValueType.STRING: | 
					
						
						|  | value_length = struct.unpack("<Q", file.read(8))[0] | 
					
						
						|  | value = file.read(value_length) | 
					
						
						|  | try: | 
					
						
						|  | value = value.decode('utf-8') | 
					
						
						|  | except: | 
					
						
						|  | pass | 
					
						
						|  | else: | 
					
						
						|  | type_str = _simple_value_packing.get(value_type) | 
					
						
						|  | bytes_length = value_type_info.get(value_type) | 
					
						
						|  | value = struct.unpack(type_str, file.read(bytes_length))[0] | 
					
						
						|  |  | 
					
						
						|  | return value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_metadata_from_file(file_obj): | 
					
						
						|  | """Load metadata from a file-like object""" | 
					
						
						|  | metadata = {} | 
					
						
						|  |  | 
					
						
						|  | GGUF_MAGIC = struct.unpack("<I", file_obj.read(4))[0] | 
					
						
						|  | GGUF_VERSION = struct.unpack("<I", file_obj.read(4))[0] | 
					
						
						|  | ti_data_count = struct.unpack("<Q", file_obj.read(8))[0] | 
					
						
						|  | kv_data_count = struct.unpack("<Q", file_obj.read(8))[0] | 
					
						
						|  |  | 
					
						
						|  | if GGUF_VERSION == 1: | 
					
						
						|  | raise Exception('You are using an outdated GGUF, please download a new one.') | 
					
						
						|  |  | 
					
						
						|  | for i in range(kv_data_count): | 
					
						
						|  | key_length = struct.unpack("<Q", file_obj.read(8))[0] | 
					
						
						|  | key = file_obj.read(key_length) | 
					
						
						|  |  | 
					
						
						|  | value_type = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) | 
					
						
						|  | if value_type == GGUFValueType.ARRAY: | 
					
						
						|  | ltype = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) | 
					
						
						|  | length = struct.unpack("<Q", file_obj.read(8))[0] | 
					
						
						|  |  | 
					
						
						|  | arr = [get_single(ltype, file_obj) for _ in range(length)] | 
					
						
						|  | metadata[key.decode()] = arr | 
					
						
						|  | else: | 
					
						
						|  | value = get_single(value_type, file_obj) | 
					
						
						|  | metadata[key.decode()] = value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | extracted_fields = {} | 
					
						
						|  | for key, value in metadata.items(): | 
					
						
						|  | if key.endswith('.block_count'): | 
					
						
						|  | extracted_fields['n_layers'] = value | 
					
						
						|  | elif key.endswith('.attention.head_count_kv'): | 
					
						
						|  | extracted_fields['n_kv_heads'] = max(value) if isinstance(value, list) else value | 
					
						
						|  | elif key.endswith('.embedding_length'): | 
					
						
						|  | extracted_fields['embedding_dim'] = value | 
					
						
						|  | elif key.endswith('.context_length'): | 
					
						
						|  | extracted_fields['context_length'] = value | 
					
						
						|  | elif key.endswith('.feed_forward_length'): | 
					
						
						|  | extracted_fields['feed_forward_dim'] = value | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | metadata.update(extracted_fields) | 
					
						
						|  | return metadata | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def download_gguf_partial(url, max_bytes=25 * 1024 * 1024): | 
					
						
						|  | """Download the first max_bytes from a GGUF URL""" | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | headers = {'Range': f'bytes=0-{max_bytes-1}'} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response = requests.get(url, headers=headers, stream=True) | 
					
						
						|  | response.raise_for_status() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | content = response.content | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return io.BytesIO(content) | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise Exception(f"Failed to download GGUF file: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_metadata(model_url, current_metadata): | 
					
						
						|  | """Load metadata from model URL and return updated metadata dict""" | 
					
						
						|  | if not model_url or model_url.strip() == "": | 
					
						
						|  | return {}, gr.update(), "Please enter a model URL" | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | model_size_mb = get_model_size_mb_from_url(model_url) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | normalized_url = normalize_huggingface_url(model_url) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | file_obj = download_gguf_partial(normalized_url) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | metadata = load_metadata_from_file(file_obj) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | gguf_filename = model_url.split('/')[-1].split('?')[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_name = model_url | 
					
						
						|  | if "huggingface.co/" in model_url: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | parts = model_url.split("huggingface.co/")[1].split("/") | 
					
						
						|  | if len(parts) >= 2: | 
					
						
						|  | model_name = f"{parts[0]}/{parts[1]}" | 
					
						
						|  | except: | 
					
						
						|  | model_name = model_url | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | metadata['url'] = model_url | 
					
						
						|  | metadata['model_name'] = model_name | 
					
						
						|  | metadata['model_size_mb'] = model_size_mb | 
					
						
						|  | metadata['loaded'] = True | 
					
						
						|  |  | 
					
						
						|  | return metadata, gr.update(value=metadata["n_layers"], maximum=metadata["n_layers"]), f"Metadata loaded successfully for: {gguf_filename}" | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | error_msg = f"Error loading metadata: {str(e)}" | 
					
						
						|  | return {}, gr.update(), error_msg | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def normalize_huggingface_url(url: str) -> str: | 
					
						
						|  | """Normalize HuggingFace URL to resolve format for direct access""" | 
					
						
						|  | if 'huggingface.co' not in url: | 
					
						
						|  | return url | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | base_url = url.split('?')[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if '/blob/' in base_url: | 
					
						
						|  | base_url = base_url.replace('/blob/', '/resolve/') | 
					
						
						|  |  | 
					
						
						|  | return base_url | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_model_size_mb_from_url(model_url: str) -> float: | 
					
						
						|  | """Get model size in MB from URL without downloading, handling multi-part files""" | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | normalized_url = normalize_huggingface_url(model_url) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response = requests.head(normalized_url, allow_redirects=True) | 
					
						
						|  | response.raise_for_status() | 
					
						
						|  | main_file_size = int(response.headers.get('content-length', 0)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | filename = normalized_url.split('/')[-1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | match = re.match(r'(.+)-(\d+)-of-(\d+)\.gguf$', filename) | 
					
						
						|  |  | 
					
						
						|  | if match: | 
					
						
						|  | base_pattern = match.group(1) | 
					
						
						|  | total_parts = int(match.group(3)) | 
					
						
						|  |  | 
					
						
						|  | total_size = 0 | 
					
						
						|  | base_url = '/'.join(normalized_url.split('/')[:-1]) + '/' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for part_num in range(1, total_parts + 1): | 
					
						
						|  | part_filename = f"{base_pattern}-{part_num:05d}-of-{total_parts:05d}.gguf" | 
					
						
						|  | part_url = base_url + part_filename | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | part_response = requests.head(part_url, allow_redirects=True) | 
					
						
						|  | part_response.raise_for_status() | 
					
						
						|  | part_size = int(part_response.headers.get('content-length', 0)) | 
					
						
						|  | total_size += part_size | 
					
						
						|  | except requests.RequestException as e: | 
					
						
						|  | print(f"Warning: Could not get size of {part_filename}, estimating...") | 
					
						
						|  |  | 
					
						
						|  | if total_size > 0: | 
					
						
						|  | avg_size = total_size / (part_num - 1) | 
					
						
						|  | remaining_parts = total_parts - (part_num - 1) | 
					
						
						|  | total_size += avg_size * remaining_parts | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | total_size = main_file_size * total_parts | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | return total_size / (1024 ** 2) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | return main_file_size / (1024 ** 2) | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error getting model size: {e}") | 
					
						
						|  | return 0.0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def estimate_vram(metadata, gpu_layers, ctx_size, cache_type): | 
					
						
						|  | """Calculate VRAM usage using the actual formula""" | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | n_layers = metadata.get('n_layers') | 
					
						
						|  | n_kv_heads = metadata.get('n_kv_heads') | 
					
						
						|  | embedding_dim = metadata.get('embedding_dim') | 
					
						
						|  | context_length = metadata.get('context_length') | 
					
						
						|  | feed_forward_dim = metadata.get('feed_forward_dim') | 
					
						
						|  | size_in_mb = metadata.get('model_size_mb', 0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | required_fields = [n_layers, n_kv_heads, embedding_dim, context_length, feed_forward_dim] | 
					
						
						|  | if any(field is None for field in required_fields): | 
					
						
						|  | missing = [name for name, field in zip( | 
					
						
						|  | ['n_layers', 'n_kv_heads', 'embedding_dim', 'context_length', 'feed_forward_dim'], | 
					
						
						|  | required_fields) if field is None] | 
					
						
						|  | raise ValueError(f"Missing required metadata fields: {missing}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if gpu_layers > n_layers: | 
					
						
						|  | gpu_layers = n_layers | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cache_type == 'q4_0': | 
					
						
						|  | cache_type = 4 | 
					
						
						|  | elif cache_type == 'q8_0': | 
					
						
						|  | cache_type = 8 | 
					
						
						|  | else: | 
					
						
						|  | cache_type = 16 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | size_per_layer = size_in_mb / max(n_layers, 1e-6) | 
					
						
						|  | kv_cache_factor = n_kv_heads * cache_type * ctx_size | 
					
						
						|  | embedding_per_context = embedding_dim / ctx_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vram = ( | 
					
						
						|  | (size_per_layer - 17.99552795246051 + 3.148552680382576e-05 * kv_cache_factor) | 
					
						
						|  | * (gpu_layers + max(0.9690636483914102, cache_type - (floor(50.77817218646521 * embedding_per_context) + 9.987899908205632))) | 
					
						
						|  | + 1516.522943869404 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return vram | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error in VRAM calculation: {e}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def estimate_vram_wrapper(model_metadata, gpu_layers, ctx_size, cache_type): | 
					
						
						|  | """Wrapper function to estimate VRAM usage""" | 
					
						
						|  | if not model_metadata or 'model_name' not in model_metadata: | 
					
						
						|  | return "<div id=\"vram-info\">Estimated VRAM to load the model:</div>" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | result = estimate_vram(model_metadata, gpu_layers, ctx_size, cache_type) | 
					
						
						|  | conservative = result + 577 | 
					
						
						|  | return f"""<div id="vram-info"> | 
					
						
						|  | <div>Expected VRAM usage: <span class="value">{result:.0f} MiB</span></div> | 
					
						
						|  | <div>Safe estimate: <span class="value">{conservative:.0f} MiB</span> - 95% chance the VRAM is at most this.</div> | 
					
						
						|  | </div>""" | 
					
						
						|  | except Exception as e: | 
					
						
						|  | return f"<div id=\"vram-info\">Estimated VRAM to load the model: <span class=\"value\">Error: {str(e)}</span></div>" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_ui(): | 
					
						
						|  | """Create the simplified UI""" | 
					
						
						|  |  | 
					
						
						|  | css = """ | 
					
						
						|  | body { | 
					
						
						|  | max-width: 810px !important; | 
					
						
						|  | margin: 0 auto !important; | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | #vram-info { | 
					
						
						|  | padding: 10px; | 
					
						
						|  | border-radius: 4px; | 
					
						
						|  | background-color: var(--background-fill-secondary); | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | #vram-info .value { | 
					
						
						|  | font-weight: bold; | 
					
						
						|  | color: var(--primary-500); | 
					
						
						|  | } | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks(css=css) as demo: | 
					
						
						|  |  | 
					
						
						|  | model_metadata = gr.State(value={}) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("# Accurate GGUF VRAM Calculator\n\nCalculate VRAM for GGUF models from GPU layers and context length using an accurate formula.\n\nFor an explanation about how this works, consult this blog post: https://oobabooga.github.io/blog/posts/gguf-vram-formula/") | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(): | 
					
						
						|  |  | 
					
						
						|  | model_url = gr.Textbox( | 
					
						
						|  | label="GGUF Model URL", | 
					
						
						|  | value="https://huggingface.co/unsloth/Qwen3-235B-A22B-GGUF/blob/main/UD-Q2_K_XL/Qwen3-235B-A22B-UD-Q2_K_XL-00001-of-00002.gguf" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | load_metadata_btn = gr.Button("Load metadata", elem_classes='refresh-button') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | gpu_layers = gr.Slider( | 
					
						
						|  | label="GPU Layers", | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=256, | 
					
						
						|  | value=256, | 
					
						
						|  | info='`--gpu-layers` in llama.cpp.' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ctx_size = gr.Slider( | 
					
						
						|  | label='Context Length', | 
					
						
						|  | minimum=512, | 
					
						
						|  | maximum=131072, | 
					
						
						|  | step=256, | 
					
						
						|  | value=8192, | 
					
						
						|  | info='`--ctx-size` in llama.cpp.' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cache_type = gr.Radio( | 
					
						
						|  | choices=['fp16', 'q8_0', 'q4_0'], | 
					
						
						|  | value='fp16', | 
					
						
						|  | label="Cache Type", | 
					
						
						|  | info='Cache quantization.' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vram_info = gr.HTML( | 
					
						
						|  | value="<div id=\"vram-info\">Estimated VRAM to load the model:</div>" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | status = gr.Textbox( | 
					
						
						|  | label="Status", | 
					
						
						|  | value="No model loaded", | 
					
						
						|  | interactive=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | load_metadata_btn.click( | 
					
						
						|  | load_metadata, | 
					
						
						|  | inputs=[model_url, model_metadata], | 
					
						
						|  | outputs=[model_metadata, gpu_layers, status], | 
					
						
						|  | show_progress=True | 
					
						
						|  | ).then( | 
					
						
						|  | estimate_vram_wrapper, | 
					
						
						|  | inputs=[model_metadata, gpu_layers, ctx_size, cache_type], | 
					
						
						|  | outputs=[vram_info], | 
					
						
						|  | show_progress=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for component in [gpu_layers, ctx_size, cache_type]: | 
					
						
						|  | component.change( | 
					
						
						|  | estimate_vram_wrapper, | 
					
						
						|  | inputs=[model_metadata, gpu_layers, ctx_size, cache_type], | 
					
						
						|  | outputs=[vram_info], | 
					
						
						|  | show_progress=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_metadata.change( | 
					
						
						|  | estimate_vram_wrapper, | 
					
						
						|  | inputs=[model_metadata, gpu_layers, ctx_size, cache_type], | 
					
						
						|  | outputs=[vram_info], | 
					
						
						|  | show_progress=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return demo | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  |  | 
					
						
						|  | demo = create_ui() | 
					
						
						|  | demo.launch() | 
					
						
						|  |  |