realzliu commited on
Commit
7c15ab5
·
1 Parent(s): 2e5b533
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import cv2
6
+ import os
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # --- Import local modules (Ensure these files are uploaded to the Space) ---
11
+ try:
12
+ from segment_predictor_cache import GenerativeSegmenter
13
+ from model.segment_anything import sam_model_registry, SamPredictor
14
+ from eval.utils import compute_logits_from_mask, masks_sample_points
15
+ except ImportError as e:
16
+ raise ImportError(f"Could not import custom modules: {e}. Please ensure STAMP source code (model/, eval/, segment_predictor_cache.py) is uploaded to the Space.")
17
+
18
+ # --- Configuration ---
19
+ MODEL_PATH = "JiaZL/STAMP-2B-uni"
20
+ # Use a specific repo to download SAM weights automatically
21
+ SAM_REPO_ID = "HCMUE-Research/SAM-vit-h"
22
+ SAM_FILENAME = "sam_vit_h_4b8939.pth"
23
+
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ print(f"Running on {DEVICE}...")
27
+
28
+ # --- Load Models (Cached globally) ---
29
+ def load_models():
30
+ print(f"Loading STAMP model from {MODEL_PATH}...")
31
+ # Adjust min/max pixels if running into OOM on smaller GPUs
32
+ segmenter = GenerativeSegmenter(
33
+ MODEL_PATH,
34
+ device_map=DEVICE,
35
+ min_pixels=512 * 28 * 28, # Reduced slightly for Space stability
36
+ max_pixels=1024 * 28 * 28
37
+ )
38
+
39
+ print("Downloading and Loading SAM model...")
40
+ sam_checkpoint = hf_hub_download(repo_id=SAM_REPO_ID, filename=SAM_FILENAME)
41
+ sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
42
+ sam = sam.to(dtype=torch.float32, device=DEVICE)
43
+ predictor = SamPredictor(sam)
44
+
45
+ return segmenter, predictor
46
+
47
+ # Initialize models
48
+ segmenter, sam_predictor = load_models()
49
+
50
+ # --- Core Inference Function ---
51
+ def run_inference(image, query, use_sam=True):
52
+ if image is None:
53
+ return None, "Please upload an image."
54
+ if not query:
55
+ return None, "Please enter a query."
56
+
57
+ # Convert to RGB PIL Image
58
+ image_pil = Image.fromarray(image).convert("RGB")
59
+ w_ori, h_ori = image_pil.size
60
+
61
+ with torch.inference_mode():
62
+ # 1. Set SAM image embedding
63
+ if use_sam:
64
+ sam_predictor.set_image(np.array(image_pil))
65
+
66
+ # 2. Generate Coarse Mask using STAMP
67
+ print(f"Generating coarse mask for query: {query}")
68
+ segmentation_masks, response_text = segmenter.generate_with_segmentation(
69
+ image_pil, query
70
+ )
71
+
72
+ if not segmentation_masks or len(segmentation_masks) == 0:
73
+ return image, f"No mask generated. Model response: {response_text}"
74
+
75
+ # Extract the first mask
76
+ mask = segmentation_masks[0]
77
+
78
+ # Resize coarse mask to original image size
79
+ mask_pred = F.interpolate(
80
+ mask.unsqueeze(0).unsqueeze(0).double(),
81
+ size=(h_ori, w_ori),
82
+ mode='nearest'
83
+ ).squeeze(0).squeeze(0)
84
+
85
+ # --- SAM Refinement ---
86
+ final_mask = np.zeros((h_ori, w_ori), dtype=np.float32)
87
+
88
+ if use_sam:
89
+ print("Refining mask with SAM...")
90
+ unique_classes = torch.unique(mask_pred)
91
+
92
+ for class_id in unique_classes:
93
+ if class_id == 0: continue
94
+
95
+ # Get binary mask for current class
96
+ binary_mask = (mask_pred == class_id).double().cpu()
97
+
98
+ try:
99
+ logits = compute_logits_from_mask(binary_mask)
100
+ point_coords, point_labels = masks_sample_points(binary_mask)
101
+
102
+ # First pass
103
+ sam_mask, _, logit = sam_predictor.predict(
104
+ point_coords=point_coords,
105
+ point_labels=point_labels,
106
+ mask_input=logits,
107
+ multimask_output=False
108
+ )
109
+
110
+ # Iterative refinement
111
+ for _ in range(2):
112
+ sam_mask, _, logit = sam_predictor.predict(
113
+ point_coords=point_coords,
114
+ point_labels=point_labels,
115
+ mask_input=logit,
116
+ multimask_output=False
117
+ )
118
+
119
+ current_refined_mask = sam_mask[0].astype(np.float32)
120
+ final_mask = np.maximum(final_mask, current_refined_mask)
121
+
122
+ except Exception as e:
123
+ print(f"SAM Error for class {class_id}: {e}")
124
+ final_mask = np.maximum(final_mask, binary_mask.numpy())
125
+ else:
126
+ final_mask = mask_pred.cpu().numpy()
127
+
128
+ # --- Visualization ---
129
+ # Convert mask to uint8 (0 or 255)
130
+ mask_uint8 = (final_mask > 0).astype(np.uint8) * 255
131
+
132
+ # Create a red overlay
133
+ overlay = image.copy()
134
+ # Paint red where mask is present
135
+ # Format is BGR in OpenCV if read via cv2, but Gradio sends RGB numpy array
136
+ # We want Red: (255, 0, 0)
137
+
138
+ # Create colored mask
139
+ color_mask = np.zeros_like(image)
140
+ color_mask[:, :, 0] = 255 # R
141
+ color_mask[:, :, 1] = 0 # G
142
+ color_mask[:, :, 2] = 0 # B
143
+
144
+ # Blend
145
+ alpha = 0.5
146
+ mask_indices = mask_uint8 > 0
147
+ overlay[mask_indices] = (alpha * image[mask_indices] + (1 - alpha) * color_mask[mask_indices]).astype(np.uint8)
148
+
149
+ # Alternatively, just return the raw mask or the overlay.
150
+ # Here we return the overlay.
151
+
152
+ return overlay, f"Success! {response_text}"
153
+
154
+ # --- Gradio Interface ---
155
+ with gr.Blocks(title="STAMP + SAM Segmentation Demo") as demo:
156
+ gr.Markdown("# STAMP + SAM: Multimodal Segmentation")
157
+ gr.Markdown("Upload an image and provide a text query to segment objects using STAMP-2B-uni refined by SAM.")
158
+
159
+ with gr.Row():
160
+ with gr.Column():
161
+ input_image = gr.Image(label="Input Image", type="numpy")
162
+ text_query = gr.Textbox(label="Text Prompt", placeholder="e.g., segment the white horse")
163
+ use_sam_checkbox = gr.Checkbox(label="Refine with SAM", value=True)
164
+ submit_btn = gr.Button("Segment", variant="primary")
165
+
166
+ with gr.Column():
167
+ output_image = gr.Image(label="Segmentation Result")
168
+ status_text = gr.Textbox(label="Status/Response", interactive=False)
169
+
170
+ submit_btn.click(
171
+ fn=run_inference,
172
+ inputs=[input_image, text_query, use_sam_checkbox],
173
+ outputs=[output_image, status_text]
174
+ )
175
+
176
+ # Add examples
177
+ gr.Examples(
178
+ examples=[
179
+ ["images/horses.png", "segment the white horse", True]
180
+ ],
181
+ inputs=[input_image, text_query, use_sam_checkbox],
182
+ fn=run_inference, # Dummy fn for cache
183
+ cache_examples=False # Disable cache if no GPU on build
184
+ )
185
+
186
+ if __name__ == "__main__":
187
+ demo.launch()
eval/utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ from skimage.feature import peak_local_max
5
+ from skimage.filters import gaussian
6
+ from scipy.ndimage import distance_transform_edt
7
+ from .transforms import ResizeLongestSide
8
+ import torch
9
+ from enum import Enum
10
+ import torch.distributed as dist
11
+ from torchvision.ops.boxes import box_area
12
+ import numpy as np
13
+
14
+
15
+ def translate_sequence(sequence_str, labels_set):
16
+
17
+ # Split the string into a list of categories
18
+ sequence = sequence_str.split('|')
19
+
20
+ # strip the whitespace from each category
21
+ sequence = [seq.strip() for seq in sequence]
22
+
23
+ # Translate the sequence using the dictionary
24
+ # translated_sequence = [labels_set[item] for item in sequence]
25
+ translated_sequence = [labels_set.get(item, 0) for item in sequence]
26
+
27
+
28
+ return translated_sequence
29
+
30
+ def decode_mask(encoded_str):
31
+ rows = encoded_str.strip("\n").split("\n ")
32
+ decoded_list = []
33
+ for row in rows:
34
+ tokens = row.split("| ")
35
+ for token in tokens:
36
+ label, count = token.split(" *")
37
+ decoded_list.extend([label] * int(count))
38
+ return "|".join(decoded_list)
39
+
40
+ # compute the bounding box from a mask. SAM expects the following input:
41
+ # box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format.
42
+ def compute_box_from_mask(mask, original_size=None, box_extension=0):
43
+ coords = np.where(mask == 1)
44
+ min_y, min_x = coords[0].min(), coords[1].min()
45
+ max_y, max_x = coords[0].max(), coords[1].max()
46
+ box = np.array([min_y, min_x, max_y + 1, max_x + 1])
47
+ return process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
48
+
49
+
50
+ # sample points from a mask. SAM expects the following point inputs:
51
+ def compute_points_from_mask(mask, original_size, box_extension):
52
+ box = compute_box_from_mask(mask, box_extension=box_extension)
53
+
54
+ # get slice and offset in python coordinate convention
55
+ bb = (slice(box[1], box[3]), slice(box[0], box[2]))
56
+ offset = np.array([box[1], box[0]])
57
+
58
+ # crop the mask and compute distances
59
+ cropped_mask = mask[bb]
60
+ inner_distances = gaussian(distance_transform_edt(cropped_mask == 1))
61
+ outer_distances = gaussian(distance_transform_edt(cropped_mask == 0))
62
+
63
+ # sample positives and negatives from the distance maxima
64
+ inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3)
65
+ outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5)
66
+
67
+ # derive the positive (=inner maxima) and negative (=outer maxima) points
68
+ point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64")
69
+ point_coords += offset
70
+
71
+ if original_size is not None:
72
+ scale_factor = np.array([
73
+ original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1])
74
+ ])[None]
75
+ point_coords *= scale_factor
76
+
77
+ # get the point labels
78
+ point_labels = np.concatenate(
79
+ [
80
+ np.ones(len(inner_maxima), dtype="uint8"),
81
+ np.zeros(len(outer_maxima), dtype="uint8"),
82
+ ]
83
+ )
84
+ return point_coords[:, ::-1], point_labels
85
+
86
+ def compute_logits_from_mask(mask, eps=1e-3):
87
+
88
+ def inv_sigmoid(x):
89
+ return np.log(x / (1 - x))
90
+
91
+ logits = np.zeros(mask.shape, dtype="float32")
92
+ logits[mask == 1] = 1
93
+ logits[mask == 0] = 0
94
+
95
+ # resize to the expected mask shape of SAM (256x256)
96
+ assert logits.ndim == 2
97
+ expected_shape = (256, 256)
98
+
99
+ if logits.shape == expected_shape: # shape matches, do nothing
100
+ pass
101
+
102
+ elif logits.shape[0] == logits.shape[1]: # shape is square
103
+ trafo = ResizeLongestSide(expected_shape[0])
104
+ logits = trafo.apply_image(logits[..., None])
105
+
106
+ else: # shape is not square
107
+ # resize the longest side to expected shape
108
+ trafo = ResizeLongestSide(expected_shape[0])
109
+ logits = trafo.apply_image(logits[..., None])
110
+
111
+ # pad the other side
112
+ h, w = logits.shape
113
+ padh = expected_shape[0] - h
114
+ padw = expected_shape[1] - w
115
+ # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding
116
+ pad_width = ((0, padh), (0, padw))
117
+ logits = np.pad(logits, pad_width, mode="constant", constant_values=-1)
118
+
119
+ logits = logits / 255.0
120
+ logits[logits >= 1] = 1 - eps
121
+ logits[logits == 0] = eps
122
+ logits[logits == -1] = 0
123
+ # print(logits)
124
+ logits = inv_sigmoid(logits)
125
+
126
+ logits = logits[None]
127
+ assert logits.shape == (1, 256, 256), f"{logits.shape}"
128
+ return logits
129
+
130
+ def process_box(box, shape, original_size=None, box_extension=0):
131
+ if box_extension == 0: # no extension
132
+ extension_y, extension_x = 0, 0
133
+ elif box_extension >= 1: # extension by a fixed factor
134
+ extension_y, extension_x = box_extension, box_extension
135
+ else: # extension by fraction of the box len
136
+ len_y, len_x = box[2] - box[0], box[3] - box[1]
137
+ extension_y, extension_x = box_extension * len_y, box_extension * len_x
138
+
139
+ box = np.array([
140
+ max(box[1] - extension_x, 0), max(box[0] - extension_y, 0),
141
+ min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]),
142
+ ])
143
+
144
+ if original_size is not None:
145
+ trafo = ResizeLongestSide(max(original_size))
146
+ box = trafo.apply_boxes(box[None], (256, 256)).squeeze()
147
+ return box
148
+
149
+ def masks_sample_points(masks):
150
+ """Sample points on mask
151
+ """
152
+ masks = masks.unsqueeze(0)
153
+ if masks.numel() == 0:
154
+ return torch.zeros((0, 2), device=masks.device)
155
+
156
+ h, w = masks.shape[-2:]
157
+
158
+ y = torch.arange(0, h, dtype=torch.float)
159
+ x = torch.arange(0, w, dtype=torch.float)
160
+ y, x = torch.meshgrid(y, x)
161
+ y = y.to(masks)
162
+ x = x.to(masks)
163
+
164
+ k = np.random.randint(10, 11)
165
+ samples_pos = []
166
+ for b_i in range(len(masks)):
167
+ select_mask = (masks[b_i] > 0.5)
168
+ x_idx = torch.masked_select(x, select_mask)
169
+ y_idx = torch.masked_select(y, select_mask)
170
+
171
+ perm = torch.randperm(x_idx.size(0))
172
+ idx = perm[:k]
173
+ samples_x = x_idx[idx]
174
+ samples_y = y_idx[idx]
175
+ samples_xy = torch.cat((samples_x[:, None], samples_y[:, None]), dim=1)
176
+ samples_pos.append(samples_xy)
177
+
178
+ samples_pos = torch.cat(samples_pos)
179
+
180
+ k = np.random.randint(10, 11)
181
+ samples_neg = []
182
+ for b_i in range(len(masks)):
183
+ select_mask = (masks[b_i] < 0.5)
184
+ x_idx = torch.masked_select(x, select_mask)
185
+ y_idx = torch.masked_select(y, select_mask)
186
+
187
+ perm = torch.randperm(x_idx.size(0))
188
+ idx = perm[:k]
189
+ samples_x = x_idx[idx]
190
+ samples_y = y_idx[idx]
191
+ samples_xy = torch.cat((samples_x[:, None], samples_y[:, None]), dim=1)
192
+ samples_neg.append(samples_xy)
193
+
194
+ samples_neg = torch.cat(samples_neg)
195
+
196
+ # get the point labels
197
+ point_labels = np.concatenate(
198
+ [
199
+ np.ones(len(samples_pos), dtype="uint8"),
200
+ np.zeros(len(samples_neg), dtype="uint8"),
201
+ ], axis=0
202
+ )
203
+ point_coords = np.concatenate([samples_pos, samples_neg], axis=0).astype("float64")
204
+
205
+ return point_coords, point_labels
206
+
207
+ def masks_to_boxes(masks):
208
+ """Compute the bounding boxes around the provided masks
209
+
210
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
211
+
212
+ Returns a [N, 4] tensors, with the boxes in xyxy format
213
+ """
214
+ if masks.numel() == 0:
215
+ return torch.zeros((0, 4), device=masks.device)
216
+
217
+ h, w = masks.shape[-2:]
218
+
219
+ y = torch.arange(0, h, dtype=torch.float)
220
+ x = torch.arange(0, w, dtype=torch.float)
221
+ y, x = torch.meshgrid(y, x)
222
+
223
+ x_mask = (masks * x.unsqueeze(0))
224
+ x_max = x_mask.flatten(1).max(-1)[0]
225
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
226
+
227
+ y_mask = (masks * y.unsqueeze(0))
228
+ y_max = y_mask.flatten(1).max(-1)[0]
229
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
230
+
231
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
232
+
233
+ def box_iou(boxes1, boxes2):
234
+ area1 = box_area(boxes1)
235
+ area2 = box_area(boxes2)
236
+
237
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
238
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
239
+
240
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
241
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
242
+
243
+ union = area1[:, None] + area2 - inter
244
+
245
+ iou = inter / union
246
+ return iou, union
247
+
248
+
249
+ def show_points(coords, labels, ax, marker_size=375):
250
+ pos_points = coords[labels == 1]
251
+ neg_points = coords[labels == 0]
252
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
253
+ linewidth=1.25)
254
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
255
+ linewidth=1.25)
256
+
257
+
258
+ def show_box(box, ax):
259
+ x0, y0 = box[0], box[1]
260
+ w, h = box[2] - box[0], box[3] - box[1]
261
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
262
+
263
+
264
+ class Summary(Enum):
265
+ NONE = 0
266
+ AVERAGE = 1
267
+ SUM = 2
268
+ COUNT = 3
269
+
270
+
271
+ class AverageMeter(object):
272
+ """Computes and stores the average and current value"""
273
+
274
+ def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
275
+ self.name = name
276
+ self.fmt = fmt
277
+ self.summary_type = summary_type
278
+ self.reset()
279
+
280
+ def reset(self):
281
+ self.val = 0
282
+ self.avg = 0
283
+ self.sum = 0
284
+ self.count = 0
285
+
286
+ def update(self, val, n=1):
287
+ self.val = val
288
+ self.sum += val * n
289
+ self.count += n
290
+ self.avg = self.sum / self.count
291
+
292
+ def all_reduce(self):
293
+ device = "cuda" if torch.cuda.is_available() else "cpu"
294
+ if isinstance(self.sum, np.ndarray):
295
+ total = torch.tensor(
296
+ self.sum.tolist()
297
+ + [
298
+ self.count,
299
+ ],
300
+ dtype=torch.float32,
301
+ device=device,
302
+ )
303
+ else:
304
+ total = torch.tensor(
305
+ [self.sum, self.count], dtype=torch.float32, device=device
306
+ )
307
+
308
+ dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
309
+ if total.shape[0] > 2:
310
+ self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
311
+ else:
312
+ self.sum, self.count = total.tolist()
313
+ self.avg = self.sum / (self.count + 1e-5)
314
+
315
+ def __str__(self):
316
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
317
+ return fmtstr.format(**self.__dict__)
318
+
319
+ def summary(self):
320
+ fmtstr = ""
321
+ if self.summary_type is Summary.NONE:
322
+ fmtstr = ""
323
+ elif self.summary_type is Summary.AVERAGE:
324
+ fmtstr = "{name} {avg:.3f}"
325
+ elif self.summary_type is Summary.SUM:
326
+ fmtstr = "{name} {sum:.3f}"
327
+ elif self.summary_type is Summary.COUNT:
328
+ fmtstr = "{name} {count:.3f}"
329
+ else:
330
+ raise ValueError("invalid summary type %r" % self.summary_type)
331
+
332
+ return fmtstr.format(**self.__dict__)
333
+
334
+
335
+ def intersectionAndUnionGPU(output, target, K, ignore_index=255):
336
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
337
+ assert output.dim() in [1, 2, 3]
338
+ assert output.shape == target.shape
339
+ output = output.view(-1)
340
+ target = target.view(-1)
341
+ output[target == ignore_index] = ignore_index
342
+ intersection = output[output == target]
343
+ area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
344
+ area_output = torch.histc(output, bins=K, min=0, max=K - 1)
345
+ area_target = torch.histc(target, bins=K, min=0, max=K - 1)
346
+ area_union = area_output + area_target - area_intersection
347
+ return area_intersection, area_union, area_target
model/modeling_qwen2_vl.py ADDED
@@ -0,0 +1,1579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch.nn import LayerNorm
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
14
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
15
+ from transformers.modeling_layers import GradientCheckpointingLayer
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
17
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
18
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
19
+ from transformers.processing_utils import Unpack
20
+ from transformers.utils import (
21
+ TransformersKwargs,
22
+ auto_docstring,
23
+ can_return_tuple,
24
+ is_torchdynamo_compiling,
25
+ logging,
26
+ )
27
+ from transformers.utils.deprecation import deprecate_kwarg
28
+ from transformers.models.qwen2.modeling_qwen2 import (
29
+ Qwen2RMSNorm,
30
+ )
31
+ from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ @dataclass
38
+ @auto_docstring(
39
+ custom_intro="""
40
+ Base class for Llava outputs, with hidden states and attentions.
41
+ """
42
+ )
43
+ class Qwen2VLModelOutputWithPast(ModelOutput):
44
+ r"""
45
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
46
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
47
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
48
+
49
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
50
+ `past_key_values` input) to speed up sequential decoding.
51
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
52
+ The rope index difference between sequence length and multimodal rope.
53
+ """
54
+
55
+ last_hidden_state: torch.FloatTensor = None
56
+ past_key_values: Optional[list[torch.FloatTensor]] = None
57
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
58
+ attentions: Optional[tuple[torch.FloatTensor]] = None
59
+ rope_deltas: Optional[torch.LongTensor] = None
60
+
61
+
62
+ @dataclass
63
+ @auto_docstring(
64
+ custom_intro="""
65
+ Base class for Qwen2VL causal language model (or autoregressive) outputs.
66
+ """
67
+ )
68
+ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
69
+ r"""
70
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
71
+ Language modeling loss (for next-token prediction).
72
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
73
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
74
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
75
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
76
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
77
+
78
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
79
+ `past_key_values` input) to speed up sequential decoding.
80
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
81
+ The rope index difference between sequence length and multimodal rope.
82
+ """
83
+
84
+ loss: Optional[torch.FloatTensor] = None
85
+ logits: Optional[torch.FloatTensor] = None
86
+ past_key_values: Optional[list[torch.FloatTensor]] = None
87
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
88
+ attentions: Optional[tuple[torch.FloatTensor]] = None
89
+ rope_deltas: Optional[torch.LongTensor] = None
90
+
91
+
92
+ class Qwen2VLRotaryEmbedding(nn.Module):
93
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
94
+
95
+ def __init__(self, config: Qwen2VLTextConfig, device=None):
96
+ super().__init__()
97
+ # BC: "rope_type" was originally "type"
98
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
99
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
100
+ else:
101
+ self.rope_type = "default"
102
+ self.max_seq_len_cached = config.max_position_embeddings
103
+ self.original_max_seq_len = config.max_position_embeddings
104
+
105
+ self.config = config
106
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
107
+
108
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
109
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
110
+ self.original_inv_freq = self.inv_freq
111
+
112
+ @torch.no_grad()
113
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
114
+ def forward(self, x, position_ids):
115
+ # In contrast to other models, Qwen2_VL has different position ids for the grids
116
+ # So we expand the inv_freq to shape (3, ...)
117
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
118
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
119
+
120
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
121
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
122
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
123
+ emb = torch.cat((freqs, freqs), dim=-1)
124
+ cos = emb.cos() * self.attention_scaling
125
+ sin = emb.sin() * self.attention_scaling
126
+
127
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
128
+
129
+
130
+ def rotate_half(x):
131
+ """Rotates half the hidden dims of the input."""
132
+ x1 = x[..., : x.shape[-1] // 2]
133
+ x2 = x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+
137
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
138
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors
139
+
140
+ Explanation:
141
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
142
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
143
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
144
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
145
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
146
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
147
+ difference with modern LLMs.
148
+
149
+ Args:
150
+ q (`torch.Tensor`): The query tensor.
151
+ k (`torch.Tensor`): The key tensor.
152
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
153
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
154
+ position_ids (`torch.Tensor`):
155
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
156
+ used to pass offsetted position ids when working with a KV-cache.
157
+ mrope_section(`List(int)`):
158
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
159
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
160
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
161
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
162
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
163
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
164
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
165
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
166
+ Returns:
167
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
168
+ """
169
+ mrope_section = mrope_section * 2
170
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
171
+ unsqueeze_dim
172
+ )
173
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
174
+ unsqueeze_dim
175
+ )
176
+
177
+ q_embed = (q * cos) + (rotate_half(q) * sin)
178
+ k_embed = (k * cos) + (rotate_half(k) * sin)
179
+ return q_embed, k_embed
180
+
181
+
182
+ def apply_rotary_pos_emb_vision(
183
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
184
+ ) -> tuple[torch.Tensor, torch.Tensor]:
185
+ orig_q_dtype = q.dtype
186
+ orig_k_dtype = k.dtype
187
+ q, k = q.float(), k.float()
188
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
189
+ q_embed = (q * cos) + (rotate_half(q) * sin)
190
+ k_embed = (k * cos) + (rotate_half(k) * sin)
191
+ q_embed = q_embed.to(orig_q_dtype)
192
+ k_embed = k_embed.to(orig_k_dtype)
193
+ return q_embed, k_embed
194
+
195
+
196
+ class VisionRotaryEmbedding(nn.Module):
197
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
198
+
199
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
200
+ super().__init__()
201
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
202
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
203
+
204
+ def forward(self, seqlen: int) -> torch.Tensor:
205
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
206
+ freqs = torch.outer(seq, self.inv_freq)
207
+ return freqs
208
+
209
+
210
+ class PatchEmbed(nn.Module):
211
+ def __init__(
212
+ self,
213
+ patch_size: int = 14,
214
+ temporal_patch_size: int = 2,
215
+ in_channels: int = 3,
216
+ embed_dim: int = 1152,
217
+ ) -> None:
218
+ super().__init__()
219
+ self.patch_size = patch_size
220
+ self.temporal_patch_size = temporal_patch_size
221
+ self.in_channels = in_channels
222
+ self.embed_dim = embed_dim
223
+
224
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
225
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
226
+
227
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228
+ target_dtype = self.proj.weight.dtype
229
+ hidden_states = hidden_states.view(
230
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
231
+ )
232
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
233
+ return hidden_states
234
+
235
+
236
+ class PatchMerger(nn.Module):
237
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
238
+ super().__init__()
239
+ self.hidden_size = context_dim * (spatial_merge_size**2)
240
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
241
+ self.mlp = nn.Sequential(
242
+ nn.Linear(self.hidden_size, self.hidden_size),
243
+ nn.GELU(),
244
+ nn.Linear(self.hidden_size, dim),
245
+ )
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
249
+ return x
250
+
251
+
252
+ class VisionMlp(nn.Module):
253
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
254
+ super().__init__()
255
+ self.fc1 = nn.Linear(dim, hidden_dim)
256
+ self.act = ACT2FN[hidden_act]
257
+ self.fc2 = nn.Linear(hidden_dim, dim)
258
+
259
+ def forward(self, x) -> torch.Tensor:
260
+ return self.fc2(self.act(self.fc1(x)))
261
+
262
+
263
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
264
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
265
+ """
266
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
267
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
268
+ """
269
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
270
+ if n_rep == 1:
271
+ return hidden_states
272
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
273
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
274
+
275
+
276
+ def eager_attention_forward(
277
+ module: nn.Module,
278
+ query: torch.Tensor,
279
+ key: torch.Tensor,
280
+ value: torch.Tensor,
281
+ attention_mask: Optional[torch.Tensor],
282
+ scaling: float,
283
+ dropout: float = 0.0,
284
+ **kwargs,
285
+ ):
286
+ key_states = repeat_kv(key, module.num_key_value_groups)
287
+ value_states = repeat_kv(value, module.num_key_value_groups)
288
+
289
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
290
+ if attention_mask is not None:
291
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
292
+ attn_weights = attn_weights + causal_mask
293
+
294
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
295
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
296
+ attn_output = torch.matmul(attn_weights, value_states)
297
+ attn_output = attn_output.transpose(1, 2).contiguous()
298
+
299
+ return attn_output, attn_weights
300
+
301
+
302
+ class VisionAttention(nn.Module):
303
+ def __init__(self, config: Qwen2VLVisionConfig) -> None:
304
+ super().__init__()
305
+ self.dim = config.embed_dim
306
+ self.num_heads = config.num_heads
307
+ self.head_dim = self.dim // self.num_heads
308
+ self.num_key_value_groups = 1 # needed for eager attention
309
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
310
+ self.proj = nn.Linear(self.dim, self.dim)
311
+ self.scaling = self.head_dim**-0.5
312
+ self.config = config
313
+ self.attention_dropout = 0.0
314
+ self.is_causal = False
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: torch.Tensor,
319
+ cu_seqlens: torch.Tensor,
320
+ rotary_pos_emb: Optional[torch.Tensor] = None,
321
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
322
+ **kwargs,
323
+ ) -> torch.Tensor:
324
+ seq_length = hidden_states.shape[0]
325
+ query_states, key_states, value_states = (
326
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
327
+ )
328
+ if position_embeddings is None:
329
+ logger.warning_once(
330
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
331
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
332
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
333
+ "removed and `position_embeddings` will be mandatory."
334
+ )
335
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
336
+ cos = emb.cos()
337
+ sin = emb.sin()
338
+ else:
339
+ cos, sin = position_embeddings
340
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
341
+
342
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
343
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
344
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
345
+
346
+ attention_interface: Callable = eager_attention_forward
347
+ if self.config._attn_implementation != "eager":
348
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
349
+
350
+ if self.config._attn_implementation == "flash_attention_2":
351
+ # Flash Attention 2: Use cu_seqlens for variable length attention
352
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
353
+ attn_output, _ = attention_interface(
354
+ self,
355
+ query_states,
356
+ key_states,
357
+ value_states,
358
+ attention_mask=None,
359
+ scaling=self.scaling,
360
+ dropout=0.0 if not self.training else self.attention_dropout,
361
+ cu_seq_lens_q=cu_seqlens,
362
+ cu_seq_lens_k=cu_seqlens,
363
+ max_length_q=max_seqlen,
364
+ max_length_k=max_seqlen,
365
+ is_causal=False,
366
+ **kwargs,
367
+ )
368
+ else:
369
+ # Other implementations: Process each chunk separately
370
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
371
+ splits = [
372
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
373
+ ]
374
+
375
+ attn_outputs = [
376
+ attention_interface(
377
+ self,
378
+ q,
379
+ k,
380
+ v,
381
+ attention_mask=None,
382
+ scaling=self.scaling,
383
+ dropout=0.0 if not self.training else self.attention_dropout,
384
+ is_causal=False,
385
+ **kwargs,
386
+ )[0]
387
+ for q, k, v in zip(*splits)
388
+ ]
389
+ attn_output = torch.cat(attn_outputs, dim=1)
390
+
391
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
392
+ attn_output = self.proj(attn_output)
393
+ return attn_output
394
+
395
+
396
+ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
397
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
398
+ super().__init__()
399
+ self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
400
+ self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
401
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
402
+
403
+ self.attn = VisionAttention(config=config)
404
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states: torch.Tensor,
409
+ cu_seqlens: torch.Tensor,
410
+ rotary_pos_emb: Optional[torch.Tensor] = None,
411
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
412
+ **kwargs,
413
+ ) -> torch.Tensor:
414
+ hidden_states = hidden_states + self.attn(
415
+ self.norm1(hidden_states),
416
+ cu_seqlens=cu_seqlens,
417
+ rotary_pos_emb=rotary_pos_emb,
418
+ position_embeddings=position_embeddings,
419
+ **kwargs,
420
+ )
421
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
422
+ return hidden_states
423
+
424
+
425
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
426
+ class Qwen2MLP(nn.Module):
427
+ def __init__(self, config):
428
+ super().__init__()
429
+ self.config = config
430
+ self.hidden_size = config.hidden_size
431
+ self.intermediate_size = config.intermediate_size
432
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
433
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
434
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
435
+ self.act_fn = ACT2FN[config.hidden_act]
436
+
437
+ def forward(self, x):
438
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
439
+ return down_proj
440
+
441
+
442
+ class Qwen2VLAttention(nn.Module):
443
+ """
444
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
445
+ and "Generating Long Sequences with Sparse Transformers".
446
+ """
447
+
448
+ def __init__(self, config: Qwen2VLTextConfig, layer_idx: Optional[int] = None):
449
+ super().__init__()
450
+ self.config = config
451
+ self.layer_idx = layer_idx
452
+ if layer_idx is None:
453
+ logger.warning_once(
454
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
455
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
456
+ "when creating this class."
457
+ )
458
+
459
+ self.hidden_size = config.hidden_size
460
+ self.num_heads = config.num_attention_heads
461
+ self.head_dim = self.hidden_size // self.num_heads
462
+ self.num_key_value_heads = config.num_key_value_heads
463
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
464
+ self.is_causal = True
465
+ self.attention_dropout = config.attention_dropout
466
+ self.rope_scaling = config.rope_scaling
467
+ self.scaling = self.head_dim**-0.5
468
+
469
+ if (self.head_dim * self.num_heads) != self.hidden_size:
470
+ raise ValueError(
471
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
472
+ f" and `num_heads`: {self.num_heads})."
473
+ )
474
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
475
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
476
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
477
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
478
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
479
+
480
+ self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
481
+
482
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
483
+ def forward(
484
+ self,
485
+ hidden_states: torch.Tensor,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ past_key_values: Optional[Cache] = None,
489
+ output_attentions: bool = False,
490
+ use_cache: bool = False,
491
+ cache_position: Optional[torch.LongTensor] = None,
492
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
493
+ **kwargs: Unpack[FlashAttentionKwargs],
494
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
495
+ bsz, q_len, _ = hidden_states.size()
496
+
497
+ query_states = self.q_proj(hidden_states)
498
+ key_states = self.k_proj(hidden_states)
499
+ value_states = self.v_proj(hidden_states)
500
+
501
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
502
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
503
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
504
+
505
+ cos, sin = position_embeddings
506
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
507
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
508
+ )
509
+
510
+ if past_key_values is not None:
511
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
512
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
513
+
514
+ attention_interface: Callable = eager_attention_forward
515
+ if self.config._attn_implementation != "eager":
516
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
517
+
518
+ attn_output, attn_weights = attention_interface(
519
+ self,
520
+ query_states,
521
+ key_states,
522
+ value_states,
523
+ attention_mask,
524
+ dropout=0.0 if not self.training else self.attention_dropout,
525
+ scaling=self.scaling,
526
+ sliding_window=self.sliding_window,
527
+ position_ids=position_ids, # pass positions for FA2
528
+ **kwargs,
529
+ )
530
+
531
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
532
+ attn_output = self.o_proj(attn_output)
533
+ return attn_output, attn_weights
534
+
535
+
536
+ class Qwen2VLDecoderLayer(GradientCheckpointingLayer):
537
+ def __init__(self, config: Qwen2VLTextConfig, layer_idx: int):
538
+ super().__init__()
539
+ self.hidden_size = config.hidden_size
540
+
541
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
542
+ logger.warning_once(
543
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
544
+ "unexpected results may be encountered."
545
+ )
546
+ self.self_attn = Qwen2VLAttention(config, layer_idx)
547
+
548
+ self.mlp = Qwen2MLP(config)
549
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
550
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
551
+ self.attention_type = config.layer_types[layer_idx]
552
+
553
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ attention_mask: Optional[torch.Tensor] = None,
558
+ position_ids: Optional[torch.LongTensor] = None,
559
+ past_key_values: Optional[tuple[torch.Tensor]] = None,
560
+ output_attentions: Optional[bool] = False,
561
+ use_cache: Optional[bool] = False,
562
+ cache_position: Optional[torch.LongTensor] = None,
563
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
564
+ **kwargs: Unpack[FlashAttentionKwargs],
565
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
566
+ """
567
+ Args:
568
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
569
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
570
+ `(batch, sequence_length)` where padding elements are indicated by 0.
571
+ output_attentions (`bool`, *optional*):
572
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
573
+ returned tensors for more detail.
574
+ use_cache (`bool`, *optional*):
575
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
576
+ (see `past_key_values`).
577
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
578
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
579
+ Indices depicting the position of the input sequence tokens in the sequence.
580
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
581
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
582
+ with `head_dim` being the embedding dimension of each attention head.
583
+ kwargs (`dict`, *optional*):
584
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
585
+ into the model
586
+ """
587
+
588
+ residual = hidden_states
589
+
590
+ hidden_states = self.input_layernorm(hidden_states)
591
+
592
+ # Self Attention
593
+ hidden_states, self_attn_weights = self.self_attn(
594
+ hidden_states=hidden_states,
595
+ attention_mask=attention_mask,
596
+ position_ids=position_ids,
597
+ past_key_values=past_key_values,
598
+ output_attentions=output_attentions,
599
+ use_cache=use_cache,
600
+ cache_position=cache_position,
601
+ position_embeddings=position_embeddings,
602
+ **kwargs,
603
+ )
604
+ hidden_states = residual + hidden_states
605
+
606
+ # Fully Connected
607
+ residual = hidden_states
608
+ hidden_states = self.post_attention_layernorm(hidden_states)
609
+ hidden_states = self.mlp(hidden_states)
610
+ hidden_states = residual + hidden_states
611
+
612
+ outputs = (hidden_states,)
613
+
614
+ if output_attentions:
615
+ outputs += (self_attn_weights,)
616
+
617
+ return outputs
618
+
619
+
620
+ @auto_docstring
621
+ class Qwen2VLPreTrainedModel(PreTrainedModel):
622
+ config: Qwen2VLConfig
623
+ base_model_prefix = "model"
624
+ supports_gradient_checkpointing = True
625
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
626
+ _skip_keys_device_placement = "past_key_values"
627
+ _supports_flash_attn = True
628
+ _supports_sdpa = True
629
+
630
+ _can_compile_fullgraph = True
631
+ _supports_attention_backend = True
632
+
633
+
634
+ @auto_docstring
635
+ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
636
+ config: Qwen2VLVisionConfig
637
+ _no_split_modules = ["Qwen2VLVisionBlock"]
638
+
639
+ def __init__(self, config) -> None:
640
+ super().__init__(config)
641
+ self.spatial_merge_size = config.spatial_merge_size
642
+
643
+ self.patch_embed = PatchEmbed(
644
+ patch_size=config.patch_size,
645
+ temporal_patch_size=config.temporal_patch_size,
646
+ in_channels=config.in_channels,
647
+ embed_dim=config.embed_dim,
648
+ )
649
+
650
+ head_dim = config.embed_dim // config.num_heads
651
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
652
+
653
+ self.blocks = nn.ModuleList([Qwen2VLVisionBlock(config) for _ in range(config.depth)])
654
+ self.merger = PatchMerger(
655
+ dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
656
+ )
657
+ self.gradient_checkpointing = False
658
+
659
+ def get_dtype(self) -> torch.dtype:
660
+ return self.blocks[0].mlp.fc2.weight.dtype
661
+
662
+ def get_device(self) -> torch.device:
663
+ return self.blocks[0].mlp.fc2.weight.device
664
+
665
+ def rot_pos_emb(self, grid_thw):
666
+ pos_ids = []
667
+ for t, h, w in grid_thw:
668
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
669
+ hpos_ids = hpos_ids.reshape(
670
+ h // self.spatial_merge_size,
671
+ self.spatial_merge_size,
672
+ w // self.spatial_merge_size,
673
+ self.spatial_merge_size,
674
+ )
675
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
676
+ hpos_ids = hpos_ids.flatten()
677
+
678
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
679
+ wpos_ids = wpos_ids.reshape(
680
+ h // self.spatial_merge_size,
681
+ self.spatial_merge_size,
682
+ w // self.spatial_merge_size,
683
+ self.spatial_merge_size,
684
+ )
685
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
686
+ wpos_ids = wpos_ids.flatten()
687
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
688
+ pos_ids = torch.cat(pos_ids, dim=0)
689
+ max_grid_size = grid_thw[:, 1:].max()
690
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
691
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
692
+ return rotary_pos_emb
693
+
694
+ @auto_docstring
695
+ def forward(
696
+ self,
697
+ hidden_states: torch.Tensor,
698
+ grid_thw: torch.Tensor,
699
+ **kwargs,
700
+ ) -> torch.Tensor:
701
+ r"""
702
+ grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
703
+ The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
704
+ """
705
+ hidden_states = self.patch_embed(hidden_states)
706
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
707
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
708
+ position_embeddings = (emb.cos(), emb.sin())
709
+
710
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
711
+ dim=0,
712
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
713
+ )
714
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
715
+
716
+ for blk in self.blocks:
717
+ hidden_states = blk(
718
+ hidden_states,
719
+ cu_seqlens=cu_seqlens,
720
+ position_embeddings=position_embeddings,
721
+ **kwargs,
722
+ )
723
+
724
+ return self.merger(hidden_states)
725
+
726
+
727
+ @auto_docstring
728
+ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
729
+ config: Qwen2VLTextConfig
730
+
731
+ def __init__(self, config: Qwen2VLTextConfig):
732
+ super().__init__(config)
733
+ self.padding_idx = config.pad_token_id
734
+ self.vocab_size = config.vocab_size
735
+
736
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
737
+ self.layers = nn.ModuleList(
738
+ [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
739
+ )
740
+ self._attn_implementation = config._attn_implementation
741
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
742
+ self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
743
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
744
+
745
+ self.gradient_checkpointing = False
746
+ # Initialize weights and apply final processing
747
+ self.post_init()
748
+
749
+ @auto_docstring
750
+ def forward(
751
+ self,
752
+ input_ids: Optional[torch.LongTensor] = None,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ position_ids: Optional[torch.LongTensor] = None,
755
+ past_key_values: Optional[Cache] = None,
756
+ inputs_embeds: Optional[torch.FloatTensor] = None,
757
+ use_cache: Optional[bool] = None,
758
+ output_attentions: Optional[bool] = None,
759
+ output_hidden_states: Optional[bool] = None,
760
+ return_dict: Optional[bool] = None,
761
+ cache_position: Optional[torch.LongTensor] = None,
762
+ **kwargs: Unpack[FlashAttentionKwargs],
763
+ ) -> Union[tuple, BaseModelOutputWithPast]:
764
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
765
+ output_hidden_states = (
766
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
767
+ )
768
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
769
+
770
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
771
+
772
+ if (input_ids is None) ^ (inputs_embeds is not None):
773
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
774
+
775
+ if self.gradient_checkpointing and self.training:
776
+ if use_cache:
777
+ logger.warning_once(
778
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
779
+ )
780
+ use_cache = False
781
+
782
+ # torch.jit.trace() doesn't support cache objects in the output
783
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
784
+ past_key_values = DynamicCache(config=self.config)
785
+
786
+ if inputs_embeds is None:
787
+ inputs_embeds = self.embed_tokens(input_ids)
788
+
789
+ if cache_position is None:
790
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
791
+ cache_position = torch.arange(
792
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
793
+ )
794
+
795
+ # the hard coded `3` is for temporal, height and width.
796
+ if position_ids is None:
797
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
798
+ elif position_ids.ndim == 2:
799
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
800
+
801
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
802
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
803
+ # There are two scenarios when FA2-like packed masking might be activated.
804
+ # 1. User specifically passed packed `position_ids` and no attention mask.
805
+ # In this case we expect the useer to create correct position ids for all 3 grids
806
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
807
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
808
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
809
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
810
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
811
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
812
+ text_position_ids = position_ids[0]
813
+ position_ids = position_ids[1:]
814
+ else:
815
+ text_position_ids = position_ids[0]
816
+
817
+ # It may already have been prepared by e.g. `generate`
818
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
819
+ # Prepare mask arguments
820
+ mask_kwargs = {
821
+ "config": self.config,
822
+ "input_embeds": inputs_embeds,
823
+ "attention_mask": attention_mask,
824
+ "cache_position": cache_position,
825
+ "past_key_values": past_key_values,
826
+ "position_ids": text_position_ids,
827
+ }
828
+ # Create the masks
829
+ causal_mask_mapping = {
830
+ "full_attention": create_causal_mask(**mask_kwargs),
831
+ }
832
+ # The sliding window alternating layers are not always activated depending on the config
833
+ if self.has_sliding_layers:
834
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
835
+
836
+ hidden_states = inputs_embeds
837
+
838
+ # create position embeddings to be shared across the decoder layers
839
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
840
+
841
+ # decoder layers
842
+ all_hidden_states = () if output_hidden_states else None
843
+ all_self_attns = () if output_attentions else None
844
+ # print(1)
845
+ for decoder_layer in self.layers:
846
+ if output_hidden_states:
847
+ all_hidden_states += (hidden_states,)
848
+
849
+ layer_outputs = decoder_layer(
850
+ hidden_states,
851
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
852
+ position_ids=text_position_ids,
853
+ past_key_values=past_key_values,
854
+ output_attentions=output_attentions,
855
+ use_cache=use_cache,
856
+ cache_position=cache_position,
857
+ position_embeddings=position_embeddings,
858
+ **kwargs,
859
+ )
860
+
861
+ hidden_states = layer_outputs[0]
862
+
863
+ if output_attentions:
864
+ all_self_attns += (layer_outputs[1],)
865
+
866
+ hidden_states = self.norm(hidden_states)
867
+
868
+ # add hidden states from the last decoder layer
869
+ if output_hidden_states:
870
+ all_hidden_states += (hidden_states,)
871
+
872
+ if not return_dict:
873
+ return tuple(
874
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
875
+ )
876
+ return BaseModelOutputWithPast(
877
+ last_hidden_state=hidden_states,
878
+ past_key_values=past_key_values,
879
+ hidden_states=all_hidden_states,
880
+ attentions=all_self_attns,
881
+ )
882
+
883
+
884
+ @auto_docstring
885
+ class Qwen2VLModel(Qwen2VLPreTrainedModel):
886
+ base_model_prefix = ""
887
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
888
+ accepts_loss_kwargs = False
889
+
890
+ def __init__(self, config: Qwen2VLConfig):
891
+ super().__init__(config)
892
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
893
+ self.language_model = Qwen2VLTextModel._from_config(config.text_config)
894
+ self.rope_deltas = None # cache rope_deltas here
895
+
896
+ # Initialize weights and apply final processing
897
+ self.post_init()
898
+
899
+ def get_input_embeddings(self):
900
+ return self.language_model.get_input_embeddings()
901
+
902
+ def set_input_embeddings(self, value):
903
+ self.language_model.set_input_embeddings(value)
904
+
905
+ def set_decoder(self, decoder):
906
+ self.language_model = decoder
907
+
908
+ def get_decoder(self):
909
+ return self.language_model
910
+
911
+ def get_rope_index(
912
+ self,
913
+ input_ids: Optional[torch.LongTensor] = None,
914
+ image_grid_thw: Optional[torch.LongTensor] = None,
915
+ video_grid_thw: Optional[torch.LongTensor] = None,
916
+ attention_mask: Optional[torch.Tensor] = None,
917
+ ) -> tuple[torch.Tensor, torch.Tensor]:
918
+ """
919
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
920
+
921
+ Explanation:
922
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
923
+
924
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
925
+ Examples:
926
+ input_ids: [T T T T T], here T is for text.
927
+ temporal position_ids: [0, 1, 2, 3, 4]
928
+ height position_ids: [0, 1, 2, 3, 4]
929
+ width position_ids: [0, 1, 2, 3, 4]
930
+
931
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
932
+ and 1D rotary position embedding for text part.
933
+ Examples:
934
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
935
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
936
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
937
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
938
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
939
+ text temporal position_ids: [3, 4, 5, 6, 7]
940
+ text height position_ids: [3, 4, 5, 6, 7]
941
+ text width position_ids: [3, 4, 5, 6, 7]
942
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
943
+
944
+ Args:
945
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
946
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
947
+ it.
948
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
949
+ The temporal, height and width of feature shape of each image in LLM.
950
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
951
+ The temporal, height and width of feature shape of each video in LLM.
952
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
953
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
954
+
955
+ - 1 for tokens that are **not masked**,
956
+ - 0 for tokens that are **masked**.
957
+
958
+ Returns:
959
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
960
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
961
+ """
962
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
963
+ image_token_id = self.config.image_token_id
964
+ video_token_id = self.config.video_token_id
965
+ vision_start_token_id = self.config.vision_start_token_id
966
+ mrope_position_deltas = []
967
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
968
+ total_input_ids = input_ids
969
+ if attention_mask is None:
970
+ attention_mask = torch.ones_like(total_input_ids)
971
+ position_ids = torch.ones(
972
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
973
+ )
974
+ image_index, video_index = 0, 0
975
+ for i, input_ids in enumerate(total_input_ids):
976
+ input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
977
+ image_nums, video_nums = 0, 0
978
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
979
+ vision_tokens = input_ids[vision_start_indices + 1]
980
+ image_nums = (vision_tokens == image_token_id).sum()
981
+ video_nums = (vision_tokens == video_token_id).sum()
982
+ input_tokens = input_ids.tolist()
983
+ llm_pos_ids_list: list = []
984
+ st = 0
985
+ remain_images, remain_videos = image_nums, video_nums
986
+ for _ in range(image_nums + video_nums):
987
+ if image_token_id in input_tokens and remain_images > 0:
988
+ ed_image = input_tokens.index(image_token_id, st)
989
+ else:
990
+ ed_image = len(input_tokens) + 1
991
+ if video_token_id in input_tokens and remain_videos > 0:
992
+ ed_video = input_tokens.index(video_token_id, st)
993
+ else:
994
+ ed_video = len(input_tokens) + 1
995
+ if ed_image < ed_video:
996
+ t, h, w = (
997
+ image_grid_thw[image_index][0],
998
+ image_grid_thw[image_index][1],
999
+ image_grid_thw[image_index][2],
1000
+ )
1001
+ image_index += 1
1002
+ remain_images -= 1
1003
+ ed = ed_image
1004
+ else:
1005
+ t, h, w = (
1006
+ video_grid_thw[video_index][0],
1007
+ video_grid_thw[video_index][1],
1008
+ video_grid_thw[video_index][2],
1009
+ )
1010
+ video_index += 1
1011
+ remain_videos -= 1
1012
+ ed = ed_video
1013
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1014
+ t.item(),
1015
+ h.item() // spatial_merge_size,
1016
+ w.item() // spatial_merge_size,
1017
+ )
1018
+ text_len = ed - st
1019
+
1020
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1021
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1022
+
1023
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1024
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1025
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1026
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1027
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1028
+
1029
+ if st < len(input_tokens):
1030
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1031
+ text_len = len(input_tokens) - st
1032
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1033
+
1034
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1035
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1036
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1037
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1038
+ return position_ids, mrope_position_deltas
1039
+ else:
1040
+ if attention_mask is not None:
1041
+ position_ids = attention_mask.long().cumsum(-1) - 1
1042
+ position_ids.masked_fill_(attention_mask == 0, 1)
1043
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1044
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1045
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1046
+ else:
1047
+ position_ids = (
1048
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1049
+ .view(1, 1, -1)
1050
+ .expand(3, input_ids.shape[0], -1)
1051
+ )
1052
+ mrope_position_deltas = torch.zeros(
1053
+ [input_ids.shape[0], 1],
1054
+ device=input_ids.device,
1055
+ dtype=input_ids.dtype,
1056
+ )
1057
+
1058
+ return position_ids, mrope_position_deltas
1059
+
1060
+ def get_video_features(
1061
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1062
+ ):
1063
+ """
1064
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
1065
+
1066
+ Args:
1067
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1068
+ The tensors corresponding to the input videos.
1069
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1070
+ The temporal, height and width of feature shape of each video in LLM.
1071
+ """
1072
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1073
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
1074
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1075
+ video_embeds = torch.split(video_embeds, split_sizes)
1076
+ return video_embeds
1077
+
1078
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1079
+ """
1080
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1081
+
1082
+ Args:
1083
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1084
+ The tensors corresponding to the input images.
1085
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1086
+ The temporal, height and width of feature shape of each image in LLM.
1087
+ """
1088
+ pixel_values = pixel_values.type(self.visual.dtype)
1089
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
1090
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1091
+ image_embeds = torch.split(image_embeds, split_sizes)
1092
+ return image_embeds
1093
+
1094
+ def get_placeholder_mask(
1095
+ self,
1096
+ input_ids: torch.LongTensor,
1097
+ inputs_embeds: torch.FloatTensor,
1098
+ image_features: torch.FloatTensor = None,
1099
+ video_features: torch.FloatTensor = None,
1100
+ ):
1101
+ """
1102
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1103
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1104
+ """
1105
+ if input_ids is None:
1106
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1107
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1108
+ )
1109
+ special_image_mask = special_image_mask.all(-1)
1110
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
1111
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
1112
+ )
1113
+ special_video_mask = special_video_mask.all(-1)
1114
+ else:
1115
+ special_image_mask = input_ids == self.config.image_token_id
1116
+ special_video_mask = input_ids == self.config.video_token_id
1117
+
1118
+ n_image_tokens = special_image_mask.sum()
1119
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1120
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
1121
+ raise ValueError(
1122
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
1123
+ )
1124
+
1125
+ n_video_tokens = special_video_mask.sum()
1126
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1127
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
1128
+ raise ValueError(
1129
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
1130
+ )
1131
+
1132
+ return special_image_mask, special_video_mask
1133
+
1134
+ @auto_docstring
1135
+ def forward(
1136
+ self,
1137
+ input_ids: torch.LongTensor = None,
1138
+ attention_mask: Optional[torch.Tensor] = None,
1139
+ position_ids: Optional[torch.LongTensor] = None,
1140
+ past_key_values: Optional[Cache] = None,
1141
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1142
+ use_cache: Optional[bool] = None,
1143
+ output_attentions: Optional[bool] = None,
1144
+ output_hidden_states: Optional[bool] = None,
1145
+ return_dict: Optional[bool] = None,
1146
+ pixel_values: Optional[torch.Tensor] = None,
1147
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1148
+ image_grid_thw: Optional[torch.LongTensor] = None,
1149
+ video_grid_thw: Optional[torch.LongTensor] = None,
1150
+ rope_deltas: Optional[torch.LongTensor] = None,
1151
+ cache_position: Optional[torch.LongTensor] = None,
1152
+ seg_mask: Optional[torch.Tensor] = None,
1153
+ **kwargs: Unpack[TransformersKwargs],
1154
+ ) -> Union[tuple, Qwen2VLModelOutputWithPast]:
1155
+ r"""
1156
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1157
+ The temporal, height and width of feature shape of each image in LLM.
1158
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1159
+ The temporal, height and width of feature shape of each video in LLM.
1160
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1161
+ The rope index difference between sequence length and multimodal rope.
1162
+ """
1163
+
1164
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1165
+ output_hidden_states = (
1166
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1167
+ )
1168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1169
+
1170
+ if inputs_embeds is None:
1171
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1172
+
1173
+ if pixel_values is not None:
1174
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
1175
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1176
+ image_mask, _ = self.get_placeholder_mask(
1177
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1178
+ )
1179
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1180
+
1181
+ if pixel_values_videos is not None:
1182
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
1183
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1184
+ _, video_mask = self.get_placeholder_mask(
1185
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1186
+ )
1187
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1188
+
1189
+ if position_ids is None:
1190
+ if self.rope_deltas is None or cache_position is None or cache_position[0] == 0:
1191
+ position_ids, rope_deltas = self.get_rope_index(
1192
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
1193
+ )
1194
+ self.rope_deltas = rope_deltas
1195
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1196
+ else:
1197
+ batch_size, seq_length, _ = inputs_embeds.shape
1198
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1199
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1200
+ if cache_position is not None:
1201
+ delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1202
+ else:
1203
+ delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
1204
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1205
+ position_ids += delta.to(position_ids.device)
1206
+
1207
+ ### CHANGE
1208
+ if seg_mask is not None:
1209
+ (
1210
+ attention_mask,
1211
+ final_position_ids,
1212
+ final_past_key_values,
1213
+ final_use_cache,
1214
+ final_cache_position,
1215
+ ) = self._create_hybrid_mask_and_dependencies(
1216
+ seg_mask,
1217
+ inputs_embeds,
1218
+ attention_mask,
1219
+ position_ids,
1220
+ **kwargs,
1221
+ )
1222
+
1223
+ if past_key_values is not None:
1224
+ inputs_embeds = inputs_embeds[seg_mask == 1].unsqueeze(0)
1225
+ position_ids = position_ids[:, seg_mask == 1].unsqueeze(1)
1226
+ attention_mask = attention_mask[:, seg_mask == 1].unsqueeze(1)
1227
+ # attention_mask = attention_mask.unsqueeze(-2)[:, :, :, seg_mask == 1]
1228
+
1229
+ attention_mask = {'full_attention': attention_mask}
1230
+
1231
+
1232
+ # print(3)
1233
+ # print(2)
1234
+ ###############
1235
+ outputs = self.language_model(
1236
+ input_ids=None,
1237
+ position_ids=position_ids,
1238
+ attention_mask=attention_mask,
1239
+ past_key_values=past_key_values,
1240
+ inputs_embeds=inputs_embeds,
1241
+ use_cache=use_cache,
1242
+ output_attentions=output_attentions,
1243
+ output_hidden_states=output_hidden_states,
1244
+ return_dict=True,
1245
+ cache_position=cache_position,
1246
+ **kwargs,
1247
+ )
1248
+
1249
+ output = Qwen2VLModelOutputWithPast(
1250
+ last_hidden_state=outputs.last_hidden_state,
1251
+ past_key_values=outputs.past_key_values,
1252
+ hidden_states=outputs.hidden_states,
1253
+ attentions=outputs.attentions,
1254
+ rope_deltas=self.rope_deltas,
1255
+ )
1256
+ return output if return_dict else output.to_tuple()
1257
+
1258
+
1259
+ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1260
+ _checkpoint_conversion_mapping = {
1261
+ "^visual": "model.visual",
1262
+ r"^model(?!\.(language_model|visual))": "model.language_model",
1263
+ }
1264
+ _tied_weights_keys = ["lm_head.weight"]
1265
+
1266
+ def __init__(self, config):
1267
+ super().__init__(config)
1268
+ self.model = Qwen2VLModel(config)
1269
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1270
+
1271
+ self.post_init()
1272
+
1273
+ def get_input_embeddings(self):
1274
+ return self.model.get_input_embeddings()
1275
+
1276
+ def set_input_embeddings(self, value):
1277
+ self.model.set_input_embeddings(value)
1278
+
1279
+ def set_decoder(self, decoder):
1280
+ self.model.set_decoder(decoder)
1281
+
1282
+ def get_decoder(self):
1283
+ return self.model.get_decoder()
1284
+
1285
+ def get_video_features(
1286
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1287
+ ):
1288
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
1289
+
1290
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1291
+ return self.model.get_image_features(pixel_values, image_grid_thw)
1292
+
1293
+ # Make modules available through conditional class for BC
1294
+ @property
1295
+ def language_model(self):
1296
+ return self.model.language_model
1297
+
1298
+ @property
1299
+ def visual(self):
1300
+ return self.model.visual
1301
+
1302
+ @can_return_tuple
1303
+ @auto_docstring
1304
+ def forward(
1305
+ self,
1306
+ input_ids: torch.LongTensor = None,
1307
+ attention_mask: Optional[torch.Tensor] = None,
1308
+ position_ids: Optional[torch.LongTensor] = None,
1309
+ past_key_values: Optional[Cache] = None,
1310
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1311
+ labels: Optional[torch.LongTensor] = None,
1312
+ use_cache: Optional[bool] = None,
1313
+ output_attentions: Optional[bool] = None,
1314
+ output_hidden_states: Optional[bool] = None,
1315
+ pixel_values: Optional[torch.Tensor] = None,
1316
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1317
+ image_grid_thw: Optional[torch.LongTensor] = None,
1318
+ video_grid_thw: Optional[torch.LongTensor] = None,
1319
+ rope_deltas: Optional[torch.LongTensor] = None,
1320
+ cache_position: Optional[torch.LongTensor] = None,
1321
+ **kwargs: Unpack[TransformersKwargs],
1322
+ ) -> Union[tuple, Qwen2VLCausalLMOutputWithPast]:
1323
+
1324
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1325
+ output_hidden_states = (
1326
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1327
+ )
1328
+
1329
+ outputs = self.model(
1330
+ input_ids=input_ids,
1331
+ pixel_values=pixel_values,
1332
+ pixel_values_videos=pixel_values_videos,
1333
+ image_grid_thw=image_grid_thw,
1334
+ video_grid_thw=video_grid_thw,
1335
+ position_ids=position_ids,
1336
+ attention_mask=attention_mask,
1337
+ past_key_values=past_key_values,
1338
+ inputs_embeds=inputs_embeds,
1339
+ use_cache=use_cache,
1340
+ output_attentions=output_attentions,
1341
+ output_hidden_states=output_hidden_states,
1342
+ return_dict=True,
1343
+ cache_position=cache_position,
1344
+ **kwargs,
1345
+ )
1346
+
1347
+ hidden_states = outputs[0]
1348
+ logits = self.lm_head(hidden_states)
1349
+
1350
+ loss = None
1351
+ if labels is not None:
1352
+ loss = self.loss_function(
1353
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1354
+ )
1355
+
1356
+ return Qwen2VLCausalLMOutputWithPast(
1357
+ loss=loss,
1358
+ logits=logits,
1359
+ past_key_values=outputs.past_key_values,
1360
+ hidden_states=outputs.hidden_states,
1361
+ attentions=outputs.attentions,
1362
+ rope_deltas=outputs.rope_deltas,
1363
+ )
1364
+
1365
+ def prepare_inputs_for_generation(
1366
+ self,
1367
+ input_ids,
1368
+ past_key_values=None,
1369
+ attention_mask=None,
1370
+ inputs_embeds=None,
1371
+ cache_position=None,
1372
+ position_ids=None,
1373
+ use_cache=True,
1374
+ pixel_values=None,
1375
+ pixel_values_videos=None,
1376
+ image_grid_thw=None,
1377
+ video_grid_thw=None,
1378
+ **kwargs,
1379
+ ):
1380
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1381
+
1382
+ model_inputs = super().prepare_inputs_for_generation(
1383
+ input_ids,
1384
+ past_key_values=past_key_values,
1385
+ attention_mask=attention_mask,
1386
+ inputs_embeds=inputs_embeds,
1387
+ cache_position=cache_position,
1388
+ position_ids=position_ids,
1389
+ pixel_values=pixel_values,
1390
+ pixel_values_videos=pixel_values_videos,
1391
+ image_grid_thw=image_grid_thw,
1392
+ video_grid_thw=video_grid_thw,
1393
+ use_cache=use_cache,
1394
+ **kwargs,
1395
+ )
1396
+
1397
+ # Qwen2-VL position_ids are prepareed with rope_deltas in forward
1398
+ if position_ids is None:
1399
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1400
+ # When compiling, we can't check tensor values thus we check only input length
1401
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1402
+ # models currently cannot do asssisted decoding
1403
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
1404
+ (input_ids is not None and input_ids.shape[1] != 1)
1405
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1406
+ )
1407
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1408
+ (cache_position is not None and cache_position[0] == 0)
1409
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
1410
+ )
1411
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.model.rope_deltas is None:
1412
+ vision_positions, rope_deltas = self.model.get_rope_index(
1413
+ model_inputs.get("input_ids", None),
1414
+ image_grid_thw=image_grid_thw,
1415
+ video_grid_thw=video_grid_thw,
1416
+ attention_mask=attention_mask,
1417
+ )
1418
+ self.model.rope_deltas = rope_deltas
1419
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1420
+ elif "position_ids" in model_inputs:
1421
+ position_ids = model_inputs["position_ids"][None, ...]
1422
+ delta = self.model.rope_deltas
1423
+ delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
1424
+ vision_positions = position_ids + delta.expand_as(position_ids)
1425
+ vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
1426
+
1427
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
1428
+ if "position_ids" not in model_inputs:
1429
+ text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
1430
+ else:
1431
+ text_positions = model_inputs["position_ids"][None, ...]
1432
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1433
+
1434
+ if model_inputs["cache_position"][0] != 0:
1435
+ model_inputs["pixel_values"] = None
1436
+ model_inputs["pixel_values_videos"] = None
1437
+
1438
+ return model_inputs
1439
+
1440
+ def _get_image_nums_and_video_nums(
1441
+ self,
1442
+ input_ids: Optional[torch.LongTensor],
1443
+ inputs_embeds: Optional[torch.Tensor] = None,
1444
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1445
+ """
1446
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1447
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1448
+
1449
+ Args:
1450
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1451
+ Indices of input sequence tokens in the vocabulary.
1452
+
1453
+ Returns:
1454
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1455
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1456
+ """
1457
+ image_token_id = self.config.image_token_id
1458
+ video_token_id = self.config.video_token_id
1459
+ vision_start_token_id = self.config.vision_start_token_id
1460
+
1461
+ if inputs_embeds is not None:
1462
+ vision_start_mask = (
1463
+ inputs_embeds
1464
+ == self.get_input_embeddings()(
1465
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1466
+ )
1467
+ )[..., 0]
1468
+ image_mask = (
1469
+ inputs_embeds
1470
+ == self.get_input_embeddings()(
1471
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
1472
+ )
1473
+ )[..., 0]
1474
+ video_mask = (
1475
+ inputs_embeds
1476
+ == self.get_input_embeddings()(
1477
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
1478
+ )
1479
+ )[..., 0]
1480
+ else:
1481
+ vision_start_mask = input_ids == vision_start_token_id
1482
+ image_mask = input_ids == image_token_id
1483
+ video_mask = input_ids == video_token_id
1484
+
1485
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1486
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1487
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1488
+
1489
+ return image_nums, video_nums
1490
+
1491
+ def _expand_inputs_for_generation(
1492
+ self,
1493
+ expand_size: int = 1,
1494
+ is_encoder_decoder: bool = False,
1495
+ input_ids: Optional[torch.LongTensor] = None,
1496
+ **model_kwargs,
1497
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1498
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1499
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1500
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1501
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1502
+
1503
+ if expand_size == 1:
1504
+ return input_ids, model_kwargs
1505
+
1506
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1507
+
1508
+ def _expand_dict_for_generation_visual(dict_to_expand):
1509
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1510
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1511
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
1512
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1513
+ )
1514
+
1515
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1516
+ samples = torch.split(x, lengths)
1517
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1518
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1519
+ return result
1520
+
1521
+ for key in dict_to_expand:
1522
+ if key == "pixel_values":
1523
+ # split images into samples
1524
+ samples = torch.split(image_grid_thw, list(image_nums))
1525
+ # compute the sequence length of images for each sample
1526
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1527
+ dict_to_expand[key] = _repeat_interleave_samples(
1528
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1529
+ )
1530
+ elif key == "image_grid_thw":
1531
+ # get the num of images for each sample
1532
+ lengths = list(image_nums)
1533
+ dict_to_expand[key] = _repeat_interleave_samples(
1534
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1535
+ )
1536
+ elif key == "pixel_values_videos":
1537
+ samples = torch.split(video_grid_thw, list(video_nums))
1538
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1539
+ dict_to_expand[key] = _repeat_interleave_samples(
1540
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1541
+ )
1542
+ elif key == "video_grid_thw":
1543
+ lengths = list(video_nums)
1544
+ dict_to_expand[key] = _repeat_interleave_samples(
1545
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1546
+ )
1547
+ elif key == "second_per_grid_ts":
1548
+ dict_to_expand[key] = _repeat_interleave_samples(
1549
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1550
+ )
1551
+ return dict_to_expand
1552
+
1553
+ def _expand_dict_for_generation(dict_to_expand):
1554
+ for key in dict_to_expand:
1555
+ if (
1556
+ key != "cache_position"
1557
+ and dict_to_expand[key] is not None
1558
+ and isinstance(dict_to_expand[key], torch.Tensor)
1559
+ and key not in visual_keys
1560
+ ):
1561
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1562
+ return dict_to_expand
1563
+
1564
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1565
+
1566
+ if input_ids is not None:
1567
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1568
+
1569
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1570
+
1571
+ if is_encoder_decoder:
1572
+ if model_kwargs.get("encoder_outputs") is None:
1573
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1574
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1575
+
1576
+ return input_ids, model_kwargs
1577
+
1578
+
1579
+ __all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]
model/qwen_changes.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Callable
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import DynamicCache
7
+ from .modeling_qwen2_vl import Qwen2VLForConditionalGeneration
8
+ from transformers.masking_utils import create_causal_mask
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ def replace_token_pair_vectorized(
13
+ input_ids: torch.Tensor,
14
+ seg_start_token_id: int,
15
+ seg_holder_token_id: int,
16
+ vision_start_token_id: int,
17
+ image_token_id: int,
18
+ ) -> torch.Tensor:
19
+ modified_ids = input_ids.clone()
20
+
21
+ #creating aligned views of current and next tokens
22
+ current_tokens = modified_ids[..., :-1]
23
+ next_tokens = modified_ids[..., 1:]
24
+
25
+ # parallel find all positions where (current == start) & (next == holder)
26
+ mask = (current_tokens == seg_start_token_id) & (next_tokens == seg_holder_token_id)
27
+
28
+ # Use the mask to perform all replacements at once, in parallel
29
+ modified_ids[..., :-1][mask] = vision_start_token_id
30
+ modified_ids[seg_holder_token_id == modified_ids] = image_token_id
31
+
32
+ return modified_ids, mask.sum()
33
+
34
+ import torch
35
+
36
+ def get_rope_index(
37
+ self,
38
+ input_ids: Optional[torch.LongTensor] = None,
39
+ image_grid_thw: Optional[torch.LongTensor] = None,
40
+ video_grid_thw: Optional[torch.LongTensor] = None,
41
+ attention_mask: Optional[torch.Tensor] = None,
42
+ seg_start_token_id: Optional[int] = None,
43
+ seg_holder_token_id: Optional[int] = None,
44
+ ) -> tuple[torch.Tensor, torch.Tensor]:
45
+
46
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
47
+ image_token_id = self.config.image_token_id
48
+ video_token_id = self.config.video_token_id
49
+ vision_start_token_id = self.config.vision_start_token_id
50
+
51
+ input_ids = input_ids.clone()
52
+ if seg_start_token_id is not None and seg_holder_token_id is not None:
53
+ input_ids, num = replace_token_pair_vectorized(input_ids, seg_start_token_id, seg_holder_token_id,
54
+ vision_start_token_id, image_token_id)
55
+ mask_grid_thw = image_grid_thw[-1].clone()
56
+ mask_grid_thw = mask_grid_thw.unsqueeze(0).repeat([num, 1])
57
+ image_grid_thw = torch.cat((image_grid_thw, mask_grid_thw), dim=0)
58
+
59
+ mrope_position_deltas = []
60
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
61
+ total_input_ids = input_ids
62
+ if attention_mask is None:
63
+ attention_mask = torch.ones_like(total_input_ids)
64
+ position_ids = torch.ones(
65
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
66
+ )
67
+ if isinstance(attention_mask, dict):
68
+ attention_mask = attention_mask['raw_attention']
69
+ image_index, video_index = 0, 0
70
+ for i, input_ids in enumerate(total_input_ids):
71
+ input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
72
+ image_nums, video_nums = 0, 0
73
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
74
+ vision_tokens = input_ids[vision_start_indices + 1]
75
+ image_nums = (vision_tokens == image_token_id).sum()
76
+ video_nums = (vision_tokens == video_token_id).sum()
77
+ input_tokens = input_ids.tolist()
78
+ llm_pos_ids_list: list = []
79
+ st = 0
80
+ remain_images, remain_videos = image_nums, video_nums
81
+ for _ in range(image_nums + video_nums):
82
+ if image_token_id in input_tokens and remain_images > 0:
83
+ ed_image = input_tokens.index(image_token_id, st)
84
+ else:
85
+ ed_image = len(input_tokens) + 1
86
+ if video_token_id in input_tokens and remain_videos > 0:
87
+ ed_video = input_tokens.index(video_token_id, st)
88
+ else:
89
+ ed_video = len(input_tokens) + 1
90
+ if ed_image < ed_video:
91
+ t, h, w = (
92
+ image_grid_thw[image_index][0],
93
+ image_grid_thw[image_index][1],
94
+ image_grid_thw[image_index][2],
95
+ )
96
+ image_index += 1
97
+ remain_images -= 1
98
+ ed = ed_image
99
+ else:
100
+ t, h, w = (
101
+ video_grid_thw[video_index][0],
102
+ video_grid_thw[video_index][1],
103
+ video_grid_thw[video_index][2],
104
+ )
105
+ video_index += 1
106
+ remain_videos -= 1
107
+ ed = ed_video
108
+ llm_grid_t, llm_grid_h, llm_grid_w = (
109
+ t.item(),
110
+ h.item() // spatial_merge_size,
111
+ w.item() // spatial_merge_size,
112
+ )
113
+ text_len = ed - st
114
+
115
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
116
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
117
+
118
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
119
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
120
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
121
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
122
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
123
+
124
+ if st < len(input_tokens):
125
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
126
+ text_len = len(input_tokens) - st
127
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
128
+
129
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
130
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
131
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
132
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
133
+ return position_ids, mrope_position_deltas
134
+ else:
135
+ if attention_mask is not None:
136
+ position_ids = attention_mask.long().cumsum(-1) - 1
137
+ position_ids.masked_fill_(attention_mask == 0, 1)
138
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
139
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
140
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
141
+ else:
142
+ position_ids = (
143
+ torch.arange(input_ids.shape[1], device=input_ids.device)
144
+ .view(1, 1, -1)
145
+ .expand(3, input_ids.shape[0], -1)
146
+ )
147
+ mrope_position_deltas = torch.zeros(
148
+ [input_ids.shape[0], 1],
149
+ device=input_ids.device,
150
+ dtype=input_ids.dtype,
151
+ )
152
+
153
+ return position_ids, mrope_position_deltas
154
+
155
+ def get_rope_index_2_5(
156
+ self,
157
+ input_ids: Optional[torch.LongTensor] = None,
158
+ image_grid_thw: Optional[torch.LongTensor] = None,
159
+ video_grid_thw: Optional[torch.LongTensor] = None,
160
+ second_per_grid_ts: Optional[torch.Tensor] = None,
161
+ attention_mask: Optional[torch.Tensor] = None,
162
+ seg_start_token_id: Optional[int] = None,
163
+ seg_holder_token_id: Optional[int] = None,
164
+ ) -> tuple[torch.Tensor, torch.Tensor]:
165
+
166
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
167
+ image_token_id = self.config.image_token_id
168
+ video_token_id = self.config.video_token_id
169
+ vision_start_token_id = self.config.vision_start_token_id
170
+ input_ids = input_ids.clone()
171
+ if seg_start_token_id is not None and seg_holder_token_id is not None:
172
+ input_ids, num = replace_token_pair_vectorized(input_ids, seg_start_token_id, seg_holder_token_id,
173
+ vision_start_token_id, image_token_id)
174
+ mask_grid_thw = image_grid_thw[-1].clone()
175
+ mask_grid_thw = mask_grid_thw.unsqueeze(0).repeat([num, 1])
176
+ image_grid_thw = torch.cat((image_grid_thw, mask_grid_thw), dim=0)
177
+
178
+ mrope_position_deltas = []
179
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
180
+ total_input_ids = input_ids
181
+ if attention_mask is None:
182
+ attention_mask = torch.ones_like(total_input_ids)
183
+ position_ids = torch.ones(
184
+ 3,
185
+ input_ids.shape[0],
186
+ input_ids.shape[1],
187
+ dtype=input_ids.dtype,
188
+ device=input_ids.device,
189
+ )
190
+ image_index, video_index = 0, 0
191
+ attention_mask = attention_mask.to(total_input_ids.device)
192
+ for i, input_ids in enumerate(total_input_ids):
193
+ input_ids = input_ids[attention_mask[i] == 1]
194
+ image_nums, video_nums = 0, 0
195
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
196
+ vision_tokens = input_ids[vision_start_indices + 1]
197
+ image_nums = (vision_tokens == image_token_id).sum()
198
+ video_nums = (vision_tokens == video_token_id).sum()
199
+ input_tokens = input_ids.tolist()
200
+ llm_pos_ids_list: list = []
201
+ st = 0
202
+ remain_images, remain_videos = image_nums, video_nums
203
+ for _ in range(image_nums + video_nums):
204
+ if image_token_id in input_tokens and remain_images > 0:
205
+ ed_image = input_tokens.index(image_token_id, st)
206
+ else:
207
+ ed_image = len(input_tokens) + 1
208
+ if video_token_id in input_tokens and remain_videos > 0:
209
+ ed_video = input_tokens.index(video_token_id, st)
210
+ else:
211
+ ed_video = len(input_tokens) + 1
212
+ if ed_image < ed_video:
213
+ t, h, w = (
214
+ image_grid_thw[image_index][0],
215
+ image_grid_thw[image_index][1],
216
+ image_grid_thw[image_index][2],
217
+ )
218
+ second_per_grid_t = 0
219
+ image_index += 1
220
+ remain_images -= 1
221
+ ed = ed_image
222
+
223
+ else:
224
+ t, h, w = (
225
+ video_grid_thw[video_index][0],
226
+ video_grid_thw[video_index][1],
227
+ video_grid_thw[video_index][2],
228
+ )
229
+ if second_per_grid_ts is not None:
230
+ second_per_grid_t = second_per_grid_ts[video_index]
231
+ else:
232
+ second_per_grid_t = 1.0
233
+ video_index += 1
234
+ remain_videos -= 1
235
+ ed = ed_video
236
+ llm_grid_t, llm_grid_h, llm_grid_w = (
237
+ t.item(),
238
+ h.item() // spatial_merge_size,
239
+ w.item() // spatial_merge_size,
240
+ )
241
+ text_len = ed - st
242
+
243
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
244
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
245
+
246
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
247
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
248
+
249
+ ## normalize type, send to device.
250
+ second_per_grid_t = torch.as_tensor(
251
+ second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
252
+ )
253
+
254
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
255
+
256
+ time_tensor_long = time_tensor.long()
257
+ t_index = time_tensor_long.flatten()
258
+
259
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
260
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
261
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
262
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
263
+
264
+ if st < len(input_tokens):
265
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
266
+ text_len = len(input_tokens) - st
267
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
268
+
269
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
270
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
271
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
272
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
273
+ return position_ids, mrope_position_deltas
274
+ else:
275
+ if attention_mask is not None:
276
+ position_ids = attention_mask.long().cumsum(-1) - 1
277
+ position_ids.masked_fill_(attention_mask == 0, 1)
278
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
279
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
280
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
281
+ else:
282
+ position_ids = (
283
+ torch.arange(input_ids.shape[1], device=input_ids.device)
284
+ .view(1, 1, -1)
285
+ .expand(3, input_ids.shape[0], -1)
286
+ )
287
+ mrope_position_deltas = torch.zeros(
288
+ [input_ids.shape[0], 1],
289
+ device=input_ids.device,
290
+ dtype=input_ids.dtype,
291
+ )
292
+
293
+ return position_ids, mrope_position_deltas
294
+
295
+ @dataclass
296
+ class CustomModelOutput(ModelOutput):
297
+ loss: Optional[torch.FloatTensor] = None
298
+ logits: torch.FloatTensor = None
299
+ bi_logits: Optional[torch.FloatTensor] = None
300
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
301
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
302
+
303
+
304
+ import torch
305
+
306
+
307
+ def create_bidirectional_lookup_function(seg_mask_tensor: torch.Tensor) -> Callable:
308
+
309
+ def lookup_function(batch_idx, head_idx, q_idx, kv_idx) -> bool:
310
+ is_query_in_seg = seg_mask_tensor[batch_idx, q_idx]
311
+
312
+ return is_query_in_seg
313
+
314
+ return lookup_function
315
+
316
+ def _create_hybrid_mask_and_dependencies(
317
+ self,
318
+ seg_mask: torch.Tensor,
319
+ inputs_embeds: torch.Tensor,
320
+ attention_mask: torch.Tensor,
321
+ position_ids: torch.Tensor,
322
+ **kwargs,
323
+ ):
324
+
325
+
326
+ bidirectional_mask_fn = create_bidirectional_lookup_function(seg_mask)
327
+
328
+ use_cache = kwargs.get('use_cache', None)
329
+ if self.is_gradient_checkpointing and self.training:
330
+ if use_cache:
331
+ use_cache = False
332
+
333
+ past_key_values = kwargs.get('past_key_values', None)
334
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
335
+ past_key_values = DynamicCache(config=self.config)
336
+
337
+ cache_position = kwargs.get('cache_position', None)
338
+ if cache_position is None:
339
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
340
+ cache_position = torch.arange(
341
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
342
+ )
343
+
344
+ if position_ids is None:
345
+ local_position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
346
+ elif position_ids.ndim == 2:
347
+ local_position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
348
+ else:
349
+ local_position_ids = position_ids
350
+
351
+ if local_position_ids.ndim == 3 and local_position_ids.shape[0] == 4:
352
+ text_position_ids = local_position_ids[0]
353
+ final_position_ids = local_position_ids[1:]
354
+ else:
355
+ text_position_ids = local_position_ids[0]
356
+ final_position_ids = position_ids
357
+
358
+ mask_kwargs = {
359
+ "config": self.config,
360
+ "input_embeds": inputs_embeds,
361
+ "attention_mask": attention_mask,
362
+ "cache_position": cache_position,
363
+ "past_key_values": past_key_values,
364
+ "position_ids": text_position_ids,
365
+ "or_mask_function": bidirectional_mask_fn,
366
+ }
367
+ hybrid_attention_mask = create_causal_mask(**mask_kwargs)
368
+
369
+ return hybrid_attention_mask, final_position_ids, past_key_values, use_cache, cache_position
370
+
371
+ class SegQwenVL(Qwen2VLForConditionalGeneration):
372
+ def __init__(self, config):
373
+ super().__init__(config)
374
+ self.classifier = nn.Linear(config.hidden_size, 1)
375
+ self.model._create_hybrid_mask_and_dependencies = _create_hybrid_mask_and_dependencies.__get__(self)
376
+ self.model.get_rope_index = get_rope_index.__get__(self)
377
+
378
+ def forward(self, input_ids: torch.LongTensor = None, attention_mask: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
379
+ position_ids=None, labels: torch.LongTensor = None, do_classification: bool=False, output_hidden_states=False, **kwargs,):
380
+
381
+ if do_classification:
382
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
383
+ image_embeds = self.model.get_image_features(pixel_values, kwargs['image_grid_thw'])
384
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
385
+ image_mask, _ = self.model.get_placeholder_mask(
386
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
387
+ )
388
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
389
+ seg_mask = (input_ids == self.mask_token_id)
390
+
391
+ inputs_embeds[seg_mask] = inputs_embeds[seg_mask] + image_embeds[-seg_mask.sum():]
392
+
393
+ outputs = self.model(
394
+ input_ids=input_ids,
395
+ inputs_embeds=inputs_embeds,
396
+ attention_mask=attention_mask,
397
+ pixel_values=None,
398
+ output_hidden_states=True,
399
+ position_ids=position_ids,
400
+ seg_mask=seg_mask,
401
+ **kwargs,
402
+ )
403
+ last_hidden_state = outputs.hidden_states[-1]
404
+ logits = self.classifier(last_hidden_state)
405
+
406
+ return CustomModelOutput(
407
+ bi_logits=logits,
408
+ # hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
411
+
412
+ else:
413
+ if labels is not None:
414
+ output_hidden_states = True
415
+
416
+ original_output = super().forward(
417
+ input_ids=input_ids,
418
+ attention_mask=attention_mask,
419
+ pixel_values=pixel_values,
420
+ labels=labels,
421
+ output_hidden_states=output_hidden_states,
422
+ position_ids=position_ids,
423
+ **kwargs,
424
+ )
425
+ if labels is not None:
426
+ last_hidden_state = original_output.hidden_states[-1]
427
+ dummy_logits = self.classifier(last_hidden_state)
428
+ if hasattr(original_output, 'loss') and original_output.loss is not None:
429
+ dummy_loss = dummy_logits[0, 0].sum() * 0.0
430
+ original_output.loss += dummy_loss
431
+
432
+ return original_output
433
+
model/segment_anything/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .build_sam import (
3
+ build_sam,
4
+ build_sam_vit_h,
5
+ build_sam_vit_l,
6
+ build_sam_vit_b,
7
+ sam_model_registry,
8
+ )
9
+ from .predictor import SamPredictor
10
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
model/segment_anything/automatic_mask_generator.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
6
+
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ from .modeling import Sam
10
+ from .predictor import SamPredictor
11
+ from .utils.amg import (
12
+ MaskData,
13
+ area_from_rle,
14
+ batch_iterator,
15
+ batched_mask_to_box,
16
+ box_xyxy_to_xywh,
17
+ build_all_layer_point_grids,
18
+ calculate_stability_score,
19
+ coco_encode_rle,
20
+ generate_crop_boxes,
21
+ is_box_near_crop_edge,
22
+ mask_to_rle_pytorch,
23
+ remove_small_regions,
24
+ rle_to_mask,
25
+ uncrop_boxes_xyxy,
26
+ uncrop_masks,
27
+ uncrop_points,
28
+ )
29
+
30
+
31
+ class SamAutomaticMaskGenerator:
32
+ def __init__(
33
+ self,
34
+ model: Sam,
35
+ points_per_side: Optional[int] = 32,
36
+ points_per_batch: int = 64,
37
+ pred_iou_thresh: float = 0.88,
38
+ stability_score_thresh: float = 0.95,
39
+ stability_score_offset: float = 1.0,
40
+ box_nms_thresh: float = 0.7,
41
+ crop_n_layers: int = 0,
42
+ crop_nms_thresh: float = 0.7,
43
+ crop_overlap_ratio: float = 512 / 1500,
44
+ crop_n_points_downscale_factor: int = 1,
45
+ point_grids: Optional[List[np.ndarray]] = None,
46
+ min_mask_region_area: int = 0,
47
+ output_mode: str = "binary_mask",
48
+ ) -> None:
49
+
50
+
51
+ assert (points_per_side is None) != (
52
+ point_grids is None
53
+ ), "Exactly one of points_per_side or point_grid must be provided."
54
+ if points_per_side is not None:
55
+ self.point_grids = build_all_layer_point_grids(
56
+ points_per_side,
57
+ crop_n_layers,
58
+ crop_n_points_downscale_factor,
59
+ )
60
+ elif point_grids is not None:
61
+ self.point_grids = point_grids
62
+ else:
63
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
64
+
65
+ assert output_mode in [
66
+ "binary_mask",
67
+ "uncompressed_rle",
68
+ "coco_rle",
69
+ ], f"Unknown output_mode {output_mode}."
70
+ if output_mode == "coco_rle":
71
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
72
+
73
+ if min_mask_region_area > 0:
74
+ import cv2 # type: ignore # noqa: F401
75
+
76
+ self.predictor = SamPredictor(model)
77
+ self.points_per_batch = points_per_batch
78
+ self.pred_iou_thresh = pred_iou_thresh
79
+ self.stability_score_thresh = stability_score_thresh
80
+ self.stability_score_offset = stability_score_offset
81
+ self.box_nms_thresh = box_nms_thresh
82
+ self.crop_n_layers = crop_n_layers
83
+ self.crop_nms_thresh = crop_nms_thresh
84
+ self.crop_overlap_ratio = crop_overlap_ratio
85
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
86
+ self.min_mask_region_area = min_mask_region_area
87
+ self.output_mode = output_mode
88
+
89
+ @torch.no_grad()
90
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
91
+
92
+
93
+ # Generate masks
94
+ mask_data = self._generate_masks(image)
95
+
96
+ # Filter small disconnected regions and holes in masks
97
+ if self.min_mask_region_area > 0:
98
+ mask_data = self.postprocess_small_regions(
99
+ mask_data,
100
+ self.min_mask_region_area,
101
+ max(self.box_nms_thresh, self.crop_nms_thresh),
102
+ )
103
+
104
+ # Encode masks
105
+ if self.output_mode == "coco_rle":
106
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
107
+ elif self.output_mode == "binary_mask":
108
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
109
+ else:
110
+ mask_data["segmentations"] = mask_data["rles"]
111
+
112
+ # Write mask records
113
+ curr_anns = []
114
+ for idx in range(len(mask_data["segmentations"])):
115
+ ann = {
116
+ "segmentation": mask_data["segmentations"][idx],
117
+ "area": area_from_rle(mask_data["rles"][idx]),
118
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
119
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
120
+ "point_coords": [mask_data["points"][idx].tolist()],
121
+ "stability_score": mask_data["stability_score"][idx].item(),
122
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
123
+ }
124
+ curr_anns.append(ann)
125
+
126
+ return curr_anns
127
+
128
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
129
+ orig_size = image.shape[:2]
130
+ crop_boxes, layer_idxs = generate_crop_boxes(
131
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
132
+ )
133
+
134
+ # Iterate over image crops
135
+ data = MaskData()
136
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
137
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
138
+ data.cat(crop_data)
139
+
140
+ # Remove duplicate masks between crops
141
+ if len(crop_boxes) > 1:
142
+ # Prefer masks from smaller crops
143
+ scores = 1 / box_area(data["crop_boxes"])
144
+ scores = scores.to(data["boxes"].device)
145
+ keep_by_nms = batched_nms(
146
+ data["boxes"].float(),
147
+ scores,
148
+ torch.zeros_like(data["boxes"][:, 0]), # categories
149
+ iou_threshold=self.crop_nms_thresh,
150
+ )
151
+ data.filter(keep_by_nms)
152
+
153
+ data.to_numpy()
154
+ return data
155
+
156
+ def _process_crop(
157
+ self,
158
+ image: np.ndarray,
159
+ crop_box: List[int],
160
+ crop_layer_idx: int,
161
+ orig_size: Tuple[int, ...],
162
+ ) -> MaskData:
163
+ # Crop the image and calculate embeddings
164
+ x0, y0, x1, y1 = crop_box
165
+ cropped_im = image[y0:y1, x0:x1, :]
166
+ cropped_im_size = cropped_im.shape[:2]
167
+ self.predictor.set_image(cropped_im)
168
+
169
+ # Get points for this crop
170
+ points_scale = np.array(cropped_im_size)[None, ::-1]
171
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
172
+
173
+ # Generate masks for this crop in batches
174
+ data = MaskData()
175
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
176
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
177
+ data.cat(batch_data)
178
+ del batch_data
179
+ self.predictor.reset_image()
180
+
181
+ # Remove duplicates within this crop.
182
+ keep_by_nms = batched_nms(
183
+ data["boxes"].float(),
184
+ data["iou_preds"],
185
+ torch.zeros_like(data["boxes"][:, 0]), # categories
186
+ iou_threshold=self.box_nms_thresh,
187
+ )
188
+ data.filter(keep_by_nms)
189
+
190
+ # Return to the original image frame
191
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
192
+ data["points"] = uncrop_points(data["points"], crop_box)
193
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
194
+
195
+ return data
196
+
197
+ def _process_batch(
198
+ self,
199
+ points: np.ndarray,
200
+ im_size: Tuple[int, ...],
201
+ crop_box: List[int],
202
+ orig_size: Tuple[int, ...],
203
+ ) -> MaskData:
204
+ orig_h, orig_w = orig_size
205
+
206
+ # Run model on this batch
207
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
208
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
209
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
210
+ masks, iou_preds, _ = self.predictor.predict_torch(
211
+ in_points[:, None, :],
212
+ in_labels[:, None],
213
+ multimask_output=True,
214
+ return_logits=True,
215
+ )
216
+
217
+ # Serialize predictions and store in MaskData
218
+ data = MaskData(
219
+ masks=masks.flatten(0, 1),
220
+ iou_preds=iou_preds.flatten(0, 1),
221
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
222
+ )
223
+ del masks
224
+
225
+ # Filter by predicted IoU
226
+ if self.pred_iou_thresh > 0.0:
227
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
228
+ data.filter(keep_mask)
229
+
230
+ # Calculate stability score
231
+ data["stability_score"] = calculate_stability_score(
232
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
233
+ )
234
+ if self.stability_score_thresh > 0.0:
235
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
236
+ data.filter(keep_mask)
237
+
238
+ # Threshold masks and calculate boxes
239
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
240
+ data["boxes"] = batched_mask_to_box(data["masks"])
241
+
242
+ # Filter boxes that touch crop boundaries
243
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
244
+ if not torch.all(keep_mask):
245
+ data.filter(keep_mask)
246
+
247
+ # Compress to RLE
248
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
249
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
250
+ del data["masks"]
251
+
252
+ return data
253
+
254
+ @staticmethod
255
+ def postprocess_small_regions(
256
+ mask_data: MaskData, min_area: int, nms_thresh: float
257
+ ) -> MaskData:
258
+
259
+ if len(mask_data["rles"]) == 0:
260
+ return mask_data
261
+
262
+ # Filter small disconnected regions and holes
263
+ new_masks = []
264
+ scores = []
265
+ for rle in mask_data["rles"]:
266
+ mask = rle_to_mask(rle)
267
+
268
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
269
+ unchanged = not changed
270
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
271
+ unchanged = unchanged and not changed
272
+
273
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
274
+ # Give score=0 to changed masks and score=1 to unchanged masks
275
+ # so NMS will prefer ones that didn't need postprocessing
276
+ scores.append(float(unchanged))
277
+
278
+ # Recalculate boxes and remove any new duplicates
279
+ masks = torch.cat(new_masks, dim=0)
280
+ boxes = batched_mask_to_box(masks)
281
+ keep_by_nms = batched_nms(
282
+ boxes.float(),
283
+ torch.as_tensor(scores),
284
+ torch.zeros_like(boxes[:, 0]), # categories
285
+ iou_threshold=nms_thresh,
286
+ )
287
+
288
+ # Only recalculate RLEs for masks that have changed
289
+ for i_mask in keep_by_nms:
290
+ if scores[i_mask] == 0.0:
291
+ mask_torch = masks[i_mask].unsqueeze(0)
292
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
293
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
294
+ mask_data.filter(keep_by_nms)
295
+
296
+ return mask_data
model/segment_anything/build_sam.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ )
22
+
23
+
24
+ build_sam = build_sam_vit_h
25
+
26
+
27
+ def build_sam_vit_l(checkpoint=None):
28
+ return _build_sam(
29
+ encoder_embed_dim=1024,
30
+ encoder_depth=24,
31
+ encoder_num_heads=16,
32
+ encoder_global_attn_indexes=[5, 11, 17, 23],
33
+ checkpoint=checkpoint,
34
+ )
35
+
36
+
37
+ def build_sam_vit_b(checkpoint=None):
38
+ return _build_sam(
39
+ encoder_embed_dim=768,
40
+ encoder_depth=12,
41
+ encoder_num_heads=12,
42
+ encoder_global_attn_indexes=[2, 5, 8, 11],
43
+ checkpoint=checkpoint,
44
+ )
45
+
46
+
47
+ sam_model_registry = {
48
+ "default": build_sam_vit_h,
49
+ "vit_h": build_sam_vit_h,
50
+ "vit_l": build_sam_vit_l,
51
+ "vit_b": build_sam_vit_b,
52
+ }
53
+
54
+
55
+ def _build_sam(
56
+ encoder_embed_dim,
57
+ encoder_depth,
58
+ encoder_num_heads,
59
+ encoder_global_attn_indexes,
60
+ checkpoint=None,
61
+ ):
62
+ prompt_embed_dim = 256
63
+ image_size = 1024
64
+ vit_patch_size = 16
65
+ image_embedding_size = image_size // vit_patch_size
66
+ sam = Sam(
67
+ image_encoder=ImageEncoderViT(
68
+ depth=encoder_depth,
69
+ embed_dim=encoder_embed_dim,
70
+ img_size=image_size,
71
+ mlp_ratio=4,
72
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73
+ num_heads=encoder_num_heads,
74
+ patch_size=vit_patch_size,
75
+ qkv_bias=True,
76
+ use_rel_pos=True,
77
+ global_attn_indexes=encoder_global_attn_indexes,
78
+ window_size=14,
79
+ out_chans=prompt_embed_dim,
80
+ ),
81
+ prompt_encoder=PromptEncoder(
82
+ embed_dim=prompt_embed_dim,
83
+ image_embedding_size=(image_embedding_size, image_embedding_size),
84
+ input_image_size=(image_size, image_size),
85
+ mask_in_chans=16,
86
+ ),
87
+ mask_decoder=MaskDecoder(
88
+ num_multimask_outputs=3,
89
+ transformer=TwoWayTransformer(
90
+ depth=2,
91
+ embedding_dim=prompt_embed_dim,
92
+ mlp_dim=2048,
93
+ num_heads=8,
94
+ ),
95
+ transformer_dim=prompt_embed_dim,
96
+ iou_head_depth=3,
97
+ iou_head_hidden_dim=256,
98
+ ),
99
+ pixel_mean=[123.675, 116.28, 103.53],
100
+ pixel_std=[58.395, 57.12, 57.375],
101
+ )
102
+ from huggingface_hub import hf_hub_download
103
+ import os
104
+ if checkpoint is None or not os.path.exists(checkpoint):
105
+ # If the checkpoint is not provided or does not exist locally, download it from Hugging Face
106
+ print(f"Model file not found locally: {checkpoint}, downloading from Hugging Face...")
107
+
108
+ try:
109
+ checkpoint = hf_hub_download(
110
+ repo_id="HCMUE-Research/SAM-vit-h",
111
+ filename="sam_vit_h_4b8939.pth"
112
+ )
113
+ print(f": {checkpoint}")
114
+ except Exception as e:
115
+ raise RuntimeError(f"Model download failed, please check your network or download manually: {e}")
116
+
117
+ with open(checkpoint, "rb") as f:
118
+ state_dict = torch.load(f)
119
+ sam.load_state_dict(state_dict)
120
+ return sam
model/segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .sam import Sam
2
+ from .image_encoder import ImageEncoderViT
3
+ from .mask_decoder import MaskDecoder
4
+ from .prompt_encoder import PromptEncoder
5
+ from .transformer import TwoWayTransformer
model/segment_anything/modeling/common.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Type
5
+
6
+
7
+ class MLPBlock(nn.Module):
8
+ def __init__(
9
+ self,
10
+ embedding_dim: int,
11
+ mlp_dim: int,
12
+ act: Type[nn.Module] = nn.GELU,
13
+ ) -> None:
14
+ super().__init__()
15
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
16
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
17
+ self.act = act()
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ return self.lin2(self.act(self.lin1(x)))
21
+
22
+
23
+ class LayerNorm2d(nn.Module):
24
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
25
+ super().__init__()
26
+ self.weight = nn.Parameter(torch.ones(num_channels))
27
+ self.bias = nn.Parameter(torch.zeros(num_channels))
28
+ self.eps = eps
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ u = x.mean(1, keepdim=True)
32
+ s = (x - u).pow(2).mean(1, keepdim=True)
33
+ x = (x - u) / torch.sqrt(s + self.eps)
34
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
35
+ return x
model/segment_anything/modeling/image_encoder.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from typing import Optional, Tuple, Type
7
+
8
+ from .common import LayerNorm2d, MLPBlock
9
+
10
+
11
+ class ImageEncoderViT(nn.Module):
12
+ def __init__(
13
+ self,
14
+ img_size: int = 1024,
15
+ patch_size: int = 16,
16
+ in_chans: int = 3,
17
+ embed_dim: int = 768,
18
+ depth: int = 12,
19
+ num_heads: int = 12,
20
+ mlp_ratio: float = 4.0,
21
+ out_chans: int = 256,
22
+ qkv_bias: bool = True,
23
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
24
+ act_layer: Type[nn.Module] = nn.GELU,
25
+ use_abs_pos: bool = True,
26
+ use_rel_pos: bool = False,
27
+ rel_pos_zero_init: bool = True,
28
+ window_size: int = 0,
29
+ global_attn_indexes: Tuple[int, ...] = (),
30
+ ) -> None:
31
+ super().__init__()
32
+ self.img_size = img_size
33
+
34
+ self.patch_embed = PatchEmbed(
35
+ kernel_size=(patch_size, patch_size),
36
+ stride=(patch_size, patch_size),
37
+ in_chans=in_chans,
38
+ embed_dim=embed_dim,
39
+ )
40
+
41
+ self.pos_embed: Optional[nn.Parameter] = None
42
+ if use_abs_pos:
43
+ # Initialize absolute positional embedding with pretrain image size.
44
+ self.pos_embed = nn.Parameter(
45
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
46
+ )
47
+
48
+ self.blocks = nn.ModuleList()
49
+ for i in range(depth):
50
+ block = Block(
51
+ dim=embed_dim,
52
+ num_heads=num_heads,
53
+ mlp_ratio=mlp_ratio,
54
+ qkv_bias=qkv_bias,
55
+ norm_layer=norm_layer,
56
+ act_layer=act_layer,
57
+ use_rel_pos=use_rel_pos,
58
+ rel_pos_zero_init=rel_pos_zero_init,
59
+ window_size=window_size if i not in global_attn_indexes else 0,
60
+ input_size=(img_size // patch_size, img_size // patch_size),
61
+ )
62
+ self.blocks.append(block)
63
+
64
+ self.neck = nn.Sequential(
65
+ nn.Conv2d(
66
+ embed_dim,
67
+ out_chans,
68
+ kernel_size=1,
69
+ bias=False,
70
+ ),
71
+ LayerNorm2d(out_chans),
72
+ nn.Conv2d(
73
+ out_chans,
74
+ out_chans,
75
+ kernel_size=3,
76
+ padding=1,
77
+ bias=False,
78
+ ),
79
+ LayerNorm2d(out_chans),
80
+ )
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ x = self.patch_embed(x)
84
+ if self.pos_embed is not None:
85
+ x = x + self.pos_embed
86
+
87
+ for blk in self.blocks:
88
+ x = blk(x)
89
+
90
+ x = self.neck(x.permute(0, 3, 1, 2))
91
+
92
+ return x
93
+
94
+
95
+ class Block(nn.Module):
96
+
97
+ def __init__(
98
+ self,
99
+ dim: int,
100
+ num_heads: int,
101
+ mlp_ratio: float = 4.0,
102
+ qkv_bias: bool = True,
103
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
104
+ act_layer: Type[nn.Module] = nn.GELU,
105
+ use_rel_pos: bool = False,
106
+ rel_pos_zero_init: bool = True,
107
+ window_size: int = 0,
108
+ input_size: Optional[Tuple[int, int]] = None,
109
+ ) -> None:
110
+ super().__init__()
111
+ self.norm1 = norm_layer(dim)
112
+ self.attn = Attention(
113
+ dim,
114
+ num_heads=num_heads,
115
+ qkv_bias=qkv_bias,
116
+ use_rel_pos=use_rel_pos,
117
+ rel_pos_zero_init=rel_pos_zero_init,
118
+ input_size=input_size if window_size == 0 else (window_size, window_size),
119
+ )
120
+
121
+ self.norm2 = norm_layer(dim)
122
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
123
+
124
+ self.window_size = window_size
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ shortcut = x
128
+ x = self.norm1(x)
129
+ # Window partition
130
+ if self.window_size > 0:
131
+ H, W = x.shape[1], x.shape[2]
132
+ x, pad_hw = window_partition(x, self.window_size)
133
+
134
+ x = self.attn(x)
135
+ # Reverse window partition
136
+ if self.window_size > 0:
137
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
138
+
139
+ x = shortcut + x
140
+ x = x + self.mlp(self.norm2(x))
141
+
142
+ return x
143
+
144
+
145
+ class Attention(nn.Module):
146
+
147
+ def __init__(
148
+ self,
149
+ dim: int,
150
+ num_heads: int = 8,
151
+ qkv_bias: bool = True,
152
+ use_rel_pos: bool = False,
153
+ rel_pos_zero_init: bool = True,
154
+ input_size: Optional[Tuple[int, int]] = None,
155
+ ) -> None:
156
+ super().__init__()
157
+ self.num_heads = num_heads
158
+ head_dim = dim // num_heads
159
+ self.scale = head_dim**-0.5
160
+
161
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
162
+ self.proj = nn.Linear(dim, dim)
163
+
164
+ self.use_rel_pos = use_rel_pos
165
+ if self.use_rel_pos:
166
+ assert (
167
+ input_size is not None
168
+ ), "Input size must be provided if using relative positional encoding."
169
+ # initialize relative positional embeddings
170
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
171
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ B, H, W, _ = x.shape
175
+ # qkv with shape (3, B, nHead, H * W, C)
176
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
177
+ # q, k, v with shape (B * nHead, H * W, C)
178
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
179
+
180
+ attn = (q * self.scale) @ k.transpose(-2, -1)
181
+
182
+ if self.use_rel_pos:
183
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
184
+
185
+ attn = attn.softmax(dim=-1)
186
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
187
+ x = self.proj(x)
188
+
189
+ return x
190
+
191
+
192
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
193
+ B, H, W, C = x.shape
194
+
195
+ pad_h = (window_size - H % window_size) % window_size
196
+ pad_w = (window_size - W % window_size) % window_size
197
+ if pad_h > 0 or pad_w > 0:
198
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
199
+ Hp, Wp = H + pad_h, W + pad_w
200
+
201
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
202
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
203
+ return windows, (Hp, Wp)
204
+
205
+
206
+ def window_unpartition(
207
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
208
+ ) -> torch.Tensor:
209
+ Hp, Wp = pad_hw
210
+ H, W = hw
211
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
212
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
213
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
214
+
215
+ if Hp > H or Wp > W:
216
+ x = x[:, :H, :W, :].contiguous()
217
+ return x
218
+
219
+
220
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
221
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
222
+ # Interpolate rel pos if needed.
223
+ if rel_pos.shape[0] != max_rel_dist:
224
+ # Interpolate rel pos.
225
+ rel_pos_resized = F.interpolate(
226
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
227
+ size=max_rel_dist,
228
+ mode="linear",
229
+ )
230
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
231
+ else:
232
+ rel_pos_resized = rel_pos
233
+
234
+ # Scale the coords with short length if shapes for q and k are different.
235
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
236
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
237
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
238
+
239
+ return rel_pos_resized[relative_coords.long()]
240
+
241
+
242
+ def add_decomposed_rel_pos(
243
+ attn: torch.Tensor,
244
+ q: torch.Tensor,
245
+ rel_pos_h: torch.Tensor,
246
+ rel_pos_w: torch.Tensor,
247
+ q_size: Tuple[int, int],
248
+ k_size: Tuple[int, int],
249
+ ) -> torch.Tensor:
250
+ q_h, q_w = q_size
251
+ k_h, k_w = k_size
252
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
253
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
254
+
255
+ B, _, dim = q.shape
256
+ r_q = q.reshape(B, q_h, q_w, dim)
257
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
258
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
259
+
260
+ attn = (
261
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
262
+ ).view(B, q_h * q_w, k_h * k_w)
263
+
264
+ return attn
265
+
266
+
267
+ class PatchEmbed(nn.Module):
268
+
269
+ def __init__(
270
+ self,
271
+ kernel_size: Tuple[int, int] = (16, 16),
272
+ stride: Tuple[int, int] = (16, 16),
273
+ padding: Tuple[int, int] = (0, 0),
274
+ in_chans: int = 3,
275
+ embed_dim: int = 768,
276
+ ) -> None:
277
+ super().__init__()
278
+
279
+ self.proj = nn.Conv2d(
280
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
281
+ )
282
+
283
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
284
+ x = self.proj(x)
285
+ # B C H W -> B H W C
286
+ x = x.permute(0, 2, 3, 1)
287
+ return x
model/segment_anything/modeling/mask_decoder.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from typing import List, Tuple, Type
6
+
7
+ from .common import LayerNorm2d
8
+
9
+
10
+ class MaskDecoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ *,
14
+ transformer_dim: int,
15
+ transformer: nn.Module,
16
+ num_multimask_outputs: int = 3,
17
+ activation: Type[nn.Module] = nn.GELU,
18
+ iou_head_depth: int = 3,
19
+ iou_head_hidden_dim: int = 256,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.transformer_dim = transformer_dim
23
+ self.transformer = transformer
24
+
25
+ self.num_multimask_outputs = num_multimask_outputs
26
+
27
+ self.iou_token = nn.Embedding(1, transformer_dim)
28
+ self.num_mask_tokens = num_multimask_outputs + 1
29
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
30
+
31
+ self.output_upscaling = nn.Sequential(
32
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
33
+ LayerNorm2d(transformer_dim // 4),
34
+ activation(),
35
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
36
+ activation(),
37
+ )
38
+ self.output_hypernetworks_mlps = nn.ModuleList(
39
+ [
40
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
41
+ for i in range(self.num_mask_tokens)
42
+ ]
43
+ )
44
+
45
+ self.iou_prediction_head = MLP(
46
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
47
+ )
48
+
49
+ def forward(
50
+ self,
51
+ image_embeddings: torch.Tensor,
52
+ image_pe: torch.Tensor,
53
+ sparse_prompt_embeddings: torch.Tensor,
54
+ dense_prompt_embeddings: torch.Tensor,
55
+ multimask_output: bool,
56
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ masks, iou_pred = self.predict_masks(
58
+ image_embeddings=image_embeddings,
59
+ image_pe=image_pe,
60
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
61
+ dense_prompt_embeddings=dense_prompt_embeddings,
62
+ )
63
+
64
+ # Select the correct mask or masks for output
65
+ if multimask_output:
66
+ mask_slice = slice(1, None)
67
+ else:
68
+ mask_slice = slice(0, 1)
69
+ masks = masks[:, mask_slice, :, :]
70
+ iou_pred = iou_pred[:, mask_slice]
71
+
72
+ # Prepare output
73
+ return masks, iou_pred
74
+
75
+ def predict_masks(
76
+ self,
77
+ image_embeddings: torch.Tensor,
78
+ image_pe: torch.Tensor,
79
+ sparse_prompt_embeddings: torch.Tensor,
80
+ dense_prompt_embeddings: torch.Tensor,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ """Predicts masks. See 'forward' for more details."""
83
+ # Concatenate output tokens
84
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
85
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
86
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
87
+
88
+ # Expand per-image data in batch direction to be per-mask
89
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
90
+ src = src + dense_prompt_embeddings
91
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
92
+ b, c, h, w = src.shape
93
+
94
+ # Run the transformer
95
+ hs, src = self.transformer(src, pos_src, tokens)
96
+ iou_token_out = hs[:, 0, :]
97
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
98
+
99
+ # Upscale mask embeddings and predict masks using the mask tokens
100
+ src = src.transpose(1, 2).view(b, c, h, w)
101
+ upscaled_embedding = self.output_upscaling(src)
102
+ hyper_in_list: List[torch.Tensor] = []
103
+ for i in range(self.num_mask_tokens):
104
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
105
+ hyper_in = torch.stack(hyper_in_list, dim=1)
106
+ b, c, h, w = upscaled_embedding.shape
107
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
108
+
109
+ # Generate mask quality predictions
110
+ iou_pred = self.iou_prediction_head(iou_token_out)
111
+
112
+ return masks, iou_pred
113
+
114
+
115
+ class MLP(nn.Module):
116
+ def __init__(
117
+ self,
118
+ input_dim: int,
119
+ hidden_dim: int,
120
+ output_dim: int,
121
+ num_layers: int,
122
+ sigmoid_output: bool = False,
123
+ ) -> None:
124
+ super().__init__()
125
+ self.num_layers = num_layers
126
+ h = [hidden_dim] * (num_layers - 1)
127
+ self.layers = nn.ModuleList(
128
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
129
+ )
130
+ self.sigmoid_output = sigmoid_output
131
+
132
+ def forward(self, x):
133
+ for i, layer in enumerate(self.layers):
134
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
135
+ if self.sigmoid_output:
136
+ x = F.sigmoid(x)
137
+ return x
model/segment_anything/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+
6
+ from typing import Any, Optional, Tuple, Type
7
+
8
+ from .common import LayerNorm2d
9
+
10
+
11
+ class PromptEncoder(nn.Module):
12
+ def __init__(
13
+ self,
14
+ embed_dim: int,
15
+ image_embedding_size: Tuple[int, int],
16
+ input_image_size: Tuple[int, int],
17
+ mask_in_chans: int,
18
+ activation: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.embed_dim = embed_dim
22
+ self.input_image_size = input_image_size
23
+ self.image_embedding_size = image_embedding_size
24
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
25
+
26
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
27
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
28
+ self.point_embeddings = nn.ModuleList(point_embeddings)
29
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
30
+
31
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
32
+ self.mask_downscaling = nn.Sequential(
33
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
34
+ LayerNorm2d(mask_in_chans // 4),
35
+ activation(),
36
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
37
+ LayerNorm2d(mask_in_chans),
38
+ activation(),
39
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
40
+ )
41
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
42
+
43
+ def get_dense_pe(self) -> torch.Tensor:
44
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
45
+
46
+ def _embed_points(
47
+ self,
48
+ points: torch.Tensor,
49
+ labels: torch.Tensor,
50
+ pad: bool,
51
+ ) -> torch.Tensor:
52
+ points = points + 0.5 # Shift to center of pixel
53
+ if pad:
54
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
55
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
56
+ points = torch.cat([points, padding_point], dim=1)
57
+ labels = torch.cat([labels, padding_label], dim=1)
58
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
59
+ point_embedding[labels == -1] = 0.0
60
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
61
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
62
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
63
+ return point_embedding
64
+
65
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
66
+ """Embeds box prompts."""
67
+ boxes = boxes + 0.5 # Shift to center of pixel
68
+ coords = boxes.reshape(-1, 2, 2)
69
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
70
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
71
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
72
+ return corner_embedding
73
+
74
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
75
+ """Embeds mask inputs."""
76
+ mask_embedding = self.mask_downscaling(masks)
77
+ return mask_embedding
78
+
79
+ def _get_batch_size(
80
+ self,
81
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
82
+ boxes: Optional[torch.Tensor],
83
+ masks: Optional[torch.Tensor],
84
+ ) -> int:
85
+ """
86
+ Gets the batch size of the output given the batch size of the input prompts.
87
+ """
88
+ if points is not None:
89
+ return points[0].shape[0]
90
+ elif boxes is not None:
91
+ return boxes.shape[0]
92
+ elif masks is not None:
93
+ return masks.shape[0]
94
+ else:
95
+ return 1
96
+
97
+ def _get_device(self) -> torch.device:
98
+ return self.point_embeddings[0].weight.device
99
+
100
+ def forward(
101
+ self,
102
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
103
+ boxes: Optional[torch.Tensor],
104
+ masks: Optional[torch.Tensor],
105
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ bs = self._get_batch_size(points, boxes, masks)
107
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
108
+ if points is not None:
109
+ coords, labels = points
110
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
111
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
112
+ if boxes is not None:
113
+ box_embeddings = self._embed_boxes(boxes)
114
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
115
+
116
+ if masks is not None:
117
+ dense_embeddings = self._embed_masks(masks)
118
+ else:
119
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
120
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
121
+ )
122
+
123
+ return sparse_embeddings, dense_embeddings
124
+
125
+
126
+ class PositionEmbeddingRandom(nn.Module):
127
+ """
128
+ Positional encoding using random spatial frequencies.
129
+ """
130
+
131
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
132
+ super().__init__()
133
+ if scale is None or scale <= 0.0:
134
+ scale = 1.0
135
+ self.register_buffer(
136
+ "positional_encoding_gaussian_matrix",
137
+ scale * torch.randn((2, num_pos_feats)),
138
+ )
139
+
140
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
141
+ """Positionally encode points that are normalized to [0,1]."""
142
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
143
+ coords = 2 * coords - 1
144
+ coords = coords @ self.positional_encoding_gaussian_matrix
145
+ coords = 2 * np.pi * coords
146
+ # outputs d_1 x ... x d_n x C shape
147
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
148
+
149
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
150
+ """Generate positional encoding for a grid of the specified size."""
151
+ h, w = size
152
+ device: Any = self.positional_encoding_gaussian_matrix.device
153
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
154
+ y_embed = grid.cumsum(dim=0) - 0.5
155
+ x_embed = grid.cumsum(dim=1) - 0.5
156
+ y_embed = y_embed / h
157
+ x_embed = x_embed / w
158
+
159
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
160
+ return pe.permute(2, 0, 1) # C x H x W
161
+
162
+ def forward_with_coords(
163
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
164
+ ) -> torch.Tensor:
165
+ """Positionally encode points that are not normalized to [0,1]."""
166
+ coords = coords_input.clone()
167
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
168
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
169
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
model/segment_anything/modeling/sam.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ from .image_encoder import ImageEncoderViT
8
+ from .mask_decoder import MaskDecoder
9
+ from .prompt_encoder import PromptEncoder
10
+
11
+
12
+ class Sam(nn.Module):
13
+ mask_threshold: float = 0.0
14
+ image_format: str = "RGB"
15
+
16
+ def __init__(
17
+ self,
18
+ image_encoder: ImageEncoderViT,
19
+ prompt_encoder: PromptEncoder,
20
+ mask_decoder: MaskDecoder,
21
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
22
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
23
+ ) -> None:
24
+ super().__init__()
25
+ self.image_encoder = image_encoder
26
+ self.prompt_encoder = prompt_encoder
27
+ self.mask_decoder = mask_decoder
28
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
29
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
30
+
31
+ @property
32
+ def device(self) -> Any:
33
+ return self.pixel_mean.device
34
+
35
+ @torch.no_grad()
36
+ def forward(
37
+ self,
38
+ batched_input: List[Dict[str, Any]],
39
+ multimask_output: bool,
40
+ ) -> List[Dict[str, torch.Tensor]]:
41
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
42
+ image_embeddings = self.image_encoder(input_images)
43
+
44
+ outputs = []
45
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
46
+ if "point_coords" in image_record:
47
+ points = (image_record["point_coords"], image_record["point_labels"])
48
+ else:
49
+ points = None
50
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
51
+ points=points,
52
+ boxes=image_record.get("boxes", None),
53
+ masks=image_record.get("mask_inputs", None),
54
+ )
55
+ low_res_masks, iou_predictions = self.mask_decoder(
56
+ image_embeddings=curr_embedding.unsqueeze(0),
57
+ image_pe=self.prompt_encoder.get_dense_pe(),
58
+ sparse_prompt_embeddings=sparse_embeddings,
59
+ dense_prompt_embeddings=dense_embeddings,
60
+ multimask_output=multimask_output,
61
+ )
62
+ masks = self.postprocess_masks(
63
+ low_res_masks,
64
+ input_size=image_record["image"].shape[-2:],
65
+ original_size=image_record["original_size"],
66
+ )
67
+ masks = masks > self.mask_threshold
68
+ outputs.append(
69
+ {
70
+ "masks": masks,
71
+ "iou_predictions": iou_predictions,
72
+ "low_res_logits": low_res_masks,
73
+ }
74
+ )
75
+ return outputs
76
+
77
+ def postprocess_masks(
78
+ self,
79
+ masks: torch.Tensor,
80
+ input_size: Tuple[int, ...],
81
+ original_size: Tuple[int, ...],
82
+ ) -> torch.Tensor:
83
+ masks = F.interpolate(
84
+ masks,
85
+ (self.image_encoder.img_size, self.image_encoder.img_size),
86
+ mode="bilinear",
87
+ align_corners=False,
88
+ )
89
+ masks = masks[..., : input_size[0], : input_size[1]]
90
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
91
+ return masks
92
+
93
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
94
+ """Normalize pixel values and pad to a square input."""
95
+ # Normalize colors
96
+ x = (x - self.pixel_mean) / self.pixel_std
97
+
98
+ # Pad
99
+ h, w = x.shape[-2:]
100
+ padh = self.image_encoder.img_size - h
101
+ padw = self.image_encoder.img_size - w
102
+ x = F.pad(x, (0, padw, 0, padh))
103
+ return x
model/segment_anything/modeling/transformer.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+ import math
5
+ from typing import Tuple, Type
6
+
7
+ from .common import MLPBlock
8
+
9
+
10
+ class TwoWayTransformer(nn.Module):
11
+ def __init__(
12
+ self,
13
+ depth: int,
14
+ embedding_dim: int,
15
+ num_heads: int,
16
+ mlp_dim: int,
17
+ activation: Type[nn.Module] = nn.ReLU,
18
+ attention_downsample_rate: int = 2,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.depth = depth
22
+ self.embedding_dim = embedding_dim
23
+ self.num_heads = num_heads
24
+ self.mlp_dim = mlp_dim
25
+ self.layers = nn.ModuleList()
26
+
27
+ for i in range(depth):
28
+ self.layers.append(
29
+ TwoWayAttentionBlock(
30
+ embedding_dim=embedding_dim,
31
+ num_heads=num_heads,
32
+ mlp_dim=mlp_dim,
33
+ activation=activation,
34
+ attention_downsample_rate=attention_downsample_rate,
35
+ skip_first_layer_pe=(i == 0),
36
+ )
37
+ )
38
+
39
+ self.final_attn_token_to_image = Attention(
40
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
41
+ )
42
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
43
+
44
+ def forward(
45
+ self,
46
+ image_embedding: Tensor,
47
+ image_pe: Tensor,
48
+ point_embedding: Tensor,
49
+ ) -> Tuple[Tensor, Tensor]:
50
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
51
+ bs, c, h, w = image_embedding.shape
52
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
53
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
54
+
55
+ # Prepare queries
56
+ queries = point_embedding
57
+ keys = image_embedding
58
+
59
+ # Apply transformer blocks and final layernorm
60
+ for layer in self.layers:
61
+ queries, keys = layer(
62
+ queries=queries,
63
+ keys=keys,
64
+ query_pe=point_embedding,
65
+ key_pe=image_pe,
66
+ )
67
+
68
+ # Apply the final attention layer from the points to the image
69
+ q = queries + point_embedding
70
+ k = keys + image_pe
71
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
72
+ queries = queries + attn_out
73
+ queries = self.norm_final_attn(queries)
74
+
75
+ return queries, keys
76
+
77
+
78
+ class TwoWayAttentionBlock(nn.Module):
79
+ def __init__(
80
+ self,
81
+ embedding_dim: int,
82
+ num_heads: int,
83
+ mlp_dim: int = 2048,
84
+ activation: Type[nn.Module] = nn.ReLU,
85
+ attention_downsample_rate: int = 2,
86
+ skip_first_layer_pe: bool = False,
87
+ ) -> None:
88
+ super().__init__()
89
+ self.self_attn = Attention(embedding_dim, num_heads)
90
+ self.norm1 = nn.LayerNorm(embedding_dim)
91
+
92
+ self.cross_attn_token_to_image = Attention(
93
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
94
+ )
95
+ self.norm2 = nn.LayerNorm(embedding_dim)
96
+
97
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
98
+ self.norm3 = nn.LayerNorm(embedding_dim)
99
+
100
+ self.norm4 = nn.LayerNorm(embedding_dim)
101
+ self.cross_attn_image_to_token = Attention(
102
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
103
+ )
104
+
105
+ self.skip_first_layer_pe = skip_first_layer_pe
106
+
107
+ def forward(
108
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
109
+ ) -> Tuple[Tensor, Tensor]:
110
+ # Self attention block
111
+ if self.skip_first_layer_pe:
112
+ queries = self.self_attn(q=queries, k=queries, v=queries)
113
+ else:
114
+ q = queries + query_pe
115
+ attn_out = self.self_attn(q=q, k=q, v=queries)
116
+ queries = queries + attn_out
117
+ queries = self.norm1(queries)
118
+
119
+ # Cross attention block, tokens attending to image embedding
120
+ q = queries + query_pe
121
+ k = keys + key_pe
122
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
123
+ queries = queries + attn_out
124
+ queries = self.norm2(queries)
125
+
126
+ # MLP block
127
+ mlp_out = self.mlp(queries)
128
+ queries = queries + mlp_out
129
+ queries = self.norm3(queries)
130
+
131
+ # Cross attention block, image embedding attending to tokens
132
+ q = queries + query_pe
133
+ k = keys + key_pe
134
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
135
+ keys = keys + attn_out
136
+ keys = self.norm4(keys)
137
+
138
+ return queries, keys
139
+
140
+
141
+ class Attention(nn.Module):
142
+ """
143
+ An attention layer that allows for downscaling the size of the embedding
144
+ after projection to queries, keys, and values.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ embedding_dim: int,
150
+ num_heads: int,
151
+ downsample_rate: int = 1,
152
+ ) -> None:
153
+ super().__init__()
154
+ self.embedding_dim = embedding_dim
155
+ self.internal_dim = embedding_dim // downsample_rate
156
+ self.num_heads = num_heads
157
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
158
+
159
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
160
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
161
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
162
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
163
+
164
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
165
+ b, n, c = x.shape
166
+ x = x.reshape(b, n, num_heads, c // num_heads)
167
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
168
+
169
+ def _recombine_heads(self, x: Tensor) -> Tensor:
170
+ b, n_heads, n_tokens, c_per_head = x.shape
171
+ x = x.transpose(1, 2)
172
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
173
+
174
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
175
+ # Input projections
176
+ q = self.q_proj(q)
177
+ k = self.k_proj(k)
178
+ v = self.v_proj(v)
179
+
180
+ # Separate into heads
181
+ q = self._separate_heads(q, self.num_heads)
182
+ k = self._separate_heads(k, self.num_heads)
183
+ v = self._separate_heads(v, self.num_heads)
184
+
185
+ # Attention
186
+ _, _, _, c_per_head = q.shape
187
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
188
+ attn = attn / math.sqrt(c_per_head)
189
+ attn = torch.softmax(attn, dim=-1)
190
+
191
+ # Get output
192
+ out = attn @ v
193
+ out = self._recombine_heads(out)
194
+ out = self.out_proj(out)
195
+
196
+ return out
model/segment_anything/predictor.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import torch
5
+ import time
6
+ from model.segment_anything.modeling import Sam
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ from model.segment_anything.utils.transforms import ResizeLongestSide
11
+
12
+
13
+ class SamPredictor:
14
+ def __init__(
15
+ self,
16
+ sam_model: Sam,
17
+ ) -> None:
18
+ """
19
+ Uses SAM to calculate the image embedding for an image, and then
20
+ allow repeated, efficient mask prediction given prompts.
21
+
22
+ Arguments:
23
+ sam_model (Sam): The model to use for mask prediction.
24
+ """
25
+ super().__init__()
26
+ self.model = sam_model
27
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
28
+ self.reset_image()
29
+
30
+ def set_image(
31
+ self,
32
+ image: np.ndarray,
33
+ image_format: str = "RGB",
34
+ ) -> None:
35
+
36
+ assert image_format in [
37
+ "RGB",
38
+ "BGR",
39
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
40
+ if image_format != self.model.image_format:
41
+ image = image[..., ::-1]
42
+
43
+ # Transform the image to the form expected by the model
44
+ input_image = self.transform.apply_image(image)
45
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
46
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
47
+
48
+ self.set_torch_image(input_image_torch, image.shape[:2])
49
+
50
+ @torch.no_grad()
51
+ def set_torch_image(
52
+ self,
53
+ transformed_image: torch.Tensor,
54
+ original_image_size: Tuple[int, ...],
55
+ ) -> None:
56
+
57
+ assert (
58
+ len(transformed_image.shape) == 4
59
+ and transformed_image.shape[1] == 3
60
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
61
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
62
+ self.reset_image()
63
+
64
+ self.original_size = original_image_size
65
+ self.input_size = tuple(transformed_image.shape[-2:])
66
+ input_image = self.model.preprocess(transformed_image)
67
+ self.features = self.model.image_encoder(input_image)
68
+ self.is_image_set = True
69
+
70
+ def predict(
71
+ self,
72
+ point_coords: Optional[np.ndarray] = None,
73
+ point_labels: Optional[np.ndarray] = None,
74
+ box: Optional[np.ndarray] = None,
75
+ mask_input: Optional[np.ndarray] = None,
76
+ multimask_output: bool = True,
77
+ return_logits: bool = False,
78
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
79
+ if not self.is_image_set:
80
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
81
+
82
+ # Transform input prompts
83
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
84
+ if point_coords is not None:
85
+ assert (
86
+ point_labels is not None
87
+ ), "point_labels must be supplied if point_coords is supplied."
88
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
89
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
90
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
91
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
92
+ if box is not None:
93
+ box = self.transform.apply_boxes(box, self.original_size)
94
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
95
+ box_torch = box_torch[None, :]
96
+ if mask_input is not None:
97
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
98
+ mask_input_torch = mask_input_torch[None, :, :, :]
99
+
100
+ masks, iou_predictions, low_res_masks = self.predict_torch(
101
+ coords_torch,
102
+ labels_torch,
103
+ box_torch,
104
+ mask_input_torch,
105
+ multimask_output,
106
+ return_logits=return_logits,
107
+ )
108
+
109
+ masks_np = masks[0].detach().cpu().numpy()
110
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
111
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
112
+ return masks_np, iou_predictions_np, low_res_masks_np
113
+
114
+ @torch.no_grad()
115
+ def predict_torch(
116
+ self,
117
+ point_coords: Optional[torch.Tensor],
118
+ point_labels: Optional[torch.Tensor],
119
+ boxes: Optional[torch.Tensor] = None,
120
+ mask_input: Optional[torch.Tensor] = None,
121
+ multimask_output: bool = True,
122
+ return_logits: bool = False,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
124
+ if not self.is_image_set:
125
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
126
+
127
+ if point_coords is not None:
128
+ points = (point_coords, point_labels)
129
+ else:
130
+ points = None
131
+
132
+ # Embed prompts
133
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
134
+ points=points,
135
+ boxes=boxes,
136
+ masks=mask_input,
137
+ )
138
+
139
+ # Predict masks
140
+ low_res_masks, iou_predictions = self.model.mask_decoder(
141
+ image_embeddings=self.features,
142
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
143
+ sparse_prompt_embeddings=sparse_embeddings,
144
+ dense_prompt_embeddings=dense_embeddings,
145
+ multimask_output=multimask_output,
146
+ )
147
+
148
+ # Upscale the masks to the original image resolution
149
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
150
+
151
+ if not return_logits:
152
+ masks = masks > self.model.mask_threshold
153
+
154
+ return masks, iou_predictions, low_res_masks
155
+
156
+ def get_image_embedding(self) -> torch.Tensor:
157
+ """
158
+ Returns the image embeddings for the currently set image, with
159
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
160
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
161
+ """
162
+ if not self.is_image_set:
163
+ raise RuntimeError(
164
+ "An image must be set with .set_image(...) to generate an embedding."
165
+ )
166
+ assert self.features is not None, "Features must exist if an image has been set."
167
+ return self.features
168
+
169
+ @property
170
+ def device(self) -> torch.device:
171
+ return self.model.device
172
+
173
+ def reset_image(self) -> None:
174
+ """Resets the currently set image."""
175
+ self.is_image_set = False
176
+ self.features = None
177
+ self.orig_h = None
178
+ self.orig_w = None
179
+ self.input_h = None
180
+ self.input_w = None
model/segment_anything/utils/__init__.py ADDED
File without changes
model/segment_anything/utils/amg.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ import math
7
+ from copy import deepcopy
8
+ from itertools import product
9
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
10
+
11
+
12
+ class MaskData:
13
+ """
14
+ A structure for storing masks and their related data in batched format.
15
+ Implements basic filtering and concatenation.
16
+ """
17
+
18
+ def __init__(self, **kwargs) -> None:
19
+ for v in kwargs.values():
20
+ assert isinstance(
21
+ v, (list, np.ndarray, torch.Tensor)
22
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
23
+ self._stats = dict(**kwargs)
24
+
25
+ def __setitem__(self, key: str, item: Any) -> None:
26
+ assert isinstance(
27
+ item, (list, np.ndarray, torch.Tensor)
28
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
29
+ self._stats[key] = item
30
+
31
+ def __delitem__(self, key: str) -> None:
32
+ del self._stats[key]
33
+
34
+ def __getitem__(self, key: str) -> Any:
35
+ return self._stats[key]
36
+
37
+ def items(self) -> ItemsView[str, Any]:
38
+ return self._stats.items()
39
+
40
+ def filter(self, keep: torch.Tensor) -> None:
41
+ for k, v in self._stats.items():
42
+ if v is None:
43
+ self._stats[k] = None
44
+ elif isinstance(v, torch.Tensor):
45
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
46
+ elif isinstance(v, np.ndarray):
47
+ self._stats[k] = v[keep.detach().cpu().numpy()]
48
+ elif isinstance(v, list) and keep.dtype == torch.bool:
49
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
50
+ elif isinstance(v, list):
51
+ self._stats[k] = [v[i] for i in keep]
52
+ else:
53
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
54
+
55
+ def cat(self, new_stats: "MaskData") -> None:
56
+ for k, v in new_stats.items():
57
+ if k not in self._stats or self._stats[k] is None:
58
+ self._stats[k] = deepcopy(v)
59
+ elif isinstance(v, torch.Tensor):
60
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
61
+ elif isinstance(v, np.ndarray):
62
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
63
+ elif isinstance(v, list):
64
+ self._stats[k] = self._stats[k] + deepcopy(v)
65
+ else:
66
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
67
+
68
+ def to_numpy(self) -> None:
69
+ for k, v in self._stats.items():
70
+ if isinstance(v, torch.Tensor):
71
+ self._stats[k] = v.detach().cpu().numpy()
72
+
73
+
74
+ def is_box_near_crop_edge(
75
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
76
+ ) -> torch.Tensor:
77
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
78
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
79
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
80
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
81
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
82
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
83
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
84
+ return torch.any(near_crop_edge, dim=1)
85
+
86
+
87
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
88
+ box_xywh = deepcopy(box_xyxy)
89
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
90
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
91
+ return box_xywh
92
+
93
+
94
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
95
+ assert len(args) > 0 and all(
96
+ len(a) == len(args[0]) for a in args
97
+ ), "Batched iteration must have inputs of all the same size."
98
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
99
+ for b in range(n_batches):
100
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
101
+
102
+
103
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
104
+ # Put in fortran order and flatten h,w
105
+ b, h, w = tensor.shape
106
+ tensor = tensor.permute(0, 2, 1).flatten(1)
107
+
108
+ # Compute change indices
109
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
110
+ change_indices = diff.nonzero()
111
+
112
+ # Encode run length
113
+ out = []
114
+ for i in range(b):
115
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
116
+ cur_idxs = torch.cat(
117
+ [
118
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
119
+ cur_idxs + 1,
120
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
121
+ ]
122
+ )
123
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
124
+ counts = [] if tensor[i, 0] == 0 else [0]
125
+ counts.extend(btw_idxs.detach().cpu().tolist())
126
+ out.append({"size": [h, w], "counts": counts})
127
+ return out
128
+
129
+
130
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
131
+ """Compute a binary mask from an uncompressed RLE."""
132
+ h, w = rle["size"]
133
+ mask = np.empty(h * w, dtype=bool)
134
+ idx = 0
135
+ parity = False
136
+ for count in rle["counts"]:
137
+ mask[idx : idx + count] = parity
138
+ idx += count
139
+ parity ^= True
140
+ mask = mask.reshape(w, h)
141
+ return mask.transpose() # Put in C order
142
+
143
+
144
+ def area_from_rle(rle: Dict[str, Any]) -> int:
145
+ return sum(rle["counts"][1::2])
146
+
147
+
148
+ def calculate_stability_score(
149
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
150
+ ) -> torch.Tensor:
151
+ # One mask is always contained inside the other.
152
+ # Save memory by preventing unnecessary cast to torch.int64
153
+ intersections = (
154
+ (masks > (mask_threshold + threshold_offset))
155
+ .sum(-1, dtype=torch.int16)
156
+ .sum(-1, dtype=torch.int32)
157
+ )
158
+ unions = (
159
+ (masks > (mask_threshold - threshold_offset))
160
+ .sum(-1, dtype=torch.int16)
161
+ .sum(-1, dtype=torch.int32)
162
+ )
163
+ return intersections / unions
164
+
165
+
166
+ def build_point_grid(n_per_side: int) -> np.ndarray:
167
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
168
+ offset = 1 / (2 * n_per_side)
169
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
170
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
171
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
172
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
173
+ return points
174
+
175
+
176
+ def build_all_layer_point_grids(
177
+ n_per_side: int, n_layers: int, scale_per_layer: int
178
+ ) -> List[np.ndarray]:
179
+ """Generates point grids for all crop layers."""
180
+ points_by_layer = []
181
+ for i in range(n_layers + 1):
182
+ n_points = int(n_per_side / (scale_per_layer**i))
183
+ points_by_layer.append(build_point_grid(n_points))
184
+ return points_by_layer
185
+
186
+
187
+ def generate_crop_boxes(
188
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
189
+ ) -> Tuple[List[List[int]], List[int]]:
190
+ """
191
+ Generates a list of crop boxes of different sizes. Each layer
192
+ has (2**i)**2 boxes for the ith layer.
193
+ """
194
+ crop_boxes, layer_idxs = [], []
195
+ im_h, im_w = im_size
196
+ short_side = min(im_h, im_w)
197
+
198
+ # Original image
199
+ crop_boxes.append([0, 0, im_w, im_h])
200
+ layer_idxs.append(0)
201
+
202
+ def crop_len(orig_len, n_crops, overlap):
203
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
204
+
205
+ for i_layer in range(n_layers):
206
+ n_crops_per_side = 2 ** (i_layer + 1)
207
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
208
+
209
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
210
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
211
+
212
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
213
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
214
+
215
+ # Crops in XYWH format
216
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
217
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
218
+ crop_boxes.append(box)
219
+ layer_idxs.append(i_layer + 1)
220
+
221
+ return crop_boxes, layer_idxs
222
+
223
+
224
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
225
+ x0, y0, _, _ = crop_box
226
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
227
+ # Check if boxes has a channel dimension
228
+ if len(boxes.shape) == 3:
229
+ offset = offset.unsqueeze(1)
230
+ return boxes + offset
231
+
232
+
233
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
234
+ x0, y0, _, _ = crop_box
235
+ offset = torch.tensor([[x0, y0]], device=points.device)
236
+ # Check if points has a channel dimension
237
+ if len(points.shape) == 3:
238
+ offset = offset.unsqueeze(1)
239
+ return points + offset
240
+
241
+
242
+ def uncrop_masks(
243
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
244
+ ) -> torch.Tensor:
245
+ x0, y0, x1, y1 = crop_box
246
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
247
+ return masks
248
+ # Coordinate transform masks
249
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
250
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
251
+ return torch.nn.functional.pad(masks, pad, value=0)
252
+
253
+
254
+ def remove_small_regions(
255
+ mask: np.ndarray, area_thresh: float, mode: str
256
+ ) -> Tuple[np.ndarray, bool]:
257
+ """
258
+ Removes small disconnected regions and holes in a mask. Returns the
259
+ mask and an indicator of if the mask has been modified.
260
+ """
261
+ import cv2 # type: ignore
262
+
263
+ assert mode in ["holes", "islands"]
264
+ correct_holes = mode == "holes"
265
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
266
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
267
+ sizes = stats[:, -1][1:] # Row 0 is background label
268
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
269
+ if len(small_regions) == 0:
270
+ return mask, False
271
+ fill_labels = [0] + small_regions
272
+ if not correct_holes:
273
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
274
+ # If every region is below threshold, keep largest
275
+ if len(fill_labels) == 0:
276
+ fill_labels = [int(np.argmax(sizes)) + 1]
277
+ mask = np.isin(regions, fill_labels)
278
+ return mask, True
279
+
280
+
281
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
282
+ from pycocotools import mask as mask_utils # type: ignore
283
+
284
+ h, w = uncompressed_rle["size"]
285
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
286
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
287
+ return rle
288
+
289
+
290
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
291
+ # torch.max below raises an error on empty inputs, just skip in this case
292
+ if torch.numel(masks) == 0:
293
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
294
+
295
+ # Normalize shape to CxHxW
296
+ shape = masks.shape
297
+ h, w = shape[-2:]
298
+ if len(shape) > 2:
299
+ masks = masks.flatten(0, -3)
300
+ else:
301
+ masks = masks.unsqueeze(0)
302
+
303
+ # Get top and bottom edges
304
+ in_height, _ = torch.max(masks, dim=-1)
305
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
306
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
307
+ in_height_coords = in_height_coords + h * (~in_height)
308
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
309
+
310
+ # Get left and right edges
311
+ in_width, _ = torch.max(masks, dim=-2)
312
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
313
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
314
+ in_width_coords = in_width_coords + w * (~in_width)
315
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
316
+
317
+ # If the mask is empty the right edge will be to the left of the left edge.
318
+ # Replace these boxes with [0, 0, 0, 0]
319
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
320
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
321
+ out = out * (~empty_filter).unsqueeze(-1)
322
+
323
+ # Return to original shape
324
+ if len(shape) > 2:
325
+ out = out.reshape(*shape[:-2], 4)
326
+ else:
327
+ out = out[0]
328
+
329
+ return out
model/segment_anything/utils/onnx.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from typing import Tuple
8
+
9
+ from ..modeling import Sam
10
+ from .amg import calculate_stability_score
11
+
12
+
13
+ class SamOnnxModel(nn.Module):
14
+
15
+
16
+ def __init__(
17
+ self,
18
+ model: Sam,
19
+ return_single_mask: bool,
20
+ use_stability_score: bool = False,
21
+ return_extra_metrics: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.mask_decoder = model.mask_decoder
25
+ self.model = model
26
+ self.img_size = model.image_encoder.img_size
27
+ self.return_single_mask = return_single_mask
28
+ self.use_stability_score = use_stability_score
29
+ self.stability_score_offset = 1.0
30
+ self.return_extra_metrics = return_extra_metrics
31
+
32
+ @staticmethod
33
+ def resize_longest_image_size(
34
+ input_image_size: torch.Tensor, longest_side: int
35
+ ) -> torch.Tensor:
36
+ input_image_size = input_image_size.to(torch.float32)
37
+ scale = longest_side / torch.max(input_image_size)
38
+ transformed_size = scale * input_image_size
39
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
40
+ return transformed_size
41
+
42
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
43
+ point_coords = point_coords + 0.5
44
+ point_coords = point_coords / self.img_size
45
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
46
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
47
+
48
+ point_embedding = point_embedding * (point_labels != -1)
49
+ point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
50
+ point_labels == -1
51
+ )
52
+
53
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
54
+ point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
55
+ i
56
+ ].weight * (point_labels == i)
57
+
58
+ return point_embedding
59
+
60
+ def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
61
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
62
+ mask_embedding = mask_embedding + (
63
+ 1 - has_mask_input
64
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
65
+ return mask_embedding
66
+
67
+ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
68
+ masks = F.interpolate(
69
+ masks,
70
+ size=(self.img_size, self.img_size),
71
+ mode="bilinear",
72
+ align_corners=False,
73
+ )
74
+
75
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
76
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
77
+
78
+ orig_im_size = orig_im_size.to(torch.int64)
79
+ h, w = orig_im_size[0], orig_im_size[1]
80
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
81
+ return masks
82
+
83
+ def select_masks(
84
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
85
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
86
+ # Determine if we should return the multiclick mask or not from the number of points.
87
+ # The reweighting is used to avoid control flow.
88
+ score_reweight = torch.tensor(
89
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
90
+ ).to(iou_preds.device)
91
+ score = iou_preds + (num_points - 2.5) * score_reweight
92
+ best_idx = torch.argmax(score, dim=1)
93
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
94
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
95
+
96
+ return masks, iou_preds
97
+
98
+ @torch.no_grad()
99
+ def forward(
100
+ self,
101
+ image_embeddings: torch.Tensor,
102
+ point_coords: torch.Tensor,
103
+ point_labels: torch.Tensor,
104
+ mask_input: torch.Tensor,
105
+ has_mask_input: torch.Tensor,
106
+ orig_im_size: torch.Tensor,
107
+ ):
108
+ sparse_embedding = self._embed_points(point_coords, point_labels)
109
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
110
+
111
+ masks, scores = self.model.mask_decoder.predict_masks(
112
+ image_embeddings=image_embeddings,
113
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
114
+ sparse_prompt_embeddings=sparse_embedding,
115
+ dense_prompt_embeddings=dense_embedding,
116
+ )
117
+
118
+ if self.use_stability_score:
119
+ scores = calculate_stability_score(
120
+ masks, self.model.mask_threshold, self.stability_score_offset
121
+ )
122
+
123
+ if self.return_single_mask:
124
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
125
+
126
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
127
+
128
+ if self.return_extra_metrics:
129
+ stability_scores = calculate_stability_score(
130
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
131
+ )
132
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
133
+ return upscaled_masks, scores, stability_scores, areas, masks
134
+
135
+ return upscaled_masks, scores, masks
model/segment_anything/utils/transforms.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
7
+
8
+ from copy import deepcopy
9
+ from typing import Tuple
10
+
11
+
12
+ class ResizeLongestSide:
13
+
14
+ def __init__(self, target_length: int) -> None:
15
+ self.target_length = target_length
16
+
17
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
18
+ """
19
+ Expects a numpy array with shape HxWxC in uint8 format.
20
+ """
21
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
22
+ return np.array(resize(to_pil_image(image), target_size))
23
+
24
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
25
+ """
26
+ Expects a numpy array of length 2 in the final dimension. Requires the
27
+ original image size in (H, W) format.
28
+ """
29
+ old_h, old_w = original_size
30
+ new_h, new_w = self.get_preprocess_shape(
31
+ original_size[0], original_size[1], self.target_length
32
+ )
33
+ coords = deepcopy(coords).astype(float)
34
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
35
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
36
+ return coords
37
+
38
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
39
+ """
40
+ Expects a numpy array shape Bx4. Requires the original image size
41
+ in (H, W) format.
42
+ """
43
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
44
+ return boxes.reshape(-1, 4)
45
+
46
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Expects batched images with shape BxCxHxW and float format. This
49
+ transformation may not exactly match apply_image. apply_image is
50
+ the transformation expected by the model.
51
+ """
52
+ # Expects an image in BCHW format. May not exactly match apply_image.
53
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
54
+ return F.interpolate(
55
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
56
+ )
57
+
58
+ def apply_coords_torch(
59
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
60
+ ) -> torch.Tensor:
61
+ """
62
+ Expects a torch tensor with length 2 in the last dimension. Requires the
63
+ original image size in (H, W) format.
64
+ """
65
+ old_h, old_w = original_size
66
+ new_h, new_w = self.get_preprocess_shape(
67
+ original_size[0], original_size[1], self.target_length
68
+ )
69
+ coords = deepcopy(coords).to(torch.float)
70
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
71
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
72
+ return coords
73
+
74
+ def apply_boxes_torch(
75
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
76
+ ) -> torch.Tensor:
77
+ """
78
+ Expects a torch tensor with shape Bx4. Requires the original image
79
+ size in (H, W) format.
80
+ """
81
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
82
+ return boxes.reshape(-1, 4)
83
+
84
+ @staticmethod
85
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
86
+ """
87
+ Compute the output size given input size and target long side length.
88
+ """
89
+ scale = long_side_length * 1.0 / max(oldh, oldw)
90
+ newh, neww = oldh * scale, oldw * scale
91
+ neww = int(neww + 0.5)
92
+ newh = int(newh + 0.5)
93
+ return (newh, neww)
requirements.txt ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.11.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.2
4
+ aiosignal==1.4.0
5
+ annotated-types==0.7.0
6
+ anyio==4.11.0
7
+ async-timeout==5.0.1
8
+ attrs==25.4.0
9
+ av==16.0.1
10
+ bitsandbytes==0.48.1
11
+ certifi==2025.10.5
12
+ charset-normalizer==3.4.4
13
+ click==8.3.0
14
+ contourpy==1.3.2
15
+ cycler==0.12.1
16
+ datasets==4.3.0
17
+ deepspeed==0.18.2
18
+ dill==0.4.0
19
+ einops==0.8.1
20
+ exceptiongroup==1.3.0
21
+ filelock==3.20.0
22
+ flash_attn==2.8.3
23
+ fonttools==4.60.1
24
+ frozenlist==1.8.0
25
+ fsspec==2025.9.0
26
+ gitdb==4.0.12
27
+ GitPython==3.1.45
28
+ h11==0.16.0
29
+ hf-xet==1.2.0
30
+ hjson==3.1.0
31
+ httpcore==1.0.9
32
+ httpx==0.28.1
33
+ huggingface-hub==0.36.0
34
+ idna==3.11
35
+ imageio==2.37.0
36
+ Jinja2==3.1.6
37
+ kiwisolver==1.4.9
38
+ lazy_loader==0.4
39
+ MarkupSafe==3.0.3
40
+ matplotlib==3.10.7
41
+ mpmath==1.3.0
42
+ msgpack==1.1.2
43
+ multidict==6.7.0
44
+ multiprocess==0.70.16
45
+ networkx==3.4.2
46
+ ninja==1.13.0
47
+ numpy==2.2.6
48
+ nvidia-cublas-cu12==12.6.4.1
49
+ nvidia-cuda-cupti-cu12==12.6.80
50
+ nvidia-cuda-nvrtc-cu12==12.6.77
51
+ nvidia-cuda-runtime-cu12==12.6.77
52
+ nvidia-cudnn-cu12==9.10.2.21
53
+ nvidia-cufft-cu12==11.3.0.4
54
+ nvidia-cufile-cu12==1.11.1.6
55
+ nvidia-curand-cu12==10.3.7.77
56
+ nvidia-cusolver-cu12==11.7.1.2
57
+ nvidia-cusparse-cu12==12.5.4.2
58
+ nvidia-cusparselt-cu12==0.7.1
59
+ nvidia-ml-py==13.580.82
60
+ nvidia-nccl-cu12==2.27.5
61
+ nvidia-nvjitlink-cu12==12.6.85
62
+ nvidia-nvshmem-cu12==3.3.20
63
+ nvidia-nvtx-cu12==12.6.77
64
+ opencv-python==4.12.0.88
65
+ packaging==25.0
66
+ pandas==2.3.3
67
+ peft==0.17.1
68
+ pillow==12.0.0
69
+ platformdirs==4.5.0
70
+ propcache==0.4.1
71
+ protobuf==6.33.0
72
+ psutil==7.1.1
73
+ py-cpuinfo==9.0.0
74
+ pyarrow==22.0.0
75
+ pycocotools==2.0.10
76
+ pydantic==2.12.4
77
+ pydantic_core==2.41.5
78
+ pyparsing==3.2.5
79
+ python-dateutil==2.9.0.post0
80
+ pytz==2025.2
81
+ PyYAML==6.0.3
82
+ qwen-vl-utils==0.0.14
83
+ regex==2025.10.23
84
+ requests==2.32.5
85
+ safetensors==0.6.2
86
+ scikit-image==0.25.2
87
+ scipy==1.15.3
88
+ sentencepiece==0.2.1
89
+ sentry-sdk==2.45.0
90
+ shellingham==1.5.4
91
+ six==1.17.0
92
+ smmap==5.0.2
93
+ sniffio==1.3.1
94
+ sympy==1.14.0
95
+ tifffile==2025.5.10
96
+ tiktoken==0.12.0
97
+ tokenizers==0.22.1
98
+ torch==2.9.0+cu126
99
+ torchaudio==2.9.0+cu126
100
+ torchvision==0.24.0+cu126
101
+ tqdm==4.67.1
102
+ transformers==4.57.1
103
+ triton==3.5.0
104
+ trl==0.24.0
105
+ typer-slim==0.20.0
106
+ typing-inspection==0.4.2
107
+ typing_extensions==4.15.0
108
+ tzdata==2025.2
109
+ urllib3==2.5.0
110
+ wandb==0.23.0
111
+ xxhash==3.6.0
112
+ yarl==1.22.0
run_seg_ref.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ from model.segment_anything import SamPredictor, sam_model_registry
7
+ from eval.utils import compute_logits_from_mask, show_points, masks_sample_points
8
+ import cv2
9
+
10
+ import requests
11
+ from PIL import Image
12
+ from io import BytesIO
13
+ import re
14
+
15
+ from segment_predictor_cache import GenerativeSegmenter
16
+
17
+
18
+ def image_parser(args):
19
+ out = args.image_file.split(args.sep)
20
+ return out
21
+
22
+
23
+ def load_image(image_file):
24
+ if image_file.startswith("http") or image_file.startswith("https"
25
+ ):
26
+ response = requests.get(image_file)
27
+ image = Image.open(BytesIO(response.content)).convert("RGB")
28
+ else:
29
+ image = Image.open(image_file).convert("RGB")
30
+ return image
31
+
32
+
33
+ def load_images(image_files):
34
+ out = []
35
+ for image_file in image_files:
36
+ image = load_image(image_file)
37
+ out.append(image)
38
+ return out
39
+
40
+
41
+ def upsample_tensor_vectorized(a, s):
42
+ h, w = a.shape
43
+ sh, sw = int(h * s), int(w * s)
44
+ # Create an output tensor of zeros
45
+ result = torch.zeros((sh, sw), dtype=a.dtype, device=a.device)
46
+ # Calculate the target indices
47
+ offset = int(s / 2)
48
+ i_indices = torch.arange(h) * s + offset
49
+ j_indices = torch.arange(w) * s + offset
50
+ # Use broadcasting to fill the result tensor
51
+ result[i_indices[:, None].long(), j_indices.long()] = a
52
+ return result
53
+
54
+
55
+ def translate_sequence(sequence_str):
56
+ """
57
+ Translates a comma-separated sequence of categorical data to numerical labels,
58
+ identifying categories from the sequence.
59
+
60
+ Parameters:
61
+ sequence_str (str): The comma-separated sequence of categorical data.
62
+
63
+ Returns:
64
+ list: The sequence of numerical labels.
65
+ """
66
+ # Split the string into a list of categories
67
+ sequence = sequence_str.split('|')
68
+
69
+ # strip the whitespace from each category
70
+ sequence = [seq.strip() for seq in sequence]
71
+
72
+ # Identify unique categories from the sequence
73
+ unique_categories = list(dict.fromkeys(sequence))
74
+
75
+ # place "others" at the beginning of the list
76
+ if "others" in unique_categories:
77
+ unique_categories.remove("others")
78
+ unique_categories.insert(0, "others")
79
+
80
+ # Create a dictionary to map each category to a unique integer
81
+ category_to_label = {
82
+ category: idx
83
+ for idx, category in enumerate(unique_categories)
84
+ }
85
+
86
+ # Translate the sequence using the dictionary
87
+ translated_sequence = [category_to_label[item] for item in sequence]
88
+
89
+ return translated_sequence
90
+
91
+
92
+ def decode_mask(encoded_str):
93
+ rows = encoded_str.strip("\n").split("\n ")
94
+ decoded_list = []
95
+ for row in rows:
96
+ tokens = row.split("| ")
97
+ for token in tokens:
98
+ label, count = token.split(" *")
99
+ decoded_list.extend([label] * int(count))
100
+ return "|".join(decoded_list)
101
+
102
+
103
+ def run_model(args):
104
+ # Model
105
+
106
+ segmenter = GenerativeSegmenter(
107
+ args.model_path,
108
+ device_map="cuda",
109
+ min_pixels=1024 * 28 * 28,
110
+ max_pixels=1280 * 28 * 28
111
+ )
112
+ sam_post_process = True
113
+
114
+ sam = sam_model_registry["vit_h"](checkpoint=args.sam_path)
115
+ sam = sam.to(dtype=torch.float32, device='cuda')
116
+ predictor = SamPredictor(sam)
117
+
118
+ prompt_seg_single = args.query
119
+
120
+ image_files = image_parser(args)
121
+ images = load_images(image_files)
122
+ image = images[0]
123
+ w_ori, h_ori = image.size
124
+
125
+ with torch.inference_mode():
126
+
127
+
128
+ predictor.set_image(np.array(image))
129
+ segmentation_masks, response_text = segmenter.generate_with_segmentation(
130
+ image, prompt_seg_single
131
+ )
132
+
133
+
134
+ print("Last response text:")
135
+ print(response_text) # This will print the last iteration's response_text
136
+
137
+ if segmentation_masks is None or len(segmentation_masks) == 0:
138
+ print("No mask found.")
139
+ return
140
+
141
+ assert len(segmentation_masks) == 1
142
+
143
+ mask = segmentation_masks[0] # This will use the last iteration's mask
144
+
145
+ mask_pred = pred_mask = F.interpolate(
146
+ mask.unsqueeze(0).unsqueeze(0).double(),
147
+ size=(h_ori, w_ori),
148
+ mode='nearest'
149
+ ).squeeze(0).squeeze(0)
150
+
151
+ new_mask_pred = np.zeros((mask_pred.shape[0], mask_pred.shape[1]))
152
+ unique_classes = np.unique(mask_pred)
153
+
154
+ if sam_post_process:
155
+ unique_classes = torch.unique(pred_mask)
156
+ for class_id in unique_classes:
157
+ if class_id == 0:
158
+ continue
159
+ binary_mask = (pred_mask == class_id).double().cpu()
160
+ try:
161
+ logits = compute_logits_from_mask(pred_mask.cpu())
162
+ point_coords, point_labels = masks_sample_points(binary_mask)
163
+ sam_mask, _, logit = predictor.predict(
164
+ point_coords=point_coords,
165
+ point_labels=point_labels,
166
+ mask_input=logits,
167
+ multimask_output=False
168
+ )
169
+ for _ in range(2):
170
+ sam_mask, _, logit = predictor.predict(
171
+ point_coords=point_coords,
172
+ point_labels=point_labels,
173
+ mask_input=logit,
174
+ multimask_output=False
175
+ )
176
+ sam_mask = sam_mask[0].astype(np.float32)
177
+ except Exception as E:
178
+ print(f"Error: {E}")
179
+ sam_mask = np.zeros((h_ori, w_ori))
180
+ new_mask_pred = torch.from_numpy(sam_mask).to(pred_mask.device)
181
+ else:
182
+ new_mask_pred = mask_pred
183
+ new_mask_pred = new_mask_pred.unsqueeze(-1).repeat(1, 1, 3).numpy()
184
+
185
+ os.makedirs("STAMP/images", exist_ok=True)
186
+ image_path="STAMP/images/horses.png"
187
+ base_name = image_path.split("/")[-1].split(".")[0]
188
+ save_path = "{}/{}_mask.jpg".format(
189
+ "STAMP/images", base_name)
190
+ cv2.imwrite(save_path, new_mask_pred * 255)
191
+
192
+
193
+
194
+ if __name__ == "__main__":
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument("--model-path", type=str, default="JiaZL/STAMP-2B-uni")
197
+ parser.add_argument("--image-file", type=str, default='STAMP/images/horses.png')
198
+ parser.add_argument("--sam_path", type=str, default='HCMUE-Research/SAM-vit-h')
199
+ parser.add_argument("--query", type=str, default='Please segment the white horse in the image.')
200
+ parser.add_argument("--sep", type=str, default=",")
201
+ args = parser.parse_args()
202
+
203
+ run_model(args)
segment_predictor_cache.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoProcessor, DynamicCache
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ from model.qwen_changes import get_rope_index, SegQwenVL
7
+ import os
8
+ import json
9
+ import time
10
+
11
+
12
+ def find_image_patch_info(image_pad_id, input_ids: torch.Tensor):
13
+ """
14
+ From the end to the beginning, find consecutive image_pad_id in the input tensor and return their count.
15
+
16
+ Parameters:
17
+ image_pad_id (int): The ID of the image padding token.
18
+ input_ids (torch.Tensor): The input tensor of IDs.
19
+
20
+ Returns:
21
+ int: The number of consecutive image patches.
22
+
23
+ Raises:
24
+ RuntimeError: If no image patches (<|image_pad|>) are found in input_ids.
25
+ """
26
+ input_ids_list = input_ids.squeeze().tolist()
27
+
28
+ # Reverse the list to search from the end to the beginning
29
+ reversed_input_ids_list = input_ids_list[::-1]
30
+
31
+ try:
32
+ # Find the first occurrence of image_pad_id in the reversed list
33
+ start_idx_rev = reversed_input_ids_list.index(image_pad_id)
34
+ end_idx_rev = start_idx_rev
35
+
36
+ # Continue to find consecutive image_pad_id
37
+ while end_idx_rev + 1 < len(reversed_input_ids_list) and reversed_input_ids_list[
38
+ end_idx_rev + 1] == image_pad_id:
39
+ end_idx_rev += 1
40
+
41
+ num_patches = (end_idx_rev - start_idx_rev) + 1
42
+ return num_patches
43
+ except ValueError:
44
+ raise RuntimeError("No image patches (<|image_pad|>) found in input_ids.")
45
+
46
+
47
+ class GenerativeSegmenter:
48
+ def __init__(self, model_path: str, min_pixels, max_pixels, **kwargs):
49
+ min_pixels = min_pixels
50
+ max_pixels = max_pixels
51
+ self.device = kwargs.get("device_map", "cuda" if torch.cuda.is_available() else "cpu")
52
+
53
+ # --- New intelligent loading logic ---
54
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
55
+
56
+ if os.path.exists(adapter_config_path):
57
+ print(f"Detected PEFT adapter configuration: {adapter_config_path}. Will load base model first, then load adapter.")
58
+ # Read the base model path from the adapter configuration
59
+ with open(adapter_config_path, 'r', encoding='utf-8') as f:
60
+ adapter_config = json.load(f)
61
+ # Base model path, if not present in the config, you need to specify it manually
62
+ base_model_path = adapter_config.get("base_model_name_or_path")
63
+ if not base_model_path:
64
+ # ********************************************************************************
65
+ # ** Important: If adapter_config.json does not contain base_model_name_or_path,
66
+ # ** please manually specify the correct base model name or path here
67
+ # ** Based on your previous error messages, the base model is likely "Qwen/Qwen2-VL-7B-Instruct"
68
+ # ********************************************************************************
69
+ base_model_path = "Qwen/Qwen2-VL-7B-Instruct"
70
+ print(f"Warning: 'base_model_name_or_path' not found in adapter configuration. Using default base model: '{base_model_path}'")
71
+ # 1. Load the base model
72
+ print(f"Loading base model from '{base_model_path}'...")
73
+ self.model = SegQwenVL.from_pretrained(
74
+ base_model_path,
75
+ torch_dtype="auto",
76
+ trust_remote_code=True,
77
+ # attn_implementation="flash_attention_2",
78
+ **kwargs
79
+ )
80
+ self.processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True,
81
+ min_pixels=min_pixels, max_pixels=max_pixels)
82
+ self.tokenizer = self.processor.tokenizer
83
+ self._add_special_tokens()
84
+ # 2. Load the adapter
85
+ print(f"Loading adapter from '{model_path}'...")
86
+ self.model.load_adapter(model_path)
87
+
88
+ else:
89
+ print(f"No PEFT adapter detected. Loading full model directly from '{model_path}'.")
90
+ # Keep the original direct loading method
91
+ self.model = SegQwenVL.from_pretrained(
92
+ model_path,
93
+ torch_dtype="auto",
94
+ trust_remote_code=True,
95
+ # attn_implementation="flash_attention_2",
96
+ **kwargs
97
+ )
98
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, min_pixels=min_pixels,
99
+ max_pixels=max_pixels)
100
+ self.tokenizer = self.processor.tokenizer
101
+ self._add_special_tokens()
102
+ # --- Intelligent loading logic ends ---
103
+
104
+ TargetClass = type(self.model.model)
105
+ TargetClass.get_rope_index = get_rope_index
106
+
107
+ # Get key token IDs
108
+ self.yes_token_id = self.tokenizer.convert_tokens_to_ids("<|yes|>")
109
+ self.no_token_id = self.tokenizer.convert_tokens_to_ids("<|no|>")
110
+ self.seg_token_id = self.tokenizer.convert_tokens_to_ids("<|seg|>")
111
+ self.mask_token_id = self.tokenizer.convert_tokens_to_ids("<|mask|>")
112
+ self.image_pad_id = self.tokenizer.convert_tokens_to_ids('<|image_pad|>')
113
+ self.eos_token_id = self.tokenizer.eos_token_id
114
+ self.model.mask_token_id = self.mask_token_id
115
+
116
+ def _add_special_tokens(self):
117
+ special_tokens = {'additional_special_tokens': ["<|seg|>", "<|mask|>", "<|yes|>", "<|no|>"]}
118
+ num_added = self.tokenizer.add_special_tokens(special_tokens)
119
+ if num_added > 0:
120
+ print(f"Added {num_added} special tokens. Resizing model embedding layer...")
121
+ self.model.resize_token_embeddings(len(self.tokenizer))
122
+ # Check if the resized size matches your model's expectations
123
+ print(
124
+ f"Resized vocabulary size: {len(self.tokenizer)}, Model embedding layer size: {self.model.get_input_embeddings().weight.shape[0]}")
125
+ if self.tokenizer.pad_token_id is None:
126
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
127
+
128
+ @torch.no_grad()
129
+ def generate_with_segmentation(self, image: Image.Image, prompt: str):
130
+ messages = [{"role": "user", "content": [{"image": image}, {"text": prompt}]}]
131
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
132
+ inputs = self.processor(text=[text], images=[image], return_tensors="pt")
133
+ merge_size = self.processor.image_processor.merge_size
134
+
135
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
136
+ prompt_len = inputs['input_ids'].shape[1]
137
+ image_grid_thw = inputs.get('image_grid_thw').to(self.device) # Qwen2.5-VL may use this key
138
+ attention_mask_raw = inputs['attention_mask'].to(self.device)
139
+
140
+ outputs = self.model.generate(
141
+ **inputs,
142
+ max_new_tokens=1024,
143
+ use_cache=True,
144
+ return_dict_in_generate=True,
145
+ eos_token_id=self.eos_token_id,
146
+ pad_token_id=self.tokenizer.pad_token_id
147
+ )
148
+
149
+ sequence = outputs.sequences[0]
150
+ full_past_key_values = outputs.past_key_values
151
+
152
+ # Find all <seg> token positions
153
+ seg_indices = torch.where(sequence == self.seg_token_id)[0].tolist()
154
+
155
+ all_segmentation_masks = []
156
+ seg_forward_times = [] # Initialize list to store times
157
+ if not seg_indices: # If there are no segmentation tasks
158
+ generated_ids = sequence[prompt_len:]
159
+ response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
160
+ return None, response_text
161
+
162
+ num_patches = find_image_patch_info(self.image_pad_id, inputs['input_ids'])
163
+
164
+ # Iterate over each <seg> token and perform segmentation
165
+ for i, idx in enumerate(seg_indices):
166
+ sliced_len = idx + 1
167
+ attention_mask = attention_mask_raw[:, :sliced_len]
168
+ legacy_cache = full_past_key_values.to_legacy_cache()
169
+ # 2. Slice each tensor in the tuple
170
+ past_key_values_sliced = tuple(
171
+ (
172
+ key_layer[:, :, :sliced_len, :],
173
+ value_layer[:, :, :sliced_len, :]
174
+ )
175
+ for key_layer, value_layer in legacy_cache
176
+ )
177
+ past_key_values_sliced = DynamicCache.from_legacy_cache(past_key_values_sliced)
178
+
179
+ mask_query_ids = torch.full((1, num_patches), self.mask_token_id, dtype=torch.long, device=self.device)
180
+ mask_query_attention_mask = torch.ones((1, num_patches + sliced_len - attention_mask[0].sum()),
181
+ dtype=torch.long, device=self.device)
182
+ mask_query_attention_mask = torch.cat((attention_mask, mask_query_attention_mask), dim=1)
183
+ mask_grid_thw = image_grid_thw[-1].clone()
184
+ mask_grid_thw = mask_grid_thw.unsqueeze(0)
185
+
186
+ mask_pre_ids = sequence.clone().unsqueeze(0)
187
+ mask_ids = torch.cat([mask_pre_ids[0, :sliced_len], mask_query_ids[0]], dim=0).unsqueeze(0)
188
+
189
+ seg_forward_outputs = self.model(
190
+ input_ids=mask_ids,
191
+ attention_mask=mask_query_attention_mask,
192
+ image_grid_thw=image_grid_thw,
193
+ pixel_values=inputs['pixel_values'],
194
+ past_key_values=past_key_values_sliced,
195
+ return_dict=True,
196
+ do_classification=True
197
+ )
198
+
199
+ mask_logits = seg_forward_outputs.bi_logits[:, -num_patches:]
200
+
201
+ segmentation_preds = (mask_logits > 0).long().squeeze().cpu()
202
+ h_grid, w_grid = mask_grid_thw[0, 1:]
203
+ h_grid, w_grid = int(h_grid / merge_size), int(w_grid / merge_size)
204
+ segmentation_preds = segmentation_preds.view(h_grid, w_grid)
205
+ all_segmentation_masks.append(segmentation_preds)
206
+
207
+ generated_ids = sequence[prompt_len:]
208
+ response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
209
+
210
+
211
+ return all_segmentation_masks, response_text
212
+