mariamkhmahran commited on
Commit
2e5bb45
·
1 Parent(s): fb1c644

upload model

Browse files
components/activation.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class GELU(nn.Module):
5
+ '''
6
+ GELU (Gausian Error Linear Unit) activation function.
7
+ '''
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x):
13
+ return 0.5 * x * (1 + torch.tanh(
14
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
15
+ (x + 0.044715 * torch.pow(x, 3))
16
+ ))
components/feed_forward.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from activation import GELU
5
+
6
+ class FeedForward(nn.Module):
7
+ '''
8
+ Feed-forward neural network with GELU activation function.
9
+ - Multi-Head Self-Attention → Captures relationships between tokens.
10
+ - Feedforward Neural Network (FFN) → Processes each token independently after attention.
11
+ '''
12
+
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+ self.layers = nn.Sequential(
16
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
17
+ GELU(),
18
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.layers(x)
components/layer_norm.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LayerNorm(nn.Module):
5
+ def __init__(self, emb_dim):
6
+ super().__init__()
7
+ self.eps = 1e-6 # small value to avoid division by zero
8
+ self.scale = nn.Parameter(torch.ones(emb_dim)) # trainable scale parameter
9
+ self.shift = nn.Parameter(torch.zeros(emb_dim)) # trainable shift parameter
10
+
11
+ def forward(self, x):
12
+ '''
13
+ In this implementation of Layer Normalization, the normalization is applied along
14
+ the last dimension of the input tensor 𝑋, which represents the embedding dimension (dim=-1).
15
+ Normalizing over the embedding dimension ensures that each word is treated independently,
16
+ preventing one word from influencing another.
17
+
18
+ For Transformer models, input data typically has the following shape:
19
+ [batch_size, seq_len, emb_dim]
20
+ '''
21
+
22
+ mean = x.mean(dim=-1, keepdim=True)
23
+ var = x.var(dim=-1, keepdim=True, unbiased=False) # unbiased=False means that the variance is calculated with the Bessel correction
24
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
25
+
26
+ return self.scale * norm_x + self.shift
components/multi_head_attention.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MultiHeadAttention(nn.Module):
5
+ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
6
+ super().__init__()
7
+ assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
8
+
9
+ self.d_out = d_out
10
+ self.num_heads = num_heads
11
+ self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
12
+
13
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
14
+ self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
15
+ self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
16
+ self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
17
+ self.dropout = nn.Dropout(dropout)
18
+ self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
19
+
20
+ def forward(self, x):
21
+ b, num_tokens, _ = x.shape
22
+
23
+ keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
24
+ queries = self.W_query(x)
25
+ values = self.W_value(x)
26
+
27
+ # We implicitly split the matrix by adding a `num_heads` dimension
28
+ # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
29
+ keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
30
+ values = values.view(b, num_tokens, self.num_heads, self.head_dim)
31
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
32
+
33
+ # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
34
+ keys = keys.transpose(1, 2)
35
+ queries = queries.transpose(1, 2)
36
+ values = values.transpose(1, 2)
37
+
38
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
39
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
40
+
41
+ # Original mask truncated to the number of tokens and converted to boolean
42
+ mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
43
+
44
+ # Use the mask to fill attention scores
45
+ attn_scores.masked_fill_(mask_bool, -torch.inf)
46
+
47
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
48
+ attn_weights = self.dropout(attn_weights)
49
+
50
+ # Shape: (b, num_tokens, num_heads, head_dim)
51
+ context_vec = (attn_weights @ values).transpose(1, 2)
52
+
53
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
54
+ context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
55
+ context_vec = self.out_proj(context_vec) # optional projection
56
+
57
+ return context_vec, attn_weights
components/transformer_block.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from feed_forward import FeedForward
5
+ from multi_head_attention import MultiHeadAttention
6
+
7
+ class TransformerBlock(nn.Module):
8
+ def __init__(self, cfg):
9
+ super().__init__()
10
+ self.att = MultiHeadAttention(
11
+ d_in=cfg["emb_dim"],
12
+ d_out=cfg["emb_dim"],
13
+ context_length=cfg["context_length"],
14
+ num_heads=cfg["n_heads"],
15
+ dropout=cfg["drop_rate"],
16
+ qkv_bias=cfg["qkv_bias"])
17
+ self.ff = FeedForward(cfg)
18
+ self.norm1 = nn.LayerNorm(cfg["emb_dim"])
19
+ self.norm2 = nn.LayerNorm(cfg["emb_dim"])
20
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
21
+
22
+
23
+ def forward(self, x):
24
+ '''
25
+ The transformer block consists of two main components:
26
+ - Multi-Head Self-Attention → Captures relationships between tokens.
27
+ - Feedforward Neural Network (FFN) → Processes each token independently after attention.
28
+
29
+ The output of the attention block is added to the input of the block (skip connection),
30
+ which is then normalized using LayerNorm. The output is then passed through the FFN,
31
+ and the result is again added to the input of the block and normalized.
32
+
33
+ The dropout is applied to the skip connections before adding them to the output of the
34
+ attention and FFN blocks. This helps to prevent overfitting and improves generalization.
35
+
36
+ Args:
37
+ x (torch.Tensor): Input tensor of shape [batch_size, seq_len, emb_dim].
38
+
39
+ Returns:
40
+ torch.Tensor: Output tensor of shape [batch_size, seq_len, emb_dim].
41
+ '''
42
+
43
+ shortcut = x
44
+ x = self.norm1(x)
45
+ x, _attn_weights = self.att(x)
46
+ x = self.drop_shortcut(x)
47
+ x = x + shortcut
48
+
49
+ shortcut = x
50
+ x = self.norm2(x)
51
+ x = self.ff(x)
52
+ x = self.drop_shortcut(x)
53
+ x = x + shortcut
54
+
55
+ return x, _attn_weights
56
+
57
+
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "gpt_and_prejudice",
3
+ "vocab_size": 50257,
4
+ "context_length": 256,
5
+ "emb_dim": 896,
6
+ "n_heads": 14,
7
+ "n_layers": 8,
8
+ "drop_rate": 0.2,
9
+ "qkv_bias": true
10
+ }
gpt_model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict, Any, Optional
5
+
6
+ from components.transformer_block import TransformerBlock
7
+ from components.layer_norm import LayerNorm
8
+
9
+ class InterventionPlan:
10
+ """
11
+ Hook object consulted during forward() to optionally replace activations.
12
+ Override any of the methods below in your experiment code.
13
+ """
14
+ def maybe_replace_resid_pre(self, layer_idx: int, x: torch.Tensor) -> torch.Tensor:
15
+ return x
16
+ def maybe_replace_resid_post(self, layer_idx: int, x: torch.Tensor) -> torch.Tensor:
17
+ return x
18
+ # Optional: only works if your blocks expose per-head z or mlp outputs.
19
+ def maybe_replace_head_z(self, layer_idx: int, z: torch.Tensor) -> torch.Tensor:
20
+ return z
21
+ def maybe_replace_mlp_out(self, layer_idx: int, h: torch.Tensor) -> torch.Tensor:
22
+ return h
23
+
24
+ class GPTModel(nn.Module):
25
+ def __init__(self, cfg: Dict[str, Any]):
26
+ super().__init__()
27
+ self.cfg = cfg
28
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]
29
+ )
30
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]
31
+ )
32
+ self.drop_emb = nn.Dropout(cfg["drop_rate"]
33
+ )
34
+ self.trf_blocks = nn.Sequential(
35
+ *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
36
+ )
37
+ self.final_norm = LayerNorm(cfg["emb_dim"]
38
+ )
39
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
40
+
41
+ @torch.no_grad()
42
+ def cache_forward(self, in_idx: torch.Tensor):
43
+ """Run a forward pass with caching enabled (no interventions)."""
44
+ return self.forward(in_idx, enable_cache=True)
45
+
46
+ def forward(
47
+ self,
48
+ in_idx: torch.Tensor,
49
+ enable_cache: bool = False,
50
+ intervention_plan: Optional[InterventionPlan] = None,
51
+ output_hidden_states: bool = False,
52
+ output_attentions_weights: bool = False,
53
+ # Backward-compat args (ignored if plan is provided)
54
+ intervene_layer: Optional[int] = None,
55
+ edited_hidden: Optional[torch.Tensor] = None,
56
+ ):
57
+ """
58
+ Mechanistic interpretability-friendly forward.
59
+ Returns: logits, (optional) cache, (optional) hidden_states, (optional) attn_weights
60
+ Cache keys: resid_pre[L], resid_post[L], attn_weights[L]
61
+ """
62
+ B, T = in_idx.shape
63
+ device = in_idx.device
64
+
65
+ tok_embeds = self.tok_emb(in_idx) # [B,T,d]
66
+ pos_embeds = self.pos_emb(torch.arange(T, device=device)) # [T,d]
67
+ x = self.drop_emb(tok_embeds + pos_embeds)
68
+
69
+ cache: Dict[str, Dict[int, torch.Tensor]] = {}
70
+ if enable_cache:
71
+ cache = {"resid_pre": {}, "resid_post": {}, "attn_weights": {}}
72
+
73
+ hidden_states = []
74
+ attention_weights_per_layer = []
75
+
76
+ # Fallback to legacy single-layer intervention if no plan provided
77
+ legacy_layer = intervene_layer if (intervention_plan is None) else None
78
+ legacy_edit = edited_hidden if (intervention_plan is None) else None
79
+
80
+ for L, block in enumerate(self.trf_blocks):
81
+ if enable_cache:
82
+ cache["resid_pre"][L] = x.detach()
83
+
84
+ # Entry-point intervention
85
+ if intervention_plan is not None:
86
+ x = intervention_plan.maybe_replace_resid_pre(L, x)
87
+ elif legacy_layer is not None and legacy_edit is not None and L == legacy_layer:
88
+ x = legacy_edit # Inject casual intervention
89
+
90
+ # Run block (assumed to return (x_out, attn_weights))
91
+ block_out = block(x)
92
+ if isinstance(block_out, tuple) and len(block_out) == 2:
93
+ x, attn_w = block_out
94
+ else:
95
+ x = block_out
96
+ attn_w = None
97
+
98
+ if output_attentions_weights and attn_w is not None:
99
+ attention_weights_per_layer.append(attn_w.detach())
100
+ if enable_cache:
101
+ cache["attn_weights"][L] = attn_w.detach()
102
+
103
+ # Exit-point intervention (rarely used, but handy)
104
+ if intervention_plan is not None:
105
+ x = intervention_plan.maybe_replace_resid_post(L, x)
106
+
107
+ if output_hidden_states:
108
+ hidden_states.append(x.clone())
109
+
110
+ if enable_cache:
111
+ cache["resid_post"][L] = x.detach()
112
+
113
+ x = self.final_norm(x)
114
+ logits = self.out_head(x)
115
+
116
+ outputs = (logits,)
117
+ if enable_cache:
118
+ outputs += (cache,)
119
+ if output_hidden_states:
120
+ outputs += (hidden_states,)
121
+ if output_attentions_weights:
122
+ outputs += (attention_weights_per_layer,)
123
+ return outputs if len(outputs) > 1 else outputs[0]
model_896_14_8_256.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2ff54930487bdf5583980796eeee05854d761d6926c5a568f8a31b9eaddcc27
3
+ size 2011729853