Update modelling_multi.py
Browse filesFixed issues with attn_implementation and decoder_inputs['past_key_values'].
- modelling_multi.py +2 -1
modelling_multi.py
CHANGED
|
@@ -123,6 +123,7 @@ class MultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
| 123 |
encoder = MultiCvtWithProjectionHead(config=config.encoder)
|
| 124 |
|
| 125 |
# Decoder:
|
|
|
|
| 126 |
if decoder is None:
|
| 127 |
decoder = transformers.BertLMHeadModel(config=config.decoder)
|
| 128 |
|
|
@@ -254,7 +255,7 @@ class MultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
| 254 |
'decoder_input_ids': decoder_inputs['input_ids'],
|
| 255 |
'decoder_token_type_ids': token_type_ids,
|
| 256 |
'encoder_outputs': encoder_outputs,
|
| 257 |
-
'past_key_values':
|
| 258 |
'use_cache': use_cache,
|
| 259 |
}
|
| 260 |
return input_dict
|
|
|
|
| 123 |
encoder = MultiCvtWithProjectionHead(config=config.encoder)
|
| 124 |
|
| 125 |
# Decoder:
|
| 126 |
+
config.decoder._attn_implementation = 'eager'
|
| 127 |
if decoder is None:
|
| 128 |
decoder = transformers.BertLMHeadModel(config=config.decoder)
|
| 129 |
|
|
|
|
| 255 |
'decoder_input_ids': decoder_inputs['input_ids'],
|
| 256 |
'decoder_token_type_ids': token_type_ids,
|
| 257 |
'encoder_outputs': encoder_outputs,
|
| 258 |
+
'past_key_values': past_key_values,
|
| 259 |
'use_cache': use_cache,
|
| 260 |
}
|
| 261 |
return input_dict
|