E2e device cuda (#575)
Browse files* use torch.cuda.current_device() instead of local_rank
* ignore NVML errors for gpu stats
* llama lora packing e2e tests
- .github/workflows/e2e.yml +1 -0
- src/axolotl/utils/bench.py +8 -5
- src/axolotl/utils/config.py +1 -1
- tests/e2e/test_lora_llama.py +42 -0
.github/workflows/e2e.yml
CHANGED
|
@@ -24,6 +24,7 @@ jobs:
|
|
| 24 |
- name: Install dependencies
|
| 25 |
run: |
|
| 26 |
pip3 install -e .
|
|
|
|
| 27 |
pip3 install -r requirements-tests.txt
|
| 28 |
|
| 29 |
- name: Run e2e tests
|
|
|
|
| 24 |
- name: Install dependencies
|
| 25 |
run: |
|
| 26 |
pip3 install -e .
|
| 27 |
+
pip3 install flash-attn
|
| 28 |
pip3 install -r requirements-tests.txt
|
| 29 |
|
| 30 |
- name: Run e2e tests
|
src/axolotl/utils/bench.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import pynvml
|
| 4 |
import torch
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def gpu_memory_usage(device=0):
|
|
@@ -20,11 +21,13 @@ def gpu_memory_usage_smi(device=0):
|
|
| 20 |
device = device.index
|
| 21 |
if isinstance(device, str) and device.startswith("cuda:"):
|
| 22 |
device = int(device[5:])
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def log_gpu_memory_usage(log, msg, device):
|
|
|
|
| 2 |
|
| 3 |
import pynvml
|
| 4 |
import torch
|
| 5 |
+
from pynvml.nvml import NVMLError
|
| 6 |
|
| 7 |
|
| 8 |
def gpu_memory_usage(device=0):
|
|
|
|
| 21 |
device = device.index
|
| 22 |
if isinstance(device, str) and device.startswith("cuda:"):
|
| 23 |
device = int(device[5:])
|
| 24 |
+
try:
|
| 25 |
+
pynvml.nvmlInit()
|
| 26 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 27 |
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 28 |
+
return info.used / 1024.0**3
|
| 29 |
+
except NVMLError:
|
| 30 |
+
return 0.0
|
| 31 |
|
| 32 |
|
| 33 |
def log_gpu_memory_usage(log, msg, device):
|
src/axolotl/utils/config.py
CHANGED
|
@@ -29,7 +29,7 @@ def choose_device(cfg):
|
|
| 29 |
cfg.device_map = "auto"
|
| 30 |
else:
|
| 31 |
if cfg.device.startswith("cuda"):
|
| 32 |
-
cfg.device_map = {"":
|
| 33 |
else:
|
| 34 |
cfg.device_map = {"": cfg.device}
|
| 35 |
|
|
|
|
| 29 |
cfg.device_map = "auto"
|
| 30 |
else:
|
| 31 |
if cfg.device.startswith("cuda"):
|
| 32 |
+
cfg.device_map = {"": torch.cuda.current_device()}
|
| 33 |
else:
|
| 34 |
cfg.device_map = {"": cfg.device}
|
| 35 |
|
tests/e2e/test_lora_llama.py
CHANGED
|
@@ -78,3 +78,45 @@ class TestLoraLlama(unittest.TestCase):
|
|
| 78 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 79 |
|
| 80 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 79 |
|
| 80 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 81 |
+
|
| 82 |
+
def test_lora_packing(self):
|
| 83 |
+
cfg = DictDefault(
|
| 84 |
+
{
|
| 85 |
+
"base_model": "JackFram/llama-68m",
|
| 86 |
+
"base_model_config": "JackFram/llama-68m",
|
| 87 |
+
"tokenizer_type": "LlamaTokenizer",
|
| 88 |
+
"sequence_len": 1024,
|
| 89 |
+
"sample_packing": True,
|
| 90 |
+
"flash_attention": True,
|
| 91 |
+
"load_in_8bit": True,
|
| 92 |
+
"adapter": "lora",
|
| 93 |
+
"lora_r": 32,
|
| 94 |
+
"lora_alpha": 64,
|
| 95 |
+
"lora_dropout": 0.05,
|
| 96 |
+
"lora_target_linear": True,
|
| 97 |
+
"val_set_size": 0.1,
|
| 98 |
+
"special_tokens": {
|
| 99 |
+
"unk_token": "<unk>",
|
| 100 |
+
"bos_token": "<s>",
|
| 101 |
+
"eos_token": "</s>",
|
| 102 |
+
},
|
| 103 |
+
"datasets": [
|
| 104 |
+
{
|
| 105 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
| 106 |
+
"type": "alpaca",
|
| 107 |
+
},
|
| 108 |
+
],
|
| 109 |
+
"num_epochs": 2,
|
| 110 |
+
"micro_batch_size": 8,
|
| 111 |
+
"gradient_accumulation_steps": 1,
|
| 112 |
+
"output_dir": tempfile.mkdtemp(),
|
| 113 |
+
"learning_rate": 0.00001,
|
| 114 |
+
"optimizer": "adamw_torch",
|
| 115 |
+
"lr_scheduler": "cosine",
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
normalize_config(cfg)
|
| 119 |
+
cli_args = TrainerCliArgs()
|
| 120 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 121 |
+
|
| 122 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|