from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# 模型路径（根据实际调整）
model_path = "/data/jhb_data/codes/LLaVA/save_LoRA_Null_adapter_llama2_PT_128"
OFFICIAL_MODEL_NAME = "liuhaotian/llava-v1.5-7b"  # 示例：7B官方模型
# 必须从LLaVA代码库导入专用类（关键步骤！）
from llava.model import LlavaLlamaForCausalLM  # 需先安装llava包

# 加载配置和模型
model = LlavaLlamaForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,  # 半精度节省显存
    device_map="auto"
)
i=0
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")
    if i>10:
        break
    i+=1
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True
)

# 验证视觉编码器是否加载
assert hasattr(model.model, 'vision_tower'), "视觉编码器未正确加载！"
print(f"视觉编码器类型: {type(model.model.vision_tower)}")