davda54 commited on
Commit
790a5cc
·
verified ·
1 Parent(s): 3e445a9

updated MLM predictions

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +4 -0
modeling_gptbert.py CHANGED
@@ -778,6 +778,10 @@ class GptBertForMaskedLM(GptBertModel):
778
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
779
  masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
780
 
 
 
 
 
781
  if not return_dict:
782
  output = (
783
  subword_prediction,
 
778
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
779
  masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
780
 
781
+ bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
782
+ bos_logits[:, :, self.config.bos_token_id] = 1.0
783
+ subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
784
+
785
  if not return_dict:
786
  output = (
787
  subword_prediction,