Multi Query Low Rank Attention
==================
this is sort of a strange blog bc it’s more actually just a stream of consciousness into a random exploration I had.
from what I understand, one of the complaints of individuals working at big labs is that there’s less ‘curiosity’ driven research, and more bet driven research. in this paradigm, you have a bet on a direction, and you’re trying to make it work, so you go explore that direction instead of just wandering down a path that interests you. when you’re betting bagillions of dollars, this makes sense, but for little old me, I sometimes just like to wander.
in order for ideas/intuition to fully develop in my brain (which hopefully then compounds into better future ideas), I usually need time to take dumb ideas to their inevitable failure or the intuition doesn’t properly develop; even if the idea itself is dumber than a box of rocks. the cool thing about these random walks, (and all research these days), is that o1/sonnet/r1/grok/etc have all sort of made this process much easier to do. eg, instead of painstakingly coding up a new idea and then losing interest completely by the end of it, you just sort of “vibe research”. you control the generative path, but the llm sort of implement the pieces of the idea that you want it to and you see where it goes. so the idea sort of starts on one end and sometimes ends at another completely unrelated end.
ideagen
i was thinking about how to get smaller and smaller models while being as fast as humanly possible without something like distillation; rather just focusing on the architecture. something that potentially can be used in conjunction with like a memory unit for an agent (we’ll get there in another blog) and is super fast at inference time. the idea sort of stemmed from 1) how to make attention faster and 2) if you’re ever to own you’re own personal AI running on your own hardware, with some kind of human/alien like memory, you’d want to never have to reset a conversation/agentic flow, it would just sort of have this infinite memory reservoir of all your interactions with it; so basically it comes down to 1) what is Noam Shazeer doing architecturally, and 2) how can we get the attention to be faster.
asking that/those question(s) leads me to some form of Multi Query Attention (MQA) and some form of compression (kv cache MLA from DeepSeek for example). so let’s make MQA faster. ok, so what’s a small didly on how we can make MQA faster in the easiest way possible? i dunno, top of mind - make it low-rank or something. sure, why not. let’s see what kind of trouble we can get ourselves into. logically this will reduce the expressiveness, but why the hell not just try it for fun.
when you think about it, it’s basically lora without the lora. it’ll be building the entire attention layer from a low-rank parameterization instead of factorizing frozen weights. almost like training a lora from scratch.
hasn’t this been done before? probably, but who cares. i’m not even going to look into it because i just want to see where this line of thinking naturally goes. let’s see what happens.
mqlra: multi-query low-rank attention
MQA basically means you share the same query/key across multiple heads. it’s comparison to standard mha is below:
# Create separate learnable weight matrices per head for Q, K, V
W_Q = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])
W_K = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])
W_V = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])
def scaled_dot_product_attention(Q, K, V):
"""
Q: (B, T, d_k)
K: (B, T, d_k)
V: (B, T, d_k)
"""
# batch matrix multiply => (B, T, T)
scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)
attn_weights = F.softmax(scores, dim=-1) # (B, T, T)
# multiply by V => (B, T, d_k)
output = torch.bmm(attn_weights, V)
return output, attn_weights
def standard_multi_head_attention(x, W_Q_list, W_K_list, W_V_list):
"""
x: (B, T, d_model)
Returns: (B, T, n_heads * d_k)
"""
head_outputs = []
for h in range(n_heads):
Q = x @ W_Q_list[h] # (B, T, d_k)
K = x @ W_K_list[h] # (B, T, d_k)
V = x @ W_V_list[h] # (B, T, d_k)
out, _ = scaled_dot_product_attention(Q, K, V)
head_outputs.append(out)
# Concatenate along the last dimension
return torch.cat(head_outputs, dim=-1) # (B, T, n_heads*d_k)
out_std = standard_multi_head_attention(x, W_Q, W_K, W_V)
print("Standard Multi-Head output shape:", out_std.shape)
# Create per-head Q, but only ONE K and V
WQ_mq = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])
WK_mq = torch.nn.Parameter(torch.randn(d_model, d_k)) # Shared across heads
WV_mq = torch.nn.Parameter(torch.randn(d_model, d_k)) # Shared across heads
def multi_query_attention(x, WQ_list, WK, WV):
"""
x: (B, T, d_model)
Returns: (B, T, n_heads * d_k) -- same shape as standard multi-head
"""
K = x @ WK # (B, T, d_k), shared
V = x @ WV # (B, T, d_k), shared
head_outputs = []
for h in range(n_heads):
Q_h = x @ WQ_list[h] # (B, T, d_k)
out_h, _ = scaled_dot_product_attention(Q_h, K, V)
head_outputs.append(out_h)
return torch.cat(head_outputs, dim=-1)
out_mq = multi_query_attention(x, WQ_mq, WK_mq, WV_mq)
print("Multi-Query Attention output shape:", out_mq.shape)
ok, so now let’s take that same idea and just make these matrices low-rank.
class MQLRA(nn.Module):
"""
Simplified Multi-Query Attention + Low-Rank factorization.
- We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
- We have a *single* K factorization: (W1K, W2K)
- We have a *single* V factorization: (W1V, W2V)
- This implementation uses standard PyTorch operations for readability.
Input shape:
X: [B, L, D_in]
We'll produce an attention output shape of:
[B, L, D_out * n_heads] (by concatenating the n_heads outputs).
"""
def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
super().__init__()
self.D_in = D_in
self.D_out = D_out
self.n_heads = n_heads
self.rank = rank
self.causal = causal
if scale is not None:
self.scale = float(scale)
else:
self.scale = 1.0 / math.sqrt(D_out)
# Q: separate for each head
# Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
self.W1Q_heads = nn.ParameterList([
nn.Parameter(torch.randn(D_in, rank) * 0.02)
for _ in range(n_heads)
])
self.W2Q_heads = nn.ParameterList([
nn.Parameter(torch.randn(rank, D_out) * 0.02)
for _ in range(n_heads)
])
# K: single
self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)
# V: single
self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)
def forward(self, X, attn_bias=None):
"""
X => [B, L, D_in]
attn_bias => optional, shape broadcastible to [B, n_heads, L, L] if needed
Returns: [B, L, n_heads*D_out]
"""
batch_size, seq_len, _ = X.shape
device = X.device
dtype = X.dtype
# Compute K and V once (shared across all heads)
# Low-rank projections for K, V
partial_k = X @ self.W1K # [B, L, rank]
K = partial_k @ self.W2K # [B, L, D_out]
partial_v = X @ self.W1V # [B, L, rank]
V = partial_v @ self.W2V # [B, L, D_out]
# Prepare output tensor to accumulate results from all heads
out = torch.zeros(batch_size, seq_len, self.n_heads * self.D_out,
device=device, dtype=dtype)
# Process each head separately
for h in range(self.n_heads):
# Low-rank projection for Q (head-specific)
partial_q = X @ self.W1Q_heads[h] # [B, L, rank]
Q = partial_q @ self.W2Q_heads[h] # [B, L, D_out]
# Compute attention scores
# [B, L, D_out] @ [B, D_out, L] -> [B, L, L]
scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale
# Apply causal mask if needed
if self.causal:
# Create causal mask and apply it
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
scores.masked_fill_(mask, -float('inf'))
# Apply attention bias if provided
if attn_bias is not None:
scores = scores + attn_bias
# Softmax and apply attention
attn_weights = torch.softmax(scores, dim=-1) # [B, L, L]
head_output = torch.bmm(attn_weights, V) # [B, L, D_out]
# Add to output tensor at the correct position
out[:, :, h * self.D_out:(h + 1) * self.D_out] = head_output
return out
for a super easy implementation, we show how this would fit into Keller Jordan’s modded nanogpt which means we basically take our MQLRA and format it the way Keller formats his CausalSelfAttention
class and then simply drop it into the script
# -----------------------------------------------------------------------------
# Custom operators: FP8 matmul by @YouJiacheng
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
@torch.compile
def impl(x: Tensor, w: Tensor):
assert x.is_contiguous() and w.is_contiguous()
x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(x_s, dtype=torch.float32),
scale_b=x.new_tensor(w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1]
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]:
@torch.compile
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
assert grad.is_contiguous()
x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
grad_x = torch._scaled_mm(
grad_f8,
w_f8.T.contiguous().T,
out_dtype=torch.bfloat16,
scale_a=grad_inv_s,
scale_b=w_inv_s,
use_fast_accum=False,
)
# faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
grad_w = torch._scaled_mm(
x_f8.T.contiguous(),
grad_f8.T.contiguous().T,
out_dtype=torch.float32,
scale_a=x_inv_s,
scale_b=grad_inv_s,
use_fast_accum=False,
).T
return grad_x, grad_w
return impl(g, x_f8, w_f8)
@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32)
def backward(ctx, grad_out: Tensor, *_):
x_f8, w_f8 = ctx.saved_tensors
x_s, w_s, grad_s = ctx.scales
grad_x, grad_w = torch.ops.nanogpt.mm_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
return grad_x, grad_w, None, None, None
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
*_, x_s, w_s, grad_s = inputs
_, x_f8, w_f8 = output
ctx.save_for_backward(x_f8, w_f8)
ctx.scales = x_s, w_s, grad_s
ctx.set_materialize_grads(False)
mm_op.register_autograd(backward, setup_context=setup_context)
# skipping the incredibly based Muon optimizer for brevity
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the model
def norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
class CastedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = use_fp8
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3)
bound = (3 ** 0.5) * std
with torch.no_grad():
self.weight.uniform_(-bound, bound)
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
super().__init__()
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
t = torch.arange(max_seq_len, dtype=torch.float32)
theta = torch.einsum("i,j -> ij", t, angular_freq)
self.cos = nn.Buffer(theta.cos(), persistent=False)
self.sin = nn.Buffer(theta.sin(), persistent=False)
def forward(self, x_BTHD: Tensor):
assert self.cos.size(0) >= x_BTHD.size(-3)
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
hdim = num_heads * head_dim
std = 0.5 * (dim ** -0.5)
bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
# merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
# https://x.com/hi_tysam/status/1879699187107033311
self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound))
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
self.rotary = Rotary(head_dim, max_seq_len)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "Must use batch size = 1 for FlexAttention"
q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
q, k = norm(q), norm(k) # QK norm @Grad62304977
q, k = self.rotary(q), self.rotary(k)
if ve is not None:
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977
else: # skip mid-layers token value embeddings by @YouJiacheng
v = self.lambdas[0] * v
# scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
# inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
hdim = 4 * dim
self.c_fc = CastedLinear(dim, hdim)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
def forward(self, x: Tensor):
x = self.c_fc(x)
x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
super().__init__()
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng
self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
self.mlp = MLP(dim)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.attn is not None:
x = x + self.attn(norm(x), ve, block_mask)
x = x + self.mlp(norm(x))
return x
# -----------------------------------------------------------------------------
# The main model
def next_multiple_of_n(v: float | int, *, n: int):
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
class GPT(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim)
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
# value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)])
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
# suggested to me by @Grad62304977. this originates from Karpathy's experiments.
self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128),
use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448)
self.lm_head.weight.detach().zero_() # @Grad62304977
# Add learnable skip connection weights for decoder layers
assert num_layers % 2 == 0
self.skip_weights = nn.Parameter(torch.ones(num_layers//2))
def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
BLOCK_SIZE = 128
docs = (input_seq == 50256).cumsum(0)
def document_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
return causal_mask & document_mask
def dense_to_ordered(dense_blockmask: Tensor):
num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
# manual block mask creation by @YouJiacheng
assert len(input_seq) % BLOCK_SIZE == 0
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
causal_blockmask_any = block_idx[:, None] >= block_idx
causal_blockmask_all = block_idx[:, None] > block_idx
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
blockmask_any = causal_blockmask_any & document_blockmask_any
blockmask_all = causal_blockmask_all & document_blockmask_all
partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
def build_bm(window_size_blocks: Tensor) -> BlockMask:
return BlockMask.from_kv_blocks(
torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
partial_kv_indices,
torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
full_kv_indices,
BLOCK_SIZE=BLOCK_SIZE,
mask_mod=document_causal,
)
# Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
assert input_seq.ndim == 1
ve = [value_embed(input_seq) for value_embed in self.value_embeds]
# 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
assert len(ve) == len(self.blocks)
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
assert len(block_masks) == len(self.blocks)
x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977
# U-net design by @brendanh0gan
skip_connections = []
n = len(self.skip_weights)
for i in range(len(self.blocks)):
if i >= n:
x = x + self.skip_weights[i - n] * skip_connections.pop()
x = self.blocks[i](x, ve[i], x0, block_masks[i])
if i < n:
skip_connections.append(x)
x = norm(x)
logits = self.lm_head(x).float()
# @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction='sum' if self.training else 'mean')
return loss
class MQLRAFlexAttention(nn.Module):
"""
Multi-Query Low-Rank Attention + FlexAttention
that acts as a drop-in replacement for CausalSelfAttention.
- We replicate the same structure:
* a rotary embedding
* a final CastedLinear (c_proj)
* a lambdas parameter for combining 've' (value embeddings)
* multi-head query, single K/V
- We still do QK norm, then apply rotary, then call flex_attention.
- The difference: we use (W1Q_h, W2Q_h) for Q per head, plus single (W1K,W2K)
and (W1V,W2V) for K, V.
"""
def __init__(
self,
dim: int,
num_heads: int,
max_seq_len: int,
head_dim=128,
rank=16, # low-rank dimension
scale=0.12, # default scale from your code
):
"""
Args:
dim: model dimension (input and output).
num_heads: number of heads.
max_seq_len: for rotary embedding.
head_dim: dimension per head (like in CausalSelfAttention).
rank: low-rank dimension for Q, K, V factorization.
scale: scale factor for attention logits, e.g. 0.12 in your code.
"""
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.rank = rank
self.scale = scale
# total hidden dimension after merging heads
hdim = num_heads * head_dim
# optional initialization approach
std = 0.5 * (dim ** -0.5)
bound = (3 ** 0.5) * std
# We keep lambdas for mixing in 've'
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
# Rotary embedding
self.rotary = Rotary(head_dim, max_seq_len)
# final projection after attention
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_()
# Q: separate (W1Q, W2Q) for each head
# shape: W1Q => [dim, rank], W2Q => [rank, head_dim]
# We'll store them in ParameterList to keep them distinct per head.
self.W1Q_heads = nn.ParameterList([
nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
for _ in range(num_heads)
])
self.W2Q_heads = nn.ParameterList([
nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))
for _ in range(num_heads)
])
# K: single (W1K, W2K)
self.W1K = nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
self.W2K = nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))
# V: single (W1V, W2V)
self.W1V = nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
self.W2V = nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))
def forward(self, x: torch.Tensor, ve: torch.Tensor | None, block_mask: BlockMask):
"""
x: [B, T, dim]
ve: optional value-embedding, shape must broadcast with V
block_mask: for flex_attention
Returns:
[B, T, dim] after c_proj
"""
B, T, dim = x.shape
# If your flex_attention kernel only works for B=1:
assert B == 1, "Must use batch size = 1 for FlexAttention"
# 1) Compute Q, K, V in multi-query style
# Q => [B, T, num_heads, head_dim]
# single K => broadcast to [B, T, num_heads, head_dim]
# single V => broadcast similarly
# Q (separate per head)
Q_list = []
for h in range(self.num_heads):
partial_q = x.matmul(self.W1Q_heads[h]) # [B, T, rank]
q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, T, head_dim]
Q_list.append(q_h)
# stack => [B, T, num_heads, head_dim]
q = torch.stack(Q_list, dim=2)
# single K
partial_k = x.matmul(self.W1K) # [B, T, rank]
K_ = partial_k.matmul(self.W2K) # [B, T, head_dim]
# expand to match num_heads => [B, T, num_heads, head_dim]
k = K_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)
# single V
partial_v = x.matmul(self.W1V) # [B, T, rank]
v_ = partial_v.matmul(self.W2V) # [B, T, head_dim]
v = v_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)
# 2) norm Q, K + rotary
# same approach as CausalSelfAttention: norm(q), norm(k), then rotary
q = norm(q)
k = norm(k)
q = self.rotary(q)
k = self.rotary(k)
# 3) incorporate 've' into v if provided
if ve is not None:
# ve => shape something like [B, T, num_heads, head_dim] or broadcastable
# in CausalSelfAttention, they do `v = lambdas[0]*v + lambdas[1]*ve.view_as(v)`
# We do the same:
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v)
else:
v = self.lambdas[0] * v
# 4) call flex_attention
# flex_attention expects shape [B, n_heads, T, head_dim]
# so we do transpose(1,2)
q_t = q.transpose(1, 2) # => [B, num_heads, T, head_dim]
k_t = k.transpose(1, 2)
v_t = v.transpose(1, 2)
# we have scale = self.scale, block_mask
# from the snippet: y = flex_attention(q_t, k_t, v_t, block_mask, scale=0.12)
y_t = flex_attention(q_t, k_t, v_t, block_mask=block_mask, scale=self.scale)
# y_t => [B, num_heads, T, head_dim]
# transpose back => [B, T, num_heads, head_dim]
y = y_t.transpose(1, 2).contiguous()
# 5) reshape => [B, T, (num_heads * head_dim)]
y = y.view(B, T, self.num_heads * self.head_dim)
# final projection => [B, T, dim]
y = self.c_proj(y)
return y
# Manually add gradient checkpointing to Block
class MQLRAFlexBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int, rank: int):
super().__init__()
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng
self.attn = MQLRAFlexAttention(dim, num_heads, max_seq_len, rank=rank) if layer_idx != 7 else None
self.mlp = MLP(dim)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
# Use checkpoint instead of direct execution
return checkpoint_block(self._forward, x, ve, x0, block_mask)
def _forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.attn is not None:
x = x + self.attn(norm(x), ve, block_mask)
x = x + self.mlp(norm(x))
return x
class ModifiedGPTWithMQLRAFlex(nn.Module):
"""Modified GPT model that uses MqaLraFlashAttention with configurable rank"""
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, rank: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim)
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
# value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
self.blocks = nn.ModuleList([MQLRAFlexBlock(model_dim, num_heads, max_seq_len, i, rank) for i in range(num_layers)])
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
# suggested to me by @Grad62304977. this originates from Karpathy's experiments.
self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128),
use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448)
self.lm_head.weight.detach().zero_() # @Grad62304977
# Add learnable skip connection weights for decoder layers
assert num_layers % 2 == 0
self.skip_weights = nn.Parameter(torch.ones(num_layers//2))
self._last_loss_value = None
def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
BLOCK_SIZE = 64
docs = (input_seq == 50256).cumsum(0)
def document_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
return causal_mask & document_mask
def dense_to_ordered(dense_blockmask: Tensor):
num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
# manual block mask creation by @YouJiacheng
assert len(input_seq) % BLOCK_SIZE == 0
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
causal_blockmask_any = block_idx[:, None] >= block_idx
causal_blockmask_all = block_idx[:, None] > block_idx
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
blockmask_any = causal_blockmask_any & document_blockmask_any
blockmask_all = causal_blockmask_all & document_blockmask_all
partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
def build_bm(window_size_blocks: Tensor) -> BlockMask:
return BlockMask.from_kv_blocks(
torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
partial_kv_indices,
torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
full_kv_indices,
BLOCK_SIZE=BLOCK_SIZE,
mask_mod=document_causal,
)
# Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
assert input_seq.ndim == 1
ve = [value_embed(input_seq) for value_embed in self.value_embeds]
# 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
assert len(ve) == len(self.blocks)
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
assert len(block_masks) == len(self.blocks)
x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977
# U-net design by @brendanh0gan
skip_connections = []
n = len(self.skip_weights)
for i in range(len(self.blocks)):
if i >= n:
x = x + self.skip_weights[i - n] * skip_connections.pop()
x = self.blocks[i](x, ve[i], x0, block_masks[i])
if i < n:
skip_connections.append(x)
x = norm(x)
logits = self.lm_head(x).float()
# @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction='sum' if self.training else 'mean')
# If we're in training mode, track the loss components
if self.training and dist.get_rank() == 0:
# Store the loss components for analysis
self._last_loss_value = float(loss.item())
return loss
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40):
"""Generate text from the model"""
self.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
# Crop idx to manageable context if needed
idx_cond = idx if idx.size(1) <= 1024 else idx[:, -1024:]
# Get logits for next token prediction
logits, _ = self._forward_generation(idx_cond)
logits = logits[:, -1, :] / temperature # Focus on last position
# Apply top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('inf')
# Sample from distribution
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# Append to sequence
idx = torch.cat((idx, idx_next), dim=1)
return idx
def _forward_generation(self, input_seq):
"""Handle generation case where we only need logits"""
# Make sure input is properly formatted
if input_seq.ndim == 2: # [B, T]
batch_size = input_seq.size(0)
if batch_size > 1:
# Only use the first batch item for simplicity in generation
input_seq = input_seq[0]
else:
input_seq = input_seq.view(-1)
# Use a reasonable window size for generation
sliding_window_num_blocks = torch.tensor(8, dtype=torch.int32, device=input_seq.device)
# Ensure input is padded to BLOCK_SIZE multiple
BLOCK_SIZE = 64 # Same as in create_blockmasks
original_len = len(input_seq)
remainder = original_len % BLOCK_SIZE
if remainder != 0:
pad_length = BLOCK_SIZE - remainder
padding = torch.zeros(pad_length, dtype=input_seq.dtype, device=input_seq.device)
input_seq = torch.cat([input_seq, padding])
# Get token embeddings for the padded sequence
x = x0 = norm(self.embed(input_seq)[None]) # use of norm here matches forward()
# Create block masks for attention with padded sequence
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm,
short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
# Get value embeddings for the padded sequence
ve = [value_embed(input_seq) for value_embed in self.value_embeds]
ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
# Process through blocks like in forward method
skip_connections = []
n = len(self.skip_weights)
for i in range(len(self.blocks)):
if i >= n:
x = x + self.skip_weights[i - n] * skip_connections.pop()
# Critical fix: Only pass value embeddings to blocks that expect them
current_ve = ve[i] if i < len(ve) and ve[i] is not None else None
current_mask = block_masks[i] if i < len(block_masks) else None
# Disable gradient checkpointing during generation
if hasattr(self.blocks[i], 'use_checkpoint'):
original_checkpoint = self.blocks[i].use_checkpoint
self.blocks[i].use_checkpoint = False
x = self.blocks[i](x, current_ve, x0, current_mask)
self.blocks[i].use_checkpoint = original_checkpoint
else:
x = self.blocks[i](x, current_ve, x0, current_mask)
if i < n:
skip_connections.append(x)
# Final norm and project to logits
x = norm(x)
logits = self.lm_head(x).float()
# Return logits for generation and the hidden state
return logits, x
flash/flex attention
i was already deep down the rabbit hole, so i figured why not attempt to drop in flash/flex attention on top of the attention matmuls and see if we can get an even larger speed boost

# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
# assume we have the flash attention triton definition:
# *Experimental* implementation of FlashAttention in Triton.
# Tested with triton==2.0.0.dev20221202.
#...
#...
# flash_attn_func = FlashAttnFunc.apply
# from flash_attn_triton import flash_attn_func
class MQLRAFlashAttention(nn.Module):
"""
Multi-Query Attention + Low-Rank factorization + FlashAttention.
- We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
- We have a *single* K factorization: (W1K, W2K)
- We have a *single* V factorization: (W1V, W2V)
- Then we call the standard Triton FlashAttention kernel on the resulting Q, K, V.
Input shape:
X: [B, L, D_in]
We'll produce an attention output shape of:
[B, L, D_out * n_heads] (by concatenating the n_heads outputs).
We store parameters for Q in `self.W1Q_heads[i], self.W2Q_heads[i]`,
and single K, V in `self.W1K, W2K, W1V, W2V`.
"""
def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
super().__init__()
self.D_in = D_in
self.D_out = D_out
self.n_heads = n_heads
self.rank = rank
self.causal = causal
if scale is not None:
self.scale = float(scale)
else:
self.scale = 1.0 / math.sqrt(D_out)
# Q: separate for each head
# Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
self.W1Q_heads = nn.ParameterList([
nn.Parameter(torch.randn(D_in, rank) * 0.02)
for _ in range(n_heads)
])
self.W2Q_heads = nn.ParameterList([
nn.Parameter(torch.randn(rank, D_out) * 0.02)
for _ in range(n_heads)
])
# K: single
self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)
# V: single
self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)
def forward(self, X, attn_bias=None):
"""
X => [B, L, D_in]
attn_bias => optional, shape broadcastible to [B, n_heads, L, L] if needed
We'll produce out => [B, L, n_heads*D_out].
Implementation steps:
1) Build Q (n_heads) => shape [B, L, n_heads, D_out]
2) Build K => shape [B, L, D_out], replicate to [B, L, 1, D_out]
3) Build V => shape [B, L, D_out], replicate to [B, L, 1, D_out]
4) Call flash-attn forward => out => shape [B, L, n_heads, D_out]
5) Reshape to [B, L, n_heads*D_out].
"""
B, L, _ = X.shape
# Step 1) Q => [B, L, n_heads, D_out]
# We'll compute each head's Q in Python. Then stack along "n_heads".
# partial_Q_h = X @ W1Q_heads[h] => [B,L,rank]
# Q_h = partial_Q_h @ W2Q_heads[h] => [B,L,D_out]
Q_list = []
for h in range(self.n_heads):
partial_q = X.matmul(self.W1Q_heads[h]) # [B, L, rank]
q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, L, D_out]
Q_list.append(q_h)
# stack => [B, L, n_heads, D_out]
Q = torch.stack(Q_list, dim=2)
# Step 2) K => [B,L,D_out], then expand to [B,L,1,D_out]
partial_k = X.matmul(self.W1K) # [B,L,rank]
K_ = partial_k.matmul(self.W2K) # [B,L,D_out]
# For multi-query: we have a single K => we can just unsqueeze dim=2 => n_heads=1
# But if we want flash to see n_heads == self.n_heads, we replicate:
# shape => [B, L, n_heads, D_out]
# but logically the data is the same. We'll do a .unsqueeze(2).expand(...)
# or we can do a .repeat_interleave, but that costs memory. Alternatively we can do
# a trick: we pass n_heads=1 to flash, but that won't match Q's shape. So we must replicate:
# This is the simplest approach:
K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)
# Step 3) V => same approach
partial_v = X.matmul(self.W1V) # [B,L,rank]
V_ = partial_v.matmul(self.W2V) # [B,L,D_out]
V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)
# Step 4) flash-attn => out => [B, L, n_heads, D_out]
# The flash_attn code expects Q => [B, L, n_heads, headdim]
# so we pass Q,K,V in that shape. We can pass an optional bias if it matches shape [B,n_heads,L,L].
Q = Q.contiguous()
K = K.contiguous()
V = V.contiguous()
print("Q shape:", Q.shape, "stride:", Q.stride())
print("K shape:", K.shape, "stride:", K.stride())
print("V shape:", V.shape, "stride:", V.stride())
out = flash_attn_func(Q, K, V,
attn_bias,
self.causal,
self.scale)
# out => [B, L, nheads, D_out]
# Step 5) reshape => [B,L,n_heads*D_out]
out = out.reshape(B, L, self.n_heads*self.D_out)
return out
class MQLRAFlexAttention(nn.Module):
"""
Multi-Query Attention + Low-Rank factorization + FlashAttention.
- We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
- We have a *single* K factorization: (W1K, W2K)
- We have a *single* V factorization: (W1V, W2V)
- Then we call the standard Triton FlashAttention kernel on the resulting Q, K, V.
Input shape:
X: [B, L, D_in]
We'll produce an attention output shape of:
[B, L, D_out * n_heads] (by concatenating the n_heads outputs).
We store parameters for Q in `self.W1Q_heads[i], self.W2Q_heads[i]`,
and single K, V in `self.W1K, W2K, W1V, W2V`.
"""
def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
super().__init__()
self.D_in = D_in
self.D_out = D_out
self.n_heads = n_heads
self.rank = rank
self.causal = causal
if scale is not None:
self.scale = float(scale)
else:
self.scale = 1.0 / math.sqrt(D_out)
# Q: separate for each head
# Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
self.W1Q_heads = nn.ParameterList([
nn.Parameter(torch.randn(D_in, rank) * 0.02)
for _ in range(n_heads)
])
self.W2Q_heads = nn.ParameterList([
nn.Parameter(torch.randn(rank, D_out) * 0.02)
for _ in range(n_heads)
])
# K: single
self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)
# V: single
self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)
def forward(self, X, attn_bias=None):
"""
X => [B, L, D_in]
attn_bias => optional, shape broadcastible to [B, n_heads, L, L] if needed
We'll produce out => [B, L, n_heads*D_out].
Implementation steps:
1) Build Q (n_heads) => shape [B, L, n_heads, D_out]
2) Build K => shape [B, L, D_out], replicate to [B, L, 1, D_out]
3) Build V => shape [B, L, D_out], replicate to [B, L, 1, D_out]
4) Call flash-attn forward => out => shape [B, L, n_heads, D_out]
5) Reshape to [B, L, n_heads*D_out].
"""
B, L, _ = X.shape
# Step 1) Q => [B, L, n_heads, D_out]
# We'll compute each head's Q in Python. Then stack along "n_heads".
# partial_Q_h = X @ W1Q_heads[h] => [B,L,rank]
# Q_h = partial_Q_h @ W2Q_heads[h] => [B,L,D_out]
Q_list = []
for h in range(self.n_heads):
partial_q = X.matmul(self.W1Q_heads[h]) # [B, L, rank]
q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, L, D_out]
Q_list.append(q_h)
# stack => [B, L, n_heads, D_out]
Q = torch.stack(Q_list, dim=2)
# Step 2) K => [B,L,D_out], then expand to [B,L,1,D_out]
partial_k = X.matmul(self.W1K) # [B,L,rank]
K_ = partial_k.matmul(self.W2K) # [B,L,D_out]
# For multi-query: we have a single K => we can just unsqueeze dim=2 => n_heads=1
# But if we want flash to see n_heads == self.n_heads, we replicate:
# shape => [B, L, n_heads, D_out]
# but logically the data is the same. We'll do a .unsqueeze(2).expand(...)
# or we can do a .repeat_interleave, but that costs memory. Alternatively we can do
# a trick: we pass n_heads=1 to flash, but that won't match Q's shape. So we must replicate:
# This is the simplest approach:
K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)
# Step 3) V => same approach
partial_v = X.matmul(self.W1V) # [B,L,rank]
V_ = partial_v.matmul(self.W2V) # [B,L,D_out]
V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)
# Step 4) flash-attn => out => [B, L, n_heads, D_out]
# The flash_attn code expects Q => [B, L, n_heads, headdim]
# so we pass Q,K,V in that shape. We can pass an optional bias if it matches shape [B,n_heads,L,L].
Q = Q.contiguous()
K = K.contiguous()
V = V.contiguous()
print("Q shape:", Q.shape, "stride:", Q.stride())
print("K shape:", K.shape, "stride:", K.stride())
print("V shape:", V.shape, "stride:", V.stride())
out = flash_attn_func(Q, K, V,
attn_bias,
self.causal,
self.scale)
# out => [B, L, nheads, D_out]
# Step 5) reshape => [B,L,n_heads*D_out]
out = out.reshape(B, L, self.n_heads*self.D_out)
return out
class MqaLraFlexAttention(nn.Module):
"""
Multi-Query Attention + Low-Rank factorization + FlexAttention.
- We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
- We have a *single* K factorization: (W1K, W2K)
- We have a *single* V factorization: (W1V, W2V)
- Then we call the Triton-based `flex_attention` kernel on the resulting Q, K, V,
similar to how `CausalSelfAttention` uses flex_attention.
Input shape:
X: [B, L, D_in]
Output shape:
[B, L, D_out * n_heads]
"""
def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
super().__init__()
self.D_in = D_in
self.D_out = D_out
self.n_heads = n_heads
self.rank = rank
self.causal = causal
# If scale is not provided, use 1/sqrt(D_out) by default
self.scale = float(scale) if scale is not None else (1.0 / math.sqrt(D_out))
# Q: separate for each head
# Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
self.W1Q_heads = nn.ParameterList(
nn.Parameter(torch.randn(D_in, rank) * 0.02)
for _ in range(n_heads)
)
self.W2Q_heads = nn.ParameterList(
nn.Parameter(torch.randn(rank, D_out) * 0.02)
for _ in range(n_heads)
)
# K: single
self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)
# V: single
self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)
def forward(self, X, block_mask=None, attn_bias=None):
"""
X: [B, L, D_in]
block_mask: the `BlockMask` (if needed by flex_attention).
attn_bias: optional, shape broadcastible to [B, n_heads, L, L] if needed.
Returns: [B, L, n_heads * D_out]
"""
B, L, _ = X.shape
# If your flex_attention kernel only supports B=1, you may want:
# assert B == 1, "We only support batch_size=1 for flex_attention."
# 1) Compute Q per head => [B, L, n_heads, D_out]
Q_list = []
for h in range(self.n_heads):
partial_q = X @ self.W1Q_heads[h] # [B, L, rank]
q_h = partial_q @ self.W2Q_heads[h] # [B, L, D_out]
Q_list.append(q_h)
Q = torch.stack(Q_list, dim=2) # => [B, L, n_heads, D_out]
# 2) Compute K => [B, L, D_out]
partial_k = X @ self.W1K # [B, L, rank]
K_ = partial_k @ self.W2K # [B, L, D_out]
# Expand K => [B, L, n_heads, D_out] for multi-query
K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)
# 3) Compute V => [B, L, D_out]
partial_v = X @ self.W1V # [B, L, rank]
V_ = partial_v @ self.W2V # [B, L, D_out]
V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)
# 4) Call flex_attention
# flex_attention expects [B, n_heads, L, head_dim], so we transpose(1,2).
# Then we get back [B, n_heads, L, head_dim], which we transpose back to [B, L, n_heads, head_dim].
Q_t = Q.transpose(1, 2) # => [B, n_heads, L, D_out]
K_t = K.transpose(1, 2) # => [B, n_heads, L, D_out]
V_t = V.transpose(1, 2) # => [B, n_heads, L, D_out]
# flex attention here
out_t = flex_attention(Q_t, K_t, V_t,
block_mask=block_mask,
scale=self.scale)
# out_t => [B, n_heads, L, D_out]
# 5) Transpose back => [B, L, n_heads, D_out], then flatten
out = out_t.transpose(1, 2) # => [B, L, n_heads, D_out]
out = out.reshape(B, L, self.n_heads * self.D_out)
return out
one more change:
i figued one simple change would be to fuse the k and v matrices into a single matrix. this is a bit more efficient, and it’s a bit faster on training (less calls to GPU)
while in some kind of forward pass we could have the for loop parallelized for the original implementation. either way, i wanted to see what would happen.
class FusedMQLRAFlexAttention(nn.Module):
"""
Multi-Query Low-Rank Attention + FlexAttention with fused K and V matrices
for improved computational efficiency.
This version maintains the low-rank factorization of Q matrices for each head,
but uses fused matrices for K and V to reduce computation during the forward pass.
The Q matrices still use the (W1Q_h, W2Q_h) factorization per head, but
K and V use direct projection matrices.
"""
def __init__(
self,
dim: int,
num_heads: int,
max_seq_len: int,
head_dim=128,
rank=16, # low-rank dimension used for initialization and Q factorization
scale=0.12, # default scale from your code
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.rank = rank
self.scale = scale
# total hidden dimension after merging heads
hdim = num_heads * head_dim
# initialization approach
std = 0.5 * (dim ** -0.5)
bound = (3 ** 0.5) * std
# We keep lambdas for mixing in 've'
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
# Rotary embedding
self.rotary = Rotary(head_dim, max_seq_len)
# final projection after attention
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_()
# Q: separate (W1Q, W2Q) for each head (keep low-rank factorization for Q)
self.W1Q_heads = nn.ParameterList([
nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
for _ in range(num_heads)
])
self.W2Q_heads = nn.ParameterList([
nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))
for _ in range(num_heads)
])
# Fused K: instead of W1K and W2K, use a single matrix initialized as their product
W1K_init = torch.empty(dim, rank).uniform_(-bound, bound)
W2K_init = torch.empty(rank, head_dim).uniform_(-bound, bound)
self.K_matrix = nn.Parameter(W1K_init @ W2K_init)
# Fused V: instead of W1V and W2V, use a single matrix initialized as their product
W1V_init = torch.empty(dim, rank).uniform_(-bound, bound)
W2V_init = torch.empty(rank, head_dim).uniform_(-bound, bound)
self.V_matrix = nn.Parameter(W1V_init @ W2V_init)
def forward(self, x: torch.Tensor, ve: torch.Tensor | None, block_mask: BlockMask):
"""
x: [B, T, dim]
ve: optional value-embedding, shape must broadcast with V
block_mask: for flex_attention
Returns:
[B, T, dim] after c_proj
"""
B, T, dim = x.shape
# If your flex_attention kernel only works for B=1:
assert B == 1, "Must use batch size = 1 for FlexAttention"
# 1) Compute Q, K, V in multi-query style
# Q => [B, T, num_heads, head_dim]
# single K => broadcast to [B, T, num_heads, head_dim]
# single V => broadcast similarly
# Q (separate per head with low-rank factorization)
Q_list = []
for h in range(self.num_heads):
partial_q = x.matmul(self.W1Q_heads[h]) # [B, T, rank]
q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, T, head_dim]
Q_list.append(q_h)
# stack => [B, T, num_heads, head_dim]
q = torch.stack(Q_list, dim=2)
# K using fused matrix - one matmul instead of two
K_ = x.matmul(self.K_matrix) # [B, T, head_dim]
# expand to match num_heads => [B, T, num_heads, head_dim]
k = K_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)
# V using fused matrix - one matmul instead of two
v_ = x.matmul(self.V_matrix) # [B, T, head_dim]
v = v_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)
# 2) norm Q, K + rotary
q = norm(q)
k = norm(k)
q = self.rotary(q)
k = self.rotary(k)
# 3) incorporate 've' into v if provided
if ve is not None:
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v)
else:
v = self.lambdas[0] * v
# 4) call flex_attention
q_t = q.transpose(1, 2) # => [B, num_heads, T, head_dim]
k_t = k.transpose(1, 2)
v_t = v.transpose(1, 2)
y_t = flex_attention(q_t, k_t, v_t, block_mask=block_mask, scale=self.scale)
# y_t => [B, num_heads, T, head_dim]
# transpose back => [B, T, num_heads, head_dim]
y = y_t.transpose(1, 2).contiguous()
# 5) reshape => [B, T, (num_heads * head_dim)]
y = y.view(B, T, self.num_heads * self.head_dim)
# final projection => [B, T, dim]
y = self.c_proj(y)
return y
class FusedMQLRAFlexBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int, rank: int):
super().__init__()
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng
self.attn = FusedMQLRAFlexAttention(dim, num_heads, max_seq_len, rank=rank) if layer_idx != 7 else None
self.mlp = MLP(dim)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
# Use checkpoint instead of direct execution
return checkpoint_block(self._forward, x, ve, x0, block_mask)
def _forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.attn is not None:
x = x + self.attn(norm(x), ve, block_mask)
x = x + self.mlp(norm(x))
return x
quick and dirty training attempt w/ fused MQLRFlexAttention.
basically, in order to roughly get the same performance as the original modded nanogpt, we need to train for many more steps since 1) i don’t have an H100 and 1a) we can train with float8, so basically we have to shrink the sequence length + increase the number of iterations (which sucks, but whatever, we’re just throwing boredom speghetti at the wall here)
you can sort of train with the fused kernel, then run inference with the original kernel, or not, or not not.

nothing wild/ground breaking, just a bit of stream of consciousness.