import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
from pathlib import Path
# 假设代码中也导入了 torch.nn.functional
import torch.nn.functional as F

# 定义文件路径和对应的名称映射
file_paths_and_names = [
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MME_pre_256_seed_233.pt", "MME"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MMBench_DEV_EN_pre_256_seed_233.pt", "MMBench"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_OCRBench_pre_256_seed_233.pt", "OCRBench"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_SEEDBench2_Plus_pre_256_seed_233.pt", "SEEDBench2_Plus"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_ScienceQA_VAL_pre_256_seed_233.pt", "ScienceQA"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MMMU_TEST_pre_256_seed_233.pt", "MMMU"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MathVista_MINI_pre_256_seed_233.pt", "MathVista"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MathVision_pre_256_seed_233.pt", "MathVision"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MMDU_pre_256_seed_233.pt", "MMDU"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MIA-Bench_pre_256_seed_233.pt", "MIA-Bench"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_POPE_pre_256_seed_233.pt", "POPE"),
    ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_HallusionBench_pre_256_seed_233.pt", "HallusionBench"),
]

# 创建名称到路径的映射字典
name_to_path = {name: path for path, name in file_paths_and_names}

def print_model_structure(data_dict, dataset_name):
    """打印模型结构信息"""
    print(f"\n=== {dataset_name} 模型结构 ===")
    print(f"总层数: {len(data_dict)}")
    
    # 获取所有层名称并排序
    layer_names = sorted(data_dict.keys())
    
    # 按层分组显示
    layer_groups = {}
    for layer_name in layer_names:
        # 提取层号
        if 'model.layers.' in layer_name:
            parts = layer_name.split('.')
            if len(parts) >= 3:
                layer_num = parts[2]
                if layer_num not in layer_groups:
                    layer_groups[layer_num] = []
                layer_groups[layer_num].append(layer_name)
    
    # 打印每层的模块
    for layer_num in sorted(layer_groups.keys(), key=int):
        print(f"\n第 {layer_num} 层:")
        for module_name in sorted(layer_groups[layer_num]):
            tensor_shape = data_dict[module_name].shape
            print(f"  {module_name}: {tensor_shape}")
    
    # 打印一些统计信息
    print(f"\n统计信息:")
    print(f"  总模块数: {len(layer_names)}")
    print(f"  总层数: {len(layer_groups)}")
    
    # 显示前几个和后几个层名称作为示例
    if len(layer_names) > 10:
        print(f"  前5个模块: {layer_names[:5]}")
        print(f"  后5个模块: {layer_names[-5:]}")
    else:
        print(f"  所有模块: {layer_names}")

def load_data_for_datasets(dataset_names):
    """根据数据集名称加载对应的数据文件"""
    selected_files = []
    for name in dataset_names:
        if name in name_to_path:
            selected_files.append((name_to_path[name], name))
        else:
            print(f"警告: 未找到数据集 '{name}' 对应的文件路径")
    
    if not selected_files:
        raise ValueError("没有找到任何有效的数据集文件")
    
    three_dict = {}
    for file_path, name in selected_files:
        print(f"正在加载 {name} 的数据文件...")
        data = torch.load(file_path, map_location="cpu")
        three_dict[name] = data
        
        # 打印模型结构
        print_model_structure(data, name)
    
    print(f"\n成功加载了 {len(three_dict)} 个数据文件")
    return three_dict

def get_layer_modules_for_layer(layer_num):
    """获取指定层的所有模块名称"""
    modules = [
        f'{layer_num}.self_attn.q_proj',
        f'{layer_num}.self_attn.k_proj', 
        f'{layer_num}.self_attn.v_proj',
        f'{layer_num}.self_attn.o_proj',
        f'{layer_num}.mlp.gate_proj',
        f'{layer_num}.mlp.up_proj',
        f'{layer_num}.mlp.down_proj'
    ]
    return modules

def visualize(layer_name, three_dict, target_size, output_path, unified_color_scale=False):
    """可视化指定层的协方差矩阵"""
    print(f"--- 处理层: {layer_name}")
    
    # 创建下采样层
    m = nn.AdaptiveAvgPool2d(target_size)
    
    cov_down_list = []
    dataset_names = []
    
    for dataset_name, all_cov_matrix in three_dict.items():
        if layer_name in all_cov_matrix:
            cov = all_cov_matrix[layer_name]
            img_size = cov.size(0)
            
            cov = cov.view(1, 1, img_size, img_size)
            cov_down = m(cov).view(target_size, target_size)
            cov_down_list.append(cov_down)
            dataset_names.append(dataset_name)
            
            print(f"{dataset_name} 处理完成")
        else:
            print(f"警告: 在 {dataset_name} 中未找到层 {layer_name}")
    
    if not cov_down_list:
        print(f"错误: 没有找到层 {layer_name} 的数据")
        return
    
    # 创建子图
    num_datasets = len(cov_down_list)
    fig, axes = plt.subplots(1, num_datasets, figsize=(6 * num_datasets, 6))
    
    # 如果只有一个数据集，确保axes是数组
    if num_datasets == 1:
        axes = [axes]
    
    fig.suptitle(layer_name, fontsize=16)
    
    # 如果启用统一颜色深度，计算全局的最大最小值
    if unified_color_scale and len(cov_down_list) > 1:
        # 计算所有数据集的全局最大最小值
        all_values = torch.cat([torch.exp(cov_down * 2).flatten() for cov_down in cov_down_list])
        vmin, vmax = all_values.min().item(), all_values.max().item()
        print(f"统一颜色范围: [{vmin:.4f}, {vmax:.4f}]")
    else:
        vmin, vmax = None, None
    
    for i, (cov_down, dataset_name) in enumerate(zip(cov_down_list, dataset_names)):
        # 应用指数变换
        cov_transformed = torch.exp(cov_down * 2)
        
        # 根据是否统一颜色深度来设置参数
        if unified_color_scale and len(cov_down_list) > 1:
            im = axes[i].imshow(cov_transformed, cmap='YlOrRd', aspect='auto', vmin=vmin, vmax=vmax)
        else:
            im = axes[i].imshow(cov_transformed, cmap='YlOrRd', aspect='auto')
            
        axes[i].set_title(dataset_name, fontsize=14)
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        
        # 添加颜色条
        plt.colorbar(im, ax=axes[i], shrink=0.8)
    
    # 调整子图间距
    plt.tight_layout()
    
    # 保存图像
    output_file = os.path.join(output_path, f"{layer_name.replace('.', '_')}.png")
    plt.savefig(output_file, format='png', bbox_inches='tight', dpi=300)
    plt.close()
    print(f"图像已保存到: {output_file}")

def main():
    parser = argparse.ArgumentParser(description='可视化协方差矩阵热力图')
    parser.add_argument('--datasets', nargs='+', required=True, 
                       help='要处理的数据集名称列表')
    parser.add_argument('--output-path', required=True, 
                       help='输出图像的路径')
    parser.add_argument('--target-size', type=int, default=32, 
                       help='目标下采样尺寸 (默认: 32)')
    parser.add_argument('--layers', nargs='+', type=int, required=True,
                       help='要处理的层数列表，例如: 0 1 2')
    parser.add_argument('--print-structure', action='store_true',
                       help='是否打印模型结构信息')
    parser.add_argument('--unified-color-scale', action='store_true',
                       help='是否使用统一的颜色深度范围 (所有子图使用相同的颜色度量)')
    
    args = parser.parse_args()
    
    # 创建输出目录
    output_path = Path(args.output_path)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # 加载数据
    three_dict = load_data_for_datasets(args.datasets)
    
    # 如果指定了打印结构，则打印所有数据集的结构
    if args.print_structure:
        print("\n" + "="*50)
        print("详细模型结构信息:")
        print("="*50)
        for dataset_name, data in three_dict.items():
            print_model_structure(data, dataset_name)
        print("="*50)
    
    print(f"\n开始处理层: {args.layers}")
    print(f"目标尺寸: {args.target_size}")
    print(f"输出路径: {output_path}")
    print(f"统一颜色深度: {args.unified_color_scale}")
    
    # 处理每个指定的层
    for layer_num in args.layers:
        print(f"\n=== 处理第 {layer_num} 层 ===")
        
        # 为每个层创建子目录
        layer_output_path = output_path / str(layer_num)
        layer_output_path.mkdir(exist_ok=True)
        
        # 获取该层的所有模块
        layer_modules = get_layer_modules_for_layer(layer_num)
        
        print(f"第 {layer_num} 层包含 {len(layer_modules)} 个模块")
        
        # 处理该层的每个模块
        for i, layer_name in enumerate(layer_modules):
            print(f"进度: {i+1}/{len(layer_modules)} - {layer_name}")
            try:
                visualize(layer_name, three_dict, args.target_size, layer_output_path, args.unified_color_scale)
            except Exception as e:
                print(f"处理层 {layer_name} 时出错: {e}")
                continue
    
    print("\n所有处理完成!")

if __name__ == "__main__":
    main()

