Patrick Haller
commited on
Commit
·
d4d56eb
1
Parent(s):
8ac80ca
Making configurable
Browse files
configuration_hf_alibaba_nlp_gte.py
CHANGED
|
@@ -116,6 +116,8 @@ class GteConfig(PretrainedConfig):
|
|
| 116 |
use_memory_efficient_attention=False,
|
| 117 |
logn_attention_scale=False,
|
| 118 |
logn_attention_clip1=False,
|
|
|
|
|
|
|
| 119 |
**kwargs,
|
| 120 |
):
|
| 121 |
super().__init__(**kwargs)
|
|
@@ -142,4 +144,7 @@ class GteConfig(PretrainedConfig):
|
|
| 142 |
self.unpad_inputs = unpad_inputs
|
| 143 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
| 144 |
self.logn_attention_scale = logn_attention_scale
|
| 145 |
-
self.logn_attention_clip1 = logn_attention_clip1
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
use_memory_efficient_attention=False,
|
| 117 |
logn_attention_scale=False,
|
| 118 |
logn_attention_clip1=False,
|
| 119 |
+
add_pooling_layer=True,
|
| 120 |
+
num_labels=0,
|
| 121 |
**kwargs,
|
| 122 |
):
|
| 123 |
super().__init__(**kwargs)
|
|
|
|
| 144 |
self.unpad_inputs = unpad_inputs
|
| 145 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
| 146 |
self.logn_attention_scale = logn_attention_scale
|
| 147 |
+
self.logn_attention_clip1 = logn_attention_clip1
|
| 148 |
+
|
| 149 |
+
self.add_pooling_layer = add_pooling_layer
|
| 150 |
+
self.num_labels = num_labels
|
modeling_hf_alibaba_nlp_gte.py
CHANGED
|
@@ -970,8 +970,9 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
| 970 |
def __init__(self, config: GteConfig):
|
| 971 |
super().__init__(config)
|
| 972 |
self.config = config
|
| 973 |
-
self.num_labels =
|
| 974 |
-
|
|
|
|
| 975 |
|
| 976 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 977 |
self.loss_function = nn.MSELoss()
|
|
@@ -1010,7 +1011,10 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
| 1010 |
output_attentions=output_attentions,
|
| 1011 |
output_hidden_states=output_hidden_states,
|
| 1012 |
)
|
| 1013 |
-
|
|
|
|
|
|
|
|
|
|
| 1014 |
|
| 1015 |
logits = self.score(hidden_states)
|
| 1016 |
|
|
|
|
| 970 |
def __init__(self, config: GteConfig):
|
| 971 |
super().__init__(config)
|
| 972 |
self.config = config
|
| 973 |
+
self.num_labels = config.num_labels
|
| 974 |
+
assert config.num_labels > 0, "num_labels should be greater than 0 for sequence classification"
|
| 975 |
+
self.model = GteModel(config, add_pooling_layer=config.add_pooling_layer)
|
| 976 |
|
| 977 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 978 |
self.loss_function = nn.MSELoss()
|
|
|
|
| 1011 |
output_attentions=output_attentions,
|
| 1012 |
output_hidden_states=output_hidden_states,
|
| 1013 |
)
|
| 1014 |
+
if self.config.add_pooling_layer:
|
| 1015 |
+
hidden_states = transformer_outputs.pooler_output
|
| 1016 |
+
else:
|
| 1017 |
+
hidden_states = transformer_outputs.last_hidden_state[:, 0]
|
| 1018 |
|
| 1019 |
logits = self.score(hidden_states)
|
| 1020 |
|