ZJU-AI4H commited on
Commit
d8e2cf1
·
verified ·
1 Parent(s): cf8ef18

Upload folder using huggingface_hub

Browse files
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "\n{%- set identifier = 'im' %}\n{% for message in messages %}\n {% if add_system_prompt and loop.first and message['role'] != 'system' %}\n {{- '<|im_start|>system\nYou are Hulu-Med, a helpful health assistant that can understand text, 2D images, videos, and 3D images.<|im_end|>\n' -}}\n {% endif %}\n {% if message['role'] == 'stream' %}\n {% set identifier = 'stream' %}\n {% else %}\n {% set identifier = 'im' %}\n {% endif %}\n {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}\n {% if message['content'] is string %}\n {{- message['content'] + '<|' + identifier + '_end|>\n' -}}\n {% else %}\n {% for content in message['content'] %}\n {% if content is string %}\n {{- content -}}\n {% elif content['type'] == 'text' or 'text' in content %}\n {{- content['text'] -}}\n {% elif content['type'] == 'image' or 'image' in content %}\n {% if 'timestamp' in content %}\n {{- 'Time ' + content['timestamp'] | round(1) | string + 's: ' -}}\n {% endif %}\n {{- image_token + '\n' -}}\n {% elif content['type'] == 'video' or 'video' in content %}\n {% for i in range(content['num_frames']) %}\n {% if 'timestamps' in content %}\n {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}\n {% endif %}\n {% if i < content['num_frames'] - 1 %}\n {{- image_token + ',' -}}\n {% else %}\n {{- image_token + '\n' -}}\n {% endif %}\n {% endfor %}\n {% endif %}\n {% endfor %}\n {% if identifier == 'stream' %}\n {{- '<|' + identifier + '_end|>' -}}\n {% else %}\n {{- '<|' + identifier + '_end|>\n' -}}\n {% endif %}\n {% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' -}}\n{% endif %}\n"
3
+ }
config.json CHANGED
@@ -1,58 +1,118 @@
1
  {
2
- "architectures": [
3
- "HulumedQwen2ForCausalLM"
4
- ],
5
- "attention_bias": false,
6
- "attention_dropout": 0.0,
7
- "bos_token_id": 151643,
8
- "decoder_sparse_step": 1,
9
- "eos_token_id": 151645,
10
- "head_dim": 128,
11
- "hidden_act": "silu",
12
- "hidden_size": 5120,
13
- "image_aspect_ratio": "avt",
14
- "image_size": -1,
15
- "image_token_index": 151669,
16
- "image_token_length": 1,
17
- "initializer_range": 0.02,
18
- "intermediate_size": 17408,
19
- "is_alignment": false,
20
- "llm_lr": 5e-05,
21
- "max_frames": 180,
22
- "max_position_embeddings": 40960,
23
- "max_window_layers": 40,
24
- "mlp_only_layers": [],
25
- "mm_hidden_size": 1152,
26
- "mm_projector_lr": 1e-05,
27
- "mm_projector_type": "mlp2x_gelu",
28
- "mm_vision_encoder": "./SigLIP-NaViT",
29
- "mm_vision_select_feature": "patch",
30
- "mm_vision_select_layer": -1,
31
- "model_type": "hulumed_qwen2",
32
- "moe_intermediate_size": 768,
33
- "norm_topk_prob": false,
34
- "num_attention_heads": 40,
35
- "num_experts": 128,
36
- "num_experts_per_tok": 8,
37
- "num_hidden_layers": 40,
38
- "num_key_value_heads": 8,
39
- "output_router_logits": false,
40
- "rms_norm_eps": 1e-06,
41
- "rope_scaling": null,
42
- "rope_theta": 1000000,
43
- "router_aux_loss_coef": 0.001,
44
- "sliding_window": null,
45
- "tie_word_embeddings": false,
46
- "tokenizer_model_max_length": 16384,
47
- "tokenizer_padding_side": "right",
48
- "torch_dtype": "bfloat16",
49
- "transformers_version": "4.51.2",
50
- "use_cache": true,
51
- "use_flash_loss": false,
52
- "use_mm_proj": true,
53
- "use_sliding_window": false,
54
- "use_token_compression": false,
55
- "vision_encoder": "./SigLIP-NaViT",
56
- "vision_encoder_lr": 2e-06,
57
- "vocab_size": 151936
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "architectures": [
3
+ "HulumedQwen3ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_hulumed_qwen3.HulumedQwen3Config",
7
+ "AutoModelForCausalLM": "modeling_hulumed_qwen3.HulumedQwen3ForCausalLM"
8
+ },
9
+ "vocab_size": 151936,
10
+ "max_position_embeddings": 40960,
11
+ "hidden_size": 5120,
12
+ "intermediate_size": 17408,
13
+ "num_hidden_layers": 40,
14
+ "num_attention_heads": 40,
15
+ "use_sliding_window": false,
16
+ "sliding_window": null,
17
+ "max_window_layers": 40,
18
+ "num_key_value_heads": 8,
19
+ "head_dim": 128,
20
+ "hidden_act": "silu",
21
+ "initializer_range": 0.02,
22
+ "rms_norm_eps": 1e-06,
23
+ "use_cache": true,
24
+ "rope_theta": 1000000,
25
+ "rope_scaling": null,
26
+ "attention_bias": false,
27
+ "attention_dropout": 0.0,
28
+ "return_dict": true,
29
+ "output_hidden_states": false,
30
+ "output_attentions": false,
31
+ "torchscript": false,
32
+ "torch_dtype": "bfloat16",
33
+ "use_bfloat16": false,
34
+ "tf_legacy_loss": false,
35
+ "pruned_heads": {},
36
+ "tie_word_embeddings": false,
37
+ "chunk_size_feed_forward": 0,
38
+ "is_encoder_decoder": false,
39
+ "is_decoder": false,
40
+ "cross_attention_hidden_size": null,
41
+ "add_cross_attention": false,
42
+ "tie_encoder_decoder": false,
43
+ "max_length": 20,
44
+ "min_length": 0,
45
+ "do_sample": false,
46
+ "early_stopping": false,
47
+ "num_beams": 1,
48
+ "num_beam_groups": 1,
49
+ "diversity_penalty": 0.0,
50
+ "temperature": 1.0,
51
+ "top_k": 50,
52
+ "top_p": 1.0,
53
+ "typical_p": 1.0,
54
+ "repetition_penalty": 1.0,
55
+ "length_penalty": 1.0,
56
+ "no_repeat_ngram_size": 0,
57
+ "encoder_no_repeat_ngram_size": 0,
58
+ "bad_words_ids": null,
59
+ "num_return_sequences": 1,
60
+ "output_scores": false,
61
+ "return_dict_in_generate": false,
62
+ "forced_bos_token_id": null,
63
+ "forced_eos_token_id": null,
64
+ "remove_invalid_values": false,
65
+ "exponential_decay_length_penalty": null,
66
+ "suppress_tokens": null,
67
+ "begin_suppress_tokens": null,
68
+ "finetuning_task": null,
69
+ "id2label": {
70
+ "0": "LABEL_0",
71
+ "1": "LABEL_1"
72
+ },
73
+ "label2id": {
74
+ "LABEL_0": 0,
75
+ "LABEL_1": 1
76
+ },
77
+ "tokenizer_class": null,
78
+ "prefix": null,
79
+ "bos_token_id": 151643,
80
+ "pad_token_id": null,
81
+ "eos_token_id": 151645,
82
+ "sep_token_id": null,
83
+ "decoder_start_token_id": null,
84
+ "task_specific_params": null,
85
+ "problem_type": null,
86
+ "_attn_implementation_autoset": true,
87
+ "transformers_version": "4.51.2",
88
+ "decoder_sparse_step": 1,
89
+ "image_aspect_ratio": "avt",
90
+ "image_size": -1,
91
+ "image_token_index": 151669,
92
+ "image_token_length": 1,
93
+ "is_alignment": false,
94
+ "llm_lr": 5e-05,
95
+ "max_frames": 180,
96
+ "mlp_only_layers": [],
97
+ "mm_hidden_size": 1152,
98
+ "mm_projector_lr": 1e-05,
99
+ "mm_projector_type": "mlp2x_gelu",
100
+ "model_type": "hulumed_qwen3",
101
+ "moe_intermediate_size": 768,
102
+ "norm_topk_prob": false,
103
+ "num_experts": 128,
104
+ "num_experts_per_tok": 8,
105
+ "output_router_logits": false,
106
+ "router_aux_loss_coef": 0.001,
107
+ "tokenizer_model_max_length": 16384,
108
+ "tokenizer_padding_side": "right",
109
+ "use_token_compression": false,
110
+ "vision_encoder_config": {
111
+ "hidden_size": 1152,
112
+ "intermediate_size": 4304,
113
+ "model_type": "hulumed_vision_encoder",
114
+ "num_attention_heads": 16,
115
+ "num_hidden_layers": 27,
116
+ "patch_size": 14
117
+ }
118
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework": "pytorch", "task": "text-generation", "allow_remote": true}
configuration_hulumed_encoder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """HuluMed vision encoder model configuration."""
18
+
19
+ from transformers import PretrainedConfig
20
+
21
+
22
+ class HulumedVisionEncoderConfig(PretrainedConfig):
23
+
24
+ model_type = "hulumed_vision_encoder"
25
+
26
+ def __init__(
27
+ self,
28
+ hidden_size=768,
29
+ intermediate_size=3072,
30
+ num_hidden_layers=12,
31
+ num_attention_heads=12,
32
+ num_channels=3,
33
+ patch_size=16,
34
+ hidden_act="gelu_pytorch_tanh",
35
+ layer_norm_eps=1e-6,
36
+ attention_dropout=0.0,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.hidden_size = hidden_size
42
+ self.intermediate_size = intermediate_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.num_channels = num_channels
46
+ self.patch_size = patch_size
47
+ self.attention_dropout = attention_dropout
48
+ self.layer_norm_eps = layer_norm_eps
49
+ self.hidden_act = hidden_act
configuration_hulumed_qwen3.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuluMed model configuration."""
2
+
3
+ import importlib.util
4
+ import os.path as osp
5
+ from typing import Optional, Dict, Any
6
+
7
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, Qwen3Config
8
+
9
+ try:
10
+ from .configuration_hulumed_encoder import HulumedVisionEncoderConfig
11
+ except ModuleNotFoundError:
12
+ spec = importlib.util.spec_from_file_location(
13
+ "configuration_hulumed_encoder",
14
+ osp.join(osp.dirname(__file__), "configuration_hulumed_encoder.py"),
15
+ )
16
+ configuration_hulumed_encoder = importlib.util.module_from_spec(spec)
17
+ spec.loader.exec_module(configuration_hulumed_encoder)
18
+ HulumedVisionEncoderConfig = getattr(
19
+ configuration_hulumed_encoder,
20
+ "HulumedVisionEncoderConfig",
21
+ )
22
+
23
+ try:
24
+ from .modeling_hulumed_encoder import HulumedVisionEncoderModel
25
+ except ModuleNotFoundError:
26
+ spec = importlib.util.spec_from_file_location(
27
+ "modeling_hulumed_encoder",
28
+ osp.join(osp.dirname(__file__), "modeling_hulumed_encoder.py"),
29
+ )
30
+ modeling_hulumed_encoder = importlib.util.module_from_spec(spec)
31
+ spec.loader.exec_module(modeling_hulumed_encoder)
32
+ HulumedVisionEncoderModel = getattr(
33
+ modeling_hulumed_encoder,
34
+ "HulumedVisionEncoderModel",
35
+ )
36
+
37
+ AutoConfig.register("hulumed_vision_encoder", HulumedVisionEncoderConfig)
38
+ AutoModel.register(HulumedVisionEncoderConfig, HulumedVisionEncoderModel)
39
+
40
+
41
+ class HulumedQwen3Config(Qwen3Config):
42
+ """
43
+ HuluMed model configuration.
44
+
45
+ This configuration class extends Qwen2Config to store the configuration of a HuluMed model.
46
+ It includes configuration for the vision encoder and multimodal projector.
47
+ """
48
+
49
+ model_type = "hulumed_qwen3"
50
+ sub_configs = {"vision_encoder_config": HulumedVisionEncoderConfig}
51
+
52
+ def __init__(
53
+ self,
54
+ vision_encoder: Optional[str] = None,
55
+ vision_encoder_config: Dict[str, Any] = {},
56
+ mm_projector_type: str = "mlp2x_gelu",
57
+ use_token_compression: bool = True,
58
+ image_token_index: int = -1,
59
+ **kwargs,
60
+ ):
61
+ """
62
+ Initialize HuluMed configuration.
63
+
64
+ Args:
65
+ vision_encoder (str, optional): Path or identifier of the vision encoder.
66
+ vision_encoder_config (dict, optional): Configuration for the vision encoder.
67
+ mm_projector_type (str): Type of multimodal projector. Default is "mlp2x_gelu".
68
+ use_token_compression (bool): Whether to use token compression for videos. Default is True.
69
+ image_token_index (int): Token index for image placeholders. Default is -1.
70
+ **kwargs: Additional arguments passed to Qwen2Config.
71
+ """
72
+ super().__init__(**kwargs)
73
+ self.model_type = "hulumed_qwen3"
74
+
75
+ self.vision_encoder = vision_encoder
76
+
77
+ if vision_encoder_config is not None and not isinstance(vision_encoder_config, PretrainedConfig):
78
+ vision_encoder_config = HulumedVisionEncoderConfig(**vision_encoder_config)
79
+
80
+ self.vision_encoder_config = vision_encoder_config
81
+ self.mm_projector_type = mm_projector_type
82
+ self.use_token_compression = use_token_compression
83
+ self.image_token_index = image_token_index
generation_config.json CHANGED
@@ -6,8 +6,9 @@
6
  151643
7
  ],
8
  "pad_token_id": 151643,
9
- "temperature": 0.6,
 
10
  "top_k": 20,
11
- "top_p": 0.95,
12
  "transformers_version": "4.51.2"
13
  }
 
6
  151643
7
  ],
8
  "pad_token_id": 151643,
9
+ "repetition_penalty": 1.05,
10
+ "temperature": 0.7,
11
  "top_k": 20,
12
+ "top_p": 0.8,
13
  "transformers_version": "4.51.2"
14
  }
image_processing_hulumed.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """Image processor class for HuluMed."""
22
+
23
+ import math
24
+ from typing import Dict, List, Optional, Union
25
+
26
+ import numpy as np
27
+
28
+ import torch
29
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
30
+ from transformers.image_utils import ImageInput
31
+ from transformers.image_transforms import (
32
+ convert_to_rgb,
33
+ resize,
34
+ to_channel_dimension_format,
35
+ )
36
+ from transformers.image_utils import (
37
+ OPENAI_CLIP_MEAN,
38
+ OPENAI_CLIP_STD,
39
+ ChannelDimension,
40
+ ImageInput,
41
+ PILImageResampling,
42
+ VideoInput,
43
+ get_image_size,
44
+ infer_channel_dimension_format,
45
+ is_scaled_image,
46
+ is_valid_image,
47
+ make_list_of_images,
48
+ to_numpy_array,
49
+ )
50
+ from transformers.utils import TensorType, is_vision_available, logging
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ if is_vision_available():
57
+ from PIL import Image
58
+
59
+
60
+ def is_valid_video(video) -> bool:
61
+ if isinstance(video, (list, tuple)):
62
+ return all(is_valid_image(frame) for frame in video)
63
+ elif isinstance(video, np.ndarray):
64
+ return video.ndim == 4
65
+ elif isinstance(video, torch.Tensor):
66
+ return video.ndim == 4
67
+ return False
68
+
69
+
70
+ def make_batched_images(images) -> List[List[ImageInput]]:
71
+ """
72
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
73
+
74
+ Args:
75
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
76
+ The input image.
77
+
78
+ Returns:
79
+ list: A list of images.
80
+ """
81
+ if isinstance(images, (list, tuple)):
82
+ # list of images/videos
83
+ if not all(is_valid_video(image) or is_valid_image(image) for image in images):
84
+ raise ValueError(f"Could not make batched images from {images}")
85
+ return images
86
+ elif is_valid_video(images) or is_valid_image(images):
87
+ # single image/video
88
+ return [images]
89
+
90
+ raise ValueError(f"Could not make batched images from {images}")
91
+
92
+
93
+ def simple_batched_resize(
94
+ images, factor: int = 28, min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
95
+ ):
96
+ min_pixels = min_tokens * factor * factor
97
+ max_pixels = max_tokens * factor * factor
98
+
99
+ num_images = 0
100
+ for image in images:
101
+ if is_valid_video(image):
102
+ num_images += len(image)
103
+ else:
104
+ num_images += 1
105
+
106
+ image_sizes = []
107
+ for image in images:
108
+ if is_valid_video(image):
109
+ image = image[0]
110
+ if isinstance(image, Image.Image):
111
+ height, width = image.size
112
+ else:
113
+ height, width = get_image_size(image, channel_dim=input_data_format)
114
+ image_sizes.append([height, width])
115
+
116
+ tmp_image_sizes = []
117
+ for height, width in image_sizes:
118
+ h_bar = round(height / factor) * factor
119
+ w_bar = round(width / factor) * factor
120
+ if h_bar * w_bar > (max_pixels // num_images):
121
+ beta = math.sqrt((height * width) / (max_pixels // num_images))
122
+ h_bar = math.floor(height / beta / factor) * factor
123
+ w_bar = math.floor(width / beta / factor) * factor
124
+ # per image min_pixels
125
+ if h_bar * w_bar < min_pixels:
126
+ beta = math.sqrt(min_pixels / (height * width))
127
+ h_bar = math.ceil(height * beta / factor) * factor
128
+ w_bar = math.ceil(width * beta / factor) * factor
129
+ tmp_image_sizes.append((h_bar, w_bar))
130
+ image_sizes = tmp_image_sizes
131
+ return image_sizes
132
+
133
+
134
+ def batched_resize(
135
+ images, factors: List[int], min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
136
+ ):
137
+ image_sizes = []
138
+ for image in images:
139
+ if is_valid_video(image):
140
+ num_frame = len(image)
141
+ image = image[0]
142
+ else:
143
+ num_frame = 1
144
+ if isinstance(image, Image.Image):
145
+ height, width = image.size
146
+ else:
147
+ height, width = get_image_size(image, channel_dim=input_data_format)
148
+ image_sizes.append([num_frame, height, width])
149
+
150
+ # global max_pixels
151
+ smart_scale_factors = 1.0
152
+ total_tokens = 0
153
+ for (num_frame, height, width), factor in zip(image_sizes, factors):
154
+ total_tokens += num_frame * math.ceil(height / factor) * math.ceil(width / factor)
155
+
156
+ # TODO: add min_pixels
157
+ if total_tokens > max_tokens:
158
+ beta = math.sqrt(total_tokens / max_tokens)
159
+ tmp_image_sizes = []
160
+ for (_, height, width), factor in zip(image_sizes, factors):
161
+ h_bar = math.floor(height / beta / factor) * factor
162
+ w_bar = math.floor(width / beta / factor) * factor
163
+ tmp_image_sizes.append((h_bar, w_bar))
164
+ image_sizes = tmp_image_sizes
165
+ else:
166
+ tmp_image_sizes = []
167
+ for (_, height, width), factor in zip(image_sizes, factors):
168
+ height = round(height / factor) * factor
169
+ width = round(width / factor) * factor
170
+ tmp_image_sizes.append((height, width))
171
+ image_sizes = tmp_image_sizes
172
+
173
+ return image_sizes
174
+
175
+
176
+ class HulumedImageProcessor(BaseImageProcessor):
177
+ r"""
178
+ Constructs a HuluMed image processor that dynamically resizes images based on the original images.
179
+
180
+ Args:
181
+ do_resize (`bool`, *optional*, defaults to `True`):
182
+ Whether to resize the image's (height, width) dimensions.
183
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
184
+ Resampling filter to use when resizing the image.
185
+ do_rescale (`bool`, *optional*, defaults to `True`):
186
+ Whether to rescale the image by the specified scale `rescale_factor`.
187
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
188
+ Scale factor to use if rescaling the image.
189
+ do_normalize (`bool`, *optional*, defaults to `True`):
190
+ Whether to normalize the image.
191
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
192
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
193
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
194
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
195
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
196
+ Whether to convert the image to RGB.
197
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
198
+ The min pixels of the image to resize the image.
199
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
200
+ The max pixels of the image to resize the image.
201
+ patch_size (`int`, *optional*, defaults to 14):
202
+ The spacial patch size of the vision encoder.
203
+ merge_size (`int`, *optional*, defaults to `None`):
204
+ The default merge size for processing. If None, no default merge size is applied.
205
+ """
206
+
207
+ model_input_names = ["pixel_values", "grid_sizes", "merge_sizes"]
208
+
209
+ def __init__(
210
+ self,
211
+ do_resize: bool = True,
212
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
213
+ do_rescale: bool = True,
214
+ rescale_factor: Union[int, float] = 1 / 255,
215
+ do_normalize: bool = True,
216
+ image_mean: Optional[Union[float, List[float]]] = None,
217
+ image_std: Optional[Union[float, List[float]]] = None,
218
+ do_convert_rgb: bool = True,
219
+ min_tokens: int = 4 * 4,
220
+ max_tokens: int = 16384,
221
+ patch_size: int = 14,
222
+ merge_size: Optional[int] = None,
223
+ **kwargs,
224
+ ) -> None:
225
+ super().__init__(**kwargs)
226
+ self.do_resize = do_resize
227
+ self.resample = resample
228
+ self.do_rescale = do_rescale
229
+ self.rescale_factor = rescale_factor
230
+ self.do_normalize = do_normalize
231
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
232
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
233
+ self.min_tokens = min_tokens
234
+ self.max_tokens = max_tokens
235
+ self.patch_size = patch_size
236
+ self.do_convert_rgb = do_convert_rgb
237
+ self.merge_size = merge_size # Added: default merge_size
238
+
239
+ def _preprocess(
240
+ self,
241
+ images: Union[ImageInput, VideoInput],
242
+ target_size: List[int],
243
+ merge_size: int = 1,
244
+ do_resize: bool = None,
245
+ resample: PILImageResampling = None,
246
+ do_rescale: bool = None,
247
+ rescale_factor: float = None,
248
+ do_normalize: bool = None,
249
+ image_mean: Optional[Union[float, List[float]]] = None,
250
+ image_std: Optional[Union[float, List[float]]] = None,
251
+ do_convert_rgb: bool = None,
252
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
253
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
254
+ ):
255
+ """
256
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
257
+
258
+ Args:
259
+ images (`ImageInput`):
260
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
261
+ target_size (`List[int]`):
262
+ The target size to resize the image to. Should be a list of two integers: [target_height, target_width].
263
+ merge_size (`int`, *optional*, defaults to `1`):
264
+ The merge size after the vision encoder.
265
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
266
+ Whether to resize the image.
267
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
268
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
269
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
270
+ Whether to rescale the image.
271
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
272
+ Scale factor to use if rescaling the image.
273
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
274
+ Whether to normalize the image.
275
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
276
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
277
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
278
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
279
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
280
+ Whether to convert the image to RGB.
281
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
282
+ The channel dimension format for the output image. Can be one of:
283
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
284
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
285
+ - Unset: Use the channel dimension format of the input image.
286
+ input_data_format (`ChannelDimension` or `str`, *optional*):
287
+ The channel dimension format for the input image. Can be one of:
288
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
289
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
290
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
291
+ """
292
+ images = make_list_of_images(images)
293
+
294
+ if do_convert_rgb:
295
+ images = [convert_to_rgb(image) for image in images]
296
+
297
+ # All transformations expect numpy arrays.
298
+ images = [to_numpy_array(image) for image in images]
299
+
300
+ if is_scaled_image(images[0]) and do_rescale:
301
+ logger.warning_once(
302
+ "It looks like you are trying to rescale already rescaled images. If the input"
303
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
304
+ )
305
+ if input_data_format is None:
306
+ # We assume that all images have the same channel dimension format.
307
+ input_data_format = infer_channel_dimension_format(images[0])
308
+
309
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
310
+ resized_height, resized_width = height, width
311
+ processed_images = []
312
+ for image in images:
313
+ if do_resize:
314
+ resized_height, resized_width = target_size
315
+ image = resize(
316
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
317
+ )
318
+
319
+ if do_rescale:
320
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
321
+
322
+ if do_normalize:
323
+ image = self.normalize(
324
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
325
+ )
326
+
327
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
328
+ processed_images.append(image)
329
+
330
+ patches = np.array(processed_images)
331
+ if data_format == ChannelDimension.LAST:
332
+ patches = patches.transpose(0, 3, 1, 2)
333
+ t = patches.shape[0]
334
+ channel = patches.shape[1]
335
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
336
+ patches = patches.reshape(
337
+ t,
338
+ channel,
339
+ grid_h // merge_size,
340
+ merge_size,
341
+ self.patch_size,
342
+ grid_w // merge_size,
343
+ merge_size,
344
+ self.patch_size,
345
+ )
346
+ patches = patches.transpose(0, 2, 5, 3, 6, 1, 4, 7)
347
+ flatten_patches = patches.reshape(
348
+ t * grid_h * grid_w, channel * self.patch_size * self.patch_size
349
+ )
350
+
351
+ return flatten_patches, (t, grid_h, grid_w)
352
+
353
+ def preprocess(
354
+ self,
355
+ images: ImageInput,
356
+ do_resize: bool = None,
357
+ resample: PILImageResampling = None,
358
+ do_rescale: bool = None,
359
+ rescale_factor: float = None,
360
+ do_normalize: bool = None,
361
+ image_mean: Optional[Union[float, List[float]]] = None,
362
+ image_std: Optional[Union[float, List[float]]] = None,
363
+ do_convert_rgb: bool = None,
364
+ merge_size: Optional[Union[int, List[int]]] = None,
365
+ return_tensors: Optional[Union[str, TensorType]] = None,
366
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
367
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
368
+ ):
369
+ """
370
+ Args:
371
+ images (`ImageInput`):
372
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
373
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
374
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
375
+ Whether to resize the image.
376
+ resample (`int`, *optional*, defaults to `self.resample`):
377
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
378
+ has an effect if `do_resize` is set to `True`.
379
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
380
+ Whether to rescale the image.
381
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
382
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
383
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
384
+ Whether to normalize the image.
385
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
386
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
387
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
388
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
389
+ `True`.
390
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
391
+ Whether to convert the image to RGB.
392
+ merge_size (`int` or `List[int]`, *optional*, defaults to `self.merge_size`):
393
+ The merge size for processing. Can be a single value or a list of values (one per image).
394
+ return_tensors (`str` or `TensorType`, *optional*):
395
+ The type of tensors to return. Can be one of:
396
+ - Unset: Return a list of `np.ndarray`.
397
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
398
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
399
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
400
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
401
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
402
+ The channel dimension format for the output image. Can be one of:
403
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
404
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
405
+ - Unset: Use the channel dimension format of the input image.
406
+ input_data_format (`ChannelDimension` or `str`, *optional*):
407
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
408
+ from the input image. Can be one of:
409
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
410
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
411
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
412
+
413
+ """
414
+ do_resize = do_resize if do_resize is not None else self.do_resize
415
+ resample = resample if resample is not None else self.resample
416
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
417
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
418
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
419
+ image_mean = image_mean if image_mean is not None else self.image_mean
420
+ image_std = image_std if image_std is not None else self.image_std
421
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
422
+
423
+ # Handle merge_size: use provided value, or fall back to instance default, or use 1
424
+ if merge_size is None:
425
+ merge_size = self.merge_size if self.merge_size is not None else 1
426
+
427
+ images = make_batched_images(images)
428
+
429
+ if isinstance(merge_size, (list, tuple)):
430
+ assert len(merge_size) == len(images), "Merge size must be the same length as images."
431
+ merge_sizes = merge_size
432
+ else:
433
+ merge_sizes = [merge_size for _ in images]
434
+ if all(merge_size == merge_sizes[0] for merge_size in merge_sizes):
435
+ target_sizes = simple_batched_resize(
436
+ images,
437
+ factor=self.patch_size * merge_sizes[0],
438
+ min_tokens=self.min_tokens,
439
+ max_tokens=self.max_tokens,
440
+ input_data_format=input_data_format,
441
+ )
442
+ else:
443
+ target_sizes = batched_resize(
444
+ images,
445
+ factors=[self.patch_size * merge_size for merge_size in merge_sizes],
446
+ min_tokens=self.min_tokens,
447
+ max_tokens=self.max_tokens,
448
+ input_data_format=input_data_format,
449
+ )
450
+
451
+ pixel_values, grid_sizes = [], []
452
+ for image, merge_size, target_size in zip(images, merge_sizes, target_sizes):
453
+ patches, grid_size = self._preprocess(
454
+ image,
455
+ target_size=target_size,
456
+ merge_size=merge_size,
457
+ do_resize=do_resize,
458
+ resample=resample,
459
+ do_rescale=do_rescale,
460
+ rescale_factor=rescale_factor,
461
+ do_normalize=do_normalize,
462
+ image_mean=image_mean,
463
+ image_std=image_std,
464
+ data_format=data_format,
465
+ do_convert_rgb=do_convert_rgb,
466
+ input_data_format=input_data_format,
467
+ )
468
+ pixel_values.append(patches)
469
+ grid_sizes.append(grid_size)
470
+
471
+ pixel_values = np.concatenate(pixel_values, axis=0)
472
+ grid_sizes = np.array(grid_sizes)
473
+ merge_sizes = np.array(merge_sizes)
474
+
475
+ data = {
476
+ "pixel_values": pixel_values,
477
+ "grid_sizes": grid_sizes,
478
+ "merge_sizes": merge_sizes,
479
+ }
480
+
481
+ return BatchFeature(data=data, tensor_type=return_tensors)
model-00001-of-00007.safetensors → model-00001-of-00006.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0739e62c2f60dd8a641dad171929acd1e100d373c66cdf27404cdd76d2b85671
3
- size 4984780784
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:851ec07ebcf50eb48ada241b5fdba017a75b3b327256844ca65fc7f4073837a0
3
+ size 5341296816
model-00002-of-00007.safetensors → model-00002-of-00006.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ee847dd850c5ba59e24b1795d91764aeef9c84e92040d6d8598fc15f6b15b873
3
- size 4980892048
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97680c87cfe9290fd89ebf8a72aa4d284b0aaec55674fba644fe71600b7062cf
3
+ size 5285001096
model-00003-of-00007.safetensors → model-00003-of-00006.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:59d7c3f708b91163dda35b0800969137253838cf67975f9fde148a54d41dd5b2
3
- size 4928485104
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6a520cf82ff8dd49062b758541e3bde486fdd76a3d2db181e9b7776463b24a4
3
+ size 5285001144
model-00004-of-00007.safetensors → model-00004-of-00006.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:98b1b3d44fdda704f811c5eedb923afab16341660d8ee1f8b8aed5b66ecdd033
3
- size 4980892112
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae6fea0a295f020e83aaa8d7d2eb53cef4f4dca9f9ba189aa55a4dce6254cdd0
3
+ size 5285001144
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d25ef567536d22a0d153b301d044944974dd594b6b124b6dffab7547fbc6a06e
3
+ size 5285001144
model-00005-of-00007.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:85d22e011701f7fb6ff51d0c453abf6b248786c04d7a101b537c13f32e44310c
3
- size 4928485104
 
 
 
 
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a8dc222611730fc1672e0ebe7ffd28af977721849209dd3df7fdb26f2feafbf
3
+ size 3943963256
model-00006-of-00007.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:33012c7d219507ff15784abbc8c431084383636704d177d3c7ba48e2644671a1
3
- size 4065911432
 
 
 
 
model-00007-of-00007.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dc3972be4a080cd08d4bf691716a852a35ef8284ef4fdf082a5bdcb88141d80d
3
- size 1555824768
 
 
 
 
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_hulumed_encoder.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """PyTorch HuluMed vision encoder model."""
22
+
23
+ import importlib.util
24
+ import os.path as osp
25
+ import math
26
+ import warnings
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ from torch.nn.init import _calculate_fan_in_and_fan_out
33
+
34
+ from transformers.activations import ACT2FN
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import is_flash_attn_2_available
37
+
38
+ if is_flash_attn_2_available():
39
+ from flash_attn import flash_attn_varlen_func
40
+ else:
41
+ flash_attn_varlen_func = None
42
+
43
+ try:
44
+ from .configuration_hulumed_encoder import HulumedVisionEncoderConfig
45
+ except ImportError:
46
+ spec = importlib.util.spec_from_file_location(
47
+ "configuration_hulumed_encoder",
48
+ osp.join(osp.dirname(__file__), "configuration_hulumed_encoder.py"),
49
+ )
50
+ configuration_hulumed_encoder = importlib.util.module_from_spec(spec)
51
+ spec.loader.exec_module(configuration_hulumed_encoder)
52
+ HulumedVisionEncoderConfig = getattr(
53
+ configuration_hulumed_encoder,
54
+ "HulumedVisionEncoderConfig",
55
+ )
56
+
57
+
58
+ def _trunc_normal_(tensor, mean, std, a, b):
59
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
60
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
61
+ def norm_cdf(x):
62
+ # Computes standard normal cumulative distribution function
63
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
64
+
65
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
66
+ warnings.warn(
67
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
68
+ "The distribution of values may be incorrect.",
69
+ stacklevel=2,
70
+ )
71
+
72
+ # Values are generated by using a truncated uniform distribution and
73
+ # then using the inverse CDF for the normal distribution.
74
+ # Get upper and lower cdf values
75
+ l = norm_cdf((a - mean) / std)
76
+ u = norm_cdf((b - mean) / std)
77
+
78
+ # Uniformly fill tensor with values from [l, u], then translate to
79
+ # [2l-1, 2u-1].
80
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
81
+
82
+ # Use inverse cdf transform for normal distribution to get truncated
83
+ # standard normal
84
+ tensor.erfinv_()
85
+
86
+ # Transform to proper mean, std
87
+ tensor.mul_(std * math.sqrt(2.0))
88
+ tensor.add_(mean)
89
+
90
+ # Clamp to ensure it's in the proper range
91
+ tensor.clamp_(min=a, max=b)
92
+
93
+
94
+ def trunc_normal_tf_(
95
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
96
+ ) -> torch.Tensor:
97
+ """Fills the input Tensor with values drawn from a truncated
98
+ normal distribution. The values are effectively drawn from the
99
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
100
+ with values outside :math:`[a, b]` redrawn until they are within
101
+ the bounds. The method used for generating the random values works
102
+ best when :math:`a \\leq \text{mean} \\leq b`.
103
+
104
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
105
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
106
+ and the result is subsequently scaled and shifted by the mean and std args.
107
+
108
+ Args:
109
+ tensor: an n-dimensional `torch.Tensor`
110
+ mean: the mean of the normal distribution
111
+ std: the standard deviation of the normal distribution
112
+ a: the minimum cutoff value
113
+ b: the maximum cutoff value
114
+ """
115
+ with torch.no_grad():
116
+ _trunc_normal_(tensor, 0, 1.0, a, b)
117
+ tensor.mul_(std).add_(mean)
118
+
119
+
120
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
121
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
122
+ if mode == "fan_in":
123
+ denom = fan_in
124
+ elif mode == "fan_out":
125
+ denom = fan_out
126
+ elif mode == "fan_avg":
127
+ denom = (fan_in + fan_out) / 2
128
+
129
+ variance = scale / denom
130
+
131
+ if distribution == "truncated_normal":
132
+ # constant is stddev of standard normal truncated to (-2, 2)
133
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
134
+ elif distribution == "normal":
135
+ with torch.no_grad():
136
+ tensor.normal_(std=math.sqrt(variance))
137
+ elif distribution == "uniform":
138
+ bound = math.sqrt(3 * variance)
139
+ with torch.no_grad():
140
+ tensor.uniform_(-bound, bound)
141
+ else:
142
+ raise ValueError(f"invalid distribution {distribution}")
143
+
144
+
145
+ def lecun_normal_(tensor):
146
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
147
+
148
+
149
+ def default_flax_embed_init(tensor):
150
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
162
+ orig_dtype = tensor.dtype
163
+ tensor = tensor.float()
164
+ cos = freqs.cos()
165
+ sin = freqs.sin()
166
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
167
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
168
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
169
+ output = output.to(orig_dtype)
170
+ return output
171
+
172
+
173
+ class VisionRotaryEmbedding(nn.Module):
174
+
175
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
176
+ super().__init__()
177
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
178
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
179
+
180
+ def forward(self, seqlen: int) -> torch.Tensor:
181
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
182
+ freqs = torch.outer(seq, self.inv_freq)
183
+ return freqs
184
+
185
+
186
+ class HulumedVisionEmbeddings(nn.Module):
187
+
188
+ def __init__(self, config: HulumedVisionEncoderConfig):
189
+ super().__init__()
190
+ self.config = config
191
+ self.embed_dim = config.hidden_size
192
+ self.patch_size = config.patch_size
193
+
194
+ self.patch_embedding = nn.Conv2d(
195
+ in_channels=config.num_channels,
196
+ out_channels=self.embed_dim,
197
+ kernel_size=self.patch_size,
198
+ stride=self.patch_size,
199
+ padding="valid",
200
+ )
201
+
202
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
203
+ hidden_states = hidden_states.view(
204
+ -1, self.config.num_channels, self.patch_size, self.patch_size
205
+ )
206
+ patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
207
+ # embeddings = patch_embeds.flatten(2).transpose(1, 2)
208
+ embeddings = patch_embeds.view(-1, self.embed_dim)
209
+
210
+ return embeddings
211
+
212
+
213
+ class VisionAttention(nn.Module):
214
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
215
+
216
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
217
+ def __init__(self, config):
218
+ super().__init__()
219
+ self.config = config
220
+ self.embed_dim = config.hidden_size
221
+ self.num_heads = config.num_attention_heads
222
+ self.head_dim = self.embed_dim // self.num_heads
223
+ if self.head_dim * self.num_heads != self.embed_dim:
224
+ raise ValueError(
225
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
226
+ f" {self.num_heads})."
227
+ )
228
+ self.scale = self.head_dim**-0.5
229
+ self.dropout = config.attention_dropout
230
+
231
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
232
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
233
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
234
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ cu_seqlens: torch.Tensor,
240
+ rotary_pos_emb: torch.Tensor = None,
241
+ ) -> torch.Tensor:
242
+ """Input shape: Time x Channel"""
243
+
244
+ q_len, _ = hidden_states.size()
245
+
246
+ query_states = self.q_proj(hidden_states)
247
+ key_states = self.k_proj(hidden_states)
248
+ value_states = self.v_proj(hidden_states)
249
+
250
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
251
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
252
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
253
+
254
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
255
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
256
+
257
+ attention_mask = torch.zeros([1, q_len, q_len], device=query_states.device, dtype=torch.bool)
258
+ for i in range(1, len(cu_seqlens)):
259
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
260
+
261
+ query_states = query_states.transpose(0, 1)
262
+ key_states = key_states.transpose(0, 1)
263
+ value_states = value_states.transpose(0, 1)
264
+
265
+ attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
266
+ attn_weights = attn_weights + attention_mask
267
+
268
+ # upcast attention to fp32
269
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
270
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
271
+ attn_output = torch.matmul(attn_weights, value_states)
272
+
273
+ attn_output = attn_output.transpose(0, 1)
274
+ attn_output = attn_output.reshape(q_len, -1)
275
+ attn_output = self.out_proj(attn_output)
276
+
277
+ return attn_output
278
+
279
+
280
+ class VisionFlashAttention2(VisionAttention):
281
+
282
+ def __init__(self, *args, **kwargs):
283
+ super().__init__(*args, **kwargs)
284
+
285
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
286
+ def forward(
287
+ self,
288
+ hidden_states: torch.Tensor,
289
+ cu_seqlens: torch.Tensor,
290
+ rotary_pos_emb: torch.Tensor = None,
291
+ ) -> torch.Tensor:
292
+ q_len, _ = hidden_states.size()
293
+
294
+ query_states = self.q_proj(hidden_states)
295
+ key_states = self.k_proj(hidden_states)
296
+ value_states = self.v_proj(hidden_states)
297
+
298
+ # Flash attention requires the input to have the shape
299
+ # batch_size x seq_length x head_dim x hidden_dim
300
+ # therefore we just need to keep the original shape
301
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
302
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
303
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
304
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
305
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
306
+
307
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
308
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
309
+ q_len, -1
310
+ )
311
+ attn_output = self.out_proj(attn_output)
312
+
313
+ return attn_output
314
+
315
+
316
+ class VisionSdpaAttention(VisionAttention):
317
+
318
+ def forward(
319
+ self,
320
+ hidden_states: torch.Tensor,
321
+ cu_seqlens: torch.Tensor,
322
+ rotary_pos_emb: torch.Tensor = None,
323
+ ) -> torch.Tensor:
324
+ seq_length = hidden_states.shape[0]
325
+ query_states = self.q_proj(hidden_states)
326
+ key_states = self.k_proj(hidden_states)
327
+ value_states = self.v_proj(hidden_states)
328
+
329
+ query_states = query_states.view(seq_length, self.num_heads, self.head_dim)
330
+ key_states = key_states.view(seq_length, self.num_heads, self.head_dim)
331
+ value_states = value_states.view(seq_length, self.num_heads, self.head_dim)
332
+
333
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
334
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
335
+
336
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
337
+ for i in range(1, len(cu_seqlens)):
338
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
339
+
340
+ query_states = query_states.transpose(0, 1)
341
+ key_states = key_states.transpose(0, 1)
342
+ value_states = value_states.transpose(0, 1)
343
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
344
+ attn_output = attn_output.transpose(0, 1)
345
+ attn_output = attn_output.reshape(seq_length, -1)
346
+ attn_output = self.out_proj(attn_output)
347
+ return attn_output
348
+
349
+
350
+ VISION_ATTENTION_CLASSES = {
351
+ "eager": VisionAttention,
352
+ "flash_attention_2": VisionFlashAttention2,
353
+ "sdpa": VisionSdpaAttention,
354
+ }
355
+
356
+
357
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Hulumed
358
+ class HulumedVisionMLP(nn.Module):
359
+
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.config = config
363
+ self.activation_fn = ACT2FN[config.hidden_act]
364
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
365
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
366
+
367
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
368
+ hidden_states = self.fc1(hidden_states)
369
+ hidden_states = self.activation_fn(hidden_states)
370
+ hidden_states = self.fc2(hidden_states)
371
+ return hidden_states
372
+
373
+
374
+ class HulumedVisionEncoderLayer(nn.Module):
375
+
376
+ def __init__(self, config: HulumedVisionEncoderConfig):
377
+ super().__init__()
378
+ self.embed_dim = config.hidden_size
379
+ self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
380
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
381
+ self.mlp = HulumedVisionMLP(config)
382
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
383
+
384
+ # Ignore copy
385
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
386
+ hidden_states = hidden_states + self.self_attn(
387
+ self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
388
+ )
389
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
390
+ return hidden_states
391
+
392
+
393
+ class HulumedVisionTransformerEncoder(nn.Module):
394
+
395
+ def __init__(self, config: HulumedVisionEncoderConfig):
396
+ super().__init__()
397
+ self.config = config
398
+ head_dim = config.hidden_size // config.num_attention_heads
399
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
400
+ self.layers = nn.ModuleList([HulumedVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
401
+ self.gradient_checkpointing = False
402
+
403
+ def rot_pos_emb(self, grid_sizes, merge_sizes):
404
+ pos_ids = []
405
+ for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
406
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
407
+ hpos_ids = hpos_ids.reshape(
408
+ h // merge_size,
409
+ merge_size,
410
+ w // merge_size,
411
+ merge_size,
412
+ )
413
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
414
+ hpos_ids = hpos_ids.flatten()
415
+
416
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
417
+ wpos_ids = wpos_ids.reshape(
418
+ h // merge_size,
419
+ merge_size,
420
+ w // merge_size,
421
+ merge_size,
422
+ )
423
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
424
+ wpos_ids = wpos_ids.flatten()
425
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
426
+
427
+ pos_ids = torch.cat(pos_ids, dim=0)
428
+ max_grid_size = grid_sizes[:, 1:].max()
429
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
430
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
431
+
432
+ return rotary_pos_emb
433
+
434
+ def forward(self, hidden_states, grid_sizes, merge_sizes) -> torch.Tensor:
435
+ rotary_pos_emb = self.rot_pos_emb(grid_sizes, merge_sizes)
436
+
437
+ cu_seqlens = torch.repeat_interleave(grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]).cumsum(dim=0, dtype=torch.int32)
438
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
439
+
440
+ for blk in self.layers:
441
+ if self.gradient_checkpointing and self.training:
442
+ hidden_states = self._gradient_checkpointing_func(
443
+ blk.__call__,
444
+ hidden_states,
445
+ cu_seqlens,
446
+ rotary_pos_emb
447
+ )
448
+ else:
449
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
450
+
451
+ return hidden_states
452
+
453
+
454
+ class HulumedVisionEncoderModel(PreTrainedModel):
455
+
456
+ config_class = HulumedVisionEncoderConfig
457
+ base_model_prefix = "hulumed"
458
+ main_input_name = "pixel_values"
459
+ supports_gradient_checkpointing = True
460
+ _no_split_modules = [
461
+ "HulumedVisionEncoderLayer",
462
+ "HulumedVisionEmbeddings",
463
+ ]
464
+ _supports_flash_attn_2 = True
465
+ _supports_sdpa = True
466
+
467
+ def __init__(self, config: HulumedVisionEncoderConfig):
468
+ super().__init__(config=config)
469
+ embed_dim = config.hidden_size
470
+
471
+ self.embeddings = HulumedVisionEmbeddings(config)
472
+ self.encoder = HulumedVisionTransformerEncoder(config)
473
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
474
+
475
+ self.post_init()
476
+
477
+ def forward(self, pixel_values, grid_sizes, merge_sizes=None) -> torch.Tensor:
478
+ hidden_states = self.embeddings(pixel_values)
479
+ hidden_states = self.encoder(hidden_states, grid_sizes, merge_sizes)
480
+ hidden_states = self.post_layernorm(hidden_states)
481
+
482
+ hidden_states_chunks = hidden_states.split(grid_sizes.prod(dim=1).tolist(), dim=0)
483
+ outputs = []
484
+
485
+ for hidden_states, grid_size, merge_size in zip(hidden_states_chunks, grid_sizes, merge_sizes):
486
+ # NOTE: previous implementation, which supports downsampling with any factor
487
+ c = hidden_states.shape[-1]
488
+ hidden_states = hidden_states.view(
489
+ grid_size[0], grid_size[1] // merge_size, grid_size[2] // merge_size, merge_size, merge_size, c
490
+ ).permute(0, 1, 3, 2, 4, 5)
491
+ hidden_states = hidden_states.reshape(
492
+ grid_size[0], grid_size[1], grid_size[2], c
493
+ ).permute(0, 3, 1, 2)
494
+ hidden_states = torch.nn.functional.interpolate(
495
+ hidden_states,
496
+ size=(grid_size[1] // merge_size, grid_size[2] // merge_size),
497
+ mode='bilinear'
498
+ )
499
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(-1, c)
500
+
501
+ # NOTE: simplified implementation, which only supports downsampling with integer factor
502
+ # NOTE: this implementation is mathematically equivalent to the previous one when merge_size is 1 or 2 but may cause slightly different results
503
+ # hidden_states = hidden_states.view(-1, merge_size * merge_size, hidden_states.size(-1))
504
+ # hidden_states = hidden_states.mean(dim=1)
505
+
506
+ outputs.append(hidden_states)
507
+
508
+ return torch.cat(outputs, dim=0)
509
+
510
+ def _init_weights(self, module):
511
+ """Initialize the weights"""
512
+ if isinstance(module, nn.Embedding):
513
+ default_flax_embed_init(module.weight)
514
+ elif isinstance(module, VisionAttention):
515
+ nn.init.xavier_uniform_(module.q_proj.weight)
516
+ nn.init.xavier_uniform_(module.k_proj.weight)
517
+ nn.init.xavier_uniform_(module.v_proj.weight)
518
+ nn.init.xavier_uniform_(module.out_proj.weight)
519
+ nn.init.zeros_(module.q_proj.bias)
520
+ nn.init.zeros_(module.k_proj.bias)
521
+ nn.init.zeros_(module.v_proj.bias)
522
+ nn.init.zeros_(module.out_proj.bias)
523
+ elif isinstance(module, HulumedVisionMLP):
524
+ nn.init.xavier_uniform_(module.fc1.weight)
525
+ nn.init.xavier_uniform_(module.fc2.weight)
526
+ nn.init.normal_(module.fc1.bias, std=1e-6)
527
+ nn.init.normal_(module.fc2.bias, std=1e-6)
528
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
529
+ lecun_normal_(module.weight)
530
+ if module.bias is not None:
531
+ nn.init.zeros_(module.bias)
532
+ elif isinstance(module, nn.LayerNorm):
533
+ module.bias.data.zero_()
534
+ module.weight.data.fill_(1.0)
modeling_hulumed_qwen3.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch HuluMed model."""
16
+
17
+ import importlib.util
18
+ import os.path as osp
19
+ import re
20
+ from abc import ABC, abstractmethod
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.utils.checkpoint
26
+ from transformers import (AutoConfig, AutoModelForCausalLM, Qwen3Config,AutoModel,
27
+ Qwen3ForCausalLM, Qwen3Model)
28
+ from transformers.generation.utils import GenerateOutput
29
+ from transformers.modeling_outputs import CausalLMOutputWithPast
30
+
31
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
32
+ WORKER_HEART_BEAT_INTERVAL = 15
33
+
34
+ LOGDIR = "."
35
+
36
+ # Model Constants
37
+ IGNORE_INDEX = -100
38
+
39
+ # Image arguments
40
+ IMAGE_TOKEN_INDEX = -200
41
+ DEFAULT_IMAGE_TOKEN = "<image>"
42
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
43
+ DEFAULT_IM_START_TOKEN = "<im_start>"
44
+ DEFAULT_IM_END_TOKEN = "<im_end>"
45
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
46
+
47
+ # Video arguments
48
+ VIDEO_TOKEN_INDEX = -201
49
+ DEFAULT_VIDEO_TOKEN = "<video>"
50
+ NUM_FRAMES = 128
51
+ MAX_FRAMES = 768
52
+ NUM_FRAMES_PER_SECOND = 1
53
+
54
+ # Audio arguments
55
+ AUDIO_TOKEN_INDEX = -202
56
+ DEFAULT_AUDIO_TOKEN = "<audio>"
57
+
58
+ # Stream arguments
59
+ STREAM_START_TOKEN = "<|stream_start|>"
60
+ STREAM_END_TOKEN = "<|stream_end|>"
61
+ STREAM_MAX_FRAMES = 400
62
+
63
+ MODAL_INDEX_MAP = {
64
+ "<image>": -200,
65
+ "<video>": -201,
66
+ "<audio>": -202,
67
+ }
68
+
69
+ subimage_token_num=196
70
+ try:
71
+ from .configuration_hulumed_qwen3 import HulumedQwen3Config
72
+ except ModuleNotFoundError:
73
+ spec = importlib.util.spec_from_file_location(
74
+ "configuration_hulumed_qwen3",
75
+ osp.join(osp.dirname(__file__), "configuration_hulumed_qwen3.py"),
76
+ )
77
+ configuration_hulumed_qwen3 = importlib.util.module_from_spec(spec)
78
+ spec.loader.exec_module(configuration_hulumed_qwen3)
79
+ HulumedQwen3Config = getattr(
80
+ configuration_hulumed_qwen3,
81
+ "HulumedQwen3Config",
82
+ )
83
+
84
+
85
+ def build_mlp(depth, hidden_size, output_hidden_size):
86
+ """Build MLP layers for projection."""
87
+ modules = [nn.Linear(hidden_size, output_hidden_size)]
88
+ for _ in range(1, depth):
89
+ modules.append(nn.GELU())
90
+ modules.append(nn.Linear(output_hidden_size, output_hidden_size))
91
+ return nn.Sequential(*modules)
92
+
93
+
94
+ def build_vision_projector(config, delay_load=False, **kwargs):
95
+ """Build vision projector based on config."""
96
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
97
+
98
+ if projector_type == "linear":
99
+ return nn.Linear(config.vision_encoder_config.hidden_size, config.hidden_size)
100
+ elif projector_type.startswith("mlp"):
101
+ return MlpGeluProjector(config, projector_type)
102
+ else:
103
+ raise ValueError(f'Unknown projector type: {projector_type}')
104
+
105
+
106
+ class MlpGeluProjector(nn.Module):
107
+ """MLP projector with GELU activation."""
108
+
109
+ def __init__(self, config, projector_type):
110
+ super().__init__()
111
+
112
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
113
+ if mlp_gelu_match is None:
114
+ raise ValueError(f"Invalid projector type format: {projector_type}")
115
+ mlp_depth = int(mlp_gelu_match.group(1))
116
+
117
+ self.readout = build_mlp(
118
+ mlp_depth,
119
+ config.vision_encoder_config.hidden_size,
120
+ config.hidden_size
121
+ )
122
+
123
+ def forward(self, x):
124
+ return self.readout(x)
125
+
126
+
127
+ class HulumedMetaModel:
128
+ """Meta model for HuluMed that handles vision encoder initialization."""
129
+
130
+ def __init__(self, config):
131
+ super(HulumedMetaModel, self).__init__(config)
132
+ print('config.vision_encoder',config.vision_encoder)
133
+ if config.vision_encoder is not None:
134
+ # Load from pretrained path
135
+ print('Load from pretrained path')
136
+ self.vision_encoder = AutoModel.from_pretrained(
137
+ config.vision_encoder,
138
+ attn_implementation=self.config._attn_implementation,
139
+ torch_dtype=self.dtype,
140
+ )
141
+ self.config.vision_encoder_config = self.vision_encoder.config
142
+ self.config.vision_encoder = None
143
+ elif config.vision_encoder_config is not None:
144
+ # Build from config
145
+ print('Build from config')
146
+ self.vision_encoder = AutoModel.from_config(
147
+ self.config.vision_encoder_config,
148
+ attn_implementation=self.config._attn_implementation,
149
+ torch_dtype=self.dtype,
150
+ )
151
+ else:
152
+ raise ValueError("Vision encoder is not provided in config")
153
+
154
+ self.mm_projector = build_vision_projector(config)
155
+
156
+ def get_vision_encoder(self):
157
+ return self.vision_encoder
158
+
159
+ def get_mm_projector(self):
160
+ return self.mm_projector
161
+
162
+
163
+ class HulumedQwen3Model(HulumedMetaModel, Qwen3Model):
164
+
165
+ config_class = HulumedQwen3Config
166
+
167
+ def __init__(self, config: HulumedQwen3Config):
168
+ super(HulumedQwen3Model, self).__init__(config)
169
+
170
+
171
+ class HulumedMetaForCausalLM(ABC):
172
+ """Meta class for HuluMed Causal LM with multimodal support."""
173
+
174
+ @abstractmethod
175
+ def get_model(self):
176
+ pass
177
+
178
+ def get_vision_encoder(self):
179
+ return self.get_model().get_vision_encoder()
180
+
181
+ def get_mm_projector(self):
182
+ return self.get_model().get_mm_projector()
183
+
184
+ def encode_images(
185
+ self,
186
+ pixel_values: torch.FloatTensor,
187
+ grid_sizes: torch.LongTensor,
188
+ merge_sizes: torch.LongTensor,
189
+ ) -> torch.FloatTensor:
190
+ """Encode images using vision encoder and projector."""
191
+ mm_features = self.get_model().get_vision_encoder()(
192
+ pixel_values=pixel_values,
193
+ grid_sizes=grid_sizes,
194
+ merge_sizes=merge_sizes,
195
+ )
196
+ mm_features = self.get_model().mm_projector(mm_features)
197
+ return mm_features
198
+
199
+ def _get_valid_visual_tokens(
200
+ self,
201
+ mm_features: torch.FloatTensor,
202
+ batched_num_patches: torch.LongTensor,
203
+ modals: List[str],
204
+ ):
205
+ """Filter out text-only samples and keep only valid visual tokens."""
206
+ valid_masks = []
207
+ for num_patches, modal in zip(batched_num_patches, modals):
208
+ valid_mask = torch.full(
209
+ (num_patches,),
210
+ modal != "text",
211
+ dtype=torch.bool,
212
+ device=mm_features.device
213
+ )
214
+ valid_masks.append(valid_mask)
215
+ mm_features = mm_features[torch.cat(valid_masks)]
216
+ return mm_features
217
+
218
+ def _maybe_truncate_visual_tokens(
219
+ self,
220
+ mm_features: torch.FloatTensor,
221
+ compression_mask: torch.BoolTensor,
222
+ batched_num_patches: torch.LongTensor,
223
+ modals: List[str],
224
+ input_ids: torch.LongTensor,
225
+ position_ids: Optional[torch.LongTensor] = None,
226
+ ):
227
+ """Truncate visual tokens if necessary based on position_ids."""
228
+ if position_ids is None or mm_features.shape[0] == input_ids.eq(self.config.image_token_index).sum():
229
+ return mm_features, compression_mask
230
+
231
+ truncation_mask = []
232
+ for num_patches, modal in zip(batched_num_patches, modals):
233
+ if modal == "text":
234
+ truncation_mask.append(torch.ones((0,), dtype=torch.bool, device=input_ids.device))
235
+ else:
236
+ truncation_mask.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device))
237
+
238
+ seq_end_indices = torch.nonzero(position_ids == 0)[:, 0]
239
+ seq_end_indices = seq_end_indices[seq_end_indices > 0].tolist() + [len(input_ids)]
240
+ seq_start_indices = [0] + seq_end_indices[:-1]
241
+ num_visual_tokens = [
242
+ input_ids[start:end].eq(self.config.image_token_index).sum()
243
+ for start, end in zip(seq_start_indices, seq_end_indices)
244
+ ]
245
+
246
+ for n, mask in zip(num_visual_tokens, truncation_mask):
247
+ if len(mask) > 0:
248
+ mask[n:] = False
249
+ truncation_mask = torch.cat(truncation_mask)
250
+
251
+ return mm_features[truncation_mask], compression_mask[truncation_mask]
252
+
253
+ def _get_compression_mask(
254
+ self,
255
+ pixel_values: torch.FloatTensor,
256
+ batched_num_patches: torch.LongTensor,
257
+ grid_sizes: torch.LongTensor,
258
+ merge_sizes: torch.LongTensor,
259
+ modals: List[str],
260
+ threshold: float = 0.1,
261
+ min_tokens: int = 1,
262
+ ) -> torch.BoolTensor:
263
+ """Get compression mask for video tokens based on frame differences."""
264
+ batched_images = pixel_values.split(grid_sizes.prod(dim=1).tolist(), dim=0)
265
+ compression_masks = []
266
+
267
+ for images, num_patches, grid_size, merge_size, modal in zip(
268
+ batched_images, batched_num_patches, grid_sizes, merge_sizes, modals
269
+ ):
270
+ t, h, w = grid_size
271
+ if modal == "image" or (modal == "video" and t == 1):
272
+ compression_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=images.device))
273
+
274
+ elif modal == "video":
275
+ # Video token compression based on pixel differences
276
+ images = images.view(t, (h // merge_size) * (w // merge_size), -1)
277
+
278
+ pixel_diff = images[1:] - images[:-1]
279
+ pixel_diff = torch.abs(pixel_diff).mean(dim=-1) * 255
280
+ pixel_diff = torch.cat([torch.full_like(pixel_diff[0:1], threshold + 1), pixel_diff], dim=0)
281
+ mask = (pixel_diff / 255.0) > threshold
282
+ padding_ids = torch.nonzero(mask.sum(dim=1) < min_tokens)[:, 0]
283
+ mask[padding_ids, :min_tokens] = 1
284
+ compression_masks.append(mask.flatten())
285
+
286
+ else:
287
+ # Pseudo image case
288
+ compression_masks.append(torch.ones((0,), dtype=torch.bool, device=images.device))
289
+
290
+ return torch.cat(compression_masks)
291
+
292
+ def _compress_visual_tokens(
293
+ self,
294
+ compression_mask: torch.BoolTensor,
295
+ mm_features: torch.FloatTensor,
296
+ input_ids: torch.LongTensor,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ position_ids: Optional[torch.LongTensor] = None,
299
+ labels: Optional[torch.LongTensor] = None,
300
+ ):
301
+ """Compress visual tokens based on compression mask."""
302
+ mm_features = mm_features[compression_mask]
303
+ image_selected = (input_ids == self.config.image_token_index)
304
+
305
+ text_masks = torch.logical_not(image_selected)
306
+ text_masks[image_selected] = compression_mask
307
+ input_ids = input_ids[text_masks]
308
+
309
+ if attention_mask is not None:
310
+ attention_mask = attention_mask[text_masks]
311
+ if labels is not None:
312
+ labels = labels[text_masks]
313
+ if position_ids is not None:
314
+ position_ids = position_ids[text_masks]
315
+ pos_start = [0] + torch.nonzero(position_ids == 0)[:, 0].tolist()
316
+ pos_end = pos_start[1:] + [len(input_ids)]
317
+ position_ids = torch.cat([
318
+ torch.arange(end - start, device=input_ids.device)
319
+ for start, end in zip(pos_start, pos_end)
320
+ ])
321
+
322
+ return mm_features, input_ids, attention_mask, position_ids, labels
323
+
324
+ def prepare_inputs_labels_for_multimodal(
325
+ self,
326
+ input_ids: torch.LongTensor = None,
327
+ attention_mask: Optional[torch.Tensor] = None,
328
+ position_ids: Optional[torch.LongTensor] = None,
329
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
330
+ labels: Optional[torch.LongTensor] = None,
331
+ pixel_values: Optional[torch.FloatTensor] = None,
332
+ grid_sizes: Optional[torch.LongTensor] = None,
333
+ merge_sizes: Optional[torch.LongTensor] = None,
334
+ modals: Optional[List[str]] = None,
335
+ ):
336
+ """Prepare inputs and labels for multimodal training/inference."""
337
+ vision_encoder = self.get_vision_encoder()
338
+
339
+ # Text-only situation
340
+ if vision_encoder is None or pixel_values is None or input_ids.shape[1] == 1:
341
+ return input_ids, attention_mask, position_ids, past_key_values, None, labels
342
+
343
+ # 1. Flatten text inputs
344
+ B, N = input_ids.shape
345
+ input_ids = input_ids.view(B * N)
346
+ if attention_mask is not None:
347
+ attention_mask = attention_mask.view(B * N)
348
+ if position_ids is not None:
349
+ position_ids = position_ids.view(B * N)
350
+ if labels is not None:
351
+ labels = labels.view(B * N)
352
+
353
+ # 2. Embed visual tokens
354
+ batched_num_patches = grid_sizes.prod(dim=1).div(merge_sizes ** 2).long()
355
+ mm_features = self.encode_images(pixel_values, grid_sizes, merge_sizes).to(input_ids.device)
356
+ mm_features = self._get_valid_visual_tokens(mm_features, batched_num_patches, modals)
357
+
358
+ compression_mask = self._get_compression_mask(
359
+ pixel_values, batched_num_patches, grid_sizes, merge_sizes, modals
360
+ )
361
+ mm_features, compression_mask = self._maybe_truncate_visual_tokens(
362
+ mm_features, compression_mask, batched_num_patches, modals, input_ids, position_ids
363
+ )
364
+
365
+ # 3. Compress visual tokens if enabled
366
+ if self.config.use_token_compression:
367
+ assert B == 1, "Token compression is only supported for batch_size=1"
368
+ mm_features, input_ids, attention_mask, position_ids, labels = self._compress_visual_tokens(
369
+ compression_mask, mm_features, input_ids, attention_mask, position_ids, labels
370
+ )
371
+
372
+ # 4. Embed text tokens
373
+ inputs_embeds = self.get_model().embed_tokens(input_ids).clone()
374
+
375
+ # 5. Replace multimodal tokens with features
376
+ image_selected = (input_ids == self.config.image_token_index)
377
+ inputs_embeds[image_selected] = inputs_embeds[image_selected] * 0.0 + mm_features
378
+
379
+ # 6. Reshape back to batched format
380
+ C = inputs_embeds.shape[-1]
381
+ inputs_embeds = inputs_embeds.reshape(B, -1, C)
382
+ if attention_mask is not None:
383
+ attention_mask = attention_mask.view(B, -1)
384
+ if labels is not None:
385
+ labels = labels.view(B, -1)
386
+ if position_ids is not None:
387
+ position_ids = position_ids.view(B, -1)
388
+
389
+ return None, attention_mask, position_ids, past_key_values, inputs_embeds, labels
390
+
391
+
392
+ class HulumedQwen3ForCausalLM(Qwen3ForCausalLM, HulumedMetaForCausalLM):
393
+
394
+ config_class = HulumedQwen3Config
395
+
396
+ def __init__(self, config, **kwargs):
397
+ super(Qwen3ForCausalLM, self).__init__(config)
398
+ self.model = HulumedQwen3Model(config)
399
+ self.vocab_size = config.vocab_size
400
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
401
+
402
+ # Initialize weights and apply final processing
403
+ self.post_init()
404
+
405
+ def get_model(self):
406
+ return self.model
407
+
408
+ def forward(
409
+ self,
410
+ input_ids: torch.LongTensor = None,
411
+ attention_mask: Optional[torch.Tensor] = None,
412
+ position_ids: Optional[torch.LongTensor] = None,
413
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
414
+ inputs_embeds: Optional[torch.FloatTensor] = None,
415
+ labels: Optional[torch.LongTensor] = None,
416
+ use_cache: Optional[bool] = None,
417
+ output_attentions: Optional[bool] = None,
418
+ output_hidden_states: Optional[bool] = None,
419
+ return_dict: Optional[bool] = None,
420
+ cache_position: Optional[torch.LongTensor] = None,
421
+ num_logits_to_keep: int = 0,
422
+ # Multimodal inputs
423
+ pixel_values: Optional[torch.FloatTensor] = None,
424
+ grid_sizes: Optional[torch.LongTensor] = None,
425
+ merge_sizes: Optional[torch.LongTensor] = None,
426
+ modals: Optional[List[str]] = None,
427
+ **loss_kwargs,
428
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
429
+ """Forward pass with multimodal support."""
430
+ if inputs_embeds is None:
431
+ (
432
+ input_ids,
433
+ attention_mask,
434
+ position_ids,
435
+ past_key_values,
436
+ inputs_embeds,
437
+ labels,
438
+ ) = self.prepare_inputs_labels_for_multimodal(
439
+ input_ids=input_ids,
440
+ attention_mask=attention_mask,
441
+ position_ids=position_ids,
442
+ past_key_values=past_key_values,
443
+ labels=labels,
444
+ pixel_values=pixel_values,
445
+ grid_sizes=grid_sizes,
446
+ merge_sizes=merge_sizes,
447
+ modals=modals,
448
+ )
449
+
450
+ return super().forward(
451
+ input_ids=input_ids,
452
+ attention_mask=attention_mask,
453
+ position_ids=position_ids,
454
+ past_key_values=past_key_values,
455
+ inputs_embeds=inputs_embeds,
456
+ labels=labels,
457
+ use_cache=use_cache,
458
+ output_attentions=output_attentions,
459
+ output_hidden_states=output_hidden_states,
460
+ return_dict=return_dict,
461
+ cache_position=cache_position,
462
+ num_logits_to_keep=num_logits_to_keep,
463
+ **loss_kwargs,
464
+ )
465
+
466
+ @torch.no_grad()
467
+ def generate(
468
+ self,
469
+ # Multimodal inputs
470
+ pixel_values: Optional[torch.FloatTensor] = None,
471
+ grid_sizes: Optional[torch.LongTensor] = None,
472
+ merge_sizes: Optional[torch.LongTensor] = None,
473
+ modals: Optional[List[str]] = None,
474
+ **kwargs,
475
+ ) -> Union[GenerateOutput, torch.LongTensor]:
476
+ """Generate with multimodal support."""
477
+ input_ids = kwargs.pop("input_ids", None)
478
+ attention_mask = kwargs.pop("attention_mask", None)
479
+ position_ids = kwargs.pop("position_ids", None)
480
+ past_key_values = kwargs.pop("past_key_values", None)
481
+
482
+ if "inputs_embeds" in kwargs:
483
+ raise NotImplementedError("`inputs_embeds` is not supported")
484
+
485
+ if pixel_values is not None:
486
+ (
487
+ input_ids,
488
+ attention_mask,
489
+ position_ids,
490
+ past_key_values,
491
+ inputs_embeds,
492
+ labels,
493
+ ) = self.prepare_inputs_labels_for_multimodal(
494
+ input_ids=input_ids,
495
+ attention_mask=attention_mask,
496
+ position_ids=position_ids,
497
+ past_key_values=past_key_values,
498
+ labels=None,
499
+ pixel_values=pixel_values,
500
+ grid_sizes=grid_sizes,
501
+ merge_sizes=merge_sizes,
502
+ modals=modals,
503
+ )
504
+ else:
505
+ inputs_embeds = self.get_model().embed_tokens(input_ids)
506
+
507
+ return super().generate(
508
+ position_ids=position_ids,
509
+ attention_mask=attention_mask,
510
+ inputs_embeds=inputs_embeds,
511
+ **kwargs
512
+ )
513
+
514
+ def prepare_inputs_for_generation(
515
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
516
+ ):
517
+ """Prepare inputs for generation."""
518
+ images = kwargs.pop("images", None)
519
+ _inputs = super().prepare_inputs_for_generation(
520
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
521
+ )
522
+ if images is not None:
523
+ _inputs['images'] = images
524
+ return _inputs
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_hulumed.HulumedImageProcessor",
4
+ "AutoProcessor": "processing_hulumed.HulumedProcessor"
5
+ },
6
+ "do_convert_rgb": true,
7
+ "do_normalize": true,
8
+ "do_rescale": true,
9
+ "do_resize": true,
10
+ "image_mean": [
11
+ 0.5,
12
+ 0.5,
13
+ 0.5
14
+ ],
15
+ "image_processor_type": "HulumedImageProcessor",
16
+ "image_std": [
17
+ 0.5,
18
+ 0.5,
19
+ 0.5
20
+ ],
21
+ "max_tokens": 16384,
22
+ "min_tokens": 16,
23
+ "patch_size": 14,
24
+ "processor_class": "HulumedProcessor",
25
+ "resample": 3,
26
+ "rescale_factor": 0.00392156862745098
27
+ }
processing_hulumed.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Processor class for HuluMed with 3D support."""
2
+
3
+ import copy
4
+ import importlib.util
5
+ import os
6
+ import os.path as osp
7
+ import warnings
8
+ from collections import defaultdict
9
+ from typing import Any, List, Union, Dict, Optional, Tuple, TypedDict
10
+
11
+ import cv2
12
+ import ffmpeg
13
+ import imageio
14
+ import json
15
+ import numpy as np
16
+ import torch
17
+ import transformers
18
+ from decord import VideoReader, cpu
19
+ from PIL import Image
20
+ from transformers.feature_extraction_utils import BatchFeature
21
+ from transformers.image_utils import ImageInput
22
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
23
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
24
+
25
+ try:
26
+ import nibabel as nib
27
+ NIBABEL_AVAILABLE = True
28
+ except ImportError:
29
+ NIBABEL_AVAILABLE = False
30
+ warnings.warn("nibabel is not installed. 3D medical imaging support will be limited. Install with: pip install nibabel")
31
+
32
+ try:
33
+ from . import image_processing_hulumed
34
+ from .image_processing_hulumed import (
35
+ is_valid_image, is_valid_video,
36
+ )
37
+ except ModuleNotFoundError:
38
+ spec = importlib.util.spec_from_file_location(
39
+ "image_processing_hulumed",
40
+ osp.join(osp.dirname(__file__), "image_processing_hulumed.py"),
41
+ )
42
+ image_processing_hulumed = importlib.util.module_from_spec(spec)
43
+ spec.loader.exec_module(image_processing_hulumed)
44
+ is_valid_image = getattr(image_processing_hulumed, "is_valid_image")
45
+ is_valid_video = getattr(image_processing_hulumed, "is_valid_video")
46
+
47
+ DEFAULT_IMAGE_TOKEN = "<image>"
48
+ IGNORE_INDEX = -100
49
+
50
+ Conversation = List[Dict[str, Any]]
51
+ SingleImage = Union[Image.Image, np.ndarray, torch.Tensor]
52
+ SingleVideo = Union[List[SingleImage], np.ndarray, torch.Tensor]
53
+ BatchedImage = List[Union[SingleImage, SingleVideo]]
54
+ BatchedNamedImage = List[Tuple[str, Union[SingleImage, SingleVideo]]]
55
+
56
+
57
+ def _custom_import(class_name: str):
58
+ try:
59
+ attribute_class = getattr(transformers, class_name)
60
+ except AttributeError:
61
+ attribute_class = getattr(image_processing_hulumed, class_name)
62
+ return attribute_class
63
+
64
+
65
+ def is_named_image(image) -> bool:
66
+ return isinstance(image, (list, tuple)) and \
67
+ len(image) == 2 and \
68
+ isinstance(image[0], str) and \
69
+ image[0] in ["image", "video", "3d"] and \
70
+ (is_valid_image(image[1]) or is_valid_video(image[1]))
71
+
72
+
73
+ def make_batched_images(images) -> Tuple[List[str], List[ImageInput]]:
74
+ if isinstance(images, (list, tuple)) and all(is_named_image(image) for image in images):
75
+ modals = [image[0] if image[0] != "3d" else "video" for image in images]
76
+ data = [image[1] for image in images]
77
+ return modals, data
78
+ elif isinstance(images, (list, tuple)) and all(is_valid_image(image) or is_valid_video(image) for image in images):
79
+ batch = []
80
+ for image in images:
81
+ if is_valid_video(image):
82
+ batch.append(("video", image))
83
+ elif is_valid_image(image):
84
+ batch.append(("image", image))
85
+ else:
86
+ raise ValueError(f"Could not make batched images from {images}")
87
+ return [x[0] for x in batch], [x[1] for x in batch]
88
+ elif is_named_image(images):
89
+ modal = images[0] if images[0] != "3d" else "video"
90
+ return [modal], [images[1]]
91
+ elif is_valid_video(images):
92
+ return ["video"], [images]
93
+ elif is_valid_image(images):
94
+ return ["image"], [images]
95
+
96
+ raise ValueError(f"Could not make batched images from {images}")
97
+
98
+
99
+ def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None):
100
+ if mode == 'uniform':
101
+ assert num_frames is not None, "Number of frames must be provided for uniform sampling."
102
+ if duration <= num_frames:
103
+ return np.arange(duration).astype(int)
104
+ return np.linspace(0, duration-1, num_frames, dtype=int)
105
+ elif mode == 'fps':
106
+ assert vid_fps is not None, "FPS must be provided for FPS sampling."
107
+ assert fps is not None, "FPS must be provided for FPS sampling."
108
+ segment_len = min(vid_fps // fps, duration)
109
+ return np.arange(segment_len // 2, duration, segment_len, dtype=int)
110
+ else:
111
+ raise ValueError(f'Unsupported frame sampling mode: {mode}')
112
+
113
+
114
+ def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=128, temporal_factor=1):
115
+ if s is not None and e is not None:
116
+ s = s if s >= 0. else 0.
117
+ e = e if e >= 0. else 0.
118
+ if s > e:
119
+ s, e = e, s
120
+ elif s == e:
121
+ e = s + 1
122
+
123
+ if os.path.isdir(video_path):
124
+ frame_files = sorted(os.listdir(video_path))
125
+ vid_fps = 3
126
+ num_frames_of_video = len(frame_files)
127
+ elif video_path.endswith('.gif'):
128
+ gif_reader = imageio.get_reader(video_path)
129
+ vid_fps = 25
130
+ num_frames_of_video = len(gif_reader)
131
+ else:
132
+ vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
133
+ vid_fps = vreader.get_avg_fps()
134
+ num_frames_of_video = len(vreader)
135
+
136
+ f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
137
+ f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
138
+ frame_indices = list(range(f_start, f_end + 1))
139
+
140
+ duration = len(frame_indices)
141
+ if fps is not None and duration / vid_fps < max_frames:
142
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps)]
143
+ else:
144
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
145
+
146
+ if os.path.isdir(video_path):
147
+ frames = np.array([cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices])
148
+ elif video_path.endswith('.gif'):
149
+ frames = np.array([cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices])
150
+ else:
151
+ frames = vreader.get_batch(sampled_frame_indices).asnumpy()
152
+
153
+ frames = frames.transpose(0, 3, 1, 2)
154
+ timestamps = [x / vid_fps for x in sampled_frame_indices]
155
+
156
+ if temporal_factor > 1:
157
+ pad_length = temporal_factor - len(frames) % temporal_factor
158
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
159
+ [timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
160
+
161
+ frames = [frame for frame in frames]
162
+
163
+ return frames, timestamps
164
+
165
+
166
+ class ChatTemplateKwargs(TypedDict, total=False):
167
+ chat_template: Optional[str]
168
+ add_system_prompt: Optional[bool]
169
+ add_generation_prompt: Optional[bool]
170
+
171
+
172
+ class HulumedProcessorKwargs(ProcessingKwargs, ChatTemplateKwargs, total=False):
173
+ chat_template_kwargs: ChatTemplateKwargs = {
174
+ **ChatTemplateKwargs.__annotations__,
175
+ }
176
+
177
+ _defaults = {
178
+ "text_kwargs": {
179
+ "padding": False,
180
+ },
181
+ "images_kwargs": {
182
+
183
+ },
184
+ "chat_template_kwargs": {
185
+ "chat_template": None,
186
+ "add_system_prompt": False,
187
+ "add_generation_prompt": False,
188
+ },
189
+ }
190
+
191
+
192
+ class HulumedProcessor(ProcessorMixin):
193
+ attributes = ["image_processor", "tokenizer"]
194
+ image_processor_class = "HulumedImageProcessor"
195
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
196
+ valid_kwargs = ["chat_template", "image_merge_size", "video_merge_size", "fps", "max_frames"]
197
+
198
+ def __init__(
199
+ self,
200
+ image_processor=None,
201
+ tokenizer=None,
202
+ chat_template: str = None,
203
+ image_merge_size: int = 1,
204
+ video_merge_size: int = 2,
205
+ fps: Optional[int] = 1,
206
+ max_frames: Optional[int] = 128,
207
+ ):
208
+ self.image_processor = image_processor
209
+ self.tokenizer = tokenizer
210
+ if chat_template is None:
211
+ chat_template = self.tokenizer.chat_template
212
+ self.chat_template = chat_template
213
+
214
+ self.image_merge_size = image_merge_size
215
+ self.video_merge_size = video_merge_size
216
+ self.fps = fps
217
+ self.max_frames = max_frames
218
+
219
+ self.generation_prompt = self._infer_generation_prompt()
220
+ self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt")
221
+ self.generation_prompt_length = len(self.generation_prompt_ids[0])
222
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
223
+ self.eos_token_id = self.tokenizer.eos_token_id
224
+
225
+ @classmethod
226
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
227
+ args = []
228
+ for attribute_name in cls.attributes:
229
+ class_name = getattr(cls, f"{attribute_name}_class")
230
+ if isinstance(class_name, tuple):
231
+ classes = tuple(_custom_import(n) if n is not None else None for n in class_name)
232
+ use_fast = kwargs.get("use_fast", True)
233
+ if use_fast and classes[1] is not None:
234
+ attribute_class = classes[1]
235
+ else:
236
+ attribute_class = classes[0]
237
+ else:
238
+ attribute_class = _custom_import(class_name)
239
+
240
+ args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
241
+ return args
242
+
243
+ def get_generation_prompt(self):
244
+ return self.generation_prompt
245
+
246
+ def get_generation_prompt_ids(self):
247
+ return self.generation_prompt_ids
248
+
249
+ def _infer_generation_prompt(self):
250
+ pseudo_message = [{"role": "user", "content": ""}]
251
+ instruction = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True)
252
+ conversation = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False)
253
+ return instruction.replace(conversation, "")
254
+
255
+ def _get_downsampled_grid_sizes(self, image_inputs: Dict[str, Any]):
256
+ grid_sizes = []
257
+ for grid_size, merge_size in zip(image_inputs.get("grid_sizes", []), image_inputs.get("merge_sizes", [])):
258
+ if not torch.all(grid_size[1:] % merge_size == 0):
259
+ warnings.warn(f"Grid size {grid_size} is not divisible by merge size. Some undesired errors may occur.")
260
+ if grid_size[0] == 1:
261
+ grid_sizes.append(grid_size[1:] / merge_size)
262
+ elif grid_size[0] > 1:
263
+ grid_sizes.extend([grid_size[1:] / merge_size] * grid_size[0])
264
+ return grid_sizes
265
+
266
+ def _get_visual_seq_len(self, grid_size: torch.Tensor):
267
+ num_tokens = int(grid_size.prod().item())
268
+ return num_tokens
269
+
270
+ def load_images(self, image_path: Union[str, List[str], Image.Image, List[Image.Image]]):
271
+ if isinstance(image_path, str) and os.path.isfile(image_path):
272
+ images = [Image.open(image_path).convert('RGB')]
273
+ elif isinstance(image_path, str) and os.path.isdir(image_path):
274
+ images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
275
+ elif isinstance(image_path, list) and isinstance(image_path[0], str):
276
+ images = [Image.open(f).convert('RGB') for f in image_path]
277
+ elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
278
+ images = [np.array(x) for x in image_path]
279
+ elif isinstance(image_path, Image.Image):
280
+ images = [np.array(image_path)]
281
+ else:
282
+ raise ValueError(f"Unsupported image path type: {type(image_path)}")
283
+ return images
284
+
285
+ def load_nii(
286
+ self,
287
+ nii_path: str,
288
+ num_slices: Optional[int] = None,
289
+ axis: int = 2,
290
+ window_center: Optional[float] = None,
291
+ window_width: Optional[float] = None,
292
+ normalize: bool = True,
293
+ ):
294
+ if not NIBABEL_AVAILABLE:
295
+ raise ImportError("nibabel is required for NIfTI support. Install with: pip install nibabel")
296
+
297
+ if not os.path.exists(nii_path):
298
+ raise FileNotFoundError(f"NIfTI file not found: {nii_path}")
299
+
300
+ nii_img = nib.load(nii_path)
301
+ volume = nii_img.get_fdata()
302
+
303
+ if axis == 0:
304
+ slices = [volume[i, :, :] for i in range(volume.shape[0])]
305
+ elif axis == 1:
306
+ slices = [volume[:, i, :] for i in range(volume.shape[1])]
307
+ elif axis == 2:
308
+ slices = [volume[:, :, i] for i in range(volume.shape[2])]
309
+ else:
310
+ raise ValueError(f"Invalid axis: {axis}. Must be 0, 1, or 2.")
311
+
312
+ if num_slices is not None and num_slices < len(slices):
313
+ indices = np.linspace(0, len(slices) - 1, num_slices, dtype=int)
314
+ slices = [slices[i] for i in indices]
315
+
316
+ processed_slices = []
317
+ for slice_2d in slices:
318
+ if window_center is not None and window_width is not None:
319
+ lower = window_center - window_width / 2
320
+ upper = window_center + window_width / 2
321
+ slice_2d = np.clip(slice_2d, lower, upper)
322
+
323
+ if normalize:
324
+ slice_min = slice_2d.min()
325
+ slice_max = slice_2d.max()
326
+ if slice_max > slice_min:
327
+ slice_2d = (slice_2d - slice_min) / (slice_max - slice_min) * 255.0
328
+ else:
329
+ slice_2d = np.zeros_like(slice_2d)
330
+
331
+ slice_2d = slice_2d.astype(np.uint8)
332
+ slice_rgb = np.stack([slice_2d] * 3, axis=0)
333
+
334
+ processed_slices.append(slice_rgb)
335
+
336
+ return processed_slices
337
+
338
+ def load_video(
339
+ self,
340
+ video_path: str,
341
+ start_time: Optional[float] = None,
342
+ end_time: Optional[float] = None,
343
+ fps: Optional[float] = None,
344
+ max_frames: Optional[float] = None,
345
+ size: Optional[int] = None,
346
+ size_divisible: int = 1,
347
+ precise_time: bool = False,
348
+ verbose: bool = False,
349
+ temporal_factor: int = 1
350
+ ):
351
+ fps = self.fps if fps is None else fps
352
+ max_frames = self.max_frames if max_frames is None else max_frames
353
+
354
+ if start_time is not None and end_time is not None and end_time - start_time < 1:
355
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
356
+ if os.path.isdir(video_path):
357
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
358
+ if video_path.endswith('.gif'):
359
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
360
+
361
+ probe = ffmpeg.probe(video_path)
362
+ duration = float(probe['format']['duration'])
363
+ video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
364
+ w, h = int(video_stream['width']), int(video_stream['height'])
365
+
366
+ kwargs, input_kwargs, output_kwargs = {}, {}, {}
367
+ do_trim = start_time is not None or end_time is not None
368
+ if start_time is not None:
369
+ new_start_time = max(float(video_stream['start_time']), start_time)
370
+ duration -= new_start_time - start_time
371
+ start_time = new_start_time
372
+ else:
373
+ start_time = float(video_stream['start_time'])
374
+ if end_time is not None:
375
+ duration = min(duration, end_time - start_time)
376
+ if do_trim:
377
+ kwargs = {'ss': start_time, 't': duration}
378
+ if precise_time:
379
+ output_kwargs.update(kwargs)
380
+ else:
381
+ input_kwargs.update(kwargs)
382
+
383
+ if size is not None:
384
+ scale_factor = size / min(w, h)
385
+ new_w, new_h = round(w * scale_factor), round(h * scale_factor)
386
+ else:
387
+ new_w, new_h = w, h
388
+ new_w = new_w // size_divisible * size_divisible
389
+ new_h = new_h // size_divisible * size_divisible
390
+
391
+ stream = ffmpeg.input(video_path, **input_kwargs)
392
+ if fps is not None:
393
+ stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
394
+ if new_w != w or new_h != h:
395
+ stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
396
+ stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
397
+ out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
398
+
399
+ frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
400
+
401
+ if fps is not None:
402
+ timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
403
+ else:
404
+ timestamps = np.linspace(start_time, start_time + duration, len(frames))
405
+
406
+ if max_frames is not None and len(frames) > max_frames:
407
+ indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
408
+ frames = frames[indices]
409
+ timestamps = timestamps[indices]
410
+
411
+ if temporal_factor > 1:
412
+ pad_length = temporal_factor - len(frames) % temporal_factor
413
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
414
+ timestamps = np.concatenate([timestamps, timestamps[-1:].repeat(pad_length) + np.arange(1, pad_length + 1) / fps])
415
+
416
+ frames = [frame for frame in frames]
417
+ timestamps = [timestamp for timestamp in timestamps]
418
+
419
+ return frames, timestamps
420
+
421
+ def _load_multimodal_data(self, conversation: Conversation):
422
+ multimodal_info = defaultdict(list)
423
+ new_conversation = []
424
+ for message in conversation:
425
+ new_message = {"role": message["role"]}
426
+ if not isinstance(message["content"], (list, tuple)):
427
+ new_message["content"] = message["content"]
428
+ new_conversation.append(new_message)
429
+ continue
430
+
431
+ new_contents = []
432
+ for content in message["content"]:
433
+ if not isinstance(content, dict):
434
+ new_contents.append(content)
435
+ continue
436
+ assert "type" in content, "Content must have 'type' field."
437
+
438
+ if content["type"] in ["image", "video", "3d"] and content["type"] in content and isinstance(content[content["type"]], dict):
439
+ load_args = content[content["type"]]
440
+ data_id = json.dumps({k: v for k, v in load_args.items() if k not in ["start_time", "end_time"]})
441
+ new_content = copy.deepcopy(content)
442
+ multimodal_info[data_id].append(new_content)
443
+ new_contents.append(new_content)
444
+ else:
445
+ new_contents.append(content)
446
+
447
+ new_message["content"] = new_contents
448
+ new_conversation.append(new_message)
449
+
450
+ for data_id, contents in multimodal_info.items():
451
+ data_type = contents[0]["type"]
452
+
453
+ if data_type == "image":
454
+ image = self.load_images(contents[0][data_type]["image_path"])[0]
455
+ for content in contents:
456
+ content["image"] = [image.copy()]
457
+
458
+ elif data_type == "3d":
459
+ load_args = contents[0]["3d"]
460
+ nii_path = load_args["image_path"]
461
+ num_slices = load_args.get("nii_num_slices", None)
462
+ axis = load_args.get("nii_axis", 2)
463
+ window_center = load_args.get("window_center", None)
464
+ window_width = load_args.get("window_width", None)
465
+
466
+ slices = self.load_nii(
467
+ nii_path=nii_path,
468
+ num_slices=num_slices,
469
+ axis=axis,
470
+ window_center=window_center,
471
+ window_width=window_width,
472
+ )
473
+
474
+ for content in contents:
475
+ content["type"] = "video"
476
+ content["video"] = slices
477
+ content["num_frames"] = len(slices)
478
+ content.pop("3d", None)
479
+
480
+ elif data_type == "video":
481
+ start_times = [content["video"].get("start_time", 0.) for content in contents]
482
+ end_times = [content["video"].get("end_time", float("inf")) for content in contents]
483
+
484
+ load_args = contents[0][data_type]
485
+ start_time, end_time = min(start_times), max(end_times)
486
+ if start_time > 0:
487
+ load_args["start_time"] = start_time
488
+ if end_time < float("inf"):
489
+ load_args["end_time"] = end_time
490
+ images, timestamps = self.load_video(**load_args)
491
+
492
+ for content, start_time, end_time in zip(contents, start_times, end_times):
493
+ cur_images, cur_timestamps = [], []
494
+ for image, timestamp in zip(images, timestamps):
495
+ if start_time <= timestamp <= end_time:
496
+ cur_images.append(image.copy())
497
+ cur_timestamps.append(timestamp)
498
+
499
+ content[data_type] = cur_images
500
+ content["num_frames"] = len(cur_images)
501
+ content["timestamps"] = cur_timestamps
502
+
503
+ return new_conversation
504
+
505
+ def _gather_multimodal_data(self, conversation: Conversation):
506
+ images = []
507
+ for message in conversation:
508
+ if not isinstance(message["content"], (list, tuple)):
509
+ continue
510
+ for content in message["content"]:
511
+ if not isinstance(content, dict):
512
+ continue
513
+ if content["type"] == "video":
514
+ video = content["video"]
515
+ assert is_valid_video(video), f"Invalid video data: {video}."
516
+ images.append(("video", video))
517
+ elif content["type"] == "image":
518
+ image = content["image"]
519
+ images.append(("image", image))
520
+ images = images if len(images) > 0 else None
521
+ return images
522
+
523
+ def _process_conversation_with_label(
524
+ self,
525
+ conversation: Conversation,
526
+ image_inputs: Dict[str, Any],
527
+ **kwargs,
528
+ ):
529
+ assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True."
530
+ assert "add_generation_prompt" not in kwargs, "'add_generation_prompt' argument is not supported when return_labels=True."
531
+
532
+ output_kwargs = self._merge_kwargs(
533
+ HulumedProcessorKwargs,
534
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
535
+ **kwargs,
536
+ )
537
+ output_kwargs["chat_template_kwargs"].pop("add_generation_prompt")
538
+
539
+ grid_sizes = self._get_downsampled_grid_sizes(image_inputs)
540
+ text_inputs = {"input_ids": [], "labels": []}
541
+ sample_types_list = []
542
+ image_idx = 0
543
+
544
+ for message_idx, message in enumerate(conversation):
545
+ prompt = self.apply_chat_template(
546
+ [message],
547
+ tokenize=False,
548
+ add_generation_prompt=False,
549
+ **output_kwargs["chat_template_kwargs"],
550
+ )
551
+ prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN)
552
+ prompt = []
553
+ for chunk_idx in range(len(prompt_chunks) - 1):
554
+ prompt.append(prompt_chunks[chunk_idx])
555
+ num_tokens = self._get_visual_seq_len(grid_sizes[image_idx])
556
+ prompt.append(DEFAULT_IMAGE_TOKEN * num_tokens)
557
+ image_idx += 1
558
+ prompt.append(prompt_chunks[-1])
559
+ prompt = "".join(prompt)
560
+
561
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt", **output_kwargs["text_kwargs"])[0]
562
+ text_inputs["input_ids"].append(input_ids)
563
+
564
+ targets = torch.full_like(input_ids, IGNORE_INDEX)
565
+ sample_types = torch.full_like(input_ids, IGNORE_INDEX)
566
+ if message["role"] == "assistant":
567
+ targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone()
568
+ elif message["role"] == "stream":
569
+ diff = torch.diff((input_ids == self.image_token_id).float())
570
+ image_end_indices = torch.nonzero(diff < 0)[:, 0]
571
+ targets[image_end_indices + 1] = input_ids[image_end_indices + 1]
572
+ sample_types = targets.clone()
573
+ sample_types[torch.logical_and(sample_types > 0, sample_types != self.eos_token_id)] = 0
574
+ targets[-2] = input_ids[-2]
575
+
576
+ if message_idx > 0 and conversation[message_idx - 1]["role"] == "stream":
577
+ targets[0] = input_ids[0]
578
+ sample_types[0] = input_ids[0]
579
+
580
+ text_inputs["labels"].append(targets)
581
+ sample_types_list.append(sample_types)
582
+
583
+ text_inputs = {k: torch.cat(v) for k, v in text_inputs.items()}
584
+ sample_types = torch.cat(sample_types_list)
585
+ types, counts = torch.unique(sample_types[sample_types > -1], return_counts=True)
586
+
587
+ if len(types) > 0:
588
+ target_num_samples = counts.amin()
589
+ for type_id, type_count in zip(types, counts):
590
+ if type_count > target_num_samples:
591
+ indices = torch.nonzero(sample_types == type_id)[:, 0]
592
+ random_selector = torch.randperm(indices.size(0))[:-target_num_samples]
593
+ text_inputs["labels"][indices[random_selector]] = IGNORE_INDEX
594
+
595
+ assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
596
+
597
+ return text_inputs
598
+
599
+ def _process_conversation_without_label(
600
+ self,
601
+ conversation: Conversation,
602
+ image_inputs: Dict[str, Any],
603
+ **kwargs,
604
+ ):
605
+ output_kwargs = self._merge_kwargs(
606
+ HulumedProcessorKwargs,
607
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
608
+ **kwargs,
609
+ )
610
+ prompt = self.apply_chat_template(
611
+ conversation,
612
+ tokenize=False,
613
+ **output_kwargs["chat_template_kwargs"],
614
+ )
615
+ return self.process_text(prompt, image_inputs, **output_kwargs["text_kwargs"])
616
+
617
+ def _process_conversation(
618
+ self,
619
+ conversation: Conversation,
620
+ images: Optional[Union[BatchedImage, BatchedNamedImage]] = None,
621
+ return_labels: bool = False,
622
+ **kwargs: Unpack[HulumedProcessorKwargs],
623
+ ) -> BatchFeature:
624
+ assert isinstance(conversation, list), "Conversation must be a list of messages."
625
+
626
+ if images is None:
627
+ conversation = self._load_multimodal_data(conversation)
628
+ images = self._gather_multimodal_data(conversation)
629
+
630
+ output_kwargs = self._merge_kwargs(
631
+ HulumedProcessorKwargs,
632
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
633
+ **kwargs,
634
+ )
635
+
636
+ if images is not None:
637
+ image_inputs = self.process_images(images, **output_kwargs["images_kwargs"])
638
+ else:
639
+ image_inputs = {}
640
+
641
+ if return_labels:
642
+ text_inputs = self._process_conversation_with_label(conversation, image_inputs, **kwargs)
643
+ else:
644
+ text_inputs = self._process_conversation_without_label(conversation, image_inputs, **kwargs)
645
+
646
+ return BatchFeature(data={**text_inputs, **image_inputs})
647
+
648
+ def _process_plain(
649
+ self,
650
+ text: Union[TextInput, PreTokenizedInput] = None,
651
+ images: Optional[Union[BatchedImage, BatchedNamedImage]] = None,
652
+ return_labels: bool = False,
653
+ **kwargs: Unpack[HulumedProcessorKwargs],
654
+ ) -> BatchFeature:
655
+ if text is None:
656
+ raise ValueError("You must provide 'text' or 'conversation'.")
657
+ if return_labels:
658
+ raise ValueError("return_labels is not supported for plain text processing.")
659
+
660
+ output_kwargs = self._merge_kwargs(
661
+ HulumedProcessorKwargs,
662
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
663
+ **kwargs,
664
+ )
665
+
666
+ if images is not None:
667
+ image_inputs = self.process_images(images, **output_kwargs["images_kwargs"])
668
+ else:
669
+ image_inputs = {}
670
+
671
+ text_inputs = self.process_text(text, image_inputs, **output_kwargs["text_kwargs"])
672
+
673
+ return BatchFeature(data={**text_inputs, **image_inputs})
674
+
675
+ def process_images(self, images: Union[BatchedImage, BatchedNamedImage], **kwargs):
676
+ modals, images = make_batched_images(images)
677
+
678
+ if "merge_size" not in kwargs:
679
+ kwargs["merge_size"] = [
680
+ self.video_merge_size if modal == "video" else self.image_merge_size
681
+ for modal in modals
682
+ ]
683
+
684
+ image_inputs = self.image_processor(images=images, **kwargs)
685
+ image_inputs["modals"] = modals
686
+ return image_inputs
687
+
688
+ def process_text(
689
+ self,
690
+ text: TextInput,
691
+ image_inputs: Dict[str, Any],
692
+ **kwargs,
693
+ ):
694
+ grid_sizes = self._get_downsampled_grid_sizes(image_inputs)
695
+
696
+ kwargs.pop("padding", None)
697
+ kwargs.pop("padding_side", None)
698
+
699
+ if len(grid_sizes) > 0:
700
+ image_idx = 0
701
+ while DEFAULT_IMAGE_TOKEN in text:
702
+ num_tokens = self._get_visual_seq_len(grid_sizes[image_idx])
703
+ text = text.replace(DEFAULT_IMAGE_TOKEN, "<placeholder>" * num_tokens, 1)
704
+ image_idx += 1
705
+ text = text.replace("<placeholder>", DEFAULT_IMAGE_TOKEN)
706
+
707
+ assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
708
+
709
+ text_inputs = self.tokenizer(text, **kwargs)
710
+ return text_inputs
711
+
712
+ def __call__(
713
+ self,
714
+ text: Optional[TextInput] = None,
715
+ conversation: Optional[Conversation] = None,
716
+ images: Optional[Union[BatchedImage, BatchedNamedImage]] = None,
717
+ return_labels: bool = False,
718
+ **kwargs: Unpack[HulumedProcessorKwargs],
719
+ ) -> BatchFeature:
720
+ if conversation is not None:
721
+ if text is not None:
722
+ raise ValueError("You cannot provide both 'conversation' and 'text'.")
723
+ return self._process_conversation(conversation, images, return_labels, **kwargs)
724
+ return self._process_plain(text, images, return_labels, **kwargs)
725
+
726
+ def batch_decode(self, *args, skip_special_tokens=True, use_think=False, **kwargs):
727
+ outputs = self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
728
+
729
+ if not use_think:
730
+ outputs = [self._remove_think_tags(output) for output in outputs]
731
+
732
+ return outputs
733
+
734
+ def decode(self, *args, skip_special_tokens=True, use_think=False, **kwargs):
735
+ output = self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
736
+
737
+ if not use_think:
738
+ output = self._remove_think_tags(output)
739
+
740
+ return output
741
+
742
+ def _remove_think_tags(self, text: str) -> str:
743
+ import re
744
+ pattern = r'<think>.*?</think>'
745
+ cleaned = re.sub(pattern, '', text, flags=re.DOTALL)
746
+ cleaned = re.sub(r'\n\s*\n', '\n\n', cleaned)
747
+ cleaned = cleaned.strip()
748
+ return cleaned
749
+
750
+ def apply_chat_template(
751
+ self,
752
+ conversation: Conversation,
753
+ chat_template: Optional[str] = None,
754
+ tokenize: bool = False,
755
+ add_system_prompt: bool = False,
756
+ add_generation_prompt: bool = False,
757
+ image_token: Optional[str] = DEFAULT_IMAGE_TOKEN,
758
+ **kwargs,
759
+ ) -> str:
760
+ if chat_template is None:
761
+ if self.chat_template is not None:
762
+ chat_template = self.chat_template
763
+ else:
764
+ raise ValueError(
765
+ "No chat template is set for this processor. Please either set the `chat_template` attribute, "
766
+ "or provide a chat template as an argument."
767
+ )
768
+ return self.tokenizer.apply_chat_template(
769
+ conversation,
770
+ chat_template=chat_template,
771
+ tokenize=tokenize,
772
+ add_system_prompt=add_system_prompt,
773
+ add_generation_prompt=add_generation_prompt,
774
+ image_token=image_token,
775
+ **kwargs
776
+ )
777
+
778
+ @property
779
+ def model_input_names(self):
780
+ tokenizer_input_names = self.tokenizer.model_input_names
781
+ image_processor_input_names = self.image_processor.model_input_names
782
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + ["modals"]
783
+
784
+ def _merge_kwargs(
785
+ self,
786
+ ModelProcessorKwargs: ProcessingKwargs,
787
+ tokenizer_init_kwargs: Optional[Dict] = None,
788
+ **kwargs,
789
+ ) -> Dict[str, Dict]:
790
+ output_kwargs = {
791
+ "text_kwargs": {},
792
+ "images_kwargs": {},
793
+ "audio_kwargs": {},
794
+ "videos_kwargs": {},
795
+ "chat_template_kwargs": {},
796
+ "common_kwargs": {},
797
+ }
798
+
799
+ default_kwargs = {
800
+ "text_kwargs": {},
801
+ "images_kwargs": {},
802
+ "audio_kwargs": {},
803
+ "videos_kwargs": {},
804
+ "chat_template_kwargs": {},
805
+ "common_kwargs": {},
806
+ }
807
+
808
+ used_keys = set()
809
+
810
+ for modality in default_kwargs:
811
+ default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
812
+ for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
813
+ if modality_key in tokenizer_init_kwargs:
814
+ value = (
815
+ getattr(self.tokenizer, modality_key)
816
+ if hasattr(self.tokenizer, modality_key)
817
+ else tokenizer_init_kwargs[modality_key]
818
+ )
819
+ default_kwargs[modality][modality_key] = value
820
+
821
+ output_kwargs.update(default_kwargs)
822
+
823
+ non_modality_kwargs = set(kwargs) - set(output_kwargs)
824
+ for modality in output_kwargs:
825
+ for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
826
+ if modality in kwargs:
827
+ kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
828
+ if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
829
+ raise ValueError(
830
+ f"Keyword argument {modality_key} was passed twice: "
831
+ f"in a dictionary for {modality} and as a **kwarg."
832
+ )
833
+ elif modality_key in kwargs:
834
+ kwarg_value = kwargs.get(modality_key, "__empty__")
835
+ else:
836
+ kwarg_value = "__empty__"
837
+ if kwarg_value != "__empty__":
838
+ output_kwargs[modality][modality_key] = kwarg_value
839
+ used_keys.add(modality_key)
840
+
841
+ if any(key in default_kwargs for key in kwargs):
842
+ for modality, subdict in kwargs.items():
843
+ if modality in default_kwargs:
844
+ for subkey, subvalue in subdict.items():
845
+ if subkey not in used_keys:
846
+ output_kwargs[modality][subkey] = subvalue
847
+ used_keys.add(subkey)
848
+ else:
849
+ for key in kwargs:
850
+ if key not in used_keys:
851
+ output_kwargs["common_kwargs"][key] = kwargs[key]
852
+
853
+ for modality in output_kwargs:
854
+ output_kwargs[modality].update(output_kwargs["common_kwargs"])
855
+
856
+ return output_kwargs
processor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_hulumed.HulumedProcessor"
4
+ },
5
+ "fps": 1,
6
+ "image_merge_size": 1,
7
+ "max_frames": 128,
8
+ "processor_class": "HulumedProcessor",
9
+ "video_merge_size": 2
10
+ }
tokenizer.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c569fae330e56901f76e6da36771c6febaefd461368057efdba32696f69381de
3
- size 11423222
 
 
 
 
tokenizer_config.json CHANGED
@@ -251,7 +251,7 @@
251
  "<|video_pad|>"
252
  ],
253
  "bos_token": null,
254
- "chat_template": "\n{%- set identifier = 'im' %}\n{% for message in messages %}\n {% if message['role'] == 'stream' %}\n {% set identifier = 'stream' %}\n {% else %}\n {% set identifier = 'im' %}\n {% endif %}\n {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}\n {% if message['content'] is string %}\n {{- message['content'] + '<|' + identifier + '_end|>\n' -}}\n {% else %}\n {% for content in message['content'] %}\n {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n\n {{- '<image>\n' -}}\n\n {% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}\n {% for i in range(content['num_frames']) %}\n {% if i < content['num_frames'] - 1 %}\n\n {{- '<image>,' -}}\n\n {% else %}\n\n {{- '<image>\n' -}}\n\n {% endif %}\n {% endfor %}\n {% elif content['type'] == 'text' or 'text' in content %}\n {{- content['text'] -}}\n {% endif %}\n {% endfor %}\n {{- '<|' + identifier + '_end|>\n' -}}\n {% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' -}}\n{% endif %}\n",
255
  "clean_up_tokenization_spaces": false,
256
  "eos_token": "<|im_end|>",
257
  "errors": "replace",
 
251
  "<|video_pad|>"
252
  ],
253
  "bos_token": null,
254
+ "chat_template": "\n{%- set identifier = 'im' %}\n{% for message in messages %}\n {% if message['role'] == 'stream' %}\n {% set identifier = 'stream' %}\n {% else %}\n {% set identifier = 'im' %}\n {% endif %}\n {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}\n {% if message['content'] is string %}\n {{- message['content'] + '<|' + identifier + '_end|>\n' -}}\n {% else %}\n {% for content in message['content'] %}\n {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n {% if 'time' in content %}\n {{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}\n {% endif %}\n\n {{- '<image>\n' -}}\n\n {% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}\n {% for i in range(content['num_frames']) %}\n {% if 'timestamps' in content %}\n {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}\n {% endif %}\n {% if i < content['num_frames'] - 1 %}\n\n {{- '<image>,' -}}\n\n {% else %}\n\n {{- '<image>\n' -}}\n\n {% endif %}\n {% endfor %}\n {% elif content['type'] == 'text' or 'text' in content %}\n {{- content['text'] -}}\n {% endif %}\n {% endfor %}\n {{- '<|' + identifier + '_end|>\n' -}}\n {% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' -}}\n{% endif %}\n",
255
  "clean_up_tokenization_spaces": false,
256
  "eos_token": "<|im_end|>",
257
  "errors": "replace",
vocab.json CHANGED
The diff for this file is too large to render. See raw diff