anicolson commited on
Commit
0792c86
·
verified ·
1 Parent(s): 77dc7a7

Update modelling_multi.py

Browse files

Fixed issues with attn_implementation and decoder_inputs['past_key_values'].

Files changed (1) hide show
  1. modelling_multi.py +2 -1
modelling_multi.py CHANGED
@@ -123,6 +123,7 @@ class MultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
123
  encoder = MultiCvtWithProjectionHead(config=config.encoder)
124
 
125
  # Decoder:
 
126
  if decoder is None:
127
  decoder = transformers.BertLMHeadModel(config=config.decoder)
128
 
@@ -254,7 +255,7 @@ class MultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
254
  'decoder_input_ids': decoder_inputs['input_ids'],
255
  'decoder_token_type_ids': token_type_ids,
256
  'encoder_outputs': encoder_outputs,
257
- 'past_key_values': decoder_inputs['past_key_values'],
258
  'use_cache': use_cache,
259
  }
260
  return input_dict
 
123
  encoder = MultiCvtWithProjectionHead(config=config.encoder)
124
 
125
  # Decoder:
126
+ config.decoder._attn_implementation = 'eager'
127
  if decoder is None:
128
  decoder = transformers.BertLMHeadModel(config=config.decoder)
129
 
 
255
  'decoder_input_ids': decoder_inputs['input_ids'],
256
  'decoder_token_type_ids': token_type_ids,
257
  'encoder_outputs': encoder_outputs,
258
+ 'past_key_values': past_key_values,
259
  'use_cache': use_cache,
260
  }
261
  return input_dict