Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| import sys | |
| import torch | |
| from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils | |
| class AWQConfig: | |
| ckpt: str = field( | |
| default=None, | |
| metadata={ | |
| "help": "Load quantized model. The path to the local AWQ checkpoint." | |
| }, | |
| ) | |
| wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) | |
| groupsize: int = field( | |
| default=-1, | |
| metadata={"help": "Groupsize to use for quantization; default uses full row."}, | |
| ) | |
| def load_awq_quantized(model_name, awq_config: AWQConfig, device): | |
| print("Loading AWQ quantized model...") | |
| try: | |
| from tinychat.utils import load_quant | |
| from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp | |
| except ImportError as e: | |
| print(f"Error: Failed to import tinychat. {e}") | |
| print("Please double check if you have successfully installed AWQ") | |
| print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") | |
| sys.exit(-1) | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, use_fast=False, trust_remote_code=True | |
| ) | |
| def skip(*args, **kwargs): | |
| pass | |
| torch.nn.init.kaiming_uniform_ = skip | |
| torch.nn.init.kaiming_normal_ = skip | |
| torch.nn.init.uniform_ = skip | |
| torch.nn.init.normal_ = skip | |
| modeling_utils._init_weights = False | |
| torch.set_default_dtype(torch.half) | |
| model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) | |
| if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): | |
| model = load_quant.load_awq_llama_fast( | |
| model, | |
| find_awq_ckpt(awq_config), | |
| awq_config.wbits, | |
| awq_config.groupsize, | |
| device, | |
| ) | |
| make_quant_attn(model, device) | |
| make_quant_norm(model) | |
| make_fused_mlp(model) | |
| else: | |
| model = load_quant.load_awq_model( | |
| model, | |
| find_awq_ckpt(awq_config), | |
| awq_config.wbits, | |
| awq_config.groupsize, | |
| device, | |
| ) | |
| return model, tokenizer | |
| def find_awq_ckpt(awq_config: AWQConfig): | |
| if Path(awq_config.ckpt).is_file(): | |
| return awq_config.ckpt | |
| for ext in ["*.pt", "*.safetensors"]: | |
| matched_result = sorted(Path(awq_config.ckpt).glob(ext)) | |
| if len(matched_result) > 0: | |
| return str(matched_result[-1]) | |
| print("Error: AWQ checkpoint not found") | |
| sys.exit(1) | |