Make this work on diffusers out of the box.
Browse files- convert.py +8 -2
convert.py
CHANGED
|
@@ -200,7 +200,13 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
| 200 |
prefix, ext = os.path.splitext(filename)
|
| 201 |
if ext in extensions:
|
| 202 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
sf_filename = os.path.join(folder, sf_in_repo)
|
| 205 |
convert_file(pt_filename, sf_filename)
|
| 206 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
|
@@ -219,7 +225,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
|
|
| 219 |
try:
|
| 220 |
operations = None
|
| 221 |
pr = previous_pr(api, model_id, pr_title)
|
| 222 |
-
if ("
|
| 223 |
raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
|
| 224 |
elif pr is not None and not force:
|
| 225 |
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
|
|
|
|
| 200 |
prefix, ext = os.path.splitext(filename)
|
| 201 |
if ext in extensions:
|
| 202 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
| 203 |
+
_, raw_filename = os.path.split(filename)
|
| 204 |
+
if raw_filename == "pytorch_model.bin":
|
| 205 |
+
# XXX: This is a special case to handle `transformers` and the
|
| 206 |
+
# `transformers` part of the model which is actually loaded by `transformers`.
|
| 207 |
+
sf_in_repo = "model.safetensors"
|
| 208 |
+
else:
|
| 209 |
+
sf_in_repo = f"{prefix}.safetensors"
|
| 210 |
sf_filename = os.path.join(folder, sf_in_repo)
|
| 211 |
convert_file(pt_filename, sf_filename)
|
| 212 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
|
|
|
| 225 |
try:
|
| 226 |
operations = None
|
| 227 |
pr = previous_pr(api, model_id, pr_title)
|
| 228 |
+
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
|
| 229 |
raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
|
| 230 |
elif pr is not None and not force:
|
| 231 |
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
|