|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from typing import Callable, Tuple | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def bipartite_soft_matching( | 
					
						
						|  | metric: torch.Tensor, | 
					
						
						|  | r: int, | 
					
						
						|  | ) -> Tuple[Callable, Callable]: | 
					
						
						|  | """ | 
					
						
						|  | Applies ToMe with a balanced matching set (50%, 50%). | 
					
						
						|  |  | 
					
						
						|  | Input size is [batch, tokens, channels]. | 
					
						
						|  | r indicates the number of tokens to remove (max 50% of tokens). | 
					
						
						|  | """ | 
					
						
						|  | protected = 0 | 
					
						
						|  |  | 
					
						
						|  | t = metric.shape[1] | 
					
						
						|  | r = min(r, (t - protected) // 2) | 
					
						
						|  |  | 
					
						
						|  | assert r > 0, r | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | metric = metric / metric.norm(dim=-1, keepdim=True) | 
					
						
						|  | a, b = metric[..., ::2, :], metric[..., 1::2, :] | 
					
						
						|  | scores = a @ b.transpose(-1, -2) | 
					
						
						|  |  | 
					
						
						|  | node_max, node_idx = scores.max(dim=-1) | 
					
						
						|  | edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] | 
					
						
						|  |  | 
					
						
						|  | unm_idx = edge_idx[..., r:, :] | 
					
						
						|  | src_idx = edge_idx[..., :r, :] | 
					
						
						|  | dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) | 
					
						
						|  |  | 
					
						
						|  | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: | 
					
						
						|  | src, dst = x[..., ::2, :], x[..., 1::2, :] | 
					
						
						|  | n, t1, c = src.shape | 
					
						
						|  | unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) | 
					
						
						|  | src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) | 
					
						
						|  | dst = dst.scatter_add(-2, dst_idx.expand(n, r, c), src) | 
					
						
						|  |  | 
					
						
						|  | return torch.cat([unm, dst], dim=1) | 
					
						
						|  |  | 
					
						
						|  | def unmerge(x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | unm_len = unm_idx.shape[1] | 
					
						
						|  | unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] | 
					
						
						|  | n, _, c = unm.shape | 
					
						
						|  |  | 
					
						
						|  | src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)) | 
					
						
						|  |  | 
					
						
						|  | out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype) | 
					
						
						|  |  | 
					
						
						|  | out[..., 1::2, :] = dst | 
					
						
						|  | out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm) | 
					
						
						|  | out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src) | 
					
						
						|  |  | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | return merge, unmerge | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def merge_wavg( | 
					
						
						|  | merge: Callable, x: torch.Tensor, size: torch.Tensor = None | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Applies the merge function by taking a weighted average based on token size. | 
					
						
						|  | Returns the merged tensor and the new token sizes. | 
					
						
						|  | """ | 
					
						
						|  | if size is None: | 
					
						
						|  | size = torch.ones_like(x[..., 0, None]) | 
					
						
						|  |  | 
					
						
						|  | x = merge(x * size, mode="sum") | 
					
						
						|  | size = merge(size, mode="sum") | 
					
						
						|  |  | 
					
						
						|  | x = x / size | 
					
						
						|  | return x, size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ToMe16_mlp_hd64(nn.Module): | 
					
						
						|  | def __init__(self, config, vision_cfg): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self._config = config | 
					
						
						|  | self.mm_hidden_size = config.mm_hidden_size | 
					
						
						|  | self.hw = vision_cfg.image_size // vision_cfg.patch_size | 
					
						
						|  | self.num_attention_heads = vision_cfg.num_attention_heads | 
					
						
						|  | self.mlp = nn.Sequential(nn.Linear(config.mm_hidden_size, config.hidden_size), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(config.hidden_size, config.hidden_size)) | 
					
						
						|  | self.max_pos_hw = self.hw | 
					
						
						|  | self.max_pos_num_frames = config.mm_pos_num_frames | 
					
						
						|  | self.num_image_patches_per_side = 8 | 
					
						
						|  | self.num_frame_patches_per_side = 4 | 
					
						
						|  |  | 
					
						
						|  | def merge_tokens(self, x, target_num_token): | 
					
						
						|  | r""" | 
					
						
						|  | x = torch.randn(10, 2560, c) | 
					
						
						|  | x = merge_tokens(x, r_merge_list=[1280]) | 
					
						
						|  | """ | 
					
						
						|  | size = None | 
					
						
						|  | b, p, c = x.shape | 
					
						
						|  | tmp_p = p | 
					
						
						|  | r_merge_list = [] | 
					
						
						|  | assert tmp_p > target_num_token, f"{tmp_p} should greater than {target_num_token}" | 
					
						
						|  | while tmp_p != target_num_token: | 
					
						
						|  | if tmp_p - target_num_token <= (tmp_p // 2): | 
					
						
						|  | r_merge_list.append(tmp_p - target_num_token) | 
					
						
						|  | break | 
					
						
						|  | else: | 
					
						
						|  | r_merge_list.append(tmp_p // 2) | 
					
						
						|  | tmp_p = tmp_p - (tmp_p // 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | head = self.num_attention_heads | 
					
						
						|  |  | 
					
						
						|  | dim = c // head | 
					
						
						|  | for r in r_merge_list: | 
					
						
						|  | metric = x.reshape(b, p, head, dim).mean(2) | 
					
						
						|  | merge, _ = bipartite_soft_matching( | 
					
						
						|  | metric, | 
					
						
						|  | r | 
					
						
						|  | ) | 
					
						
						|  | x, size = merge_wavg(merge, x, size) | 
					
						
						|  | _, p, _ = x.shape | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, compress=False, local_num_frames=-1): | 
					
						
						|  | height = width = self.hw | 
					
						
						|  | assert height * width == x.shape[1] | 
					
						
						|  |  | 
					
						
						|  | if local_num_frames != -1 and local_num_frames != 1: | 
					
						
						|  | assert compress is True | 
					
						
						|  | if compress: | 
					
						
						|  | if local_num_frames != -1: | 
					
						
						|  | num_frames = local_num_frames | 
					
						
						|  | x = x.reshape(x.shape[0] // local_num_frames, -1, x.shape[-1]) | 
					
						
						|  | else: | 
					
						
						|  | num_frames = x.shape[0] | 
					
						
						|  | x = x.reshape(1, -1, x.shape[-1]) | 
					
						
						|  | num_tome_tokens = 16 * num_frames | 
					
						
						|  | else: | 
					
						
						|  | num_tome_tokens = 64 | 
					
						
						|  |  | 
					
						
						|  | x = self.merge_tokens(x, target_num_token=num_tome_tokens) | 
					
						
						|  | x = self.mlp(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def config(self): | 
					
						
						|  | return {"mm_projector_type": "tome16_mlp_hd64"} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def build_vision_projector(config, delay_load=False, **kwargs): | 
					
						
						|  | projector_type = getattr(config, "mm_projector_type", "linear") | 
					
						
						|  |  | 
					
						
						|  | if projector_type == 'tome16_mlp_hd64': | 
					
						
						|  | return ToMe16_mlp_hd64(config, kwargs["vision_cfg"]) | 
					
						
						|  |  | 
					
						
						|  | raise ValueError(f"Unknown projector type: {projector_type}") | 
					
						
						|  |  |