| import json | |
| from typing import Dict | |
| from safetensors.torch import load_file, save_file | |
| from huggingface_hub import split_torch_state_dict_into_shards | |
| import torch | |
| import os | |
| def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): | |
| state_dict_split = split_torch_state_dict_into_shards(state_dict, filename_pattern='consolidated{suffix}.safetensors') | |
| for filename, tensors in state_dict_split.filename_to_tensors.items(): | |
| shard = {tensor: state_dict[tensor] for tensor in tensors} | |
| print("Saving", save_directory, filename) | |
| save_file(shard, os.path.join(save_directory, filename)) | |
| if state_dict_split.is_sharded: | |
| index = { | |
| "metadata": state_dict_split.metadata, | |
| "weight_map": state_dict_split.tensor_to_filename, | |
| } | |
| with open(os.path.join(save_directory, "consolidated.safetensors.index.json"), "w") as f: | |
| f.write(json.dumps(index, indent=2)) | |
| big_file = 'consolidated.safetensors' | |
| loaded = load_file(big_file) | |
| save_state_dict(loaded, save_directory=f'.') | |