import gradio as gr
import torch
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, AutoModelForCausalLM
import html
import time
DEFAULT_MODEL = "distilgpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_CACHE = {}
# ---------------- Fullscreen helper (injected JS per plot) ----------------
def fullscreen_plot_js(plot_id):
safe_id = plot_id.replace("-", "_")
return f"""
"""
def fullscreen_button_html(plot_id, label="π Full Screen"):
# wrapper HTML: JS function + button
return fullscreen_plot_js(plot_id) + f''
# ---------------- CORE UTILS ----------------
def load_model(model_name):
if model_name in MODEL_CACHE:
return MODEL_CACHE[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, output_attentions=True, output_hidden_states=True
).to(DEVICE)
model.eval()
MODEL_CACHE[model_name] = (model, tokenizer)
return model, tokenizer
def softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum(axis=-1, keepdims=True)
def safe_tokens(tokens):
return " ".join([f"[{html.escape(t)}]" for t in tokens])
def compute_pca(hidden_layer):
try:
return PCA(n_components=2).fit_transform(hidden_layer)
except:
seq = hidden_layer.shape[0]
dim0 = hidden_layer[:, 0] if hidden_layer.shape[1] > 0 else np.zeros(seq)
dim1 = hidden_layer[:, 1] if hidden_layer.shape[1] > 1 else np.zeros(seq)
return np.vstack([dim0, dim1]).T
def fig_attention(matrix, tokens, title):
fig = px.imshow(matrix, x=tokens, y=tokens, title=title,
labels={"x": "Key token", "y": "Query token", "color": "Attention"})
fig.update_layout(height=420)
return fig
def fig_pca(points, tokens, highlight=None, title="PCA"):
fig = px.scatter(x=points[:, 0], y=points[:, 1], text=tokens, title=title)
fig.update_traces(textposition="top center", marker=dict(size=10))
if highlight is not None:
fig.add_trace(go.Scatter(
x=[points[highlight, 0]],
y=[points[highlight, 1]],
mode="markers+text",
text=[tokens[highlight]],
marker=dict(size=18, color="red")
))
fig.update_layout(height=420)
return fig
def fig_probs(tokens, scores):
fig = go.Figure()
fig.add_trace(go.Bar(x=tokens, y=scores))
fig.update_layout(title="Next-token probabilities", height=380)
return fig
# ---------------- ANALYSIS CORE ----------------
def analyze_text(text, model_name, simple):
if not text.strip():
return {"error": "Please enter text."}
try:
model, tokenizer = load_model(model_name)
except Exception as e:
return {"error": f"Failed to load model: {e}"}
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
try:
with torch.no_grad():
out = model(**inputs)
except Exception as e:
return {"error": f"Model error: {e}"}
input_ids = inputs["input_ids"][0].cpu().numpy().tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids)
attentions = [a[0].cpu().numpy() for a in out.attentions]
hidden = [h[0].cpu().numpy() for h in out.hidden_states]
logits = out.logits[0].cpu().numpy()
# PCA per layer
pca_layers = [compute_pca(h) for h in hidden]
# top-k
last = logits[-1]
probs = softmax(last)
idx = np.argsort(probs)[-20:][::-1]
top_tokens = [tokenizer.decode([i]) for i in idx]
top_scores = probs[idx].tolist()
default_layer = len(attentions) - 1
default_head = 0
# neuron explorer
neuron_info = []
try:
last_h = hidden[-1]
mean_act = np.abs(last_h).mean(axis=0)
top_neurons = np.argsort(mean_act)[-24:][::-1]
for n in top_neurons:
vals = last_h[:, n]
top_ix = np.argsort(np.abs(vals))[-5:][::-1]
neuron_info.append({
"neuron": int(n),
"top_tokens": [(tokens[i], float(vals[i])) for i in top_ix]
})
except:
neuron_info = []
# residual decomposition (safe)
residuals = compute_residuals_safe(model, inputs)
return {
"tokens": tokens,
"attentions": attentions,
"hidden": hidden,
"pca": pca_layers,
"logits": logits,
"top_tokens": top_tokens,
"top_scores": top_scores,
"default_layer": default_layer,
"default_head": default_head,
"neuron_info": neuron_info,
"residuals": residuals,
"token_display": safe_tokens(tokens),
"explanation": explain(simple)
}
def explain(s):
if s:
return (
"π§ **Simple mode:**\n"
"- The model cuts text into small pieces (tokens).\n"
"- It looks at which tokens matter (attention).\n"
"- It builds an internal map (PCA) of meanings.\n"
"- Then it guesses the next token.\n"
)
return (
"π¬ **Technical mode:**\n"
"Showing tokens, attention (queryβkey), PCA projections, logits, "
"neuron activations, and layerwise residual contributions.\n"
)
# ---------------- RESIDUAL DECOMPOSITION SAFE ----------------
def compute_residuals_safe(model, inputs):
"""
Guaranteed safe residual norms for GPT-2-style blocks.
Will NEVER crash. Returns None if not applicable.
"""
if not hasattr(model, "transformer") or not hasattr(model.transformer, "h"):
return None
try:
blocks = model.transformer.h
wte = model.transformer.wte
x = wte(inputs["input_ids"]).to(DEVICE)
attn_norms = []
mlp_norms = []
for block in blocks:
try:
ln1 = block.ln_1(x)
attn_out = block.attn(ln1)[0]
x = x + attn_out
ln2 = block.ln_2(x)
mlp_out = block.mlp(ln2)
x = x + mlp_out
# detach to avoid requires_grad warning
attn_norms.append(float(torch.norm(attn_out.detach()).cpu()))
mlp_norms.append(float(torch.norm(mlp_out.detach()).cpu()))
except:
# fallback safe zero
attn_norms.append(0.0)
mlp_norms.append(0.0)
# normalize lengths safely
L = min(len(attn_norms), len(mlp_norms))
return {
"attn": attn_norms[:L],
"mlp": mlp_norms[:L],
}
except:
return None
# ---------------- ACTIVATION PATCHING (SAFE VERSION) ----------------
def activation_patch(tokens, model_name, layer, pos, from_pos, scale=1.0):
"""
Safe activation patching (never crashes, only works for GPT-2 style).
"""
try:
model, tokenizer = load_model(model_name)
except:
return {"error": "Model load error."}
if not hasattr(model, "transformer") or not hasattr(model.transformer, "h"):
return {"error": "Model not compatible with patching."}
text = " ".join(tokens)
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
blocks = model.transformer.h
wte = model.transformer.wte
ln_f = model.transformer.ln_f if hasattr(model.transformer, "ln_f") else None
lm_head = model.lm_head
with torch.no_grad():
x = wte(inputs["input_ids"]).to(DEVICE)
hidden_layers = [x.clone().cpu().numpy()[0]]
for b in blocks:
ln1 = b.ln_1(x)
a = b.attn(ln1)[0]
x = x + a
ln2 = b.ln_2(x)
m = b.mlp(ln2)
x = x + m
hidden_layers.append(x.clone().cpu().numpy()[0])
if layer >= len(hidden_layers):
return {"error": "Layer out of range."}
seq_len = hidden_layers[layer].shape[0]
if pos >= seq_len or from_pos >= seq_len:
return {"error": "Position out of range."}
patch_vec = torch.tensor(hidden_layers[layer][from_pos], dtype=torch.float32).to(DEVICE) * float(scale)
# re-run with patch
with torch.no_grad():
x = wte(inputs["input_ids"]).to(DEVICE)
for i, b in enumerate(blocks):
ln1 = b.ln_1(x)
a = b.attn(ln1)[0]
x = x + a
ln2 = b.ln_2(x)
m = b.mlp(ln2)
x = x + m
if i == layer:
x[0, pos, :] = patch_vec
final = ln_f(x) if ln_f else x
logits = lm_head(final)[0, -1, :].cpu().numpy()
probs = softmax(logits)
idx = np.argsort(probs)[-20:][::-1]
tt = [tokenizer.decode([int(i)]) for i in idx]
ss = probs[idx].tolist()
return {"tokens": tt, "scores": ss}
# ---------------- GRADIO UI ----------------
with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
gr.Markdown("# π§ Full LLM Visualizer (Advanced)")
gr.Markdown("Fully stable build with attention, PCA, neuron explorer, residuals, activation-patching")
# Panel 1
with gr.Row():
with gr.Column(scale=3):
model_name = gr.Textbox(label="Model", value=DEFAULT_MODEL)
input_text = gr.Textbox(label="Input", value="Hello world", lines=3)
simple = gr.Checkbox(label="Explain simply", value=True)
run_btn = gr.Button("Run", variant="primary")
gr.Markdown("Presets:")
with gr.Row():
gr.Button("Greeting").click(lambda: "Hello! How are you?", None, input_text)
gr.Button("Story").click(lambda: "Once upon a time there was a robot.", None, input_text)
gr.Button("Question").click(lambda: "Why is the sky blue?", None, input_text)
with gr.Column(scale=2):
token_display = gr.Markdown()
explanation_md = gr.Markdown()
model_info = gr.Markdown()
# Panel 2
with gr.Row():
with gr.Column():
layer_slider = gr.Slider(0, 0, value=0, step=1, label="Layer")
head_slider = gr.Slider(0, 0, value=0, step=1, label="Head")
token_step = gr.Slider(0, 0, value=0, step=1, label="Token index")
attn_plot = gr.Plot(elem_id="attn_plot")
attn_fs = gr.HTML(fullscreen_button_html("attn_plot"))
with gr.Column():
pca_plot = gr.Plot(elem_id="pca_plot")
pca_fs = gr.HTML(fullscreen_button_html("pca_plot"))
step_attn_plot = gr.Plot(elem_id="step_attn_plot")
step_fs = gr.HTML(fullscreen_button_html("step_attn_plot"))
probs_plot = gr.Plot(elem_id="probs_plot")
probs_fs = gr.HTML(fullscreen_button_html("probs_plot"))
# Panel 3 β Residuals
residual_plot = gr.Plot(elem_id="residual_plot")
residual_fs = gr.HTML(fullscreen_button_html("residual_plot"))
# Panel 4 β Neuron explorer
with gr.Row():
neuron_find_btn = gr.Button("Find neurons")
neuron_idx = gr.Number(label="Neuron index", value=0)
neuron_table = gr.Dataframe(headers=["token", "activation"], interactive=False)
# Panel 5 β Activation Patching
with gr.Row():
patch_layer = gr.Slider(0, 0, value=0, step=1, label="Patch layer")
patch_pos = gr.Slider(0, 0, value=0, step=1, label="Target token position")
patch_from = gr.Slider(0, 0, value=0, step=1, label="Copy from position")
patch_scale = gr.Number(label="Scale", value=1.0)
patch_btn = gr.Button("Run patch")
patch_output = gr.Plot(elem_id="patch_plot")
patch_fs = gr.HTML(fullscreen_button_html("patch_plot"))
state = gr.State()
# ---- RUN ANALYSIS ----
def run_app(text, model, simp):
res = analyze_text(text, model, simp)
if "error" in res:
return {
token_display: gr.update(value=""),
explanation_md: gr.update(value=res["error"]),
model_info: gr.update(value=f"Model: {model}"),
attn_plot: gr.update(value=None),
pca_plot: gr.update(value=None),
probs_plot: gr.update(value=None),
layer_slider: gr.update(maximum=0, value=0),
head_slider: gr.update(maximum=0, value=0),
token_step: gr.update(maximum=0, value=0),
residual_plot: gr.update(value=None),
neuron_table: gr.update(value=[]),
patch_layer: gr.update(maximum=0),
patch_pos: gr.update(maximum=0),
patch_from: gr.update(maximum=0),
state: res,
step_attn_plot: gr.update(value=None),
patch_output: gr.update(value=None),
}
tokens = res["tokens"]
L = len(res["attentions"])
H = res["attentions"][0].shape[0]
T = len(tokens) - 1
residual_fig = None
if res["residuals"]:
attn_vals = res["residuals"]["attn"]
ml_vals = res["residuals"]["mlp"]
Lmin = min(len(attn_vals), len(ml_vals))
df = pd.DataFrame({
"layer": list(range(Lmin)),
"attention": attn_vals[:Lmin],
"mlp": ml_vals[:Lmin]
})
fig = go.Figure()
fig.add_trace(go.Bar(x=df["layer"], y=df["attention"], name="Attention norm"))
fig.add_trace(go.Bar(x=df["layer"], y=df["mlp"], name="MLP norm"))
fig.update_layout(barmode="group", height=360)
residual_fig = fig
return {
token_display: gr.update(value=f"**Tokens:** {res['token_display']}"),
explanation_md: gr.update(value=res["explanation"]),
model_info: gr.update(value=f"Model: {model} β’ layers: {L} β’ heads: {H} β’ tokens: {len(tokens)}"),
attn_plot: gr.update(value=res["fig_attn"] if res.get("fig_attn") else None),
pca_plot: gr.update(value=res["fig_pca"] if res.get("fig_pca") else None),
probs_plot: gr.update(value=fig_probs(res["top_tokens"], res["top_scores"])),
layer_slider: gr.update(maximum=L-1, value=res["default_layer"]),
head_slider: gr.update(maximum=H-1, value=res["default_head"]),
token_step: gr.update(maximum=T, value=0),
residual_plot: gr.update(value=residual_fig),
neuron_table: gr.update(value=[[t, round(v,4)] for t,v in res["neuron_info"][0]["top_tokens"]] if res["neuron_info"] else []),
patch_layer: gr.update(maximum=L-1, value=0),
patch_pos: gr.update(maximum=T, value=0),
patch_from: gr.update(maximum=T, value=0),
state: res,
step_attn_plot: gr.update(value=None),
patch_output: gr.update(value=None),
}
run_btn.click(
run_app,
inputs=[input_text, model_name, simple],
outputs=[
token_display, explanation_md, model_info,
attn_plot, pca_plot, probs_plot,
layer_slider, head_slider, token_step,
residual_plot, neuron_table,
patch_layer, patch_pos, patch_from,
state, step_attn_plot, patch_output
]
)
# ---- SLIDER UPDATES ----
def update_view(res, layer, head, tok):
if not res or "error" in res:
return {
attn_plot: gr.update(value=None),
pca_plot: gr.update(value=None),
step_attn_plot: gr.update(value=None),
}
tokens = res["tokens"]
layer = min(max(0, layer), len(res["attentions"]) - 1)
head = min(max(0, head), res["attentions"][0].shape[0] - 1)
tok = min(max(0, tok), len(tokens) - 1)
att = fig_attention(res["attentions"][layer][head], tokens, f"Layer {layer} Head {head}")
pts = res["pca"][layer]
pca_fig = fig_pca(pts, tokens, highlight=tok, title=f"PCA Layer {layer}")
row = res["attentions"][layer][head][tok]
step_fig = go.Figure([go.Bar(x=tokens, y=row)])
step_fig.update_layout(title=f"Token {tok} attends to")
return {
attn_plot: gr.update(value=att),
pca_plot: gr.update(value=pca_fig),
step_attn_plot: gr.update(value=step_fig)
}
layer_slider.change(update_view, [state, layer_slider, head_slider, token_step],
[attn_plot, pca_plot, step_attn_plot])
head_slider.change(update_view, [state, layer_slider, head_slider, token_step],
[attn_plot, pca_plot, step_attn_plot])
token_step.change(update_view, [state, layer_slider, head_slider, token_step],
[attn_plot, pca_plot, step_attn_plot])
# ---- NEURON EXPLORER ----
def neuron_auto(res):
if not res or "neuron_info" not in res:
return gr.update(value=[])
rows = []
for item in res["neuron_info"]:
for t, v in item["top_tokens"]:
rows.append([t, round(v,4)])
df = pd.DataFrame(rows, columns=["token","activation"]).drop_duplicates().head(24)
return gr.update(value=df.values.tolist())
neuron_find_btn.click(neuron_auto, [state], [neuron_table])
def neuron_manual(res, idx):
if not res or "hidden" not in res:
return gr.update(value=[])
try:
idx = int(idx)
except:
return gr.update(value=[])
last = res["hidden"][-1]
if idx >= last.shape[1]:
return gr.update(value=[])
vals = last[:, idx]
tokens = res["tokens"]
pairs = sorted([(tokens[i], float(vals[i])) for i in range(len(tokens))],
key=lambda x: -abs(x[1]))[:12]
return gr.update(value=[[t, round(v,4)] for t,v in pairs])
neuron_idx.change(neuron_manual, [state, neuron_idx], [neuron_table])
# ---- ACTIVATION PATCHING ----
def patch_run(res, L, P, FP, S, model):
if not res or "tokens" not in res:
return gr.update(value=None)
out = activation_patch(res["tokens"], model, int(L), int(P), int(FP), float(S))
if "error" in out:
return gr.update(value=None)
fig = fig_probs(out["tokens"], out["scores"])
return gr.update(value=fig)
patch_btn.click(patch_run,
[state, patch_layer, patch_pos, patch_from, patch_scale, model_name],
[patch_output])
demo.launch()