feat: add check for quantized model (#913)
Browse files* feat: add check for quantized model
* chore: refactor and add another check
* Update src/axolotl/utils/models.py
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/utils/models.py +23 -0
src/axolotl/utils/models.py
CHANGED
|
@@ -28,6 +28,27 @@ from axolotl.utils.dict import DictDefault
|
|
| 28 |
LOG = logging.getLogger("axolotl")
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def load_model_config(cfg):
|
| 32 |
model_config_name = cfg.base_model_config or cfg.base_model
|
| 33 |
trust_remote_code = cfg.trust_remote_code is True
|
|
@@ -38,6 +59,8 @@ def load_model_config(cfg):
|
|
| 38 |
for key, val in cfg.model_config.items():
|
| 39 |
setattr(model_config, key, val)
|
| 40 |
|
|
|
|
|
|
|
| 41 |
return model_config
|
| 42 |
|
| 43 |
|
|
|
|
| 28 |
LOG = logging.getLogger("axolotl")
|
| 29 |
|
| 30 |
|
| 31 |
+
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
| 32 |
+
quant_config_exists = hasattr(model_config, "quantization_config")
|
| 33 |
+
quant_config_method_is_gptq = (
|
| 34 |
+
quant_config_exists
|
| 35 |
+
and "quant_method" in model_config.quantization_config
|
| 36 |
+
and model_config.quantization_config["quant_method"] == "gptq"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if cfg.gptq and not quant_config_method_is_gptq:
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
| 42 |
+
"Please make sure to point to a GPTQ model."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if not cfg.gptq and quant_config_exists:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
"model_config.quantization_config is set but `gptq` flag is not. "
|
| 48 |
+
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
def load_model_config(cfg):
|
| 53 |
model_config_name = cfg.base_model_config or cfg.base_model
|
| 54 |
trust_remote_code = cfg.trust_remote_code is True
|
|
|
|
| 59 |
for key, val in cfg.model_config.items():
|
| 60 |
setattr(model_config, key, val)
|
| 61 |
|
| 62 |
+
check_model_config(cfg, model_config)
|
| 63 |
+
|
| 64 |
return model_config
|
| 65 |
|
| 66 |
|