Fix CXRBertOutput errors
Browse files* The latest transformers complains that `CXRBertOutput` is not a dataclass: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/utils/generic.py#L354-L358
* And that it can't have more than one required field: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/utils/generic.py#L370-L371
- modeling_cxrbert.py +4 -2
modeling_cxrbert.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
| 4 |
# ------------------------------------------------------------------------------------------
|
| 5 |
|
|
|
|
| 6 |
from typing import Any, Optional, Tuple, Union
|
| 7 |
|
| 8 |
import torch
|
|
@@ -16,9 +17,10 @@ from .configuration_cxrbert import CXRBertConfig
|
|
| 16 |
|
| 17 |
BERTTupleOutput = Tuple[T, T, T, T, T]
|
| 18 |
|
|
|
|
| 19 |
class CXRBertOutput(ModelOutput):
|
| 20 |
-
last_hidden_state: torch.FloatTensor
|
| 21 |
-
logits: torch.FloatTensor
|
| 22 |
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
| 23 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 24 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
|
| 3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
| 4 |
# ------------------------------------------------------------------------------------------
|
| 5 |
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
from typing import Any, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
|
|
|
| 17 |
|
| 18 |
BERTTupleOutput = Tuple[T, T, T, T, T]
|
| 19 |
|
| 20 |
+
@dataclass
|
| 21 |
class CXRBertOutput(ModelOutput):
|
| 22 |
+
last_hidden_state: torch.FloatTensor = None
|
| 23 |
+
logits: torch.FloatTensor = None
|
| 24 |
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
| 25 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 26 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|