|
|
|
|
|
import os |
|
|
import torch |
|
|
import shutil |
|
|
import argparse |
|
|
from transformers import AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
def main(args): |
|
|
base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path, |
|
|
torch_dtype=torch.bfloat16) |
|
|
model = PeftModel.from_pretrained(base_model, args.peft_model_path) |
|
|
print("\n>>> Base Model + PEFT before merging:\n", model) |
|
|
model = model.merge_and_unload() |
|
|
print("\n>>> Base Model + PEFT after merging:\n", model) |
|
|
print("\n>>> Save model into {}".format(args.save_dir)) |
|
|
model.save_pretrained(args.save_dir, safe_serialization=True) |
|
|
|
|
|
print("\n>>> Copy tokenization files...") |
|
|
shutil.copyfile(os.path.join(args.peft_model_path, "special_tokens_map.json"), |
|
|
os.path.join(args.save_dir, "special_tokens_map.json")) |
|
|
shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer.model"), |
|
|
os.path.join(args.save_dir, "tokenizer.model")) |
|
|
shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer_config.json"), |
|
|
os.path.join(args.save_dir, "tokenizer_config.json")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Merge PEFT models") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--base_model_path", default="/home/shuai/pretrained/meta-llama/Llama-2-13b-hf", type=str) |
|
|
parser.add_argument("--peft_model_path", default="/home/shuai/output/rlhf/sim_conf_sft_llama_13b_001", type=str) |
|
|
args = parser.parse_args() |
|
|
args.save_dir = args.peft_model_path + "_merged" |
|
|
print(vars(args)) |
|
|
main(args) |