shuai bai
commited on
Commit
·
616d8e0
1
Parent(s):
5a88f83
Update modeling_qwen.py
Browse files- modeling_qwen.py +9 -6
modeling_qwen.py
CHANGED
|
@@ -564,7 +564,13 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 564 |
|
| 565 |
images = self.visual.encode(images)
|
| 566 |
assert images.shape[0] == len(images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
else:
|
|
|
|
| 568 |
images = None
|
| 569 |
|
| 570 |
output_attentions = (
|
|
@@ -623,11 +629,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 623 |
|
| 624 |
if inputs_embeds is None:
|
| 625 |
inputs_embeds = self.wte(input_ids)
|
| 626 |
-
if self.training and images == None: # Compatible with plain text data training
|
| 627 |
-
fake_images=torch.zeros(1,3,224,224).to(
|
| 628 |
-
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
|
| 629 |
-
image_embeds = self.visual(fake_images)
|
| 630 |
-
inputs_embeds = inputs_embeds + image_embeds.mean()*0
|
| 631 |
|
| 632 |
if batch_size <= 0:
|
| 633 |
raise ValueError("batch_size has to be defined and > 0")
|
|
@@ -657,7 +658,9 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 657 |
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
|
| 658 |
|
| 659 |
hidden_states = self.drop(hidden_states).clone()
|
| 660 |
-
if
|
|
|
|
|
|
|
| 661 |
for idx, (i, a, b) in enumerate(img_pos):
|
| 662 |
hidden_states[i][a + 1 : b] = images[idx]
|
| 663 |
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
| 564 |
|
| 565 |
images = self.visual.encode(images)
|
| 566 |
assert images.shape[0] == len(images)
|
| 567 |
+
fake_images = None
|
| 568 |
+
elif self.training:
|
| 569 |
+
fake_images=torch.zeros(1,3,224,224).to(
|
| 570 |
+
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
|
| 571 |
+
images = self.visual(fake_images)
|
| 572 |
else:
|
| 573 |
+
fake_images = None
|
| 574 |
images = None
|
| 575 |
|
| 576 |
output_attentions = (
|
|
|
|
| 629 |
|
| 630 |
if inputs_embeds is None:
|
| 631 |
inputs_embeds = self.wte(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
|
| 633 |
if batch_size <= 0:
|
| 634 |
raise ValueError("batch_size has to be defined and > 0")
|
|
|
|
| 658 |
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
|
| 659 |
|
| 660 |
hidden_states = self.drop(hidden_states).clone()
|
| 661 |
+
if fake_images is not None:
|
| 662 |
+
hidden_states = hidden_states + images.mean()*0
|
| 663 |
+
elif images is not None:
|
| 664 |
for idx, (i, a, b) in enumerate(img_pos):
|
| 665 |
hidden_states[i][a + 1 : b] = images[idx]
|
| 666 |
output_shape = input_shape + (hidden_states.size(-1),)
|