nsa-117m-byte-sft / configuration_nsa.py
seconds-0's picture
Upload NSAForCausalLM
9a7a74a verified
# Remote code: configuration and modeling for NSA
from transformers import PretrainedConfig
class NSAConfig(PretrainedConfig):
model_type = "nsa"
def __init__(
self,
vocab_size=50257,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
n_kv_groups=1,
d_k=64,
d_v=64,
max_position_embeddings=2048,
rope_theta=10000,
nsa=None,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_kv_groups = n_kv_groups
self.d_k = d_k
self.d_v = d_v
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.nsa = nsa or {
"branches": ["cmp", "sel", "win"],
"window": 512,
"gqa_groups": n_kv_groups,
"block": 32,
"stride": 16,
"sel_block": 64,
"sel_top_n": 16,
}