Update modeling_internvl_chat.py
Browse files
modeling_internvl_chat.py
CHANGED
|
@@ -38,7 +38,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 38 |
_supports_flash_attn_2 = True
|
| 39 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer', 'MistralDecoderLayer']
|
| 40 |
|
| 41 |
-
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
|
| 42 |
super().__init__(config)
|
| 43 |
|
| 44 |
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
|
|
@@ -81,7 +81,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 81 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
| 82 |
)
|
| 83 |
|
| 84 |
-
self.img_context_token_id =
|
| 85 |
self.mr_prompt = MRPromptV3()
|
| 86 |
|
| 87 |
def forward(
|
|
|
|
| 38 |
_supports_flash_attn_2 = True
|
| 39 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer', 'MistralDecoderLayer']
|
| 40 |
|
| 41 |
+
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True, img_context_token_id=None):
|
| 42 |
super().__init__(config)
|
| 43 |
|
| 44 |
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
|
|
|
|
| 81 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
| 82 |
)
|
| 83 |
|
| 84 |
+
self.img_context_token_id = img_context_token_id
|
| 85 |
self.mr_prompt = MRPromptV3()
|
| 86 |
|
| 87 |
def forward(
|