import numpy as np
import argparse
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from modeling_oursvd_llama import CovSVDLinear
from transformers.models.llama.modeling_llama import LlamaSdpaAttention
from transformers.models.llama.modeling_llama import LlamaMLP
import os
import json

from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig


def main(args):
    # 加载分词器
    model_id = args.model_id
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    from modeling_oursvd_llama import CovSVDLlavaLlamaForCausalLM
    model = CovSVDLlavaLlamaForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True
    )

    # ----------- 新增：只加载ViT作为vision_tower -----------
    # config_path = os.path.join(model_id, "config.json")
    # if os.path.exists(config_path):
    #     with open(config_path, "r") as f:
    #         config = json.load(f)
    #     vision_path = config.get("mm_vision_tower", None)
    #     if vision_path is not None:
    #         try:
    #             from transformers import CLIPVisionModel
    #             vision_model = CLIPVisionModel.from_pretrained(vision_path)
    #             if hasattr(model, "vision_tower"):
    #                 model.vision_tower = vision_model
    #             elif hasattr(model.model, "vision_tower"):
    #                 model.model.vision_tower = vision_model
    #             print(f"已自动加载 ViT 视觉骨干: {vision_path}")
    #         except Exception as e:
    #             print(f"ViT 加载失败: {e}")
    # ----------- 新增结束 -----------

    print("\n---- model before merge ---\n")
    # print(model)

    # 收集需要合并的线性层信息
    full_name_dict = {module: name for name, module in model.named_modules()}
    linear_info = {}
    modules = [model]
    while len(modules) > 0:
        submodule = modules.pop()
        for name, raw_linear in submodule.named_children():
            if name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]:
                full_name = full_name_dict[raw_linear]
                linear_info[raw_linear] = {
                    "father": submodule,
                    "name": name,
                    "full_name": full_name,
                }
            else:
                modules.append(raw_linear)

    # 合并主干模型
    print("\nbegin merge. \n")
    module_dict = {module: name for name, module in model.named_modules()}
    for module in module_dict.keys():
        if type(module).__name__ == "CovSVDLinear":
            info = linear_info[module]
            in_features = module.BLinear.in_features
            out_features = module.ALinear.out_features
            new_linear = nn.Linear(in_features, out_features, bias=False)
            merged_weight = module.ALinear.weight.data @ module.BLinear.weight.data + module.weight_residual
            new_linear.weight.data = merged_weight
            delattr(info["father"], info["name"])
            setattr(info["father"], info["name"], new_linear)

    print("\n---- model after merge ---\n")
    # print(model)

    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default='/media/raid/workspace/jiangkailin/LLaVA/checkpoints/llava-v1.5-7b-task',
        help="Pretrained model ID",
    )
    parser.add_argument(
        "--save_model",
        default=True,
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default='mllm_save_LoRA_Null_adapter_llama2_PT_128_math_Null_v1_merged',
    )

    args = parser.parse_args()

    main(args)