mariamkhmahran
commited on
Commit
·
2e5bb45
1
Parent(s):
fb1c644
upload model
Browse files- components/activation.py +16 -0
- components/feed_forward.py +22 -0
- components/layer_norm.py +26 -0
- components/multi_head_attention.py +57 -0
- components/transformer_block.py +57 -0
- config.json +10 -0
- gpt_model.py +123 -0
- model_896_14_8_256.pth +3 -0
components/activation.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class GELU(nn.Module):
|
| 5 |
+
'''
|
| 6 |
+
GELU (Gausian Error Linear Unit) activation function.
|
| 7 |
+
'''
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return 0.5 * x * (1 + torch.tanh(
|
| 14 |
+
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
|
| 15 |
+
(x + 0.044715 * torch.pow(x, 3))
|
| 16 |
+
))
|
components/feed_forward.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from activation import GELU
|
| 5 |
+
|
| 6 |
+
class FeedForward(nn.Module):
|
| 7 |
+
'''
|
| 8 |
+
Feed-forward neural network with GELU activation function.
|
| 9 |
+
- Multi-Head Self-Attention → Captures relationships between tokens.
|
| 10 |
+
- Feedforward Neural Network (FFN) → Processes each token independently after attention.
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
def __init__(self, cfg):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.layers = nn.Sequential(
|
| 16 |
+
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
|
| 17 |
+
GELU(),
|
| 18 |
+
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return self.layers(x)
|
components/layer_norm.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class LayerNorm(nn.Module):
|
| 5 |
+
def __init__(self, emb_dim):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.eps = 1e-6 # small value to avoid division by zero
|
| 8 |
+
self.scale = nn.Parameter(torch.ones(emb_dim)) # trainable scale parameter
|
| 9 |
+
self.shift = nn.Parameter(torch.zeros(emb_dim)) # trainable shift parameter
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
'''
|
| 13 |
+
In this implementation of Layer Normalization, the normalization is applied along
|
| 14 |
+
the last dimension of the input tensor 𝑋, which represents the embedding dimension (dim=-1).
|
| 15 |
+
Normalizing over the embedding dimension ensures that each word is treated independently,
|
| 16 |
+
preventing one word from influencing another.
|
| 17 |
+
|
| 18 |
+
For Transformer models, input data typically has the following shape:
|
| 19 |
+
[batch_size, seq_len, emb_dim]
|
| 20 |
+
'''
|
| 21 |
+
|
| 22 |
+
mean = x.mean(dim=-1, keepdim=True)
|
| 23 |
+
var = x.var(dim=-1, keepdim=True, unbiased=False) # unbiased=False means that the variance is calculated with the Bessel correction
|
| 24 |
+
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
| 25 |
+
|
| 26 |
+
return self.scale * norm_x + self.shift
|
components/multi_head_attention.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class MultiHeadAttention(nn.Module):
|
| 5 |
+
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
|
| 6 |
+
super().__init__()
|
| 7 |
+
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
| 8 |
+
|
| 9 |
+
self.d_out = d_out
|
| 10 |
+
self.num_heads = num_heads
|
| 11 |
+
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
|
| 12 |
+
|
| 13 |
+
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
| 14 |
+
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
| 15 |
+
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
| 16 |
+
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
| 17 |
+
self.dropout = nn.Dropout(dropout)
|
| 18 |
+
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
b, num_tokens, _ = x.shape
|
| 22 |
+
|
| 23 |
+
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
| 24 |
+
queries = self.W_query(x)
|
| 25 |
+
values = self.W_value(x)
|
| 26 |
+
|
| 27 |
+
# We implicitly split the matrix by adding a `num_heads` dimension
|
| 28 |
+
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
| 29 |
+
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
| 30 |
+
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
|
| 31 |
+
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
| 32 |
+
|
| 33 |
+
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
|
| 34 |
+
keys = keys.transpose(1, 2)
|
| 35 |
+
queries = queries.transpose(1, 2)
|
| 36 |
+
values = values.transpose(1, 2)
|
| 37 |
+
|
| 38 |
+
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
| 39 |
+
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
| 40 |
+
|
| 41 |
+
# Original mask truncated to the number of tokens and converted to boolean
|
| 42 |
+
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
| 43 |
+
|
| 44 |
+
# Use the mask to fill attention scores
|
| 45 |
+
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
| 46 |
+
|
| 47 |
+
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
| 48 |
+
attn_weights = self.dropout(attn_weights)
|
| 49 |
+
|
| 50 |
+
# Shape: (b, num_tokens, num_heads, head_dim)
|
| 51 |
+
context_vec = (attn_weights @ values).transpose(1, 2)
|
| 52 |
+
|
| 53 |
+
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
| 54 |
+
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
|
| 55 |
+
context_vec = self.out_proj(context_vec) # optional projection
|
| 56 |
+
|
| 57 |
+
return context_vec, attn_weights
|
components/transformer_block.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from feed_forward import FeedForward
|
| 5 |
+
from multi_head_attention import MultiHeadAttention
|
| 6 |
+
|
| 7 |
+
class TransformerBlock(nn.Module):
|
| 8 |
+
def __init__(self, cfg):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.att = MultiHeadAttention(
|
| 11 |
+
d_in=cfg["emb_dim"],
|
| 12 |
+
d_out=cfg["emb_dim"],
|
| 13 |
+
context_length=cfg["context_length"],
|
| 14 |
+
num_heads=cfg["n_heads"],
|
| 15 |
+
dropout=cfg["drop_rate"],
|
| 16 |
+
qkv_bias=cfg["qkv_bias"])
|
| 17 |
+
self.ff = FeedForward(cfg)
|
| 18 |
+
self.norm1 = nn.LayerNorm(cfg["emb_dim"])
|
| 19 |
+
self.norm2 = nn.LayerNorm(cfg["emb_dim"])
|
| 20 |
+
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
'''
|
| 25 |
+
The transformer block consists of two main components:
|
| 26 |
+
- Multi-Head Self-Attention → Captures relationships between tokens.
|
| 27 |
+
- Feedforward Neural Network (FFN) → Processes each token independently after attention.
|
| 28 |
+
|
| 29 |
+
The output of the attention block is added to the input of the block (skip connection),
|
| 30 |
+
which is then normalized using LayerNorm. The output is then passed through the FFN,
|
| 31 |
+
and the result is again added to the input of the block and normalized.
|
| 32 |
+
|
| 33 |
+
The dropout is applied to the skip connections before adding them to the output of the
|
| 34 |
+
attention and FFN blocks. This helps to prevent overfitting and improves generalization.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x (torch.Tensor): Input tensor of shape [batch_size, seq_len, emb_dim].
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: Output tensor of shape [batch_size, seq_len, emb_dim].
|
| 41 |
+
'''
|
| 42 |
+
|
| 43 |
+
shortcut = x
|
| 44 |
+
x = self.norm1(x)
|
| 45 |
+
x, _attn_weights = self.att(x)
|
| 46 |
+
x = self.drop_shortcut(x)
|
| 47 |
+
x = x + shortcut
|
| 48 |
+
|
| 49 |
+
shortcut = x
|
| 50 |
+
x = self.norm2(x)
|
| 51 |
+
x = self.ff(x)
|
| 52 |
+
x = self.drop_shortcut(x)
|
| 53 |
+
x = x + shortcut
|
| 54 |
+
|
| 55 |
+
return x, _attn_weights
|
| 56 |
+
|
| 57 |
+
|
config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "gpt_and_prejudice",
|
| 3 |
+
"vocab_size": 50257,
|
| 4 |
+
"context_length": 256,
|
| 5 |
+
"emb_dim": 896,
|
| 6 |
+
"n_heads": 14,
|
| 7 |
+
"n_layers": 8,
|
| 8 |
+
"drop_rate": 0.2,
|
| 9 |
+
"qkv_bias": true
|
| 10 |
+
}
|
gpt_model.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Dict, Any, Optional
|
| 5 |
+
|
| 6 |
+
from components.transformer_block import TransformerBlock
|
| 7 |
+
from components.layer_norm import LayerNorm
|
| 8 |
+
|
| 9 |
+
class InterventionPlan:
|
| 10 |
+
"""
|
| 11 |
+
Hook object consulted during forward() to optionally replace activations.
|
| 12 |
+
Override any of the methods below in your experiment code.
|
| 13 |
+
"""
|
| 14 |
+
def maybe_replace_resid_pre(self, layer_idx: int, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return x
|
| 16 |
+
def maybe_replace_resid_post(self, layer_idx: int, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
return x
|
| 18 |
+
# Optional: only works if your blocks expose per-head z or mlp outputs.
|
| 19 |
+
def maybe_replace_head_z(self, layer_idx: int, z: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
return z
|
| 21 |
+
def maybe_replace_mlp_out(self, layer_idx: int, h: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
return h
|
| 23 |
+
|
| 24 |
+
class GPTModel(nn.Module):
|
| 25 |
+
def __init__(self, cfg: Dict[str, Any]):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.cfg = cfg
|
| 28 |
+
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]
|
| 29 |
+
)
|
| 30 |
+
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]
|
| 31 |
+
)
|
| 32 |
+
self.drop_emb = nn.Dropout(cfg["drop_rate"]
|
| 33 |
+
)
|
| 34 |
+
self.trf_blocks = nn.Sequential(
|
| 35 |
+
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
|
| 36 |
+
)
|
| 37 |
+
self.final_norm = LayerNorm(cfg["emb_dim"]
|
| 38 |
+
)
|
| 39 |
+
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
| 40 |
+
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def cache_forward(self, in_idx: torch.Tensor):
|
| 43 |
+
"""Run a forward pass with caching enabled (no interventions)."""
|
| 44 |
+
return self.forward(in_idx, enable_cache=True)
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self,
|
| 48 |
+
in_idx: torch.Tensor,
|
| 49 |
+
enable_cache: bool = False,
|
| 50 |
+
intervention_plan: Optional[InterventionPlan] = None,
|
| 51 |
+
output_hidden_states: bool = False,
|
| 52 |
+
output_attentions_weights: bool = False,
|
| 53 |
+
# Backward-compat args (ignored if plan is provided)
|
| 54 |
+
intervene_layer: Optional[int] = None,
|
| 55 |
+
edited_hidden: Optional[torch.Tensor] = None,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Mechanistic interpretability-friendly forward.
|
| 59 |
+
Returns: logits, (optional) cache, (optional) hidden_states, (optional) attn_weights
|
| 60 |
+
Cache keys: resid_pre[L], resid_post[L], attn_weights[L]
|
| 61 |
+
"""
|
| 62 |
+
B, T = in_idx.shape
|
| 63 |
+
device = in_idx.device
|
| 64 |
+
|
| 65 |
+
tok_embeds = self.tok_emb(in_idx) # [B,T,d]
|
| 66 |
+
pos_embeds = self.pos_emb(torch.arange(T, device=device)) # [T,d]
|
| 67 |
+
x = self.drop_emb(tok_embeds + pos_embeds)
|
| 68 |
+
|
| 69 |
+
cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
| 70 |
+
if enable_cache:
|
| 71 |
+
cache = {"resid_pre": {}, "resid_post": {}, "attn_weights": {}}
|
| 72 |
+
|
| 73 |
+
hidden_states = []
|
| 74 |
+
attention_weights_per_layer = []
|
| 75 |
+
|
| 76 |
+
# Fallback to legacy single-layer intervention if no plan provided
|
| 77 |
+
legacy_layer = intervene_layer if (intervention_plan is None) else None
|
| 78 |
+
legacy_edit = edited_hidden if (intervention_plan is None) else None
|
| 79 |
+
|
| 80 |
+
for L, block in enumerate(self.trf_blocks):
|
| 81 |
+
if enable_cache:
|
| 82 |
+
cache["resid_pre"][L] = x.detach()
|
| 83 |
+
|
| 84 |
+
# Entry-point intervention
|
| 85 |
+
if intervention_plan is not None:
|
| 86 |
+
x = intervention_plan.maybe_replace_resid_pre(L, x)
|
| 87 |
+
elif legacy_layer is not None and legacy_edit is not None and L == legacy_layer:
|
| 88 |
+
x = legacy_edit # Inject casual intervention
|
| 89 |
+
|
| 90 |
+
# Run block (assumed to return (x_out, attn_weights))
|
| 91 |
+
block_out = block(x)
|
| 92 |
+
if isinstance(block_out, tuple) and len(block_out) == 2:
|
| 93 |
+
x, attn_w = block_out
|
| 94 |
+
else:
|
| 95 |
+
x = block_out
|
| 96 |
+
attn_w = None
|
| 97 |
+
|
| 98 |
+
if output_attentions_weights and attn_w is not None:
|
| 99 |
+
attention_weights_per_layer.append(attn_w.detach())
|
| 100 |
+
if enable_cache:
|
| 101 |
+
cache["attn_weights"][L] = attn_w.detach()
|
| 102 |
+
|
| 103 |
+
# Exit-point intervention (rarely used, but handy)
|
| 104 |
+
if intervention_plan is not None:
|
| 105 |
+
x = intervention_plan.maybe_replace_resid_post(L, x)
|
| 106 |
+
|
| 107 |
+
if output_hidden_states:
|
| 108 |
+
hidden_states.append(x.clone())
|
| 109 |
+
|
| 110 |
+
if enable_cache:
|
| 111 |
+
cache["resid_post"][L] = x.detach()
|
| 112 |
+
|
| 113 |
+
x = self.final_norm(x)
|
| 114 |
+
logits = self.out_head(x)
|
| 115 |
+
|
| 116 |
+
outputs = (logits,)
|
| 117 |
+
if enable_cache:
|
| 118 |
+
outputs += (cache,)
|
| 119 |
+
if output_hidden_states:
|
| 120 |
+
outputs += (hidden_states,)
|
| 121 |
+
if output_attentions_weights:
|
| 122 |
+
outputs += (attention_weights_per_layer,)
|
| 123 |
+
return outputs if len(outputs) > 1 else outputs[0]
|
model_896_14_8_256.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f2ff54930487bdf5583980796eeee05854d761d6926c5a568f8a31b9eaddcc27
|
| 3 |
+
size 2011729853
|