STAMP-2B-uni / segment_predictor_cache.py
realzliu
init
7c15ab5
import torch
from PIL import Image
from transformers import AutoProcessor, DynamicCache
import numpy as np
import torch.nn.functional as F
from model.qwen_changes import get_rope_index, SegQwenVL
import os
import json
import time
def find_image_patch_info(image_pad_id, input_ids: torch.Tensor):
"""
From the end to the beginning, find consecutive image_pad_id in the input tensor and return their count.
Parameters:
image_pad_id (int): The ID of the image padding token.
input_ids (torch.Tensor): The input tensor of IDs.
Returns:
int: The number of consecutive image patches.
Raises:
RuntimeError: If no image patches (<|image_pad|>) are found in input_ids.
"""
input_ids_list = input_ids.squeeze().tolist()
# Reverse the list to search from the end to the beginning
reversed_input_ids_list = input_ids_list[::-1]
try:
# Find the first occurrence of image_pad_id in the reversed list
start_idx_rev = reversed_input_ids_list.index(image_pad_id)
end_idx_rev = start_idx_rev
# Continue to find consecutive image_pad_id
while end_idx_rev + 1 < len(reversed_input_ids_list) and reversed_input_ids_list[
end_idx_rev + 1] == image_pad_id:
end_idx_rev += 1
num_patches = (end_idx_rev - start_idx_rev) + 1
return num_patches
except ValueError:
raise RuntimeError("No image patches (<|image_pad|>) found in input_ids.")
class GenerativeSegmenter:
def __init__(self, model_path: str, min_pixels, max_pixels, **kwargs):
min_pixels = min_pixels
max_pixels = max_pixels
self.device = kwargs.get("device_map", "cuda" if torch.cuda.is_available() else "cpu")
# --- New intelligent loading logic ---
adapter_config_path = os.path.join(model_path, "adapter_config.json")
if os.path.exists(adapter_config_path):
print(f"Detected PEFT adapter configuration: {adapter_config_path}. Will load base model first, then load adapter.")
# Read the base model path from the adapter configuration
with open(adapter_config_path, 'r', encoding='utf-8') as f:
adapter_config = json.load(f)
# Base model path, if not present in the config, you need to specify it manually
base_model_path = adapter_config.get("base_model_name_or_path")
if not base_model_path:
# ********************************************************************************
# ** Important: If adapter_config.json does not contain base_model_name_or_path,
# ** please manually specify the correct base model name or path here
# ** Based on your previous error messages, the base model is likely "Qwen/Qwen2-VL-7B-Instruct"
# ********************************************************************************
base_model_path = "Qwen/Qwen2-VL-7B-Instruct"
print(f"Warning: 'base_model_name_or_path' not found in adapter configuration. Using default base model: '{base_model_path}'")
# 1. Load the base model
print(f"Loading base model from '{base_model_path}'...")
self.model = SegQwenVL.from_pretrained(
base_model_path,
torch_dtype="auto",
trust_remote_code=True,
# attn_implementation="flash_attention_2",
**kwargs
)
self.processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True,
min_pixels=min_pixels, max_pixels=max_pixels)
self.tokenizer = self.processor.tokenizer
self._add_special_tokens()
# 2. Load the adapter
print(f"Loading adapter from '{model_path}'...")
self.model.load_adapter(model_path)
else:
print(f"No PEFT adapter detected. Loading full model directly from '{model_path}'.")
# Keep the original direct loading method
self.model = SegQwenVL.from_pretrained(
model_path,
torch_dtype="auto",
trust_remote_code=True,
# attn_implementation="flash_attention_2",
**kwargs
)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, min_pixels=min_pixels,
max_pixels=max_pixels)
self.tokenizer = self.processor.tokenizer
self._add_special_tokens()
# --- Intelligent loading logic ends ---
TargetClass = type(self.model.model)
TargetClass.get_rope_index = get_rope_index
# Get key token IDs
self.yes_token_id = self.tokenizer.convert_tokens_to_ids("<|yes|>")
self.no_token_id = self.tokenizer.convert_tokens_to_ids("<|no|>")
self.seg_token_id = self.tokenizer.convert_tokens_to_ids("<|seg|>")
self.mask_token_id = self.tokenizer.convert_tokens_to_ids("<|mask|>")
self.image_pad_id = self.tokenizer.convert_tokens_to_ids('<|image_pad|>')
self.eos_token_id = self.tokenizer.eos_token_id
self.model.mask_token_id = self.mask_token_id
def _add_special_tokens(self):
special_tokens = {'additional_special_tokens': ["<|seg|>", "<|mask|>", "<|yes|>", "<|no|>"]}
num_added = self.tokenizer.add_special_tokens(special_tokens)
if num_added > 0:
print(f"Added {num_added} special tokens. Resizing model embedding layer...")
self.model.resize_token_embeddings(len(self.tokenizer))
# Check if the resized size matches your model's expectations
print(
f"Resized vocabulary size: {len(self.tokenizer)}, Model embedding layer size: {self.model.get_input_embeddings().weight.shape[0]}")
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
@torch.no_grad()
def generate_with_segmentation(self, image: Image.Image, prompt: str):
messages = [{"role": "user", "content": [{"image": image}, {"text": prompt}]}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text], images=[image], return_tensors="pt")
merge_size = self.processor.image_processor.merge_size
inputs = {k: v.to(self.device) for k, v in inputs.items()}
prompt_len = inputs['input_ids'].shape[1]
image_grid_thw = inputs.get('image_grid_thw').to(self.device) # Qwen2.5-VL may use this key
attention_mask_raw = inputs['attention_mask'].to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=1024,
use_cache=True,
return_dict_in_generate=True,
eos_token_id=self.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
)
sequence = outputs.sequences[0]
full_past_key_values = outputs.past_key_values
# Find all <seg> token positions
seg_indices = torch.where(sequence == self.seg_token_id)[0].tolist()
all_segmentation_masks = []
seg_forward_times = [] # Initialize list to store times
if not seg_indices: # If there are no segmentation tasks
generated_ids = sequence[prompt_len:]
response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
return None, response_text
num_patches = find_image_patch_info(self.image_pad_id, inputs['input_ids'])
# Iterate over each <seg> token and perform segmentation
for i, idx in enumerate(seg_indices):
sliced_len = idx + 1
attention_mask = attention_mask_raw[:, :sliced_len]
legacy_cache = full_past_key_values.to_legacy_cache()
# 2. Slice each tensor in the tuple
past_key_values_sliced = tuple(
(
key_layer[:, :, :sliced_len, :],
value_layer[:, :, :sliced_len, :]
)
for key_layer, value_layer in legacy_cache
)
past_key_values_sliced = DynamicCache.from_legacy_cache(past_key_values_sliced)
mask_query_ids = torch.full((1, num_patches), self.mask_token_id, dtype=torch.long, device=self.device)
mask_query_attention_mask = torch.ones((1, num_patches + sliced_len - attention_mask[0].sum()),
dtype=torch.long, device=self.device)
mask_query_attention_mask = torch.cat((attention_mask, mask_query_attention_mask), dim=1)
mask_grid_thw = image_grid_thw[-1].clone()
mask_grid_thw = mask_grid_thw.unsqueeze(0)
mask_pre_ids = sequence.clone().unsqueeze(0)
mask_ids = torch.cat([mask_pre_ids[0, :sliced_len], mask_query_ids[0]], dim=0).unsqueeze(0)
seg_forward_outputs = self.model(
input_ids=mask_ids,
attention_mask=mask_query_attention_mask,
image_grid_thw=image_grid_thw,
pixel_values=inputs['pixel_values'],
past_key_values=past_key_values_sliced,
return_dict=True,
do_classification=True
)
mask_logits = seg_forward_outputs.bi_logits[:, -num_patches:]
segmentation_preds = (mask_logits > 0).long().squeeze().cpu()
h_grid, w_grid = mask_grid_thw[0, 1:]
h_grid, w_grid = int(h_grid / merge_size), int(w_grid / merge_size)
segmentation_preds = segmentation_preds.view(h_grid, w_grid)
all_segmentation_masks.append(segmentation_preds)
generated_ids = sequence[prompt_len:]
response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
return all_segmentation_masks, response_text