Morton-Li commited on
Commit
21b44e5
·
1 Parent(s): dc2223b

更新模型版本并修复问题。

Browse files
Files changed (5) hide show
  1. QiDeBERTa.py +570 -99
  2. README.md +1 -1
  3. config.json +7 -6
  4. Configuration.py → configuration.py +17 -9
  5. tokenizer.py +52 -76
QiDeBERTa.py CHANGED
@@ -1,68 +1,473 @@
 
1
  from typing import Optional, Tuple
2
 
3
  import torch
4
- from torch.nn import Module, Linear, Parameter, LayerNorm, Embedding, Dropout, ModuleList, MultiheadAttention
5
  from torch.nn.functional import gelu
6
  from transformers import DebertaV2PreTrainedModel
7
- from transformers.modeling_outputs import BaseModelOutput, TokenClassifierOutput, MaskedLMOutput, SequenceClassifierOutput
8
- from transformers.models.deberta_v2.modeling_deberta_v2 import build_relative_position, DebertaV2Layer
 
9
 
10
- from .Configuration import QiDeBERTaConfig
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  class QiDeBERTaEmbeddings(Module):
14
  """Construct the embeddings from word, position and token_type embeddings."""
15
-
16
  def __init__(
17
  self,
18
  pad_token_id: int,
19
- hidden_size: int,
20
  vocab_size: int,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ):
22
  super().__init__()
23
- self.word_embeddings = Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- self.LayerNorm = LayerNorm(normalized_shape=hidden_size, eps=1e-7)
26
- self.dropout = Dropout(p=0.1)
 
 
 
27
 
28
- def forward(self, input_ids=None, mask=None, inputs_embeds=None):
29
- if inputs_embeds is None:
30
- inputs_embeds = self.word_embeddings(input_ids)
 
 
 
 
 
 
 
31
 
32
- embeddings = inputs_embeds
 
 
 
33
 
34
- embeddings = self.LayerNorm(embeddings)
 
 
 
35
 
36
- if mask is not None:
37
- if mask.dim() != embeddings.dim():
38
- if mask.dim() == 4:
39
- mask = mask.squeeze(1).squeeze(1)
40
- mask = mask.unsqueeze(2)
41
- mask = mask.to(embeddings.dtype)
42
 
43
- embeddings = embeddings * mask
 
 
44
 
45
- embeddings = self.dropout(embeddings)
46
- return embeddings
 
47
 
48
 
49
- class QiDeBERTaEncoder(Module):
50
- """Modified BertEncoder with relative position bias support"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- self.layer = ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
56
 
57
- self.max_relative_positions = config.max_position_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- self.position_buckets = config.position_buckets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- pos_ebd_size = self.position_buckets * 2
62
 
63
- self.rel_embeddings = Embedding(pos_ebd_size, config.hidden_size)
64
 
65
- self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  self.gradient_checkpointing = False
68
 
@@ -72,12 +477,10 @@ class QiDeBERTaEncoder(Module):
72
  rel_embeddings = self.LayerNorm(rel_embeddings)
73
  return rel_embeddings
74
 
75
- def get_attention_mask(self, attention_mask):
76
- if attention_mask.dim() <= 2:
77
- extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
78
- attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
79
- elif attention_mask.dim() == 3:
80
- attention_mask = attention_mask.unsqueeze(1)
81
 
82
  return attention_mask
83
 
@@ -94,6 +497,7 @@ class QiDeBERTaEncoder(Module):
94
  self,
95
  hidden_states,
96
  attention_mask,
 
97
  ):
98
  attention_mask = self.get_attention_mask(attention_mask)
99
  relative_pos = self.get_rel_pos(hidden_states)
@@ -109,34 +513,35 @@ class QiDeBERTaEncoder(Module):
109
  layer_module.__call__,
110
  next_kv,
111
  attention_mask,
112
- None,
113
  relative_pos,
114
  rel_embeddings,
115
- True,
116
  )
117
  else:
118
  output_states, attn_weights = layer_module(
119
- next_kv,
120
- attention_mask,
121
- query_states=None,
122
  relative_pos=relative_pos,
123
  rel_embeddings=rel_embeddings,
124
- output_attentions=True,
125
  )
126
 
127
- all_attentions = all_attentions + (attn_weights,)
 
128
 
129
  all_hidden_states = all_hidden_states + (output_states,)
130
 
131
  next_kv = output_states
132
 
133
- return BaseModelOutput(
134
- last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
 
 
135
  )
136
 
137
 
138
  class QiDeBERTaBase(DebertaV2PreTrainedModel):
139
- VERSION = '1.0.1'
140
  config_class = QiDeBERTaConfig
141
  base_model_prefix = 'qideberta'
142
  _encoder_layer_path = ''
@@ -171,7 +576,7 @@ class QiDeBERTaBase(DebertaV2PreTrainedModel):
171
  else:
172
  encoder_layer.requires_grad_(requires_grad=True)
173
 
174
- def freeze_embed_layer(self, freeze: bool = True):
175
  """
176
  Freeze the embedding layer
177
  :param freeze:
@@ -198,10 +603,26 @@ class QiDeBERTa(QiDeBERTaBase):
198
 
199
  self.embeddings = QiDeBERTaEmbeddings(
200
  pad_token_id=config.pad_token_id,
201
- hidden_size=config.hidden_size,
202
  vocab_size=config.vocab_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
- self.encoder = QiDeBERTaEncoder(config)
205
  # Initialize weights and apply final processing
206
  self.post_init()
207
 
@@ -213,84 +634,59 @@ class QiDeBERTa(QiDeBERTaBase):
213
 
214
  def forward(
215
  self,
216
- input_ids: Optional[torch.Tensor] = None,
217
  attention_mask: Optional[torch.Tensor] = None,
218
- inputs_embeds: Optional[torch.Tensor] = None,
219
- deep_recurrent_refinement_steps: int = 0,
220
  ) -> BaseModelOutput:
221
  """
222
  Forward pass of the model
223
 
224
- :param input_ids:
225
  :param attention_mask: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
226
  - 1 for tokens that are **not masked**,
227
  - 0 for tokens that are **masked**.
228
- :param inputs_embeds:
229
- :param deep_recurrent_refinement_steps:
230
  :return:
231
  """
232
- if input_ids is not None and inputs_embeds is not None:
233
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
234
- if input_ids is not None:
235
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
236
- input_shape = input_ids.size()
237
- elif inputs_embeds is not None:
238
- input_shape = inputs_embeds.size()[:-1]
239
- else:
240
- raise ValueError("You have to specify either input_ids or inputs_embeds")
241
-
242
- device = input_ids.device if input_ids is not None else inputs_embeds.device
243
 
244
  if attention_mask is None:
245
  attention_mask = torch.ones(input_shape, device=device)
246
 
247
- embedding_output = self.embeddings(
248
  input_ids=input_ids,
249
  mask=attention_mask,
250
- inputs_embeds=inputs_embeds,
251
  )
252
 
253
  encoder_outputs = self.encoder(
254
  hidden_states=embedding_output,
255
  attention_mask=attention_mask,
 
256
  )
257
- encoded_layers = encoder_outputs.hidden_states
258
-
259
- if deep_recurrent_refinement_steps > 1:
260
- hidden_states = encoded_layers[-2]
261
- layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
262
- query_states = encoded_layers[-1]
263
- rel_embeddings = self.encoder.get_rel_embedding()
264
- attention_mask = self.encoder.get_attention_mask(attention_mask)
265
- rel_pos = self.encoder.get_rel_pos(embedding_output)
266
- for layer in layers[1:]:
267
- query_states = layer(
268
- hidden_states,
269
- attention_mask,
270
- output_attentions=False,
271
- query_states=query_states,
272
- relative_pos=rel_pos,
273
- rel_embeddings=rel_embeddings,
274
- )
275
- encoded_layers.append(query_states)
276
 
277
  return BaseModelOutput(
278
  last_hidden_state=encoder_outputs.last_hidden_state,
279
  hidden_states=encoder_outputs.hidden_states,
280
- attentions=encoder_outputs.attentions,
 
 
281
  )
282
 
283
 
284
  class QiDeBERTaMLMHead(Module):
285
  def __init__(
286
  self,
287
- hidden_size: int,
288
- vocab_size: int
 
289
  ):
290
  super().__init__()
291
- self.dense = Linear(in_features=hidden_size, out_features=hidden_size)
292
 
293
- self.LayerNorm = LayerNorm(normalized_shape=hidden_size, eps=1e-7, elementwise_affine=True)
294
 
295
  self.bias = Parameter(torch.zeros(vocab_size))
296
 
@@ -302,26 +698,101 @@ class QiDeBERTaMLMHead(Module):
302
  if module.bias is not None:
303
  module.bias.data.zero_()
304
 
 
 
 
 
 
 
305
  def forward(self, hidden_states: torch.Tensor, word_embeddings: Embedding):
306
  hidden_states = self.dense(hidden_states)
307
- hidden_states = gelu(hidden_states)
308
  hidden_states = self.LayerNorm(hidden_states)
309
  hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
310
  return hidden_states
311
 
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  class QiDeBERTaForMaskedLM(QiDeBERTaBase):
314
  _tied_weights_keys = ["mlm_head.weight", "qideberta.embeddings.word_embeddings.weight"]
315
  _encoder_layer_path = 'qideberta.encoder'
316
  _embedding_layer_path = 'qideberta.embeddings'
317
- task_head = 'mlm_head'
318
 
319
  def __init__(self, config: QiDeBERTaConfig):
320
  super().__init__(config)
321
  self.qideberta = QiDeBERTa(config=config)
322
  self.mlm_head = QiDeBERTaMLMHead(
323
- hidden_size=config.hidden_size,
324
- vocab_size=config.vocab_size
 
325
  )
326
 
327
  self.post_init()
@@ -346,7 +817,7 @@ class QiDeBERTaForMaskedLM(QiDeBERTaBase):
346
  deep_recurrent_refinement_steps=0,
347
  )
348
 
349
- prediction_scores = self.mlm_head(hidden_states=outputs.last_hidden_state, word_embeddings=self.qideberta.embeddings.word_embeddings)
350
 
351
  return MaskedLMOutput(
352
  logits=prediction_scores,
 
1
+ from dataclasses import dataclass
2
  from typing import Optional, Tuple
3
 
4
  import torch
5
+ from torch.nn import Module, Linear, Parameter, LayerNorm, Embedding, Dropout, ModuleList, MultiheadAttention, functional
6
  from torch.nn.functional import gelu
7
  from transformers import DebertaV2PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutput as EncoderOutput, MaskedLMOutput
9
+ from transformers.models.deberta_v2.modeling_deberta_v2 import build_relative_position, scaled_size_sqrt, build_rpos
10
+ from transformers.utils import ModelOutput
11
 
12
+ from .configuration import QiDeBERTaConfig
13
+
14
+
15
+ @dataclass
16
+ class BaseModelOutput(ModelOutput):
17
+ last_hidden_state: Optional[torch.FloatTensor] = None
18
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
19
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
20
+ embedding_output: Optional[torch.FloatTensor] = None
21
+ token_embeddings: Optional[torch.FloatTensor] = None
22
 
23
 
24
  class QiDeBERTaEmbeddings(Module):
25
  """Construct the embeddings from word, position and token_type embeddings."""
 
26
  def __init__(
27
  self,
28
  pad_token_id: int,
29
+ d_model: int,
30
  vocab_size: int,
31
+ layer_norm_eps: float,
32
+ hidden_dropout_prob: float,
33
+ ):
34
+ super().__init__()
35
+ self.word_embeddings = Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=pad_token_id)
36
+ self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps)
37
+ self.dropout = Dropout(p=hidden_dropout_prob)
38
+
39
+ def forward(self, input_ids: torch.Tensor, mask: torch.Tensor):
40
+ inputs_embeds = self.word_embeddings(input_ids)
41
+ embeddings = self.LayerNorm(inputs_embeds)
42
+ if mask.dim() != embeddings.dim():
43
+ if mask.dim() == 4:
44
+ mask = mask.squeeze(1).squeeze(1)
45
+ mask = mask.unsqueeze(2)
46
+ mask = mask.to(embeddings.dtype)
47
+ return self.dropout(embeddings * mask), inputs_embeds
48
+
49
+
50
+ class QiDeBERTaDisentangledSelfAttention(Module):
51
+ """
52
+ Disentangled self-attention module
53
+ """
54
+ def __init__(
55
+ self,
56
+ num_heads: int,
57
+ d_model: int,
58
+ share_att_key: bool,
59
+ relative_attention: bool,
60
+ max_position_embeddings: int,
61
+ hidden_dropout_prob: float,
62
+ attention_probs_dropout_prob: float,
63
+ pos_att_type: Optional[list] = None,
64
+ position_buckets: int = -1,
65
+ max_relative_positions: int = -1,
66
  ):
67
  super().__init__()
68
+ self.num_attention_heads = num_heads
69
+ self.attention_head_size = d_model // num_heads
70
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
71
+ self.query_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
72
+ self.key_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
73
+ self.value_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
74
+
75
+ self.share_att_key = share_att_key
76
+ self.pos_att_type = pos_att_type if pos_att_type is not None else []
77
+ self.relative_attention = relative_attention
78
+
79
+ if self.relative_attention:
80
+ self.position_buckets = position_buckets
81
+ self.max_relative_positions = max_relative_positions
82
+ if self.max_relative_positions < 1:
83
+ self.max_relative_positions = max_position_embeddings
84
+ self.pos_ebd_size = self.max_relative_positions
85
+ if self.position_buckets > 0:
86
+ self.pos_ebd_size = self.position_buckets
87
+
88
+ self.pos_dropout = Dropout(p=hidden_dropout_prob)
89
+
90
+ if not self.share_att_key:
91
+ if "c2p" in self.pos_att_type:
92
+ self.pos_key_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
93
+ if "p2c" in self.pos_att_type:
94
+ self.pos_query_proj = Linear(in_features=d_model, out_features=self.all_head_size)
95
+
96
+ self.dropout = Dropout(p=attention_probs_dropout_prob)
97
 
98
+ @staticmethod
99
+ def transpose_for_scores(x, attention_heads) -> torch.Tensor:
100
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
101
+ x = x.view(new_x_shape)
102
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
103
 
104
+ def forward(
105
+ self,
106
+ hidden_states,
107
+ attention_mask,
108
+ output_attentions=False,
109
+ relative_pos=None,
110
+ rel_embeddings=None,
111
+ ):
112
+ """
113
+ Call the module
114
 
115
+ Args:
116
+ hidden_states (`torch.FloatTensor`):
117
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
118
+ *Attention(Q,K,V)*
119
 
120
+ attention_mask (`torch.BoolTensor`):
121
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
122
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
123
+ th token.
124
 
125
+ output_attentions (`bool`, *optional*):
126
+ Whether return the attention matrix.
 
 
 
 
127
 
128
+ relative_pos (`torch.LongTensor`):
129
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
130
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
131
 
132
+ rel_embeddings (`torch.FloatTensor`):
133
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
134
+ \\text{max_relative_positions}\\), *hidden_size*].
135
 
136
 
137
+ """
138
+ query_layer = self.transpose_for_scores(self.query_proj(hidden_states), self.num_attention_heads)
139
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
140
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
141
+
142
+ rel_att = None
143
+ # Take the dot product between "query" and "key" to get the raw attention scores.
144
+ scale_factor = 1
145
+ if "c2p" in self.pos_att_type:
146
+ scale_factor += 1
147
+ if "p2c" in self.pos_att_type:
148
+ scale_factor += 1
149
+ scale = scaled_size_sqrt(query_layer, scale_factor)
150
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
151
+ if self.relative_attention:
152
+ rel_embeddings = self.pos_dropout(rel_embeddings)
153
+ rel_att = self.disentangled_attention_bias(
154
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
155
+ )
156
+
157
+ if rel_att is not None:
158
+ attention_scores = attention_scores + rel_att
159
+ attention_scores = attention_scores
160
+ attention_scores = attention_scores.view(
161
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
162
+ )
163
 
164
+ attention_mask = attention_mask.bool()
165
+ attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
166
+ # bsz x height x length x dimension
167
+ attention_probs = functional.softmax(attention_scores, dim=-1)
168
+
169
+ attention_probs = self.dropout(attention_probs)
170
+ context_layer = torch.bmm(
171
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
172
+ )
173
+ context_layer = (
174
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
175
+ .permute(0, 2, 1, 3)
176
+ .contiguous()
177
+ )
178
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
179
+ context_layer = context_layer.view(new_context_layer_shape)
180
+
181
+ return (context_layer, attention_probs) if output_attentions else (context_layer, None)
182
+
183
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
184
+ if relative_pos is None:
185
+ relative_pos = build_relative_position(
186
+ query_layer,
187
+ key_layer,
188
+ bucket_size=self.position_buckets,
189
+ max_position=self.max_relative_positions,
190
+ )
191
+ if relative_pos.dim() == 2:
192
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
193
+ elif relative_pos.dim() == 3:
194
+ relative_pos = relative_pos.unsqueeze(1)
195
+ # bsz x height x query x key
196
+ elif relative_pos.dim() != 4:
197
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
198
+
199
+ att_span = self.pos_ebd_size
200
+ relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
201
+
202
+ rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
203
+ if self.share_att_key:
204
+ pos_query_layer = self.transpose_for_scores(
205
+ self.query_proj(rel_embeddings), self.num_attention_heads
206
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
207
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
208
+ query_layer.size(0) // self.num_attention_heads, 1, 1
209
+ )
210
+ else:
211
+ if "c2p" in self.pos_att_type:
212
+ pos_key_layer = self.transpose_for_scores(
213
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
214
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
215
+ if "p2c" in self.pos_att_type:
216
+ pos_query_layer = self.transpose_for_scores(
217
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
218
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
219
+
220
+ score = 0
221
+ # content->position
222
+ if "c2p" in self.pos_att_type:
223
+ scale = scaled_size_sqrt(pos_key_layer, scale_factor)
224
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
225
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
226
+ c2p_att = torch.gather(
227
+ c2p_att,
228
+ dim=-1,
229
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
230
+ )
231
+ score += c2p_att / scale.to(dtype=c2p_att.dtype)
232
+
233
+ # position->content
234
+ if "p2c" in self.pos_att_type:
235
+ scale = scaled_size_sqrt(pos_query_layer, scale_factor)
236
+ r_pos = build_rpos(
237
+ query_layer,
238
+ key_layer,
239
+ relative_pos,
240
+ self.max_relative_positions,
241
+ self.position_buckets,
242
+ )
243
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
244
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
245
+ p2c_att = torch.gather(
246
+ p2c_att,
247
+ dim=-1,
248
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
249
+ ).transpose(-1, -2)
250
+ score += p2c_att / scale.to(dtype=p2c_att.dtype)
251
+
252
+ return score
253
+
254
+
255
+ class QiDeBERTaSelfOutput(Module):
256
+ def __init__(
257
+ self,
258
+ d_model: int,
259
+ layer_norm_eps: float,
260
+ hidden_dropout_prob: float,
261
+ ):
262
  super().__init__()
263
+ self.dense = Linear(in_features=d_model, out_features=d_model)
264
+ self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps)
265
+ self.dropout = Dropout(p=hidden_dropout_prob)
266
+
267
+ def forward(self, hidden_states, input_tensor):
268
+ hidden_states = self.dense(hidden_states)
269
+ hidden_states = self.dropout(hidden_states)
270
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
271
+ return hidden_states
272
+
273
+
274
+ class QiDeBERTaAttention(Module):
275
+ def __init__(
276
+ self,
277
+ num_heads: int,
278
+ d_model: int,
279
+ share_att_key: bool,
280
+ relative_attention: bool,
281
+ max_position_embeddings: int,
282
+ hidden_dropout_prob: float,
283
+ attention_probs_dropout_prob: float,
284
+ layer_norm_eps: float,
285
+ pos_att_type: Optional[list] = None,
286
+ position_buckets: int = -1,
287
+ max_relative_positions: int = -1,
288
+ ):
289
+ super().__init__()
290
+ self.self = QiDeBERTaDisentangledSelfAttention(
291
+ num_heads=num_heads,
292
+ d_model=d_model,
293
+ share_att_key=share_att_key,
294
+ relative_attention=relative_attention,
295
+ max_position_embeddings=max_position_embeddings,
296
+ hidden_dropout_prob=hidden_dropout_prob,
297
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
298
+ pos_att_type=pos_att_type,
299
+ position_buckets=position_buckets,
300
+ max_relative_positions=max_relative_positions,
301
+ )
302
+ self.output = QiDeBERTaSelfOutput(
303
+ d_model=d_model,
304
+ layer_norm_eps=layer_norm_eps,
305
+ hidden_dropout_prob=hidden_dropout_prob,
306
+ )
307
+
308
+ def forward(
309
+ self,
310
+ hidden_states,
311
+ attention_mask,
312
+ output_attentions: bool = False,
313
+ relative_pos=None,
314
+ rel_embeddings=None,
315
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
316
+ self_output, att_matrix = self.self(
317
+ hidden_states=hidden_states,
318
+ attention_mask=attention_mask,
319
+ output_attentions=output_attentions,
320
+ relative_pos=relative_pos,
321
+ rel_embeddings=rel_embeddings,
322
+ )
323
+ attention_output = self.output(hidden_states=self_output, input_tensor=hidden_states)
324
+
325
+ return (attention_output, att_matrix) if output_attentions else (attention_output, None)
326
+
327
+
328
+ class QiDeBERTaIntermediate(Module):
329
+ def __init__(
330
+ self,
331
+ d_model: int,
332
+ d_ff: int,
333
+ ):
334
+ super().__init__()
335
+ self.dense = Linear(in_features=d_model, out_features=d_ff)
336
+
337
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
338
+ hidden_states = self.dense(hidden_states)
339
+ hidden_states = functional.gelu(hidden_states)
340
+ return hidden_states
341
+
342
+
343
+ class QiDeBERTaOutput(Module):
344
+ def __init__(
345
+ self,
346
+ d_ff: int,
347
+ d_model: int,
348
+ layer_norm_eps: float,
349
+ hidden_dropout_prob: float,
350
+ ):
351
+ super().__init__()
352
+ self.dense = Linear(in_features=d_ff, out_features=d_model)
353
+ self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps)
354
+ self.dropout = Dropout(p=hidden_dropout_prob)
355
+
356
+ def forward(self, hidden_states, input_tensor):
357
+ hidden_states = self.dense(hidden_states)
358
+ hidden_states = self.dropout(hidden_states)
359
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
360
+ return hidden_states
361
 
 
362
 
363
+ class QiDeBERTaLayer(Module):
364
+ def __init__(
365
+ self,
366
+ num_heads: int,
367
+ d_model: int,
368
+ d_ff: int,
369
+ share_att_key: bool,
370
+ relative_attention: bool,
371
+ max_position_embeddings: int,
372
+ hidden_dropout_prob: float,
373
+ attention_probs_dropout_prob: float,
374
+ layer_norm_eps: float,
375
+ pos_att_type: Optional[list] = None,
376
+ position_buckets: int = -1,
377
+ max_relative_positions: int = -1,
378
+ ):
379
+ super().__init__()
380
+ self.attention = QiDeBERTaAttention(
381
+ num_heads=num_heads,
382
+ d_model=d_model,
383
+ share_att_key=share_att_key,
384
+ relative_attention=relative_attention,
385
+ max_position_embeddings=max_position_embeddings,
386
+ hidden_dropout_prob=hidden_dropout_prob,
387
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
388
+ layer_norm_eps=layer_norm_eps,
389
+ pos_att_type=pos_att_type,
390
+ position_buckets=position_buckets,
391
+ max_relative_positions=max_relative_positions,
392
+ )
393
+ self.intermediate = QiDeBERTaIntermediate(
394
+ d_model=d_model,
395
+ d_ff=d_ff
396
+ )
397
+ self.output = QiDeBERTaOutput(
398
+ d_ff=d_ff,
399
+ d_model=d_model,
400
+ layer_norm_eps=layer_norm_eps,
401
+ hidden_dropout_prob=hidden_dropout_prob,
402
+ )
403
 
404
+ def forward(
405
+ self,
406
+ hidden_states,
407
+ attention_mask,
408
+ relative_pos=None,
409
+ rel_embeddings=None,
410
+ output_attentions: bool = False,
411
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
412
+ attention_output, att_matrix = self.attention(
413
+ hidden_states=hidden_states,
414
+ attention_mask=attention_mask,
415
+ output_attentions=output_attentions,
416
+ relative_pos=relative_pos,
417
+ rel_embeddings=rel_embeddings,
418
+ )
419
+ intermediate_output = self.intermediate(attention_output)
420
+ layer_output = self.output(intermediate_output, attention_output)
421
 
422
+ return (layer_output, att_matrix) if output_attentions else (layer_output, None)
423
 
 
424
 
425
+ class QiDeBERTaEncoder(Module):
426
+ """Modified BertEncoder with relative position bias support"""
427
+ def __init__(
428
+ self,
429
+ num_layers: int,
430
+ num_heads: int,
431
+ d_model: int,
432
+ d_ff: int,
433
+ share_att_key: bool,
434
+ relative_attention: bool,
435
+ max_position_embeddings: int,
436
+ hidden_dropout_prob: float,
437
+ attention_probs_dropout_prob: float,
438
+ layer_norm_eps: float,
439
+ pos_att_type: Optional[list] = None,
440
+ position_buckets: int = -1,
441
+ max_relative_positions: int = -1,
442
+ ):
443
+ super().__init__()
444
+
445
+ self.layer = ModuleList([
446
+ QiDeBERTaLayer(
447
+ num_heads=num_heads,
448
+ d_model=d_model,
449
+ d_ff=d_ff,
450
+ share_att_key=share_att_key,
451
+ relative_attention=relative_attention,
452
+ max_position_embeddings=max_position_embeddings,
453
+ hidden_dropout_prob=hidden_dropout_prob,
454
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
455
+ layer_norm_eps=layer_norm_eps,
456
+ pos_att_type=pos_att_type,
457
+ position_buckets=position_buckets,
458
+ max_relative_positions=max_relative_positions,
459
+ )
460
+ for _ in range(num_layers)
461
+ ])
462
+
463
+ self.max_relative_positions = max_position_embeddings
464
+
465
+ self.position_buckets = position_buckets
466
+
467
+ pos_ebd_size = position_buckets * 2
468
+ self.rel_embeddings = Embedding(num_embeddings=pos_ebd_size, embedding_dim=d_model)
469
+
470
+ self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps, elementwise_affine=True)
471
 
472
  self.gradient_checkpointing = False
473
 
 
477
  rel_embeddings = self.LayerNorm(rel_embeddings)
478
  return rel_embeddings
479
 
480
+ @staticmethod
481
+ def get_attention_mask(attention_mask):
482
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
483
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
 
 
484
 
485
  return attention_mask
486
 
 
497
  self,
498
  hidden_states,
499
  attention_mask,
500
+ output_attentions: bool = True,
501
  ):
502
  attention_mask = self.get_attention_mask(attention_mask)
503
  relative_pos = self.get_rel_pos(hidden_states)
 
513
  layer_module.__call__,
514
  next_kv,
515
  attention_mask,
 
516
  relative_pos,
517
  rel_embeddings,
518
+ output_attentions,
519
  )
520
  else:
521
  output_states, attn_weights = layer_module(
522
+ hidden_states=next_kv,
523
+ attention_mask=attention_mask,
 
524
  relative_pos=relative_pos,
525
  rel_embeddings=rel_embeddings,
526
+ output_attentions=output_attentions,
527
  )
528
 
529
+ if output_attentions:
530
+ all_attentions = all_attentions + (attn_weights,)
531
 
532
  all_hidden_states = all_hidden_states + (output_states,)
533
 
534
  next_kv = output_states
535
 
536
+ return EncoderOutput(
537
+ last_hidden_state=output_states,
538
+ hidden_states=all_hidden_states,
539
+ attentions=all_attentions if output_attentions else None
540
  )
541
 
542
 
543
  class QiDeBERTaBase(DebertaV2PreTrainedModel):
544
+ VERSION = '1.1.0'
545
  config_class = QiDeBERTaConfig
546
  base_model_prefix = 'qideberta'
547
  _encoder_layer_path = ''
 
576
  else:
577
  encoder_layer.requires_grad_(requires_grad=True)
578
 
579
+ def freeze_encoder_embed_layer(self, freeze: bool = True):
580
  """
581
  Freeze the embedding layer
582
  :param freeze:
 
603
 
604
  self.embeddings = QiDeBERTaEmbeddings(
605
  pad_token_id=config.pad_token_id,
606
+ d_model=config.d_model,
607
  vocab_size=config.vocab_size,
608
+ layer_norm_eps=config.layer_norm_eps,
609
+ hidden_dropout_prob=config.hidden_dropout_prob,
610
+ )
611
+ self.encoder = QiDeBERTaEncoder(
612
+ num_layers=config.num_layers,
613
+ num_heads=config.num_heads,
614
+ max_position_embeddings=config.max_position_embeddings,
615
+ position_buckets=config.position_buckets,
616
+ d_model=config.d_model,
617
+ d_ff=config.d_ff,
618
+ layer_norm_eps=config.layer_norm_eps,
619
+ share_att_key=config.share_att_key,
620
+ relative_attention=config.relative_attention,
621
+ hidden_dropout_prob=config.hidden_dropout_prob,
622
+ attention_probs_dropout_prob=config.attention_probs_dropout_prob,
623
+ pos_att_type=config.pos_att_type,
624
+ max_relative_positions=config.max_relative_positions,
625
  )
 
626
  # Initialize weights and apply final processing
627
  self.post_init()
628
 
 
634
 
635
  def forward(
636
  self,
637
+ input_ids: torch.Tensor,
638
  attention_mask: Optional[torch.Tensor] = None,
639
+ output_attentions: bool = True,
 
640
  ) -> BaseModelOutput:
641
  """
642
  Forward pass of the model
643
 
644
+ :param input_ids: Token indices of input sequence tokens in the vocabulary. (batch_size, sequence_length)
645
  :param attention_mask: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
646
  - 1 for tokens that are **not masked**,
647
  - 0 for tokens that are **masked**.
648
+ (batch_size, sequence_length)
649
+ :param output_attentions:
650
  :return:
651
  """
652
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
653
+ input_shape = input_ids.size()
654
+ device = input_ids.device
 
 
 
 
 
 
 
 
655
 
656
  if attention_mask is None:
657
  attention_mask = torch.ones(input_shape, device=device)
658
 
659
+ embedding_output, token_embeddings = self.embeddings(
660
  input_ids=input_ids,
661
  mask=attention_mask,
 
662
  )
663
 
664
  encoder_outputs = self.encoder(
665
  hidden_states=embedding_output,
666
  attention_mask=attention_mask,
667
+ output_attentions=output_attentions,
668
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
  return BaseModelOutput(
671
  last_hidden_state=encoder_outputs.last_hidden_state,
672
  hidden_states=encoder_outputs.hidden_states,
673
+ attentions=encoder_outputs.attentions if output_attentions else None,
674
+ embedding_output=embedding_output,
675
+ token_embeddings=token_embeddings, # [B, L, H]
676
  )
677
 
678
 
679
  class QiDeBERTaMLMHead(Module):
680
  def __init__(
681
  self,
682
+ d_model: int,
683
+ vocab_size: int,
684
+ layer_norm_eps: float,
685
  ):
686
  super().__init__()
687
+ self.dense = Linear(in_features=d_model, out_features=d_model)
688
 
689
+ self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps, elementwise_affine=True)
690
 
691
  self.bias = Parameter(torch.zeros(vocab_size))
692
 
 
698
  if module.bias is not None:
699
  module.bias.data.zero_()
700
 
701
+ def _initialize_weights(self, module):
702
+ if getattr(module, "_is_hf_initialized", False):
703
+ return
704
+ self._init_weights(module)
705
+ module._is_hf_initialized = True
706
+
707
  def forward(self, hidden_states: torch.Tensor, word_embeddings: Embedding):
708
  hidden_states = self.dense(hidden_states)
709
+ hidden_states = functional.gelu(hidden_states)
710
  hidden_states = self.LayerNorm(hidden_states)
711
  hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
712
  return hidden_states
713
 
714
 
715
+ class QiDeBERTaClassificationHead(Module):
716
+ def __init__(
717
+ self,
718
+ d_model: int,
719
+ num_labels: int,
720
+ hidden_dropout_prob: float,
721
+ ):
722
+ super().__init__()
723
+
724
+ self.dropout = Dropout(p=hidden_dropout_prob)
725
+ self.classifier = Linear(in_features=d_model, out_features=num_labels)
726
+
727
+ @staticmethod
728
+ def _init_weights(module):
729
+ """Initialize the weights."""
730
+ if isinstance(module, Linear):
731
+ module.weight.data.normal_(mean=0.0, std=0.02)
732
+ if module.bias is not None:
733
+ module.bias.data.zero_()
734
+
735
+ def _initialize_weights(self, module):
736
+ if getattr(module, "_is_hf_initialized", False):
737
+ return
738
+ self._init_weights(module)
739
+ module._is_hf_initialized = True
740
+
741
+ def forward(self, hidden_states: torch.Tensor):
742
+ dropped = self.dropout(hidden_states)
743
+ logits = self.classifier(dropped)
744
+
745
+ return logits
746
+
747
+
748
+ class ContextPooler(Module):
749
+ def __init__(
750
+ self,
751
+ pooler_hidden_size: int = 1024,
752
+ pooler_mode: str = 'token', # mean, max, attn, token
753
+ ):
754
+ super().__init__()
755
+ if pooler_mode not in ['mean', 'max', 'attn', 'token']:
756
+ raise ValueError(f'Invalid pooler mode: {pooler_mode}')
757
+
758
+ self.dense = Linear(in_features=pooler_hidden_size, out_features=pooler_hidden_size)
759
+ self.pooler_mode = pooler_mode
760
+ if self.pooler_mode == 'attn':
761
+ self.attn = MultiheadAttention(embed_dim=pooler_hidden_size, num_heads=pooler_hidden_size, batch_first=True)
762
+ self.LayerNorm = LayerNorm(normalized_shape=pooler_hidden_size, eps=1e-7)
763
+
764
+ def forward(self, hidden_states):
765
+ if self.pooler_mode == 'attn':
766
+ query = hidden_states[:, 0:1, :]
767
+ attn_output, _ = self.attn(query, hidden_states, hidden_states) # [batch_size, 1, hidden_size]
768
+ attn_output = attn_output.squeeze(1) # [batch_size, hidden_size]
769
+ context_token = attn_output + hidden_states[:, 0] # 残差连接
770
+ context_token = self.LayerNorm(context_token) # 仅对 attn 方式归一化
771
+ elif self.pooler_mode == 'mean':
772
+ context_token = hidden_states.mean(dim=1) # 计算所有 token 的平均表示
773
+ elif self.pooler_mode == 'max':
774
+ context_token = hidden_states.max(dim=1).values # 计算所有 token 的最大表示
775
+ elif self.pooler_mode == 'token':
776
+ context_token = hidden_states[:, 0] # 取第一个 token 的表示
777
+
778
+ pooled_output = self.dense(context_token)
779
+ pooled_output = gelu(pooled_output)
780
+ return pooled_output
781
+
782
+
783
  class QiDeBERTaForMaskedLM(QiDeBERTaBase):
784
  _tied_weights_keys = ["mlm_head.weight", "qideberta.embeddings.word_embeddings.weight"]
785
  _encoder_layer_path = 'qideberta.encoder'
786
  _embedding_layer_path = 'qideberta.embeddings'
787
+ task_head = ['mlm_head']
788
 
789
  def __init__(self, config: QiDeBERTaConfig):
790
  super().__init__(config)
791
  self.qideberta = QiDeBERTa(config=config)
792
  self.mlm_head = QiDeBERTaMLMHead(
793
+ d_model=config.d_model,
794
+ vocab_size=config.vocab_size,
795
+ layer_norm_eps=config.layer_norm_eps,
796
  )
797
 
798
  self.post_init()
 
817
  deep_recurrent_refinement_steps=0,
818
  )
819
 
820
+ prediction_scores = self.mlm_head(hidden_states=outputs.last_hidden_state, word_embeddings=self.get_output_embeddings())
821
 
822
  return MaskedLMOutput(
823
  logits=prediction_scores,
README.md CHANGED
@@ -98,7 +98,7 @@ texts = [
98
  "我爱北京天安门,天安门上太阳升。"
99
  ]
100
 
101
- outputs = model(**tokenizer(texts, padding=True)) # BaseModelOutput[last_hidden_state, hidden_states, attentions]
102
  ```
103
 
104
  ## Citation
 
98
  "我爱北京天安门,天安门上太阳升。"
99
  ]
100
 
101
+ outputs = model(**tokenizer(texts, padding=True, return_tensors='pt')) # BaseModelOutput[last_hidden_state, hidden_states, attentions]
102
  ```
103
 
104
  ## Citation
config.json CHANGED
@@ -5,25 +5,26 @@
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
- "AutoConfig": "Configuration.QiDeBERTaConfig",
9
  "AutoModel": "QiDeBERTa.QiDeBERTa",
10
  "AutoModelForMaskedLM": "QiDeBERTa.QiDeBERTaForMaskedLM"
11
  },
12
  "bos_token_id": 1,
 
 
 
13
  "eos_token_id": 2,
14
  "hidden_act": "gelu",
15
  "hidden_dropout_prob": 0.1,
16
- "hidden_size": 768,
17
  "initializer_range": 0.02,
18
- "intermediate_size": 3072,
19
  "layer_norm_eps": 1e-07,
20
  "mask_token_id": 4,
21
  "max_position_embeddings": 512,
22
  "max_relative_positions": -1,
23
  "model_type": "QiDeBERTa",
24
  "norm_rel_ebd": "layer_norm",
25
- "num_attention_heads": 12,
26
- "num_hidden_layers": 12,
27
  "pad_token_id": 3,
28
  "pooler_hidden_size": 768,
29
  "pooler_mode": "token",
@@ -36,7 +37,7 @@
36
  "relative_attention": true,
37
  "share_att_key": true,
38
  "torch_dtype": "float32",
39
- "transformers_version": "4.50.0",
40
  "unk_token_id": 0,
41
  "vocab_size": 25500
42
  }
 
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
+ "AutoConfig": "configuration.QiDeBERTaConfig",
9
  "AutoModel": "QiDeBERTa.QiDeBERTa",
10
  "AutoModelForMaskedLM": "QiDeBERTa.QiDeBERTaForMaskedLM"
11
  },
12
  "bos_token_id": 1,
13
+ "classifier_num_labels": -1,
14
+ "d_ff": 3072,
15
+ "d_model": 768,
16
  "eos_token_id": 2,
17
  "hidden_act": "gelu",
18
  "hidden_dropout_prob": 0.1,
 
19
  "initializer_range": 0.02,
 
20
  "layer_norm_eps": 1e-07,
21
  "mask_token_id": 4,
22
  "max_position_embeddings": 512,
23
  "max_relative_positions": -1,
24
  "model_type": "QiDeBERTa",
25
  "norm_rel_ebd": "layer_norm",
26
+ "num_heads": 12,
27
+ "num_layers": 12,
28
  "pad_token_id": 3,
29
  "pooler_hidden_size": 768,
30
  "pooler_mode": "token",
 
37
  "relative_attention": true,
38
  "share_att_key": true,
39
  "torch_dtype": "float32",
40
+ "transformers_version": "4.52.4",
41
  "unk_token_id": 0,
42
  "vocab_size": 25500
43
  }
Configuration.py → configuration.py RENAMED
@@ -3,14 +3,20 @@ from transformers import PretrainedConfig
3
 
4
  class QiDeBERTaConfig(PretrainedConfig):
5
  model_type = "QiDeBERTa"
 
 
 
 
 
 
6
 
7
  def __init__(
8
  self,
9
  vocab_size=25500,
10
- hidden_size=1024,
11
- num_hidden_layers=24,
12
- num_attention_heads=16,
13
- intermediate_size=4096,
14
  hidden_act="gelu",
15
  hidden_dropout_prob=0.1,
16
  attention_probs_dropout_prob=0.1,
@@ -20,6 +26,7 @@ class QiDeBERTaConfig(PretrainedConfig):
20
  relative_attention=True,
21
  max_relative_positions=-1,
22
  norm_rel_ebd='layer_norm',
 
23
  unk_token_id=0,
24
  bos_token_id=1,
25
  eos_token_id=2,
@@ -34,10 +41,10 @@ class QiDeBERTaConfig(PretrainedConfig):
34
  ):
35
  super().__init__(**kwargs)
36
 
37
- self.hidden_size = hidden_size
38
- self.num_hidden_layers = num_hidden_layers
39
- self.num_attention_heads = num_attention_heads
40
- self.intermediate_size = intermediate_size
41
  self.hidden_act = hidden_act
42
  self.hidden_dropout_prob = hidden_dropout_prob
43
  self.attention_probs_dropout_prob = attention_probs_dropout_prob
@@ -45,6 +52,7 @@ class QiDeBERTaConfig(PretrainedConfig):
45
  self.initializer_range = initializer_range
46
  self.relative_attention = relative_attention
47
  self.max_relative_positions = max_relative_positions
 
48
  self.unk_token_id = unk_token_id
49
  self.bos_token_id = bos_token_id
50
  self.eos_token_id = eos_token_id
@@ -63,5 +71,5 @@ class QiDeBERTaConfig(PretrainedConfig):
63
  self.vocab_size = vocab_size
64
  self.layer_norm_eps = layer_norm_eps
65
 
66
- self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
67
  self.pooler_mode = pooler_mode
 
3
 
4
  class QiDeBERTaConfig(PretrainedConfig):
5
  model_type = "QiDeBERTa"
6
+ attribute_map = {
7
+ "hidden_size": "d_model",
8
+ "num_attention_heads": "num_heads",
9
+ "num_hidden_layers": "num_layers",
10
+ "intermediate_size": "d_ff",
11
+ }
12
 
13
  def __init__(
14
  self,
15
  vocab_size=25500,
16
+ d_model=1024,
17
+ num_layers=24,
18
+ num_heads=16,
19
+ d_ff=4096,
20
  hidden_act="gelu",
21
  hidden_dropout_prob=0.1,
22
  attention_probs_dropout_prob=0.1,
 
26
  relative_attention=True,
27
  max_relative_positions=-1,
28
  norm_rel_ebd='layer_norm',
29
+ classifier_num_labels=-1,
30
  unk_token_id=0,
31
  bos_token_id=1,
32
  eos_token_id=2,
 
41
  ):
42
  super().__init__(**kwargs)
43
 
44
+ self.d_model = d_model
45
+ self.num_layers = num_layers
46
+ self.num_heads = num_heads
47
+ self.d_ff = d_ff
48
  self.hidden_act = hidden_act
49
  self.hidden_dropout_prob = hidden_dropout_prob
50
  self.attention_probs_dropout_prob = attention_probs_dropout_prob
 
52
  self.initializer_range = initializer_range
53
  self.relative_attention = relative_attention
54
  self.max_relative_positions = max_relative_positions
55
+ self.classifier_num_labels = classifier_num_labels
56
  self.unk_token_id = unk_token_id
57
  self.bos_token_id = bos_token_id
58
  self.eos_token_id = eos_token_id
 
71
  self.vocab_size = vocab_size
72
  self.layer_norm_eps = layer_norm_eps
73
 
74
+ self.pooler_hidden_size = kwargs.get("pooler_hidden_size", d_model)
75
  self.pooler_mode = pooler_mode
tokenizer.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
- from typing import Optional, Dict, Any, List, Tuple, Union
3
 
4
  import sentencepiece
5
- import torch
6
- from torch.nn.utils.rnn import pad_sequence
7
  from transformers import PreTrainedTokenizer
8
  from transformers.models.deberta_v2.tokenization_deberta_v2 import SPMTokenizer
9
- from transformers.tokenization_utils_base import TextInput, BatchEncoding
 
10
 
11
 
12
  class QiDeBERTaTokenizer(PreTrainedTokenizer):
@@ -55,6 +55,7 @@ class QiDeBERTaTokenizer(PreTrainedTokenizer):
55
  **kwargs,
56
  )
57
  self._tokenizer.special_tokens = self.all_special_tokens
 
58
 
59
  @property
60
  def vocab_size(self):
@@ -66,63 +67,66 @@ class QiDeBERTaTokenizer(PreTrainedTokenizer):
66
 
67
  def __call__(
68
  self,
69
- texts: Union[str, list[str]],
 
 
 
70
  add_special_tokens: bool = True,
71
- padding: bool = False,
 
 
 
 
 
 
 
 
72
  return_attention_mask: bool = True,
73
- ) -> BatchEncoding:
74
- """
75
- 对输入文本进行编码,返回token ids和attention mask
76
- :return:
77
- """
78
- if isinstance(texts, str):
79
- texts = [texts]
80
- if (isinstance(texts, list) and all(isinstance(text, str) for text in texts)) is not True:
81
- raise ValueError(
82
- f"Input must be a string or a list of strings, but got {type(texts)}"
83
- )
84
-
85
- if len(texts) > 1 and all(len(text) == len(texts[0]) for text in texts) is False and padding is False:
86
- # Torch不支持可变长度的tensor,所以需要padding
87
- print(f'[Warning] The input texts are not the same length, padding is required.')
88
- padding = True
89
-
90
- if padding:
91
- token_ids = pad_sequence(
92
- sequences=[torch.LongTensor(input_id) for input_id in self.encode(texts=texts, add_special_tokens=add_special_tokens)],
93
- batch_first=True,
94
- padding_value=self.processor().pad_id()
95
- )
96
- else:
97
- token_ids = torch.LongTensor(self.encode(texts=texts, add_special_tokens=add_special_tokens))
98
-
99
- if return_attention_mask:
100
- return BatchEncoding(
101
- data={
102
- 'input_ids': token_ids,
103
- 'attention_mask': token_ids != self.processor().pad_id()
104
- }
105
- )
106
- else:
107
- return BatchEncoding(
108
- data={
109
- 'input_ids': token_ids
110
- }
111
- )
112
 
113
  def get_vocab(self):
114
  vocab = self.vocab.copy()
115
  vocab.update(self.get_added_vocab())
116
  return vocab
117
 
118
- def _tokenize(self, text: str) -> List[str]:
119
  """Take as input a string and return a list of strings (tokens) for words/sub-words"""
120
  if self.do_lower_case:
121
  text = text.lower()
122
  return self._tokenizer.tokenize(text)
123
 
124
- def tokenize(self, text: TextInput, **kwargs) -> List[str]:
125
- return super().tokenize(text, **kwargs)[1:]
 
126
 
127
  def _convert_token_to_id(self, token: str):
128
  """Converts a token (str) in an id using the vocab."""
@@ -229,34 +233,6 @@ class QiDeBERTaTokenizer(PreTrainedTokenizer):
229
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
230
  return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix)
231
 
232
- def encode(self, texts: str|list[str], add_special_tokens: bool = True) -> list[int]|list[list[int]]:
233
- """
234
- 编码文本
235
- :param texts:
236
- :return:
237
- """
238
- if isinstance(texts, str):
239
- return [self._tokenizer.spm.bos_id()] + self._tokenizer.spm.encode_as_ids(texts)[1:] + [self._tokenizer.spm.eos_id()] if add_special_tokens else self._tokenizer.spm.encode_as_ids(texts)[1:]
240
- elif isinstance(texts, list):
241
- return [
242
- [self._tokenizer.spm.bos_id()] + ids[1:] + [self._tokenizer.spm.eos_id()] if add_special_tokens else ids[1:]
243
- for ids in self._tokenizer.spm.encode_as_ids(texts)
244
- ]
245
-
246
- def decode(self, token_ids: list[int] or list[list[int]]) -> list[str] or list[list[str]]:
247
- """
248
- 解码文本
249
- :param token_ids:
250
- :return:
251
- """
252
- if token_ids and isinstance(token_ids[0], list):
253
- return [
254
- self._tokenizer.spm.DecodeIds(input=ids)
255
- for ids in token_ids
256
- ]
257
- elif token_ids and isinstance(token_ids[0], int):
258
- return self._tokenizer.spm.DecodeIds(input=token_ids)
259
-
260
  def _get_bos_piece(self) -> str:
261
  """
262
  获取BOS Piece
 
1
  import os
2
+ from typing import Optional, Dict, Any, Tuple
3
 
4
  import sentencepiece
5
+ from torch import TensorType
 
6
  from transformers import PreTrainedTokenizer
7
  from transformers.models.deberta_v2.tokenization_deberta_v2 import SPMTokenizer
8
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, TruncationStrategy
9
+ from transformers.utils import PaddingStrategy
10
 
11
 
12
  class QiDeBERTaTokenizer(PreTrainedTokenizer):
 
55
  **kwargs,
56
  )
57
  self._tokenizer.special_tokens = self.all_special_tokens
58
+ self.space_token_id = self._tokenizer.spm.PieceToId('▁')
59
 
60
  @property
61
  def vocab_size(self):
 
67
 
68
  def __call__(
69
  self,
70
+ text: TextInput|PreTokenizedInput|list[TextInput]|list[PreTokenizedInput],
71
+ text_pair: Optional[TextInput|PreTokenizedInput|list[TextInput]|list[PreTokenizedInput]] = None,
72
+ text_target: Optional[TextInput|PreTokenizedInput|list[TextInput]|list[PreTokenizedInput]] = None,
73
+ text_pair_target: Optional[TextInput|PreTokenizedInput|list[TextInput]|list[PreTokenizedInput]] = None,
74
  add_special_tokens: bool = True,
75
+ padding: bool|str|PaddingStrategy = False,
76
+ truncation: Optional[bool|str|TruncationStrategy] = None,
77
+ max_length: Optional[int] = None,
78
+ stride: int = 0,
79
+ is_split_into_words: bool = False,
80
+ pad_to_multiple_of: Optional[int] = None,
81
+ padding_side: Optional[str] = None,
82
+ return_tensors: str|TensorType = 'pt',
83
+ return_token_type_ids: bool = False,
84
  return_attention_mask: bool = True,
85
+ return_overflowing_tokens: bool = False,
86
+ return_special_tokens_mask: bool = False,
87
+ return_offsets_mapping: bool = False,
88
+ return_length: bool = False,
89
+ verbose: bool = True,
90
+ **kwargs,
91
+ ):
92
+ return super().__call__(
93
+ text=text,
94
+ text_pair=text_pair,
95
+ text_target=text_target,
96
+ text_pair_target=text_pair_target,
97
+ add_special_tokens=add_special_tokens,
98
+ padding=padding,
99
+ truncation=truncation,
100
+ max_length=max_length,
101
+ stride=stride,
102
+ is_split_into_words=is_split_into_words,
103
+ pad_to_multiple_of=pad_to_multiple_of,
104
+ padding_side=padding_side,
105
+ return_tensors=return_tensors,
106
+ return_token_type_ids=return_token_type_ids,
107
+ return_attention_mask=return_attention_mask,
108
+ return_overflowing_tokens=return_overflowing_tokens,
109
+ return_special_tokens_mask=return_special_tokens_mask,
110
+ return_offsets_mapping=return_offsets_mapping,
111
+ return_length=return_length,
112
+ verbose=verbose,
113
+ **kwargs,
114
+ )
 
 
 
 
 
 
 
 
 
115
 
116
  def get_vocab(self):
117
  vocab = self.vocab.copy()
118
  vocab.update(self.get_added_vocab())
119
  return vocab
120
 
121
+ def _tokenize(self, text: str) -> list[str]:
122
  """Take as input a string and return a list of strings (tokens) for words/sub-words"""
123
  if self.do_lower_case:
124
  text = text.lower()
125
  return self._tokenizer.tokenize(text)
126
 
127
+ def tokenize(self, text: TextInput, **kwargs) -> list[str]:
128
+ result = super().tokenize(text, **kwargs)
129
+ return result[1:] if result[0] == '▁' else result
130
 
131
  def _convert_token_to_id(self, token: str):
132
  """Converts a token (str) in an id using the vocab."""
 
233
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
234
  return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix)
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def _get_bos_piece(self) -> str:
237
  """
238
  获取BOS Piece