mechvis / app.py
SaiMupparaju
Fix permission issues by using a non-root user and dedicated cache directories
c6e1c92
import torch
import numpy as np
import os
import sys
from flask import Flask, render_template, request, jsonify
from transformer_lens import HookedTransformer
import json
app = Flask(__name__)
# Use environment variables for cache locations
cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/home/appuser/cache')
hf_home = os.environ.get('HF_HOME', '/home/appuser/hf_home')
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME'] = hf_home
print(f"Using cache directories: {cache_dir} and {hf_home}", file=sys.stderr)
# Load GPT-2 small model
try:
print("Loading GPT-2 small model...", file=sys.stderr)
model = HookedTransformer.from_pretrained(
"gpt2", # 'gpt2' is the correct model ID on Hugging Face
center_unembed=True,
center_writing_weights=True,
fold_ln=True,
refactor_factored_attn_matrices=True,
)
print("Model loaded successfully!", file=sys.stderr)
except Exception as e:
print(f"Error loading model: {e}", file=sys.stderr)
# Continue with app setup even if model fails to load
# This allows the app to start and show an error message
model = None
@app.route('/', methods=['GET', 'POST'])
def index():
prediction = None
text = ""
head_contributions = None
error_message = None
if model is None:
error_message = "Model failed to load. Please check the logs or try again later."
elif request.method == 'POST':
text = request.form.get('text', '')
if text:
try:
# Tokenize the input text
tokens = model.to_tokens(text, prepend_bos=True)
# Run the model with cache to get intermediate activations
logits, cache = model.run_with_cache(tokens)
# Get logits for the last token
last_token_logits = logits[0, -1]
# Get the index of the token with the highest logit
top_token_idx = torch.argmax(last_token_logits).item()
# Get the logit value
top_token_logit = last_token_logits[top_token_idx].item()
# Get the probability
probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
top_token_prob = probs[top_token_idx].item() * 100 # Convert to percentage
# Get the token as a string
top_token_str = model.to_string([top_token_idx])
# Get attention head contributions for the top token
head_contributions = calculate_head_contributions(cache, top_token_idx, model)
prediction = {
'token': top_token_str,
'logit': top_token_logit,
'prob': top_token_prob
}
except Exception as e:
error_message = f"Error processing request: {str(e)}"
print(f"Error in processing: {e}", file=sys.stderr)
return render_template('index.html', prediction=prediction, text=text,
head_contributions=json.dumps(head_contributions) if head_contributions else None,
error_message=error_message)
def calculate_head_contributions(cache, token_idx, model):
"""Calculate the contribution of each attention head to the top token's logit."""
# Get all head outputs for the last token
head_outputs_by_layer = []
contributions = []
layer_total_contributions = []
# Get the direction in the residual stream that corresponds to the token
token_direction = model.W_U[:, token_idx].detach()
# Calculate contributions for each head
for layer in range(model.cfg.n_layers):
# Get the output of each head at the last position
z = cache["z", layer][0, -1] # [head, d_head]
# Apply the OV matrix for each head
head_outputs = torch.einsum("hd,hdm->hm", z, model.W_O[layer]) # [head, d_model]
# Project onto the token direction to get contribution to the logit
head_contribs = torch.einsum("hm,m->h", head_outputs, token_direction)
# Calculate total contribution for this layer
layer_total = head_contribs.sum().item()
layer_total_contributions.append(layer_total)
# Convert to list for JSON serialization
layer_contributions = head_contribs.detach().cpu().numpy().tolist()
contributions.append(layer_contributions)
# Calculate total contribution across all heads
total_contribution = sum([sum(layer_contrib) for layer_contrib in contributions])
# Convert contributions to percentage of total
percentage_contributions = []
for layer_contributions in contributions:
percentage_layer = [(contrib / total_contribution) * 100 for contrib in layer_contributions]
percentage_contributions.append(percentage_layer)
# Calculate per-layer contribution percentages
layer_percentages = [(layer_total / total_contribution) * 100 for layer_total in layer_total_contributions]
# Get the max and min values for normalization in visualization
all_contribs_pct = np.array(percentage_contributions).flatten()
max_contrib = float(np.max(all_contribs_pct))
min_contrib = float(np.min(all_contribs_pct))
return {
"contributions": percentage_contributions,
"max_value": max_contrib,
"min_value": min_contrib,
"layer_contributions": layer_percentages
}
if __name__ == '__main__':
app.run(host="0.0.0.0", port=7860, debug=False)