File size: 2,243 Bytes
df2d4df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# coding=utf-8
import os
import torch
import shutil
import argparse
from transformers import AutoModelForCausalLM
from peft import PeftModel


def main(args):
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path,
                                                        torch_dtype=torch.bfloat16)
    model = PeftModel.from_pretrained(base_model, args.peft_model_path)
    print("\n>>> Base Model + PEFT before merging:\n", model)
    model = model.merge_and_unload()
    print("\n>>> Base Model + PEFT after merging:\n", model)
    print("\n>>> Save model into {}".format(args.save_dir))
    model.save_pretrained(args.save_dir, safe_serialization=True)

    print("\n>>> Copy tokenization files...")
    shutil.copyfile(os.path.join(args.peft_model_path, "special_tokens_map.json"),
                    os.path.join(args.save_dir, "special_tokens_map.json"))
    shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer.model"),
                    os.path.join(args.save_dir, "tokenizer.model"))
    shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer_config.json"),
                    os.path.join(args.save_dir, "tokenizer_config.json"))
    # Loading test
    # merged_model = AutoModelForCausalLM.from_pretrained(args.save_dir)
    # print(">>> Merged Model:\n", merged_model)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Merge PEFT models")
    # parser.add_argument("--base_model_path", default="/home/shuai/pretrained/meta-llama/Llama-2-7b-hf", type=str)
    # parser.add_argument("--base_model_path", default="/home/shuai/pretrained/HuggingFaceH4/zephyr-7b-alpha", type=str)
    # parser.add_argument("--base_model_path", default="/home/shuai/pretrained/mistralai/Mistral-7B-v0.1", type=str)
    # parser.add_argument("--peft_model_path", default="/home/shuai/output/rlhf/sim_conf_sft_zephyr_007", type=str)
    parser.add_argument("--base_model_path", default="/home/shuai/pretrained/meta-llama/Llama-2-13b-hf", type=str)
    parser.add_argument("--peft_model_path", default="/home/shuai/output/rlhf/sim_conf_sft_llama_13b_001", type=str)
    args = parser.parse_args()
    args.save_dir = args.peft_model_path + "_merged"
    print(vars(args))
    main(args)