File size: 11,691 Bytes
d56eb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""# `shared_space_config.py`

#### `*Config`
"""

from typing import Optional

import torch
from torch import nn

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel

"""`def make_shorthand`"""

def make_shorthand(model_cfg):
    """
    Takes an instance subencoder `*Config` and constructs a shorthand
    name for the model based on settings.
    """

    dense_str = str(model_cfg.num_dense_layers) + "mha + "

    if model_cfg.o_shared_dim is not None:
        o_str = "." + str(model_cfg.o_shared_dim)
    else:
        o_str = ""

    # If no output subspace is used, the dimension will show as -1.
    attn_str = (
        dense_str
        + "mla."
        + str(model_cfg.q_shared_dim)
        + "."
        + str(model_cfg.kv_shared_dim)
        + o_str
    )

    # MLP Configuration
    if model_cfg.ffn_decompose:
        dense_str = (
            str(model_cfg.num_dense_layers)
            + "mlp."
            + str(model_cfg.intermediate_size)
            + " + "
        )

        mlp_str = (
            dense_str
            + str(model_cfg.num_hidden_layers - model_cfg.num_dense_layers)
            + "dcmp."
            + "x"
            + str(model_cfg.intermediate_size)
            + "."
            + str(model_cfg.ffn_rank)
        )
    else:
        mlp_str = "mlp." + str(model_cfg.intermediate_size)

    # Assemble string
    shorthand = (
        f"{attn_str} - {mlp_str} - "
        f"h{model_cfg.hidden_size} - l{model_cfg.num_hidden_layers}"
    )

    """
    The run name includes training settings

    run_name = (
        f"{config['stats']['total_elements']} - "
        f"{attn_str} - {mlp_str} - "
        f"h{model_cfg.hidden_size} - l{model_cfg.num_hidden_layers} - "
        f"bs{ptrain_cfg['train_batch_size']} - lr{lr_str} - "
        f"seq{ptrain_cfg['max_seq_length']}"
    )
    """

    return shorthand


class SharedSpaceDecoderConfig(PretrainedConfig):
    r"""
    Configuration class for SharedSpaceDecoderConfig.

    Extends the HuggingFace `PretrainedConfig` to support architectural
    variations including:
    - Multi-Head Latent Attention (MLA)
    - Decomposed MLPs (low-rank FFNs)
    - Flexible attention backends (eager, flash, sdpa)
    - Explicit shared subspaces for Q, K, V, and O projections

    This config does not infer any defaults based on `hidden_size`. All
    dimensions and ranks must be explicitly specified. If required values are
    missing, a `ValueError` is raised during initialization.

    ----------------------
    Core Model Parameters:
    ----------------------
    - vocab_size (`int`) β€” Vocabulary size.
    - hidden_size (`int`) β€” Model hidden dimension.
    - num_hidden_layers (`int`) β€” Number of transformer blocks.
    - intermediate_size (`int`) β€” Feed-forward hidden dimension.
    - hidden_act (`str`) β€” Activation function.
    - hidden_dropout_prob (`float`) β€” Dropout after projections and FFNs.
    - attention_dropout_prob (`float`) β€” Dropout applied to attention scores.
    - max_position_embeddings (`int`) β€” Max sequence length.
    - initializer_range (`float`) β€” Stddev of weight init.

    - layer_norm_eps (`float`) β€” Epsilon for LayerNorm.
    - rms_norm_ps (`float`) β€” Epsilon for RMSNorm

    - classifier_dropout (`float` or None) β€” Dropout for final classifier.

    - vocab_subspace
    - vocab_rank

    ----------------------------------
    Multi-Head Latent Attention (MLA):
    ----------------------------------
    - num_attention_heads (`int`) β€” Number of attention heads.

    - q_shared_dim (`int`) β€” Rank of the shared query subspace.
    - kv_shared_dim (`int`) β€” Rank of the shared key/value subspace.

    - output_subspace (`bool`) β€” Whether to use a shared latent subspace for output projections.
    - o_shared_dim (`int`) β€” Rank of the shared output subspace (required if `output_subspace=True`).
    - qk_private_dim (`int`) β€” Query/key private dimension per head.
    - vo_private_dim (`int`) β€” Value/output private dimension per head.

    - rope_dims (`int`) β€” Number of head dimensions carrying RoPE.
    - nope_dims (`int`) β€” Non-positional encoding dimensions.
    - rope_theta (`float`) β€” Base frequency used for RoPE.
    - rope_scaling (`dict` or None) β€” HF-style scaling dict for RoPE.
    - attention_bias (`bool`) β€” Whether to include bias terms in Q/K/V projections.
    - num_dense_layers (`int`) β€” Number of leading layers that do not use
                                 subspaces for attention or FFNs.
    - attention_backend (`str`) β€” Must be one of `"eager"`, `"flash_attention_2"`, or `"sdpa"`.

    ----------------------
    Decomposed MLP (Low-Rank FFN):
    ----------------------
    - ffn_decompose (`bool`) β€” Whether to enable low-rank FFNs.
    - ffn_rank (`int`) β€” Rank of the shared FFN latent space (required if `ffn_decompose=True`).

    ----------------------
    Validation Behavior:
    ----------------------
    Raises `ValueError` at init time if:
    - FFN decomposition is enabled without specifying `ffn_rank`.
    - An unknown `attention_backend` is provided.
    """

    model_type = "shared_subspace_decoder"

    def __init__(
        self,

        # === Core Model ===
        vocab_size:         int = 30522,
        hidden_size:        int = 512,
        num_hidden_layers:  int = 12,

        intermediate_size:  int = 3072,

        hidden_dropout_prob=0.1,
        attention_dropout_prob=0.1,
        max_position_embeddings: int = 2048,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        rms_norm_eps=1e-6, # Their default, but confirm in config.
        norm_type="layernorm", # Choice between "layernorm" and "rmsnorm"
        classifier_dropout=None,

        vocab_subspace=False,
        vocab_rank=None,
        tie_word_embeddings=True,

        # === Multi-Head Latent Attention ===
        num_attention_heads: int = 16,
        rope_dims:           int = 16,

        q_shared_dim:        int = None,
        kv_shared_dim:       int = None,

        o_shared_dim=None,  # If None, no output subspace is used

        # Private head dimensions
        qk_private_dim:      int = None,  # Query/key private dimension per head
        vo_private_dim:      int = None,  # Value/output private dimension per head
        nope_dims:           int = None,  # Non-positional encoding dimensions

        attention_backend="eager",
        rope_theta=10000.0,
        rope_scaling=None,
        attention_bias=False,

        # === MLA Composition ===
        num_dense_layers=12,  # dense MHA layers before MLA starts

        # === Decomposed MLP ===
        ffn_decompose=False,
        ffn_rank=None,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)



        # === Core Model ===
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_dropout_prob = attention_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.rms_norm_eps = rms_norm_eps
        self.norm_type = norm_type
        self.classifier_dropout = classifier_dropout

        self.vocab_subspace = vocab_subspace
        self.vocab_rank = vocab_rank
        self.tie_word_embeddings = tie_word_embeddings

        # === MLA ===
        self.num_attention_heads = num_attention_heads
        self.rope_dims = rope_dims

        self.q_shared_dim = q_shared_dim
        self.kv_shared_dim = kv_shared_dim
        self.o_shared_dim = o_shared_dim

        # Private head dimensions
        self.qk_private_dim = qk_private_dim
        self.vo_private_dim = vo_private_dim
        self.nope_dims = nope_dims
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.attention_bias = attention_bias
        self.num_dense_layers = num_dense_layers

        # === Decomposed FFN ===
        self.ffn_decompose = ffn_decompose
        self.ffn_rank = ffn_rank

        # === Attention backend ===
        self.attention_backend = attention_backend

        # === Validation ===
        # TODO - Somewhere during training these get instantiated with bad
        #        values...
        #self._validate()

        #print(f"  > SubEnc *Config.init: {make_shorthand(self)}\n")


    def _validate(self):
        # === Model ===
        if self.num_dense_layers > self.num_hidden_layers:
            raise ValueError("`num_dense_layers` must be <= `num_hidden_layers`")
        if self.vocab_subspace and self.vocab_rank is None:
            raise ValueError("`vocab_rank` must be set when `vocab_subspace=True`")

        # === MLA Validation ===
        # At least one of q_shared_dim or kv_shared_dim must be set if we have subspace layers
        if self.num_dense_layers < self.num_hidden_layers and self.q_shared_dim is None and self.kv_shared_dim is None:
            raise ValueError("At least one of q_shared_dim or kv_shared_dim must be set when there are subspace layers")
        
        # Validate that private dimensions are set
        if self.qk_private_dim is None or self.vo_private_dim is None:
            raise ValueError("Must set qk_private_dim and vo_private_dim")
        if self.nope_dims is None:
            raise ValueError("Must set nope_dims")

        # === Decomposed FFN ===
        if self.ffn_decompose and self.ffn_rank is None:
            raise ValueError("`ffn_rank` must be set when `ffn_decompose=True`")
        if self.ffn_decompose and self.num_dense_layers >= self.num_hidden_layers:
            raise ValueError("`ffn_decompose` was set but `num_dense` is >= number of layers")

        # === Attention Backend ===
        valid_backends = ["eager", "flash_attention_2", "sdpa"]
        if self.attention_backend not in valid_backends:
            raise ValueError(f"Unknown attention backend: {self.attention_backend}, options are {valid_backends}")
        
        # === Norm Type ===
        valid_norm_types = ["layernorm", "rmsnorm"]
        if self.norm_type not in valid_norm_types:
            raise ValueError(f"Unknown norm type: {self.norm_type}, options are {valid_norm_types}")

#### `get_config`

import json

def get_config(filename):

    # Load the config file.
    with open(filename) as f:
        full_cfg = json.load(f)

    # Strict key check on the model configuration.

    # Get the list of keys allowed / required by `*Config`
    valid_keys = SharedSpaceDecoderConfig.__init__.__code__.co_varnames
    # Remove `self` and `kwargs`
    valid_keys = set(valid_keys) - {"self", "kwargs"}

    # Compare the set of keys in the json file vs `*Config`
    extra_keys = set(full_cfg["model"]) - valid_keys
    missing_keys = valid_keys - set(full_cfg["model"])

    # If there any in the `json` that aren't in `*Config`,
    if extra_keys:
        # List them for the user.
        raise ValueError(f"Unknown keys in config: {sorted(extra_keys)}")

    #  If the json config is missing required keys,
    if missing_keys:
        # List them for the user.
        raise ValueError(f"config json is missing: {sorted(missing_keys)}")

    # Will raise TypeError, by design, if required args are missing
    # The asterisks unpack the dictionary into a list of keywords as though
    # all of the settings were writting out individually.
    model_cfg = SharedSpaceDecoderConfig(**full_cfg["model"])

    return full_cfg, model_cfg