import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
from scipy.ndimage import zoom
import argparse
import sys
from matplotlib.colors import PowerNorm

# CKA专用配色方案
CMAP_CKA = sns.blend_palette(
    [
        "#ffffff",  # 纯白
        "#e3f2fd",  # 浅蓝
        "#2196f3",  # 蓝色
        "#1976d2",  # 深蓝
        "#0d47a1"   # 更深蓝
    ],
    as_cmap=True
)

def center_gram_matrix(gram_matrix):
    """
    对Gram矩阵进行中心化处理
    """
    n = gram_matrix.shape[0]
    unit = np.ones((n, n)) / n
    identity = np.eye(n)
    centered_gram = (identity - unit) @ gram_matrix @ (identity - unit)
    return centered_gram

def compute_cka_similarity(matrix1, matrix2):
    """
    计算两个矩阵之间的Centered Kernel Alignment (CKA)相似度
    
    Args:
        matrix1: 第一个矩阵 (n_samples, n_features)
        matrix2: 第二个矩阵 (n_samples, n_features)
    
    Returns:
        cka_score: CKA相似度分数 (0-1之间，1表示完全相似)
    """
    # 确保矩阵形状一致
    if matrix1.shape != matrix2.shape:
        # 如果形状不同，使用较小的维度进行下采样
        min_size = min(matrix1.shape[0], matrix2.shape[0])
        if matrix1.shape[0] != min_size:
            scale_factor = min_size / matrix1.shape[0]
            matrix1 = zoom(matrix1, scale_factor, order=1)
        if matrix2.shape[0] != min_size:
            scale_factor = min_size / matrix2.shape[0]
            matrix2 = zoom(matrix2, scale_factor, order=1)
    
    # 计算Gram矩阵
    gram1 = matrix1 @ matrix1.T
    gram2 = matrix2 @ matrix2.T
    
    # 中心化Gram矩阵
    centered_gram1 = center_gram_matrix(gram1)
    centered_gram2 = center_gram_matrix(gram2)
    
    # 计算CKA
    numerator = np.trace(centered_gram1 @ centered_gram2)
    denominator = np.sqrt(np.trace(centered_gram1 @ centered_gram1) * np.trace(centered_gram2 @ centered_gram2))
    
    if denominator == 0:
        return 0.0
    
    cka_score = numerator / denominator
    return max(0.0, min(1.0, cka_score))  # 确保在[0,1]范围内

def normalize_similarity_matrix(similarity_matrix):
    """
    对相似度矩阵进行min-max归一化，让差距更明显
    
    Args:
        similarity_matrix: 原始相似度矩阵
    
    Returns:
        normalized_matrix: 归一化后的相似度矩阵
        original_min: 原始最小值
        original_max: 原始最大值
    """
    # 获取上三角部分（排除对角线）
    upper_triangle = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]
    
    if len(upper_triangle) == 0:
        return similarity_matrix.copy(), 0, 1
    
    original_min = np.min(upper_triangle)
    original_max = np.max(upper_triangle)
    
    # 如果所有值都相同，返回原矩阵
    if original_max == original_min:
        return similarity_matrix.copy(), original_min, original_max
    
    # 创建归一化矩阵的副本
    normalized_matrix = similarity_matrix.copy()
    
    # 只对上三角部分进行归一化（保持对角线为1）
    for i in range(similarity_matrix.shape[0]):
        for j in range(i+1, similarity_matrix.shape[1]):
            normalized_matrix[i, j] = (similarity_matrix[i, j] - original_min) / (original_max - original_min)
            normalized_matrix[j, i] = normalized_matrix[i, j]  # 保持对称性
    
    return normalized_matrix, original_min, original_max

def load_covariance_data(file_path):
    """
    加载协方差矩阵数据
    
    Args:
        file_path: 协方差矩阵文件路径
    
    Returns:
        covariance_data: 加载的数据，如果失败返回None
    """
    if not os.path.exists(file_path):
        print(f"❌ 错误：文件不存在: {file_path}")
        return None
    
    try:
        print(f"🔄 正在加载: {file_path}")
        covariance_data = torch.load(file_path, map_location='cpu')
        print("✅ 文件加载成功！")
        return covariance_data
    except Exception as e:
        print(f"❌ 加载文件失败: {str(e)}")
        return None

def get_layer_matrices(covariance_data, target_layer):
    """
    获取指定层的所有矩阵
    """
    if not isinstance(covariance_data, dict):
        print("❌ 数据不是字典格式")
        return {}
    
    target_keys = [key for key in covariance_data.keys() if key.startswith(f"{target_layer}.")]
    layer_matrices = {}
    
    for key in target_keys:
        if isinstance(covariance_data[key], torch.Tensor) and len(covariance_data[key].shape) == 2:
            layer_matrices[key] = covariance_data[key].detach().cpu().numpy()
    
    return layer_matrices

def downsample_matrix(matrix, target_size=32):
    """
    下采样矩阵到指定大小
    """
    original_shape = matrix.shape
    if original_shape[0] != target_size:
        scale_factor = target_size / original_shape[0]
        downsampled_matrix = zoom(matrix, scale_factor, order=1)
    else:
        downsampled_matrix = matrix.copy()
    
    return downsampled_matrix

def analyze_dataset_cka_similarities(file_paths_and_names, target_layer, output_path, target_size=32, normalize=False):
    """
    分析所有数据集之间的CKA相似度，生成12x12的热力图
    
    Args:
        file_paths_and_names: 数据集文件路径和名称列表
        target_layer: 目标层数
        output_path: 输出路径
        target_size: 下采样目标大小
        normalize: 是否对相似度进行归一化
    """
    print("=" * 80)
    print(f"数据集间CKA相似度分析 - 第{target_layer}层")
    if normalize:
        print("🔧 启用相似度归一化模式")
    print("=" * 80)
    
    # 加载所有数据集的数据
    all_datasets_data = {}
    dataset_names = []
    
    for file_path, dataset_name in file_paths_and_names:
        print(f"\n🔄 加载数据集: {dataset_name}")
        covariance_data = load_covariance_data(file_path)
        if covariance_data is not None:
            layer_matrices = get_layer_matrices(covariance_data, target_layer)
            if layer_matrices:
                all_datasets_data[dataset_name] = layer_matrices
                dataset_names.append(dataset_name)
                print(f"  ✅ 成功加载 {len(layer_matrices)} 个矩阵")
            else:
                print(f"  ❌ 未找到层 {target_layer} 的数据")
        else:
            print(f"  ❌ 加载失败")
    
    if len(all_datasets_data) < 2:
        print("❌ 至少需要2个数据集才能进行CKA分析")
        return
    
    print(f"\n📊 成功加载 {len(all_datasets_data)} 个数据集")
    print(f"数据集: {dataset_names}")
    
    # 获取所有共同的矩阵名称
    all_matrix_names = set()
    for dataset_name, matrices in all_datasets_data.items():
        all_matrix_names.update(matrices.keys())
    
    print(f"\n�� 找到 {len(all_matrix_names)} 个共同的矩阵结构")
    print(f"矩阵名称: {sorted(all_matrix_names)}")
    
    # 创建输出目录
    output_dir = Path(output_path) / f"cka_analysis_layer_{target_layer}"
    if normalize:
        output_dir = output_dir / "normalized"
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"📁 输出目录: {output_dir}")
    
    # 对每个矩阵进行CKA分析
    for matrix_name in sorted(all_matrix_names):
        print(f"\n�� 分析矩阵: {matrix_name}")
        
        # 收集所有数据集的数据
        matrices_data = {}
        available_datasets = []
        
        for dataset_name in dataset_names:
            if matrix_name in all_datasets_data[dataset_name]:
                # 下采样到统一大小
                downsampled_matrix = downsample_matrix(all_datasets_data[dataset_name][matrix_name], target_size)
                matrices_data[dataset_name] = downsampled_matrix
                available_datasets.append(dataset_name)
        
        if len(available_datasets) < 2:
            print(f"  ⚠️  只有 {len(available_datasets)} 个数据集包含此矩阵，跳过")
            continue
        
        print(f"  📊 分析 {len(available_datasets)} 个数据集")
        
        # 计算CKA相似度矩阵
        n_datasets = len(available_datasets)
        similarity_matrix = np.zeros((n_datasets, n_datasets))
        
        for i, name1 in enumerate(available_datasets):
            for j, name2 in enumerate(available_datasets):
                if i == j:
                    similarity_matrix[i, j] = 1.0  # 对角线为1
                else:
                    similarity_matrix[i, j] = compute_cka_similarity(
                        matrices_data[name1], 
                        matrices_data[name2]
                    )
        
        # 保存原始相似度矩阵用于对比
        original_similarity_matrix = similarity_matrix.copy()
        original_min = None
        original_max = None
        
        # 如果启用归一化，对相似度矩阵进行归一化
        if normalize:
            similarity_matrix, original_min, original_max = normalize_similarity_matrix(similarity_matrix)
            print(f"  🔧 归一化完成 - 原始范围: [{original_min:.4f}, {original_max:.4f}]")
        
        # 打印相似度矩阵
        matrix_type = "归一化CKA相似度矩阵" if normalize else "CKA相似度矩阵"
        print(f"\n  📈 {matrix_type}:")
        print("     " + " ".join(f"{name:>12}" for name in available_datasets))
        for i, name in enumerate(available_datasets):
            row_str = f"{name:>12}"
            for j in range(len(available_datasets)):
                row_str += f" {similarity_matrix[i, j]:>12.4f}"
            print(row_str)
        
        # 计算统计信息
        upper_triangle = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]
        mean_similarity = np.mean(upper_triangle)
        std_similarity = np.std(upper_triangle)
        min_similarity = np.min(upper_triangle)
        max_similarity = np.max(upper_triangle)
        
        print(f"\n  📊 统计信息:")
        print(f"    平均相似度: {mean_similarity:.4f}")
        print(f"    标准差: {std_similarity:.4f}")
        print(f"    最小相似度: {min_similarity:.4f}")
        print(f"    最大相似度: {max_similarity:.4f}")
        
        if normalize and original_min is not None and original_max is not None:
            print(f"    原始范围: [{original_min:.4f}, {original_max:.4f}]")
        
        # 绘制CKA相似度热力图
        plt.figure(figsize=(14, 12))
        plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False
        
        # 创建热力图
        title_suffix = " (归一化)" if normalize else ""
        cbar_label = "归一化CKA相似度" if normalize else "CKA相似度"
        
        sns.heatmap(
            similarity_matrix,
            annot=True,
            fmt='.4f',
            cmap=CMAP_CKA,
            square=True,
            xticklabels=available_datasets,
            yticklabels=available_datasets,
            cbar_kws={'label': cbar_label},
            vmin=0,
            vmax=1
        )
        
        plt.title(f'第{target_layer}层 - {matrix_name} CKA相似度分析{title_suffix}\n'
                 f'平均相似度: {mean_similarity:.4f} ± {std_similarity:.4f}', 
                 fontsize=16, fontweight='bold')
        plt.xlabel('数据集', fontsize=14)
        plt.ylabel('数据集', fontsize=14)
        
        # 调整布局
        plt.tight_layout()
        
        # 保存图片
        safe_matrix_name = matrix_name.replace('.', '_').replace(':', '_')
        suffix = "_normalized" if normalize else ""
        cka_image_path = output_dir / f"{safe_matrix_name}_cka_similarity{suffix}.png"
        plt.savefig(cka_image_path, dpi=300, bbox_inches='tight')
        print(f"  ✅ CKA热力图已保存: {cka_image_path}")
        
        plt.close()
        
        # 保存CKA数据
        cka_data_path = output_dir / f"{safe_matrix_name}_cka_data{suffix}.npz"
        save_data = {
            'similarity_matrix': similarity_matrix,
            'dataset_names': available_datasets,
            'matrix_name': matrix_name,
            'layer': target_layer,
            'mean_similarity': mean_similarity,
            'std_similarity': std_similarity,
            'min_similarity': min_similarity,
            'max_similarity': max_similarity,
            'normalized': normalize
        }
        
        if normalize:
            save_data.update({
                'original_similarity_matrix': original_similarity_matrix,
                'original_min': original_min,
                'original_max': original_max
            })
        
        np.savez_compressed(cka_data_path, **save_data)
        print(f"  ✅ CKA数据已保存: {cka_data_path}")
        
        # 创建数据集相似度排名
        dataset_similarities = {}
        for i, name in enumerate(available_datasets):
            # 计算与其他数据集的平均相似度（排除自己）
            other_similarities = [similarity_matrix[i, j] for j in range(len(available_datasets)) if i != j]
            dataset_similarities[name] = np.mean(other_similarities)
        
        # 按相似度排序
        sorted_datasets = sorted(dataset_similarities.items(), key=lambda x: x[1], reverse=True)
        
        print(f"\n  📊 数据集相似度排名:")
        for rank, (dataset_name, avg_similarity) in enumerate(sorted_datasets, 1):
            print(f"    {rank:2d}. {dataset_name}: {avg_similarity:.4f}")
        
        # 保存排名数据
        ranking_data_path = output_dir / f"{safe_matrix_name}_similarity_ranking{suffix}.txt"
        with open(ranking_data_path, 'w', encoding='utf-8') as f:
            f.write(f"CKA相似度排名 - 第{target_layer}层 {matrix_name}{title_suffix}\n")
            f.write("=" * 50 + "\n")
            f.write(f"平均相似度: {mean_similarity:.4f} ± {std_similarity:.4f}\n")
            f.write(f"相似度范围: [{min_similarity:.4f}, {max_similarity:.4f}]\n")
            if normalize and original_min is not None and original_max is not None:
                f.write(f"原始相似度范围: [{original_min:.4f}, {original_max:.4f}]\n")
            f.write("\n数据集排名:\n")
            for rank, (dataset_name, avg_similarity) in enumerate(sorted_datasets, 1):
                f.write(f"{rank:2d}. {dataset_name}: {avg_similarity:.4f}\n")
        
        print(f"  ✅ 排名数据已保存: {ranking_data_path}")
    
    print(f"\n" + "=" * 80)
    print("✅ CKA相似度分析完成！")
    print(f"�� 所有分析结果保存在: {output_dir}")
    print("=" * 80)

def parse_arguments():
    """
    解析命令行参数
    """
    parser = argparse.ArgumentParser(description='数据集间CKA相似度分析工具')
    
    parser.add_argument('--layer', '-l', type=int, required=True,
                       help='指定要分析的层数 (0-31)')
    parser.add_argument('--output-path', '-o', type=str, 
                       default="/home/bingxing2/ailab/scx6mh7/jkl/LLaVA_8_8_null_space/cka_outputs",
                       help='输出路径')
    parser.add_argument('--target-size', '-t', type=int, default=32,
                       help='下采样目标大小（默认: 32）')
    parser.add_argument('--normalize', '-n', action='store_true',
                       help='是否对相似度进行归一化处理（让差距更明显）')
    
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_arguments()
    
    # 数据集文件路径和名称
    file_paths_and_names = [
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MME_pre_256_seed_233.pt", "MME"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MMBench_DEV_EN_pre_256_seed_233.pt", "MMBench"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_OCRBench_pre_256_seed_233.pt", "OCRBench"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_SEEDBench2_Plus_pre_256_seed_233.pt", "SEEDBench2_Plus"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_ScienceQA_VAL_pre_256_seed_233.pt", "ScienceQA"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MMMU_TEST_pre_256_seed_233.pt", "MMMU"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MathVista_MINI_pre_256_seed_233.pt", "MathVista"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MathVision_pre_256_seed_233.pt", "MathVision"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MMDU_pre_256_seed_233.pt", "MMDU"),
        # ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_MIA-Bench_pre_256_seed_233.pt", "MIA-Bench"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_POPE_pre_256_seed_233.pt", "POPE"),
        ("/home/bingxing2/ailab/group/ai4bio/renyuchen/jkl/benchmark_covariance_matrix/cache/llava-v1.5-7b_covariance_matrices_from_HallusionBench_pre_256_seed_233.pt", "HallusionBench"),
    ]
    
    target_layer = args.layer
    
    if target_layer < 0 or target_layer > 31:
        print(f"❌ 错误：层数必须在 0-31 范围内，当前指定层数：{target_layer}")
        sys.exit(1)
    
    if args.target_size <= 0:
        print(f"❌ 错误：目标大小必须大于0，当前指定大小：{args.target_size}")
        sys.exit(1)
    
    print(f"🎯 分析模式：第{target_layer}层")
    print(f"📊 数据集数量：{len(file_paths_and_names)}")
    print(f"�� 输出路径: {args.output_path}")
    print(f"�� 目标大小: {args.target_size}")
    if args.normalize:
        print(f"�� 归一化模式: 已启用")
    
    # 执行CKA分析
    analyze_dataset_cka_similarities(
        file_paths_and_names, 
        target_layer, 
        args.output_path, 
        args.target_size,
        args.normalize
    )
    
    print("\n✅ 分析完成！")