updated MLM predictions
Browse files- modeling_gptbert.py +4 -0
modeling_gptbert.py
CHANGED
|
@@ -778,6 +778,10 @@ class GptBertForMaskedLM(GptBertModel):
|
|
| 778 |
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 779 |
masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
if not return_dict:
|
| 782 |
output = (
|
| 783 |
subword_prediction,
|
|
|
|
| 778 |
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 779 |
masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 780 |
|
| 781 |
+
bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
|
| 782 |
+
bos_logits[:, :, self.config.bos_token_id] = 1.0
|
| 783 |
+
subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
|
| 784 |
+
|
| 785 |
if not return_dict:
|
| 786 |
output = (
|
| 787 |
subword_prediction,
|