Merge pull request #179 from OpenAccess-AI-Collective/fix-max_seq_len
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -255,8 +255,15 @@ def load_model(
|
|
| 255 |
)
|
| 256 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
| 257 |
# when training starts
|
| 258 |
-
if config
|
| 259 |
config.max_seq_len = cfg.sequence_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
model = AutoModelForCausalLM.from_pretrained(
|
| 261 |
base_model,
|
| 262 |
config=config,
|
|
|
|
| 255 |
)
|
| 256 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
| 257 |
# when training starts
|
| 258 |
+
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
|
| 259 |
config.max_seq_len = cfg.sequence_len
|
| 260 |
+
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
| 261 |
+
elif (
|
| 262 |
+
hasattr(config, "max_sequence_length")
|
| 263 |
+
and cfg.sequence_len > config.max_sequence_length
|
| 264 |
+
):
|
| 265 |
+
config.max_sequence_length = cfg.sequence_len
|
| 266 |
+
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
| 267 |
model = AutoModelForCausalLM.from_pretrained(
|
| 268 |
base_model,
|
| 269 |
config=config,
|