anicolson commited on
Commit
6671e0f
·
verified ·
1 Parent(s): ac09e83

Update modelling_longitudinal.py

Browse files

Fixed issues with attn_implementation and decoder_inputs['past_key_values'].

Files changed (1) hide show
  1. modelling_longitudinal.py +2 -1
modelling_longitudinal.py CHANGED
@@ -127,6 +127,7 @@ class LongitudinalPromptMultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
127
  encoder = MultiCvtWithProjectionHead(config=config.encoder)
128
 
129
  # Decoder:
 
130
  if decoder is None:
131
  decoder = transformers.BertLMHeadModel(config=config.decoder)
132
 
@@ -288,7 +289,7 @@ class LongitudinalPromptMultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
288
  'decoder_token_type_ids': token_type_ids,
289
  'decoder_position_ids': decoder_position_ids,
290
  'encoder_outputs': encoder_outputs,
291
- 'past_key_values': decoder_inputs['past_key_values'],
292
  'use_cache': use_cache,
293
  }
294
  return input_dict
 
127
  encoder = MultiCvtWithProjectionHead(config=config.encoder)
128
 
129
  # Decoder:
130
+ config.decoder._attn_implementation = 'eager'
131
  if decoder is None:
132
  decoder = transformers.BertLMHeadModel(config=config.decoder)
133
 
 
289
  'decoder_token_type_ids': token_type_ids,
290
  'decoder_position_ids': decoder_position_ids,
291
  'encoder_outputs': encoder_outputs,
292
+ 'past_key_values': past_key_values,
293
  'use_cache': use_cache,
294
  }
295
  return input_dict