Update modeling_phi.py
Browse files- modeling_phi.py +1 -1
modeling_phi.py
CHANGED
|
@@ -509,7 +509,7 @@ class PhiFlashAttention2(PhiAttention):
|
|
| 509 |
value_states = value_states.to(target_dtype)
|
| 510 |
|
| 511 |
attn_output = self._flash_attention_forward(
|
| 512 |
-
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=
|
| 513 |
)
|
| 514 |
|
| 515 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
|
|
| 509 |
value_states = value_states.to(target_dtype)
|
| 510 |
|
| 511 |
attn_output = self._flash_attention_forward(
|
| 512 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
|
| 513 |
)
|
| 514 |
|
| 515 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|