Spaces:
Sleeping
Sleeping
SaiMupparaju
commited on
Commit
·
03653db
0
Parent(s):
Initial commit for MechVis Hugging Face Space
Browse files- .gitignore +26 -0
- 1_4_1_Indirect_Object_Identification_exercises.ipynb +0 -0
- Dockerfile +13 -0
- Procfile +1 -0
- README.md +80 -0
- README_HF.md +27 -0
- app.py +116 -0
- requirements.txt +5 -0
- space.yaml +7 -0
- templates/index.html +346 -0
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Distribution / packaging
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
*.egg-info/
|
| 10 |
+
|
| 11 |
+
# Virtual environments
|
| 12 |
+
venv/
|
| 13 |
+
.env/
|
| 14 |
+
|
| 15 |
+
# Environment variables
|
| 16 |
+
.env
|
| 17 |
+
|
| 18 |
+
# IDE files
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
|
| 22 |
+
# Jupyter Notebook
|
| 23 |
+
.ipynb_checkpoints/
|
| 24 |
+
|
| 25 |
+
# Miscellaneous
|
| 26 |
+
.DS_Store
|
1_4_1_Indirect_Object_Identification_exercises.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 6 |
+
|
| 7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 8 |
+
|
| 9 |
+
COPY . /code
|
| 10 |
+
|
| 11 |
+
EXPOSE 7860
|
| 12 |
+
|
| 13 |
+
CMD ["python", "app.py"]
|
Procfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
web: python app.py
|
README.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MechVis: GPT-2 Attention Head Contribution Visualization
|
| 2 |
+
|
| 3 |
+
[](https://huggingface.co/spaces/saivamsim26/mechvis)
|
| 4 |
+
|
| 5 |
+
MechVis is a tool for visualizing how attention heads in GPT-2 small contribute to next token predictions. It provides a simple web interface where you can enter text, see what token the model predicts next, and visualize which attention heads contribute most to that prediction.
|
| 6 |
+
|
| 7 |
+
This project is inspired by mechanistic interpretability research on language models, particularly studies of "indirect object identification" in GPT-2 small.
|
| 8 |
+
|
| 9 |
+
## Features
|
| 10 |
+
|
| 11 |
+
- Input any text prompt and see GPT-2's next token prediction
|
| 12 |
+
- View a heatmap visualization of each attention head's contribution to the predicted token
|
| 13 |
+
- Interactive tooltips showing exact contribution values for each head
|
| 14 |
+
- Simple, clean web interface
|
| 15 |
+
|
| 16 |
+
## Deployment on Hugging Face Spaces
|
| 17 |
+
|
| 18 |
+
1. Create a new Space on Hugging Face:
|
| 19 |
+
- Go to https://huggingface.co/spaces
|
| 20 |
+
- Click "Create new Space"
|
| 21 |
+
- Choose "Docker" as the SDK
|
| 22 |
+
- Set the environment variables if needed
|
| 23 |
+
|
| 24 |
+
2. Upload the following files to your Space:
|
| 25 |
+
- `app.py`
|
| 26 |
+
- `requirements.txt`
|
| 27 |
+
- `Dockerfile`
|
| 28 |
+
- Contents of `templates/` directory
|
| 29 |
+
|
| 30 |
+
The application will automatically deploy and will be available at your Space's URL.
|
| 31 |
+
|
| 32 |
+
## Local Development
|
| 33 |
+
|
| 34 |
+
To run the application locally:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip install -r requirements.txt
|
| 38 |
+
python app.py
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
The application will be available at http://localhost:7860
|
| 42 |
+
|
| 43 |
+
## How to Use
|
| 44 |
+
|
| 45 |
+
1. Enter a text prompt in the input field
|
| 46 |
+
2. Click "Predict Next Word"
|
| 47 |
+
3. View the predicted token, its logit value, and probability
|
| 48 |
+
4. Explore the heatmap visualization showing each attention head's contribution:
|
| 49 |
+
- Red cells indicate positive contributions to the predicted token
|
| 50 |
+
- Blue cells indicate negative contributions
|
| 51 |
+
- Hover over cells to see exact contribution values
|
| 52 |
+
|
| 53 |
+
## Understanding the Visualization
|
| 54 |
+
|
| 55 |
+
The visualization shows a 12×12 grid representing all attention heads in GPT-2 small, with:
|
| 56 |
+
- Rows representing layers (0-11)
|
| 57 |
+
- Columns representing heads within each layer (0-11)
|
| 58 |
+
- Color intensity showing the magnitude of contribution
|
| 59 |
+
|
| 60 |
+
This kind of visualization can help identify which attention heads are most important for specific prediction tasks. For example, research has shown that certain heads specialize in tasks like:
|
| 61 |
+
- Name mover heads (e.g., 9.9, 10.0, 9.6)
|
| 62 |
+
- Induction heads (e.g., 5.5, 6.9)
|
| 63 |
+
- S-inhibition heads (e.g., 7.3, 7.9, 8.6, 8.10)
|
| 64 |
+
|
| 65 |
+
## Example Use Cases
|
| 66 |
+
|
| 67 |
+
1. **Indirect Object Identification**: Try entering "When John and Mary went to the store, John gave a drink to" and see which heads contribute to predicting "Mary"
|
| 68 |
+
|
| 69 |
+
2. **Induction Pattern Detection**: Enter repetitive sequences like "The capital of France is Paris. The capital of Germany is" to see induction heads activate
|
| 70 |
+
|
| 71 |
+
3. **Exploration**: Try various prompts to see how different heads specialize in different linguistic patterns
|
| 72 |
+
|
| 73 |
+
## References
|
| 74 |
+
|
| 75 |
+
- [Transformer Lens](https://github.com/neelnanda-io/TransformerLens) - Library for transformer interpretability
|
| 76 |
+
- [Indirect Object Identification](https://arxiv.org/abs/2211.00593) - Research on circuits in GPT-2 small
|
| 77 |
+
|
| 78 |
+
## License
|
| 79 |
+
|
| 80 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
README_HF.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MechVis: GPT-2 Attention Head Visualization
|
| 2 |
+
|
| 3 |
+
This interactive web app allows you to visualize how different attention heads in GPT-2 small contribute to next token predictions.
|
| 4 |
+
|
| 5 |
+
## How to Use
|
| 6 |
+
|
| 7 |
+
1. Enter text in the input field (e.g., "When John and Mary went to the store, John gave a drink to")
|
| 8 |
+
2. Click "Predict Next Word"
|
| 9 |
+
3. See what token GPT-2 predicts next and explore how each attention head contributes to that prediction
|
| 10 |
+
|
| 11 |
+
## Features
|
| 12 |
+
|
| 13 |
+
- Next token prediction with GPT-2 small
|
| 14 |
+
- Interactive heatmap showing attention head contributions
|
| 15 |
+
- Layer contribution analysis
|
| 16 |
+
- Hover over cells to see exact contribution values
|
| 17 |
+
|
| 18 |
+
## Examples to Try
|
| 19 |
+
|
| 20 |
+
- **Indirect Object Identification**: "When John and Mary went to the store, John gave a drink to" (likely predicts "Mary")
|
| 21 |
+
- **Induction Pattern**: "The capital of France is Paris. The capital of Germany is" (likely predicts "Berlin")
|
| 22 |
+
|
| 23 |
+
## About
|
| 24 |
+
|
| 25 |
+
This project uses [TransformerLens](https://github.com/neelnanda-io/TransformerLens) to access internal model activations and calculate how each attention head contributes to the final logit score of the predicted token.
|
| 26 |
+
|
| 27 |
+
[GitHub Repository](https://github.com/saivamsim26/mechvis)
|
app.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from flask import Flask, render_template, request, jsonify
|
| 4 |
+
from transformer_lens import HookedTransformer
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
app = Flask(__name__)
|
| 8 |
+
|
| 9 |
+
# Load GPT-2 small model
|
| 10 |
+
model = HookedTransformer.from_pretrained(
|
| 11 |
+
"gpt2-small",
|
| 12 |
+
center_unembed=True,
|
| 13 |
+
center_writing_weights=True,
|
| 14 |
+
fold_ln=True,
|
| 15 |
+
refactor_factored_attn_matrices=True,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
@app.route('/', methods=['GET', 'POST'])
|
| 19 |
+
def index():
|
| 20 |
+
prediction = None
|
| 21 |
+
text = ""
|
| 22 |
+
head_contributions = None
|
| 23 |
+
|
| 24 |
+
if request.method == 'POST':
|
| 25 |
+
text = request.form.get('text', '')
|
| 26 |
+
|
| 27 |
+
if text:
|
| 28 |
+
# Tokenize the input text
|
| 29 |
+
tokens = model.to_tokens(text, prepend_bos=True)
|
| 30 |
+
|
| 31 |
+
# Run the model with cache to get intermediate activations
|
| 32 |
+
logits, cache = model.run_with_cache(tokens)
|
| 33 |
+
|
| 34 |
+
# Get logits for the last token
|
| 35 |
+
last_token_logits = logits[0, -1]
|
| 36 |
+
|
| 37 |
+
# Get the index of the token with the highest logit
|
| 38 |
+
top_token_idx = torch.argmax(last_token_logits).item()
|
| 39 |
+
|
| 40 |
+
# Get the logit value
|
| 41 |
+
top_token_logit = last_token_logits[top_token_idx].item()
|
| 42 |
+
|
| 43 |
+
# Get the probability
|
| 44 |
+
probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
| 45 |
+
top_token_prob = probs[top_token_idx].item() * 100 # Convert to percentage
|
| 46 |
+
|
| 47 |
+
# Get the token as a string
|
| 48 |
+
top_token_str = model.to_string([top_token_idx])
|
| 49 |
+
|
| 50 |
+
# Get attention head contributions for the top token
|
| 51 |
+
head_contributions = calculate_head_contributions(cache, top_token_idx, model)
|
| 52 |
+
|
| 53 |
+
prediction = {
|
| 54 |
+
'token': top_token_str,
|
| 55 |
+
'logit': top_token_logit,
|
| 56 |
+
'prob': top_token_prob
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
return render_template('index.html', prediction=prediction, text=text, head_contributions=json.dumps(head_contributions) if head_contributions else None)
|
| 60 |
+
|
| 61 |
+
def calculate_head_contributions(cache, token_idx, model):
|
| 62 |
+
"""Calculate the contribution of each attention head to the top token's logit."""
|
| 63 |
+
|
| 64 |
+
# Get all head outputs for the last token
|
| 65 |
+
head_outputs_by_layer = []
|
| 66 |
+
contributions = []
|
| 67 |
+
layer_total_contributions = []
|
| 68 |
+
|
| 69 |
+
# Get the direction in the residual stream that corresponds to the token
|
| 70 |
+
token_direction = model.W_U[:, token_idx].detach()
|
| 71 |
+
|
| 72 |
+
# Calculate contributions for each head
|
| 73 |
+
for layer in range(model.cfg.n_layers):
|
| 74 |
+
# Get the output of each head at the last position
|
| 75 |
+
z = cache["z", layer][0, -1] # [head, d_head]
|
| 76 |
+
|
| 77 |
+
# Apply the OV matrix for each head
|
| 78 |
+
head_outputs = torch.einsum("hd,hdm->hm", z, model.W_O[layer]) # [head, d_model]
|
| 79 |
+
|
| 80 |
+
# Project onto the token direction to get contribution to the logit
|
| 81 |
+
head_contribs = torch.einsum("hm,m->h", head_outputs, token_direction)
|
| 82 |
+
|
| 83 |
+
# Calculate total contribution for this layer
|
| 84 |
+
layer_total = head_contribs.sum().item()
|
| 85 |
+
layer_total_contributions.append(layer_total)
|
| 86 |
+
|
| 87 |
+
# Convert to list for JSON serialization
|
| 88 |
+
layer_contributions = head_contribs.detach().cpu().numpy().tolist()
|
| 89 |
+
contributions.append(layer_contributions)
|
| 90 |
+
|
| 91 |
+
# Calculate total contribution across all heads
|
| 92 |
+
total_contribution = sum([sum(layer_contrib) for layer_contrib in contributions])
|
| 93 |
+
|
| 94 |
+
# Convert contributions to percentage of total
|
| 95 |
+
percentage_contributions = []
|
| 96 |
+
for layer_contributions in contributions:
|
| 97 |
+
percentage_layer = [(contrib / total_contribution) * 100 for contrib in layer_contributions]
|
| 98 |
+
percentage_contributions.append(percentage_layer)
|
| 99 |
+
|
| 100 |
+
# Calculate per-layer contribution percentages
|
| 101 |
+
layer_percentages = [(layer_total / total_contribution) * 100 for layer_total in layer_total_contributions]
|
| 102 |
+
|
| 103 |
+
# Get the max and min values for normalization in visualization
|
| 104 |
+
all_contribs_pct = np.array(percentage_contributions).flatten()
|
| 105 |
+
max_contrib = float(np.max(all_contribs_pct))
|
| 106 |
+
min_contrib = float(np.min(all_contribs_pct))
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"contributions": percentage_contributions,
|
| 110 |
+
"max_value": max_contrib,
|
| 111 |
+
"min_value": min_contrib,
|
| 112 |
+
"layer_contributions": layer_percentages
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
if __name__ == '__main__':
|
| 116 |
+
app.run(host="0.0.0.0", port=7860, debug=False)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask==2.0.1
|
| 2 |
+
torch==2.0.1
|
| 3 |
+
numpy>=1.21.0
|
| 4 |
+
transformer-lens==1.2.2
|
| 5 |
+
gunicorn==20.1.0
|
space.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
title: MechVis
|
| 2 |
+
emoji: 📊
|
| 3 |
+
colorFrom: indigo
|
| 4 |
+
colorTo: purple
|
| 5 |
+
sdk: docker
|
| 6 |
+
app_port: 7860
|
| 7 |
+
pinned: false
|
templates/index.html
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>GPT-2 Next Word Prediction</title>
|
| 7 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
|
| 8 |
+
<script src="https://d3js.org/d3.v7.min.js"></script>
|
| 9 |
+
<style>
|
| 10 |
+
body {
|
| 11 |
+
padding: 40px;
|
| 12 |
+
font-family: system-ui, -apple-system, sans-serif;
|
| 13 |
+
}
|
| 14 |
+
.prediction {
|
| 15 |
+
margin-top: 30px;
|
| 16 |
+
padding: 20px;
|
| 17 |
+
background-color: #f8f9fa;
|
| 18 |
+
border-radius: 5px;
|
| 19 |
+
}
|
| 20 |
+
.token {
|
| 21 |
+
font-size: 1.2rem;
|
| 22 |
+
font-weight: bold;
|
| 23 |
+
background-color: #e9ecef;
|
| 24 |
+
padding: 5px 10px;
|
| 25 |
+
border-radius: 4px;
|
| 26 |
+
display: inline-block;
|
| 27 |
+
margin-bottom: 10px;
|
| 28 |
+
}
|
| 29 |
+
#visualization {
|
| 30 |
+
margin-top: 30px;
|
| 31 |
+
width: 100%;
|
| 32 |
+
overflow-x: auto;
|
| 33 |
+
}
|
| 34 |
+
.head-cell {
|
| 35 |
+
stroke: #ddd;
|
| 36 |
+
stroke-width: 1px;
|
| 37 |
+
}
|
| 38 |
+
.layer-label, .head-label {
|
| 39 |
+
font-size: 12px;
|
| 40 |
+
font-weight: bold;
|
| 41 |
+
text-anchor: middle;
|
| 42 |
+
}
|
| 43 |
+
.tooltip {
|
| 44 |
+
position: absolute;
|
| 45 |
+
background-color: rgba(255, 255, 255, 0.9);
|
| 46 |
+
border: 1px solid #ddd;
|
| 47 |
+
padding: 8px;
|
| 48 |
+
border-radius: 4px;
|
| 49 |
+
pointer-events: none;
|
| 50 |
+
font-size: 12px;
|
| 51 |
+
}
|
| 52 |
+
.visualization-container {
|
| 53 |
+
margin-top: 30px;
|
| 54 |
+
background-color: white;
|
| 55 |
+
border-radius: 5px;
|
| 56 |
+
padding: 20px;
|
| 57 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.1);
|
| 58 |
+
}
|
| 59 |
+
.legend {
|
| 60 |
+
margin-top: 15px;
|
| 61 |
+
margin-bottom: 20px;
|
| 62 |
+
}
|
| 63 |
+
.legend-item {
|
| 64 |
+
display: inline-block;
|
| 65 |
+
margin-right: 20px;
|
| 66 |
+
}
|
| 67 |
+
.legend-color {
|
| 68 |
+
display: inline-block;
|
| 69 |
+
width: 20px;
|
| 70 |
+
height: 20px;
|
| 71 |
+
margin-right: 5px;
|
| 72 |
+
vertical-align: middle;
|
| 73 |
+
}
|
| 74 |
+
</style>
|
| 75 |
+
</head>
|
| 76 |
+
<body>
|
| 77 |
+
<div class="container">
|
| 78 |
+
<h1 class="mb-4">GPT-2 Next Word Prediction</h1>
|
| 79 |
+
|
| 80 |
+
<div class="row">
|
| 81 |
+
<div class="col-md-12">
|
| 82 |
+
<form method="POST">
|
| 83 |
+
<div class="mb-3">
|
| 84 |
+
<label for="text" class="form-label">Input Text:</label>
|
| 85 |
+
<textarea class="form-control" id="text" name="text" rows="3" placeholder="Enter text (e.g. 'When John and Mary went to the store, John gave a drink to')" required>{{ text }}</textarea>
|
| 86 |
+
</div>
|
| 87 |
+
<button type="submit" class="btn btn-primary">Predict Next Word</button>
|
| 88 |
+
</form>
|
| 89 |
+
</div>
|
| 90 |
+
</div>
|
| 91 |
+
|
| 92 |
+
{% if prediction %}
|
| 93 |
+
<div class="row">
|
| 94 |
+
<div class="col-md-12">
|
| 95 |
+
<div class="prediction">
|
| 96 |
+
<h3>Prediction Results</h3>
|
| 97 |
+
<p>Input text: <strong>{{ text }}</strong></p>
|
| 98 |
+
<p>Next word: <span class="token">{{ prediction.token }}</span></p>
|
| 99 |
+
<p>Logit value: <strong>{{ "%.4f"|format(prediction.logit) }}</strong></p>
|
| 100 |
+
<p>Probability: <strong>{{ "%.2f"|format(prediction.prob) }}%</strong></p>
|
| 101 |
+
</div>
|
| 102 |
+
</div>
|
| 103 |
+
</div>
|
| 104 |
+
|
| 105 |
+
{% if head_contributions %}
|
| 106 |
+
<div class="row">
|
| 107 |
+
<div class="col-md-12">
|
| 108 |
+
<div class="visualization-container">
|
| 109 |
+
<h3>Layer Contributions to Log Probability</h3>
|
| 110 |
+
<p>This chart shows how each layer in GPT-2 contributes to the log probability of the token "{{ prediction.token }}" (as % of total contribution).</p>
|
| 111 |
+
|
| 112 |
+
<div id="layer-chart"></div>
|
| 113 |
+
|
| 114 |
+
<h3>Attention Head Contributions</h3>
|
| 115 |
+
<p>This visualization shows how each attention head in GPT-2 contributes to the prediction of the token "{{ prediction.token }}" (as % of total contribution).</p>
|
| 116 |
+
|
| 117 |
+
<div class="legend">
|
| 118 |
+
<div class="legend-item">
|
| 119 |
+
<div class="legend-color" style="background-color: #4575b4;"></div>
|
| 120 |
+
<span>Negative contribution %</span>
|
| 121 |
+
</div>
|
| 122 |
+
<div class="legend-item">
|
| 123 |
+
<div class="legend-color" style="background-color: #ffffbf;"></div>
|
| 124 |
+
<span>Neutral (0%)</span>
|
| 125 |
+
</div>
|
| 126 |
+
<div class="legend-item">
|
| 127 |
+
<div class="legend-color" style="background-color: #d73027;"></div>
|
| 128 |
+
<span>Positive contribution %</span>
|
| 129 |
+
</div>
|
| 130 |
+
</div>
|
| 131 |
+
|
| 132 |
+
<div id="visualization"></div>
|
| 133 |
+
</div>
|
| 134 |
+
</div>
|
| 135 |
+
</div>
|
| 136 |
+
{% endif %}
|
| 137 |
+
{% endif %}
|
| 138 |
+
</div>
|
| 139 |
+
|
| 140 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
|
| 141 |
+
|
| 142 |
+
{% if head_contributions %}
|
| 143 |
+
<script>
|
| 144 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 145 |
+
const headContributions = {{ head_contributions|safe }};
|
| 146 |
+
|
| 147 |
+
// Create layer contributions bar chart
|
| 148 |
+
const createLayerChart = () => {
|
| 149 |
+
const layerContribs = headContributions.layer_contributions;
|
| 150 |
+
const margin = { top: 40, right: 30, bottom: 50, left: 60 };
|
| 151 |
+
const width = Math.min(800, window.innerWidth - 100);
|
| 152 |
+
const height = 300;
|
| 153 |
+
|
| 154 |
+
const svg = d3.select("#layer-chart")
|
| 155 |
+
.append("svg")
|
| 156 |
+
.attr("width", width)
|
| 157 |
+
.attr("height", height);
|
| 158 |
+
|
| 159 |
+
const g = svg.append("g")
|
| 160 |
+
.attr("transform", `translate(${margin.left},${margin.top})`);
|
| 161 |
+
|
| 162 |
+
// Create scales
|
| 163 |
+
const x = d3.scaleBand()
|
| 164 |
+
.domain(d3.range(layerContribs.length))
|
| 165 |
+
.range([0, width - margin.left - margin.right])
|
| 166 |
+
.padding(0.1);
|
| 167 |
+
|
| 168 |
+
const y = d3.scaleLinear()
|
| 169 |
+
.domain([
|
| 170 |
+
Math.min(0, d3.min(layerContribs)),
|
| 171 |
+
Math.max(0, d3.max(layerContribs))
|
| 172 |
+
])
|
| 173 |
+
.nice()
|
| 174 |
+
.range([height - margin.top - margin.bottom, 0]);
|
| 175 |
+
|
| 176 |
+
// Create color scale - positive is green, negative is purple
|
| 177 |
+
const colorScale = d3.scaleLinear()
|
| 178 |
+
.domain([Math.min(0, d3.min(layerContribs)), 0, Math.max(0, d3.max(layerContribs))])
|
| 179 |
+
.range(["#9467bd", "#f7f7f7", "#2ca02c"]);
|
| 180 |
+
|
| 181 |
+
// Create tooltip
|
| 182 |
+
const tooltip = d3.select("body")
|
| 183 |
+
.append("div")
|
| 184 |
+
.attr("class", "tooltip")
|
| 185 |
+
.style("opacity", 0);
|
| 186 |
+
|
| 187 |
+
// Create bars
|
| 188 |
+
g.selectAll(".bar")
|
| 189 |
+
.data(layerContribs)
|
| 190 |
+
.join("rect")
|
| 191 |
+
.attr("class", "bar")
|
| 192 |
+
.attr("x", (d, i) => x(i))
|
| 193 |
+
.attr("y", d => d >= 0 ? y(d) : y(0))
|
| 194 |
+
.attr("width", x.bandwidth())
|
| 195 |
+
.attr("height", d => Math.abs(y(0) - y(d)))
|
| 196 |
+
.attr("fill", d => colorScale(d))
|
| 197 |
+
.attr("stroke", "#555")
|
| 198 |
+
.attr("stroke-width", 1)
|
| 199 |
+
.on("mouseover", function(event, d) {
|
| 200 |
+
d3.select(this).attr("stroke", "#000").attr("stroke-width", 2);
|
| 201 |
+
tooltip.transition().duration(200).style("opacity", 1);
|
| 202 |
+
tooltip.html(`Layer ${layerContribs.indexOf(d)}<br>Contribution: ${d.toFixed(2)}%`)
|
| 203 |
+
.style("left", (event.pageX + 10) + "px")
|
| 204 |
+
.style("top", (event.pageY - 28) + "px");
|
| 205 |
+
})
|
| 206 |
+
.on("mouseout", function() {
|
| 207 |
+
d3.select(this).attr("stroke", "#555").attr("stroke-width", 1);
|
| 208 |
+
tooltip.transition().duration(500).style("opacity", 0);
|
| 209 |
+
});
|
| 210 |
+
|
| 211 |
+
// Add x-axis
|
| 212 |
+
g.append("g")
|
| 213 |
+
.attr("transform", `translate(0,${y(0)})`)
|
| 214 |
+
.call(d3.axisBottom(x).tickFormat(i => `L${i}`))
|
| 215 |
+
.selectAll("text")
|
| 216 |
+
.style("font-size", "12px");
|
| 217 |
+
|
| 218 |
+
// Add y-axis
|
| 219 |
+
g.append("g")
|
| 220 |
+
.call(d3.axisLeft(y).tickFormat(d => `${d.toFixed(1)}%`))
|
| 221 |
+
.selectAll("text")
|
| 222 |
+
.style("font-size", "12px");
|
| 223 |
+
|
| 224 |
+
// Add title
|
| 225 |
+
svg.append("text")
|
| 226 |
+
.attr("x", width / 2)
|
| 227 |
+
.attr("y", 20)
|
| 228 |
+
.attr("text-anchor", "middle")
|
| 229 |
+
.style("font-size", "16px")
|
| 230 |
+
.style("font-weight", "bold")
|
| 231 |
+
.text("Layer Contributions to Log Probability (%)");
|
| 232 |
+
|
| 233 |
+
// Add x-axis label
|
| 234 |
+
svg.append("text")
|
| 235 |
+
.attr("x", width / 2)
|
| 236 |
+
.attr("y", height - 10)
|
| 237 |
+
.attr("text-anchor", "middle")
|
| 238 |
+
.style("font-size", "14px")
|
| 239 |
+
.text("Layer");
|
| 240 |
+
|
| 241 |
+
// Add y-axis label
|
| 242 |
+
svg.append("text")
|
| 243 |
+
.attr("transform", "rotate(-90)")
|
| 244 |
+
.attr("x", -(height / 2))
|
| 245 |
+
.attr("y", 15)
|
| 246 |
+
.attr("text-anchor", "middle")
|
| 247 |
+
.style("font-size", "14px")
|
| 248 |
+
.text("Contribution %");
|
| 249 |
+
};
|
| 250 |
+
|
| 251 |
+
// Create head contributions heatmap
|
| 252 |
+
const createHeadHeatmap = () => {
|
| 253 |
+
// Define visualization parameters
|
| 254 |
+
const cellSize = 40;
|
| 255 |
+
const numLayers = headContributions.contributions.length;
|
| 256 |
+
const numHeads = headContributions.contributions[0].length;
|
| 257 |
+
const margin = { top: 60, right: 20, bottom: 20, left: 60 };
|
| 258 |
+
const width = cellSize * numHeads + margin.left + margin.right;
|
| 259 |
+
const height = cellSize * numLayers + margin.top + margin.bottom;
|
| 260 |
+
|
| 261 |
+
// Create SVG
|
| 262 |
+
const svg = d3.select("#visualization")
|
| 263 |
+
.append("svg")
|
| 264 |
+
.attr("width", width)
|
| 265 |
+
.attr("height", height);
|
| 266 |
+
|
| 267 |
+
// Create a group for the heatmap
|
| 268 |
+
const g = svg.append("g")
|
| 269 |
+
.attr("transform", `translate(${margin.left},${margin.top})`);
|
| 270 |
+
|
| 271 |
+
// Create color scale
|
| 272 |
+
const colorScale = d3.scaleSequential(d3.interpolateRdBu)
|
| 273 |
+
.domain([headContributions.max_value, headContributions.min_value]);
|
| 274 |
+
|
| 275 |
+
// Create tooltip
|
| 276 |
+
const tooltip = d3.select("body")
|
| 277 |
+
.append("div")
|
| 278 |
+
.attr("class", "tooltip")
|
| 279 |
+
.style("opacity", 0);
|
| 280 |
+
|
| 281 |
+
// Create cells
|
| 282 |
+
for (let layer = 0; layer < numLayers; layer++) {
|
| 283 |
+
for (let head = 0; head < numHeads; head++) {
|
| 284 |
+
const contribution = headContributions.contributions[layer][head];
|
| 285 |
+
|
| 286 |
+
g.append("rect")
|
| 287 |
+
.attr("class", "head-cell")
|
| 288 |
+
.attr("x", head * cellSize)
|
| 289 |
+
.attr("y", layer * cellSize)
|
| 290 |
+
.attr("width", cellSize)
|
| 291 |
+
.attr("height", cellSize)
|
| 292 |
+
.attr("fill", colorScale(contribution))
|
| 293 |
+
.on("mouseover", function(event) {
|
| 294 |
+
d3.select(this).attr("stroke", "#000").attr("stroke-width", 2);
|
| 295 |
+
tooltip.transition().duration(200).style("opacity", 1);
|
| 296 |
+
tooltip.html(`Layer ${layer}, Head ${head}<br>Contribution: ${contribution.toFixed(2)}%`)
|
| 297 |
+
.style("left", (event.pageX + 10) + "px")
|
| 298 |
+
.style("top", (event.pageY - 28) + "px");
|
| 299 |
+
})
|
| 300 |
+
.on("mouseout", function() {
|
| 301 |
+
d3.select(this).attr("stroke", "#ddd").attr("stroke-width", 1);
|
| 302 |
+
tooltip.transition().duration(500).style("opacity", 0);
|
| 303 |
+
});
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// Add layer labels
|
| 308 |
+
for (let layer = 0; layer < numLayers; layer++) {
|
| 309 |
+
g.append("text")
|
| 310 |
+
.attr("class", "layer-label")
|
| 311 |
+
.attr("x", -10)
|
| 312 |
+
.attr("y", layer * cellSize + cellSize / 2)
|
| 313 |
+
.attr("text-anchor", "end")
|
| 314 |
+
.attr("dominant-baseline", "middle")
|
| 315 |
+
.text(`L${layer}`);
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
// Add head labels
|
| 319 |
+
for (let head = 0; head < numHeads; head++) {
|
| 320 |
+
g.append("text")
|
| 321 |
+
.attr("class", "head-label")
|
| 322 |
+
.attr("x", head * cellSize + cellSize / 2)
|
| 323 |
+
.attr("y", -10)
|
| 324 |
+
.attr("text-anchor", "middle")
|
| 325 |
+
.attr("dominant-baseline", "central")
|
| 326 |
+
.text(`H${head}`);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
// Add title
|
| 330 |
+
svg.append("text")
|
| 331 |
+
.attr("x", width / 2)
|
| 332 |
+
.attr("y", 20)
|
| 333 |
+
.attr("text-anchor", "middle")
|
| 334 |
+
.style("font-size", "16px")
|
| 335 |
+
.style("font-weight", "bold")
|
| 336 |
+
.text("Head Contributions to Log Probability (%)");
|
| 337 |
+
};
|
| 338 |
+
|
| 339 |
+
// Create both visualizations
|
| 340 |
+
createLayerChart();
|
| 341 |
+
createHeadHeatmap();
|
| 342 |
+
});
|
| 343 |
+
</script>
|
| 344 |
+
{% endif %}
|
| 345 |
+
</body>
|
| 346 |
+
</html>
|