import argparse
import torch
import torch.nn as nn
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from modeling_oursvd_llama import CovSVDLlavaLlamaForCausalLM
import json
import os


def merge_adapter(args):
    # 使用 load_pretrained_model 加载模型
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        args.model_path, 
        args.model_base, 
        model_name, 
        device_map='cpu'
    )

    print("\n---- model before merge ---\n")
    print(model)

    # 获取所有模块的完整名称映射
    full_name_dict = {module: name for name, module in model.named_modules()}
    linear_info = {}
    modules = [model]
    
    # 收集所有 CovSVDLinear 模块的信息
    while len(modules) > 0:
        submodule = modules.pop()
        for name, raw_linear in submodule.named_children():
            if name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]:
                full_name = full_name_dict[raw_linear]
                linear_info[raw_linear] = {
                    "father": submodule,
                    "name": name,
                    "full_name": full_name,
                }
            else:
                modules.append(raw_linear)

    # 开始合并适配器权重
    print("\nbegin merge. \n")
    module_dict = {module: name for name, module in model.named_modules()}
    
    for module in module_dict.keys():
        if type(module).__name__ == "CovSVDLinear":
            info = linear_info[module]
            in_features = module.BLinear.in_features
            out_features = module.ALinear.out_features
            new_linear = nn.Linear(in_features, out_features, bias=False)
            
            # 合并权重: A @ B + residual
            merged_weight = module.ALinear.weight.data @ module.BLinear.weight.data + module.weight_residual
            
            new_linear.weight.data = merged_weight
            
            # 替换模块
            delattr(info["father"], info["name"])
            setattr(info["father"], info["name"], new_linear)

    print("\n---- model after merge ---\n")
    print(model)

    # 保存合并后的模型
    if args.save_model:
        assert args.save_path is not None
        save_path = args.save_path

        # 保存分词器和模型
        tokenizer.save_pretrained(save_path)
        model.save_pretrained(save_path)
        
        # 更新配置文件
        config = model.config.to_dict()
        config["architectures"] = ["LlamaForCausalLM"]
        
        # 删除适配器相关的配置
        if "lora_r" in config:
            del config["lora_r"]
        if "auto_map" in config:
            del config["auto_map"]
        if "_name_or_path" in config:
            del config["_name_or_path"]

        # 保存配置文件
        with open(os.path.join(save_path, "config.json"), "w") as f:
            json.dump(config, f, indent=2)

        print(f"Done merging adapter into the original model architecture in {save_path}")
        
        # 清理内存
        del model
        del tokenizer





if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        type=str,
        # required=True,
        default='/media/raid/workspace/jiangkailin/LLaVA/checkpoints/llava-v1.5-7b-task',
        help="Path to the model with adapter weights"
    )
    parser.add_argument(
        "--model_base",
        type=str,
        default=None,
        help="Path to the base model (if needed)"
    )
    parser.add_argument(
        "--save_model",
        action="store_true",
        default=True,
        help="Whether to save the merged model"
    )
    parser.add_argument(
        "--save_path",
        type=str,
        # required=True,
        help="Path to save the merged model",
         default='mllm_save_LoRA_Null_adapter_llama2_PT_128_math_Null_v1_merged',
    )

    args = parser.parse_args()

    merge_adapter(args)