| import torch | |
| from safetensors.torch import load_file, save_file | |
| model_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors'] | |
| merged_state_dict = {} | |
| for model_file in model_files: | |
| state_dict = load_file(model_file) | |
| for key, value in state_dict.items(): | |
| if key in merged_state_dict: | |
| merged_state_dict[key] += value | |
| else: | |
| merged_state_dict[key] = value | |
| torch.save(merged_state_dict, 'pytorch_model.bin') | |