Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| import os | |
| from os.path import isdir, isfile | |
| from pathlib import Path | |
| import sys | |
| from transformers import AutoTokenizer | |
| class GptqConfig: | |
| ckpt: str = field( | |
| default=None, | |
| metadata={ | |
| "help": "Load quantized model. The path to the local GPTQ checkpoint." | |
| }, | |
| ) | |
| wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) | |
| groupsize: int = field( | |
| default=-1, | |
| metadata={"help": "Groupsize to use for quantization; default uses full row."}, | |
| ) | |
| act_order: bool = field( | |
| default=True, | |
| metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, | |
| ) | |
| def load_gptq_quantized(model_name, gptq_config: GptqConfig): | |
| print("Loading GPTQ quantized model...") | |
| try: | |
| script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | |
| module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") | |
| sys.path.insert(0, module_path) | |
| from llama import load_quant | |
| except ImportError as e: | |
| print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") | |
| print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") | |
| sys.exit(-1) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| # only `fastest-inference-4bit` branch cares about `act_order` | |
| if gptq_config.act_order: | |
| model = load_quant( | |
| model_name, | |
| find_gptq_ckpt(gptq_config), | |
| gptq_config.wbits, | |
| gptq_config.groupsize, | |
| act_order=gptq_config.act_order, | |
| ) | |
| else: | |
| # other branches | |
| model = load_quant( | |
| model_name, | |
| find_gptq_ckpt(gptq_config), | |
| gptq_config.wbits, | |
| gptq_config.groupsize, | |
| ) | |
| return model, tokenizer | |
| def find_gptq_ckpt(gptq_config: GptqConfig): | |
| if Path(gptq_config.ckpt).is_file(): | |
| return gptq_config.ckpt | |
| for ext in ["*.pt", "*.safetensors"]: | |
| matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) | |
| if len(matched_result) > 0: | |
| return str(matched_result[-1]) | |
| print("Error: gptq checkpoint not found") | |
| sys.exit(1) | |