Merge pull request #164 from NanoCode012/fix/falcon-fsdp-validate
Browse files- src/axolotl/utils/validation.py +3 -0
- tests/test_validation.py +33 -0
src/axolotl/utils/validation.py
CHANGED
|
@@ -54,6 +54,9 @@ def validate_config(cfg):
|
|
| 54 |
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
| 55 |
)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
# TODO
|
| 58 |
# MPT 7b
|
| 59 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 54 |
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
| 55 |
)
|
| 56 |
|
| 57 |
+
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
| 58 |
+
raise ValueError("FSDP is not supported for falcon models")
|
| 59 |
+
|
| 60 |
# TODO
|
| 61 |
# MPT 7b
|
| 62 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
tests/test_validation.py
CHANGED
|
@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
|
|
| 165 |
)
|
| 166 |
|
| 167 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
|
| 167 |
validate_config(cfg)
|
| 168 |
+
|
| 169 |
+
def test_falcon_fsdp(self):
|
| 170 |
+
regex_exp = r".*FSDP is not supported for falcon models.*"
|
| 171 |
+
|
| 172 |
+
# Check for lower-case
|
| 173 |
+
cfg = DictDefault(
|
| 174 |
+
{
|
| 175 |
+
"base_model": "tiiuae/falcon-7b",
|
| 176 |
+
"fsdp": ["full_shard", "auto_wrap"],
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
with pytest.raises(ValueError, match=regex_exp):
|
| 181 |
+
validate_config(cfg)
|
| 182 |
+
|
| 183 |
+
# Check for upper-case
|
| 184 |
+
cfg = DictDefault(
|
| 185 |
+
{
|
| 186 |
+
"base_model": "Falcon-7b",
|
| 187 |
+
"fsdp": ["full_shard", "auto_wrap"],
|
| 188 |
+
}
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
with pytest.raises(ValueError, match=regex_exp):
|
| 192 |
+
validate_config(cfg)
|
| 193 |
+
|
| 194 |
+
cfg = DictDefault(
|
| 195 |
+
{
|
| 196 |
+
"base_model": "tiiuae/falcon-7b",
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
validate_config(cfg)
|