consistency
Browse files
model.py
CHANGED
|
@@ -23,7 +23,7 @@ class BertClassifier(nn.Module):
|
|
| 23 |
output = self.bert(input_ids, attention_mask=attention_mask)
|
| 24 |
logits = self.classifier(output.pooler_output)
|
| 25 |
loss = None
|
| 26 |
-
if labels
|
| 27 |
loss_fct = nn.CrossEntropyLoss()
|
| 28 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 29 |
return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=output.hidden_states,attentions=output.attentions)
|
|
|
|
| 23 |
output = self.bert(input_ids, attention_mask=attention_mask)
|
| 24 |
logits = self.classifier(output.pooler_output)
|
| 25 |
loss = None
|
| 26 |
+
if labels:
|
| 27 |
loss_fct = nn.CrossEntropyLoss()
|
| 28 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 29 |
return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=output.hidden_states,attentions=output.attentions)
|