Update modelling_longitudinal.py
Browse filesFixed issues with attn_implementation and decoder_inputs['past_key_values'].
modelling_longitudinal.py
CHANGED
|
@@ -127,6 +127,7 @@ class LongitudinalPromptMultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
| 127 |
encoder = MultiCvtWithProjectionHead(config=config.encoder)
|
| 128 |
|
| 129 |
# Decoder:
|
|
|
|
| 130 |
if decoder is None:
|
| 131 |
decoder = transformers.BertLMHeadModel(config=config.decoder)
|
| 132 |
|
|
@@ -288,7 +289,7 @@ class LongitudinalPromptMultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
| 288 |
'decoder_token_type_ids': token_type_ids,
|
| 289 |
'decoder_position_ids': decoder_position_ids,
|
| 290 |
'encoder_outputs': encoder_outputs,
|
| 291 |
-
'past_key_values':
|
| 292 |
'use_cache': use_cache,
|
| 293 |
}
|
| 294 |
return input_dict
|
|
|
|
| 127 |
encoder = MultiCvtWithProjectionHead(config=config.encoder)
|
| 128 |
|
| 129 |
# Decoder:
|
| 130 |
+
config.decoder._attn_implementation = 'eager'
|
| 131 |
if decoder is None:
|
| 132 |
decoder = transformers.BertLMHeadModel(config=config.decoder)
|
| 133 |
|
|
|
|
| 289 |
'decoder_token_type_ids': token_type_ids,
|
| 290 |
'decoder_position_ids': decoder_position_ids,
|
| 291 |
'encoder_outputs': encoder_outputs,
|
| 292 |
+
'past_key_values': past_key_values,
|
| 293 |
'use_cache': use_cache,
|
| 294 |
}
|
| 295 |
return input_dict
|