Hotfix
Browse fileshttps://github.com/huggingface/safetensors/pull/102
- convert.py +4 -3
convert.py
CHANGED
|
@@ -45,7 +45,8 @@ def check_file_size(sf_filename: str, pt_filename: str):
|
|
| 45 |
|
| 46 |
|
| 47 |
def rename(pt_filename: str) -> str:
|
| 48 |
-
|
|
|
|
| 49 |
local = local.replace("pytorch_model", "model")
|
| 50 |
return local
|
| 51 |
|
|
@@ -103,7 +104,7 @@ def convert_file(
|
|
| 103 |
# For tensors to be contiguous
|
| 104 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
| 105 |
|
| 106 |
-
dirname =
|
| 107 |
os.makedirs(dirname, exist_ok=True)
|
| 108 |
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
| 109 |
check_file_size(sf_filename, pt_filename)
|
|
@@ -199,7 +200,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
| 199 |
prefix, ext = os.path.splitext(filename)
|
| 200 |
if ext in extensions:
|
| 201 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
| 202 |
-
sf_in_repo = f"{
|
| 203 |
sf_filename = os.path.join(folder, sf_in_repo)
|
| 204 |
convert_file(pt_filename, sf_filename)
|
| 205 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def rename(pt_filename: str) -> str:
|
| 48 |
+
filename, ext = os.path.splitext(pt_filename)
|
| 49 |
+
local = f"{filename}.safetensors"
|
| 50 |
local = local.replace("pytorch_model", "model")
|
| 51 |
return local
|
| 52 |
|
|
|
|
| 104 |
# For tensors to be contiguous
|
| 105 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
| 106 |
|
| 107 |
+
dirname = os.path.dirname(sf_filename)
|
| 108 |
os.makedirs(dirname, exist_ok=True)
|
| 109 |
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
| 110 |
check_file_size(sf_filename, pt_filename)
|
|
|
|
| 200 |
prefix, ext = os.path.splitext(filename)
|
| 201 |
if ext in extensions:
|
| 202 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
| 203 |
+
sf_in_repo = f"{prefix}.safetensors"
|
| 204 |
sf_filename = os.path.join(folder, sf_in_repo)
|
| 205 |
convert_file(pt_filename, sf_filename)
|
| 206 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|