import torch import numpy as np import os import sys from flask import Flask, render_template, request, jsonify from transformer_lens import HookedTransformer import json app = Flask(__name__) # Use environment variables for cache locations cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/home/appuser/cache') hf_home = os.environ.get('HF_HOME', '/home/appuser/hf_home') os.environ['TRANSFORMERS_CACHE'] = cache_dir os.environ['HF_HOME'] = hf_home print(f"Using cache directories: {cache_dir} and {hf_home}", file=sys.stderr) # Load GPT-2 small model try: print("Loading GPT-2 small model...", file=sys.stderr) model = HookedTransformer.from_pretrained( "gpt2", # 'gpt2' is the correct model ID on Hugging Face center_unembed=True, center_writing_weights=True, fold_ln=True, refactor_factored_attn_matrices=True, ) print("Model loaded successfully!", file=sys.stderr) except Exception as e: print(f"Error loading model: {e}", file=sys.stderr) # Continue with app setup even if model fails to load # This allows the app to start and show an error message model = None @app.route('/', methods=['GET', 'POST']) def index(): prediction = None text = "" head_contributions = None error_message = None if model is None: error_message = "Model failed to load. Please check the logs or try again later." elif request.method == 'POST': text = request.form.get('text', '') if text: try: # Tokenize the input text tokens = model.to_tokens(text, prepend_bos=True) # Run the model with cache to get intermediate activations logits, cache = model.run_with_cache(tokens) # Get logits for the last token last_token_logits = logits[0, -1] # Get the index of the token with the highest logit top_token_idx = torch.argmax(last_token_logits).item() # Get the logit value top_token_logit = last_token_logits[top_token_idx].item() # Get the probability probs = torch.nn.functional.softmax(last_token_logits, dim=-1) top_token_prob = probs[top_token_idx].item() * 100 # Convert to percentage # Get the token as a string top_token_str = model.to_string([top_token_idx]) # Get attention head contributions for the top token head_contributions = calculate_head_contributions(cache, top_token_idx, model) prediction = { 'token': top_token_str, 'logit': top_token_logit, 'prob': top_token_prob } except Exception as e: error_message = f"Error processing request: {str(e)}" print(f"Error in processing: {e}", file=sys.stderr) return render_template('index.html', prediction=prediction, text=text, head_contributions=json.dumps(head_contributions) if head_contributions else None, error_message=error_message) def calculate_head_contributions(cache, token_idx, model): """Calculate the contribution of each attention head to the top token's logit.""" # Get all head outputs for the last token head_outputs_by_layer = [] contributions = [] layer_total_contributions = [] # Get the direction in the residual stream that corresponds to the token token_direction = model.W_U[:, token_idx].detach() # Calculate contributions for each head for layer in range(model.cfg.n_layers): # Get the output of each head at the last position z = cache["z", layer][0, -1] # [head, d_head] # Apply the OV matrix for each head head_outputs = torch.einsum("hd,hdm->hm", z, model.W_O[layer]) # [head, d_model] # Project onto the token direction to get contribution to the logit head_contribs = torch.einsum("hm,m->h", head_outputs, token_direction) # Calculate total contribution for this layer layer_total = head_contribs.sum().item() layer_total_contributions.append(layer_total) # Convert to list for JSON serialization layer_contributions = head_contribs.detach().cpu().numpy().tolist() contributions.append(layer_contributions) # Calculate total contribution across all heads total_contribution = sum([sum(layer_contrib) for layer_contrib in contributions]) # Convert contributions to percentage of total percentage_contributions = [] for layer_contributions in contributions: percentage_layer = [(contrib / total_contribution) * 100 for contrib in layer_contributions] percentage_contributions.append(percentage_layer) # Calculate per-layer contribution percentages layer_percentages = [(layer_total / total_contribution) * 100 for layer_total in layer_total_contributions] # Get the max and min values for normalization in visualization all_contribs_pct = np.array(percentage_contributions).flatten() max_contrib = float(np.max(all_contribs_pct)) min_contrib = float(np.min(all_contribs_pct)) return { "contributions": percentage_contributions, "max_value": max_contrib, "min_value": min_contrib, "layer_contributions": layer_percentages } if __name__ == '__main__': app.run(host="0.0.0.0", port=7860, debug=False)