try to detect accelerate and only use device_map=None in that case (#373)
Browse files
src/axolotl/utils/config.py
CHANGED
|
@@ -30,6 +30,12 @@ def choose_device(cfg):
|
|
| 30 |
else:
|
| 31 |
cfg.device_map = {"": cfg.device}
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def normalize_config(cfg):
|
| 35 |
# setup some derived config / hyperparams
|
|
|
|
| 30 |
else:
|
| 31 |
cfg.device_map = {"": cfg.device}
|
| 32 |
|
| 33 |
+
# in `accelerate launch`, we need to not pass through any device map and let
|
| 34 |
+
# accelerate figure out which parts of the model to put on which gpu
|
| 35 |
+
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
|
| 36 |
+
if accelerate_vars:
|
| 37 |
+
cfg.device_map = None
|
| 38 |
+
|
| 39 |
|
| 40 |
def normalize_config(cfg):
|
| 41 |
# setup some derived config / hyperparams
|
src/axolotl/utils/models.py
CHANGED
|
@@ -235,6 +235,7 @@ def load_model(
|
|
| 235 |
model = LlamaForCausalLM.from_pretrained(
|
| 236 |
base_model,
|
| 237 |
config=config,
|
|
|
|
| 238 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 239 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 240 |
torch_dtype=torch_dtype,
|
|
@@ -269,6 +270,7 @@ def load_model(
|
|
| 269 |
elif model_type and not cfg.trust_remote_code:
|
| 270 |
model = getattr(transformers, model_type).from_pretrained(
|
| 271 |
base_model,
|
|
|
|
| 272 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 273 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 274 |
torch_dtype=torch_dtype,
|
|
@@ -299,6 +301,7 @@ def load_model(
|
|
| 299 |
model = AutoModelForCausalLM.from_pretrained(
|
| 300 |
base_model,
|
| 301 |
config=config,
|
|
|
|
| 302 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 303 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 304 |
torch_dtype=torch_dtype,
|
|
@@ -312,6 +315,7 @@ def load_model(
|
|
| 312 |
LOG.exception(err)
|
| 313 |
model = AutoModelForCausalLM.from_pretrained(
|
| 314 |
base_model,
|
|
|
|
| 315 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 316 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 317 |
torch_dtype=torch_dtype,
|
|
|
|
| 235 |
model = LlamaForCausalLM.from_pretrained(
|
| 236 |
base_model,
|
| 237 |
config=config,
|
| 238 |
+
device_map=cfg.device_map,
|
| 239 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 240 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 241 |
torch_dtype=torch_dtype,
|
|
|
|
| 270 |
elif model_type and not cfg.trust_remote_code:
|
| 271 |
model = getattr(transformers, model_type).from_pretrained(
|
| 272 |
base_model,
|
| 273 |
+
device_map=cfg.device_map,
|
| 274 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 275 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 276 |
torch_dtype=torch_dtype,
|
|
|
|
| 301 |
model = AutoModelForCausalLM.from_pretrained(
|
| 302 |
base_model,
|
| 303 |
config=config,
|
| 304 |
+
device_map=cfg.device_map,
|
| 305 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 306 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 307 |
torch_dtype=torch_dtype,
|
|
|
|
| 315 |
LOG.exception(err)
|
| 316 |
model = AutoModelForCausalLM.from_pretrained(
|
| 317 |
base_model,
|
| 318 |
+
device_map=cfg.device_map,
|
| 319 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 320 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 321 |
torch_dtype=torch_dtype,
|