Pass throught _init_weights
Browse files- modeling_gptbert.py +1 -11
modeling_gptbert.py
CHANGED
|
@@ -653,17 +653,7 @@ class GptBertPreTrainedModel(PreTrainedModel):
|
|
| 653 |
raise NotImplementedError("Gradient checkpointing is not supported by this model")
|
| 654 |
|
| 655 |
def _init_weights(self, module):
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
if isinstance(module, nn.Linear):
|
| 659 |
-
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 660 |
-
if module.bias is not None:
|
| 661 |
-
module.bias.data.zero_()
|
| 662 |
-
elif isinstance(module, nn.Embedding):
|
| 663 |
-
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 664 |
-
elif isinstance(module, nn.LayerNorm):
|
| 665 |
-
module.bias.data.zero_()
|
| 666 |
-
module.weight.data.fill_(1.0)
|
| 667 |
|
| 668 |
|
| 669 |
class GptBertModel(GptBertPreTrainedModel):
|
|
|
|
| 653 |
raise NotImplementedError("Gradient checkpointing is not supported by this model")
|
| 654 |
|
| 655 |
def _init_weights(self, module):
|
| 656 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
|
| 659 |
class GptBertModel(GptBertPreTrainedModel):
|