ffurfaro commited on
Commit
73bcbf2
·
verified ·
1 Parent(s): 5aafa2e

Upload model + init tptt code

Browse files
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ pipeline_tag: text-generation
10
+ base_model: mistralai/Mistral-7B-v0.3
11
+ datasets:
12
+ - yahma/alpaca-cleaned
13
+ ---
14
+
15
+ # Titans-v2-Mistral-7B-v0.3
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
+
35
+ Titanesque version of `mistralai/Mistral-7B-v0.3` with parallel linearized attention (TPTT 😊) and PEFT.
36
+
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
+
39
+
40
+ ## Model list
41
+
42
+ Classic model parameter with LiZA injection :
43
+
44
+ | Subfolder | Max Self Attn Length | Mag Weight | Cross Gate | Max Chunk Size | Bidirectional | LoRA | Description |
45
+ |-------------------------------|----------------------|------------|------------|----------------|---------------|------|-------------------------------------------------------|
46
+ | delta_rule | 8192 (default) | 0.5 | False | 64 | False | Yes | Parallel linearized attention with delta_rule operator|
47
+ | delta_rule_gelu | 8192 (default) | 0.5 | False | 64 | False | Yes | Non-linear operator with gelu activation |
48
+ | delta_product | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with derivative trick |
49
+ | delta_product_r | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with rotative trick |
50
+ | delta_product_c | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with combined trick |
51
+
52
+ ## Usage
53
+
54
+ ```python
55
+ from transformers import AutoModelForCausalLM, AutoTokenizer
56
+
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ "ffurfaro/Titans-v2-Mistral-7B-v0.3",
59
+ subfolder="tptt_subfolder", # see in repo tree
60
+ trust_remote_code=True
61
+ )
62
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/mistralai/Mistral-7B-v0.3")
63
+
64
+ prompt = "Your prompt here"
65
+ inputs = tokenizer(prompt, return_tensors="pt")
66
+ outputs = model.generate(**inputs, max_new_tokens=100)
67
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
68
+
69
+ ```
70
+
71
+
72
+ ## Citation & Contact
73
+
74
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
75
+
76
+
77
+ ---
__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
3
+ """
4
+
5
+ from .configuration_tptt import (TpttConfig, generate_model_card,
6
+ parse_mode_name)
7
+ from .modeling_tptt import (LCache, LinearAttention, LinearAttentionOp,
8
+ LiZAttention, TpttModel, get_tptt_model,
9
+ load_tptt_safetensors, save_tptt_safetensors)
10
+ from .train_tptt import LiZACallback, SaveBestModelCallback
11
+
12
+ __all__ = [
13
+ "TpttConfig",
14
+ "TpttModel",
15
+ "get_tptt_model",
16
+ "LiZACallback",
17
+ "SaveBestModelCallback",
18
+ "LCache",
19
+ "LinearAttentionOp",
20
+ "LiZAttention",
21
+ "generate_model_card",
22
+ "LinearAttention",
23
+ "parse_mode_name",
24
+ "load_tptt_safetensors",
25
+ "save_tptt_safetensors",
26
+ ]
configuration_tptt.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+ """
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
+
12
+ import psutil
13
+ import torch
14
+ from transformers import AutoConfig, PretrainedConfig
15
+
16
+ logger = logging.getLogger(__name__) # monitoring
17
+
18
+ # Constants
19
+ BYTES_IN_GB = 1024**3
20
+
21
+
22
+ def convert_sets_to_lists(obj):
23
+ """Convert sets to list for LoRA serialized config"""
24
+ if isinstance(obj, set):
25
+ return list(obj)
26
+ if isinstance(obj, dict):
27
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
28
+ if isinstance(obj, (list, tuple)):
29
+ return [convert_sets_to_lists(x) for x in obj]
30
+ return obj
31
+
32
+
33
+ class TpttConfig(PretrainedConfig):
34
+ """
35
+ Configuration class for the TPTT model.
36
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
37
+ """
38
+
39
+ model_type = "tptt"
40
+ auto_map = {
41
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
42
+ "AutoConfig": "configuration_tptt.TpttConfig",
43
+ }
44
+ architectures = ["TpttModel"]
45
+
46
+ RECURRENT_MODES = {
47
+ "delta_rule": {
48
+ "order": 1,
49
+ "gate_type": "k",
50
+ "linear": True,
51
+ "trick": "derivative",
52
+ },
53
+ "delta_rule_v": {
54
+ "order": 1,
55
+ "gate_type": "v",
56
+ "linear": True,
57
+ "trick": "derivative",
58
+ },
59
+ "delta_rule_kv": {
60
+ "order": 1,
61
+ "gate_type": "kv",
62
+ "linear": True,
63
+ "trick": "derivative",
64
+ },
65
+ "delta_rule_gelu": {
66
+ "order": 1,
67
+ "gate_type": "k",
68
+ "linear": False,
69
+ "trick": "derivative",
70
+ },
71
+ "delta_product": {
72
+ "order": 2,
73
+ "gate_type": "k",
74
+ "linear": True,
75
+ "trick": "derivative",
76
+ },
77
+ "delta_product_r": {
78
+ "order": 2,
79
+ "gate_type": "k",
80
+ "linear": True,
81
+ "trick": "rotative",
82
+ },
83
+ "delta_product_c": {
84
+ "order": 2,
85
+ "gate_type": "k",
86
+ "linear": True,
87
+ "trick": "combined",
88
+ },
89
+ } # Tested modes, see parse_mode_name if you want to add more
90
+
91
+ def __init__(
92
+ self,
93
+ base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
94
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
95
+ base_model_subfolder: Optional[str] = None,
96
+ name_or_path: Optional[str] = None,
97
+ model_task: str = "causal_lm",
98
+ target_modules_names: Optional[List[str]] = None,
99
+ operator_mode: str = "delta_rule",
100
+ use_linear_checkpoint: Optional[bool] = None,
101
+ max_self_attn_length: Optional[
102
+ int
103
+ ] = None, # unnecessary if SWA, else, standards 8192
104
+ base_scale_attn: bool = False,
105
+ mag_weight: float = 0.5, # if 1.0, use only linear operator
106
+ cross_gate: bool = False, # unlinear mixing strategy
107
+ max_chunk_size: int = 64, # 128 if adaptive chunking (longest)
108
+ linear_precision: Union[str, torch.dtype] = "float32",
109
+ lora_config: Optional[dict] = None, # only serialized accepted
110
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
111
+ bidirectional: bool = False, # if True, use bidirectional attention
112
+ pooling_config: Optional[Dict[str, Any]] = None,
113
+ **kwargs,
114
+ ):
115
+ # If base_model_config is provided, load it and merge with this config
116
+ if base_model_config is not None:
117
+ if isinstance(base_model_config, PretrainedConfig):
118
+ base_model_config = base_model_config.to_dict()
119
+ else:
120
+ # Load config from Hugging Face Hub or a local path
121
+ base_model_config = AutoConfig.from_pretrained(
122
+ base_model_name, **kwargs
123
+ ).to_dict()
124
+ # Merge all backbone fields into this config
125
+ for k, v in base_model_config.items():
126
+ setattr(self, k, v)
127
+
128
+ self.base_model_name = base_model_name
129
+ self.base_model_subfolder = base_model_subfolder
130
+ self.model_task = model_task
131
+
132
+ if name_or_path is not None:
133
+ self._name_or_path = name_or_path
134
+ else:
135
+ if "/" in base_model_name:
136
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
137
+ else:
138
+ self._name_or_path = "Titans-" + base_model_name
139
+
140
+ self.target_modules_names = target_modules_names or [
141
+ "attn",
142
+ "self_attn",
143
+ "attention",
144
+ ]
145
+ self.operator_mode = operator_mode
146
+
147
+ # Detect available memory on accelerator device
148
+ if torch.cuda.is_available():
149
+ _, total_mem = torch.cuda.mem_get_info()
150
+ else:
151
+ total_mem = psutil.virtual_memory().total
152
+ total_mem_gb = total_mem / BYTES_IN_GB
153
+
154
+ self.use_linear_checkpoint = (
155
+ total_mem_gb < 16
156
+ if use_linear_checkpoint is None
157
+ else use_linear_checkpoint
158
+ )
159
+
160
+ self.base_scale_attn = base_scale_attn
161
+ self.mag_weight = mag_weight
162
+ self.cross_gate = cross_gate
163
+ self.max_chunk_size = max_chunk_size
164
+ self.max_self_attn_length = max_self_attn_length
165
+ if isinstance(linear_precision, torch.dtype):
166
+ linear_precision = str(linear_precision).replace("torch.", "")
167
+ self.linear_precision = linear_precision
168
+
169
+ self.lora_config = lora_config
170
+ if lora_config is not None:
171
+ if hasattr(self.lora_config.get("peft_type"), "value"):
172
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
173
+ self.lora_config = convert_sets_to_lists(self.lora_config)
174
+
175
+ self.padding_side = padding_side
176
+ self.bidirectional = bidirectional
177
+ if self.bidirectional:
178
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
179
+ self.pooling_config = pooling_config
180
+
181
+ super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
182
+ # Copy class attributes to instance for serialization (save dict)
183
+ self.model_type = self.__class__.model_type
184
+ self.auto_map = self.__class__.auto_map
185
+ self.architectures = self.__class__.architectures
186
+ # Padding side configuration if not set
187
+ if self.padding_side is None:
188
+ self.padding_side = "right"
189
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
190
+ # set recurrent configuration from operator mode
191
+ if operator_mode not in self.__class__.RECURRENT_MODES:
192
+ self.recurrent_config = parse_mode_name(operator_mode)
193
+ else:
194
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
195
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
196
+
197
+
198
+ TpttConfig.register_for_auto_class()
199
+
200
+
201
+ def parse_mode_name(name: str) -> dict:
202
+ """Parse mode to recurrent config"""
203
+ if name.startswith("delta_product"):
204
+ parts = name.split("_")
205
+ # Prefix is always two words: 'delta' and 'product'
206
+ base_len = 2
207
+ order = 2
208
+ gate_type = "k"
209
+ linear = True
210
+ trick = "derivative"
211
+
212
+ idx = base_len
213
+ # Check for order (immediately after the prefix)
214
+ if len(parts) > idx and parts[idx].isdigit():
215
+ order = int(parts[idx])
216
+ idx += 1
217
+
218
+ remaining = parts[idx:]
219
+ # Trick (r/c) is always at the far right if present
220
+ if remaining and remaining[-1] in ("r", "c"):
221
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
222
+ remaining = remaining[:-1]
223
+ # 'gelu' comes just before the trick if present
224
+ if remaining and remaining[-1] == "gelu":
225
+ linear = False
226
+ remaining = remaining[:-1]
227
+ # If anything remains, it's the gate_type
228
+ if remaining:
229
+ gate_type = "_".join(remaining)
230
+ return {
231
+ "order": order,
232
+ "gate_type": gate_type,
233
+ "linear": linear,
234
+ "trick": trick,
235
+ }
236
+
237
+ # delta_rule[_gate][_gelu]
238
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
239
+ if m:
240
+ return {
241
+ "order": 1,
242
+ "gate_type": m.group(1) if m.group(1) else "k",
243
+ "linear": not bool(m.group(2)),
244
+ "trick": "derivative",
245
+ }
246
+ raise ValueError(f"Unknown mode: {name}")
247
+
248
+
249
+ def get_mode_name(
250
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
251
+ ) -> str:
252
+ """Get recurrent mode name from parameter"""
253
+ base = (
254
+ "delta_rule"
255
+ if order == 1
256
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
257
+ )
258
+ parts = []
259
+ if gate_type != "k":
260
+ parts.append(gate_type)
261
+ if not linear:
262
+ parts.append("gelu")
263
+ if order >= 2 and trick != "derivative":
264
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
265
+ return base + (("_" + "_".join(parts)) if parts else "")
266
+
267
+
268
+ def render_template(template_path: str, variables: dict) -> str:
269
+ """Load and render a Jinja2 template from any file path."""
270
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
271
+ template = env.get_template(os.path.basename(template_path))
272
+ return template.render(**variables)
273
+
274
+
275
+ def write_model_card(output_path: str, content: str):
276
+ """Write the generated content into README.md."""
277
+ os.makedirs(output_path, exist_ok=True)
278
+ readme_path = os.path.join(output_path, "README.md")
279
+ with open(readme_path, "w", encoding="utf-8") as f:
280
+ f.write(content)
281
+
282
+
283
+ def generate_model_card(
284
+ output_path: str,
285
+ config: Union[dict, object],
286
+ template: Optional[
287
+ str
288
+ ], # can be "model_card" OR an absolute/relative path to a .md file
289
+ extra_variables: Optional[Dict] = None,
290
+ ):
291
+ """
292
+ Generate a README.md file from a Jinja2 template and a configuration.
293
+
294
+ - template can be either:
295
+ * a full path to a template file
296
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
297
+ """
298
+ if template is None:
299
+ template = "model_card_template" # default template name
300
+ # Locate the template
301
+ if os.path.exists(template): # direct file path provided
302
+ template_path = template
303
+ else:
304
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
305
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
306
+
307
+ if not os.path.exists(template_path):
308
+ raise FileNotFoundError(f"Template not found: {template_path}")
309
+
310
+ variables = {
311
+ "model_id": os.path.basename(output_path),
312
+ "config": config,
313
+ }
314
+ if extra_variables:
315
+ variables.update(extra_variables)
316
+
317
+ content = render_template(template_path, variables)
318
+ write_model_card(output_path, content)
lora_delta_product_m0.5_constant/README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ pipeline_tag: text-generation
10
+ base_model: mistralai/Mistral-7B-v0.3
11
+ datasets:
12
+ - yahma/alpaca-cleaned
13
+ ---
14
+
15
+ # lora_delta_product_m0.5_constant
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
+
35
+ Titanesque version of `mistralai/Mistral-7B-v0.3` with parallel linearized attention (TPTT 😊) and PEFT.
36
+
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
+
39
+
40
+ ## Model Details
41
+
42
+ - **Architecture:** ['TpttModel']
43
+ - **Base model:** mistralai/Mistral-7B-v0.3
44
+ - **LiZA config:** operator=delta_product, mag=0.5
45
+ - **LoRA config:** r=8, alpha=16, dropout=0.05
46
+ - **torch_dtype:**
47
+
48
+ ## Usage
49
+
50
+
51
+ ```python
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ "ffurfaro/lora_delta_product_m0.5_constant",
56
+ trust_remote_code=True
57
+ )
58
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/mistralai/Mistral-7B-v0.3")
59
+
60
+ prompt = "Your prompt here"
61
+ inputs = tokenizer(prompt, return_tensors="pt")
62
+ outputs = model.generate(**inputs, max_new_tokens=100)
63
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
64
+
65
+ ```
66
+
67
+ > [!IMPORTANT]
68
+ > You must specify the `subfolder` if the repo contains multiple models, see the homepage for details.
69
+
70
+ ## Training
71
+
72
+ - **Dataset:** yahma/alpaca-cleaned
73
+ - **Platform:** Kaggle
74
+ - **Hardware:** NVIDIA 2xT4
75
+ - **Batch size:** 1
76
+ - **Epochs:** 1.0
77
+ - **Learning rate (final):** N/A
78
+ - **Loss (final):** 1.2633397308452659
79
+ - **Training runtime:** 11512.6247 sec
80
+ - **Samples per second:** 0.174
81
+ - **Steps per second:** 0.174
82
+ - **Total FLOPs:** 5574366965268480.0
83
+ - **Gradient norm (final):** N/A
84
+
85
+ ## Evaluation
86
+
87
+ - **Metrics:** Training loss only (no eval yet, table soon : PiQA, ARC, Hella, Wino, GSM8K, MMLU)
88
+ - **Results:** Final training loss: 1.2633397308452659
89
+
90
+
91
+ ## Citation & Contact
92
+
93
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
94
+
95
+
96
+ ---
lora_delta_product_m0.5_constant/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa460de2d5833515a362ba04953808e73f83cd2f909de119d2605e18b79d8ec9
3
+ size 27298792
lora_delta_product_m0.5_constant/config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TpttModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_tptt.TpttConfig",
8
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel"
9
+ },
10
+ "base_model_name": "mistralai/Mistral-7B-v0.3",
11
+ "base_model_subfolder": null,
12
+ "base_scale_attn": false,
13
+ "bidirectional": false,
14
+ "cross_gate": false,
15
+ "head_dim": 128,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 4096,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 14336,
20
+ "linear_precision": "bfloat16",
21
+ "lora_config": {
22
+ "alpha_pattern": {},
23
+ "auto_mapping": null,
24
+ "base_model_name_or_path": null,
25
+ "bias": "none",
26
+ "eva_config": null,
27
+ "exclude_modules": null,
28
+ "fan_in_fan_out": false,
29
+ "inference_mode": false,
30
+ "init_lora_weights": true,
31
+ "layer_replication": null,
32
+ "layers_pattern": null,
33
+ "layers_to_transform": null,
34
+ "loftq_config": {},
35
+ "lora_alpha": 16,
36
+ "lora_bias": false,
37
+ "lora_dropout": 0.05,
38
+ "megatron_config": null,
39
+ "megatron_core": "megatron.core",
40
+ "modules_to_save": null,
41
+ "peft_type": "LORA",
42
+ "r": 8,
43
+ "rank_pattern": {},
44
+ "revision": null,
45
+ "target_modules": [
46
+ "q_proj",
47
+ "k_proj",
48
+ "o_proj",
49
+ "v_proj"
50
+ ],
51
+ "task_type": "CAUSAL_LM",
52
+ "use_dora": false,
53
+ "use_rslora": false
54
+ },
55
+ "mag_weight": 0.5,
56
+ "max_chunk_size": 32,
57
+ "max_position_embeddings": 32768,
58
+ "max_self_attn_length": null,
59
+ "model_task": "causal_lm",
60
+ "model_type": "tptt",
61
+ "num_attention_heads": 32,
62
+ "num_hidden_layers": 32,
63
+ "num_key_value_heads": 8,
64
+ "operator_mode": "delta_product",
65
+ "padding_side": "left",
66
+ "pooling_config": null,
67
+ "recurrent_config": {
68
+ "gate_type": "k",
69
+ "linear": true,
70
+ "order": 2,
71
+ "trick": "derivative"
72
+ },
73
+ "rms_norm_eps": 1e-05,
74
+ "rope_theta": 1000000.0,
75
+ "sliding_window": null,
76
+ "target_modules_names": [
77
+ "attn",
78
+ "self_attn",
79
+ "attention"
80
+ ],
81
+ "torch_dtype": "bfloat16",
82
+ "transformers_version": "4.49.0",
83
+ "use_cache": true,
84
+ "use_linear_checkpoint": true,
85
+ "vocab_size": 32768
86
+ }
lora_delta_product_m0.5_constant/configuration_tptt.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+ """
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
+
12
+ import psutil
13
+ import torch
14
+ from transformers import AutoConfig, PretrainedConfig
15
+
16
+ logger = logging.getLogger(__name__) # monitoring
17
+
18
+ # Constants
19
+ BYTES_IN_GB = 1024**3
20
+
21
+
22
+ def convert_sets_to_lists(obj):
23
+ """Convert sets to list for LoRA serialized config"""
24
+ if isinstance(obj, set):
25
+ return list(obj)
26
+ if isinstance(obj, dict):
27
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
28
+ if isinstance(obj, (list, tuple)):
29
+ return [convert_sets_to_lists(x) for x in obj]
30
+ return obj
31
+
32
+
33
+ class TpttConfig(PretrainedConfig):
34
+ """
35
+ Configuration class for the TPTT model.
36
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
37
+ """
38
+
39
+ model_type = "tptt"
40
+ auto_map = {
41
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
42
+ "AutoConfig": "configuration_tptt.TpttConfig",
43
+ }
44
+ architectures = ["TpttModel"]
45
+
46
+ RECURRENT_MODES = {
47
+ "delta_rule": {
48
+ "order": 1,
49
+ "gate_type": "k",
50
+ "linear": True,
51
+ "trick": "derivative",
52
+ },
53
+ "delta_rule_v": {
54
+ "order": 1,
55
+ "gate_type": "v",
56
+ "linear": True,
57
+ "trick": "derivative",
58
+ },
59
+ "delta_rule_kv": {
60
+ "order": 1,
61
+ "gate_type": "kv",
62
+ "linear": True,
63
+ "trick": "derivative",
64
+ },
65
+ "delta_rule_gelu": {
66
+ "order": 1,
67
+ "gate_type": "k",
68
+ "linear": False,
69
+ "trick": "derivative",
70
+ },
71
+ "delta_product": {
72
+ "order": 2,
73
+ "gate_type": "k",
74
+ "linear": True,
75
+ "trick": "derivative",
76
+ },
77
+ "delta_product_r": {
78
+ "order": 2,
79
+ "gate_type": "k",
80
+ "linear": True,
81
+ "trick": "rotative",
82
+ },
83
+ "delta_product_c": {
84
+ "order": 2,
85
+ "gate_type": "k",
86
+ "linear": True,
87
+ "trick": "combined",
88
+ },
89
+ } # Tested modes, see parse_mode_name if you want to add more
90
+
91
+ def __init__(
92
+ self,
93
+ base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
94
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
95
+ base_model_subfolder: Optional[str] = None,
96
+ name_or_path: Optional[str] = None,
97
+ model_task: str = "causal_lm",
98
+ target_modules_names: Optional[List[str]] = None,
99
+ operator_mode: str = "delta_rule",
100
+ use_linear_checkpoint: Optional[bool] = None,
101
+ max_self_attn_length: Optional[
102
+ int
103
+ ] = None, # unnecessary if SWA, else, standards 8192
104
+ base_scale_attn: bool = False,
105
+ mag_weight: float = 0.5, # if 1.0, use only linear operator
106
+ cross_gate: bool = False, # unlinear mixing strategy
107
+ max_chunk_size: int = 64, # 128 if adaptive chunking (longest)
108
+ linear_precision: Union[str, torch.dtype] = "float32",
109
+ lora_config: Optional[dict] = None, # only serialized accepted
110
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
111
+ bidirectional: bool = False, # if True, use bidirectional attention
112
+ pooling_config: Optional[Dict[str, Any]] = None,
113
+ **kwargs,
114
+ ):
115
+ # If base_model_config is provided, load it and merge with this config
116
+ if base_model_config is not None:
117
+ if isinstance(base_model_config, PretrainedConfig):
118
+ base_model_config = base_model_config.to_dict()
119
+ else:
120
+ # Load config from Hugging Face Hub or a local path
121
+ base_model_config = AutoConfig.from_pretrained(
122
+ base_model_name, **kwargs
123
+ ).to_dict()
124
+ # Merge all backbone fields into this config
125
+ for k, v in base_model_config.items():
126
+ setattr(self, k, v)
127
+
128
+ self.base_model_name = base_model_name
129
+ self.base_model_subfolder = base_model_subfolder
130
+ self.model_task = model_task
131
+
132
+ if name_or_path is not None:
133
+ self._name_or_path = name_or_path
134
+ else:
135
+ if "/" in base_model_name:
136
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
137
+ else:
138
+ self._name_or_path = "Titans-" + base_model_name
139
+
140
+ self.target_modules_names = target_modules_names or [
141
+ "attn",
142
+ "self_attn",
143
+ "attention",
144
+ ]
145
+ self.operator_mode = operator_mode
146
+
147
+ # Detect available memory on accelerator device
148
+ if torch.cuda.is_available():
149
+ _, total_mem = torch.cuda.mem_get_info()
150
+ else:
151
+ total_mem = psutil.virtual_memory().total
152
+ total_mem_gb = total_mem / BYTES_IN_GB
153
+
154
+ self.use_linear_checkpoint = (
155
+ total_mem_gb < 16
156
+ if use_linear_checkpoint is None
157
+ else use_linear_checkpoint
158
+ )
159
+
160
+ self.base_scale_attn = base_scale_attn
161
+ self.mag_weight = mag_weight
162
+ self.cross_gate = cross_gate
163
+ self.max_chunk_size = max_chunk_size
164
+ self.max_self_attn_length = max_self_attn_length
165
+ if isinstance(linear_precision, torch.dtype):
166
+ linear_precision = str(linear_precision).replace("torch.", "")
167
+ self.linear_precision = linear_precision
168
+
169
+ self.lora_config = lora_config
170
+ if lora_config is not None:
171
+ if hasattr(self.lora_config.get("peft_type"), "value"):
172
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
173
+ self.lora_config = convert_sets_to_lists(self.lora_config)
174
+
175
+ self.padding_side = padding_side
176
+ self.bidirectional = bidirectional
177
+ if self.bidirectional:
178
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
179
+ self.pooling_config = pooling_config
180
+
181
+ super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
182
+ # Copy class attributes to instance for serialization (save dict)
183
+ self.model_type = self.__class__.model_type
184
+ self.auto_map = self.__class__.auto_map
185
+ self.architectures = self.__class__.architectures
186
+ # Padding side configuration if not set
187
+ if self.padding_side is None:
188
+ self.padding_side = "right"
189
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
190
+ # set recurrent configuration from operator mode
191
+ if operator_mode not in self.__class__.RECURRENT_MODES:
192
+ self.recurrent_config = parse_mode_name(operator_mode)
193
+ else:
194
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
195
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
196
+
197
+
198
+ TpttConfig.register_for_auto_class()
199
+
200
+
201
+ def parse_mode_name(name: str) -> dict:
202
+ """Parse mode to recurrent config"""
203
+ if name.startswith("delta_product"):
204
+ parts = name.split("_")
205
+ # Prefix is always two words: 'delta' and 'product'
206
+ base_len = 2
207
+ order = 2
208
+ gate_type = "k"
209
+ linear = True
210
+ trick = "derivative"
211
+
212
+ idx = base_len
213
+ # Check for order (immediately after the prefix)
214
+ if len(parts) > idx and parts[idx].isdigit():
215
+ order = int(parts[idx])
216
+ idx += 1
217
+
218
+ remaining = parts[idx:]
219
+ # Trick (r/c) is always at the far right if present
220
+ if remaining and remaining[-1] in ("r", "c"):
221
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
222
+ remaining = remaining[:-1]
223
+ # 'gelu' comes just before the trick if present
224
+ if remaining and remaining[-1] == "gelu":
225
+ linear = False
226
+ remaining = remaining[:-1]
227
+ # If anything remains, it's the gate_type
228
+ if remaining:
229
+ gate_type = "_".join(remaining)
230
+ return {
231
+ "order": order,
232
+ "gate_type": gate_type,
233
+ "linear": linear,
234
+ "trick": trick,
235
+ }
236
+
237
+ # delta_rule[_gate][_gelu]
238
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
239
+ if m:
240
+ return {
241
+ "order": 1,
242
+ "gate_type": m.group(1) if m.group(1) else "k",
243
+ "linear": not bool(m.group(2)),
244
+ "trick": "derivative",
245
+ }
246
+ raise ValueError(f"Unknown mode: {name}")
247
+
248
+
249
+ def get_mode_name(
250
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
251
+ ) -> str:
252
+ """Get recurrent mode name from parameter"""
253
+ base = (
254
+ "delta_rule"
255
+ if order == 1
256
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
257
+ )
258
+ parts = []
259
+ if gate_type != "k":
260
+ parts.append(gate_type)
261
+ if not linear:
262
+ parts.append("gelu")
263
+ if order >= 2 and trick != "derivative":
264
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
265
+ return base + (("_" + "_".join(parts)) if parts else "")
266
+
267
+
268
+ def render_template(template_path: str, variables: dict) -> str:
269
+ """Load and render a Jinja2 template from any file path."""
270
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
271
+ template = env.get_template(os.path.basename(template_path))
272
+ return template.render(**variables)
273
+
274
+
275
+ def write_model_card(output_path: str, content: str):
276
+ """Write the generated content into README.md."""
277
+ os.makedirs(output_path, exist_ok=True)
278
+ readme_path = os.path.join(output_path, "README.md")
279
+ with open(readme_path, "w", encoding="utf-8") as f:
280
+ f.write(content)
281
+
282
+
283
+ def generate_model_card(
284
+ output_path: str,
285
+ config: Union[dict, object],
286
+ template: Optional[
287
+ str
288
+ ], # can be "model_card" OR an absolute/relative path to a .md file
289
+ extra_variables: Optional[Dict] = None,
290
+ ):
291
+ """
292
+ Generate a README.md file from a Jinja2 template and a configuration.
293
+
294
+ - template can be either:
295
+ * a full path to a template file
296
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
297
+ """
298
+ if template is None:
299
+ template = "model_card_template" # default template name
300
+ # Locate the template
301
+ if os.path.exists(template): # direct file path provided
302
+ template_path = template
303
+ else:
304
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
305
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
306
+
307
+ if not os.path.exists(template_path):
308
+ raise FileNotFoundError(f"Template not found: {template_path}")
309
+
310
+ variables = {
311
+ "model_id": os.path.basename(output_path),
312
+ "config": config,
313
+ }
314
+ if extra_variables:
315
+ variables.update(extra_variables)
316
+
317
+ content = render_template(template_path, variables)
318
+ write_model_card(output_path, content)
lora_delta_product_m0.5_constant/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.49.0"
4
+ }
lora_delta_product_m0.5_constant/modeling_tptt.py ADDED
@@ -0,0 +1,1501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
+ """
4
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
+ Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import (
28
+ AutoConfig,
29
+ AutoModel,
30
+ AutoModelForCausalLM,
31
+ DynamicCache,
32
+ PreTrainedModel,
33
+ )
34
+ from transformers.configuration_utils import PretrainedConfig
35
+
36
+ from .configuration_tptt import TpttConfig
37
+
38
+ logger = logging.getLogger(__name__) # monitoring
39
+
40
+
41
+ class LCache:
42
+ """Cache for storing intermediate states of linear attention layers."""
43
+
44
+ def __init__(self):
45
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
46
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
47
+ {}
48
+ ) # recurrent states and qkv buffers
49
+
50
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
51
+ """Retrieve cached state for a given layer, or None if not present"""
52
+ return self.inputs_states.get(layer_idx, None)
53
+
54
+ def update(self, layer_idx: int, **kwargs):
55
+ """Detach all tensors to avoid retaining computation graphs"""
56
+ detached_kwargs = {
57
+ k: v.detach() if isinstance(v, torch.Tensor) else v
58
+ for k, v in kwargs.items()
59
+ }
60
+ # Update or create the state for the specified layer
61
+ if layer_idx in self.inputs_states:
62
+ self.inputs_states[layer_idx].update(detached_kwargs)
63
+ else:
64
+ self.inputs_states[layer_idx] = detached_kwargs
65
+
66
+ def reset(self):
67
+ """Clear all cached states and reset the token counter"""
68
+ self.inputs_states.clear()
69
+
70
+
71
+ class CausalAvgPool1d(nn.Module):
72
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
73
+
74
+ def __init__(
75
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
76
+ ):
77
+ super().__init__()
78
+ self.offsets = offsets
79
+ self.mode = mode
80
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """x: [B, S, F] → [B, S, F → output_size]"""
84
+ x_ = x.transpose(1, 2) # [B, F, S]
85
+ idxs = torch.tensor(self.offsets, device=x.device)
86
+ ksize = idxs.max() - idxs.min() + 1
87
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
88
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
89
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
90
+ pad_left = -idxs.min().item()
91
+ pad_right = (ksize - 1) - pad_left
92
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
93
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
94
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
95
+
96
+
97
+ class LinearAttention(nn.Module):
98
+ """
99
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
100
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ hidden_dim: int,
106
+ num_heads: int,
107
+ head_dim: Optional[int] = None,
108
+ num_key_value_heads: Optional[int] = None,
109
+ num_key_value_groups: Optional[int] = None,
110
+ bias: bool = True,
111
+ dropout: Optional[float] = None,
112
+ linear_precision: torch.dtype = torch.float32,
113
+ padding_side: str = "right",
114
+ shared_attn: bool = False, # shared attention
115
+ layer_idx: int = 0,
116
+ operator_mode: str = "delta_rule",
117
+ use_linear_checkpoint: bool = False,
118
+ recurrent_config: Optional[Dict[str, Any]] = None,
119
+ linear_cache: Optional[LCache] = None,
120
+ max_chunk_size: int = 64,
121
+ bidirectional: bool = False, # not used if causal
122
+ pooling_config: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ super().__init__()
125
+ if pooling_config is None:
126
+ pooling_config = {
127
+ "offsets": (0, 1, 2),
128
+ "mode": "replicate",
129
+ }
130
+ self.hidden_dim = hidden_dim
131
+ self.num_heads = num_heads
132
+ self.head_dim = head_dim or hidden_dim // num_heads
133
+ self.num_key_value_heads = num_key_value_heads or num_heads
134
+ self.num_key_value_groups = num_key_value_groups or (
135
+ num_heads // (num_key_value_heads or num_heads)
136
+ )
137
+ self.scaling = self.head_dim**-0.5
138
+ self.linear_precision = linear_precision
139
+ self.padding_side = padding_side
140
+
141
+ self.shared_attn = shared_attn
142
+
143
+ if not shared_attn:
144
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
145
+ self.k_proj = nn.Linear(
146
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
147
+ )
148
+ self.v_proj = nn.Linear(
149
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
150
+ )
151
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
152
+
153
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
154
+
155
+ self.linear_operator = LinearAttentionOp(
156
+ layer_idx=layer_idx,
157
+ operator_mode=operator_mode,
158
+ use_linear_checkpoint=use_linear_checkpoint,
159
+ recurrent_config=recurrent_config,
160
+ max_chunk_size=max_chunk_size,
161
+ linear_cache=linear_cache,
162
+ linear_precision=linear_precision,
163
+ )
164
+ self.bidirectional = bidirectional
165
+ # Causal average pooling for gating
166
+ self.pooling_config = pooling_config
167
+ self.pool_g = CausalAvgPool1d(
168
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ x: Union[List[torch.Tensor], torch.Tensor],
174
+ attn_mask: Optional[torch.Tensor] = None,
175
+ out_proj: Optional[nn.Module] = None,
176
+ **kwargs: Any,
177
+ ) -> torch.Tensor:
178
+ """
179
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
180
+ """
181
+
182
+ if not self.shared_attn:
183
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
184
+ # Projections
185
+ q = self.q_proj(hidden_states)
186
+ k = self.k_proj(hidden_states)
187
+ v = self.v_proj(hidden_states)
188
+ out_proj = self.out_proj
189
+ else:
190
+ # Shared attention <=> no projections here
191
+ q, k, v = x[0], x[1], x[2]
192
+ out_proj = self.out_proj if out_proj is None else out_proj
193
+
194
+ # get dtype and device
195
+ final_dtype, final_device = q.dtype, q.device
196
+ # Masking if needed
197
+ if attn_mask is not None:
198
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
199
+
200
+ # Forget and Write Gating for linear attn (abusive term)
201
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
202
+
203
+ # Reshape for multi-head
204
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
205
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
206
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
207
+
208
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
209
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
210
+
211
+ # Repeat for GQA
212
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
213
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
214
+
215
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
216
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
217
+
218
+ ## DeltaNet-style: Silu activation and normalization
219
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
220
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
221
+
222
+ ## linear stability part
223
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
224
+
225
+ # Apply sigmoid to forget and write gates
226
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
227
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
228
+
229
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
230
+ q, k, v, f_g, w_g = (
231
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
232
+ )
233
+ g = (f_g, w_g)
234
+
235
+ # Linear Attention Core, output: [B, H, S, d]
236
+ if self.bidirectional: # Work only with uncausal attention
237
+ # Forward direction
238
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
239
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
240
+ kwargs_bwd = kwargs.copy()
241
+ kwargs_bwd["use_cache"] = False
242
+ out_backward = self.linear_operator(
243
+ torch.flip(q, dims=[2]),
244
+ torch.flip(k, dims=[2]),
245
+ torch.flip(v, dims=[2]),
246
+ tuple(torch.flip(t, dims=[2]) for t in g),
247
+ **kwargs_bwd,
248
+ )
249
+ # Flip the output back to restore proper order
250
+ out_backward = torch.flip(out_backward, dims=[2])
251
+ # Fusion: here, simple addition
252
+ out = out_forward + out_backward
253
+ else:
254
+ out = self.linear_operator(q, k, v, g, **kwargs)
255
+
256
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
257
+ out = rearrange(out, "b h s d -> b s (h d)")
258
+ # Normalize output (RMS norm). Note: bidirectional compatibility
259
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
260
+ # Ensure dtype and device consistency
261
+ out = out.to(dtype=final_dtype, device=final_device)
262
+ # Apply output projection
263
+ out = out_proj(out) # [B, S, D]
264
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
265
+ # Apply dropout if specified
266
+ if self.dropout is not None:
267
+ out = self.dropout(out)
268
+ return out
269
+
270
+
271
+ class LiZAttention(nn.Module):
272
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
273
+
274
+ def __init__(
275
+ self,
276
+ base_attn: nn.Module,
277
+ layer_idx: int,
278
+ base_config: PretrainedConfig, # Backbone Config
279
+ linear_cache: Optional[LCache] = None,
280
+ operator_mode: str = "delta_rule",
281
+ use_linear_checkpoint: bool = False,
282
+ recurrent_config: Optional[Dict[str, Any]] = None,
283
+ max_self_attn_length: Optional[int] = None, # unnecessary
284
+ base_scale_attn: bool = False,
285
+ mag_weight: float = 0.5,
286
+ cross_gate: bool = False,
287
+ max_chunk_size: int = 64,
288
+ linear_precision: Union[str, torch.dtype] = "float32",
289
+ padding_side: str = "right", # for tokenizer
290
+ disable_linear_attn: bool = False,
291
+ bidirectional: bool = False, # if True, use bidirectional attention
292
+ pooling_config: Optional[Dict[str, Any]] = None,
293
+ ):
294
+ super().__init__()
295
+ if isinstance(linear_precision, str):
296
+ linear_precision = getattr(torch, linear_precision)
297
+ self.linear_precision = linear_precision
298
+ self.base_attn: nn.Module = base_attn
299
+ self.base_config = base_config
300
+ self.layer_idx = layer_idx
301
+ self.max_self_attn_length = max_self_attn_length
302
+ self.base_scale_attn = base_scale_attn
303
+ self.mag_weight = mag_weight
304
+ self.cross_gate = cross_gate
305
+ self.max_chunk_size = max_chunk_size
306
+ self.linear_precision = linear_precision
307
+ self.padding_side = padding_side
308
+ self.disable_linear_attn = disable_linear_attn
309
+
310
+ (
311
+ self.num_heads,
312
+ self.head_dim,
313
+ self.num_key_value_heads,
314
+ self.num_key_value_groups,
315
+ ) = self._get_attention_parameters(base_attn, base_config)
316
+ self.scaling = self.head_dim**-0.5
317
+
318
+ self.linear_attn = LinearAttention(
319
+ layer_idx=layer_idx,
320
+ shared_attn=True,
321
+ operator_mode=operator_mode,
322
+ use_linear_checkpoint=use_linear_checkpoint,
323
+ recurrent_config=recurrent_config,
324
+ hidden_dim=base_config.hidden_size,
325
+ num_heads=self.num_heads,
326
+ head_dim=self.head_dim,
327
+ num_key_value_heads=self.num_key_value_heads,
328
+ num_key_value_groups=self.num_key_value_groups,
329
+ linear_precision=linear_precision,
330
+ linear_cache=linear_cache,
331
+ max_chunk_size=max_chunk_size,
332
+ padding_side=padding_side,
333
+ bidirectional=bidirectional,
334
+ pooling_config=pooling_config,
335
+ )
336
+
337
+ def _get_attention_parameters(
338
+ self, base_attn: nn.Module, base_config: PretrainedConfig
339
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
340
+ """Retrieve the attention parameters from the base attention module."""
341
+ # first order base attention module and second order config
342
+ num_heads = (
343
+ getattr(base_attn, "num_heads", None)
344
+ or getattr(base_attn, "num_q_heads", None)
345
+ or getattr(base_config, "num_heads", None)
346
+ or getattr(base_config, "num_attention_heads", None)
347
+ )
348
+ head_dim = (
349
+ getattr(base_attn, "head_dim", None)
350
+ or getattr(base_attn, "attention_head_size", None)
351
+ or getattr(base_config, "head_dim", None)
352
+ or (
353
+ getattr(base_config, "hidden_size", None) // num_heads
354
+ if num_heads and getattr(base_config, "hidden_size", None)
355
+ else None
356
+ )
357
+ )
358
+ num_key_value_heads = (
359
+ getattr(base_attn, "num_kv_heads", None)
360
+ or getattr(base_attn, "num_k_heads", None)
361
+ or getattr(base_config, "num_key_value_heads", None)
362
+ or num_heads # fallback
363
+ )
364
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
365
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
366
+ )
367
+ return (
368
+ num_heads,
369
+ head_dim,
370
+ num_key_value_heads,
371
+ num_key_value_groups,
372
+ )
373
+
374
+ def _apply_shared_projections(
375
+ self, hidden_states: torch.Tensor
376
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
377
+ base_attn = self.base_attn
378
+ if hasattr(base_attn, "q_proj"):
379
+ # LLama, OLMO and Mistral style
380
+ q = base_attn.q_proj(hidden_states)
381
+ k = base_attn.k_proj(hidden_states)
382
+ v = base_attn.v_proj(hidden_states)
383
+ out_proj = base_attn.o_proj
384
+ elif hasattr(base_attn, "qkv_proj"):
385
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
386
+ qkv = base_attn.qkv_proj(hidden_states)
387
+ q, k, v = split_qkv(base_attn, qkv)
388
+ out_proj = base_attn.out_proj
389
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
390
+ # GPT-2 style
391
+ qkv = base_attn.c_attn(hidden_states)
392
+ q, k, v = qkv.chunk(3, dim=-1)
393
+ out_proj = base_attn.c_proj
394
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
395
+ # BERT - ViT
396
+ q = base_attn.query(hidden_states)
397
+ k = base_attn.key(hidden_states)
398
+ v = base_attn.value(hidden_states)
399
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
400
+ else:
401
+ raise ValueError("Unsupported attention module: cannot find projections.")
402
+ # Ensure stability
403
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
404
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
405
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
406
+ return q, k, v, out_proj
407
+
408
+ def _process_self_attn(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ attention_mask: Optional[torch.Tensor],
412
+ kwargs,
413
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
414
+ """Process the self-attention part (with truncation)."""
415
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
416
+ hidden_states, attention_mask = truncate_attention_mask(
417
+ hidden_states, attention_mask, self.max_self_attn_length
418
+ )
419
+
420
+ if kwargs.get("position_embeddings", None) is not None:
421
+ cos, sin = kwargs["position_embeddings"]
422
+ cos = cos[:, -self.max_self_attn_length :]
423
+ sin = sin[:, -self.max_self_attn_length :]
424
+ kwargs["position_embeddings"] = (cos, sin)
425
+
426
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
427
+ # cache management
428
+ if (
429
+ len(kwargs["past_key_value"]) > self.layer_idx
430
+ and self.layer_idx == 0
431
+ ):
432
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
433
+
434
+ # Ensure attention mask is of the correct dtype and device
435
+ if attention_mask is not None:
436
+ attention_mask = attention_mask.to(
437
+ dtype=hidden_states.dtype, device=hidden_states.device
438
+ )
439
+ # Standard attention (mask and rotation is applied inside)
440
+ base_attn_outputs = self.base_attn(
441
+ hidden_states,
442
+ attention_mask=attention_mask,
443
+ **kwargs,
444
+ )
445
+
446
+ if isinstance(base_attn_outputs, tuple):
447
+ if len(base_attn_outputs) == 3:
448
+ o_base, attn_weights, present_key_value = base_attn_outputs
449
+ expected_attn_mode = 3
450
+ elif len(base_attn_outputs) == 2:
451
+ o_base, attn_weights = base_attn_outputs
452
+ present_key_value, expected_attn_mode = None, 2
453
+ else:
454
+ raise ValueError(
455
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
456
+ )
457
+ else:
458
+ o_base = base_attn_outputs
459
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
460
+ # Ensure stability
461
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
462
+ return o_base, attn_weights, present_key_value, expected_attn_mode
463
+
464
+ def _prepare_attn_mixin(
465
+ self,
466
+ o_lin: torch.Tensor,
467
+ o_base: torch.Tensor,
468
+ tensor_dtype: torch.dtype,
469
+ eps: float = 1e-5,
470
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
471
+ """Prepare linear attn for mixing with self attn."""
472
+ # Force cast typing, shape : [b n (h d)]
473
+ o_lin = o_lin.to(tensor_dtype)
474
+ o_base = o_base.to(tensor_dtype)
475
+ # feature scaling
476
+ if self.base_scale_attn:
477
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
478
+ o_lin = scaler * o_lin
479
+ return o_lin, o_base
480
+
481
+ def _apply_mag(
482
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
483
+ ) -> torch.Tensor:
484
+ """Apply the MAG strategy"""
485
+ # Left-Padding management
486
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
487
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
488
+ linear_attention, softmax_attention = (
489
+ linear_attention[:, -left_trunc:],
490
+ softmax_attention[:, -left_trunc:],
491
+ )
492
+ # NAM : Neural Attention Mixer (with graph forcing)
493
+ mag_weight = torch.tensor(
494
+ self.mag_weight,
495
+ dtype=softmax_attention.dtype,
496
+ device=softmax_attention.device,
497
+ )
498
+ softmax_weighted = (1 - mag_weight) * softmax_attention
499
+ linear_weighted = mag_weight * linear_attention
500
+ if self.cross_gate:
501
+ output_attention = (
502
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
503
+ ) # complex cross product (unlinear interaction)
504
+ else:
505
+ output_attention = softmax_weighted + linear_weighted # classic
506
+
507
+ if torch.allclose(softmax_weighted, output_attention):
508
+ logger.info(
509
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
510
+ self.layer_idx,
511
+ )
512
+ # Final output
513
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
514
+
515
+ def forward(
516
+ self,
517
+ hidden_states: torch.Tensor,
518
+ attention_mask: Optional[torch.Tensor] = None,
519
+ **kwargs,
520
+ ) -> torch.Tensor:
521
+ """Mix linear and self attention forward"""
522
+ device = hidden_states.device
523
+ tensor_dtype = hidden_states.dtype
524
+ self.base_attn.to(device)
525
+
526
+ if self.training:
527
+ kwargs.pop("past_key_value", None)
528
+ kwargs["use_cache"] = False
529
+ elif "use_cache" not in kwargs:
530
+ kwargs.pop("past_key_value", None)
531
+ kwargs["use_cache"] = False
532
+
533
+ kwargs.pop("position_ids", None) # obsolete
534
+
535
+ # Apply shared projections
536
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
537
+
538
+ # Apply linear attention to hidden states
539
+ o_lin = self.linear_attn(
540
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
541
+ )
542
+
543
+ # Process self attn with truncation
544
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
545
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
546
+ )
547
+
548
+ # Prepare output mixing
549
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
550
+
551
+ # Apply Memory as Gate in self-attention (with length management and ablation)
552
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
553
+
554
+ # Return output following transformer convention
555
+ if expected_attn_mode == 3:
556
+ return out, attn_weights, present_key_value
557
+ if expected_attn_mode == 2:
558
+ return out, attn_weights
559
+ return out
560
+
561
+
562
+ def load_tptt_safetensors(
563
+ repo_or_path: str,
564
+ model: Union[PreTrainedModel, PeftModel],
565
+ subfolder: Optional[str] = None,
566
+ token: Optional[str] = None,
567
+ ) -> Union[PreTrainedModel, PeftModel]:
568
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
569
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
570
+ fname = "adapter_model.safetensors"
571
+ # subfolder management
572
+ if subfolder:
573
+ repo_or_path_norm = os.path.normpath(repo_or_path)
574
+ subfolder_norm = os.path.normpath(subfolder)
575
+ if not repo_or_path_norm.endswith(subfolder_norm):
576
+ fname = f"{subfolder}/{fname}" if subfolder else fname
577
+ # Find file path
578
+ if os.path.isdir(repo_or_path):
579
+ path = os.path.join(repo_or_path, fname)
580
+ if not os.path.exists(path):
581
+ return model
582
+ else:
583
+ if fname not in list_repo_files(repo_or_path, token=token):
584
+ return model
585
+ path = hf_hub_download(repo_or_path, fname, token=token)
586
+
587
+ # Load weights from safetensors
588
+ with safe_open(path, framework="pt") as f:
589
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
590
+
591
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
592
+ def adapt_keys(sd, model):
593
+ model_keys = list(model.state_dict().keys())
594
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
595
+ prefix = "tptt_model.base_model."
596
+ elif any(k.startswith("base_model.") for k in model_keys):
597
+ prefix = "base_model."
598
+ else:
599
+ prefix = ""
600
+
601
+ has_base_attn = any(".base_attn." in k for k in model_keys)
602
+
603
+ def adapt_key(k):
604
+ k_ = k if k.startswith(prefix) else prefix + k
605
+ # first, verify and modify base_attn (LiZA)
606
+ if ".base_attn." in k_ and not has_base_attn:
607
+ k_ = k_.replace(".base_attn.", ".")
608
+ # change LoRA if needed
609
+ if (
610
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
611
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
612
+ k_ = k_.replace(".weight", ".default.weight")
613
+ return k_
614
+
615
+ return {adapt_key(k): v for k, v in sd.items()}
616
+
617
+ state_dict = adapt_keys(state_dict, model)
618
+
619
+ # Cast tensors to the expected dtype of the model parameters
620
+ model_state_dict = model.state_dict()
621
+ for k, v in state_dict.items():
622
+ if k in model_state_dict:
623
+ expected_dtype = model_state_dict[k].dtype
624
+ if v.dtype != expected_dtype:
625
+ state_dict[k] = v.to(expected_dtype)
626
+
627
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
628
+
629
+ # Load into model
630
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
631
+ missing_lora = [k for k in missing if "lora" in k]
632
+ if missing_lora:
633
+ logger.warning("Missing keys: %s", missing_lora)
634
+ if unexpected:
635
+ logger.warning("Unexpected keys: %s", unexpected)
636
+ return model
637
+
638
+
639
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
640
+ model: nn.Module,
641
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
642
+ linear_cache: Optional[LCache] = None,
643
+ liza_attention: nn.Module = LiZAttention,
644
+ target_modules_names: Optional[list[str]] = None,
645
+ operator_mode: str = "delta_rule",
646
+ use_linear_checkpoint: bool = False,
647
+ recurrent_config: Optional[Dict[str, Any]] = None,
648
+ base_scale_attn: bool = False,
649
+ mag_weight: float = 0.5,
650
+ cross_gate: bool = False,
651
+ max_chunk_size: int = 64,
652
+ linear_precision: torch.dtype = torch.float32,
653
+ max_self_attn_length: Optional[int] = None, # unnecessary
654
+ padding_side: str = "right", # for tokenizer
655
+ bidirectional: bool = False, # if True, use bidirectional attention
656
+ pooling_config: Optional[Dict[str, Any]] = None,
657
+ **kwargs, # quickfix unexpected arguments
658
+ ) -> Tuple[PreTrainedModel, LCache]:
659
+ """Replace target modules in a model with LiZAttention."""
660
+ if target_modules_names is None:
661
+ target_modules_names = ["attn", "self_attn", "attention"]
662
+ # Find target modules by suffix (e.g., "attn", "attention")
663
+ target_modules_names = [
664
+ name
665
+ for name, _ in model.named_modules()
666
+ if any(name.endswith(suffix) for suffix in target_modules_names)
667
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
668
+ ]
669
+ if not target_modules_names:
670
+ raise ValueError(
671
+ f"Target modules '{target_modules_names}' not found in the model."
672
+ )
673
+ # Prepare recurrent config
674
+ linear_cache = linear_cache or LCache()
675
+ # Inject LiZAttention into the model
676
+ for name, _ in model.named_modules():
677
+ if name in target_modules_names:
678
+ parent = model
679
+ *path, last = name.split(".")
680
+ for p in path:
681
+ parent = getattr(parent, p)
682
+ layer_idx = extract_layer_idx(name)
683
+ setattr(
684
+ parent,
685
+ last,
686
+ liza_attention(
687
+ getattr(parent, last),
688
+ layer_idx=layer_idx,
689
+ base_config=base_config,
690
+ linear_cache=linear_cache,
691
+ operator_mode=operator_mode,
692
+ use_linear_checkpoint=use_linear_checkpoint,
693
+ recurrent_config=recurrent_config,
694
+ max_self_attn_length=max_self_attn_length,
695
+ base_scale_attn=base_scale_attn,
696
+ mag_weight=mag_weight,
697
+ cross_gate=cross_gate,
698
+ max_chunk_size=max_chunk_size,
699
+ linear_precision=linear_precision,
700
+ padding_side=padding_side,
701
+ bidirectional=bidirectional,
702
+ pooling_config=pooling_config,
703
+ ),
704
+ )
705
+ return model, linear_cache
706
+
707
+
708
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
709
+ """Save trainable LoRA/Specific weights and adapting key names"""
710
+ # 1. Get the full state_dict
711
+ all_sd = model.state_dict()
712
+
713
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
714
+ trainable_keys = [
715
+ name for name, param in model.named_parameters() if param.requires_grad
716
+ ] # Also, you can manually select specific keys in model after load
717
+
718
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
719
+ to_save = {
720
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
721
+ for k in trainable_keys
722
+ }
723
+
724
+ # 4. Save the filtered adapters to a safetensors file
725
+ if to_save:
726
+ os.makedirs(os.path.dirname(path), exist_ok=True)
727
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
728
+ save_file(to_save, os.path.join(path, name))
729
+
730
+
731
+ class TpttModel(PreTrainedModel):
732
+ """
733
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
734
+ Handles only architecture and weights.
735
+ """
736
+
737
+ config_class = TpttConfig
738
+
739
+ def __init__(
740
+ self,
741
+ config: TpttConfig,
742
+ **kwargs,
743
+ ):
744
+ """
745
+ Initialize TpttModel with a given config and backbone.
746
+ Injects LiZA attention modules into the backbone.
747
+ """
748
+ super().__init__(config, **kwargs)
749
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
750
+
751
+ # 1. Load backbone (with subfolder management) :
752
+ kwargs_bb = kwargs.copy()
753
+ if config.base_model_subfolder is not None:
754
+ kwargs_bb["subfolder"] = config.base_model_subfolder
755
+ else:
756
+ kwargs_bb.pop("subfolder", None)
757
+
758
+ if config.model_task == "causal_lm":
759
+ tptt_model = AutoModelForCausalLM.from_pretrained(
760
+ config.base_model_name, **kwargs_bb
761
+ )
762
+ else:
763
+ tptt_model = AutoModel.from_pretrained(config.base_model_name, **kwargs_bb)
764
+
765
+ # 2. Inject LiZA attention
766
+ self.linear_cache = LCache()
767
+ tptt_model, self.linear_cache = get_tptt_model(
768
+ tptt_model, config, self.linear_cache, **config.to_dict()
769
+ )
770
+
771
+ # 3. Apply LoRA/Specific if present and configured
772
+ if config.lora_config is not None:
773
+ lora_config_obj = LoraConfig(**config.lora_config)
774
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
775
+ else:
776
+ # Doesn't work if quantization is applied !
777
+ tptt_model = set_trainable_parameters(tptt_model)
778
+
779
+ # 4. Load safetensor if tptt/peft adaptor in repo
780
+ if repo_or_path:
781
+ tptt_model = load_tptt_safetensors(
782
+ repo_or_path,
783
+ tptt_model,
784
+ subfolder=kwargs.get("subfolder", None),
785
+ token=kwargs.get("token", None),
786
+ )
787
+ self.tptt_model = tptt_model
788
+
789
+ def forward(
790
+ self,
791
+ input_ids: Optional[torch.LongTensor] = None,
792
+ attention_mask: Optional[torch.Tensor] = None,
793
+ labels: Optional[torch.LongTensor] = None,
794
+ **kwargs,
795
+ ):
796
+ """Forward pass. All arguments are passed to the underlying base model."""
797
+ if self.training:
798
+ kwargs["use_cache"] = False
799
+ kwargs.pop("num_items_in_batch", None)
800
+ elif "use_cache" not in kwargs: # evaluation
801
+ kwargs.pop("num_items_in_batch", None)
802
+ kwargs["use_cache"] = False
803
+ return self.tptt_model(
804
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
805
+ )
806
+
807
+ def generate(self, *args, **kwargs):
808
+ """Delegate the generate call to the backbone model, which supports generation"""
809
+ return self.tptt_model.generate(*args, **kwargs)
810
+
811
+ def save_pretrained(self, path: str, **kwargs):
812
+ """Save model weights, config, and source code to the given path."""
813
+ # 0. Save complete tptt config (with or without LoRA)
814
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
815
+ self._adjust_save_strategy(path, **kwargs)
816
+ # 1. Save true weights and adapte keys
817
+ save_tptt_safetensors(self, path)
818
+ # 2. Copy Python files for trust_remote_code
819
+ self._copy_source_files(path, **kwargs)
820
+
821
+ def _adjust_save_strategy(self, path: str, **kwargs):
822
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
823
+ if isinstance(self.tptt_model, PeftModel):
824
+ self.tptt_model.save_pretrained(path, **kwargs)
825
+ safetensor_path = os.path.join(path, "model.safetensors")
826
+ if os.path.exists(safetensor_path):
827
+ os.remove(safetensor_path)
828
+ adapter_path = os.path.join(path, "adapter_config.json")
829
+ if os.path.exists(adapter_path):
830
+ os.remove(adapter_path)
831
+
832
+ def _copy_source_files(self, target_path: str, **kwargs):
833
+ """Copy all .py files from package directory for trust_remote_code."""
834
+ src_dir = os.path.dirname(os.path.abspath(__file__))
835
+ dst_dir = (
836
+ f"./{str(Path(target_path).parts[0])}"
837
+ if kwargs.get("subfolder", False)
838
+ else target_path
839
+ )
840
+ for fname in os.listdir(src_dir):
841
+ if fname.endswith(".py"):
842
+ src = os.path.join(src_dir, fname)
843
+ dst = os.path.join(dst_dir, fname)
844
+ shutil.copy2(src, dst)
845
+
846
+ def retie_lm_after_load(self, **kwargs):
847
+ """Re-link lm_head after loading external weights."""
848
+ embed_lm = find_embedding_lm(self.tptt_model)
849
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
850
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
851
+ self.tptt_model.lm_head = nn.Linear(
852
+ embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
853
+ )
854
+ if kwargs.get("tie_word_embeddings", True):
855
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
856
+ logger.info("Weights of lm_head have been shared with embedding.")
857
+ else:
858
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
859
+ logger.info("Weights of lm_head have been cloned from the embedding.")
860
+
861
+ @classmethod
862
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
863
+ """Custom from_pretrained that accepts the standard positional argument"""
864
+ config = kwargs.pop("config", None)
865
+ repo_or_path = (
866
+ pretrained_model_name_or_path
867
+ or kwargs.pop("pretrained_model_name_or_path", None)
868
+ or kwargs.pop("repo_or_path", None)
869
+ or (getattr(config, "_base_path", None) if config else None)
870
+ or (getattr(config, "_name_or_path", None) if config else None)
871
+ )
872
+
873
+ if config is None and repo_or_path is not None:
874
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
875
+ model = cls(config, *model_args, **kwargs)
876
+ model.retie_lm_after_load(**kwargs)
877
+ return model
878
+
879
+
880
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
881
+
882
+
883
+ class LinearAttentionOp(nn.Module):
884
+ """Base class for linear attention operators."""
885
+
886
+ def __init__(
887
+ self,
888
+ layer_idx: int,
889
+ operator_mode: str = "delta_rule",
890
+ use_linear_checkpoint: bool = False,
891
+ recurrent_config: Optional[dict] = None,
892
+ max_chunk_size: int = 64,
893
+ linear_cache: Optional[LCache] = None,
894
+ linear_precision: torch.dtype = torch.float32,
895
+ ):
896
+ super().__init__()
897
+ self.layer_idx = layer_idx
898
+ if recurrent_config is None:
899
+ operator_mode = "delta_rule" # force default operator mode if no config
900
+ recurrent_config = {
901
+ "order": 1,
902
+ "gate_type": "k",
903
+ "linear": True,
904
+ "trick": "derivative",
905
+ }
906
+ self.operator_mode = operator_mode
907
+ self.use_linear_checkpoint = use_linear_checkpoint
908
+
909
+ self.order = recurrent_config["order"]
910
+ self.gate_type = recurrent_config["gate_type"]
911
+ self.linear = recurrent_config["linear"]
912
+ self.trick = recurrent_config["trick"]
913
+
914
+ self.max_chunk_size = max_chunk_size
915
+ self.linear_cache = linear_cache or LCache()
916
+ self.linear_precision = linear_precision
917
+
918
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
919
+ """
920
+ Compute the gating tensor according to the gate_type.
921
+ """
922
+ if self.gate_type == "k":
923
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
924
+ if self.gate_type == "v":
925
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
926
+ if self.gate_type == "kv":
927
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
928
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
929
+
930
+ def get_cache(self, use_cache: bool) -> Tuple[
931
+ Optional[torch.Tensor],
932
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
933
+ ]:
934
+ """
935
+ Retrieve recurrent state and qkv buffers from the cache.
936
+ """
937
+ if not use_cache:
938
+ return None, None
939
+ last_state = self.linear_cache[self.layer_idx]
940
+ if last_state is not None:
941
+ recurrent_state = last_state.get("recurrent_state", None)
942
+ qkv_buffers = last_state.get("qkv", None)
943
+ else:
944
+ recurrent_state = None
945
+ qkv_buffers = None
946
+ return recurrent_state, qkv_buffers
947
+
948
+ def save_cache(
949
+ self,
950
+ use_cache: bool,
951
+ q: torch.Tensor,
952
+ k: torch.Tensor,
953
+ v: torch.Tensor,
954
+ gate: torch.Tensor,
955
+ state: torch.Tensor,
956
+ ) -> None:
957
+ """
958
+ Save the recurrent state and qkv buffers to the cache.
959
+ """
960
+ if not use_cache:
961
+ return
962
+ if self.order > 1:
963
+ qkv_buffers = (
964
+ q[:, :, -(self.order - 1) :, :],
965
+ k[:, :, -(self.order - 1) :, :],
966
+ v[:, :, -(self.order - 1) :, :],
967
+ gate[:, :, -(self.order - 1) :, :],
968
+ )
969
+ else:
970
+ qkv_buffers = None
971
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
972
+
973
+ def forward(
974
+ self,
975
+ q: torch.Tensor,
976
+ k: torch.Tensor,
977
+ v: torch.Tensor,
978
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
979
+ **kwargs,
980
+ ) -> torch.Tensor:
981
+ """
982
+ Forward pass for the attention operator.
983
+ """
984
+ # Ensure linear_precision for numerical stability (float32)
985
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
986
+ if isinstance(beta, (tuple, list)):
987
+ beta = tuple(b.to(self.linear_precision) for b in beta)
988
+ else:
989
+ beta = beta.to(self.linear_precision)
990
+
991
+ gate = self.compute_gate(beta)
992
+
993
+ # Retrieve cache if needed
994
+ use_cache = kwargs.get("use_cache", False)
995
+ use_checkpoint = not (use_cache) and self.use_linear_checkpoint
996
+ recurrent_state, qkvb = self.get_cache(use_cache)
997
+
998
+ if qkvb is not None and qkvb[0].shape == q.shape:
999
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
1000
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
1001
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
1002
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
1003
+ self.linear_precision
1004
+ )
1005
+
1006
+ output, state = self.chunk_delta_product_forward(
1007
+ q,
1008
+ k,
1009
+ v,
1010
+ gate,
1011
+ self.max_chunk_size,
1012
+ n=self.order,
1013
+ trick=self.trick,
1014
+ linear=self.linear,
1015
+ initial_state=recurrent_state,
1016
+ use_checkpoint=use_checkpoint,
1017
+ linear_precision=self.linear_precision,
1018
+ )
1019
+
1020
+ # Save cache if needed
1021
+ self.save_cache(use_cache, q, k, v, gate, state)
1022
+
1023
+ return output
1024
+
1025
+ @staticmethod
1026
+ def chunk_delta_product_forward(
1027
+ query: torch.Tensor,
1028
+ key: torch.Tensor,
1029
+ value: torch.Tensor,
1030
+ beta_gate: torch.Tensor,
1031
+ chunk_size: int,
1032
+ n: int = 1,
1033
+ trick: str = "derivative",
1034
+ linear: bool = True,
1035
+ initial_state: Optional[torch.Tensor] = None,
1036
+ use_checkpoint: bool = True,
1037
+ linear_precision: torch.dtype = torch.float32,
1038
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1039
+ """
1040
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1041
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
1042
+ """
1043
+
1044
+ # --- Main chunk_delta_product_forward logic ---
1045
+
1046
+ batch_size, num_heads, seq_len, head_dim = query.shape
1047
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1048
+ num_chunks = seq_len // chunk_size
1049
+
1050
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1051
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1052
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1053
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
1054
+
1055
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1056
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1057
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1058
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1059
+
1060
+ k_beta = k_chunks * beta_chunks
1061
+ v_beta = v_chunks * beta_chunks
1062
+
1063
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1064
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1065
+
1066
+ # size : N = chunk_size * n
1067
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1068
+
1069
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1070
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1071
+
1072
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1073
+ if initial_state is not None and initial_state.shape == state_shape:
1074
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1075
+ else:
1076
+ state = torch.full(
1077
+ state_shape,
1078
+ fill_value=1e-6, # stability if unlinear activation
1079
+ device=query.device,
1080
+ dtype=linear_precision,
1081
+ )
1082
+
1083
+ output, final_state = sequential_delta_product_scan(
1084
+ q_chunks.to(dtype=linear_precision),
1085
+ w.to(dtype=linear_precision),
1086
+ u.to(dtype=linear_precision),
1087
+ n,
1088
+ linear,
1089
+ chunk_size,
1090
+ state.to(dtype=linear_precision),
1091
+ linear_precision=linear_precision,
1092
+ use_checkpoint=use_checkpoint,
1093
+ )
1094
+
1095
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1096
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
1097
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
1098
+
1099
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1100
+
1101
+
1102
+ def sequential_delta_product_scan(
1103
+ q_chunks: torch.Tensor,
1104
+ w: torch.Tensor,
1105
+ u: torch.Tensor,
1106
+ n_orders: int,
1107
+ linear_activation: bool,
1108
+ current_chunk_size: int,
1109
+ initial_recurrent_state: torch.Tensor,
1110
+ linear_precision: torch.dtype,
1111
+ use_checkpoint: bool,
1112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1113
+ """
1114
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1115
+ Implements the per-token Householder state updates.
1116
+ """
1117
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1118
+ output_inner = torch.empty_like(q_chunks)
1119
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1120
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1121
+
1122
+ def process_one_chunk(
1123
+ q_chunk_params: torch.Tensor,
1124
+ w_chunk_params: torch.Tensor,
1125
+ u_chunk_params: torch.Tensor,
1126
+ h_0_base: torch.Tensor,
1127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1128
+ """
1129
+ Process a single chunk (with per-token state for n_orders > 1).
1130
+ """
1131
+ o_intra_current_chunk = torch.zeros(
1132
+ batch,
1133
+ head,
1134
+ chunk_n_total,
1135
+ dim,
1136
+ device=q_chunk_params.device,
1137
+ dtype=linear_precision,
1138
+ )
1139
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1140
+ current_accumulated_state_per_token = (
1141
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1142
+ ) # [B, H, current_chunk_size, D, D]
1143
+
1144
+ for step in range(n_orders):
1145
+ idx_virtual_tokens = (
1146
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1147
+ * n_orders
1148
+ + step
1149
+ )
1150
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1151
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1152
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1153
+
1154
+ state_input_for_this_step = current_accumulated_state_per_token
1155
+
1156
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1157
+ k_trans_h_old = (
1158
+ torch.matmul(
1159
+ w_s.unsqueeze(-2),
1160
+ state_input_for_this_step,
1161
+ )
1162
+ .squeeze(-2)
1163
+ .to(dtype=linear_precision)
1164
+ )
1165
+
1166
+ u_val = u_s - k_trans_h_old
1167
+
1168
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1169
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1170
+ .squeeze(-2)
1171
+ .to(dtype=linear_precision)
1172
+ )
1173
+
1174
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1175
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1176
+ dtype=linear_precision
1177
+ )
1178
+
1179
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1180
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1181
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1182
+ dtype=linear_precision
1183
+ )
1184
+ # Return all needed for next chunk
1185
+ return (
1186
+ o_intra_current_chunk,
1187
+ o_inter_current_chunk,
1188
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1189
+ )
1190
+
1191
+ for chunk_idx_inner in range(num_chunks_inner):
1192
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1193
+ w_chunk_params = w[:, :, chunk_idx_inner]
1194
+ u_chunk_params = u[:, :, chunk_idx_inner]
1195
+
1196
+ # Checkpointed call if training
1197
+ call = (
1198
+ partial(checkpoint, use_reentrant=False)
1199
+ if use_checkpoint
1200
+ else lambda f, *a: f(*a)
1201
+ )
1202
+ o_intra, o_inter, h_0_base = call(
1203
+ process_one_chunk,
1204
+ q_chunk_params,
1205
+ w_chunk_params,
1206
+ u_chunk_params,
1207
+ h_0_base,
1208
+ )
1209
+ if not linear_activation: # unlinear activation between chunks
1210
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1211
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1212
+
1213
+ return output_inner, h_0_base
1214
+
1215
+
1216
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1217
+ """Unlinear activation between chunk"""
1218
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1219
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1220
+ return (x / scale) * x_gelu
1221
+
1222
+
1223
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1224
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1225
+ batch_size, num_heads, _, head_dim = x.shape
1226
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1227
+
1228
+
1229
+ def expand_virtual_tokens(
1230
+ x: torch.Tensor, n: int, mode: str = "derivative"
1231
+ ) -> torch.Tensor:
1232
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1233
+ batch_size, num_heads, seq_len, head_dim = x.shape
1234
+ device, dtype = x.device, x.dtype
1235
+
1236
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1237
+ """Expand tokens using the derivative trick."""
1238
+ x_pad = torch.cat(
1239
+ [
1240
+ torch.zeros(
1241
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1242
+ ),
1243
+ x,
1244
+ ],
1245
+ dim=2,
1246
+ )
1247
+ coeffs = torch.tensor(
1248
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1249
+ device=device,
1250
+ dtype=dtype,
1251
+ )
1252
+ coeffs /= coeffs.norm(p=1)
1253
+ return (
1254
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1255
+ .flip(-1)
1256
+ .permute(0, 1, 2, 4, 3)
1257
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1258
+ )
1259
+
1260
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1261
+ """Expand tokens using the rotative trick."""
1262
+ d_parity = head_dim // 2
1263
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1264
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1265
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1266
+ if head_dim % 2:
1267
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1268
+ else:
1269
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1270
+ x_pairs = x_pairs.unsqueeze(3).expand(
1271
+ batch_size, num_heads, seq_len, n, d_parity, 2
1272
+ )
1273
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1274
+ x0r = x0 * cos - x1 * sin
1275
+ x1r = x0 * sin + x1 * cos
1276
+ rot = torch.stack([x0r, x1r], -1).reshape(
1277
+ batch_size, num_heads, seq_len, n, d_parity * 2
1278
+ )
1279
+ if head_dim % 2:
1280
+ last = (
1281
+ x[..., -1]
1282
+ .unsqueeze(-1)
1283
+ .unsqueeze(3)
1284
+ .expand(batch_size, num_heads, seq_len, n, 1)
1285
+ )
1286
+ rot = torch.cat([rot, last], -1)
1287
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1288
+
1289
+ if mode == "derivative":
1290
+ return derivative_expand(x)
1291
+ if mode == "rotative":
1292
+ return rotative_expand(x)
1293
+ if mode == "combined":
1294
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1295
+ raise ValueError(f"Unknown mode: {mode}")
1296
+
1297
+
1298
+ def extract_layer_idx(module_name: str) -> int:
1299
+ """Extract the layer index from a module name string."""
1300
+ match = re.search(r"\.(\d+)\.", module_name)
1301
+ if match:
1302
+ return int(match.group(1))
1303
+ return -1
1304
+
1305
+
1306
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1307
+ """Find the embedding weight in a model module."""
1308
+ for _, child in module.named_modules():
1309
+ if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
1310
+ return child.embed_tokens
1311
+ if hasattr(child, "token_embeddings") and hasattr(
1312
+ child.token_embeddings, "weight"
1313
+ ):
1314
+ return child.token_embeddings
1315
+ return None
1316
+
1317
+
1318
+ def set_trainable_parameters(
1319
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1320
+ ) -> PreTrainedModel:
1321
+ """Freeze model parameters except trainable_patterns."""
1322
+ if trainable_patterns is None:
1323
+ trainable_patterns = [
1324
+ "q_proj",
1325
+ "k_proj",
1326
+ "v_proj",
1327
+ "o_proj",
1328
+ "qkv_proj",
1329
+ "out_proj",
1330
+ "c_attn",
1331
+ "c_proj",
1332
+ "query",
1333
+ "key",
1334
+ "value",
1335
+ ]
1336
+
1337
+ for name, param in model.named_parameters():
1338
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
1339
+
1340
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1341
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1342
+ return model
1343
+
1344
+
1345
+ def ensure_stability(
1346
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1347
+ ) -> torch.Tensor:
1348
+ """stability forcing"""
1349
+ dtype = tensor.dtype
1350
+ center = (max_val + min_val) / 2
1351
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1352
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1353
+ return tensor.to(dtype=dtype)
1354
+
1355
+
1356
+ def apply_linear_attention_mask(
1357
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1358
+ ) -> torch.Tensor:
1359
+ """Extract if padding --> [B,S]"""
1360
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1361
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1362
+ else:
1363
+ mask = attention_mask.squeeze(
1364
+ dim=tuple(
1365
+ i
1366
+ for i in range(1, attention_mask.dim())
1367
+ if attention_mask.shape[i] == 1
1368
+ )
1369
+ )
1370
+ # Ensure cast to the same dtype as v and convert to binary mask
1371
+ if not (
1372
+ mask.dtype == torch.bool
1373
+ or (
1374
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1375
+ and mask.max() <= 1
1376
+ and mask.min() >= 0
1377
+ )
1378
+ ):
1379
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1380
+ else:
1381
+ mask = mask.to(v.dtype)
1382
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1383
+ if padding_side == "left":
1384
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1385
+ else: # right padding
1386
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1387
+ return v * mask
1388
+
1389
+
1390
+ def truncate_attention_mask(
1391
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1392
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1393
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
1394
+ seq_dim = 1 # convention: (batch, seq, ...)
1395
+ seq_len = hidden_states.shape[seq_dim]
1396
+ if seq_len > max_length:
1397
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
1398
+ if attention_mask is not None:
1399
+ # mask [batch, seq]
1400
+ if attention_mask.dim() == 2:
1401
+ attention_mask = attention_mask[:, -max_length:]
1402
+ # mask [batch, seq, seq]
1403
+ elif attention_mask.dim() == 3:
1404
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
1405
+ # mask [batch, 1, seq, seq]
1406
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1407
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
1408
+ else:
1409
+ raise ValueError(
1410
+ "No dimension in attention_mask matches sequence length of hidden_states."
1411
+ )
1412
+ return hidden_states, attention_mask
1413
+
1414
+
1415
+ def fast_invert_matrix(
1416
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1417
+ ) -> torch.Tensor:
1418
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1419
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1420
+ chunk_size = tri_tensor.shape[-1]
1421
+
1422
+ for i in range(1, chunk_size):
1423
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1424
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1425
+ ).sum(-2)
1426
+
1427
+ tri_tensor = tri_tensor + torch.eye(
1428
+ chunk_size, dtype=dtype, device=tri_tensor.device
1429
+ )
1430
+ return tri_tensor.to(dtype=dtype)
1431
+
1432
+
1433
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1434
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
1435
+ for c in range(min(chunk_size, total_l), 0, -1):
1436
+ if total_l % c == 0:
1437
+ return c
1438
+ return 1
1439
+
1440
+
1441
+ ## RARELY
1442
+ def split_qkv(
1443
+ base_attn: nn.Module, qkv: torch.Tensor
1444
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1445
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1446
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1447
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1448
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1449
+ head_dim = getattr(base_attn, "head_dim", None)
1450
+
1451
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1452
+ raise ValueError(
1453
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1454
+ )
1455
+
1456
+ q_len = num_q_heads * head_dim
1457
+ k_len = num_k_heads * head_dim
1458
+ v_len = num_v_heads * head_dim
1459
+
1460
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1461
+ return q, k, v
1462
+
1463
+
1464
+ ## OPTIONAL
1465
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1466
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
1467
+ src_size = x.shape[dim]
1468
+ if src_size == target_size:
1469
+ return x
1470
+ x = torch.moveaxis(x, dim, -1)
1471
+ shape = x.shape
1472
+ if src_size < target_size:
1473
+ x = x.reshape(-1, 1, src_size)
1474
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
1475
+ x = x.reshape(*shape[:-1], target_size)
1476
+ else:
1477
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
1478
+ x = F.linear(x, eye) # pylint: disable=not-callable
1479
+ x = torch.moveaxis(x, -1, dim)
1480
+ return x
1481
+
1482
+
1483
+ def soft_clamp(
1484
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1485
+ ) -> torch.Tensor:
1486
+ """Differentiable clamping for stability"""
1487
+ dtype = x.dtype
1488
+ scale = (max_val - min_val) / 2
1489
+ center = (max_val + min_val) / 2
1490
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1491
+
1492
+
1493
+ def describe(x: torch.Tensor, name="tensor") -> None:
1494
+ """Prints the shape, min, max, mean, and std of a tensor."""
1495
+ stats = (x.min(), x.max(), x.mean(), x.std())
1496
+ print(
1497
+ f"{name} shape: {tuple(x.shape)}, "
1498
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1499
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1500
+ + f"dtype: {x.dtype}, device: {x.device}"
1501
+ )
lora_delta_product_m0.5_constant/runs/Aug25_08-56-48_bab62a6ed72c/events.out.tfevents.1756112223.bab62a6ed72c.35.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:029aaede7e00d361d1c1d38d545036397d5964c1a19ff5027e3347cfcd676387
3
+ size 115857
lora_delta_product_m0.5_constant/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
lora_delta_product_m0.5_constant/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
lora_delta_product_m0.5_constant/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37f00374dea48658ee8f5d0f21895b9bc55cb0103939607c8185bfd1c6ca1f89
3
+ size 587404
lora_delta_product_m0.5_constant/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_tptt.py ADDED
@@ -0,0 +1,1501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
+ """
4
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
+ Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import (
28
+ AutoConfig,
29
+ AutoModel,
30
+ AutoModelForCausalLM,
31
+ DynamicCache,
32
+ PreTrainedModel,
33
+ )
34
+ from transformers.configuration_utils import PretrainedConfig
35
+
36
+ from .configuration_tptt import TpttConfig
37
+
38
+ logger = logging.getLogger(__name__) # monitoring
39
+
40
+
41
+ class LCache:
42
+ """Cache for storing intermediate states of linear attention layers."""
43
+
44
+ def __init__(self):
45
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
46
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
47
+ {}
48
+ ) # recurrent states and qkv buffers
49
+
50
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
51
+ """Retrieve cached state for a given layer, or None if not present"""
52
+ return self.inputs_states.get(layer_idx, None)
53
+
54
+ def update(self, layer_idx: int, **kwargs):
55
+ """Detach all tensors to avoid retaining computation graphs"""
56
+ detached_kwargs = {
57
+ k: v.detach() if isinstance(v, torch.Tensor) else v
58
+ for k, v in kwargs.items()
59
+ }
60
+ # Update or create the state for the specified layer
61
+ if layer_idx in self.inputs_states:
62
+ self.inputs_states[layer_idx].update(detached_kwargs)
63
+ else:
64
+ self.inputs_states[layer_idx] = detached_kwargs
65
+
66
+ def reset(self):
67
+ """Clear all cached states and reset the token counter"""
68
+ self.inputs_states.clear()
69
+
70
+
71
+ class CausalAvgPool1d(nn.Module):
72
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
73
+
74
+ def __init__(
75
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
76
+ ):
77
+ super().__init__()
78
+ self.offsets = offsets
79
+ self.mode = mode
80
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """x: [B, S, F] → [B, S, F → output_size]"""
84
+ x_ = x.transpose(1, 2) # [B, F, S]
85
+ idxs = torch.tensor(self.offsets, device=x.device)
86
+ ksize = idxs.max() - idxs.min() + 1
87
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
88
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
89
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
90
+ pad_left = -idxs.min().item()
91
+ pad_right = (ksize - 1) - pad_left
92
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
93
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
94
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
95
+
96
+
97
+ class LinearAttention(nn.Module):
98
+ """
99
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
100
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ hidden_dim: int,
106
+ num_heads: int,
107
+ head_dim: Optional[int] = None,
108
+ num_key_value_heads: Optional[int] = None,
109
+ num_key_value_groups: Optional[int] = None,
110
+ bias: bool = True,
111
+ dropout: Optional[float] = None,
112
+ linear_precision: torch.dtype = torch.float32,
113
+ padding_side: str = "right",
114
+ shared_attn: bool = False, # shared attention
115
+ layer_idx: int = 0,
116
+ operator_mode: str = "delta_rule",
117
+ use_linear_checkpoint: bool = False,
118
+ recurrent_config: Optional[Dict[str, Any]] = None,
119
+ linear_cache: Optional[LCache] = None,
120
+ max_chunk_size: int = 64,
121
+ bidirectional: bool = False, # not used if causal
122
+ pooling_config: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ super().__init__()
125
+ if pooling_config is None:
126
+ pooling_config = {
127
+ "offsets": (0, 1, 2),
128
+ "mode": "replicate",
129
+ }
130
+ self.hidden_dim = hidden_dim
131
+ self.num_heads = num_heads
132
+ self.head_dim = head_dim or hidden_dim // num_heads
133
+ self.num_key_value_heads = num_key_value_heads or num_heads
134
+ self.num_key_value_groups = num_key_value_groups or (
135
+ num_heads // (num_key_value_heads or num_heads)
136
+ )
137
+ self.scaling = self.head_dim**-0.5
138
+ self.linear_precision = linear_precision
139
+ self.padding_side = padding_side
140
+
141
+ self.shared_attn = shared_attn
142
+
143
+ if not shared_attn:
144
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
145
+ self.k_proj = nn.Linear(
146
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
147
+ )
148
+ self.v_proj = nn.Linear(
149
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
150
+ )
151
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
152
+
153
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
154
+
155
+ self.linear_operator = LinearAttentionOp(
156
+ layer_idx=layer_idx,
157
+ operator_mode=operator_mode,
158
+ use_linear_checkpoint=use_linear_checkpoint,
159
+ recurrent_config=recurrent_config,
160
+ max_chunk_size=max_chunk_size,
161
+ linear_cache=linear_cache,
162
+ linear_precision=linear_precision,
163
+ )
164
+ self.bidirectional = bidirectional
165
+ # Causal average pooling for gating
166
+ self.pooling_config = pooling_config
167
+ self.pool_g = CausalAvgPool1d(
168
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ x: Union[List[torch.Tensor], torch.Tensor],
174
+ attn_mask: Optional[torch.Tensor] = None,
175
+ out_proj: Optional[nn.Module] = None,
176
+ **kwargs: Any,
177
+ ) -> torch.Tensor:
178
+ """
179
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
180
+ """
181
+
182
+ if not self.shared_attn:
183
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
184
+ # Projections
185
+ q = self.q_proj(hidden_states)
186
+ k = self.k_proj(hidden_states)
187
+ v = self.v_proj(hidden_states)
188
+ out_proj = self.out_proj
189
+ else:
190
+ # Shared attention <=> no projections here
191
+ q, k, v = x[0], x[1], x[2]
192
+ out_proj = self.out_proj if out_proj is None else out_proj
193
+
194
+ # get dtype and device
195
+ final_dtype, final_device = q.dtype, q.device
196
+ # Masking if needed
197
+ if attn_mask is not None:
198
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
199
+
200
+ # Forget and Write Gating for linear attn (abusive term)
201
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
202
+
203
+ # Reshape for multi-head
204
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
205
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
206
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
207
+
208
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
209
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
210
+
211
+ # Repeat for GQA
212
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
213
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
214
+
215
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
216
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
217
+
218
+ ## DeltaNet-style: Silu activation and normalization
219
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
220
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
221
+
222
+ ## linear stability part
223
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
224
+
225
+ # Apply sigmoid to forget and write gates
226
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
227
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
228
+
229
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
230
+ q, k, v, f_g, w_g = (
231
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
232
+ )
233
+ g = (f_g, w_g)
234
+
235
+ # Linear Attention Core, output: [B, H, S, d]
236
+ if self.bidirectional: # Work only with uncausal attention
237
+ # Forward direction
238
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
239
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
240
+ kwargs_bwd = kwargs.copy()
241
+ kwargs_bwd["use_cache"] = False
242
+ out_backward = self.linear_operator(
243
+ torch.flip(q, dims=[2]),
244
+ torch.flip(k, dims=[2]),
245
+ torch.flip(v, dims=[2]),
246
+ tuple(torch.flip(t, dims=[2]) for t in g),
247
+ **kwargs_bwd,
248
+ )
249
+ # Flip the output back to restore proper order
250
+ out_backward = torch.flip(out_backward, dims=[2])
251
+ # Fusion: here, simple addition
252
+ out = out_forward + out_backward
253
+ else:
254
+ out = self.linear_operator(q, k, v, g, **kwargs)
255
+
256
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
257
+ out = rearrange(out, "b h s d -> b s (h d)")
258
+ # Normalize output (RMS norm). Note: bidirectional compatibility
259
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
260
+ # Ensure dtype and device consistency
261
+ out = out.to(dtype=final_dtype, device=final_device)
262
+ # Apply output projection
263
+ out = out_proj(out) # [B, S, D]
264
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
265
+ # Apply dropout if specified
266
+ if self.dropout is not None:
267
+ out = self.dropout(out)
268
+ return out
269
+
270
+
271
+ class LiZAttention(nn.Module):
272
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
273
+
274
+ def __init__(
275
+ self,
276
+ base_attn: nn.Module,
277
+ layer_idx: int,
278
+ base_config: PretrainedConfig, # Backbone Config
279
+ linear_cache: Optional[LCache] = None,
280
+ operator_mode: str = "delta_rule",
281
+ use_linear_checkpoint: bool = False,
282
+ recurrent_config: Optional[Dict[str, Any]] = None,
283
+ max_self_attn_length: Optional[int] = None, # unnecessary
284
+ base_scale_attn: bool = False,
285
+ mag_weight: float = 0.5,
286
+ cross_gate: bool = False,
287
+ max_chunk_size: int = 64,
288
+ linear_precision: Union[str, torch.dtype] = "float32",
289
+ padding_side: str = "right", # for tokenizer
290
+ disable_linear_attn: bool = False,
291
+ bidirectional: bool = False, # if True, use bidirectional attention
292
+ pooling_config: Optional[Dict[str, Any]] = None,
293
+ ):
294
+ super().__init__()
295
+ if isinstance(linear_precision, str):
296
+ linear_precision = getattr(torch, linear_precision)
297
+ self.linear_precision = linear_precision
298
+ self.base_attn: nn.Module = base_attn
299
+ self.base_config = base_config
300
+ self.layer_idx = layer_idx
301
+ self.max_self_attn_length = max_self_attn_length
302
+ self.base_scale_attn = base_scale_attn
303
+ self.mag_weight = mag_weight
304
+ self.cross_gate = cross_gate
305
+ self.max_chunk_size = max_chunk_size
306
+ self.linear_precision = linear_precision
307
+ self.padding_side = padding_side
308
+ self.disable_linear_attn = disable_linear_attn
309
+
310
+ (
311
+ self.num_heads,
312
+ self.head_dim,
313
+ self.num_key_value_heads,
314
+ self.num_key_value_groups,
315
+ ) = self._get_attention_parameters(base_attn, base_config)
316
+ self.scaling = self.head_dim**-0.5
317
+
318
+ self.linear_attn = LinearAttention(
319
+ layer_idx=layer_idx,
320
+ shared_attn=True,
321
+ operator_mode=operator_mode,
322
+ use_linear_checkpoint=use_linear_checkpoint,
323
+ recurrent_config=recurrent_config,
324
+ hidden_dim=base_config.hidden_size,
325
+ num_heads=self.num_heads,
326
+ head_dim=self.head_dim,
327
+ num_key_value_heads=self.num_key_value_heads,
328
+ num_key_value_groups=self.num_key_value_groups,
329
+ linear_precision=linear_precision,
330
+ linear_cache=linear_cache,
331
+ max_chunk_size=max_chunk_size,
332
+ padding_side=padding_side,
333
+ bidirectional=bidirectional,
334
+ pooling_config=pooling_config,
335
+ )
336
+
337
+ def _get_attention_parameters(
338
+ self, base_attn: nn.Module, base_config: PretrainedConfig
339
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
340
+ """Retrieve the attention parameters from the base attention module."""
341
+ # first order base attention module and second order config
342
+ num_heads = (
343
+ getattr(base_attn, "num_heads", None)
344
+ or getattr(base_attn, "num_q_heads", None)
345
+ or getattr(base_config, "num_heads", None)
346
+ or getattr(base_config, "num_attention_heads", None)
347
+ )
348
+ head_dim = (
349
+ getattr(base_attn, "head_dim", None)
350
+ or getattr(base_attn, "attention_head_size", None)
351
+ or getattr(base_config, "head_dim", None)
352
+ or (
353
+ getattr(base_config, "hidden_size", None) // num_heads
354
+ if num_heads and getattr(base_config, "hidden_size", None)
355
+ else None
356
+ )
357
+ )
358
+ num_key_value_heads = (
359
+ getattr(base_attn, "num_kv_heads", None)
360
+ or getattr(base_attn, "num_k_heads", None)
361
+ or getattr(base_config, "num_key_value_heads", None)
362
+ or num_heads # fallback
363
+ )
364
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
365
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
366
+ )
367
+ return (
368
+ num_heads,
369
+ head_dim,
370
+ num_key_value_heads,
371
+ num_key_value_groups,
372
+ )
373
+
374
+ def _apply_shared_projections(
375
+ self, hidden_states: torch.Tensor
376
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
377
+ base_attn = self.base_attn
378
+ if hasattr(base_attn, "q_proj"):
379
+ # LLama, OLMO and Mistral style
380
+ q = base_attn.q_proj(hidden_states)
381
+ k = base_attn.k_proj(hidden_states)
382
+ v = base_attn.v_proj(hidden_states)
383
+ out_proj = base_attn.o_proj
384
+ elif hasattr(base_attn, "qkv_proj"):
385
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
386
+ qkv = base_attn.qkv_proj(hidden_states)
387
+ q, k, v = split_qkv(base_attn, qkv)
388
+ out_proj = base_attn.out_proj
389
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
390
+ # GPT-2 style
391
+ qkv = base_attn.c_attn(hidden_states)
392
+ q, k, v = qkv.chunk(3, dim=-1)
393
+ out_proj = base_attn.c_proj
394
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
395
+ # BERT - ViT
396
+ q = base_attn.query(hidden_states)
397
+ k = base_attn.key(hidden_states)
398
+ v = base_attn.value(hidden_states)
399
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
400
+ else:
401
+ raise ValueError("Unsupported attention module: cannot find projections.")
402
+ # Ensure stability
403
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
404
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
405
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
406
+ return q, k, v, out_proj
407
+
408
+ def _process_self_attn(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ attention_mask: Optional[torch.Tensor],
412
+ kwargs,
413
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
414
+ """Process the self-attention part (with truncation)."""
415
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
416
+ hidden_states, attention_mask = truncate_attention_mask(
417
+ hidden_states, attention_mask, self.max_self_attn_length
418
+ )
419
+
420
+ if kwargs.get("position_embeddings", None) is not None:
421
+ cos, sin = kwargs["position_embeddings"]
422
+ cos = cos[:, -self.max_self_attn_length :]
423
+ sin = sin[:, -self.max_self_attn_length :]
424
+ kwargs["position_embeddings"] = (cos, sin)
425
+
426
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
427
+ # cache management
428
+ if (
429
+ len(kwargs["past_key_value"]) > self.layer_idx
430
+ and self.layer_idx == 0
431
+ ):
432
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
433
+
434
+ # Ensure attention mask is of the correct dtype and device
435
+ if attention_mask is not None:
436
+ attention_mask = attention_mask.to(
437
+ dtype=hidden_states.dtype, device=hidden_states.device
438
+ )
439
+ # Standard attention (mask and rotation is applied inside)
440
+ base_attn_outputs = self.base_attn(
441
+ hidden_states,
442
+ attention_mask=attention_mask,
443
+ **kwargs,
444
+ )
445
+
446
+ if isinstance(base_attn_outputs, tuple):
447
+ if len(base_attn_outputs) == 3:
448
+ o_base, attn_weights, present_key_value = base_attn_outputs
449
+ expected_attn_mode = 3
450
+ elif len(base_attn_outputs) == 2:
451
+ o_base, attn_weights = base_attn_outputs
452
+ present_key_value, expected_attn_mode = None, 2
453
+ else:
454
+ raise ValueError(
455
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
456
+ )
457
+ else:
458
+ o_base = base_attn_outputs
459
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
460
+ # Ensure stability
461
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
462
+ return o_base, attn_weights, present_key_value, expected_attn_mode
463
+
464
+ def _prepare_attn_mixin(
465
+ self,
466
+ o_lin: torch.Tensor,
467
+ o_base: torch.Tensor,
468
+ tensor_dtype: torch.dtype,
469
+ eps: float = 1e-5,
470
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
471
+ """Prepare linear attn for mixing with self attn."""
472
+ # Force cast typing, shape : [b n (h d)]
473
+ o_lin = o_lin.to(tensor_dtype)
474
+ o_base = o_base.to(tensor_dtype)
475
+ # feature scaling
476
+ if self.base_scale_attn:
477
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
478
+ o_lin = scaler * o_lin
479
+ return o_lin, o_base
480
+
481
+ def _apply_mag(
482
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
483
+ ) -> torch.Tensor:
484
+ """Apply the MAG strategy"""
485
+ # Left-Padding management
486
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
487
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
488
+ linear_attention, softmax_attention = (
489
+ linear_attention[:, -left_trunc:],
490
+ softmax_attention[:, -left_trunc:],
491
+ )
492
+ # NAM : Neural Attention Mixer (with graph forcing)
493
+ mag_weight = torch.tensor(
494
+ self.mag_weight,
495
+ dtype=softmax_attention.dtype,
496
+ device=softmax_attention.device,
497
+ )
498
+ softmax_weighted = (1 - mag_weight) * softmax_attention
499
+ linear_weighted = mag_weight * linear_attention
500
+ if self.cross_gate:
501
+ output_attention = (
502
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
503
+ ) # complex cross product (unlinear interaction)
504
+ else:
505
+ output_attention = softmax_weighted + linear_weighted # classic
506
+
507
+ if torch.allclose(softmax_weighted, output_attention):
508
+ logger.info(
509
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
510
+ self.layer_idx,
511
+ )
512
+ # Final output
513
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
514
+
515
+ def forward(
516
+ self,
517
+ hidden_states: torch.Tensor,
518
+ attention_mask: Optional[torch.Tensor] = None,
519
+ **kwargs,
520
+ ) -> torch.Tensor:
521
+ """Mix linear and self attention forward"""
522
+ device = hidden_states.device
523
+ tensor_dtype = hidden_states.dtype
524
+ self.base_attn.to(device)
525
+
526
+ if self.training:
527
+ kwargs.pop("past_key_value", None)
528
+ kwargs["use_cache"] = False
529
+ elif "use_cache" not in kwargs:
530
+ kwargs.pop("past_key_value", None)
531
+ kwargs["use_cache"] = False
532
+
533
+ kwargs.pop("position_ids", None) # obsolete
534
+
535
+ # Apply shared projections
536
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
537
+
538
+ # Apply linear attention to hidden states
539
+ o_lin = self.linear_attn(
540
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
541
+ )
542
+
543
+ # Process self attn with truncation
544
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
545
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
546
+ )
547
+
548
+ # Prepare output mixing
549
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
550
+
551
+ # Apply Memory as Gate in self-attention (with length management and ablation)
552
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
553
+
554
+ # Return output following transformer convention
555
+ if expected_attn_mode == 3:
556
+ return out, attn_weights, present_key_value
557
+ if expected_attn_mode == 2:
558
+ return out, attn_weights
559
+ return out
560
+
561
+
562
+ def load_tptt_safetensors(
563
+ repo_or_path: str,
564
+ model: Union[PreTrainedModel, PeftModel],
565
+ subfolder: Optional[str] = None,
566
+ token: Optional[str] = None,
567
+ ) -> Union[PreTrainedModel, PeftModel]:
568
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
569
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
570
+ fname = "adapter_model.safetensors"
571
+ # subfolder management
572
+ if subfolder:
573
+ repo_or_path_norm = os.path.normpath(repo_or_path)
574
+ subfolder_norm = os.path.normpath(subfolder)
575
+ if not repo_or_path_norm.endswith(subfolder_norm):
576
+ fname = f"{subfolder}/{fname}" if subfolder else fname
577
+ # Find file path
578
+ if os.path.isdir(repo_or_path):
579
+ path = os.path.join(repo_or_path, fname)
580
+ if not os.path.exists(path):
581
+ return model
582
+ else:
583
+ if fname not in list_repo_files(repo_or_path, token=token):
584
+ return model
585
+ path = hf_hub_download(repo_or_path, fname, token=token)
586
+
587
+ # Load weights from safetensors
588
+ with safe_open(path, framework="pt") as f:
589
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
590
+
591
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
592
+ def adapt_keys(sd, model):
593
+ model_keys = list(model.state_dict().keys())
594
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
595
+ prefix = "tptt_model.base_model."
596
+ elif any(k.startswith("base_model.") for k in model_keys):
597
+ prefix = "base_model."
598
+ else:
599
+ prefix = ""
600
+
601
+ has_base_attn = any(".base_attn." in k for k in model_keys)
602
+
603
+ def adapt_key(k):
604
+ k_ = k if k.startswith(prefix) else prefix + k
605
+ # first, verify and modify base_attn (LiZA)
606
+ if ".base_attn." in k_ and not has_base_attn:
607
+ k_ = k_.replace(".base_attn.", ".")
608
+ # change LoRA if needed
609
+ if (
610
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
611
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
612
+ k_ = k_.replace(".weight", ".default.weight")
613
+ return k_
614
+
615
+ return {adapt_key(k): v for k, v in sd.items()}
616
+
617
+ state_dict = adapt_keys(state_dict, model)
618
+
619
+ # Cast tensors to the expected dtype of the model parameters
620
+ model_state_dict = model.state_dict()
621
+ for k, v in state_dict.items():
622
+ if k in model_state_dict:
623
+ expected_dtype = model_state_dict[k].dtype
624
+ if v.dtype != expected_dtype:
625
+ state_dict[k] = v.to(expected_dtype)
626
+
627
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
628
+
629
+ # Load into model
630
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
631
+ missing_lora = [k for k in missing if "lora" in k]
632
+ if missing_lora:
633
+ logger.warning("Missing keys: %s", missing_lora)
634
+ if unexpected:
635
+ logger.warning("Unexpected keys: %s", unexpected)
636
+ return model
637
+
638
+
639
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
640
+ model: nn.Module,
641
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
642
+ linear_cache: Optional[LCache] = None,
643
+ liza_attention: nn.Module = LiZAttention,
644
+ target_modules_names: Optional[list[str]] = None,
645
+ operator_mode: str = "delta_rule",
646
+ use_linear_checkpoint: bool = False,
647
+ recurrent_config: Optional[Dict[str, Any]] = None,
648
+ base_scale_attn: bool = False,
649
+ mag_weight: float = 0.5,
650
+ cross_gate: bool = False,
651
+ max_chunk_size: int = 64,
652
+ linear_precision: torch.dtype = torch.float32,
653
+ max_self_attn_length: Optional[int] = None, # unnecessary
654
+ padding_side: str = "right", # for tokenizer
655
+ bidirectional: bool = False, # if True, use bidirectional attention
656
+ pooling_config: Optional[Dict[str, Any]] = None,
657
+ **kwargs, # quickfix unexpected arguments
658
+ ) -> Tuple[PreTrainedModel, LCache]:
659
+ """Replace target modules in a model with LiZAttention."""
660
+ if target_modules_names is None:
661
+ target_modules_names = ["attn", "self_attn", "attention"]
662
+ # Find target modules by suffix (e.g., "attn", "attention")
663
+ target_modules_names = [
664
+ name
665
+ for name, _ in model.named_modules()
666
+ if any(name.endswith(suffix) for suffix in target_modules_names)
667
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
668
+ ]
669
+ if not target_modules_names:
670
+ raise ValueError(
671
+ f"Target modules '{target_modules_names}' not found in the model."
672
+ )
673
+ # Prepare recurrent config
674
+ linear_cache = linear_cache or LCache()
675
+ # Inject LiZAttention into the model
676
+ for name, _ in model.named_modules():
677
+ if name in target_modules_names:
678
+ parent = model
679
+ *path, last = name.split(".")
680
+ for p in path:
681
+ parent = getattr(parent, p)
682
+ layer_idx = extract_layer_idx(name)
683
+ setattr(
684
+ parent,
685
+ last,
686
+ liza_attention(
687
+ getattr(parent, last),
688
+ layer_idx=layer_idx,
689
+ base_config=base_config,
690
+ linear_cache=linear_cache,
691
+ operator_mode=operator_mode,
692
+ use_linear_checkpoint=use_linear_checkpoint,
693
+ recurrent_config=recurrent_config,
694
+ max_self_attn_length=max_self_attn_length,
695
+ base_scale_attn=base_scale_attn,
696
+ mag_weight=mag_weight,
697
+ cross_gate=cross_gate,
698
+ max_chunk_size=max_chunk_size,
699
+ linear_precision=linear_precision,
700
+ padding_side=padding_side,
701
+ bidirectional=bidirectional,
702
+ pooling_config=pooling_config,
703
+ ),
704
+ )
705
+ return model, linear_cache
706
+
707
+
708
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
709
+ """Save trainable LoRA/Specific weights and adapting key names"""
710
+ # 1. Get the full state_dict
711
+ all_sd = model.state_dict()
712
+
713
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
714
+ trainable_keys = [
715
+ name for name, param in model.named_parameters() if param.requires_grad
716
+ ] # Also, you can manually select specific keys in model after load
717
+
718
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
719
+ to_save = {
720
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
721
+ for k in trainable_keys
722
+ }
723
+
724
+ # 4. Save the filtered adapters to a safetensors file
725
+ if to_save:
726
+ os.makedirs(os.path.dirname(path), exist_ok=True)
727
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
728
+ save_file(to_save, os.path.join(path, name))
729
+
730
+
731
+ class TpttModel(PreTrainedModel):
732
+ """
733
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
734
+ Handles only architecture and weights.
735
+ """
736
+
737
+ config_class = TpttConfig
738
+
739
+ def __init__(
740
+ self,
741
+ config: TpttConfig,
742
+ **kwargs,
743
+ ):
744
+ """
745
+ Initialize TpttModel with a given config and backbone.
746
+ Injects LiZA attention modules into the backbone.
747
+ """
748
+ super().__init__(config, **kwargs)
749
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
750
+
751
+ # 1. Load backbone (with subfolder management) :
752
+ kwargs_bb = kwargs.copy()
753
+ if config.base_model_subfolder is not None:
754
+ kwargs_bb["subfolder"] = config.base_model_subfolder
755
+ else:
756
+ kwargs_bb.pop("subfolder", None)
757
+
758
+ if config.model_task == "causal_lm":
759
+ tptt_model = AutoModelForCausalLM.from_pretrained(
760
+ config.base_model_name, **kwargs_bb
761
+ )
762
+ else:
763
+ tptt_model = AutoModel.from_pretrained(config.base_model_name, **kwargs_bb)
764
+
765
+ # 2. Inject LiZA attention
766
+ self.linear_cache = LCache()
767
+ tptt_model, self.linear_cache = get_tptt_model(
768
+ tptt_model, config, self.linear_cache, **config.to_dict()
769
+ )
770
+
771
+ # 3. Apply LoRA/Specific if present and configured
772
+ if config.lora_config is not None:
773
+ lora_config_obj = LoraConfig(**config.lora_config)
774
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
775
+ else:
776
+ # Doesn't work if quantization is applied !
777
+ tptt_model = set_trainable_parameters(tptt_model)
778
+
779
+ # 4. Load safetensor if tptt/peft adaptor in repo
780
+ if repo_or_path:
781
+ tptt_model = load_tptt_safetensors(
782
+ repo_or_path,
783
+ tptt_model,
784
+ subfolder=kwargs.get("subfolder", None),
785
+ token=kwargs.get("token", None),
786
+ )
787
+ self.tptt_model = tptt_model
788
+
789
+ def forward(
790
+ self,
791
+ input_ids: Optional[torch.LongTensor] = None,
792
+ attention_mask: Optional[torch.Tensor] = None,
793
+ labels: Optional[torch.LongTensor] = None,
794
+ **kwargs,
795
+ ):
796
+ """Forward pass. All arguments are passed to the underlying base model."""
797
+ if self.training:
798
+ kwargs["use_cache"] = False
799
+ kwargs.pop("num_items_in_batch", None)
800
+ elif "use_cache" not in kwargs: # evaluation
801
+ kwargs.pop("num_items_in_batch", None)
802
+ kwargs["use_cache"] = False
803
+ return self.tptt_model(
804
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
805
+ )
806
+
807
+ def generate(self, *args, **kwargs):
808
+ """Delegate the generate call to the backbone model, which supports generation"""
809
+ return self.tptt_model.generate(*args, **kwargs)
810
+
811
+ def save_pretrained(self, path: str, **kwargs):
812
+ """Save model weights, config, and source code to the given path."""
813
+ # 0. Save complete tptt config (with or without LoRA)
814
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
815
+ self._adjust_save_strategy(path, **kwargs)
816
+ # 1. Save true weights and adapte keys
817
+ save_tptt_safetensors(self, path)
818
+ # 2. Copy Python files for trust_remote_code
819
+ self._copy_source_files(path, **kwargs)
820
+
821
+ def _adjust_save_strategy(self, path: str, **kwargs):
822
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
823
+ if isinstance(self.tptt_model, PeftModel):
824
+ self.tptt_model.save_pretrained(path, **kwargs)
825
+ safetensor_path = os.path.join(path, "model.safetensors")
826
+ if os.path.exists(safetensor_path):
827
+ os.remove(safetensor_path)
828
+ adapter_path = os.path.join(path, "adapter_config.json")
829
+ if os.path.exists(adapter_path):
830
+ os.remove(adapter_path)
831
+
832
+ def _copy_source_files(self, target_path: str, **kwargs):
833
+ """Copy all .py files from package directory for trust_remote_code."""
834
+ src_dir = os.path.dirname(os.path.abspath(__file__))
835
+ dst_dir = (
836
+ f"./{str(Path(target_path).parts[0])}"
837
+ if kwargs.get("subfolder", False)
838
+ else target_path
839
+ )
840
+ for fname in os.listdir(src_dir):
841
+ if fname.endswith(".py"):
842
+ src = os.path.join(src_dir, fname)
843
+ dst = os.path.join(dst_dir, fname)
844
+ shutil.copy2(src, dst)
845
+
846
+ def retie_lm_after_load(self, **kwargs):
847
+ """Re-link lm_head after loading external weights."""
848
+ embed_lm = find_embedding_lm(self.tptt_model)
849
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
850
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
851
+ self.tptt_model.lm_head = nn.Linear(
852
+ embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
853
+ )
854
+ if kwargs.get("tie_word_embeddings", True):
855
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
856
+ logger.info("Weights of lm_head have been shared with embedding.")
857
+ else:
858
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
859
+ logger.info("Weights of lm_head have been cloned from the embedding.")
860
+
861
+ @classmethod
862
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
863
+ """Custom from_pretrained that accepts the standard positional argument"""
864
+ config = kwargs.pop("config", None)
865
+ repo_or_path = (
866
+ pretrained_model_name_or_path
867
+ or kwargs.pop("pretrained_model_name_or_path", None)
868
+ or kwargs.pop("repo_or_path", None)
869
+ or (getattr(config, "_base_path", None) if config else None)
870
+ or (getattr(config, "_name_or_path", None) if config else None)
871
+ )
872
+
873
+ if config is None and repo_or_path is not None:
874
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
875
+ model = cls(config, *model_args, **kwargs)
876
+ model.retie_lm_after_load(**kwargs)
877
+ return model
878
+
879
+
880
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
881
+
882
+
883
+ class LinearAttentionOp(nn.Module):
884
+ """Base class for linear attention operators."""
885
+
886
+ def __init__(
887
+ self,
888
+ layer_idx: int,
889
+ operator_mode: str = "delta_rule",
890
+ use_linear_checkpoint: bool = False,
891
+ recurrent_config: Optional[dict] = None,
892
+ max_chunk_size: int = 64,
893
+ linear_cache: Optional[LCache] = None,
894
+ linear_precision: torch.dtype = torch.float32,
895
+ ):
896
+ super().__init__()
897
+ self.layer_idx = layer_idx
898
+ if recurrent_config is None:
899
+ operator_mode = "delta_rule" # force default operator mode if no config
900
+ recurrent_config = {
901
+ "order": 1,
902
+ "gate_type": "k",
903
+ "linear": True,
904
+ "trick": "derivative",
905
+ }
906
+ self.operator_mode = operator_mode
907
+ self.use_linear_checkpoint = use_linear_checkpoint
908
+
909
+ self.order = recurrent_config["order"]
910
+ self.gate_type = recurrent_config["gate_type"]
911
+ self.linear = recurrent_config["linear"]
912
+ self.trick = recurrent_config["trick"]
913
+
914
+ self.max_chunk_size = max_chunk_size
915
+ self.linear_cache = linear_cache or LCache()
916
+ self.linear_precision = linear_precision
917
+
918
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
919
+ """
920
+ Compute the gating tensor according to the gate_type.
921
+ """
922
+ if self.gate_type == "k":
923
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
924
+ if self.gate_type == "v":
925
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
926
+ if self.gate_type == "kv":
927
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
928
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
929
+
930
+ def get_cache(self, use_cache: bool) -> Tuple[
931
+ Optional[torch.Tensor],
932
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
933
+ ]:
934
+ """
935
+ Retrieve recurrent state and qkv buffers from the cache.
936
+ """
937
+ if not use_cache:
938
+ return None, None
939
+ last_state = self.linear_cache[self.layer_idx]
940
+ if last_state is not None:
941
+ recurrent_state = last_state.get("recurrent_state", None)
942
+ qkv_buffers = last_state.get("qkv", None)
943
+ else:
944
+ recurrent_state = None
945
+ qkv_buffers = None
946
+ return recurrent_state, qkv_buffers
947
+
948
+ def save_cache(
949
+ self,
950
+ use_cache: bool,
951
+ q: torch.Tensor,
952
+ k: torch.Tensor,
953
+ v: torch.Tensor,
954
+ gate: torch.Tensor,
955
+ state: torch.Tensor,
956
+ ) -> None:
957
+ """
958
+ Save the recurrent state and qkv buffers to the cache.
959
+ """
960
+ if not use_cache:
961
+ return
962
+ if self.order > 1:
963
+ qkv_buffers = (
964
+ q[:, :, -(self.order - 1) :, :],
965
+ k[:, :, -(self.order - 1) :, :],
966
+ v[:, :, -(self.order - 1) :, :],
967
+ gate[:, :, -(self.order - 1) :, :],
968
+ )
969
+ else:
970
+ qkv_buffers = None
971
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
972
+
973
+ def forward(
974
+ self,
975
+ q: torch.Tensor,
976
+ k: torch.Tensor,
977
+ v: torch.Tensor,
978
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
979
+ **kwargs,
980
+ ) -> torch.Tensor:
981
+ """
982
+ Forward pass for the attention operator.
983
+ """
984
+ # Ensure linear_precision for numerical stability (float32)
985
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
986
+ if isinstance(beta, (tuple, list)):
987
+ beta = tuple(b.to(self.linear_precision) for b in beta)
988
+ else:
989
+ beta = beta.to(self.linear_precision)
990
+
991
+ gate = self.compute_gate(beta)
992
+
993
+ # Retrieve cache if needed
994
+ use_cache = kwargs.get("use_cache", False)
995
+ use_checkpoint = not (use_cache) and self.use_linear_checkpoint
996
+ recurrent_state, qkvb = self.get_cache(use_cache)
997
+
998
+ if qkvb is not None and qkvb[0].shape == q.shape:
999
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
1000
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
1001
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
1002
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
1003
+ self.linear_precision
1004
+ )
1005
+
1006
+ output, state = self.chunk_delta_product_forward(
1007
+ q,
1008
+ k,
1009
+ v,
1010
+ gate,
1011
+ self.max_chunk_size,
1012
+ n=self.order,
1013
+ trick=self.trick,
1014
+ linear=self.linear,
1015
+ initial_state=recurrent_state,
1016
+ use_checkpoint=use_checkpoint,
1017
+ linear_precision=self.linear_precision,
1018
+ )
1019
+
1020
+ # Save cache if needed
1021
+ self.save_cache(use_cache, q, k, v, gate, state)
1022
+
1023
+ return output
1024
+
1025
+ @staticmethod
1026
+ def chunk_delta_product_forward(
1027
+ query: torch.Tensor,
1028
+ key: torch.Tensor,
1029
+ value: torch.Tensor,
1030
+ beta_gate: torch.Tensor,
1031
+ chunk_size: int,
1032
+ n: int = 1,
1033
+ trick: str = "derivative",
1034
+ linear: bool = True,
1035
+ initial_state: Optional[torch.Tensor] = None,
1036
+ use_checkpoint: bool = True,
1037
+ linear_precision: torch.dtype = torch.float32,
1038
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1039
+ """
1040
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1041
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
1042
+ """
1043
+
1044
+ # --- Main chunk_delta_product_forward logic ---
1045
+
1046
+ batch_size, num_heads, seq_len, head_dim = query.shape
1047
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1048
+ num_chunks = seq_len // chunk_size
1049
+
1050
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1051
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1052
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1053
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
1054
+
1055
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1056
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1057
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1058
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1059
+
1060
+ k_beta = k_chunks * beta_chunks
1061
+ v_beta = v_chunks * beta_chunks
1062
+
1063
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1064
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1065
+
1066
+ # size : N = chunk_size * n
1067
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1068
+
1069
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1070
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1071
+
1072
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1073
+ if initial_state is not None and initial_state.shape == state_shape:
1074
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1075
+ else:
1076
+ state = torch.full(
1077
+ state_shape,
1078
+ fill_value=1e-6, # stability if unlinear activation
1079
+ device=query.device,
1080
+ dtype=linear_precision,
1081
+ )
1082
+
1083
+ output, final_state = sequential_delta_product_scan(
1084
+ q_chunks.to(dtype=linear_precision),
1085
+ w.to(dtype=linear_precision),
1086
+ u.to(dtype=linear_precision),
1087
+ n,
1088
+ linear,
1089
+ chunk_size,
1090
+ state.to(dtype=linear_precision),
1091
+ linear_precision=linear_precision,
1092
+ use_checkpoint=use_checkpoint,
1093
+ )
1094
+
1095
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1096
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
1097
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
1098
+
1099
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1100
+
1101
+
1102
+ def sequential_delta_product_scan(
1103
+ q_chunks: torch.Tensor,
1104
+ w: torch.Tensor,
1105
+ u: torch.Tensor,
1106
+ n_orders: int,
1107
+ linear_activation: bool,
1108
+ current_chunk_size: int,
1109
+ initial_recurrent_state: torch.Tensor,
1110
+ linear_precision: torch.dtype,
1111
+ use_checkpoint: bool,
1112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1113
+ """
1114
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1115
+ Implements the per-token Householder state updates.
1116
+ """
1117
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1118
+ output_inner = torch.empty_like(q_chunks)
1119
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1120
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1121
+
1122
+ def process_one_chunk(
1123
+ q_chunk_params: torch.Tensor,
1124
+ w_chunk_params: torch.Tensor,
1125
+ u_chunk_params: torch.Tensor,
1126
+ h_0_base: torch.Tensor,
1127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1128
+ """
1129
+ Process a single chunk (with per-token state for n_orders > 1).
1130
+ """
1131
+ o_intra_current_chunk = torch.zeros(
1132
+ batch,
1133
+ head,
1134
+ chunk_n_total,
1135
+ dim,
1136
+ device=q_chunk_params.device,
1137
+ dtype=linear_precision,
1138
+ )
1139
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1140
+ current_accumulated_state_per_token = (
1141
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1142
+ ) # [B, H, current_chunk_size, D, D]
1143
+
1144
+ for step in range(n_orders):
1145
+ idx_virtual_tokens = (
1146
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1147
+ * n_orders
1148
+ + step
1149
+ )
1150
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1151
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1152
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1153
+
1154
+ state_input_for_this_step = current_accumulated_state_per_token
1155
+
1156
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1157
+ k_trans_h_old = (
1158
+ torch.matmul(
1159
+ w_s.unsqueeze(-2),
1160
+ state_input_for_this_step,
1161
+ )
1162
+ .squeeze(-2)
1163
+ .to(dtype=linear_precision)
1164
+ )
1165
+
1166
+ u_val = u_s - k_trans_h_old
1167
+
1168
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1169
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1170
+ .squeeze(-2)
1171
+ .to(dtype=linear_precision)
1172
+ )
1173
+
1174
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1175
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1176
+ dtype=linear_precision
1177
+ )
1178
+
1179
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1180
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1181
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1182
+ dtype=linear_precision
1183
+ )
1184
+ # Return all needed for next chunk
1185
+ return (
1186
+ o_intra_current_chunk,
1187
+ o_inter_current_chunk,
1188
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1189
+ )
1190
+
1191
+ for chunk_idx_inner in range(num_chunks_inner):
1192
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1193
+ w_chunk_params = w[:, :, chunk_idx_inner]
1194
+ u_chunk_params = u[:, :, chunk_idx_inner]
1195
+
1196
+ # Checkpointed call if training
1197
+ call = (
1198
+ partial(checkpoint, use_reentrant=False)
1199
+ if use_checkpoint
1200
+ else lambda f, *a: f(*a)
1201
+ )
1202
+ o_intra, o_inter, h_0_base = call(
1203
+ process_one_chunk,
1204
+ q_chunk_params,
1205
+ w_chunk_params,
1206
+ u_chunk_params,
1207
+ h_0_base,
1208
+ )
1209
+ if not linear_activation: # unlinear activation between chunks
1210
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1211
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1212
+
1213
+ return output_inner, h_0_base
1214
+
1215
+
1216
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1217
+ """Unlinear activation between chunk"""
1218
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1219
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1220
+ return (x / scale) * x_gelu
1221
+
1222
+
1223
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1224
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1225
+ batch_size, num_heads, _, head_dim = x.shape
1226
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1227
+
1228
+
1229
+ def expand_virtual_tokens(
1230
+ x: torch.Tensor, n: int, mode: str = "derivative"
1231
+ ) -> torch.Tensor:
1232
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1233
+ batch_size, num_heads, seq_len, head_dim = x.shape
1234
+ device, dtype = x.device, x.dtype
1235
+
1236
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1237
+ """Expand tokens using the derivative trick."""
1238
+ x_pad = torch.cat(
1239
+ [
1240
+ torch.zeros(
1241
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1242
+ ),
1243
+ x,
1244
+ ],
1245
+ dim=2,
1246
+ )
1247
+ coeffs = torch.tensor(
1248
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1249
+ device=device,
1250
+ dtype=dtype,
1251
+ )
1252
+ coeffs /= coeffs.norm(p=1)
1253
+ return (
1254
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1255
+ .flip(-1)
1256
+ .permute(0, 1, 2, 4, 3)
1257
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1258
+ )
1259
+
1260
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1261
+ """Expand tokens using the rotative trick."""
1262
+ d_parity = head_dim // 2
1263
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1264
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1265
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1266
+ if head_dim % 2:
1267
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1268
+ else:
1269
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1270
+ x_pairs = x_pairs.unsqueeze(3).expand(
1271
+ batch_size, num_heads, seq_len, n, d_parity, 2
1272
+ )
1273
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1274
+ x0r = x0 * cos - x1 * sin
1275
+ x1r = x0 * sin + x1 * cos
1276
+ rot = torch.stack([x0r, x1r], -1).reshape(
1277
+ batch_size, num_heads, seq_len, n, d_parity * 2
1278
+ )
1279
+ if head_dim % 2:
1280
+ last = (
1281
+ x[..., -1]
1282
+ .unsqueeze(-1)
1283
+ .unsqueeze(3)
1284
+ .expand(batch_size, num_heads, seq_len, n, 1)
1285
+ )
1286
+ rot = torch.cat([rot, last], -1)
1287
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1288
+
1289
+ if mode == "derivative":
1290
+ return derivative_expand(x)
1291
+ if mode == "rotative":
1292
+ return rotative_expand(x)
1293
+ if mode == "combined":
1294
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1295
+ raise ValueError(f"Unknown mode: {mode}")
1296
+
1297
+
1298
+ def extract_layer_idx(module_name: str) -> int:
1299
+ """Extract the layer index from a module name string."""
1300
+ match = re.search(r"\.(\d+)\.", module_name)
1301
+ if match:
1302
+ return int(match.group(1))
1303
+ return -1
1304
+
1305
+
1306
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1307
+ """Find the embedding weight in a model module."""
1308
+ for _, child in module.named_modules():
1309
+ if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
1310
+ return child.embed_tokens
1311
+ if hasattr(child, "token_embeddings") and hasattr(
1312
+ child.token_embeddings, "weight"
1313
+ ):
1314
+ return child.token_embeddings
1315
+ return None
1316
+
1317
+
1318
+ def set_trainable_parameters(
1319
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1320
+ ) -> PreTrainedModel:
1321
+ """Freeze model parameters except trainable_patterns."""
1322
+ if trainable_patterns is None:
1323
+ trainable_patterns = [
1324
+ "q_proj",
1325
+ "k_proj",
1326
+ "v_proj",
1327
+ "o_proj",
1328
+ "qkv_proj",
1329
+ "out_proj",
1330
+ "c_attn",
1331
+ "c_proj",
1332
+ "query",
1333
+ "key",
1334
+ "value",
1335
+ ]
1336
+
1337
+ for name, param in model.named_parameters():
1338
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
1339
+
1340
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1341
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1342
+ return model
1343
+
1344
+
1345
+ def ensure_stability(
1346
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1347
+ ) -> torch.Tensor:
1348
+ """stability forcing"""
1349
+ dtype = tensor.dtype
1350
+ center = (max_val + min_val) / 2
1351
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1352
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1353
+ return tensor.to(dtype=dtype)
1354
+
1355
+
1356
+ def apply_linear_attention_mask(
1357
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1358
+ ) -> torch.Tensor:
1359
+ """Extract if padding --> [B,S]"""
1360
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1361
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1362
+ else:
1363
+ mask = attention_mask.squeeze(
1364
+ dim=tuple(
1365
+ i
1366
+ for i in range(1, attention_mask.dim())
1367
+ if attention_mask.shape[i] == 1
1368
+ )
1369
+ )
1370
+ # Ensure cast to the same dtype as v and convert to binary mask
1371
+ if not (
1372
+ mask.dtype == torch.bool
1373
+ or (
1374
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1375
+ and mask.max() <= 1
1376
+ and mask.min() >= 0
1377
+ )
1378
+ ):
1379
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1380
+ else:
1381
+ mask = mask.to(v.dtype)
1382
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1383
+ if padding_side == "left":
1384
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1385
+ else: # right padding
1386
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1387
+ return v * mask
1388
+
1389
+
1390
+ def truncate_attention_mask(
1391
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1392
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1393
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
1394
+ seq_dim = 1 # convention: (batch, seq, ...)
1395
+ seq_len = hidden_states.shape[seq_dim]
1396
+ if seq_len > max_length:
1397
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
1398
+ if attention_mask is not None:
1399
+ # mask [batch, seq]
1400
+ if attention_mask.dim() == 2:
1401
+ attention_mask = attention_mask[:, -max_length:]
1402
+ # mask [batch, seq, seq]
1403
+ elif attention_mask.dim() == 3:
1404
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
1405
+ # mask [batch, 1, seq, seq]
1406
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1407
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
1408
+ else:
1409
+ raise ValueError(
1410
+ "No dimension in attention_mask matches sequence length of hidden_states."
1411
+ )
1412
+ return hidden_states, attention_mask
1413
+
1414
+
1415
+ def fast_invert_matrix(
1416
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1417
+ ) -> torch.Tensor:
1418
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1419
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1420
+ chunk_size = tri_tensor.shape[-1]
1421
+
1422
+ for i in range(1, chunk_size):
1423
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1424
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1425
+ ).sum(-2)
1426
+
1427
+ tri_tensor = tri_tensor + torch.eye(
1428
+ chunk_size, dtype=dtype, device=tri_tensor.device
1429
+ )
1430
+ return tri_tensor.to(dtype=dtype)
1431
+
1432
+
1433
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1434
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
1435
+ for c in range(min(chunk_size, total_l), 0, -1):
1436
+ if total_l % c == 0:
1437
+ return c
1438
+ return 1
1439
+
1440
+
1441
+ ## RARELY
1442
+ def split_qkv(
1443
+ base_attn: nn.Module, qkv: torch.Tensor
1444
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1445
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1446
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1447
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1448
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1449
+ head_dim = getattr(base_attn, "head_dim", None)
1450
+
1451
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1452
+ raise ValueError(
1453
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1454
+ )
1455
+
1456
+ q_len = num_q_heads * head_dim
1457
+ k_len = num_k_heads * head_dim
1458
+ v_len = num_v_heads * head_dim
1459
+
1460
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1461
+ return q, k, v
1462
+
1463
+
1464
+ ## OPTIONAL
1465
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1466
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
1467
+ src_size = x.shape[dim]
1468
+ if src_size == target_size:
1469
+ return x
1470
+ x = torch.moveaxis(x, dim, -1)
1471
+ shape = x.shape
1472
+ if src_size < target_size:
1473
+ x = x.reshape(-1, 1, src_size)
1474
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
1475
+ x = x.reshape(*shape[:-1], target_size)
1476
+ else:
1477
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
1478
+ x = F.linear(x, eye) # pylint: disable=not-callable
1479
+ x = torch.moveaxis(x, -1, dim)
1480
+ return x
1481
+
1482
+
1483
+ def soft_clamp(
1484
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1485
+ ) -> torch.Tensor:
1486
+ """Differentiable clamping for stability"""
1487
+ dtype = x.dtype
1488
+ scale = (max_val - min_val) / 2
1489
+ center = (max_val + min_val) / 2
1490
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1491
+
1492
+
1493
+ def describe(x: torch.Tensor, name="tensor") -> None:
1494
+ """Prints the shape, min, max, mean, and std of a tensor."""
1495
+ stats = (x.min(), x.max(), x.mean(), x.std())
1496
+ print(
1497
+ f"{name} shape: {tuple(x.shape)}, "
1498
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1499
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1500
+ + f"dtype: {x.dtype}, device: {x.device}"
1501
+ )
train_tptt.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments
2
+
3
+ """
4
+ Author : Fabien FURFARO
5
+ """
6
+
7
+ from typing import Optional, Union
8
+
9
+ from transformers import PreTrainedModel, TrainerCallback
10
+
11
+ from .modeling_tptt import LiZAttention
12
+
13
+
14
+ class LiZACallback(TrainerCallback):
15
+ """
16
+ TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
17
+
18
+ Modes:
19
+ - "gradual": linear interpolation from initial_weight to final_weight.
20
+ - "cyclic": alternate between values in weight_list at each step.
21
+ - "switch": alternately enable/disable linear attention at each step.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model: PreTrainedModel,
27
+ mode: str = "gradual",
28
+ initial_weight: float = 0.0,
29
+ final_weight: float = 0.5,
30
+ transition_step: Union[int, tuple, list] = 100,
31
+ weight_list: Optional[list] = None,
32
+ switch_period: int = 1, # period for switching
33
+ ):
34
+ self.model = model
35
+ self.mode = mode
36
+
37
+ # Ensure initial_weight is a float scalar, not tuple/list
38
+ if isinstance(initial_weight, (tuple, list)):
39
+ initial_weight = initial_weight[0]
40
+ if isinstance(final_weight, (tuple, list)):
41
+ final_weight = final_weight[0]
42
+ self.initial_weight = float(initial_weight)
43
+ self.final_weight = float(final_weight)
44
+
45
+ # Ensure transition_step is an int scalar, not tuple/list
46
+ self.transition_step = ensure_int(transition_step)
47
+ if self.mode == "constant":
48
+ # For constant mode, transition_step is not used
49
+ self.initial_weight = self.final_weight
50
+ # For cyclic mode: ensure all weights are float scalars
51
+ if weight_list is not None:
52
+ self.weight_list = [
53
+ float(w[0]) if isinstance(w, (tuple, list)) else float(w)
54
+ for w in weight_list
55
+ ]
56
+ else:
57
+ self.weight_list = [self.initial_weight, self.final_weight]
58
+
59
+ # For switch_alternate mode
60
+ self.switch_period = int(switch_period)
61
+
62
+ def on_step_end(self, args, state, control, **kwargs):
63
+ current_step = state.global_step
64
+ transition_step = self.transition_step
65
+
66
+ # Ensure current_step and transition_step are plain ints
67
+ current_step = ensure_int(current_step)
68
+ transition_step = ensure_int(transition_step)
69
+
70
+ # Select mag_weight or enable/disable linear attention according to mode
71
+ if self.mode == "constant":
72
+ # Set mag_weight to final_weight for constant mode
73
+ weight = self.final_weight
74
+ for _, module in self.model.named_modules():
75
+ if isinstance(module, LiZAttention):
76
+ module.mag_weight = weight
77
+
78
+ elif self.mode == "gradual":
79
+ if current_step <= transition_step:
80
+ weight = self.initial_weight + (
81
+ self.final_weight - self.initial_weight
82
+ ) * (current_step / transition_step)
83
+ else:
84
+ weight = self.final_weight
85
+ for _, module in self.model.named_modules():
86
+ if isinstance(module, LiZAttention):
87
+ module.mag_weight = weight
88
+
89
+ elif self.mode == "cyclic":
90
+ idx = current_step % len(self.weight_list)
91
+ weight = self.weight_list[idx]
92
+ for _, module in self.model.named_modules():
93
+ if isinstance(module, LiZAttention):
94
+ module.mag_weight = weight
95
+
96
+ elif self.mode == "switch":
97
+ # Alternately enable/disable linear attention every switch_period steps
98
+ disable = (current_step // self.switch_period) % 2 == 0
99
+ for _, module in self.model.named_modules():
100
+ if isinstance(module, LiZAttention):
101
+ module.disable_linear_attn = disable
102
+
103
+ else:
104
+ raise ValueError(f"Unknown mode: {self.mode}")
105
+
106
+ def on_log(self, args, state, control, logs=None, **kwargs):
107
+ mag_weight = None
108
+ disable_linear_attn = None
109
+ # Log the current mag_weight and disable_linear_attn
110
+ for _, module in self.model.named_modules():
111
+ if isinstance(module, LiZAttention):
112
+ mag_weight = getattr(module, "mag_weight", None)
113
+ disable_linear_attn = getattr(module, "disable_linear_attn", None)
114
+ break
115
+ if mag_weight is not None and logs is not None:
116
+ logs["mag_weight"] = float(mag_weight)
117
+ if disable_linear_attn is not None and logs is not None:
118
+ logs["disable_linear_attn"] = not bool(disable_linear_attn)
119
+
120
+
121
+ def ensure_int(value: Union[int, tuple, list]) -> int:
122
+ """Ensure the value is a plain integer."""
123
+ if isinstance(value, (tuple, list)):
124
+ value = int(value[0])
125
+ if hasattr(value, "item"):
126
+ value = int(value.item())
127
+ return value
128
+
129
+
130
+ class SaveBestModelCallback(TrainerCallback):
131
+ """TrainerCallback to save the best model based on evaluation loss."""
132
+
133
+ def __init__(self):
134
+ self.best_metric = float("inf")
135
+
136
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
137
+ if metrics is not None and "eval_loss" in metrics:
138
+ if metrics["eval_loss"] < self.best_metric:
139
+ self.best_metric = metrics["eval_loss"]
140
+ control.should_save = True # Trigger save
141
+ else:
142
+ control.should_save = False # Skip save