TianYeZ1214 commited on
Commit
e043d8b
·
verified ·
1 Parent(s): 3ba517c

Upload 3 files

Browse files
Files changed (3) hide show
  1. Qwenov3Config.py +186 -0
  2. StreamlitUI.py +204 -0
  3. inference.py +50 -0
Qwenov3Config.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
2
+ from modelscope import AutoConfig, AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+
7
+
8
+ class Qwenov3Config(PretrainedConfig):
9
+ model_type = "Qwenov3"
10
+
11
+ def __init__(self, llm_model_path='Qwen/Qwen3-0.6B',
12
+ vision_model_path='facebook/dinov3-vitl16-pretrain-lvd1689m',
13
+ freeze_vision_model=False,
14
+ freeze_llm_model=False,
15
+ image_pad_num=49,
16
+ training_scratch=False,
17
+ num_hidden_layers=None,
18
+ hidden_size=None,
19
+ num_attention_heads=None,
20
+ vocab_size=None,
21
+ **kwargs):
22
+ self.vision_model_path = vision_model_path
23
+ self.llm_model_path = llm_model_path
24
+ self.freeze_vision_model = freeze_vision_model
25
+ self.freeze_llm_model = freeze_llm_model
26
+ self.image_pad_num = image_pad_num
27
+ self.freeze_vision_model = freeze_vision_model
28
+ self.training_scratch = training_scratch
29
+ self.num_hidden_layers = num_hidden_layers
30
+ self.hidden_size = hidden_size
31
+ self.num_attention_heads = num_attention_heads
32
+ self.vocab_size = vocab_size
33
+
34
+ super().__init__(**kwargs)
35
+
36
+
37
+ class Qwenov3(GenerationMixin, PreTrainedModel):
38
+ config_class = Qwenov3Config
39
+
40
+ def __init__(self, config):
41
+ super().__init__(config)
42
+ self.config = config
43
+ if self.config.training_scratch:
44
+ self.vision_model = AutoModel.from_pretrained(self.config.vision_model_path, low_cpu_mem_usage=True,
45
+ dtype=torch.bfloat16, attn_implementation="flash_attention_2")
46
+ self.llm_model = AutoModelForCausalLM.from_pretrained(self.config.llm_model_path, low_cpu_mem_usage=True,
47
+ dtype=torch.bfloat16,
48
+ attn_implementation="flash_attention_2")
49
+ else:
50
+ vision_config = AutoConfig.from_pretrained(self.config.vision_model_path)
51
+ self.vision_model = AutoModel.from_config(vision_config, attn_implementation="sdpa", dtype=torch.bfloat16)
52
+ llm_config = AutoConfig.from_pretrained(self.config.llm_model_path)
53
+ self.llm_model = AutoModelForCausalLM.from_config(llm_config, attn_implementation="sdpa", dtype=torch.bfloat16)
54
+
55
+ if self.config.num_hidden_layers is None:
56
+ self.config.num_hidden_layers = self.llm_model.config.num_hidden_layers
57
+ if self.config.hidden_size is None:
58
+ self.config.hidden_size = self.llm_model.config.hidden_size
59
+ if self.config.num_attention_heads is None:
60
+ self.config.num_attention_heads = self.llm_model.config.num_attention_heads
61
+ if self.config.vocab_size is None:
62
+ self.config.vocab_size = self.llm_model.config.vocab_size
63
+
64
+ self.processor = AutoProcessor.from_pretrained(self.config.vision_model_path)
65
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path, use_fast=True)
66
+
67
+ if self.tokenizer.pad_token is None:
68
+ self.tokenizer.pad_token = self.tokenizer.eos_token
69
+ if '<|image_pad|>' not in self.tokenizer.get_vocab():
70
+ self.tokenizer.add_tokens(['<|image_pad|>'])
71
+ self.llm_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
72
+ if '<|vision_start|>' not in self.tokenizer.get_vocab():
73
+ self.tokenizer.add_tokens(['<|vision_start|>'])
74
+ self.llm_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
75
+ if '<|vision_end|>' not in self.tokenizer.get_vocab():
76
+ self.tokenizer.add_tokens(['<|vision_end|>'])
77
+ self.llm_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
78
+
79
+ self.adapter = nn.Sequential(
80
+ nn.RMSNorm(4096, dtype=torch.bfloat16),
81
+ nn.Linear(4096, self.llm_model.config.hidden_size, dtype=torch.bfloat16),
82
+ nn.GELU(),
83
+ nn.Linear(self.llm_model.config.hidden_size, self.llm_model.config.hidden_size, dtype=torch.bfloat16)
84
+ )
85
+
86
+ if self.config.freeze_vision_model:
87
+ for param in self.vision_model.parameters():
88
+ param.requires_grad = False
89
+ if self.config.freeze_llm_model:
90
+ for param in self.llm_model.parameters():
91
+ param.requires_grad = False
92
+
93
+ def forward(self, input_ids=None, labels=None, pixel_values=None, attention_mask=None,
94
+ inputs_embeds=None, past_key_values=None, use_cache=None, **kwargs):
95
+
96
+ if inputs_embeds is None:
97
+ text_embeds = self.llm_model.get_input_embeddings()(input_ids)
98
+ if pixel_values is not None:
99
+ image_embeds = self.vision_model(pixel_values).last_hidden_state
100
+ patch_embeds = image_embeds[:, 5:, :] # [batch, 196, 1024]
101
+ b, num_patches, hidden_dim = patch_embeds.shape
102
+ patch_embeds = patch_embeds.view(b, num_patches // 4, hidden_dim * 4) # [batch, 49, 4096]
103
+ image_features = self.adapter(patch_embeds)
104
+ text_embeds = text_embeds.to(image_features.dtype)
105
+ inputs_embeds = self.merge_input_ids_with_image_features(image_features, text_embeds, input_ids)
106
+ else:
107
+ inputs_embeds = text_embeds
108
+
109
+ outputs = self.llm_model(
110
+ inputs_embeds=inputs_embeds,
111
+ attention_mask=attention_mask,
112
+ past_key_values=past_key_values,
113
+ use_cache=use_cache,
114
+ return_dict=True
115
+ )
116
+
117
+ logits = outputs.logits
118
+ loss = None
119
+ if labels is not None:
120
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
121
+ loss = loss_fct(
122
+ logits.view(-1, logits.size(-1)), labels.view(-1).to(logits.device)
123
+ )
124
+
125
+ return CausalLMOutputWithPast(
126
+ loss=loss,
127
+ logits=logits,
128
+ past_key_values=outputs.past_key_values,
129
+ hidden_states=outputs.hidden_states,
130
+ attentions=outputs.attentions
131
+ )
132
+
133
+ @torch.inference_mode()
134
+ def generate(self, input_ids=None, pixel_values=None, attention_mask=None,
135
+ max_new_tokens=512, temperature=0.7, top_p=0.8, top_k=20,
136
+ do_sample=True, num_beams=1, use_cache=True, **kwargs):
137
+ if pixel_values is not None:
138
+ text_embeds = self.llm_model.get_input_embeddings()(input_ids)
139
+ image_embeds = self.vision_model(pixel_values).last_hidden_state
140
+ patch_embeds = image_embeds[:, 5:, :]
141
+ b, num_patches, hidden_dim = patch_embeds.shape
142
+ patch_embeds = patch_embeds.view(b, num_patches // 4, hidden_dim * 4)
143
+ image_features = self.adapter(patch_embeds)
144
+ text_embeds = text_embeds.to(image_features.dtype)
145
+ inputs_embeds = self.merge_input_ids_with_image_features(image_features, text_embeds, input_ids)
146
+ return self.llm_model.generate(
147
+ input_ids=input_ids,
148
+ inputs_embeds=inputs_embeds,
149
+ attention_mask=attention_mask,
150
+ max_new_tokens=max_new_tokens,
151
+ temperature=temperature,
152
+ top_p=top_p,
153
+ top_k=top_k,
154
+ do_sample=do_sample,
155
+ num_beams=num_beams,
156
+ use_cache=use_cache,
157
+ pad_token_id=self.tokenizer.pad_token_id,
158
+ eos_token_id=self.tokenizer.eos_token_id,
159
+ **kwargs
160
+ )
161
+ else:
162
+ return self.llm_model.generate(
163
+ input_ids=input_ids,
164
+ attention_mask=attention_mask,
165
+ max_new_tokens=max_new_tokens,
166
+ temperature=temperature,
167
+ top_p=top_p,
168
+ top_k=top_k,
169
+ do_sample=do_sample,
170
+ num_beams=num_beams,
171
+ use_cache=use_cache,
172
+ pad_token_id=self.tokenizer.pad_token_id,
173
+ eos_token_id=self.tokenizer.eos_token_id,
174
+ **kwargs
175
+ )
176
+
177
+ def can_generate(self):
178
+ return True
179
+
180
+ def merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
181
+ num_images, num_image_patches, embed_dim = image_features.shape
182
+ batch_indices, image_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0])
183
+ if len(batch_indices) == 0:
184
+ return inputs_embeds
185
+ inputs_embeds[batch_indices, image_indices] = image_features.view(-1, embed_dim)
186
+ return inputs_embeds
StreamlitUI.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, TextIteratorStreamer, AutoConfig
4
+ import gc
5
+ from threading import Thread
6
+ from Qwenov3Config import Qwenov3Config, Qwenov3
7
+ from PIL import Image
8
+
9
+ MODEL_MAPPING = {
10
+ 'QwenoV3-Pretrain': '',
11
+ 'QwenoV3-SFT': '',
12
+ }
13
+
14
+
15
+ def unload_model():
16
+ if 'model' in st.session_state:
17
+ del st.session_state.model
18
+ if 'tokenizer' in st.session_state:
19
+ del st.session_state.tokenizer
20
+ if 'processor' in st.session_state:
21
+ del st.session_state.processor
22
+ if 'streamer' in st.session_state:
23
+ del st.session_state.streamer
24
+ torch.cuda.empty_cache()
25
+ gc.collect()
26
+
27
+
28
+ def call_model(info_placeholder, messages, generated_text, message_placeholder, image=None):
29
+ info_placeholder.markdown(f'已选择{st.session_state.model_display}执行任务')
30
+ if image is not None:
31
+ image = Image.open(image).convert('RGB')
32
+ if '<image>' not in messages[1]['content']:
33
+ messages[1]['content'] = '<image>\n' + messages[1]['content']
34
+
35
+ query_text = st.session_state.tokenizer.apply_chat_template(
36
+ messages,
37
+ tokenize=False,
38
+ add_generation_prompt=True,
39
+ enable_thinking=False
40
+ )
41
+ if '<image>' in query_text:
42
+ query_text = query_text.replace('<image>', '<|vision_start|>' + '<|image_pad|>' *
43
+ st.session_state.model.config.image_pad_num + '<|vision_end|>')
44
+ text_inputs = st.session_state.tokenizer(query_text, return_tensors="pt")
45
+ input_ids = text_inputs['input_ids'].to(st.session_state.model.device)
46
+ attention_mask = text_inputs['attention_mask'].to(st.session_state.model.device)
47
+ text_embeds = st.session_state.model.llm_model.get_input_embeddings()(input_ids)
48
+
49
+ if image is not None:
50
+ pixel_values = st.session_state.processor(images=image, return_tensors="pt")['pixel_values'].to(
51
+ st.session_state.model.device)
52
+ image_embeds = st.session_state.model.vision_model(pixel_values).last_hidden_state
53
+ patch_embeds = image_embeds[:, 5:, :]
54
+ b, num_patches, hidden_dim = patch_embeds.shape
55
+ patch_embeds = patch_embeds.view(b, num_patches // 4, hidden_dim * 4)
56
+ image_features = st.session_state.model.adapter(patch_embeds)
57
+ text_embeds = text_embeds.to(image_features.dtype)
58
+ inputs_embeds = st.session_state.model.merge_input_ids_with_image_features(image_features, text_embeds, input_ids)
59
+ else:
60
+ inputs_embeds = text_embeds
61
+
62
+ generate_params = dict(
63
+ inputs_embeds=inputs_embeds,
64
+ attention_mask=attention_mask,
65
+ max_new_tokens=st.session_state.max_new_tokens,
66
+ min_new_tokens=st.session_state.min_new_tokens,
67
+ do_sample=True,
68
+ temperature=st.session_state.temperature,
69
+ top_k=st.session_state.top_k,
70
+ top_p=st.session_state.top_p,
71
+ min_p=0.0,
72
+ repetition_penalty=st.session_state.repetition_penalty,
73
+ streamer=st.session_state.streamer,
74
+ eos_token_id=st.session_state.tokenizer.eos_token_id
75
+ )
76
+ thread = Thread(target=st.session_state.model.llm_model.generate, kwargs=generate_params)
77
+ thread.start()
78
+
79
+ for new_text in st.session_state.streamer:
80
+ generated_text += new_text
81
+ message_placeholder.markdown(generated_text)
82
+
83
+ return generated_text
84
+
85
+
86
+ def ini_message():
87
+ if 'messages' not in st.session_state:
88
+ st.session_state.messages = [
89
+ {"role": "system", "content": "You are QwenoV3, a helpful assistant created by 天烨."},
90
+ ]
91
+ if 'uploaded_image' not in st.session_state:
92
+ st.session_state.uploaded_image = None
93
+
94
+
95
+ def parameter_settings():
96
+ with st.sidebar:
97
+ previous_model = st.session_state.get('model_display', None)
98
+ st.session_state.model_display = st.selectbox("选择模型", list(MODEL_MAPPING.keys()),
99
+ index=len(MODEL_MAPPING.keys()) - 1, help="选择模型")
100
+ st.session_state.model_path = MODEL_MAPPING[st.session_state.model_display]
101
+ with st.expander("对话参数", expanded=False):
102
+ col1, col2 = st.columns(2)
103
+ with col1:
104
+ st.session_state.temperature = st.slider("Temperature", 0.0, 2.0, 0.7, 0.1,
105
+ help="控制模型回答的多样性,值越高表示回复多样性越高")
106
+ st.session_state.min_new_tokens = st.number_input("Min Tokens",
107
+ min_value=0,
108
+ max_value=512,
109
+ value=10,
110
+ help="生成文本的最小长度")
111
+ st.session_state.max_new_tokens = st.number_input("Max Tokens",
112
+ min_value=1,
113
+ max_value=4096,
114
+ value=512,
115
+ help="生成文本的最大长度")
116
+ with col2:
117
+ st.session_state.top_p = st.slider("Top P", 0.0, 1.0, 0.8, 0.1,
118
+ help="控制词汇选择的多样性,值越高表示潜在生成词汇越多样")
119
+ st.session_state.top_k = st.slider("Top K", 0, 80, 20, 1,
120
+ help="控制词汇选择的多样性,值越高表示潜在生成词汇越多样")
121
+ st.session_state.repetition_penalty = st.slider("Repetition Penalty", 0.0, 2.0, 1.05, 0.1,
122
+ help="控制回复主题的多样性性,值越高重复性越低")
123
+
124
+ with st.expander("图片上传", expanded=False):
125
+ st.session_state.uploaded_image = st.file_uploader(
126
+ "上传图片",
127
+ type=["jpg", "jpeg", "png"]
128
+ )
129
+ if st.session_state.uploaded_image:
130
+ image = Image.open(st.session_state.uploaded_image)
131
+ width, height = image.size
132
+ if width > 256 or height > 256:
133
+ scale = 256 / max(height, width)
134
+ new_h, new_w = int(height * scale), int(width * scale)
135
+ image = image.resize((new_w, new_h), Image.BILINEAR)
136
+ st.image(image, caption="图片预览")
137
+
138
+ if st.button("开启新对话", help="开启新对话将清空当前对话记录"):
139
+ st.session_state.uploaded_image = None
140
+ st.session_state.messages = [
141
+ {"role": "system", "content": "You are QwenoV3, a helpful assistant created by 天烨."},
142
+ ]
143
+ st.success("已成功开启新的对话")
144
+ st.rerun()
145
+
146
+ if previous_model != st.session_state.model_display or 'tokenizer' not in st.session_state or 'model' not in st.session_state or 'processor' not in st.session_state:
147
+ unload_model()
148
+ try:
149
+ with st.spinner('加载模型中...'):
150
+ AutoConfig.register("Qwenov3", Qwenov3Config)
151
+ AutoModelForCausalLM.register(Qwenov3Config, Qwenov3)
152
+ st.session_state.model = AutoModelForCausalLM.from_pretrained(
153
+ st.session_state.model_path,
154
+ torch_dtype=torch.bfloat16,
155
+ device_map="auto",
156
+ low_cpu_mem_usage=True,
157
+ trust_remote_code=True
158
+ )
159
+ st.session_state.tokenizer = st.session_state.model.tokenizer
160
+ st.session_state.processor = st.session_state.model.processor
161
+ st.session_state.streamer = TextIteratorStreamer(st.session_state.tokenizer,
162
+ skip_prompt=True, skip_special_tokens=True)
163
+ except Exception as e:
164
+ st.error('模型加载出错:', e)
165
+ return
166
+
167
+
168
+ def main():
169
+ st.markdown("""
170
+ <h1 style='text-align: center;'>
171
+ QwenoV3 - Marrying DinoV3 With Qwen3 🫡
172
+ </h1>
173
+ <div style='text-align: center; margin-bottom: 20px;'>
174
+ </div>
175
+ """, unsafe_allow_html=True)
176
+ ini_message()
177
+ parameter_settings()
178
+
179
+ for message in st.session_state.messages:
180
+ if message["role"] == "system":
181
+ continue
182
+ with st.chat_message(message["role"]):
183
+ st.markdown(message["content"])
184
+
185
+ if user_input := st.chat_input("在这里输入您的问题:", key="chat_input"):
186
+ with st.chat_message("user"):
187
+ st.markdown(user_input)
188
+ st.session_state.messages.append({"role": "user", "content": user_input})
189
+
190
+ with st.chat_message("assistant"):
191
+ info_placeholder = st.empty()
192
+ message_placeholder = st.empty()
193
+ generated_text = ""
194
+ try:
195
+ with torch.inference_mode():
196
+ generated_text = call_model(info_placeholder, st.session_state.messages, generated_text,
197
+ message_placeholder, st.session_state.uploaded_image)
198
+ st.session_state.messages.append({"role": "assistant", "content": generated_text})
199
+ except Exception as e:
200
+ st.error(f"生成回答时出错: {str(e)}")
201
+
202
+
203
+ if __name__ == '__main__':
204
+ main()
inference.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoConfig
2
+ from PIL import Image
3
+ from Qwenov3Config import Qwenov3Config, Qwenov3
4
+ import torch
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ model_path = ''
8
+ AutoConfig.register("Qwenov3", Qwenov3Config)
9
+ AutoModelForCausalLM.register(Qwenov3Config, Qwenov3)
10
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, dtype=torch.bfloat16,
11
+ trust_remote_code=True).to(device)
12
+ model.eval()
13
+ processor = model.processor
14
+ tokenizer = model.tokenizer
15
+ messages = [
16
+ {"role": "system", "content": 'You are a helpful assistant.'},
17
+ {"role": "user", "content": '<image>\n用中文描述图片内容。'},
18
+ ]
19
+ if '<image>' not in messages[1]['content']:
20
+ messages[1]['content'] = '<image>\n' + messages[1]['content']
21
+
22
+ print(messages)
23
+
24
+ q_text = tokenizer.apply_chat_template(messages,
25
+ tokenize=False,
26
+ add_generation_prompt=True,
27
+ enable_thinking=False).replace('<image>',
28
+ '<|vision_start|>' + '<|image_pad|>' * model.config.image_pad_num + '<|vision_end|>')
29
+ print(q_text)
30
+
31
+ text_inputs = tokenizer(q_text, return_tensors='pt')
32
+ input_ids = text_inputs['input_ids'].to(device)
33
+ attention_mask = text_inputs['attention_mask'].to(device)
34
+
35
+ image = Image.open('')
36
+ pixel_values = processor(images=image, return_tensors="pt")['pixel_values'].to(device)
37
+
38
+ output_ids = model.generate(
39
+ input_ids=input_ids,
40
+ attention_mask=attention_mask,
41
+ pixel_values=pixel_values,
42
+ max_new_tokens=512,
43
+ temperature=0.7,
44
+ top_k=20,
45
+ top_p=0.8,
46
+ do_sample=True,
47
+ repetition_penalty=1.00,
48
+ )
49
+
50
+ print(tokenizer.decode(output_ids[0], skip_special_tokens=True))