Spaces:
Running
Running
Update model/utils.py
Browse files- model/utils.py +2 -2
model/utils.py
CHANGED
|
@@ -562,8 +562,8 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
|
| 562 |
|
| 563 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 564 |
if ckpt_type == "safetensors":
|
| 565 |
-
|
| 566 |
-
|
| 567 |
else:
|
| 568 |
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
| 569 |
|
|
|
|
| 562 |
|
| 563 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 564 |
if ckpt_type == "safetensors":
|
| 565 |
+
from safetensors.torch import load_file
|
| 566 |
+
checkpoint = load_file(ckpt_path, device=device)
|
| 567 |
else:
|
| 568 |
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
| 569 |
|