Replicating The Circuit Kings
==================
lets get wild
- i write with zero autocorrect. literally freeballing. as i’ve said before, if the shoggoth will replicate me, may it do so in all my glory.
on to the show…
I started reading the Anthropic blog called “Circuit Tracing: Revealing Computational Graphs in Language Models” and thought it would be cool to replicate it on gpt2-small.
for those who didn’t read it, the abstract is as follows:
We introduce a method to uncover mechanisms underlying behaviors of language models. We produce graph descriptions of the model’s computation on prompts of interest by tracing individual computational steps in a “replacement model”. This replacement model substitutes a more interpretable component (here, a “cross-layer transcoder”) for parts of the underlying model (here, the multi-layer perceptrons) that it is trained to approximate. We develop a suite of visualization and validation tools we use to investigate these “attribution graphs” supporting simple behaviors of an 18-layer language model, and lay the groundwork for a companion paper applying these methods to a frontier model, Claude 3.5 Haiku.
building a replacement model
import sys
import os
import json
import glob
import math
import random
from typing import List, Tuple, Dict, Optional
import transformer_lens
import transformer_lens.HookedTransformer as HookedTransformer
import transformer_lens.HookedTransformerConfig as HookedTransformerConfig
import numpy as np
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import AutoTokenizer
from datasets import load_dataset
create some data with transformer lens
because the paper deals with extracting activations from the model, rather than doing this at inference time, we’ll shortcut this by just running the model’s forward pass and storing the activations in a dataset.
def sample_fineweb_data(n_samples=5000):
"""
Streams the FineWeb dataset and yields up to n_samples items.
"""
ds_stream = load_dataset(
"HuggingFaceFW/fineweb",
name="CC-MAIN-2024-10",
split="train",
streaming=True
)
# ds_stream is an iterable dataset. We'll just take the first N examples
data_iter = iter(ds_stream)
for i in range(n_samples):
try:
item = next(data_iter)
yield item
except StopIteration:
break
model_name = "gpt2-small"
model = HookedTransformer.from_pretrained(
model_name,
device="cuda" if torch.cuda.is_available() else "cpu"
)
resid_cache = {}
mlp_out_cache = {}
def hook_resid_pre(activation, hook):
resid_cache[hook.name] = activation.detach().cpu()
def hook_mlp_out(activation, hook):
mlp_out_cache[hook.name] = activation.detach().cpu()
# Register hooks for each layer
for layer_idx in range(model.cfg.n_layers):
# e.g. 'blocks.0.hook_resid_pre'
resid_name = f"blocks.{layer_idx}.hook_resid_pre"
mlp_name = f"blocks.{layer_idx}.hook_mlp_out"
model.add_hook(resid_name, hook_resid_pre, "fwd")
model.add_hook(mlp_name, hook_mlp_out, "fwd")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# GPT-2 typically doesn't have a pad token, we'll just allow a new token or do a truncation approach
tokenizer.pad_token = tokenizer.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"
def collect_hidden_states_and_save(
text_list,
batch_idx,
output_dir="clt_data",
max_seq_len=128,
):
"""
text_list: list of raw strings for a single batch
batch_idx: which batch index we are on
We'll tokenize, run forward pass, store the hidden states in .npz or something.
"""
enc = tokenizer(
text_list,
padding=True,
truncation=True,
max_length=max_seq_len,
return_tensors="pt"
)
enc = {k: v.to(device) for k, v in enc.items()}
# Clear old caches
resid_cache.clear()
mlp_out_cache.clear()
_ = model(enc["input_ids"]) # triggers hooks
# Now gather the hidden states from resid_cache, mlp_out_cache
# They will have shape: [batch_size, seq_len, d_model]
# For each layer we have e.g. resid_cache["blocks.0.hook_resid_pre"]
# We can store them in an npz
layer_data = {}
for layer_idx in range(model.cfg.n_layers):
rname = f"blocks.{layer_idx}.hook_resid_pre"
mname = f"blocks.{layer_idx}.hook_mlp_out"
# shape: [batch, seq, d_model]
resid_arr = resid_cache[rname].numpy()
mlp_arr = mlp_out_cache[mname].numpy()
layer_data[f"resid_{layer_idx}"] = resid_arr
layer_data[f"mlp_{layer_idx}"] = mlp_arr
# Save to a single file
os.makedirs(output_dir, exist_ok=True)
out_path = os.path.join(output_dir, f"batch_{batch_idx}.npz")
np.savez_compressed(out_path, **layer_data)
print(f"Saved {out_path} with shape: {resid_arr.shape}")
def main_collect_5k(output_dir="clt_data"):
model_name = "gpt2-small"
model = HookedTransformer.from_pretrained(
model_name,
device="cuda" if torch.cuda.is_available() else "cpu"
)
resid_cache = {}
mlp_out_cache = {}
def hook_resid_pre(activation, hook):
resid_cache[hook.name] = activation.detach().cpu()
def hook_mlp_out(activation, hook):
mlp_out_cache[hook.name] = activation.detach().cpu()
# Register hooks for each layer
for layer_idx in range(model.cfg.n_layers):
# e.g. 'blocks.0.hook_resid_pre'
resid_name = f"blocks.{layer_idx}.hook_resid_pre"
mlp_name = f"blocks.{layer_idx}.hook_mlp_out"
model.add_hook(resid_name, hook_resid_pre, "fwd")
model.add_hook(mlp_name, hook_mlp_out, "fwd")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# GPT-2 typically doesn't have a pad token, we'll just allow a new token or do a truncation approach
tokenizer.pad_token = tokenizer.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16
buffer = []
batch_index = 0
from tqdm import tqdm
for doc_idx, record in tqdm(enumerate(sample_fineweb_data(n_samples=5000))):
text = record["text"] # or whichever field is correct
# Add to buffer
buffer.append(text)
if len(buffer) >= batch_size:
collect_hidden_states_and_save(buffer, batch_index, output_dir)
buffer = []
batch_index += 1
# leftover
if buffer:
collect_hidden_states_and_save(buffer, batch_index, output_dir)
# let's collect 5k samples
main_collect_5k()
create a dataset iterator
def normalize_vector(vec, eps=1e-9):
# vec shape: [d_model] or [batch, d_model]
# We'll do L2 norm along the last dim
norm = vec.norm(dim=-1, keepdim=True).clamp_min(eps)
return vec / norm
class CLTHiddenStateDataset(Dataset):
def __init__(self, data_dir="clt_data", layer_count=12):
self.files = sorted(glob.glob(f"{data_dir}/batch_*.npz"))
# We'll store (file_index, idx_in_file) in an index
self.index = []
self.layer_count = layer_count
# We'll parse each file once to see how many items it has
for fi, path in enumerate(self.files):
try:
with np.load(path) as npz:
# e.g. shape of resid_0 is [batch, seq, d_model]
shape = npz["resid_0"].shape
# shape = (b, s, d_model)
num_positions = shape[0] * shape[1] # b*s
for i in range(num_positions):
self.index.append((fi, i))
except Exception as e:
print(f"Error loading {path}: {e}")
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
file_idx, pos_idx = self.index[idx]
path = self.files[file_idx]
with np.load(path) as npz:
# We'll reconstruct which (batch, seq) => (pos_idx // seq_len, pos_idx % seq_len).
# But we need shape info:
shape = npz["resid_0"].shape # [b, s, d]
b, s, d = shape
local_b = pos_idx // s
local_s = pos_idx % s
# Now gather resid and mlp for each layer
# We'll build lists of shape [d_model]
resid_list = []
mlp_list = []
for layer_idx in range(self.layer_count):
rname = f"resid_{layer_idx}"
mname = f"mlp_{layer_idx}"
# shape of r: [b, s, d]
r = npz[rname]# [local_b, local_s] # shape [d]
m = npz[mname]# [local_b, local_s] # shape [d]
# normalize
r = normalize_vector(torch.from_numpy(r).float())
m = normalize_vector(torch.from_numpy(m).float())
resid_list.append(r)
mlp_list.append(m)
return resid_list, mlp_list
cross-layer transcoder

class RectangleFunction(autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return ((x > -0.5) & (x < 0.5)).float()
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[(x <= -0.5) | (x >= 0.5)] = 0
return grad_input
class JumpReLUFunction(autograd.Function):
@staticmethod
def forward(ctx, x, log_threshold, bandwidth):
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
threshold = torch.exp(log_threshold)
return x * (x > threshold).float()
@staticmethod
def backward(ctx, grad_output):
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
bandwidth = bandwidth_tensor.item()
threshold = torch.exp(log_threshold)
x_grad = (x > threshold).float() * grad_output
threshold_grad = (
-(threshold / bandwidth)
* RectangleFunction.apply((x - threshold) / bandwidth)
* grad_output
)
return x_grad, threshold_grad, None # None for bandwidth
class JumpReLU(nn.Module):
def __init__(self, feature_size, bandwidth, device='cpu'):
super(JumpReLU, self).__init__()
self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
self.bandwidth = bandwidth
def forward(self, x):
return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
class StepFunction(autograd.Function):
@staticmethod
def forward(ctx, x, log_threshold, bandwidth):
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
threshold = torch.exp(log_threshold)
return (x > threshold).float()
@staticmethod
def backward(ctx, grad_output):
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
bandwidth = bandwidth_tensor.item()
threshold = torch.exp(log_threshold)
x_grad = torch.zeros_like(x)
threshold_grad = (
-(1.0 / bandwidth)
* RectangleFunction.apply((x - threshold) / bandwidth)
* grad_output
)
return x_grad, threshold_grad, None # None for bandwidth
def uniform_init(tensor, limit):
with torch.no_grad():
tensor.uniform_(-limit, limit)
class CrossLayerTranscoder(nn.Module):
def __init__(self,
layer_feature_list: List[int],
d_model: int,
bandwidth: float = 1.0,
device="cpu"):
"""
layer_feature_list: e.g. [128, 256, 128, ...], length = n_layers
specifying how many features for each layer
d_model: hidden dimension of the underlying Transformer
bandwidth: controls JumpReLU partial derivatives
device: 'cpu' or 'cuda'
"""
super().__init__()
self.n_layers = len(layer_feature_list)
self.layer_feature_list = layer_feature_list
self.d_model = d_model
self.bandwidth = bandwidth
self.device = device
# (1) ENCODERS:
# We'll store a separate W_enc[i] for each layer i, shape: [layer_feature_list[i], d_model].
# We'll place them in a nn.ParameterList so we can do "W_enc[i]" in forward code.
self.W_enc = nn.ParameterList()
for i, feat_count in enumerate(layer_feature_list):
limit = 1.0 / math.sqrt(feat_count)
param = nn.Parameter(torch.empty(feat_count, d_model, device=device))
uniform_init(param, limit)
self.W_enc.append(param)
# (2) DECODERS:
# For each (src -> tgt) with src<=tgt,
# we define a matrix [layer_feature_list[src], d_model].
# We'll store them in a single nn.ParameterList, but keep track in index_map.
self.W_dec = nn.ParameterList()
self.index_map = []
dec_limit = 1.0 / math.sqrt(self.n_layers * d_model)
idx_counter = 0
for src_layer in range(self.n_layers):
row = []
src_feat = layer_feature_list[src_layer]
for tgt_layer in range(self.n_layers):
if tgt_layer >= src_layer:
dec_param = nn.Parameter(torch.empty(src_feat, d_model, device=device))
uniform_init(dec_param, dec_limit)
self.W_dec.append(dec_param)
row.append(idx_counter)
idx_counter += 1
else:
row.append(None)
self.index_map.append(row)
# (3) JumpReLUs:
# If each layer has a different #features, we can either store
# one JumpReLU for each layer, or do something simpler.
# For demonstration, we'll store one per layer:
from torch import autograd
self.jumps = nn.ModuleList()
for i, feat_count in enumerate(layer_feature_list):
self.jumps.append(JumpReLU(feat_count, bandwidth, device=device))
def forward(self, resid_streams: List[torch.Tensor]) -> List[torch.Tensor]:
"""
resid_streams[i]: shape [batch, seq, d_model], for layer i
Returns: list of length n_layers, each [batch, seq, d_model],
the reconstruction for each layer's MLP out
"""
batch_size, seq_len, _ = resid_streams[0].shape
all_activations = [] # a^ℓ for each layer
# 1) ENCODING
for i in range(self.n_layers):
x = resid_streams[i] # [batch, seq, d_model]
W_enc_mat = self.W_enc[i] # [layer_feature_list[i], d_model]
# => a_pre shape [batch, seq, layer_feature_list[i]]
a_pre = torch.einsum("bsd,nd->bsn", x, W_enc_mat)
# jump layer i
a_post = self.jumps[i](a_pre)
all_activations.append(a_post)
# 2) DECODING
# y^ℓ_hat = sum_{ℓ'<=ℓ} W_dec^(ℓ'->ℓ) * a^(ℓ')
mlp_recon = []
for tgt_layer in range(self.n_layers):
recon = torch.zeros(batch_size, seq_len, self.d_model, device=self.device)
for src_layer in range(tgt_layer+1):
dec_idx = self.index_map[src_layer][tgt_layer]
W_dec_mat = self.W_dec[dec_idx] # shape [layer_feature_list[src_layer], d_model]
a_src = all_activations[src_layer] # [batch, seq, layer_feature_list[src_layer]]
# => [batch, seq, d_model]
recon_part = torch.einsum("bsn,nd->bsd", a_src, W_dec_mat)
recon += recon_part
mlp_recon.append(recon)
return mlp_recon
def forward_with_preacts(self, resid_streams: List[torch.Tensor]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Same as forward(), but also returning the raw pre-activations a_pre for each layer.
"""
batch_size, seq_len, _ = resid_streams[0].shape
all_a_pre = []
all_a_post = []
for i in range(self.n_layers):
x = resid_streams[i]
W_enc_mat = self.W_enc[i]
a_pre = torch.einsum("bsd,nd->bsn", x, W_enc_mat)
a_post = self.jumps[i](a_pre)
all_a_pre.append(a_pre)
all_a_post.append(a_post)
mlp_recon = []
for tgt_layer in range(self.n_layers):
recon = torch.zeros(batch_size, seq_len, self.d_model, device=self.device)
for src_layer in range(tgt_layer+1):
dec_idx = self.index_map[src_layer][tgt_layer]
W_dec_mat = self.W_dec[dec_idx]
a_src = all_a_post[src_layer]
recon_part = torch.einsum("bsn,nd->bsd", a_src, W_dec_mat)
recon += recon_part
mlp_recon.append(recon)
return mlp_recon, all_a_pre
@classmethod
def from_hookedtransformer(cls,
hmodel,
layer_feature_list: List[int],
bandwidth=1.0,
device="cpu"):
"""
hmodel: a HookedTransformer (from transformer_lens)
layer_feature_list: e.g. [128, 256, ...], length = hmodel.cfg.n_layers
"""
L = hmodel.cfg.n_layers
d_model = hmodel.cfg.d_model
if len(layer_feature_list) != L:
raise ValueError(f"layer_feature_list must have length {L}, got {len(layer_feature_list)}")
return cls(
layer_feature_list=layer_feature_list,
d_model=d_model,
bandwidth=bandwidth,
device=device
)
so a cross-layer transcoder is a module that takes in a list of residual streams and returns a list of reconstructed residual streams. as the paper says, it’s goal is to reconstruct the activations of the MLPs in the model.
to re-explain an already well explained concept (from the blog):
- each feature reads from the residual stream of the layer it’s in using a linear encoder
- a given layer’s features helps reconstruct the residual stream of all layers below it using a linear decoder
- features are trained jointly thus the output of an MLP is reconstructed from the features of all the layers below it
training the cross-layer transcoder
to train the CLT, there are two loss functions:
- a reconstruction loss summed over all layers
- a sparsity penalty summed over all layers
the reconstruction loss is obvious, we want to reconstruct the MLP outputs from the features, and we use MSE to do this.
the sparsity penalty is a bit more complex, but the idea is we want to encourage the model to use as few features as possible to reconstruct the MLP outputs
- this promotes interpretability: without these constraints, the model can many features to reconstruct the MLP outputs, which makes it harder attribute specific behavior to a single feature.
- networks exibit polysemanticity, where a single feature can have multiple meanings. this sparsity constraint allows us to monosemanticity, where a single feature has a single meaning.
- reducing noise in the eventual attribution graphs: having too many active features can lead to incredibly dense attribution graphs, which are harder to analyze.
def advanced_sparsity_loss(
clt: CrossLayerTranscoder,
preacts_list: List[torch.Tensor],
c: float = 1.0,
lambda_spars: float = 1e-3
):
"""
Example: sum_{src_layer, i} tanh( c * ||W_dec_{i}|| * mean(a_pre) )
We'll do a simplified version, ignoring multi-step or partial layering.
"""
device = preacts_list[0].device
L = clt.n_layers
penalty = torch.zeros((), device=device)
# for each layer i
for src_layer in range(L):
a_pre = preacts_list[src_layer] # shape [batch, seq, n_feat_this_layer]
# average activation across batch, seq => [n_feat_this_layer]
a_mean = a_pre.mean(dim=(0,1))
# for each tgt_layer >= src_layer
for tgt_layer in range(src_layer, L):
dec_idx = clt.index_map[src_layer][tgt_layer]
W_dec_mat = clt.W_dec[dec_idx] # shape [layer_feature_list[src_layer], d_model]
dec_norm = W_dec_mat.norm(dim=1) # shape [layer_feature_list[src_layer]]
raw_vals = c * dec_norm * a_mean
penalty_layer = torch.tanh(raw_vals).sum()
penalty += penalty_layer
return lambda_spars * penalty
def get_sparsity_scale(current_step, total_steps, lambda_final):
# linear ramp from 0 to lambda_final across total_steps
scale = min(1.0, current_step / (total_steps - 1))
return scale * lambda_final
def sum_of_relu_neg(preacts_list):
"""
preacts_list: a list of Tensors, each shape [batch, seq, n_features],
containing the pre-activation values for each layer.
Returns a scalar that is the sum over all layers, batches, and feature dims of
ReLU(-preactivation).
"""
total_loss = torch.tensor(0.0, device=preacts_list[0].device)
for a_pre in preacts_list:
# shape: [batch, seq, n_features]
negvals = torch.relu(-a_pre) # ReLU(-x) = max(0, -x)
total_loss += negvals.sum()
return total_loss
def train_clt(
clt,
dataloader,
num_epochs,
lambda_spars_final=1e-3,
preact_loss_coef=3e-6,
total_steps=None
):
"""
clt: CrossLayerTranscoder module with JumpReLU threshold=0.03, uniform init, etc.
dataloader: yields (resid_list, mlp_list) already normalized
"""
optimizer = torch.optim.Adam(clt.parameters(), lr=1e-4)
if total_steps is None:
total_steps = num_epochs * len(dataloader)
global_step = 0
for epoch in range(num_epochs):
print(f"Epoch {epoch} of {num_epochs}")
for batch_idx, (resid_batch_list, mlp_batch_list) in enumerate(dataloader):
print(f"Batch {batch_idx} of {len(dataloader)}")
batch_size = resid_batch_list[0].shape[0]
# to GPU
# remember to ensure these are normalized per token
for l in range(len(resid_batch_list)):
# need this for weird batching error
resid_batch_list[l] = resid_batch_list[l].to(clt.device)[0]
mlp_batch_list[l] = mlp_batch_list[l].to(clt.device)[0]
# forward pass => recon_list
recon_list, preactivations = clt.forward_with_preacts(resid_batch_list)
# ^ maybe you define forward_with_preacts() to return
# the list of preactivation a^ell and the final recon.
L_preact = 3e-6 * sum_of_relu_neg(preactivations)
# MSE
mse = torch.tensor(0.0, device=clt.device)
for i in range(len(recon_list)):
diff = recon_list[i] - mlp_batch_list[i]
mse += diff.pow(2).mean()
# advanced sparsity
current_lambda_spars = get_sparsity_scale(
global_step, total_steps, lambda_spars_final
)
L_spars = advanced_sparsity_loss(
clt,
preactivations, # the list of shape [batch, seq, n_features] for each layer
c=1.0,
lambda_spars=current_lambda_spars
)
# preact loss
# preactivations is a list of shape [batch, n_features] for each layer
total_loss = mse + L_spars + L_preact
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
global_step += 1
print(f"Epoch {epoch} done. last total_loss={total_loss.item():.4f} MSE={mse.item():.4f}")
dataset = CLTHiddenStateDataset("clt_data")
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
model_name = "gpt2-small"
hmodel = HookedTransformer.from_pretrained(model_name)
layer_feature_list = [128]*hmodel.cfg.n_layers # or something custom
clt = CrossLayerTranscoder.from_hookedtransformer(
hmodel, layer_feature_list, bandwidth=1.0, device="cuda"
)
train_clt(clt, dataloader, 2)
replacement modeling

basically, a replacement model is the swapping out of the MLPs in the model with the CLT while also providing an error correction term to account for the error in the reconstruction.
the way anthropic frames this is that this basically re-writes the underlying model with sparser, more interpretable units.
the way anthropic describes how we should view the local replacement model is such:
it’s a fully connected neural network spanning across tokens that allow us to do interpretability at research
- effectively, it’s a graph.
- it’s the union of CLT features active at every token position.
- Its weights are the summed interactions over all the linear paths from one feature to another
def get_layer_index_from_hook(hook_name: str):
"""
A helper that extracts the integer layer index from a name like
'blocks.3.hook_mlp_out' or 'blocks.3.hook_resid_mid', etc.
Returns None if not recognized.
"""
# e.g. "blocks.3.hook_mlp_out" -> 3
# naive approach:
if "blocks." not in hook_name:
return None
try:
after_blocks = hook_name.split("blocks.")[1]
i_str = after_blocks.split(".")[0] # e.g. "3"
return int(i_str)
except:
return None
class LocalReplacementModel(nn.Module):
"""
A local replacement model that:
1) Uses `base_transformer` for the general structure,
2) Freezes LN outputs & attention patterns with pre-recorded data,
3) Replaces each MLP with CLT + error corrections,
4) Records the cross-layer-transcoder's feature activations in self.clt_activations.
"""
def __init__(
self,
base_transformer: HookedTransformer,
clt: "CrossLayerTranscoder",
error_corrections: Dict[str, torch.Tensor],
ln_scales: Dict[str, torch.Tensor],
attn_patterns: Dict[str, torch.Tensor],
device="cuda"
):
super().__init__()
self.base = base_transformer
self.clt = clt
self.error_corrections = error_corrections # e.g. { "blocks.3.hook_mlp_out": tensor(...) }
self.ln_scales = ln_scales # e.g. { "blocks.3.ln1.hook_normalized": tensor(...) }
self.attn_patterns = attn_patterns # e.g. { "blocks.3.attn.hook_pattern": tensor(...) }
self.device = device
# We'll store the MLP inputs for each layer (the "hook_mlp_in" values)
self.mlp_inputs = [None]*clt.n_layers
# We'll also store the cross-layer feature activations for each layer
# (shape [batch, seq, n_features_of_layer]) after we run the MLP-out hook
self.clt_activations = [None]*clt.n_layers
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
"""
1. Freeze LN by injecting recorded LN outputs.
2. Freeze QK by injecting stored attention patterns.
3. Replace MLP with CLT + error corrections.
4. Store CLT feature activations in self.clt_activations.
"""
# ===== Define hooking functions =====
def freeze_ln_hook(activation, hook):
"""
Hook on e.g. "blocks.0.ln1.hook_normalized":
We replace the LN output with our pre-saved LN output (from an earlier pass).
"""
name = hook.name
if name in self.ln_scales:
return self.ln_scales[name].to(self.device)
return activation
def freeze_attn_pattern_hook(activation, hook):
"""
Hook on e.g. "blocks.0.attn.hook_pattern":
We replace the QK-softmax pattern with pre-saved patterns.
"""
name = hook.name
if name in self.attn_patterns:
return self.attn_patterns[name].to(self.device)
return activation
def mlp_in_hook(activation, hook):
"""
E.g. "blocks.3.hook_mlp_in" => store for layer 3
"""
layer_idx = get_layer_index_from_hook(hook.name)
if layer_idx is not None:
self.mlp_inputs[layer_idx] = activation.detach().clone()
return activation
def mlp_out_hook(activation, hook):
"""
E.g. "blocks.3.hook_mlp_out" => we do:
1) feed self.mlp_inputs[L] into the CLT for layer L,
2) add error correction,
3) also fill self.clt_activations.
"""
layer_idx = get_layer_index_from_hook(hook.name)
if layer_idx is None:
return activation
x = self.mlp_inputs[layer_idx]
if x is None:
return activation # fallback
# We'll feed "dummy_resids" so that only layer_idx's input is nonzero
dummy_resids = []
for i in range(self.clt.n_layers):
if i == layer_idx:
dummy_resids.append(x)
else:
dummy_resids.append(torch.zeros_like(x))
# We want the "pre-activations" a_pre from each layer,
# so we use forward_with_preacts:
mlp_recon_list, all_a_post = self.clt.forward_with_preacts(dummy_resids)
# mlp_recon_list => list of length n_layers, each shape [batch, seq, d_model]
# all_a_post => list of length n_layers, each shape [batch, seq, n_feat(i)]
# The reconstructed MLP-out for this layer
recon_layer = mlp_recon_list[layer_idx]
# Let's store the cross-layer feature activations for *all* layers
# (each a_post is shape [batch, seq, #features for that layer])
for i, feat_acts in enumerate(all_a_post):
self.clt_activations[i] = feat_acts.detach().clone()
# Then add the "error correction," if any
err_key = hook.name # e.g. "blocks.3.hook_mlp_out"
if err_key in self.error_corrections:
e = self.error_corrections[err_key].to(self.device)
recon_layer = recon_layer + e
return recon_layer
# ===== Attach the hooks =====
hooks = []
# Freeze LN
for i in range(self.base.cfg.n_layers):
ln_name = f"blocks.{i}.ln1.hook_normalized"
if ln_name in self.base.hook_dict:
h_ln = self.base.add_hook(ln_name, freeze_ln_hook, "fwd")
hooks.append(h_ln)
# Freeze attention pattern
for i in range(self.base.cfg.n_layers):
attn_name = f"blocks.{i}.attn.hook_pattern"
if attn_name in self.base.hook_dict:
h_attn = self.base.add_hook(attn_name, freeze_attn_pattern_hook, "fwd")
hooks.append(h_attn)
# Intercept MLP in/out
for i in range(self.clt.n_layers):
in_name = f"blocks.{i}.hook_mlp_in"
out_name = f"blocks.{i}.hook_mlp_out"
if in_name in self.base.hook_dict:
hi = self.base.add_hook(in_name, mlp_in_hook, "fwd")
hooks.append(hi)
if out_name in self.base.hook_dict:
ho = self.base.add_hook(out_name, mlp_out_hook, "fwd")
hooks.append(ho)
# ===== Run forward pass =====
logits = self.base(tokens)
# remove hooks after finishing
for h in hooks:
if h is not None:
h.remove()
return logits
def build_local_replacement_model_with_cache(
base_model: HookedTransformer,
clt: CrossLayerTranscoder,
prompt: str,
device="cuda"
):
"""
1) forward pass on base_model, store LN outputs, attention patterns, MLP in/out
2) compute error corrections
3) return a LocalReplacementModel that re-uses LN & attn
and MLP is replaced with CLT+error
"""
layer_count = base_model.cfg.n_layers
tokens = base_model.to_tokens(prompt, prepend_bos=True).to(device)
# -- Use a filter that picks up the hook names we want. --
# We want LN outputs, attn pattern, mlp_in, mlp_out
# For LN: "blocks.{i}.ln1.hook_normalized"
# For attn: "blocks.{i}.attn.hook_pattern"
# For MLP in/out: "blocks.{i}.hook_mlp_in" / "blocks.{i}.hook_mlp_out"
def activation_filter(name: str):
# Return True if we want to store this hook in the cache
# Return False otherwise
if ".ln1.hook_normalized" in name:
return True
if ".attn.hook_pattern" in name:
return True
if ".hook_mlp_in" in name:
return True
if ".hook_mlp_out" in name:
return True
return False
# run_with_cache returns (logits, cache).
# cache is a HookedTransformerCache containing the stored activations
logits, cache = base_model.run_with_cache(
tokens,
return_type="logits", # or "none" if you don't need final logits
names_filter=activation_filter
)
ln_scales = {}
attn_patterns = {}
mlp_in_cache = {}
mlp_out_cache = {}
# read from cache for each layer's LN out, attn pattern, MLP in/out
for i in range(layer_count):
ln_key = f"blocks.{i}.ln1.hook_normalized"
attn_key = f"blocks.{i}.attn.hook_pattern"
in_key = f"blocks.{i}.hook_mlp_in"
out_key = f"blocks.{i}.hook_mlp_out"
if ln_key in cache:
ln_scales[ln_key] = cache[ln_key].detach().clone()
if attn_key in cache:
attn_patterns[attn_key] = cache[attn_key].detach().clone()
if in_key in cache:
mlp_in_cache[in_key] = cache[in_key].detach().clone()
if out_key in cache:
mlp_out_cache[out_key] = cache[out_key].detach().clone()
# 2) Build error corrections by comparing the model's MLP out to the CLT recon
error_corrections = {}
for i in range(clt.n_layers):
layer_in_name = f"blocks.{i}.hook_mlp_in"
layer_out_name = f"blocks.{i}.hook_mlp_out"
# Only do this if we actually have MLP in/out for that layer
if layer_in_name in mlp_in_cache and layer_out_name in mlp_out_cache:
mlp_in = mlp_in_cache[layer_in_name].to(device)
mlp_out_true = mlp_out_cache[layer_out_name].to(device)
# minimal forward in CLT for just this layer
dummy_inputs = [torch.zeros_like(mlp_in) for _ in range(clt.n_layers)]
dummy_inputs[i] = mlp_in
clt_outputs = clt(dummy_inputs) # list of shape [n_layers]
clt_layer_out = clt_outputs[i] # [batch, seq, d_model]
diff = mlp_out_true - clt_layer_out
error_corrections[layer_out_name] = diff.detach().clone()
# 3) Build the final "frozen" local model
local_model = LocalReplacementModel(
base_transformer=base_model,
clt=clt,
error_corrections=error_corrections,
ln_scales=ln_scales,
attn_patterns=attn_patterns,
device=device
)
return local_model
Attribution Graphs

these are representations of how different computational components (specifically the interpretable features) contribute to the final output of the model.
graphs contain 4 nodes:
- output: the final output tokens of the model
- input: the input embedding tokens to the model
- intermediate: clt features at each prompt token
- error: the remaining output unexplained by the CLT
edges represent linear attributions; they originate from the input node and end at the output node and indicate a direct, linear influence from one node to another. the idea is that the activation of any feature (or token) is decomposed into the sum of its incoming contributions from these edges, allowing a clear, linear causal interpretation.
the graph is constructed by backward attribution via jacobians.
so we want to decompose the output into the sum of contributions from earlier components. because the influence is spread over many different paths, using backward jacobians allows one to quantify the sensitivity of the target’s pre-activation to changes in the source’s activation.
< explain math/gradients >
for every target t in our graph, we inject its corresponding input vector into the residual stream at the appropriate layer/token position. we then perform a backward pass on the underlying model with the following modifications: 1) stop-gradients are inserted into the non-linear parts of the model and 2) frozen attention patterns are used so that the only way the input can affect the output is through the linear paths in the graph. for any source node, its contribution is computed as a dot product between its decoder vector and the gradient signal that flows from the target to the source. lastly, the result is scaled by the activation of the source node.
by computing these backward jacobians, we effectively obtain a set of linear influence weights that hell us how much each feature contributes to the target. by doing this, you can effectively map out a circuit (in the mech interp sense).
anthropic also describes a graph pruning algorithm that they describe with the following pseudocode:
function compute_normalized_adjacency_matrix(graph):
# Convert graph to adjacency matrix A
# A[j, i] = weight from i to j (note the transposition)
A = convert_graph_to_adjacency_matrix(graph)
A = absolute_value(A)
# Normalize each row to sum to 1
row_sums = sum(A, axis=1)
row_sums = maximum(row_sums, 1e-8) # Avoid division by zero
A = diagonal_matrix(1/row_sums) @ A
return A
function prune_nodes_by_indirect_influence(graph, threshold):
A = compute_normalized_adjacency_matrix(graph)
# Calculate the indirect influence matrix: B = (I - A)^-1 - I
# This is a more efficient way to compute A + A^2 + A^3 …
B = inverse(identity_matrix(size=A.shape[0]) - A) - identity_matrix(size=A.shape[0])
# Get weights for logit nodes.
# This is 0 if a node is a non-logit node and equal to the probability for logit nodes
logit_weights = get_logit_weights(graph)
# Calculate influence on logit nodes for each node
influence_on_logits = matrix_multiply(B, logit_weights)
# Sort nodes by influence
sorted_node_indices = argsort(influence_on_logits, descending=True)
# Calculate cumulative influence
cumulative_influence = cumulative_sum(
influence_on_logits[sorted_node_indices]) / sum(influence_on_logits)
# Keep nodes with cumulative influence up to threshold
nodes_to_keep = cumulative_influence <= threshold
# Create new graph with only kept nodes and their edges
return create_subgraph(graph, nodes_to_keep)
# Edge pruning by thresholded influence
function prune_edges_by_thresholded_influence(graph, threshold):
# Get normalized adjacency matrix
A = compute_normalized_adjacency_matrix(graph)
# Calculate influence matrix (as before)
B = estimate_indirect_influence(A)
# Get logit node weights (as before)
logit_weights = get_logit_weights(graph)
# Calculate node scores (influence on logits)
node_score = matrix_multiply(B, logit_weights)
# Edge score is weighted by the logit influence of the target node
edge_score = A * node_score[:, None]
# Calculate edges to keep based on thresholded cumulative score
sorted_edges = sort(edge_score.flatten(), descending=True)
cumulative_score = cumulative_sum(sorted_edges) / sum(sorted_edges)
threshold_index = index_where(cumulative_score >= threshold)
edge_mask = edge_score >= sorted_edges[threshold_index]
# Create new graph with pruned adjacency matrix
pruned_adjacency = A * edge_mask
return create_subgraph_from_adjacency(graph, pruned_adjacency)
my implementation of the attribution graph code is as follows:
class NodeType:
EMBEDDING = "embedding"
FEATURE = "feature"
ERROR = "error"
LOGIT = "logit"
class AttributionNode:
"""
A node in the attribution graph, e.g. a single feature in some layer,
or an embedding node, or a logit node, etc.
"""
def __init__(
self,
node_type: str,
name: str,
layer_idx: int = None,
context_pos: int = None,
logit_index: int = None
):
self.node_type = node_type
self.name = name
# Set layer_idx to a sentinel if not provided
if layer_idx is None:
if node_type == NodeType.EMBEDDING:
self.layer_idx = -1
elif node_type == NodeType.LOGIT:
self.layer_idx = 9999
else:
self.layer_idx = 0
else:
self.layer_idx = layer_idx
self.context_pos = context_pos
self.logit_index = logit_index
# Storing these as placeholders
self.activation = None
self.output_vector = None
self.input_vector = None
# Must be an integer for feature nodes:
self.feature_index = None # we must fill this in ourselves
# We'll store an integer ID in the graph
self.id = None
def __repr__(self):
return f"Node<{self.node_type}:{self.name}:{self.id}>"
class AttributionGraph:
"""
A container for all nodes and edges in the attribution graph.
"""
def __init__(self):
self.nodes = []
self.edges = {} # adjacency: edges[u] = list of (v, weight)
def add_node(self, node: AttributionNode):
node.id = len(self.nodes)
self.nodes.append(node)
self.edges[node.id] = []
return node.id
def add_edge(self, src_id: int, tgt_id: int, weight: float):
self.edges[src_id].append((tgt_id, weight))
def get_num_nodes(self):
return len(self.nodes)
def build_graph_nodes(
local_model, # e.g. a LocalReplacementModel
prompt: str,
top_k=3,
feature_threshold=0.01
):
"""
1) Forward pass => local_model.clt_activations is populated
2) Build an AttributionGraph
3) Add logit nodes for top_k final predictions
4) Add feature nodes for each cross-layer feature whose activation > threshold
"""
G = AttributionGraph()
tokens = local_model.base.to_tokens(prompt, prepend_bos=True).to(local_model.device)
with torch.no_grad():
logits = local_model(tokens) # fills local_model.clt_activations
# Add top-k logit nodes
final_logits = logits[0, -1, :]
probs = F.softmax(final_logits, dim=-1)
top_vals, top_inds = torch.topk(final_logits, top_k)
logit_node_ids = []
for rank in range(top_k):
tok_id = top_inds[rank].item()
p = probs[tok_id].item()
node = AttributionNode(
node_type=NodeType.LOGIT,
name=f"logit_{tok_id} (p={p:.3f})",
layer_idx=None, # sets layer_idx=9999
context_pos=tokens.shape[1]-1, # last position
logit_index=tok_id
)
nid = G.add_node(node)
logit_node_ids.append(nid)
# Create feature nodes from local_model.clt_activations
feature_nodes = {}
for i in range(local_model.clt.n_layers):
acts_i = local_model.clt_activations[i] # shape [1, seq, n_feat_i]
if acts_i is None:
continue
acts_i = acts_i[0] # shape [seq, n_feat_i] if batch=1
seq_len, n_feat_i = acts_i.shape
for pos in range(seq_len):
for feat_j in range(n_feat_i):
val = acts_i[pos, feat_j].item()
if val > feature_threshold:
# Build a feature node => layer i, position pos, feature index feat_j
node = AttributionNode(
node_type=NodeType.FEATURE,
name=f"feat_L{i}_f{feat_j}_pos{pos}",
layer_idx=i,
context_pos=pos,
logit_index=None
)
node.activation = val
# Key fix: assign the integer feature index
node.feature_index = feat_j # <--- CRUCIAL
node_id = G.add_node(node)
feature_nodes[(i, pos, feat_j)] = node_id
cache_dict = {
"logit_node_ids": logit_node_ids,
"feature_nodes": feature_nodes
}
return G, cache_dict
def compute_logit_injection_vector(local_model, vocab_idx):
"""
Suppose you do 'logits = W_U * final_resid'. Then the gradient for
(logit[vocab_idx] - mean(logit)) w.r.t. final_resid is:
injection_vec = W_U[vocab_idx] - average(W_U)
or something along these lines.
"""
W_U = local_model.base.W_U # e.g. unembedding weights shape [d_model, vocab_size]
d_model, vocab_size = W_U.shape
# minimal example:
w_target = W_U[:, vocab_idx] # shape [d_model]
w_mean = W_U.mean(dim=1) # shape [d_model]
injection_vec = (w_target - w_mean)
return injection_vec
def run_backward_with_injection(
local_model: LocalReplacementModel,
injection_vec: torch.Tensor,
layer_idx: int,
real_tokens: torch.Tensor,
token_pos: int = None,
freeze_ln=True,
freeze_attn=True,
):
"""
Runs a backward pass in 'local_model' (a LocalReplacementModel) with an
'injection_vec' placed into the chosen layer (layer_idx) residual stream.
We now pass 'real_tokens' so that the dimension of the residual
matches the dimension we used in the normal forward pass.
Args:
local_model: The local replacement model (with LN & attn patterns frozen).
injection_vec: a Tensor of shape [d_model] or [batch, seq, d_model],
specifying the "gradient signal" we want to inject.
layer_idx: which layer's residual we inject into.
real_tokens: the actual token IDs used in the normal forward pass
(must have the same shape as originally used by local_model).
token_pos: which token position to inject into. If None, we broadcast across all positions.
freeze_ln: whether to freeze LN denominators in the backward pass.
freeze_attn: whether to freeze QK patterns in the backward pass.
Returns:
grad_dict: a dictionary of shape {(layer_idx, pos): residual_grad_vector}
containing the gradient w.r.t. each layer's residual stream after injection.
"""
# (A) clamp layer_idx if needed
nL = local_model.base.cfg.n_layers
if layer_idx >= nL:
# Perhaps for logit nodes, set it to nL - 1
layer_idx = nL - 1
# 1) Zero out old gradients
local_model.zero_grad(set_to_none=True)
# 2) Optionally freeze LN or QK params so they do not accumulate gradient
# This ensures that no gradient accumulates in LN or QK parameters.
# Just be mindful that some models might have LN named differently (ln_f or ln_final).
if freeze_ln or freeze_attn:
for name, param in local_model.base.named_parameters():
if freeze_ln and ("ln" in name):
param.requires_grad_(False)
if freeze_attn and (".W_Q" in name or ".W_K" in name):
param.requires_grad_(False)
# 3) Hook the chosen layer's residual via a forward pass with the *real tokens*
storage = {}
def store_resid_hook(resid, hook):
# resid shape is [batch, seq, d_model]
storage["resid"] = resid
return resid
# By default for GPT-2 style, the "pre-MLP residual" is often "blocks.{layer_idx}.hook_resid_mid"
# or "blocks.{layer_idx}.hook_resid_pre". For GPT-2 from TransformerLens, it's typically:
# "blocks.{layer_idx}.hook_resid_mid"
# but the user's code might vary. We'll try "hook_resid_mid" here:
resid_name = f"blocks.{layer_idx}.hook_resid_mid"
if resid_name not in local_model.base.hook_dict:
# fallback or raise error
# e.g. check if "hook_resid_post" is present
alt_name = f"blocks.{layer_idx}.hook_resid_post"
if alt_name in local_model.base.hook_dict:
resid_name = alt_name
else:
raise KeyError(f"No valid mid/post resid hook for layer {layer_idx}")
# register the forward hook
handle = local_model.base.add_hook(resid_name, store_resid_hook, "fwd")
with torch.enable_grad():
# Perform the forward pass with the REAL tokens so shapes match
_ = local_model(real_tokens)
if handle is not None:
handle.remove()
if "resid" not in storage:
raise ValueError(
f"Could not find residual for layer {layer_idx} - check your "
f"hook name (is it 'hook_resid_mid' or 'hook_resid_pre'?)"
)
resid_L = storage["resid"] # shape [batch, seq, d_model]
# We want gradient to flow from the dummy loss -> resid_L
# 4) Construct the dummy loss
# If injection_vec is [d_model], we broadcast across [batch, seq, d_model].
shape = resid_L.shape
if injection_vec.dim() == 1 and injection_vec.shape[0] == shape[-1]:
expanded_injection = injection_vec.view(1, 1, -1).expand_as(resid_L)
elif injection_vec.shape == shape:
expanded_injection = injection_vec
else:
raise ValueError(f"injection_vec shape mismatch: got {injection_vec.shape}, "
f"but resid_L is {shape}.")
if token_pos is not None:
# zero out gradient for all positions except 'token_pos'
mask = torch.zeros_like(resid_L)
mask[:, token_pos, :] = 1.0
expanded_injection = expanded_injection * mask
dummy_loss = (resid_L * expanded_injection).sum()
dummy_loss.backward(retain_graph=True)
# 5) gather gradient in a dictionary
grad_dict = {}
# For demonstration, we only store the gradient for exactly (layer_idx, token_pos)
# If you want all positions, you can store them all.
# We'll do a single key => shape [batch, seq, d_model]
if resid_L.grad is not None:
grad_dict[(layer_idx, "all")] = resid_L.grad.detach().cpu()
else:
grad_dict[(layer_idx, "all")] = None
return grad_dict
def is_upstream_layer(src_node, tgt_node):
"""
Example helper: returns True if src_node is strictly earlier in
(layer, position) than tgt_node. You can define your own logic for
"which nodes can feed into which" in your graph.
"""
# e.g. if src_node.layer_idx < tgt_node.layer_idx, or if equal layer but earlier tokenpos
# etc. We'll just do a dummy check:
return (src_node.layer_idx < tgt_node.layer_idx) or (
src_node.layer_idx == tgt_node.layer_idx
and src_node.context_pos <= tgt_node.context_pos
)
def get_token_embedding(token_id, local_model):
"""
Suppose your model has a wte (word token embed) of shape [vocab_size, d_model].
Then the embedding is wte[token_id].
"""
return local_model.base.embed[token_id] # for example
def compute_direct_edges_for_node(
G,
node_id: int,
local_model: nn.Module,
tokens: torch.Tensor,
freeze_ln=True,
freeze_attn=True,
epsilon=1e-8,
):
"""
For a given node in the attribution graph, we compute the direct edges from
*all* upstream nodes in G to `node_id`. We do a custom backward pass in the
local replacement model, with LN denominators and QK patterns frozen (stop-grad).
Pseudocode steps:
1) Build the injection vector for the `node_id`—this depends on whether it's a
Feature node, Logit node, etc.
2) Insert that injection vector into the residual stream at the correct layer
(or final-layer residual for a logit).
3) Run a custom backward pass that accumulates `.grad` in the model's residual
streams (and zero in LN denominators, etc.).
4) For each source node in G, compute the direct edge weight using either:
w = source_activation * sum_{k} [ W_dec^{(source)}^T * grad_for_layer_k * W_enc^{(target)} ]
or for embeddings / error nodes, a direct dot-product with the residual grad.
5) Add edge to G if abs(weight) > threshold.
"""
node = G.nodes[node_id]
node_type = node.node_type
# ---------------------------------------------------------------------
# 1) Build the injection vector for the "target node"
# ---------------------------------------------------------------------
if node_type == NodeType.LOGIT:
# For a logit node, commonly we do "gradient w.r.t. (logit_tok - mean_logit)".
target_logit_idx = node.logit_index # e.g. ID in vocab
# Suppose we have the final-layer residual dimension = d_model
# We make an injection vector of shape [d_model], or [1, d_model]
injection_vec = compute_logit_injection_vector(local_model, target_logit_idx)
elif node_type == NodeType.FEATURE:
# For a feature node, we typically want to inject that feature's input vector
# into the residual stream at the layer it reads from. That is:
# v_in^L = W_enc[featureID] (the encoder weights)
# For cross-layer transcoders, each feature has an encoder layer. We assume
# node stores e.g. node.layer_idx, node.feature_idx, etc.
layer_idx = node.layer_idx
feat_idx = node.feature_idx # e.g. index in that layer's set of features
injection_vec = local_model.clt.W_enc[layer_idx][feat_idx].detach().clone()
elif node_type == NodeType.ERROR:
# For error nodes, we might define injection_vec as the error node's
# "output vector" in the residual stream.
# e.g. v_out = MLP_out_true - CLT_out. We'll just pretend we have it stored:
injection_vec = node.output_vector.detach().clone()
elif node_type == NodeType.EMBEDDING:
# For an embedding node, we can do something like:
token_embed = get_token_embedding(node.token_id, local_model)
injection_vec = token_embed.detach().clone()
else:
raise ValueError(f"Unknown node_type = {node_type}")
# We'll reshape to [d_model] or [1,d_model] if needed
injection_vec = injection_vec.detach().view(-1)
injection_vec.requires_grad_(False)
# ---------------------------------------------------------------------
# 2) Insert the injection vector into the residual stream
# We'll do a custom backward pass that sets MLP out, LN denominators, QK patterns
# to no_grad or zero_grad, etc.
# ---------------------------------------------------------------------
# For example, if the node is a LOGIT node, we put injection_vec at the final-layer residual
# If it's a FEATURE in layer L, we put the injection_vec in the residual of layer L
# We'll store them in a 'tensor' that the model sees as the backward pass "incoming gradient".
# We'll define a helper function. In a real code, you'd define a specialized
# "run_backward_with_injection" that does your custom hooking logic:
grad_dict = run_backward_with_injection(
local_model=local_model,
injection_vec=injection_vec,
layer_idx=node.layer_idx if node_type == NodeType.FEATURE else local_model.base.cfg.n_layers,
# ^ if logit or error or embedding, we might set this to final layer
real_tokens=tokens,
freeze_ln=freeze_ln,
freeze_attn=freeze_attn,
)
# grad_dict: Suppose it returns a dict mapping {(layer_idx, token_pos): residual_grad_tensor}
# or something similar for each layer. You may also want the raw MLP-input grads in
# your local replacement model or the partial derivatives w.r.t. CLT decoders, etc.
# ---------------------------------------------------------------------
# 3) For each source node, compute direct edge weight
# ---------------------------------------------------------------------
# We'll show how to do it for a "feature" source node (the typical case).
# We'll also handle "embedding" and "error" node. (They differ because
# features have to multiply by their activation in this prompt.)
# We'll do a loop over all nodes in G. In practice, you might only want to
# handle nodes at earlier layers, or earlier token positions, etc.
for src_id, src_node in enumerate(G.nodes):
if src_id == node_id:
continue
if not is_upstream_layer(src_node, node): # you might have logic to skip obviously irrelevant nodes
continue
w = 0.0
if src_node.node_type == NodeType.FEATURE:
# The formula from the text is roughly:
#
# A_{s -> t} = a_s * sum_{ℓ in [src_layer..(tgt_layer-1)]} [
# W_dec(src_feat)ᵀ * grad[ℓ] * W_enc(tgt_feat)
# ]
#
# But we already have "grad[ℓ] * W_enc(tgt_feat)" from the injection pass,
# plus we can do "W_dec(src_feat)ᵀ . that" ...
# We'll do it more explicitly in code:
src_activation = src_node.activation # a_s from the prompt
s_layer = src_node.layer_idx
s_feat_idx = src_node.feature_idx
# Grab decoders from local_model.clt for that feature
# For each layer in [s_layer..(node.layer_idx-1)], we can do something:
# out_vec = local_model.clt.W_dec[ index_map[s_layer][ℓ] ][ s_feat_idx ]
# partial = torch.dot( out_vec, grad_dict[ℓ, node.pos] ) # shape both [d_model]
# Summation of partial across the relevant layers.
sum_val = 0.0
# We'll define target_layer = node.layer_idx if node is feature
# or local_model.base.cfg.n_layers if node is logit
end_layer = (node.layer_idx if node_type == NodeType.FEATURE
else local_model.base.cfg.n_layers)
for mid_layer in range(s_layer, end_layer):
dec_idx = local_model.clt.index_map[s_layer][mid_layer]
out_vec = local_model.clt.W_dec[dec_idx][s_feat_idx] # shape [d_model]
# Suppose grad_dict[(mid_layer, node.context_pos)] is the gradient for that token/resid
# We'll do a dot product:
if (mid_layer, src_node.context_pos) not in grad_dict:
continue # might not exist
grad_vec = grad_dict[(mid_layer, src_node.context_pos)]
partial = torch.dot(out_vec, grad_vec)
sum_val += partial.item()
# Now multiply by a_s:
w = src_activation * sum_val
elif src_node.node_type == NodeType.EMBEDDING:
# Then the direct edge is basically
# w = (Emb_src)^T grad[ src_layer, context_pos ]
# Embeddings read into layer 0 (or wherever your model does).
# Or you might store them differently.
# If your local model lumps them into the "resid at layer 0," you can do:
emb_vec = src_node.embedding_vector
layer_for_emb = 0 # typically we treat the embedding as "layer 0"
pos = src_node.context_pos
if (layer_for_emb, pos) in grad_dict:
grad_vec = grad_dict[(layer_for_emb, pos)]
w = torch.dot(emb_vec, grad_vec).item()
else:
w = 0.0
elif src_node.node_type == NodeType.ERROR:
# Error nodes have no upstream inputs. By definition in the text,
# they "pop out of nowhere." So typically we do NOT add edges in from anything.
# Conversely, if you're computing edges from error to *this node*, you do:
# w = error_vec^T gradVec
# for the relevant layer. Example:
err_out_vec = src_node.output_vector # shape [d_model]
err_layer = src_node.layer_idx
pos = src_node.context_pos
if (err_layer, pos) in grad_dict:
grad_vec = grad_dict[(err_layer, pos)]
w = torch.dot(err_out_vec, grad_vec).item()
else:
w = 0.0
# If abs(w) is big enough, we add an edge
if abs(w) > 1e-2: # TODO: change this threshold to 1e-4
G.add_edge(src_id, node_id, w)
# done. G is updated with edges from source_node -> node_id.
# You'd repeat this function for each node in the graph for which
# you want to discover incoming edges.
return G # done
def build_attribution_graph_for_prompt(local_model, prompt_tokens):
# 1) build the nodes
G, cache = build_graph_nodes(prompt_tokens, local_model)
# 2) for each "target node" in logit_node_ids, do partial backward
for lnid in cache["logit_node_ids"]:
compute_direct_edges_for_node(G, lnid, local_model, prompt_tokens)
# 3) for each feature node, if you want edges going into it, do partial backward too
for (layer, pos, feat_i), node_id in cache["feature_nodes"].items():
compute_direct_edges_for_node(G, node_id, local_model, prompt_tokens)
# 4) done building adjacency
return G
def build_adjacency_matrix(G: AttributionGraph):
"""
Returns adjacency matrix A of shape [N, N], where A[j,i] = sum of edges from i->j
This is typically used for the pruning step.
"""
N = G.get_num_nodes()
A = np.zeros((N,N), dtype=np.float32)
for i in range(N):
for (j, w) in G.edges[i]:
A[j,i] += w
return A
def prune_graph(
G: AttributionGraph,
logit_nodes: List[int],
threshold_nodes=0.8,
threshold_edges=0.98
):
"""
1) Build adjacency matrix
2) compute B = (I - A)^-1 - I
3) For each node, sum up influence on logit nodes
4) keep top X% => subgraph
5) then re-build adjacency, do edge-level pruning similarly.
"""
# 1) adjacency
A = build_adjacency_matrix(G)
N = A.shape[0]
# 2) compute B
# we might want to clamp negative edges or do absolute value. The paper does "abs + row-normalize"
# For simplicity, do exactly as the paper says: we'll do a quick absolute + row normalization:
A_abs = np.abs(A)
row_sums = A_abs.sum(axis=1, keepdims=True)
row_sums[row_sums < 1e-10] = 1.0
A_norm = A_abs / row_sums
# Then B = (I - A_norm)^-1 - I
I = np.eye(N, dtype=np.float32)
M = (I - A_norm)
M_inv = np.linalg.inv(M)
B = M_inv - I # shape [N, N]
# 3) logit influence
# We have logit_nodes; we can weight them by their probability, or just sum them
# Suppose we do a simple sum
logit_mask = np.zeros((N,), dtype=np.float32)
for lnid in logit_nodes:
logit_mask[lnid] = 1.0
# node_influence = B * logit_mask
# shape = [N,N], [N], => we do a matmul
node_influence = B @ logit_mask # shape [N]
# 4) sort nodes by influence
order = np.argsort(-node_influence) # descending
csum = np.cumsum(node_influence[order])
total = csum[-1] if csum[-1]>0 else 1.0
cutoff_val = threshold_nodes*total
keep_mask = np.zeros((N,), dtype=bool)
for idx in range(len(order)):
if csum[idx] <= cutoff_val:
keep_mask[order[idx]] = True
else:
break
# We now form a new subgraph with only those nodes. (Edges from or to "kept" nodes.)
# Then we do a second pass for edge pruning using the same logic but with B for edges.
# For brevity, we'll just do node-level here.
# In practice, you'd want to do the 2-step approach from the text.
new_G = AttributionGraph()
old_to_new = {}
for i in range(N):
if keep_mask[i]:
new_i = new_G.add_node(G.nodes[i])
old_to_new[i] = new_i
for i in range(N):
if keep_mask[i]:
for (j, w) in G.edges[i]:
if keep_mask[j]:
new_G.add_edge(old_to_new[i], old_to_new[j], w)
return new_G
def compute_all_direct_edges_layer(
local_model,
tokens,
layer_idx,
target_nodes,
source_nodes,
threshold=1e-4,
freeze_ln=True,
freeze_attn=True,
):
"""
- Gathers the injection vectors for many 'target_nodes' that read from 'layer_idx'.
- Runs single_pass_grad_for_targets exactly ONCE, returning a 4D array of partial derivatives.
- For each source_node in `source_nodes`, does a dot product to get the direct edge
from source -> each target.
- Returns a list of (source_id, target_id, weight).
NOTE: This function is a simplified example. It only demonstrates:
- FEATURE nodes as targets, or LOGIT nodes as targets, for building injection vectors.
- EMBEDDING or FEATURE nodes as sources (we do a simple dot product).
- You can adapt it for your code (e.g. ERROR nodes, multi-layer partial sums, etc.).
"""
# 1) Build the injection matrix for the target_nodes
injection_list = []
# We'll keep track of which target_node corresponds to each row
# so we can output edges in the correct order
target_index_map = []
d_model = local_model.base.cfg.d_model
for i, tnode in enumerate(target_nodes):
if tnode.node_type == "feature":
# e.g. local_model.clt.W_enc[layer_idx][feat_idx]
feat_idx = tnode.feature_index
inj_vec = local_model.clt.W_enc[layer_idx][feat_idx].detach().clone()
elif tnode.node_type == "logit":
# logit injection
inj_vec = compute_logit_injection_vector(local_model, tnode.logit_index).detach().clone()
else:
# If you had embeddings or error nodes as targets, you'd define them here
raise ValueError(f"Unsupported target node type {tnode.node_type}")
injection_list.append(inj_vec)
target_index_map.append(i)
# shape = [n_targets, d_model]
injection_mat = torch.stack(injection_list, dim=0)
# 2) Single pass => gradient wrt this layer's residual
grad_wrt_resid = single_pass_grad_for_targets(
local_model=local_model,
tokens=tokens,
layer_idx=layer_idx,
target_injection_vectors=injection_mat,
freeze_ln=freeze_ln,
freeze_attn=freeze_attn,
)
# shape [batch, seq, d_model, n_targets]
# We'll assume batch=1 for simplicity
# (If batch>1, you might do a sum over all batch elements or handle them differently.)
grad_wrt_resid = grad_wrt_resid[0] # shape [seq, d_model, n_targets]
seq_len, d_model, n_targets = grad_wrt_resid.shape
# 3) Build edges from source_nodes to target_nodes
edges = [] # list of (src_id, tgt_id, weight)
for src_node in source_nodes:
src_id = src_node.id
if src_node.node_type == "embedding":
# Then the embedding vector is local_model.base.embed[idx], for token idx
# We'll do a dot product with grad_wrt_resid at the correct context_pos
emb_idx = src_node.logit_index # or however you store it
# Actually, we might store it as src_node.output_vector or something
# For demonstration:
token_id = src_node.logit_index
emb_vec = local_model.base.embed[token_id] # shape [d_model]
# Or if you stored an actual .embedding_vector in the node, use that
emb_vec = emb_vec.detach().clone()
pos = src_node.context_pos
if pos < 0 or pos >= seq_len:
# skip if out of range
continue
# shape [d_model, n_targets]
grad_slice = grad_wrt_resid[pos] # shape [d_model, n_targets]
# w = dot(emb_vec, grad_slice[:, i]) for each target i
# => shape [n_targets]
w_vals = torch.einsum("d, dn->n", emb_vec, grad_slice)
# w_vals is shape [n_targets]
# we can apply threshold, or store them all
for t_i, wval in enumerate(w_vals):
val = wval.item()
if abs(val) >= threshold:
tgt_id = target_nodes[t_i].id
edges.append((src_id, tgt_id, val))
elif src_node.node_type == "feature":
# Suppose we do a simpler, single-layer approach:
# direct edge ~ src_activation * dot(W_dec[src_feat], grad_slice)
# We'll find the correct position, assume it's src_node.context_pos
pos = src_node.context_pos
s_feat_idx = src_node.feature_index
# The src activation is typically stored in node.activation
src_activation = src_node.activation
# The "decoder" vector is local_model.clt.W_dec[index_map[src_layer][layer_idx]][s_feat_idx], etc.
# For simplicity, let's assume src_node.layer_idx == layer_idx
# in real code, you might do a loop from src_layer..layer_idx-1
if pos < 0 or pos >= seq_len:
continue
# shape [d_model, n_targets]
grad_slice = grad_wrt_resid[pos] # shape [d_model, n_targets]
# The decoder vector for that feature -> layer_idx
dec_idx = local_model.clt.index_map[src_node.layer_idx][layer_idx]
if dec_idx is None:
# skip if e.g. src_node.layer_idx>layer_idx
continue
out_vec = local_model.clt.W_dec[dec_idx][s_feat_idx] # shape [d_model]
# Then do dot(out_vec, grad_slice) => shape [n_targets]
partial_vals = torch.einsum("d, dn->n", out_vec, grad_slice)
# Multiply by the activation
w_vals = src_activation * partial_vals # shape [n_targets]
for t_i, wval in enumerate(w_vals):
val = wval.item()
if abs(val) >= threshold:
tgt_id = target_nodes[t_i].id
edges.append((src_id, tgt_id, val))
else:
# If you had other node types (ERROR, etc.), do similarly
pass
return edges
def build_direct_edges_single_pass_debug(
G: AttributionGraph,
local_model: LocalReplacementModel,
tokens: torch.Tensor,
threshold=1e-4,
):
"""
Debug version: We add print statements to confirm the shape of each injection vector
for each target node. We also print the final stacked shape, to confirm it's
[n_targets, d_model].
"""
n_layers = local_model.base.cfg.n_layers
d_model = local_model.base.cfg.d_model
print(f"[DEBUG] n_layers={n_layers}, d_model={d_model}")
# Bucket target nodes by layer
layer_to_targets = {L: [] for L in range(n_layers+1)}
for node in G.nodes:
if node.node_type == NodeType.LOGIT:
layer_to_targets[n_layers].append(node)
elif node.node_type == NodeType.FEATURE:
layer_to_targets[node.layer_idx].append(node)
else:
pass # e.g. embedding/error => sources only
# One pass per layer
for layer_i in range(n_layers+1):
target_nodes = layer_to_targets[layer_i]
if len(target_nodes) == 0:
continue
print(f"\n[DEBUG] Layer={layer_i}, #targets={len(target_nodes)}")
# 2A) Build injection_mat => shape [n_targets, d_model]
injection_list = []
for idx, tnode in enumerate(target_nodes):
if tnode.node_type == NodeType.FEATURE:
feat_idx = tnode.feature_index
w_enc_matrix = local_model.clt.W_enc[layer_i] # shape [num_features, d_model], e.g. [128,768]
print(f" [DEBUG] Target node {idx}: FEATURE layer={layer_i}, feat_idx={feat_idx}")
print(f" w_enc_matrix.shape = {tuple(w_enc_matrix.shape)}")
# The injection vector must be shape [d_model], not bigger
inj_vec = w_enc_matrix[feat_idx] # shape => [768]
print(f" inj_vec.shape = {tuple(inj_vec.shape)}")
injection_list.append(inj_vec.detach().clone())
elif tnode.node_type == NodeType.LOGIT:
print(f" [DEBUG] Target node {idx}: LOGIT layer={layer_i}, logit_index={tnode.logit_index}")
inj_vec = compute_logit_injection_vector(local_model, tnode.logit_index)
print(f" inj_vec.shape = {tuple(inj_vec.shape)}")
injection_list.append(inj_vec.detach().clone())
else:
print(f" [DEBUG] Skipping target node {idx}: node_type={tnode.node_type}")
continue
if len(injection_list) == 0:
print(f" [DEBUG] No valid feature/logit targets for layer={layer_i}, skipping.")
continue
# shape => [n_targets, d_model]
injection_mat = torch.stack(injection_list, dim=0)
print(f"[DEBUG] injection_mat final shape => {tuple(injection_mat.shape)}")
# If layer_i == n_layers => clamp to n_layers-1
actual_layer = min(layer_i, n_layers - 1)
print(f"[DEBUG] actual_layer = {actual_layer}")
# 2B) single backward pass => gradient wrt resid at actual_layer
grad_wrt_resid = single_pass_grad_for_targets_debug(
local_model=local_model,
tokens=tokens,
layer_idx=actual_layer,
target_injection_vectors=injection_mat,
freeze_ln=True,
freeze_attn=True,
)
# => shape [batch=1, seq, d_model, n_targets]
grad_wrt_resid = grad_wrt_resid[0] # => [seq, d_model, n_targets]
seq_len, _, n_targets = grad_wrt_resid.shape
# 2C) For each source node...
for src_id, src_node in enumerate(G.nodes):
if not is_upstream_layer_layer(src_node, layer_i):
continue
for t_idx, tgt_node in enumerate(target_nodes):
pos = tgt_node.context_pos
if (pos < 0) or (pos >= seq_len):
continue
grad_slice = grad_wrt_resid[pos, :, t_idx] # => shape [d_model]
w_val = 0.0
if src_node.node_type == NodeType.EMBEDDING:
emb_vec = src_node.output_vector # shape [d_model]
w_val = float(torch.dot(emb_vec, grad_slice))
elif src_node.node_type == NodeType.FEATURE:
a_s = src_node.activation
s_feat_idx = src_node.feature_index
s_layer = src_node.layer_idx
target_index = min(layer_i, n_layers - 1)
dec_idx = local_model.clt.index_map[s_layer][target_index]
if dec_idx is not None:
out_vec = local_model.clt.W_dec[dec_idx][s_feat_idx] # [d_model]
partial = float(torch.dot(out_vec, grad_slice))
w_val = a_s * partial
if abs(w_val) >= threshold:
G.add_edge(src_id, tgt_node.id, w_val)
return G
def single_pass_grad_for_targets_debug(
local_model,
tokens,
layer_idx,
target_injection_vectors,
freeze_ln=True,
freeze_attn=True,
):
"""
Same as single_pass_grad_for_targets but with debug prints about shapes.
"""
d_model = local_model.base.cfg.d_model
print(f"\n[DEBUG single_pass_grad_for_targets] => layer_idx={layer_idx}")
print(f" target_injection_vectors.shape = {list(target_injection_vectors.shape)}")
local_model.zero_grad(set_to_none=True)
# Optionally freeze LN or QK
if freeze_ln or freeze_attn:
for name, param in local_model.base.named_parameters():
if freeze_ln and ("ln" in name):
param.requires_grad_(False)
if freeze_attn and (".W_Q" in name or ".W_K" in name):
param.requires_grad_(False)
# Hook
storage = {}
def store_resid_hook(resid, hook):
storage["resid"] = resid
return resid
# E.g. blocks.{layer_idx}.hook_resid_mid
resid_name = f"blocks.{layer_idx}.hook_resid_mid"
if resid_name not in local_model.base.hook_dict:
# fallback
resid_name = f"blocks.{layer_idx}.hook_resid_post"
handle = local_model.base.add_hook(resid_name, store_resid_hook, "fwd")
with torch.enable_grad():
_ = local_model(tokens)
if handle is not None:
handle.remove()
if "resid" not in storage:
raise ValueError(f"Could not find residual for layer {layer_idx}: {resid_name}")
resid_L = storage["resid"]
print(f" resid_L.shape = {tuple(resid_L.shape)}")
# Now we verify the injection_vec shape is [n_targets, d_model].
n_targets = target_injection_vectors.shape[0]
if target_injection_vectors.ndim != 2 or target_injection_vectors.shape[1] != d_model:
raise RuntimeError(
f"[DEBUG ERROR] Expected injection_vectors shape [n_targets, {d_model}], "
f"but got {list(target_injection_vectors.shape)}"
)
# 4) build partial losses
injection_mat = target_injection_vectors.view(1, 1, d_model, n_targets)
expanded_resid = resid_L.unsqueeze(-1) # => [batch, seq, d_model, 1]
product = expanded_resid * injection_mat # => [batch, seq, d_model, n_targets]
partial_losses = product.sum(dim=[0,1,2]) # => shape [n_targets]
# 3) Provide grad_outputs => identity, so we get full Jacobian
outs = partial_losses # shape [n_targets]
jac_list = []
for i in range(n_targets):
unit_vec = torch.zeros_like(outs) # shape [n_targets]
unit_vec[i] = 1.0
# Now partial derivative of outs·unit_vec => outs[i]
grad_i = torch.autograd.grad(
outputs=outs, # [n_targets]
inputs=resid_L, # [batch, seq, d_model]
grad_outputs=unit_vec, # also [n_targets]
retain_graph=True,
)[0] # => shape [batch, seq, d_model]
jac_list.append(grad_i.unsqueeze(-1)) # => [batch, seq, d_model, 1]
print(f"index={i}, grad_i.shape={tuple(grad_i.shape)}")
# Finally cat them => [batch, seq, d_model, n_targets]
grad_wrt_resid = torch.cat(jac_list, dim=-1)
print(f" => grad_wrt_resid.shape = {tuple(grad_wrt_resid.shape)}\n")
return grad_wrt_resid
putting it together so far
#### collect 5k samples
main_collect_5k()
#### train the CLT
dataset = CLTHiddenStateDataset("clt_data")
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
model_name = "gpt2-small"
hmodel = HookedTransformer.from_pretrained(model_name)
layer_feature_list = [128]*hmodel.cfg.n_layers # or something custom
cross_layer_transcoder = CrossLayerTranscoder.from_hookedtransformer(
hmodel, layer_feature_list, bandwidth=1.0, device="cuda"
)
train_clt(cross_layer_transcoder, dataloader, 2)
#### build the local replacement model
prompt = "Hello, I'd like you to help me plan my wedding please."
local_replacement_model = build_local_replacement_model_with_cache(hmodel, cross_layer_transcoder, prompt, device="cuda")
#### Let's compare outputs of the original model vs. the local replacement
with torch.no_grad():
tokens = hmodel.to_tokens(prompt, prepend_bos=True).to("cuda")
orig_logits = hmodel(tokens)
rep_logits = local_replacement_model(tokens)
print("Are the logits close?", torch.allclose(orig_logits, rep_logits, atol=1e-5))
tokens = hmodel.to_tokens(prompt, prepend_bos=True)
#### 1) Build node set
G, cache = build_graph_nodes(local_replacement_model, prompt, top_k=3)
#### 2) compute direct edges for each logit node or MLP node
build_direct_edges_single_pass_debug(G, local_replacement_model, tokens, threshold=1e-4)
#### 3) prune
pruned_G = prune_graph(G, cache["logit_node_ids"],
threshold_nodes=0.8,
threshold_edges=0.98)
there ya go. some breakdown of anthropic’s paper with the assistance of some of my dearest friends: sonnet 3.7 thinking, o1-pro, gpt-4o, o3-mini-high, and long time homie, sonnet 3.5.