|
|
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
|
class NSAConfig(PretrainedConfig): |
|
|
model_type = "nsa" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=50257, |
|
|
hidden_size=768, |
|
|
num_hidden_layers=12, |
|
|
num_attention_heads=12, |
|
|
n_kv_groups=1, |
|
|
d_k=64, |
|
|
d_v=64, |
|
|
max_position_embeddings=2048, |
|
|
rope_theta=10000, |
|
|
nsa=None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.n_kv_groups = n_kv_groups |
|
|
self.d_k = d_k |
|
|
self.d_v = d_v |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.rope_theta = rope_theta |
|
|
self.nsa = nsa or { |
|
|
"branches": ["cmp", "sel", "win"], |
|
|
"window": 512, |
|
|
"gqa_groups": n_kv_groups, |
|
|
"block": 32, |
|
|
"stride": 16, |
|
|
"sel_block": 64, |
|
|
"sel_top_n": 16, |
|
|
} |
|
|
|