smcleish commited on
Commit
fd0abc3
·
verified ·
1 Parent(s): ab85171

Upload RavenForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_checkpoint_impl": "per-iteration",
3
+ "architecture_class_name": "RecurrentGPT",
4
+ "architectures": [
5
+ "RavenForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "raven_config_minimal.RavenConfig",
9
+ "AutoModelForCausalLM": "raven_modeling_minimal.RavenForCausalLM"
10
+ },
11
+ "bias": false,
12
+ "block_class_name": "SandwichBlock",
13
+ "block_size": 1024,
14
+ "bos_token_id": 65504,
15
+ "effective_expected_depth": 56,
16
+ "eos_token_id": 65505,
17
+ "head_dim": 64,
18
+ "init_orthogonal": false,
19
+ "init_strategy": "takase",
20
+ "init_values": {
21
+ "embed_scale": 1.0,
22
+ "embedding": 0.008703882797784892,
23
+ "out_proj": 0.0005356869554443541,
24
+ "std": 0.008703882797784892
25
+ },
26
+ "injection_type": "linear",
27
+ "intermediate_size": 8192,
28
+ "max_position_embeddings": 131072,
29
+ "mean_backprop_depth": 8,
30
+ "mean_recurrence": 8,
31
+ "mlp_class_name": "GatedMLP",
32
+ "model_type": "huginn_raven",
33
+ "n_embd": 2048,
34
+ "n_heads": 32,
35
+ "n_layers": 14,
36
+ "n_layers_in_coda": 4,
37
+ "n_layers_in_prelude": 4,
38
+ "n_layers_in_recurrent_block": 6,
39
+ "nonlin_name": "SiLU",
40
+ "norm_class_name": "RMSNorm_llama",
41
+ "norm_eps": 1e-05,
42
+ "num_key_value_heads": 8,
43
+ "pad_token_id": 65509,
44
+ "padded_vocab_size": 128256,
45
+ "padding_multiple": 4096,
46
+ "qk_bias": false,
47
+ "rope_base": 500000.0,
48
+ "rope_scaling": {
49
+ "factor": 32.0,
50
+ "high_freq_factor": 4.0,
51
+ "low_freq_factor": 1.0,
52
+ "original_max_position_embeddings": 8192,
53
+ "rope_type": "llama3"
54
+ },
55
+ "rope_theta": 500000.0,
56
+ "sampling_scheme": "poisson-lognormal-filling",
57
+ "state_init": "like-init",
58
+ "test_time_noise": 0,
59
+ "test_time_noise_type": "fixed",
60
+ "tie_embeddings": false,
61
+ "tie_word_embeddings": false,
62
+ "torch_dtype": "float32",
63
+ "transformers_version": "4.53.1",
64
+ "vocab_size": 128256
65
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 65504,
4
+ "eos_token_id": 65505,
5
+ "pad_token_id": 65509,
6
+ "transformers_version": "4.53.1"
7
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86e5f63b6525276ec1086a9911688514cedf352e083f37ab5404ab906fe909f4
3
+ size 4490249800
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3af08a58f6006152d03f9d61b282cf651ac601e4c84eae73afd7994a9829ed3
3
+ size 1050673280
model.safetensors.index.json ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 1385228288,
4
+ "total_size": 5540913152
5
+ },
6
+ "weight_map": {
7
+ "lm_head.weight": "model-00002-of-00002.safetensors",
8
+ "transformer.adapter.weight": "model-00001-of-00002.safetensors",
9
+ "transformer.coda.0.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
10
+ "transformer.coda.0.attn.proj.weight": "model-00001-of-00002.safetensors",
11
+ "transformer.coda.0.mlp.fc.weight": "model-00001-of-00002.safetensors",
12
+ "transformer.coda.0.mlp.proj.weight": "model-00001-of-00002.safetensors",
13
+ "transformer.coda.0.norm_1.weight": "model-00001-of-00002.safetensors",
14
+ "transformer.coda.0.norm_2.weight": "model-00001-of-00002.safetensors",
15
+ "transformer.coda.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
16
+ "transformer.coda.1.attn.proj.weight": "model-00001-of-00002.safetensors",
17
+ "transformer.coda.1.mlp.fc.weight": "model-00001-of-00002.safetensors",
18
+ "transformer.coda.1.mlp.proj.weight": "model-00001-of-00002.safetensors",
19
+ "transformer.coda.1.norm_1.weight": "model-00001-of-00002.safetensors",
20
+ "transformer.coda.1.norm_2.weight": "model-00001-of-00002.safetensors",
21
+ "transformer.coda.2.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
22
+ "transformer.coda.2.attn.proj.weight": "model-00001-of-00002.safetensors",
23
+ "transformer.coda.2.mlp.fc.weight": "model-00001-of-00002.safetensors",
24
+ "transformer.coda.2.mlp.proj.weight": "model-00001-of-00002.safetensors",
25
+ "transformer.coda.2.norm_1.weight": "model-00001-of-00002.safetensors",
26
+ "transformer.coda.2.norm_2.weight": "model-00001-of-00002.safetensors",
27
+ "transformer.coda.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
28
+ "transformer.coda.3.attn.proj.weight": "model-00001-of-00002.safetensors",
29
+ "transformer.coda.3.mlp.fc.weight": "model-00001-of-00002.safetensors",
30
+ "transformer.coda.3.mlp.proj.weight": "model-00001-of-00002.safetensors",
31
+ "transformer.coda.3.norm_1.weight": "model-00001-of-00002.safetensors",
32
+ "transformer.coda.3.norm_2.weight": "model-00001-of-00002.safetensors",
33
+ "transformer.core_block.0.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
34
+ "transformer.core_block.0.attn.proj.weight": "model-00001-of-00002.safetensors",
35
+ "transformer.core_block.0.mlp.fc.weight": "model-00001-of-00002.safetensors",
36
+ "transformer.core_block.0.mlp.proj.weight": "model-00001-of-00002.safetensors",
37
+ "transformer.core_block.0.norm_1.weight": "model-00001-of-00002.safetensors",
38
+ "transformer.core_block.0.norm_2.weight": "model-00001-of-00002.safetensors",
39
+ "transformer.core_block.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
40
+ "transformer.core_block.1.attn.proj.weight": "model-00001-of-00002.safetensors",
41
+ "transformer.core_block.1.mlp.fc.weight": "model-00001-of-00002.safetensors",
42
+ "transformer.core_block.1.mlp.proj.weight": "model-00001-of-00002.safetensors",
43
+ "transformer.core_block.1.norm_1.weight": "model-00001-of-00002.safetensors",
44
+ "transformer.core_block.1.norm_2.weight": "model-00001-of-00002.safetensors",
45
+ "transformer.core_block.2.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
46
+ "transformer.core_block.2.attn.proj.weight": "model-00001-of-00002.safetensors",
47
+ "transformer.core_block.2.mlp.fc.weight": "model-00001-of-00002.safetensors",
48
+ "transformer.core_block.2.mlp.proj.weight": "model-00001-of-00002.safetensors",
49
+ "transformer.core_block.2.norm_1.weight": "model-00001-of-00002.safetensors",
50
+ "transformer.core_block.2.norm_2.weight": "model-00001-of-00002.safetensors",
51
+ "transformer.core_block.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
52
+ "transformer.core_block.3.attn.proj.weight": "model-00001-of-00002.safetensors",
53
+ "transformer.core_block.3.mlp.fc.weight": "model-00001-of-00002.safetensors",
54
+ "transformer.core_block.3.mlp.proj.weight": "model-00001-of-00002.safetensors",
55
+ "transformer.core_block.3.norm_1.weight": "model-00001-of-00002.safetensors",
56
+ "transformer.core_block.3.norm_2.weight": "model-00001-of-00002.safetensors",
57
+ "transformer.core_block.4.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
58
+ "transformer.core_block.4.attn.proj.weight": "model-00001-of-00002.safetensors",
59
+ "transformer.core_block.4.mlp.fc.weight": "model-00001-of-00002.safetensors",
60
+ "transformer.core_block.4.mlp.proj.weight": "model-00001-of-00002.safetensors",
61
+ "transformer.core_block.4.norm_1.weight": "model-00001-of-00002.safetensors",
62
+ "transformer.core_block.4.norm_2.weight": "model-00001-of-00002.safetensors",
63
+ "transformer.core_block.5.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
64
+ "transformer.core_block.5.attn.proj.weight": "model-00001-of-00002.safetensors",
65
+ "transformer.core_block.5.mlp.fc.weight": "model-00001-of-00002.safetensors",
66
+ "transformer.core_block.5.mlp.proj.weight": "model-00001-of-00002.safetensors",
67
+ "transformer.core_block.5.norm_1.weight": "model-00001-of-00002.safetensors",
68
+ "transformer.core_block.5.norm_2.weight": "model-00001-of-00002.safetensors",
69
+ "transformer.ln_f.weight": "model-00001-of-00002.safetensors",
70
+ "transformer.prelude.0.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
71
+ "transformer.prelude.0.attn.proj.weight": "model-00001-of-00002.safetensors",
72
+ "transformer.prelude.0.mlp.fc.weight": "model-00001-of-00002.safetensors",
73
+ "transformer.prelude.0.mlp.proj.weight": "model-00001-of-00002.safetensors",
74
+ "transformer.prelude.0.norm_1.weight": "model-00001-of-00002.safetensors",
75
+ "transformer.prelude.0.norm_2.weight": "model-00001-of-00002.safetensors",
76
+ "transformer.prelude.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
77
+ "transformer.prelude.1.attn.proj.weight": "model-00001-of-00002.safetensors",
78
+ "transformer.prelude.1.mlp.fc.weight": "model-00001-of-00002.safetensors",
79
+ "transformer.prelude.1.mlp.proj.weight": "model-00001-of-00002.safetensors",
80
+ "transformer.prelude.1.norm_1.weight": "model-00001-of-00002.safetensors",
81
+ "transformer.prelude.1.norm_2.weight": "model-00001-of-00002.safetensors",
82
+ "transformer.prelude.2.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
83
+ "transformer.prelude.2.attn.proj.weight": "model-00001-of-00002.safetensors",
84
+ "transformer.prelude.2.mlp.fc.weight": "model-00001-of-00002.safetensors",
85
+ "transformer.prelude.2.mlp.proj.weight": "model-00001-of-00002.safetensors",
86
+ "transformer.prelude.2.norm_1.weight": "model-00001-of-00002.safetensors",
87
+ "transformer.prelude.2.norm_2.weight": "model-00001-of-00002.safetensors",
88
+ "transformer.prelude.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
89
+ "transformer.prelude.3.attn.proj.weight": "model-00001-of-00002.safetensors",
90
+ "transformer.prelude.3.mlp.fc.weight": "model-00001-of-00002.safetensors",
91
+ "transformer.prelude.3.mlp.proj.weight": "model-00001-of-00002.safetensors",
92
+ "transformer.prelude.3.norm_1.weight": "model-00001-of-00002.safetensors",
93
+ "transformer.prelude.3.norm_2.weight": "model-00001-of-00002.safetensors",
94
+ "transformer.wte.weight": "model-00001-of-00002.safetensors"
95
+ }
96
+ }
raven_config_minimal.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A HuggingFace-style model configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+ from math import sqrt
5
+
6
+
7
+ class RavenConfig(PretrainedConfig):
8
+ model_type = "huginn_raven"
9
+ keys_to_ignore_at_inference = [""]
10
+ attribute_map = {"num_attention_heads": "n_heads", "hidden_size": "n_embd", "num_hidden_layers": "n_layers"}
11
+
12
+ def __init__(
13
+ self,
14
+ n_embd: int = 5280,
15
+ n_heads: int = 55,
16
+ n_layers: int = 8, # total of prelude + recurrent + coda
17
+ block_size: int = 4096,
18
+ vocab_size: int = 65536,
19
+ padding_multiple: int = 4096,
20
+ tie_embeddings: bool = True,
21
+ intermediate_size: int = 17920,
22
+ bias: bool = False,
23
+ architecture_class_name: str = "RecurrentGPT",
24
+ block_class_name: str = "SandwichBlock",
25
+ norm_class_name: str = "RMSNorm_llama",
26
+ norm_eps: float = 0.000001,
27
+ mlp_class_name: str = "GatedMLP",
28
+ nonlin_name: str = "SiLU",
29
+ init_strategy: str = "takase",
30
+ init_orthogonal: bool = False,
31
+ state_init: str = "like-init",
32
+ injection_type: str = "linear",
33
+ n_layers_in_recurrent_block: int = 4,
34
+ mean_recurrence: int = 32,
35
+ sampling_scheme: str = "poisson-lognormal-filling",
36
+ mean_backprop_depth: int = 8,
37
+ n_layers_in_prelude: int = 2,
38
+ n_layers_in_coda: int = 2,
39
+ qk_bias: bool = True,
40
+ activation_checkpoint_impl: str = "per-iteration",
41
+ rope_base: float = 50_000,
42
+ torch_dtype: str = "bfloat16",
43
+ transformers_version: str = "4.47.1",
44
+ **kwargs,
45
+ ):
46
+ self.n_embd = n_embd
47
+ self.n_heads = n_heads
48
+ self.n_layers = n_layers
49
+ self.block_size = block_size
50
+ self.vocab_size = self.padded_vocab_size = vocab_size
51
+ self.padding_multiple = padding_multiple
52
+ self.tie_embeddings = tie_embeddings
53
+ self.intermediate_size = intermediate_size
54
+ self.bias = bias
55
+ self.architecture_class_name = architecture_class_name
56
+ self.block_class_name = block_class_name
57
+ self.norm_class_name = norm_class_name
58
+ self.norm_eps = norm_eps
59
+ self.mlp_class_name = mlp_class_name
60
+ self.nonlin_name = nonlin_name
61
+ self.init_strategy = init_strategy
62
+ self.init_orthogonal = init_orthogonal
63
+ self.state_init = state_init
64
+ self.injection_type = injection_type
65
+ self.n_layers_in_recurrent_block = n_layers_in_recurrent_block
66
+ self.mean_recurrence = mean_recurrence
67
+ self.sampling_scheme = sampling_scheme
68
+ self.mean_backprop_depth = mean_backprop_depth
69
+ self.n_layers_in_prelude = n_layers_in_prelude
70
+ self.n_layers_in_coda = n_layers_in_coda
71
+ self.qk_bias = qk_bias
72
+ self.activation_checkpoint_impl = activation_checkpoint_impl
73
+ self.rope_base = rope_base
74
+ self.torch_dtype = torch_dtype # Added from JSON
75
+ self.transformers_version = transformers_version # Added from JSON
76
+ # inference
77
+ self.test_time_noise = 0
78
+ self.test_time_noise_type = "fixed"
79
+ # Derived
80
+ self.num_key_value_heads = n_heads
81
+ self.num_attention_heads = n_heads
82
+ self.head_dim = n_embd // n_heads
83
+ self.effective_expected_depth = (
84
+ self.n_layers_in_prelude + self.n_layers_in_coda + self.n_layers_in_recurrent_block * self.mean_recurrence
85
+ )
86
+ self.init_values = {
87
+ "std": sqrt(2 / (5 * self.n_embd)),
88
+ "out_proj": sqrt(2 / (5 * self.n_embd)) / sqrt(2 * self.effective_expected_depth),
89
+ "embedding": sqrt(2 / (5 * self.n_embd)),
90
+ "embed_scale": sqrt(self.n_embd),
91
+ }
92
+
93
+ super().__init__(
94
+ # pad_token_id=65509,
95
+ # bos_token_id=65504,
96
+ # eos_token_id=65505,
97
+ tie_word_embeddings=tie_embeddings,
98
+ **kwargs,
99
+ )
raven_modeling_minimal.py ADDED
@@ -0,0 +1,1572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modeling file for HF compatibility and zero-shot experiments."""
2
+
3
+ import torch
4
+ import math
5
+
6
+ from torch import Tensor
7
+ from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
8
+ from torch.nn.attention import bias as attn_bias
9
+ from dataclasses import dataclass
10
+ from typing import Union, Optional, Any
11
+
12
+
13
+ from .raven_config_minimal import RavenConfig
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+
16
+ ###################### Huggingface Glue code I ##################################################################
17
+ from transformers import PreTrainedModel, GenerationMixin
18
+ from transformers.utils import ModelOutput
19
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
20
+
21
+ import torch.nn.functional as F
22
+ from transformers import GenerationConfig
23
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
24
+
25
+ # torch.backends.cuda.enable_math_sdp(False)
26
+
27
+
28
+ class RavenPreTrainedModel(PreTrainedModel):
29
+ config_class = RavenConfig
30
+ base_model_prefix = "model"
31
+ supports_gradient_checkpointing = True
32
+ _no_split_modules = ["SandwichBlock"]
33
+ _skip_keys_device_placement = ["past_key_values"]
34
+ _tied_weights_keys = ["lm_head.weight"]
35
+ _supports_flash_attn_2 = True
36
+ _supports_sdpa = True
37
+ _supports_cache_class = True
38
+ _supports_quantized_cache = False
39
+ _supports_static_cache = True
40
+ _tp_plan = {}
41
+
42
+ def _init_weights(self, module):
43
+ if not torch.rand((1,)).is_meta:
44
+ print("Random Initialization not implemented.")
45
+
46
+
47
+ @dataclass
48
+ class CausalLMOutputRecurrentLatents(ModelOutput):
49
+ loss: Optional[torch.Tensor] = None
50
+ log_ppl: Optional[torch.Tensor] = None
51
+ logits: Optional[torch.Tensor] = None
52
+ past_key_values: Optional[Cache] = None
53
+ latent_states: Optional[torch.Tensor] = None
54
+ hidden_states: Optional[torch.Tensor] = None
55
+ attention_maps: Optional[dict[int, torch.Tensor]] = None
56
+ stats: Optional[dict] = None
57
+
58
+
59
+ ###################### Minimal implementation from here ############################################################
60
+
61
+
62
+ class RMSNorm(torch.nn.Module):
63
+ """Saner dtype handling and slightly better for fusion"""
64
+
65
+ def __init__(self, dim: int, eps: float = 1e-6):
66
+ super().__init__()
67
+ self.eps = eps
68
+ self.weight = torch.nn.Parameter(torch.ones(dim))
69
+
70
+ def _norm(self, x):
71
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
72
+
73
+ def forward(self, x):
74
+ with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
75
+ return self._norm(x.float()).type_as(x) * self.weight
76
+
77
+ def reset_parameters(self) -> None:
78
+ torch.nn.init.ones_(self.weight)
79
+
80
+
81
+ class HuginnDynamicCache(DynamicCache):
82
+ def __init__(self, lookup_strategy: str = "full") -> None:
83
+ super().__init__()
84
+ self._seen_tokens = 0
85
+ self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
86
+ self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
87
+ # structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
88
+ # the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
89
+ # per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
90
+ # Also, It is critical that the head indices do not overlap with the recurrent iteration indices
91
+ self.lookup_strategy = lookup_strategy
92
+
93
+ def update(
94
+ self,
95
+ key_states: torch.Tensor,
96
+ value_states: torch.Tensor,
97
+ step_idx_tensor: torch.Tensor,
98
+ lookup_strategy: Optional[str] = None,
99
+ ) -> tuple[torch.Tensor, torch.Tensor]:
100
+ step_idx: int = int(step_idx_tensor) # todo: fix dicts with tensor step_idx, currently the memberships fail
101
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
102
+ if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
103
+ if "compress-s" in self.lookup_strategy:
104
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
105
+ new_step_idx = (step_idx - 2) % compression_stage + 2
106
+ elif "compress-anchor" in self.lookup_strategy:
107
+ if step_idx - 2 < 4 * 8: # anchor onto first 8 recurrence steps # noqa: SIM108
108
+ new_step_idx = step_idx
109
+ else: # then re-use the next 4 KV states = one recurrence for all future recurrence
110
+ new_step_idx = 34 + (step_idx - 34) % 4
111
+ # print(step_idx, new_step_idx)
112
+ else: # compress-r
113
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
114
+ new_step_idx = (step_idx - 2) // compression_stage + 2
115
+ step_idx = new_step_idx
116
+ # Init
117
+ if step_idx not in self.key_cache:
118
+ self.key_cache[step_idx] = {}
119
+ self.value_cache[step_idx] = {}
120
+ # Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
121
+ if step_idx == 0:
122
+ self._seen_tokens += key_states.shape[-2]
123
+ # Add entries to cache
124
+ for idx, entry in enumerate(key_states.unbind(dim=-2)):
125
+ if "compress-" not in self.lookup_strategy:
126
+ assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
127
+ self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
128
+ for idx, entry in enumerate(value_states.unbind(dim=-2)):
129
+ self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
130
+
131
+ # Materialize past state based on lookup strategy:
132
+ if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full":
133
+ # All entries are present, materialize cache as normal
134
+ return (
135
+ torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
136
+ torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
137
+ )
138
+ else: # some entries were not previously computed
139
+ if lookup_strategy.startswith("latest-m4"):
140
+ latest_keys = []
141
+ latest_values = []
142
+ for token_pos in range(self._seen_tokens):
143
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
144
+ if step_idx >= 2:
145
+ # Find valid steps for this token position
146
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
147
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
148
+ else:
149
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
150
+ latest_keys.append(self.key_cache[max_step][token_pos])
151
+ latest_values.append(self.value_cache[max_step][token_pos])
152
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
153
+ elif lookup_strategy.startswith("available-m4"):
154
+ latest_keys = []
155
+ latest_values = []
156
+ for token_pos in range(self._seen_tokens):
157
+ if token_pos in self.key_cache[step_idx]:
158
+ step = step_idx
159
+ else:
160
+ # Find valid steps for this token position
161
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
162
+ step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
163
+ latest_keys.append(self.key_cache[step][token_pos])
164
+ latest_values.append(self.value_cache[step][token_pos])
165
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
166
+ elif lookup_strategy.startswith("always-last-m4"):
167
+ latest_keys = []
168
+ latest_values = []
169
+ for token_pos in range(self._seen_tokens):
170
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
171
+ if step_idx >= 2:
172
+ # Find valid steps for this token position
173
+ valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]]
174
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
175
+ else:
176
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
177
+ latest_keys.append(self.key_cache[max_step][token_pos])
178
+ latest_values.append(self.value_cache[max_step][token_pos])
179
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
180
+ elif lookup_strategy.startswith("skip"):
181
+ existing_keys = []
182
+ existing_values = []
183
+ for token_pos in range(self._seen_tokens):
184
+ if token_pos in self.key_cache[step_idx]:
185
+ existing_keys.append(self.key_cache[step_idx][token_pos])
186
+ existing_values.append(self.value_cache[step_idx][token_pos])
187
+ return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
188
+ elif lookup_strategy.startswith("randomized"): # sanity check
189
+ rand_keys = []
190
+ rand_values = []
191
+ for token_pos in range(self._seen_tokens):
192
+ if step_idx < 2: # For prelude steps
193
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
194
+ else: # Get all steps from same block position
195
+ curr_modulo = (step_idx - 2) % 4 + 2
196
+ valid_steps = [
197
+ s
198
+ for s in range(2, step_idx + 1)
199
+ if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s]
200
+ ]
201
+ max_step = valid_steps[torch.randint(len(valid_steps), (1,))]
202
+ rand_keys.append(self.key_cache[max_step][token_pos])
203
+ rand_values.append(self.value_cache[max_step][token_pos])
204
+ return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
205
+ else:
206
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
207
+
208
+ def reset(self) -> None:
209
+ """Reset the cache state."""
210
+ self._seen_tokens = 0
211
+ self.key_cache.clear()
212
+ self.value_cache.clear()
213
+
214
+ def clear_last_k_entries(self, k: int = 0):
215
+ """Partially clear cache."""
216
+ assert self._seen_tokens >= k
217
+ self._seen_tokens = self._seen_tokens - k
218
+ # self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
219
+ self.key_cache = {
220
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
221
+ for step, cache in self.key_cache.items()
222
+ }
223
+ self.value_cache = {
224
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
225
+ for step, cache in self.value_cache.items()
226
+ }
227
+
228
+ def get_seq_length(self, step_idx: int = 0) -> int:
229
+ return self._seen_tokens
230
+
231
+ def get_memory_usage(self) -> float:
232
+ total_bytes = 0
233
+ # For each recurrent step/layer index
234
+ for step_idx in self.key_cache:
235
+ # Get the sequence cache for this step
236
+ key_seq_cache = self.key_cache[step_idx]
237
+ for seq_idx in key_seq_cache:
238
+ key_tensor = key_seq_cache[seq_idx]
239
+ # Add memory for of key tensors, assuming value is the same
240
+ total_bytes += key_tensor.nelement() * key_tensor.element_size()
241
+ return total_bytes * 2 / (1024 * 1024)
242
+
243
+
244
+ class HuginnStaticCache(Cache):
245
+ """Static Cache for the recurrent model"""
246
+
247
+ is_compileable = False # this is todo
248
+
249
+ def __init__(
250
+ self,
251
+ max_length: int,
252
+ max_num_steps: int,
253
+ num_heads: int,
254
+ hidden_dim: int,
255
+ batch_size: int = 1,
256
+ lookup_strategy: str = "full",
257
+ device: Optional[Union[torch.device, str]] = None,
258
+ dtype: torch.dtype = torch.float32,
259
+ ) -> None:
260
+ super().__init__()
261
+ self._seen_tokens = 0
262
+ self.max_length = max_length
263
+ self.lookup_strategy = lookup_strategy
264
+
265
+ # Adjust max_num_steps based on compression strategy
266
+ if "compress-" in lookup_strategy:
267
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
268
+ if "compress-s" in lookup_strategy:
269
+ # For modulo compression (s), we need steps for 0,1 + compressed steps
270
+ self.max_num_steps = 4 + compression_stage
271
+ else:
272
+ # For relative compression, we need steps for 0,1 + compressed steps
273
+ self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage
274
+ else:
275
+ self.max_num_steps = max_num_steps
276
+
277
+ # Pre-allocate cache tensors [steps, batch, heads, seq_len, head_dim]
278
+ device = torch.device(device) if device is not None else None
279
+ cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim)
280
+
281
+ self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
282
+ self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
283
+ self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device)
284
+ # Mark tensors as static for compile
285
+ torch._dynamo.mark_static_address(self.key_cache)
286
+ torch._dynamo.mark_static_address(self.value_cache)
287
+ torch._dynamo.mark_static_address(self.valid_mask)
288
+
289
+ def update(
290
+ self,
291
+ key_states: torch.Tensor,
292
+ value_states: torch.Tensor,
293
+ step_idx: torch.Tensor,
294
+ lookup_strategy: Optional[str] = None,
295
+ ) -> tuple[torch.Tensor, torch.Tensor]:
296
+ if step_idx == 0:
297
+ self._seen_tokens += key_states.shape[-2]
298
+
299
+ # Adjust step_idx for compression
300
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
301
+ if "compress-" in lookup_strategy and step_idx > 1:
302
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
303
+ if "compress-s" in lookup_strategy:
304
+ step_idx = (step_idx - 2) % compression_stage + 2
305
+ else:
306
+ step_idx = (step_idx - 2) // compression_stage + 2
307
+
308
+ start_idx = self._seen_tokens - key_states.shape[-2]
309
+
310
+ indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device)
311
+ self.key_cache[step_idx].index_copy_(2, indices, key_states)
312
+ self.value_cache[step_idx].index_copy_(2, indices, value_states)
313
+ self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True
314
+
315
+ # Return based on lookup strategy
316
+ if lookup_strategy == "full":
317
+ return (
318
+ self.key_cache[step_idx, :, :, : self._seen_tokens],
319
+ self.value_cache[step_idx, :, :, : self._seen_tokens],
320
+ )
321
+ elif lookup_strategy.startswith("latest-m4"):
322
+ if step_idx >= 2:
323
+ pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device)
324
+ pattern_valid = self.valid_mask[pattern_steps]
325
+ max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)]
326
+ return (
327
+ self.key_cache[max_valid_step, torch.arange(self._seen_tokens)],
328
+ self.value_cache[max_valid_step, torch.arange(self._seen_tokens)],
329
+ )
330
+ return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[
331
+ step_idx, :, :, : self._seen_tokens
332
+ ]
333
+ elif lookup_strategy == "skip":
334
+ valid_mask = self.valid_mask[step_idx, : self._seen_tokens]
335
+ return (
336
+ self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
337
+ self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
338
+ )
339
+ elif lookup_strategy.startswith("randomized"):
340
+ if step_idx < 2:
341
+ max_step = step_idx
342
+ else:
343
+ curr_modulo = (step_idx - 2) % 4 + 2
344
+ valid_steps = (
345
+ torch.where(
346
+ (torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo
347
+ )[0]
348
+ + 2
349
+ )
350
+ rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device)
351
+ max_step = valid_steps[rand_idx]
352
+ return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens]
353
+ else:
354
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
355
+
356
+ def reset(self) -> None:
357
+ self._seen_tokens = 0
358
+ self.key_cache.zero_()
359
+ self.value_cache.zero_()
360
+ self.valid_mask.zero_()
361
+
362
+ def get_seq_length(self, step_idx: int = 0) -> int:
363
+ return self._seen_tokens
364
+
365
+ def get_memory_usage(self) -> float:
366
+ return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024)
367
+
368
+
369
+ ValidCache = HuginnDynamicCache | HuginnStaticCache
370
+
371
+
372
+ class CausalSelfAttention(torch.nn.Module):
373
+ def __init__(self, config: RavenConfig) -> None:
374
+ super().__init__()
375
+ self.config = config
376
+ self.n_head = config.num_attention_heads
377
+ self.n_kv_heads = config.num_key_value_heads
378
+ self.head_dim = config.n_embd // self.n_head
379
+
380
+ shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
381
+ self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
382
+ self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
383
+ if config.qk_bias:
384
+ self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
385
+ self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False)
386
+
387
+ def forward(
388
+ self,
389
+ x: Tensor,
390
+ freqs_cis: Tensor,
391
+ block_idx: torch.Tensor,
392
+ mask: Optional[BlockMask] = None,
393
+ past_key_values: Optional[ValidCache] = None,
394
+ ) -> Tensor:
395
+ B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
396
+ q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
397
+ q = q.view(B, S, self.n_head, self.head_dim)
398
+ k = k.view(B, S, self.n_kv_heads, self.head_dim)
399
+ v = v.view(B, S, self.n_kv_heads, self.head_dim)
400
+ # bias?
401
+ if self.config.qk_bias:
402
+ q_bias, k_bias = self.qk_bias.split(1, dim=0)
403
+ q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
404
+
405
+ q = q.transpose(1, 2) # (B, nh, S, hs)
406
+ k = k.transpose(1, 2)
407
+ v = v.transpose(1, 2)
408
+
409
+ # apply rotary
410
+ cos, sin = freqs_cis
411
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
412
+
413
+ if past_key_values is not None:
414
+ k, v = past_key_values.update(k, v, block_idx)
415
+
416
+ if mask is not None:
417
+ y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) # type: ignore
418
+ else:
419
+ if q.shape[2] < k.shape[2]:
420
+ if q.shape[2] > 1:
421
+ bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2])
422
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0, enable_gqa=True)
423
+ else:
424
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False, enable_gqa=True)
425
+ else:
426
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True, enable_gqa=True)
427
+ y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
428
+ return self.proj(y)
429
+
430
+
431
+ class GatedMLP(torch.nn.Module):
432
+ def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
433
+ super().__init__()
434
+ in_features = config.n_embd if in_features == 0 else in_features
435
+ self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
436
+
437
+ self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
438
+ self.nonlin = torch.nn.SiLU()
439
+
440
+ def forward(self, x: Tensor) -> Tensor:
441
+ # modified to single FC layer to improve parallelism
442
+ x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
443
+ x = self.nonlin(x_fc_1) * x_fc_2
444
+ return self.proj(x)
445
+
446
+
447
+ class SandwichBlock(torch.nn.Module):
448
+ expanded = False
449
+
450
+ def __init__(self, config: RavenConfig, layer_id: int) -> None:
451
+ super().__init__()
452
+ self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
453
+ self.attn = CausalSelfAttention(config)
454
+ self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
455
+ self.mlp = GatedMLP(config)
456
+ self.layer_id = layer_id
457
+
458
+ def forward(
459
+ self,
460
+ x: Tensor,
461
+ freqs_cis: Tensor,
462
+ step_idx: int,
463
+ mask: Optional[BlockMask] = None,
464
+ past_key_values: Optional[ValidCache] = None,
465
+ ) -> Tensor:
466
+ attn_out = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values)
467
+ x = attn_out + x
468
+ x = self.mlp(self.norm_2(x)) + x
469
+ return x
470
+
471
+
472
+ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
473
+
474
+ def __init__(
475
+ self,
476
+ config: RavenConfig,
477
+ ) -> None:
478
+ super().__init__(config)
479
+ self.config = config
480
+
481
+ # Transformer layers
482
+ prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
483
+ adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
484
+ core_block = torch.nn.ModuleList(
485
+ SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
486
+ for i in range(config.n_layers_in_recurrent_block)
487
+ )
488
+ o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
489
+ coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
490
+
491
+ self.transformer = torch.nn.ModuleDict(
492
+ dict(
493
+ wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
494
+ prelude=prelude,
495
+ adapter=adapter,
496
+ core_block=core_block,
497
+ coda=coda,
498
+ ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
499
+ )
500
+ )
501
+ self.emb_scale = config.init_values["embed_scale"]
502
+ # Head
503
+ self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
504
+ if self.config.tie_embeddings:
505
+ self.tie_weights()
506
+ # rope
507
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
508
+
509
+ def get_input_embeddings(self):
510
+ return self.transformer.wte
511
+
512
+ def get_output_embeddings(self):
513
+ return self.lm_head
514
+
515
+
516
+ def compile_mask(
517
+ self,
518
+ input_ids: torch.Tensor,
519
+ attention_mask: Optional[torch.Tensor] = None,
520
+ past_key_values: Optional[ValidCache] = None,
521
+ pad_token_id=65509,
522
+ ) -> Optional[BlockMask]:
523
+ batch_size, seq_len = input_ids.shape[0], input_ids.shape[1]
524
+
525
+ # If no padding and no attention mask, no need for a mask
526
+ if attention_mask is None and (input_ids == pad_token_id).sum() == 0:
527
+ return None
528
+
529
+ if past_key_values is not None and seq_len == 1:
530
+ return None
531
+
532
+ # Get total sequence length including cache
533
+ cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0
534
+ kv_length = cache_len + seq_len
535
+
536
+ if attention_mask is None:
537
+
538
+ def mask_mod(b, h, q_idx, kv_idx):
539
+ return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id)
540
+ else:
541
+
542
+ def mask_mod(b, h, q_idx, kv_idx):
543
+ return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx]
544
+
545
+ kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len
546
+ if kv_length == 0:
547
+ kv_length = seq_len # prefill
548
+ block_mask = create_block_mask(
549
+ mask_mod,
550
+ B=batch_size,
551
+ H=None,
552
+ Q_LEN=seq_len,
553
+ KV_LEN=kv_length,
554
+ device=input_ids.device,
555
+ )
556
+
557
+ # # Define mask_mod function
558
+ # def mask_mod(b, h, q_idx, kv_idx):
559
+ # # Always apply causal constraint
560
+ # is_causal = q_idx >= kv_idx
561
+
562
+ # # Handle cache vs current tokens
563
+ # is_cache = kv_idx < cache_len
564
+ # current_idx = kv_idx - cache_len
565
+
566
+ # # For cache: always valid; For current: check padding
567
+ # not_pad = input_ids[b, current_idx] != pad_token_id
568
+ # valid = is_cache | not_pad
569
+
570
+ # # Apply attention mask if provided
571
+ # if attention_mask is not None:
572
+ # q_idx_curr = q_idx - cache_len
573
+ # attn_valid = attention_mask[b, q_idx_curr, current_idx]
574
+ # valid = valid & (is_cache | attn_valid)
575
+
576
+ # return is_causal & valid
577
+
578
+ # def mask_mod(b, h, q_idx, kv_idx):
579
+ # is_causal = q_idx >= kv_idx
580
+ # is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
581
+ # current_idx = kv_idx - cache_len
582
+
583
+ # is_valid = (~is_current) | (
584
+ # (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
585
+ # )
586
+
587
+ # return is_causal & is_valid
588
+
589
+ # # Define mask_mod function
590
+ # def mask_mod(b, h, q_idx, kv_idx):
591
+ # # Always apply causal constraint
592
+ # is_causal = q_idx >= kv_idx
593
+
594
+ # # Handle cache vs current tokens
595
+ # is_cache = kv_idx < cache_len
596
+ # current_idx = kv_idx - cache_len
597
+ # in_bounds = (current_idx >= 0) & (current_idx < seq_len)
598
+
599
+ # # For cache: always valid; For current: check padding
600
+ # not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
601
+ # valid = is_cache | (not_pad & in_bounds)
602
+
603
+ # # Apply attention mask if provided
604
+ # if attention_mask is not None:
605
+ # q_idx_curr = q_idx - cache_len
606
+ # q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
607
+ # attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
608
+ # valid = valid & (is_cache | attn_valid)
609
+
610
+ # return is_causal & valid
611
+
612
+ # Create block mask
613
+ block_mask = create_block_mask(
614
+ mask_mod,
615
+ B=batch_size,
616
+ H=None,
617
+ Q_LEN=seq_len,
618
+ KV_LEN=kv_length,
619
+ device=input_ids.device,
620
+ )
621
+
622
+ return block_mask
623
+
624
+ def forward(
625
+ self,
626
+ input_ids: torch.Tensor,
627
+ input_embeds: Optional[torch.Tensor] = None,
628
+ input_states: Optional[torch.Tensor] = None,
629
+ attention_mask: Optional[torch.Tensor] = None, # binary mask of shape q x kv, True=valid position
630
+ position_ids: Optional[torch.Tensor] = None,
631
+ labels: Optional[torch.Tensor] = None,
632
+ num_steps: Optional[torch.Tensor] = None,
633
+ past_key_values: Optional[ValidCache] = None,
634
+ output_details: dict = {
635
+ "return_logits": True,
636
+ "return_latents": True,
637
+ "return_head": False,
638
+ "return_stats": False,
639
+ },
640
+ use_cache: bool = False,
641
+ cache_position: Optional[torch.Tensor] = None,
642
+ init_scale: float = 1.0,
643
+ **kwargs,
644
+ ) -> CausalLMOutputRecurrentLatents:
645
+ # Support multiple position formats:
646
+ if position_ids is None and cache_position is None:
647
+ position_ids = torch.arange(input_ids.shape[1], device=self.device).unsqueeze(0)
648
+ elif cache_position is not None:
649
+ position_ids = cache_position.unsqueeze(0)
650
+
651
+ if input_embeds is None:
652
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
653
+
654
+ if self.emb_scale != 1:
655
+ input_embeds = input_embeds * self.emb_scale # type: ignore
656
+
657
+ if use_cache and past_key_values is None:
658
+ past_key_values = HuginnDynamicCache()
659
+
660
+ prepared_attn_mask = None # self.compile_mask(input_ids, attention_mask, past_key_values)
661
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
662
+
663
+ freqs_cis = self.rotary_emb(input_embeds, position_ids)
664
+
665
+ # Non-recurrent prelude
666
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
667
+ block_idx += 1
668
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
669
+
670
+ # Main recurrence
671
+ x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward(
672
+ input_embeds, # type: ignore # mystery typing error
673
+ input_states,
674
+ freqs_cis,
675
+ block_idx,
676
+ prepared_attn_mask,
677
+ past_key_values,
678
+ num_steps,
679
+ init_scale,
680
+ )
681
+ latent_states = x.clone().detach()
682
+
683
+ # Coda layers
684
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
685
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
686
+ block_idx -= 1
687
+ x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
688
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
689
+
690
+ # Prediction head, assuming labels really are labels and not equal to input_ids
691
+ if labels is not None:
692
+ logits = self.lm_head(x).float()
693
+ loss = torch.nn.functional.cross_entropy(
694
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
695
+ )
696
+ log_ppl = loss.clone().detach().exp()
697
+ else:
698
+ logits = self.lm_head(x)#.float()
699
+ loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
700
+
701
+ return CausalLMOutputRecurrentLatents(
702
+ loss=loss,
703
+ log_ppl=log_ppl,
704
+ logits=logits if output_details["return_logits"] else None,
705
+ past_key_values=past_key_values,
706
+ hidden_states=x if output_details["return_head"] else None,
707
+ latent_states=latent_states if output_details["return_latents"] else None,
708
+ stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
709
+ if output_details["return_stats"]
710
+ else None,
711
+ )
712
+
713
+ @torch._dynamo.disable(recursive=False) # type: ignore
714
+ def iterate_forward(
715
+ self,
716
+ input_embeds: torch.Tensor,
717
+ input_states: torch.Tensor,
718
+ freqs_cis,
719
+ block_idx: torch.Tensor,
720
+ mask: Optional[BlockMask],
721
+ past_key_values: Optional[ValidCache] = None,
722
+ num_steps: Optional[torch.Tensor] = None,
723
+ init_scale: float = 1.0,
724
+ ):
725
+ x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone()
726
+ if num_steps is None:
727
+ num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
728
+ elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
729
+ num_steps_no_grad, num_steps_with_grad = num_steps
730
+ else:
731
+ num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
732
+
733
+ with torch.no_grad():
734
+ # ultra annoying in ddp due to
735
+ # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
736
+ # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
737
+ # and all parameters are always used
738
+ for no_grad_step in range(num_steps_no_grad):
739
+ xk = x
740
+ x, block_idx = self.core_block_forward(
741
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step
742
+ )
743
+
744
+ for grad_step in range(num_steps_with_grad):
745
+ xk = x
746
+ x, block_idx = self.core_block_forward(
747
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
748
+ )
749
+ return x, num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
750
+
751
+ def core_block_forward(
752
+ self,
753
+ x,
754
+ input_embeds,
755
+ freqs_cis,
756
+ mask: Optional[BlockMask],
757
+ past_key_values,
758
+ block_idx: torch.Tensor,
759
+ current_step: int | Tensor,
760
+ ):
761
+ x = self._maybe_inject_noise(x, current_step)
762
+ x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
763
+ for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
764
+ block_idx += 1
765
+ x = block(x, freqs_cis, block_idx, mask, past_key_values)
766
+ return x, block_idx
767
+
768
+ @torch.no_grad()
769
+ def iterate_one_step(
770
+ self,
771
+ input_embeds,
772
+ input_states,
773
+ position_ids: Optional[torch.Tensor] = None,
774
+ cache_position: Optional[torch.Tensor] = None,
775
+ block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long),
776
+ attention_mask: Optional[BlockMask] = None,
777
+ past_key_values: Optional[ValidCache] = None,
778
+ current_step: int = 0,
779
+ ):
780
+ if position_ids is None and cache_position is None:
781
+ freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
782
+ elif position_ids is not None:
783
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
784
+ elif cache_position is not None:
785
+ freqs_cis = self.freqs_cis[:, cache_position]
786
+ x, block_idx = self.core_block_forward(
787
+ input_states,
788
+ input_embeds,
789
+ freqs_cis,
790
+ attention_mask,
791
+ past_key_values,
792
+ block_idx,
793
+ current_step=current_step,
794
+ )
795
+ return x, block_idx, current_step + 1
796
+
797
+ def predict_from_latents(
798
+ self,
799
+ latents,
800
+ attention_mask: Optional[BlockMask] = None,
801
+ position_ids: Optional[torch.Tensor] = None,
802
+ cache_position: Optional[torch.Tensor] = None,
803
+ past_key_values: Optional[ValidCache] = None,
804
+ ):
805
+ if position_ids is None and cache_position is None:
806
+ freqs_cis = self.freqs_cis[:, : latents.shape[1]]
807
+ elif position_ids is not None:
808
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
809
+ elif cache_position is not None:
810
+ freqs_cis = self.freqs_cis[:, cache_position]
811
+ x = self.transformer.ln_f(latents) # type: ignore # types broken in 2.6+
812
+ # Coda layers
813
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
814
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
815
+ block_idx -= 1
816
+ x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
817
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
818
+
819
+ logits = self.lm_head(x).float()
820
+
821
+ return CausalLMOutputRecurrentLatents(
822
+ loss=torch.as_tensor(0.0),
823
+ log_ppl=torch.as_tensor(0.0),
824
+ logits=logits,
825
+ past_key_values=past_key_values,
826
+ latent_states=x,
827
+ )
828
+
829
+ def embed_inputs(
830
+ self,
831
+ input_ids: torch.Tensor,
832
+ attention_mask: Optional[torch.Tensor] = None,
833
+ position_ids: Optional[torch.Tensor] = None,
834
+ past_key_values: Optional[ValidCache] = None,
835
+ use_cache: bool = False,
836
+ cache_position: Optional[torch.Tensor] = None,
837
+ **kwargs,
838
+ ) -> tuple[torch.Tensor, torch.Tensor]:
839
+ # Support multiple position formats:
840
+ if position_ids is None and cache_position is None:
841
+ freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
842
+ elif position_ids is not None:
843
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
844
+ elif cache_position is not None:
845
+ freqs_cis = self.freqs_cis[:, cache_position]
846
+
847
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
848
+ prepared_attn_mask = self.compile_mask(input_ids, attention_mask)
849
+
850
+ if self.emb_scale != 1:
851
+ input_embeds = input_embeds * self.emb_scale # type: ignore
852
+
853
+ if use_cache and past_key_values is None:
854
+ past_key_values = HuginnDynamicCache()
855
+
856
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
857
+ # Non-recurrent prelude
858
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
859
+ block_idx += 1
860
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
861
+ return input_embeds, block_idx
862
+
863
+ @torch._dynamo.disable(recursive=False) # type: ignore
864
+ def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
865
+ """Outputs are long tensors so that they can be passed through compiled functions"""
866
+ t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
867
+ s = self.config.mean_backprop_depth
868
+ if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
869
+ # these values are only the mean TFLOPs of the randomized sampler
870
+ # Note that this clause also breaks the contract, and returns ints in meta tensor mode
871
+ return t, s # type: ignore
872
+ if self.training:
873
+ sigma = 0.5
874
+ mu = math.log(t + s) - (sigma**2 / 2)
875
+ rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
876
+ p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
877
+ n = torch.clamp(p - s, min=0)
878
+ k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
879
+ else:
880
+ n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
881
+
882
+ return n.to(dtype=torch.long), k.to(dtype=torch.long)
883
+
884
+ def initialize_state(self, input_embeds, scale: float = 1.0):
885
+ x = torch.randn_like(input_embeds)
886
+ std = self.config.init_values["std"] * scale
887
+ if std > 0:
888
+ torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
889
+ if self.emb_scale != 1:
890
+ x = x * self.emb_scale
891
+ else:
892
+ x.zero_()
893
+ return x
894
+
895
+ def _maybe_inject_noise(self, x, current_step, renorm=False):
896
+ if self.config.test_time_noise > 0:
897
+ n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
898
+ if self.config.test_time_noise_type == "geom":
899
+ step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
900
+ x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
901
+ elif self.config.test_time_noise_type == "sqrt":
902
+ step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
903
+ x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
904
+ elif self.config.test_time_noise_type == "line":
905
+ noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
906
+ x = x * (1 - noise) + torch.randn_like(x) * noise
907
+ elif self.config.test_time_noise_type == "chi":
908
+ noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
909
+ x = x * (1 - noise) + torch.randn_like(x) * noise
910
+ elif self.config.test_time_noise_type == "fixed":
911
+ x = x * (1 - n) + torch.randn_like(x) * n
912
+ else:
913
+ raise ValueError()
914
+
915
+ if renorm:
916
+ x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
917
+ return x
918
+
919
+ def prepare_inputs_for_generation(
920
+ self,
921
+ input_ids: torch.Tensor,
922
+ past_key_values: Optional[Cache] = None,
923
+ attention_mask: Optional[torch.Tensor] = None,
924
+ inputs_embeds: Optional[torch.FloatTensor] = None,
925
+ cache_position: Optional[torch.Tensor] = None,
926
+ cache_lookup_strategy: str = "full",
927
+ **kwargs,
928
+ ):
929
+ model_inputs = {}
930
+ model_inputs["cache_position"] = cache_position
931
+ current_input_length = input_ids.shape[1]
932
+
933
+ if past_key_values is not None:
934
+ if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)):
935
+ assert past_key_values.get_seq_length() == 0 # only replace empty caches
936
+ # Need to use custom cache, detect and replace HF cache if generate injects it
937
+ if isinstance(past_key_values, StaticCache):
938
+ past_key_values = HuginnStaticCache(
939
+ max_length=getattr(self.generation_config, "max_length", self.config.block_size),
940
+ max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4,
941
+ num_heads=self.config.num_key_value_heads,
942
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
943
+ dtype=torch.bfloat16,
944
+ device=input_ids.device,
945
+ lookup_strategy=cache_lookup_strategy,
946
+ )
947
+ else:
948
+ past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
949
+ model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
950
+ input_ids = input_ids[:, cache_position] # type: ignore
951
+
952
+ model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
953
+ if cache_position is None:
954
+ position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
955
+ model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
956
+ memory_format=torch.contiguous_format
957
+ ) # some form of position_ids is a critical argument for the model to correctly apply rope!
958
+
959
+ # forward all other entries
960
+ for key, value in kwargs.items():
961
+ if key not in model_inputs:
962
+ model_inputs[key] = value
963
+ return model_inputs
964
+
965
+ @torch.no_grad()
966
+ def generate(self, *args, **kwargs):
967
+ """Dispatcher - use HF generate in all normal cases."""
968
+ self.generation_config = args[1] if len(args) > 1 else self.generation_config
969
+ if any(k in kwargs for k in ("criterion", "exit_threshold")):
970
+ # print("Dispatching to custom generate_adaptive function call")
971
+ return self.generate_with_adaptive_compute(*args, **kwargs)
972
+ elif "continuous_compute" in kwargs:
973
+ # print("Dispatching to custom generate_minimal function call")
974
+ return self.generate_minimal(*args, **kwargs)
975
+ else:
976
+ return super().generate(*args, **kwargs)
977
+
978
+ @torch.no_grad()
979
+ def _prep_generate_args(
980
+ self,
981
+ input_ids: torch.Tensor,
982
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
983
+ cache_lookup_strategy: str = "full",
984
+ model_kwargs: dict = {},
985
+ ):
986
+ # Setup
987
+ if generation_config is None:
988
+ generation_config: GenerationConfig = self.generation_config # type: ignore
989
+ if "max_new_tokens" in model_kwargs:
990
+ max_new_tokens = model_kwargs["max_new_tokens"]
991
+ if "max_length" in model_kwargs:
992
+ max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1])
993
+ else:
994
+ max_length = model_kwargs.get("max_length", generation_config.max_length)
995
+ max_new_tokens = max_length - input_ids.shape[1]
996
+
997
+ if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic":
998
+ model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
999
+ else:
1000
+ model_kwargs["past_key_values"] = HuginnStaticCache(
1001
+ max_length=max_length,
1002
+ max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4,
1003
+ num_heads=self.config.num_key_value_heads,
1004
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
1005
+ batch_size=input_ids.shape[0],
1006
+ dtype=torch.bfloat16,
1007
+ device=input_ids.device,
1008
+ lookup_strategy=cache_lookup_strategy,
1009
+ )
1010
+ model_kwargs["use_cache"] = True
1011
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1012
+ return model_kwargs, generation_config, max_new_tokens
1013
+
1014
+ @torch.no_grad()
1015
+ def generate_minimal(
1016
+ self,
1017
+ input_ids: torch.Tensor,
1018
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1019
+ tokenizer=None,
1020
+ streamer=None,
1021
+ continuous_compute=False, # warm-start state / continuous CoT
1022
+ init_scale: float = 1.0,
1023
+ cache_lookup_strategy: str = "full",
1024
+ **model_kwargs,
1025
+ ) -> Union[torch.Tensor, dict[str, Any]]:
1026
+ """Minimal single-sequence generation. Template for more complicated generate tasks"""
1027
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1028
+ input_ids, generation_config, cache_lookup_strategy
1029
+ )
1030
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1031
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1032
+
1033
+ # Set up continuous compute if enabled
1034
+ if continuous_compute:
1035
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1036
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1037
+
1038
+ # Generate tokens
1039
+ batch_size = input_ids.shape[0]
1040
+ for _ in range(max_new_tokens):
1041
+ # Forward pass
1042
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1043
+ outputs = self(**model_inputs, init_scale=init_scale)
1044
+
1045
+ # Get next token
1046
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
1047
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1048
+
1049
+ # Append token to sequence
1050
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1051
+
1052
+ if streamer:
1053
+ streamer.put(next_token.cpu())
1054
+
1055
+ # Update model kwargs
1056
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1057
+ if continuous_compute:
1058
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1059
+
1060
+ if stop_tokens is not None:
1061
+ for i in range(batch_size):
1062
+ if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens:
1063
+ unfinished_sequences[i] = 0
1064
+ if "stopping_criteria" in model_kwargs:
1065
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1066
+ if unfinished_sequences.max() == 0:
1067
+ break
1068
+
1069
+ if streamer:
1070
+ streamer.end()
1071
+
1072
+ if generation_config.return_dict_in_generate:
1073
+ return GenerateDecoderOnlyOutput(
1074
+ sequences=input_ids, # type: ignore
1075
+ scores=None,
1076
+ logits=None,
1077
+ attentions=None,
1078
+ hidden_states=None,
1079
+ past_key_values=model_kwargs.get("past_key_values"),
1080
+ )
1081
+ return input_ids
1082
+
1083
+ @torch.no_grad()
1084
+ def generate_with_adaptive_compute(
1085
+ self,
1086
+ input_ids: torch.Tensor,
1087
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1088
+ tokenizer=None,
1089
+ streamer=None,
1090
+ continuous_compute=False, # warm-start state / continuous CoT
1091
+ criterion="none", # off by default, turn on by choosing an exit criterion
1092
+ exit_threshold: Union[str, float, int] = "auto",
1093
+ init_scale: float = 1.0,
1094
+ cache_lookup_strategy: str = "full",
1095
+ **model_kwargs,
1096
+ ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
1097
+ """
1098
+ Generate tokens with adaptive compute. This is NOT the most efficient implementation.
1099
+ For batches, on each token, we iterate until the entire batch finishes.
1100
+ """
1101
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1102
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1103
+ )
1104
+ max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence)
1105
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1106
+ logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device)
1107
+ batch_size = input_ids.shape[0]
1108
+ compute_steps = []
1109
+
1110
+ # Set up continuous compute if enabled
1111
+ if continuous_compute:
1112
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1113
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1114
+
1115
+ # Track which sequences have finished (using unfinished_sequences to match generate_minimal)
1116
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1117
+
1118
+ # Generate tokens
1119
+ for _ in range(max_new_tokens):
1120
+ # Adaptive compute forward
1121
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1122
+ aux_inputs = {
1123
+ k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
1124
+ }
1125
+ embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
1126
+ current_latents = (
1127
+ self.initialize_state(embedded_inputs, scale=init_scale)
1128
+ if not continuous_compute
1129
+ else model_kwargs["input_states"]
1130
+ )
1131
+
1132
+ # Initialize criterion tracking for each sequence in batch
1133
+ exit_values_per_seq = [[] for _ in range(batch_size)]
1134
+ compute_steps_per_seq = [0] * batch_size
1135
+ exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1136
+
1137
+ # Set up criterions based on selected strategy
1138
+ if criterion == "entropy-diff":
1139
+ entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
1140
+ exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1141
+ elif criterion == "latent-diff":
1142
+ exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
1143
+ elif "kl" in criterion:
1144
+ V = self.config.padded_vocab_size
1145
+ log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
1146
+ if criterion == "minp-kl":
1147
+ exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
1148
+ else:
1149
+ exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
1150
+ elif criterion == "argmax-stability":
1151
+ stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
1152
+ current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
1153
+ exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
1154
+ elif criterion == "none":
1155
+ exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
1156
+ else:
1157
+ raise ValueError("Invalid adaptive compute strategy.")
1158
+
1159
+ next_token_logits = None
1160
+
1161
+ # Iterate through compute steps
1162
+ for compute_step in range(max_steps):
1163
+ prev_latents = current_latents.clone()
1164
+ current_latents, block_idx, _ = self.iterate_one_step(
1165
+ embedded_inputs,
1166
+ current_latents,
1167
+ block_idx=block_idx,
1168
+ **aux_inputs,
1169
+ current_step=compute_step,
1170
+ )
1171
+
1172
+ if _ > 0: # do not exit in prefill
1173
+ # Check exit condition for each sequence in batch
1174
+ if criterion == "entropy-diff":
1175
+ prev_entropy = entropy
1176
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1177
+ logits: torch.Tensor = outputs.logits # type: ignore
1178
+ probs = F.softmax(logits[:, -1, :], dim=-1)
1179
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
1180
+ exit_values = (entropy - prev_entropy).abs()
1181
+ elif criterion == "latent-diff":
1182
+ norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
1183
+ exit_values = norm_diff.mean(dim=-1)
1184
+ elif "kl" in criterion:
1185
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1186
+ logits: torch.Tensor = outputs.logits # type: ignore
1187
+ prev_log_probs = log_probs
1188
+ if criterion == "minp-kl":
1189
+ probs = F.softmax(logits[:, -1, :].float(), dim=-1)
1190
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1191
+ probs_mask = probs < (0.1 * max_probs)
1192
+ masked_probs = probs.clone()
1193
+ masked_probs[probs_mask] = 1 / V
1194
+ probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
1195
+ log_probs = probs.log()
1196
+ else:
1197
+ log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
1198
+ exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
1199
+ elif criterion == "argmax-stability":
1200
+ prev_argmax = current_argmax
1201
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1202
+ logits: torch.Tensor = outputs.logits # type: ignore
1203
+ current_argmax = logits[:, -1, :].argmax(dim=-1)
1204
+ stable_for_n_steps = torch.where(
1205
+ current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
1206
+ )
1207
+ exit_values = stable_for_n_steps
1208
+ elif criterion == "none":
1209
+ exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
1210
+
1211
+ # Record values and check exits for each sequence
1212
+ for i in range(batch_size):
1213
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1214
+ exit_values_per_seq[i].append(exit_values[i].item())
1215
+
1216
+ # Check for new exits, respecting unfinished_sequences
1217
+ new_exits = (
1218
+ exit_values < exit_threshold
1219
+ if criterion != "argmax-stability"
1220
+ else exit_values >= exit_threshold
1221
+ )
1222
+ new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
1223
+
1224
+ if new_exits.any():
1225
+ exit_reached = exit_reached | new_exits
1226
+ if criterion == "latent-diff":
1227
+ # Normally we don't compute the output for latent-diff, but when there is an exit,
1228
+ # we need to compute and save the output
1229
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1230
+ logits: torch.Tensor = outputs.logits # type: ignore
1231
+ if next_token_logits is None:
1232
+ next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
1233
+ else:
1234
+ for i in range(batch_size):
1235
+ if new_exits[i]:
1236
+ next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
1237
+ for i in range(batch_size):
1238
+ if new_exits[i]:
1239
+ compute_steps_per_seq[i] = compute_step + 1
1240
+
1241
+ # If all sequences have exited or finished, break early
1242
+ if (exit_reached | ~unfinished_sequences.bool()).all():
1243
+ break
1244
+ # This else is if the for loop finished without breaking
1245
+ else:
1246
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1247
+
1248
+ # For sequences that didn't exit early, use the final logits
1249
+ if next_token_logits is None:
1250
+ next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
1251
+ else:
1252
+ for i in range(batch_size):
1253
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1254
+ next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
1255
+ compute_steps_per_seq[i] = max_steps
1256
+
1257
+ # Save latent states for continuous compute if enabled
1258
+ if continuous_compute:
1259
+ model_kwargs["input_states"] = current_latents[:, -1:, :]
1260
+
1261
+ # Record compute steps for this token generation
1262
+ compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
1263
+
1264
+ # Sample or select next token based on generation config
1265
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1266
+
1267
+ # Append token to sequence
1268
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1269
+
1270
+ if streamer:
1271
+ streamer.put(next_token.cpu())
1272
+
1273
+ # Update model kwargs for next iteration
1274
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1275
+
1276
+ # Check for stop tokens and update unfinished sequences
1277
+ for i in range(batch_size):
1278
+ if (
1279
+ unfinished_sequences[i].bool()
1280
+ and stop_tokens is not None
1281
+ and next_token[i, 0].item() in stop_tokens
1282
+ ):
1283
+ unfinished_sequences[i] = 0
1284
+
1285
+ # Apply any custom stopping criteria
1286
+ if "stopping_criteria" in model_kwargs:
1287
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1288
+
1289
+ # Break if all sequences are finished
1290
+ if unfinished_sequences.max() == 0:
1291
+ break
1292
+
1293
+ if streamer:
1294
+ streamer.end()
1295
+
1296
+ if generation_config.return_dict_in_generate:
1297
+ return GenerateDecoderOnlyOutput(
1298
+ sequences=input_ids, # type: ignore
1299
+ scores=compute_steps, # type: ignore
1300
+ logits=None,
1301
+ attentions=None,
1302
+ hidden_states=None,
1303
+ past_key_values=model_kwargs.get("past_key_values"),
1304
+ )
1305
+ return input_ids
1306
+
1307
+ def _get_stops(self, generation_config, tokenizer, model_kwargs):
1308
+ stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
1309
+ if generation_config.eos_token_id is not None:
1310
+ stop_tokens.add(generation_config.eos_token_id)
1311
+ if "stopping_criteria" in model_kwargs and tokenizer is None:
1312
+ tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
1313
+ if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
1314
+ for s in generation_config.stop_strings:
1315
+ token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1316
+ stop_tokens.add(token_id)
1317
+ return torch.tensor(list(stop_tokens))
1318
+
1319
+ def _sample_next_token(self, next_token_logits, generation_config):
1320
+ """Helper function to sample the next token."""
1321
+ if generation_config.do_sample:
1322
+ if generation_config.temperature:
1323
+ next_token_logits = next_token_logits.float() / generation_config.temperature
1324
+
1325
+ probs = F.softmax(next_token_logits, dim=-1)
1326
+
1327
+ # Apply top_k
1328
+ if generation_config.top_k:
1329
+ top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
1330
+ min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1331
+ probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1332
+
1333
+ # Apply top_p (nucleus sampling)
1334
+ if generation_config.top_p:
1335
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1336
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1337
+
1338
+ # Create mask for probs to keep
1339
+ remove_indices = cumulative_probs > generation_config.top_p
1340
+ remove_indices[:, 0] = False # Keep at least the top probability
1341
+
1342
+ # Convert sorted indices mask back to original indices mask
1343
+ mask = torch.zeros_like(probs, dtype=torch.bool)
1344
+ for i in range(probs.shape[0]):
1345
+ mask[i, sorted_indices[i, remove_indices[i]]] = True
1346
+
1347
+ probs = torch.where(mask, torch.zeros_like(probs), probs)
1348
+
1349
+ # Apply min_p
1350
+ if generation_config.min_p:
1351
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1352
+ min_p_threshold = generation_config.min_p * max_probs
1353
+ probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1354
+
1355
+ # Renormalize probabilities
1356
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1357
+
1358
+ # Sample from the distribution
1359
+ return torch.multinomial(probs, num_samples=1)
1360
+ else:
1361
+ return torch.argmax(next_token_logits, dim=-1, keepdim=True)
1362
+
1363
+ @torch.no_grad()
1364
+ def generate_speculative(
1365
+ self,
1366
+ input_ids: torch.Tensor,
1367
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1368
+ tokenizer=None,
1369
+ streamer=None,
1370
+ continuous_compute=False, # warm-start state / continuous CoT
1371
+ init_scale: float = 1.0,
1372
+ cache_lookup_strategy: str = "full",
1373
+ draft_steps=32,
1374
+ lookahead_for_draft=8,
1375
+ verification_threshold=1,
1376
+ num_steps: int = 32, # intercept deliberately
1377
+ **model_kwargs,
1378
+ ) -> Union[torch.Tensor, dict[str, Any]]:
1379
+ """Batched speculative decoding with per-sequence acceptance."""
1380
+ assert lookahead_for_draft > 0
1381
+ pad_id = 65509
1382
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1383
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1384
+ )
1385
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1386
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1387
+
1388
+ # Set up continuous compute if enabled
1389
+ if continuous_compute:
1390
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1391
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1392
+
1393
+ tokens_generated = 0
1394
+ # Prefill cache with full num_steps
1395
+ if model_kwargs["past_key_values"].get_seq_length() == 0:
1396
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1397
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1398
+ next_token = self._sample_next_token(
1399
+ outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config
1400
+ )
1401
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1402
+ tokens_generated += 1
1403
+ if streamer:
1404
+ streamer.put(next_token.cpu())
1405
+ model_kwargs["cache_position"] = torch.as_tensor(
1406
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1407
+ )
1408
+ if continuous_compute:
1409
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1410
+
1411
+ # Generate tokens
1412
+ batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1]
1413
+ accepted_tokens = []
1414
+
1415
+ while tokens_generated < max_new_tokens:
1416
+ ### Run the next draft ####
1417
+ drafted_inputs = input_ids.clone()
1418
+ current_len = input_ids.shape[1]
1419
+
1420
+ for _ in range(lookahead_for_draft):
1421
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1422
+ outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale)
1423
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32)
1424
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1425
+ drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1)
1426
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
1427
+ if continuous_compute:
1428
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1429
+
1430
+ model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft)
1431
+
1432
+ ## Verify drafted tokens ###
1433
+ model_kwargs["cache_position"] = torch.arange(
1434
+ current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device
1435
+ )
1436
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1437
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1438
+ verified_next_token_preds = outputs.logits.argmax(dim=-1)
1439
+
1440
+ if verification_threshold >= 1:
1441
+ mismatched_tokens = (
1442
+ verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:]
1443
+ )
1444
+ not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1)
1445
+ else:
1446
+ verified_logits = outputs.logits[:, -lookahead_for_draft:, :]
1447
+ verified_probs = F.softmax(verified_logits, dim=-1)
1448
+ drafted_token_probs = torch.gather(
1449
+ verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1)
1450
+ ).squeeze(-1)
1451
+ max_probs = verified_probs.max(dim=-1)[0]
1452
+ verification_passed = drafted_token_probs >= verification_threshold * max_probs
1453
+ not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1)
1454
+
1455
+ # Per-sequence acceptance handling
1456
+ acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft)
1457
+
1458
+ # Build next_tokens for each sequence
1459
+ next_tokens_batch = []
1460
+ for i in range(batch_size):
1461
+ seq_acceptance = acceptance_lengths[i].item()
1462
+ if not_all_matched[i] and seq_acceptance < lookahead_for_draft:
1463
+ # Accept up to mismatch + sample final token
1464
+ accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1465
+ final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32)
1466
+ final_token = self._sample_next_token(final_token_logits, generation_config)
1467
+ seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token
1468
+ else:
1469
+ # Accept all drafted tokens
1470
+ seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1471
+ next_tokens_batch.append(seq_tokens)
1472
+
1473
+ # Clean up KV cache - only if any sequence had mismatches
1474
+ if not_all_matched.any():
1475
+ min_first_mismatch = first_mismatch.min().item()
1476
+ model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1)
1477
+
1478
+ # Concatenate accepted tokens to input_ids
1479
+ batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch]
1480
+ max_len = max(batch_accepted_counts)
1481
+ padded_tokens = [
1482
+ torch.cat(
1483
+ [
1484
+ tokens,
1485
+ pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device),
1486
+ ],
1487
+ dim=-1,
1488
+ )
1489
+ if tokens.shape[1] < max_len
1490
+ else tokens
1491
+ for tokens in next_tokens_batch
1492
+ ]
1493
+ next_tokens = torch.cat(padded_tokens, dim=0)
1494
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
1495
+
1496
+ accepted_tokens.append(batch_accepted_counts)
1497
+ tokens_generated += max(batch_accepted_counts)
1498
+
1499
+ if streamer:
1500
+ streamer.put(next_tokens_batch[0].cpu())
1501
+
1502
+ model_kwargs["cache_position"] = torch.as_tensor(
1503
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1504
+ )
1505
+ if continuous_compute:
1506
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1507
+
1508
+ # Check stopping conditions
1509
+ if stop_tokens is not None:
1510
+ for i in range(batch_size):
1511
+ if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any():
1512
+ unfinished_sequences[i] = 0
1513
+ if "stopping_criteria" in model_kwargs:
1514
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1515
+ if unfinished_sequences.max() == 0:
1516
+ break
1517
+
1518
+ if streamer:
1519
+ streamer.end()
1520
+
1521
+ # Cut off extraneous parts of the sequence per batch element
1522
+ if stop_tokens is not None:
1523
+ for i in range(batch_size):
1524
+ stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero()
1525
+ if len(stop_positions) > 0:
1526
+ input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id
1527
+ # Trim tensor to remove columns that are pad_id across all sequences
1528
+ non_pad_mask = input_ids != pad_id
1529
+ last_real_token = non_pad_mask.any(dim=0).nonzero()
1530
+ if len(last_real_token) > 0:
1531
+ input_ids = input_ids[:, : last_real_token[-1].item() + 1]
1532
+
1533
+ if generation_config.return_dict_in_generate:
1534
+ return GenerateDecoderOnlyOutput(
1535
+ sequences=input_ids, # type: ignore
1536
+ scores=accepted_tokens, # type: ignore
1537
+ logits=None,
1538
+ attentions=None,
1539
+ hidden_states=None,
1540
+ past_key_values=model_kwargs.get("past_key_values"),
1541
+ )
1542
+ return input_ids
1543
+
1544
+ def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1545
+ probs = torch.softmax(logits.float(), dim=-1)
1546
+ prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
1547
+ residual_diff = (x - latent_states).norm(dim=-1)
1548
+ rel_residual = residual_diff / latent_states.norm(dim=-1)
1549
+ stats = {
1550
+ "entropy": prob_entropy,
1551
+ "residual_diff": residual_diff,
1552
+ "rel_residual": rel_residual,
1553
+ "num_steps_no_grad": num_steps_no_grad,
1554
+ "num_steps_with_grad": num_steps_with_grad,
1555
+ }
1556
+ return stats
1557
+
1558
+
1559
+ #################################### HF registration ############################################################
1560
+
1561
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
1562
+
1563
+ # New
1564
+ RavenConfig.register_for_auto_class()
1565
+
1566
+ RavenForCausalLM.register_for_auto_class("AutoModel")
1567
+ RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
1568
+
1569
+ # Old?
1570
+ AutoConfig.register("huginn_raven", RavenConfig)
1571
+ AutoModel.register(RavenConfig, RavenForCausalLM)
1572
+ AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)