Fix Flash with batch inputs
Browse files- modeling_eurobert.py +1 -1
modeling_eurobert.py
CHANGED
|
@@ -526,7 +526,7 @@ class EuroBertModel(EuroBertPreTrainedModel):
|
|
| 526 |
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
|
| 527 |
mask = self.mask_converter.to_4d(attention_mask, attention_mask.shape[1], inputs_embeds.dtype)
|
| 528 |
else:
|
| 529 |
-
mask =
|
| 530 |
|
| 531 |
hidden_states = inputs_embeds
|
| 532 |
|
|
|
|
| 526 |
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
|
| 527 |
mask = self.mask_converter.to_4d(attention_mask, attention_mask.shape[1], inputs_embeds.dtype)
|
| 528 |
else:
|
| 529 |
+
mask = attention_mask
|
| 530 |
|
| 531 |
hidden_states = inputs_embeds
|
| 532 |
|