Update modeling_flmr.py
Browse files- modeling_flmr.py +16 -14
modeling_flmr.py
CHANGED
|
@@ -584,13 +584,14 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
|
|
| 584 |
self.text_encoder_embedding_size = self.config.text_config.hidden_size
|
| 585 |
self.late_interaction_embedding_size = self.config.dim
|
| 586 |
|
| 587 |
-
self.
|
| 588 |
-
(
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
|
|
|
|
|
|
| 592 |
)
|
| 593 |
-
)
|
| 594 |
|
| 595 |
if self.config.use_vision_encoder:
|
| 596 |
self.context_vision_encoder = FLMRVisionModel(config.vision_config)
|
|
@@ -636,13 +637,14 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
|
|
| 636 |
self.query_text_encoder_linear = self.context_text_encoder_linear
|
| 637 |
self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
|
| 638 |
|
| 639 |
-
if self.config.
|
| 640 |
-
self.
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
|
|
|
| 646 |
|
| 647 |
if self.config.load_cpu_extension:
|
| 648 |
try:
|
|
@@ -1304,7 +1306,7 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
|
|
| 1304 |
# TODO: fix the engine to support masks with discontinuous 0 and 1.
|
| 1305 |
D = torch.cat([vision_embeddings, text_embeddings], dim=1)
|
| 1306 |
# concatenate the mask
|
| 1307 |
-
mask = torch.cat([
|
| 1308 |
elif concat_output_from_vision_encoder:
|
| 1309 |
D = vision_embeddings
|
| 1310 |
mask = image_mask
|
|
|
|
| 584 |
self.text_encoder_embedding_size = self.config.text_config.hidden_size
|
| 585 |
self.late_interaction_embedding_size = self.config.dim
|
| 586 |
|
| 587 |
+
if self.config.use_vision_encoder:
|
| 588 |
+
self.context_vision_projection = FLMRMultiLayerPerceptron(
|
| 589 |
+
(
|
| 590 |
+
self.vision_encoder_embedding_size,
|
| 591 |
+
(self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
|
| 592 |
+
self.late_interaction_embedding_size * self.mapping_network_prefix_length,
|
| 593 |
+
)
|
| 594 |
)
|
|
|
|
| 595 |
|
| 596 |
if self.config.use_vision_encoder:
|
| 597 |
self.context_vision_encoder = FLMRVisionModel(config.vision_config)
|
|
|
|
| 637 |
self.query_text_encoder_linear = self.context_text_encoder_linear
|
| 638 |
self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
|
| 639 |
|
| 640 |
+
if self.config.use_vision_encoder:
|
| 641 |
+
if self.config.separate_query_and_context_vision_encoder:
|
| 642 |
+
self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
|
| 643 |
+
self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
|
| 644 |
+
else:
|
| 645 |
+
self.query_vision_encoder = self.context_vision_encoder
|
| 646 |
+
self.query_vision_projection = self.context_vision_projection
|
| 647 |
+
self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
|
| 648 |
|
| 649 |
if self.config.load_cpu_extension:
|
| 650 |
try:
|
|
|
|
| 1306 |
# TODO: fix the engine to support masks with discontinuous 0 and 1.
|
| 1307 |
D = torch.cat([vision_embeddings, text_embeddings], dim=1)
|
| 1308 |
# concatenate the mask
|
| 1309 |
+
mask = torch.cat([image_mask, mask], dim=1)
|
| 1310 |
elif concat_output_from_vision_encoder:
|
| 1311 |
D = vision_embeddings
|
| 1312 |
mask = image_mask
|