duzx16
commited on
Commit
·
eb55ff0
1
Parent(s):
9692905
Add empty_init option
Browse files- modeling_chatglm.py +37 -14
modeling_chatglm.py
CHANGED
|
@@ -346,10 +346,18 @@ def attention_fn(
|
|
| 346 |
return outputs
|
| 347 |
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
class SelfAttention(torch.nn.Module):
|
| 350 |
def __init__(self, hidden_size, num_attention_heads,
|
| 351 |
layer_id, hidden_size_per_attention_head=None, bias=True,
|
| 352 |
-
params_dtype=torch.float, position_encoding_2d=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
super(SelfAttention, self).__init__()
|
| 354 |
|
| 355 |
self.layer_id = layer_id
|
|
@@ -377,7 +385,7 @@ class SelfAttention(torch.nn.Module):
|
|
| 377 |
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
| 378 |
|
| 379 |
# Strided linear layer.
|
| 380 |
-
self.query_key_value =
|
| 381 |
torch.nn.Linear,
|
| 382 |
hidden_size,
|
| 383 |
3 * self.inner_hidden_size,
|
|
@@ -385,7 +393,7 @@ class SelfAttention(torch.nn.Module):
|
|
| 385 |
dtype=params_dtype,
|
| 386 |
)
|
| 387 |
|
| 388 |
-
self.dense =
|
| 389 |
torch.nn.Linear,
|
| 390 |
self.inner_hidden_size,
|
| 391 |
hidden_size,
|
|
@@ -498,8 +506,12 @@ class GEGLU(torch.nn.Module):
|
|
| 498 |
|
| 499 |
class GLU(torch.nn.Module):
|
| 500 |
def __init__(self, hidden_size, inner_hidden_size=None,
|
| 501 |
-
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
|
| 502 |
super(GLU, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
self.layer_id = layer_id
|
| 504 |
self.activation_func = activation_func
|
| 505 |
|
|
@@ -508,7 +520,7 @@ class GLU(torch.nn.Module):
|
|
| 508 |
if inner_hidden_size is None:
|
| 509 |
inner_hidden_size = 4 * hidden_size
|
| 510 |
self.inner_hidden_size = inner_hidden_size
|
| 511 |
-
self.dense_h_to_4h =
|
| 512 |
torch.nn.Linear,
|
| 513 |
self.hidden_size,
|
| 514 |
self.inner_hidden_size,
|
|
@@ -516,7 +528,7 @@ class GLU(torch.nn.Module):
|
|
| 516 |
dtype=params_dtype,
|
| 517 |
)
|
| 518 |
# Project back to h.
|
| 519 |
-
self.dense_4h_to_h =
|
| 520 |
torch.nn.Linear,
|
| 521 |
self.inner_hidden_size,
|
| 522 |
self.hidden_size,
|
|
@@ -552,7 +564,8 @@ class GLMBlock(torch.nn.Module):
|
|
| 552 |
use_bias=True,
|
| 553 |
params_dtype=torch.float,
|
| 554 |
num_layers=28,
|
| 555 |
-
position_encoding_2d=True
|
|
|
|
| 556 |
):
|
| 557 |
super(GLMBlock, self).__init__()
|
| 558 |
# Set output layer initialization if not provided.
|
|
@@ -572,7 +585,8 @@ class GLMBlock(torch.nn.Module):
|
|
| 572 |
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
| 573 |
bias=use_bias,
|
| 574 |
params_dtype=params_dtype,
|
| 575 |
-
position_encoding_2d=self.position_encoding_2d
|
|
|
|
| 576 |
)
|
| 577 |
|
| 578 |
# Layernorm on the input data.
|
|
@@ -587,6 +601,7 @@ class GLMBlock(torch.nn.Module):
|
|
| 587 |
bias=use_bias,
|
| 588 |
layer_id=layer_id,
|
| 589 |
params_dtype=params_dtype,
|
|
|
|
| 590 |
)
|
| 591 |
|
| 592 |
def forward(
|
|
@@ -781,9 +796,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 781 |
`encoder_hidden_states` is then expected as an input to the forward pass.
|
| 782 |
"""
|
| 783 |
|
| 784 |
-
def __init__(self, config: ChatGLMConfig):
|
| 785 |
super().__init__(config)
|
| 786 |
-
|
|
|
|
|
|
|
|
|
|
| 787 |
# recording parameters
|
| 788 |
self.max_sequence_length = config.max_sequence_length
|
| 789 |
self.hidden_size = config.hidden_size
|
|
@@ -798,7 +816,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 798 |
self.pre_seq_len = config.pre_seq_len
|
| 799 |
self.prefix_projection = config.prefix_projection
|
| 800 |
|
| 801 |
-
self.word_embeddings =
|
| 802 |
torch.nn.Embedding,
|
| 803 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
| 804 |
dtype=self.params_dtype
|
|
@@ -817,6 +835,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 817 |
use_bias=True,
|
| 818 |
params_dtype=self.params_dtype,
|
| 819 |
position_encoding_2d=self.position_encoding_2d,
|
|
|
|
| 820 |
)
|
| 821 |
|
| 822 |
self.layers = torch.nn.ModuleList(
|
|
@@ -1004,8 +1023,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 1004 |
|
| 1005 |
|
| 1006 |
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
| 1007 |
-
def __init__(self, config: ChatGLMConfig):
|
| 1008 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
# self.hidden_size = config.hidden_size
|
| 1011 |
# self.params_dtype = torch.half
|
|
@@ -1014,9 +1037,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1014 |
|
| 1015 |
self.position_encoding_2d = config.position_encoding_2d
|
| 1016 |
|
| 1017 |
-
self.transformer = ChatGLMModel(config)
|
| 1018 |
|
| 1019 |
-
self.lm_head =
|
| 1020 |
nn.Linear,
|
| 1021 |
config.hidden_size,
|
| 1022 |
config.vocab_size,
|
|
|
|
| 346 |
return outputs
|
| 347 |
|
| 348 |
|
| 349 |
+
def default_init(cls, *args, **kwargs):
|
| 350 |
+
return cls(*args, **kwargs)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
class SelfAttention(torch.nn.Module):
|
| 354 |
def __init__(self, hidden_size, num_attention_heads,
|
| 355 |
layer_id, hidden_size_per_attention_head=None, bias=True,
|
| 356 |
+
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
|
| 357 |
+
if empty_init:
|
| 358 |
+
init_method = skip_init
|
| 359 |
+
else:
|
| 360 |
+
init_method = default_init
|
| 361 |
super(SelfAttention, self).__init__()
|
| 362 |
|
| 363 |
self.layer_id = layer_id
|
|
|
|
| 385 |
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
| 386 |
|
| 387 |
# Strided linear layer.
|
| 388 |
+
self.query_key_value = init_method(
|
| 389 |
torch.nn.Linear,
|
| 390 |
hidden_size,
|
| 391 |
3 * self.inner_hidden_size,
|
|
|
|
| 393 |
dtype=params_dtype,
|
| 394 |
)
|
| 395 |
|
| 396 |
+
self.dense = init_method(
|
| 397 |
torch.nn.Linear,
|
| 398 |
self.inner_hidden_size,
|
| 399 |
hidden_size,
|
|
|
|
| 506 |
|
| 507 |
class GLU(torch.nn.Module):
|
| 508 |
def __init__(self, hidden_size, inner_hidden_size=None,
|
| 509 |
+
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
|
| 510 |
super(GLU, self).__init__()
|
| 511 |
+
if empty_init:
|
| 512 |
+
init_method = skip_init
|
| 513 |
+
else:
|
| 514 |
+
init_method = default_init
|
| 515 |
self.layer_id = layer_id
|
| 516 |
self.activation_func = activation_func
|
| 517 |
|
|
|
|
| 520 |
if inner_hidden_size is None:
|
| 521 |
inner_hidden_size = 4 * hidden_size
|
| 522 |
self.inner_hidden_size = inner_hidden_size
|
| 523 |
+
self.dense_h_to_4h = init_method(
|
| 524 |
torch.nn.Linear,
|
| 525 |
self.hidden_size,
|
| 526 |
self.inner_hidden_size,
|
|
|
|
| 528 |
dtype=params_dtype,
|
| 529 |
)
|
| 530 |
# Project back to h.
|
| 531 |
+
self.dense_4h_to_h = init_method(
|
| 532 |
torch.nn.Linear,
|
| 533 |
self.inner_hidden_size,
|
| 534 |
self.hidden_size,
|
|
|
|
| 564 |
use_bias=True,
|
| 565 |
params_dtype=torch.float,
|
| 566 |
num_layers=28,
|
| 567 |
+
position_encoding_2d=True,
|
| 568 |
+
empty_init=True
|
| 569 |
):
|
| 570 |
super(GLMBlock, self).__init__()
|
| 571 |
# Set output layer initialization if not provided.
|
|
|
|
| 585 |
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
| 586 |
bias=use_bias,
|
| 587 |
params_dtype=params_dtype,
|
| 588 |
+
position_encoding_2d=self.position_encoding_2d,
|
| 589 |
+
empty_init=empty_init
|
| 590 |
)
|
| 591 |
|
| 592 |
# Layernorm on the input data.
|
|
|
|
| 601 |
bias=use_bias,
|
| 602 |
layer_id=layer_id,
|
| 603 |
params_dtype=params_dtype,
|
| 604 |
+
empty_init=empty_init
|
| 605 |
)
|
| 606 |
|
| 607 |
def forward(
|
|
|
|
| 796 |
`encoder_hidden_states` is then expected as an input to the forward pass.
|
| 797 |
"""
|
| 798 |
|
| 799 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
| 800 |
super().__init__(config)
|
| 801 |
+
if empty_init:
|
| 802 |
+
init_method = skip_init
|
| 803 |
+
else:
|
| 804 |
+
init_method = default_init
|
| 805 |
# recording parameters
|
| 806 |
self.max_sequence_length = config.max_sequence_length
|
| 807 |
self.hidden_size = config.hidden_size
|
|
|
|
| 816 |
self.pre_seq_len = config.pre_seq_len
|
| 817 |
self.prefix_projection = config.prefix_projection
|
| 818 |
|
| 819 |
+
self.word_embeddings = init_method(
|
| 820 |
torch.nn.Embedding,
|
| 821 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
| 822 |
dtype=self.params_dtype
|
|
|
|
| 835 |
use_bias=True,
|
| 836 |
params_dtype=self.params_dtype,
|
| 837 |
position_encoding_2d=self.position_encoding_2d,
|
| 838 |
+
empty_init=empty_init
|
| 839 |
)
|
| 840 |
|
| 841 |
self.layers = torch.nn.ModuleList(
|
|
|
|
| 1023 |
|
| 1024 |
|
| 1025 |
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
| 1026 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
| 1027 |
super().__init__(config)
|
| 1028 |
+
if empty_init:
|
| 1029 |
+
init_method = skip_init
|
| 1030 |
+
else:
|
| 1031 |
+
init_method = default_init
|
| 1032 |
|
| 1033 |
# self.hidden_size = config.hidden_size
|
| 1034 |
# self.params_dtype = torch.half
|
|
|
|
| 1037 |
|
| 1038 |
self.position_encoding_2d = config.position_encoding_2d
|
| 1039 |
|
| 1040 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
| 1041 |
|
| 1042 |
+
self.lm_head = init_method(
|
| 1043 |
nn.Linear,
|
| 1044 |
config.hidden_size,
|
| 1045 |
config.vocab_size,
|