Fix: bf16 support for inference (#981)
Browse files* Fix: bf16 torch dtype
* simplify casting to device and dtype
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/cli/__init__.py
CHANGED
|
@@ -103,7 +103,7 @@ def do_inference(
|
|
| 103 |
importlib.import_module("axolotl.prompters"), prompter
|
| 104 |
)
|
| 105 |
|
| 106 |
-
model = model.to(cfg.device)
|
| 107 |
|
| 108 |
while True:
|
| 109 |
print("=" * 80)
|
|
@@ -168,7 +168,7 @@ def do_inference_gradio(
|
|
| 168 |
importlib.import_module("axolotl.prompters"), prompter
|
| 169 |
)
|
| 170 |
|
| 171 |
-
model = model.to(cfg.device)
|
| 172 |
|
| 173 |
def generate(instruction):
|
| 174 |
if not instruction:
|
|
|
|
| 103 |
importlib.import_module("axolotl.prompters"), prompter
|
| 104 |
)
|
| 105 |
|
| 106 |
+
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
| 107 |
|
| 108 |
while True:
|
| 109 |
print("=" * 80)
|
|
|
|
| 168 |
importlib.import_module("axolotl.prompters"), prompter
|
| 169 |
)
|
| 170 |
|
| 171 |
+
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
| 172 |
|
| 173 |
def generate(instruction):
|
| 174 |
if not instruction:
|