File size: 11,511 Bytes
4e3b08f
 
e043d8b
 
 
4e3b08f
e043d8b
4e3b08f
 
 
 
e043d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e3b08f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e043d8b
 
4e3b08f
 
 
 
 
 
 
 
e043d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e3b08f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e043d8b
 
4e3b08f
e043d8b
4e3b08f
e043d8b
4e3b08f
e043d8b
4e3b08f
 
e043d8b
4e3b08f
 
 
 
 
e043d8b
 
4e3b08f
 
 
e043d8b
4e3b08f
e043d8b
 
 
4e3b08f
e043d8b
 
 
4e3b08f
e043d8b
 
 
 
 
 
 
 
 
 
 
 
4e3b08f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e043d8b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from typing import Optional, Union
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin, Cache, BatchFeature
from modelscope import AutoConfig, AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
from transformers.image_utils import ImageInput
from transformers.modeling_outputs import CausalLMOutputWithPast
from liger_kernel.transformers import LigerCrossEntropyLoss
from transformers.processing_utils import Unpack, ProcessorMixin
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
from transformers.utils import TransformersKwargs


class Qwenov3Config(PretrainedConfig):
    model_type = "Qwenov3"

    def __init__(self, llm_model_path='Qwen/Qwen3-0.6B',

                 vision_model_path='facebook/dinov3-vitl16-pretrain-lvd1689m',

                 freeze_vision_model=False,

                 freeze_llm_model=False,

                 image_pad_num=49,

                 training_scratch=False,

                 num_hidden_layers=None,

                 hidden_size=None,

                 num_attention_heads=None,

                 vocab_size=None,

                 **kwargs):
        self.vision_model_path = vision_model_path
        self.llm_model_path = llm_model_path
        self.freeze_vision_model = freeze_vision_model
        self.freeze_llm_model = freeze_llm_model
        self.image_pad_num = image_pad_num
        self.freeze_vision_model = freeze_vision_model
        self.training_scratch = training_scratch
        self.num_hidden_layers = num_hidden_layers
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.vocab_size = vocab_size
        
        super().__init__(**kwargs)


class Qwenov3Processor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_num=49, **kwargs):
        self.image_token = "<|image_pad|>"
        self.image_pad_num = image_pad_num
        if chat_template is None and tokenizer is not None:
            chat_template = getattr(tokenizer, "chat_template", None)
        super().__init__(image_processor, tokenizer, chat_template=chat_template)

    def __call__(

        self,

        images: Optional[ImageInput] = None,

        text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,

        return_tensors: str = "pt",

        **kwargs,

    ) -> BatchFeature:
        image_inputs = {}
        if images is not None:
            image_inputs = {'pixel_values': self.image_processor(images=images, return_tensors="pt")['pixel_values']}

        if not isinstance(text, list):
            text = [text]

        processed_text = []
        for t in text:
            replacement = '<|vision_start|>' + '<|image_pad|>' * self.image_pad_num + '<|vision_end|>'
            if '<image>' not in t:
                t = t.replace('<|im_end|>', '<image><|im_end|>', 1)
            processed_text.append(t.replace('<image>', replacement))

        tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in ['images']}
        text_inputs = self.tokenizer(processed_text, return_tensors=return_tensors, **tokenizer_kwargs)

        return BatchFeature(data={**text_inputs, **image_inputs})


class Qwenov3(GenerationMixin, PreTrainedModel):
    config_class = Qwenov3Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["MoeDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_sdpa = True
    _supports_flash_attn = True
    _can_compile_fullgraph = False
    _supports_attention_backend = True

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        if self.config.training_scratch:
            self.vision_model = AutoModel.from_pretrained(self.config.vision_model_path, low_cpu_mem_usage=True,
                                                          dtype=torch.bfloat16, attn_implementation="flash_attention_2")
            self.llm_model = AutoModelForCausalLM.from_pretrained(self.config.llm_model_path, low_cpu_mem_usage=True,
                                                                  dtype=torch.bfloat16,
                                                                  attn_implementation="flash_attention_2")
        else:
            vision_config = AutoConfig.from_pretrained(self.config.vision_model_path)
            self.vision_model = AutoModel.from_config(vision_config, attn_implementation="sdpa", dtype=torch.bfloat16)
            llm_config = AutoConfig.from_pretrained(self.config.llm_model_path)
            self.llm_model = AutoModelForCausalLM.from_config(llm_config, attn_implementation="sdpa", dtype=torch.bfloat16)

        if self.config.num_hidden_layers is None:
            self.config.num_hidden_layers = self.llm_model.config.num_hidden_layers
        if self.config.hidden_size is None:
            self.config.hidden_size = self.llm_model.config.hidden_size
        if self.config.num_attention_heads is None:
            self.config.num_attention_heads = self.llm_model.config.num_attention_heads
        if self.config.vocab_size is None:
            self.config.vocab_size = self.llm_model.config.vocab_size

        self.processor = AutoProcessor.from_pretrained(self.config.vision_model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path, use_fast=True)

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if '<|image_pad|>' not in self.tokenizer.get_vocab():
            self.tokenizer.add_tokens(['<|image_pad|>'])
            self.llm_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
        if '<|vision_start|>' not in self.tokenizer.get_vocab():
            self.tokenizer.add_tokens(['<|vision_start|>'])
            self.llm_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
        if '<|vision_end|>' not in self.tokenizer.get_vocab():
            self.tokenizer.add_tokens(['<|vision_end|>'])
            self.llm_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)

        self.adapter = nn.Sequential(
            nn.RMSNorm(4096, dtype=torch.bfloat16),
            nn.Linear(4096, self.llm_model.config.hidden_size, dtype=torch.bfloat16),
            nn.GELU(),
            nn.Linear(self.llm_model.config.hidden_size, self.llm_model.config.hidden_size, dtype=torch.bfloat16)
        )

        if self.config.freeze_vision_model:
            for param in self.vision_model.parameters():
                param.requires_grad = False
        if self.config.freeze_llm_model:
            for param in self.llm_model.parameters():
                param.requires_grad = False

    def forward(

        self,

        input_ids: Optional[torch.LongTensor] = None,

        pixel_values: Optional[torch.LongTensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        past_key_values: Optional[Cache] = None,

        labels: Optional[torch.LongTensor] = None,

        use_cache: Optional[bool] = None,

        cache_position: Optional[torch.LongTensor] = None,

        logits_to_keep: Union[int, torch.Tensor] = 0,

        **kwargs: Unpack[TransformersKwargs],

    ):
        text_embeds = self.llm_model.get_input_embeddings()(input_ids)
        if pixel_values is not None:
            image_embeds = self.vision_model(pixel_values).last_hidden_state
            patch_embeds = image_embeds[:, 5:, :]  # [batch, 196, 1024]
            b, num_patches, hidden_dim = patch_embeds.shape
            patch_embeds = patch_embeds.view(b, num_patches // 4, hidden_dim * 4)  # [batch, 49, 4096]
            image_features = self.adapter(patch_embeds)
            text_embeds = text_embeds.to(image_features.dtype)
            inputs_embeds = self.merge_input_ids_with_image_features(image_features, text_embeds, input_ids)
        else:
            inputs_embeds = text_embeds

        outputs = self.llm_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss_fct = LigerCrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1).to(logits.device))

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    @torch.inference_mode()
    def generate(self, input_ids=None, pixel_values=None, attention_mask=None,

                 max_new_tokens=512, temperature=0.7, top_p=0.8, top_k=20,

                 do_sample=True, num_beams=1, use_cache=True, **kwargs):
        if pixel_values is not None:
            text_embeds = self.llm_model.get_input_embeddings()(input_ids)
            image_embeds = self.vision_model(pixel_values).last_hidden_state
            patch_embeds = image_embeds[:, 5:, :]
            b, num_patches, hidden_dim = patch_embeds.shape
            patch_embeds = patch_embeds.view(b, num_patches // 4, hidden_dim * 4)
            image_features = self.adapter(patch_embeds)
            text_embeds = text_embeds.to(image_features.dtype)
            inputs_embeds = self.merge_input_ids_with_image_features(image_features, text_embeds, input_ids)
        else:
            inputs_embeds = self.llm_model.get_input_embeddings()(input_ids)
        return self.llm_model.generate(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            do_sample=do_sample,
            num_beams=num_beams,
            use_cache=use_cache,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            **kwargs
        )

    def can_generate(self):
        return True

    def merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
        num_images, num_image_patches, embed_dim = image_features.shape
        batch_indices, image_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0])
        if len(batch_indices) == 0:
            return inputs_embeds
        inputs_embeds[batch_indices, image_indices] = image_features.view(-1, embed_dim)
        return inputs_embeds