Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import os | |
| import sys | |
| from flask import Flask, render_template, request, jsonify | |
| from transformer_lens import HookedTransformer | |
| import json | |
| app = Flask(__name__) | |
| # Use environment variables for cache locations | |
| cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/home/appuser/cache') | |
| hf_home = os.environ.get('HF_HOME', '/home/appuser/hf_home') | |
| os.environ['TRANSFORMERS_CACHE'] = cache_dir | |
| os.environ['HF_HOME'] = hf_home | |
| print(f"Using cache directories: {cache_dir} and {hf_home}", file=sys.stderr) | |
| # Load GPT-2 small model | |
| try: | |
| print("Loading GPT-2 small model...", file=sys.stderr) | |
| model = HookedTransformer.from_pretrained( | |
| "gpt2", # 'gpt2' is the correct model ID on Hugging Face | |
| center_unembed=True, | |
| center_writing_weights=True, | |
| fold_ln=True, | |
| refactor_factored_attn_matrices=True, | |
| ) | |
| print("Model loaded successfully!", file=sys.stderr) | |
| except Exception as e: | |
| print(f"Error loading model: {e}", file=sys.stderr) | |
| # Continue with app setup even if model fails to load | |
| # This allows the app to start and show an error message | |
| model = None | |
| def index(): | |
| prediction = None | |
| text = "" | |
| head_contributions = None | |
| error_message = None | |
| if model is None: | |
| error_message = "Model failed to load. Please check the logs or try again later." | |
| elif request.method == 'POST': | |
| text = request.form.get('text', '') | |
| if text: | |
| try: | |
| # Tokenize the input text | |
| tokens = model.to_tokens(text, prepend_bos=True) | |
| # Run the model with cache to get intermediate activations | |
| logits, cache = model.run_with_cache(tokens) | |
| # Get logits for the last token | |
| last_token_logits = logits[0, -1] | |
| # Get the index of the token with the highest logit | |
| top_token_idx = torch.argmax(last_token_logits).item() | |
| # Get the logit value | |
| top_token_logit = last_token_logits[top_token_idx].item() | |
| # Get the probability | |
| probs = torch.nn.functional.softmax(last_token_logits, dim=-1) | |
| top_token_prob = probs[top_token_idx].item() * 100 # Convert to percentage | |
| # Get the token as a string | |
| top_token_str = model.to_string([top_token_idx]) | |
| # Get attention head contributions for the top token | |
| head_contributions = calculate_head_contributions(cache, top_token_idx, model) | |
| prediction = { | |
| 'token': top_token_str, | |
| 'logit': top_token_logit, | |
| 'prob': top_token_prob | |
| } | |
| except Exception as e: | |
| error_message = f"Error processing request: {str(e)}" | |
| print(f"Error in processing: {e}", file=sys.stderr) | |
| return render_template('index.html', prediction=prediction, text=text, | |
| head_contributions=json.dumps(head_contributions) if head_contributions else None, | |
| error_message=error_message) | |
| def calculate_head_contributions(cache, token_idx, model): | |
| """Calculate the contribution of each attention head to the top token's logit.""" | |
| # Get all head outputs for the last token | |
| head_outputs_by_layer = [] | |
| contributions = [] | |
| layer_total_contributions = [] | |
| # Get the direction in the residual stream that corresponds to the token | |
| token_direction = model.W_U[:, token_idx].detach() | |
| # Calculate contributions for each head | |
| for layer in range(model.cfg.n_layers): | |
| # Get the output of each head at the last position | |
| z = cache["z", layer][0, -1] # [head, d_head] | |
| # Apply the OV matrix for each head | |
| head_outputs = torch.einsum("hd,hdm->hm", z, model.W_O[layer]) # [head, d_model] | |
| # Project onto the token direction to get contribution to the logit | |
| head_contribs = torch.einsum("hm,m->h", head_outputs, token_direction) | |
| # Calculate total contribution for this layer | |
| layer_total = head_contribs.sum().item() | |
| layer_total_contributions.append(layer_total) | |
| # Convert to list for JSON serialization | |
| layer_contributions = head_contribs.detach().cpu().numpy().tolist() | |
| contributions.append(layer_contributions) | |
| # Calculate total contribution across all heads | |
| total_contribution = sum([sum(layer_contrib) for layer_contrib in contributions]) | |
| # Convert contributions to percentage of total | |
| percentage_contributions = [] | |
| for layer_contributions in contributions: | |
| percentage_layer = [(contrib / total_contribution) * 100 for contrib in layer_contributions] | |
| percentage_contributions.append(percentage_layer) | |
| # Calculate per-layer contribution percentages | |
| layer_percentages = [(layer_total / total_contribution) * 100 for layer_total in layer_total_contributions] | |
| # Get the max and min values for normalization in visualization | |
| all_contribs_pct = np.array(percentage_contributions).flatten() | |
| max_contrib = float(np.max(all_contribs_pct)) | |
| min_contrib = float(np.min(all_contribs_pct)) | |
| return { | |
| "contributions": percentage_contributions, | |
| "max_value": max_contrib, | |
| "min_value": min_contrib, | |
| "layer_contributions": layer_percentages | |
| } | |
| if __name__ == '__main__': | |
| app.run(host="0.0.0.0", port=7860, debug=False) |