Update convert.py
Browse files- convert.py +7 -7
convert.py
CHANGED
|
@@ -71,7 +71,7 @@ def rename(pt_filename: str) -> str:
|
|
| 71 |
return local
|
| 72 |
|
| 73 |
|
| 74 |
-
def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
|
| 75 |
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
|
| 76 |
with open(filename, "r") as f:
|
| 77 |
data = json.load(f)
|
|
@@ -79,7 +79,7 @@ def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
|
|
| 79 |
filenames = set(data["weight_map"].values())
|
| 80 |
local_filenames = []
|
| 81 |
for filename in filenames:
|
| 82 |
-
pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
|
| 83 |
|
| 84 |
sf_filename = rename(pt_filename)
|
| 85 |
sf_filename = os.path.join(folder, sf_filename)
|
|
@@ -102,7 +102,7 @@ def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
|
|
| 102 |
return operations, errors
|
| 103 |
|
| 104 |
|
| 105 |
-
def convert_single(model_id: str, folder: str, token: str) -> ConversionResult:
|
| 106 |
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
|
| 107 |
|
| 108 |
sf_name = "model.safetensors"
|
|
@@ -156,8 +156,8 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
|
|
| 156 |
return "\n".join(errors)
|
| 157 |
|
| 158 |
|
| 159 |
-
def check_final_model(model_id: str, folder: str):
|
| 160 |
-
config = hf_hub_download(repo_id=model_id, filename="config.json")
|
| 161 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
| 162 |
config = AutoConfig.from_pretrained(folder)
|
| 163 |
|
|
@@ -236,7 +236,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
|
|
| 236 |
return None
|
| 237 |
|
| 238 |
|
| 239 |
-
def convert_generic(model_id: str, folder: str, filenames: Set[str], token: str) -> ConversionResult:
|
| 240 |
operations = []
|
| 241 |
errors = []
|
| 242 |
|
|
@@ -288,7 +288,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
|
|
| 288 |
operations, errors = convert_multi(model_id, folder, token=api.token)
|
| 289 |
else:
|
| 290 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 291 |
-
check_final_model(model_id, folder)
|
| 292 |
else:
|
| 293 |
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
|
| 294 |
|
|
|
|
| 71 |
return local
|
| 72 |
|
| 73 |
|
| 74 |
+
def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
|
| 75 |
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
|
| 76 |
with open(filename, "r") as f:
|
| 77 |
data = json.load(f)
|
|
|
|
| 79 |
filenames = set(data["weight_map"].values())
|
| 80 |
local_filenames = []
|
| 81 |
for filename in filenames:
|
| 82 |
+
pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token)
|
| 83 |
|
| 84 |
sf_filename = rename(pt_filename)
|
| 85 |
sf_filename = os.path.join(folder, sf_filename)
|
|
|
|
| 102 |
return operations, errors
|
| 103 |
|
| 104 |
|
| 105 |
+
def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
|
| 106 |
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
|
| 107 |
|
| 108 |
sf_name = "model.safetensors"
|
|
|
|
| 156 |
return "\n".join(errors)
|
| 157 |
|
| 158 |
|
| 159 |
+
def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
| 160 |
+
config = hf_hub_download(repo_id=model_id, filename="config.json", token=token)
|
| 161 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
| 162 |
config = AutoConfig.from_pretrained(folder)
|
| 163 |
|
|
|
|
| 236 |
return None
|
| 237 |
|
| 238 |
|
| 239 |
+
def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
|
| 240 |
operations = []
|
| 241 |
errors = []
|
| 242 |
|
|
|
|
| 288 |
operations, errors = convert_multi(model_id, folder, token=api.token)
|
| 289 |
else:
|
| 290 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 291 |
+
check_final_model(model_id, folder, token=api.token)
|
| 292 |
else:
|
| 293 |
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
|
| 294 |
|