Llama-2-7b-hf-conf-sft / merge_model.py
mzhaoshuai's picture
Upload folder using huggingface_hub
df2d4df verified
# coding=utf-8
import os
import torch
import shutil
import argparse
from transformers import AutoModelForCausalLM
from peft import PeftModel
def main(args):
base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path,
torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, args.peft_model_path)
print("\n>>> Base Model + PEFT before merging:\n", model)
model = model.merge_and_unload()
print("\n>>> Base Model + PEFT after merging:\n", model)
print("\n>>> Save model into {}".format(args.save_dir))
model.save_pretrained(args.save_dir, safe_serialization=True)
print("\n>>> Copy tokenization files...")
shutil.copyfile(os.path.join(args.peft_model_path, "special_tokens_map.json"),
os.path.join(args.save_dir, "special_tokens_map.json"))
shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer.model"),
os.path.join(args.save_dir, "tokenizer.model"))
shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer_config.json"),
os.path.join(args.save_dir, "tokenizer_config.json"))
# Loading test
# merged_model = AutoModelForCausalLM.from_pretrained(args.save_dir)
# print(">>> Merged Model:\n", merged_model)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge PEFT models")
# parser.add_argument("--base_model_path", default="/home/shuai/pretrained/meta-llama/Llama-2-7b-hf", type=str)
# parser.add_argument("--base_model_path", default="/home/shuai/pretrained/HuggingFaceH4/zephyr-7b-alpha", type=str)
# parser.add_argument("--base_model_path", default="/home/shuai/pretrained/mistralai/Mistral-7B-v0.1", type=str)
# parser.add_argument("--peft_model_path", default="/home/shuai/output/rlhf/sim_conf_sft_zephyr_007", type=str)
parser.add_argument("--base_model_path", default="/home/shuai/pretrained/meta-llama/Llama-2-13b-hf", type=str)
parser.add_argument("--peft_model_path", default="/home/shuai/output/rlhf/sim_conf_sft_llama_13b_001", type=str)
args = parser.parse_args()
args.save_dir = args.peft_model_path + "_merged"
print(vars(args))
main(args)