add validation/warning for bettertransformers and torch version
Browse files
src/axolotl/utils/validation.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Module for validating config files"""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
-
|
| 5 |
|
| 6 |
def validate_config(cfg):
|
| 7 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
@@ -63,7 +63,10 @@ def validate_config(cfg):
|
|
| 63 |
if cfg.fp16 or cfg.bf16:
|
| 64 |
raise ValueError("AMP is not supported with BetterTransformer")
|
| 65 |
if cfg.float16 is not True:
|
| 66 |
-
logging.warning("You should probably set float16 to true")
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# TODO
|
| 69 |
# MPT 7b
|
|
|
|
| 1 |
"""Module for validating config files"""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
+
import torch
|
| 5 |
|
| 6 |
def validate_config(cfg):
|
| 7 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
|
|
| 63 |
if cfg.fp16 or cfg.bf16:
|
| 64 |
raise ValueError("AMP is not supported with BetterTransformer")
|
| 65 |
if cfg.float16 is not True:
|
| 66 |
+
logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers")
|
| 67 |
+
if torch.__version__.split(".")[0] < 2:
|
| 68 |
+
logging.warning("torch>=2.0.0 required")
|
| 69 |
+
raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}")
|
| 70 |
|
| 71 |
# TODO
|
| 72 |
# MPT 7b
|