==================

this is sort of a strange blog bc it’s more actually just a stream of consciousness into a random exploration I had.

from what I understand, one of the complaints of individuals working at big labs is that there’s less ‘curiosity’ driven research, and more bet driven research. in this paradigm, you have a bet on a direction, and you’re trying to make it work, so you go explore that direction instead of just wandering down a path that interests you. when you’re betting bagillions of dollars, this makes sense, but for little old me, I sometimes just like to wander.

in order for ideas/intuition to fully develop in my brain (which hopefully then compounds into better future ideas), I usually need time to take dumb ideas to their inevitable failure or the intuition doesn’t properly develop; even if the idea itself is dumber than a box of rocks. the cool thing about these random walks, (and all research these days), is that o1/sonnet/r1/grok/etc have all sort of made this process much easier to do. eg, instead of painstakingly coding up a new idea and then losing interest completely by the end of it, you just sort of “vibe research”. you control the generative path, but the llm sort of implement the pieces of the idea that you want it to and you see where it goes. so the idea sort of starts on one end and sometimes ends at another completely unrelated end.

ideagen

i was thinking about how to get smaller and smaller models while being as fast as humanly possible without something like distillation; rather just focusing on the architecture. something that potentially can be used in conjunction with like a memory unit for an agent (we’ll get there in another blog) and is super fast at inference time. the idea sort of stemmed from 1) how to make attention faster and 2) if you’re ever to own you’re own personal AI running on your own hardware, with some kind of human/alien like memory, you’d want to never have to reset a conversation/agentic flow, it would just sort of have this infinite memory reservoir of all your interactions with it; so basically it comes down to 1) what is Noam Shazeer doing architecturally, and 2) how can we get the attention to be faster.

asking that/those question(s) leads me to some form of Multi Query Attention (MQA) and some form of compression (kv cache MLA from DeepSeek for example). so let’s make MQA faster. ok, so what’s a small didly on how we can make MQA faster in the easiest way possible? i dunno, top of mind - make it low-rank or something. sure, why not. let’s see what kind of trouble we can get ourselves into. logically this will reduce the expressiveness, but why the hell not just try it for fun.

when you think about it, it’s basically lora without the lora. it’ll be building the entire attention layer from a low-rank parameterization instead of factorizing frozen weights. almost like training a lora from scratch.

hasn’t this been done before? probably, but who cares. i’m not even going to look into it because i just want to see where this line of thinking naturally goes. let’s see what happens.

mqlra: multi-query low-rank attention

MQA basically means you share the same query/key across multiple heads. it’s comparison to standard mha is below:

# Create separate learnable weight matrices per head for Q, K, V
W_Q = torch.nn.ParameterList([
    torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])
W_K = torch.nn.ParameterList([
    torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])
W_V = torch.nn.ParameterList([
    torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])

def scaled_dot_product_attention(Q, K, V):
    """
    Q: (B, T, d_k)
    K: (B, T, d_k)
    V: (B, T, d_k)
    """
    # batch matrix multiply => (B, T, T)
    scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)
    attn_weights = F.softmax(scores, dim=-1)  # (B, T, T)
    # multiply by V => (B, T, d_k)
    output = torch.bmm(attn_weights, V)
    return output, attn_weights

def standard_multi_head_attention(x, W_Q_list, W_K_list, W_V_list):
    """
    x: (B, T, d_model)
    Returns: (B, T, n_heads * d_k)
    """
    head_outputs = []
    for h in range(n_heads):
        Q = x @ W_Q_list[h]  # (B, T, d_k)
        K = x @ W_K_list[h]  # (B, T, d_k)
        V = x @ W_V_list[h]  # (B, T, d_k)

        out, _ = scaled_dot_product_attention(Q, K, V) 
        head_outputs.append(out)

    # Concatenate along the last dimension
    return torch.cat(head_outputs, dim=-1)  # (B, T, n_heads*d_k)

out_std = standard_multi_head_attention(x, W_Q, W_K, W_V)
print("Standard Multi-Head output shape:", out_std.shape)
# Create per-head Q, but only ONE K and V
WQ_mq = torch.nn.ParameterList([
    torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads)
])
WK_mq = torch.nn.Parameter(torch.randn(d_model, d_k))  # Shared across heads
WV_mq = torch.nn.Parameter(torch.randn(d_model, d_k))  # Shared across heads

def multi_query_attention(x, WQ_list, WK, WV):
    """
    x: (B, T, d_model)
    Returns: (B, T, n_heads * d_k)  -- same shape as standard multi-head
    """
    K = x @ WK  # (B, T, d_k), shared
    V = x @ WV  # (B, T, d_k), shared
    
    head_outputs = []
    for h in range(n_heads):
        Q_h = x @ WQ_list[h]  # (B, T, d_k)
        out_h, _ = scaled_dot_product_attention(Q_h, K, V)
        head_outputs.append(out_h)
        
    return torch.cat(head_outputs, dim=-1)

out_mq = multi_query_attention(x, WQ_mq, WK_mq, WV_mq)
print("Multi-Query Attention output shape:", out_mq.shape)

ok, so now let’s take that same idea and just make these matrices low-rank.

class MQLRA(nn.Module):
    """
    Simplified Multi-Query Attention + Low-Rank factorization.
    
    - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
    - We have a *single* K factorization: (W1K, W2K)
    - We have a *single* V factorization: (W1V, W2V)
    - This implementation uses standard PyTorch operations for readability.
    
    Input shape:
      X: [B, L, D_in]
    
    We'll produce an attention output shape of:
      [B, L, D_out * n_heads]   (by concatenating the n_heads outputs).
    """
    def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
        super().__init__()
        self.D_in = D_in
        self.D_out = D_out
        self.n_heads = n_heads
        self.rank = rank
        self.causal = causal
        if scale is not None:
            self.scale = float(scale)
        else:
            self.scale = 1.0 / math.sqrt(D_out)
            
        # Q: separate for each head
        # Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
        self.W1Q_heads = nn.ParameterList([
            nn.Parameter(torch.randn(D_in, rank) * 0.02)
            for _ in range(n_heads)
        ])
        self.W2Q_heads = nn.ParameterList([
            nn.Parameter(torch.randn(rank, D_out) * 0.02)
            for _ in range(n_heads)
        ])
        
        # K: single
        self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)
        
        # V: single
        self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)
    
    def forward(self, X, attn_bias=None):
        """
        X => [B, L, D_in]
        attn_bias => optional, shape broadcastible to [B, n_heads, L, L] if needed
        Returns: [B, L, n_heads*D_out]
        """
        batch_size, seq_len, _ = X.shape
        device = X.device
        dtype = X.dtype
        
        # Compute K and V once (shared across all heads)
        # Low-rank projections for K, V
        partial_k = X @ self.W1K  # [B, L, rank]
        K = partial_k @ self.W2K  # [B, L, D_out]
        
        partial_v = X @ self.W1V  # [B, L, rank]
        V = partial_v @ self.W2V  # [B, L, D_out]
        
        # Prepare output tensor to accumulate results from all heads
        out = torch.zeros(batch_size, seq_len, self.n_heads * self.D_out, 
                          device=device, dtype=dtype)
        
        # Process each head separately
        for h in range(self.n_heads):
            # Low-rank projection for Q (head-specific)
            partial_q = X @ self.W1Q_heads[h]  # [B, L, rank]
            Q = partial_q @ self.W2Q_heads[h]  # [B, L, D_out]
            
            # Compute attention scores
            # [B, L, D_out] @ [B, D_out, L] -> [B, L, L]
            scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale
            
            # Apply causal mask if needed
            if self.causal:
                # Create causal mask and apply it
                mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
                scores.masked_fill_(mask, -float('inf'))
            
            # Apply attention bias if provided
            if attn_bias is not None:
                scores = scores + attn_bias
            
            # Softmax and apply attention
            attn_weights = torch.softmax(scores, dim=-1)  # [B, L, L]
            head_output = torch.bmm(attn_weights, V)  # [B, L, D_out]
            
            # Add to output tensor at the correct position
            out[:, :, h * self.D_out:(h + 1) * self.D_out] = head_output
        
        return out

for a super easy implementation, we show how this would fit into Keller Jordan’s modded nanogpt which means we basically take our MQLRA and format it the way Keller formats his CausalSelfAttention class and then simply drop it into the script

# -----------------------------------------------------------------------------
# Custom operators: FP8 matmul by @YouJiacheng

@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
    @torch.compile
    def impl(x: Tensor, w: Tensor):
        assert x.is_contiguous() and w.is_contiguous()
        x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
        w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
        out = torch._scaled_mm(
            x_f8,
            w_f8.T,
            out_dtype=torch.bfloat16,
            scale_a=x.new_tensor(x_s, dtype=torch.float32),
            scale_b=x.new_tensor(w_s, dtype=torch.float32),
            use_fast_accum=True,
        )
        return out, x_f8, w_f8

    return impl(x, w)

@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
    assert x.ndim == w.ndim == 2
    assert x.shape[1] == w.shape[1]
    assert x.device == w.device
    assert x.is_contiguous() and w.is_contiguous()
    return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)

@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]:
    @torch.compile
    def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
        assert grad.is_contiguous()
        x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
        w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
        grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
        grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
        grad_x = torch._scaled_mm(
            grad_f8,
            w_f8.T.contiguous().T,
            out_dtype=torch.bfloat16,
            scale_a=grad_inv_s,
            scale_b=w_inv_s,
            use_fast_accum=False,
        )
        # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
        grad_w = torch._scaled_mm(
            x_f8.T.contiguous(),
            grad_f8.T.contiguous().T,
            out_dtype=torch.float32,
            scale_a=x_inv_s,
            scale_b=grad_inv_s,
            use_fast_accum=False,
        ).T
        return grad_x, grad_w

    return impl(g, x_f8, w_f8)

@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
    return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32)

def backward(ctx, grad_out: Tensor, *_):
    x_f8, w_f8 = ctx.saved_tensors
    x_s, w_s, grad_s = ctx.scales
    grad_x, grad_w = torch.ops.nanogpt.mm_backward(
        grad_out, x_f8, w_f8, x_s, w_s, grad_s
    )
    return grad_x, grad_w, None, None, None

def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
    *_, x_s, w_s, grad_s = inputs
    _, x_f8, w_f8 = output
    ctx.save_for_backward(x_f8, w_f8)
    ctx.scales = x_s, w_s, grad_s
    ctx.set_materialize_grads(False)

mm_op.register_autograd(backward, setup_context=setup_context)


# skipping the incredibly based Muon optimizer for brevity

# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the model

def norm(x: Tensor):
    return F.rms_norm(x, (x.size(-1),))

class CastedLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0):
        super().__init__(in_features, out_features, bias=False)
        self.use_fp8 = use_fp8
        self.x_s = x_s
        self.w_s = w_s
        self.grad_s = grad_s

    def reset_parameters(self) -> None:
        std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3)
        bound = (3 ** 0.5) * std
        with torch.no_grad():
            self.weight.uniform_(-bound, bound)

    def forward(self, x: Tensor):
        if self.use_fp8 and self.training:
            _x = x.flatten(0, -2)
            out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
            return out.reshape(*x.shape[:-1], -1)
        else:
            return F.linear(x, self.weight.type_as(x))

class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
        angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.cos = nn.Buffer(theta.cos(), persistent=False)
        self.sin = nn.Buffer(theta.sin(), persistent=False)

    def forward(self, x_BTHD: Tensor):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)

class CausalSelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        hdim = num_heads * head_dim
        std = 0.5 * (dim ** -0.5)
        bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
        # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
        # https://x.com/hi_tysam/status/1879699187107033311
        self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound))
        self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
        self.rotary = Rotary(head_dim, max_seq_len)
        self.c_proj = CastedLinear(hdim, dim)
        self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977

    def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask):
        B, T = x.size(0), x.size(1) # batch size, sequence length
        assert B == 1, "Must use batch size = 1 for FlexAttention"
        q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
        q, k = norm(q), norm(k) # QK norm @Grad62304977
        q, k = self.rotary(q), self.rotary(k)
        if ve is not None:
            v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977
        else: # skip mid-layers token value embeddings by @YouJiacheng
            v = self.lambdas[0] * v
        # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
        # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
        y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2)
        y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        hdim = 4 * dim
        self.c_fc = CastedLinear(dim, hdim)
        self.c_proj = CastedLinear(hdim, dim)
        self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977

    def forward(self, x: Tensor):
        x = self.c_fc(x)
        x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
        super().__init__()
        # skip attention of blocks.7 (the 8th layer) by @YouJiacheng
        self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
        self.mlp = MLP(dim)
        self.lambdas = nn.Parameter(torch.tensor([1., 0.]))

    def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
        x = self.lambdas[0] * x + self.lambdas[1] * x0
        if self.attn is not None:
            x = x + self.attn(norm(x), ve, block_mask)
        x = x + self.mlp(norm(x))
        return x

# -----------------------------------------------------------------------------
# The main model

def next_multiple_of_n(v: float | int, *, n: int):
    return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)

class GPT(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, model_dim)
        # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
        # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
        self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
        self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)])
        # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
        # suggested to me by @Grad62304977. this originates from Karpathy's experiments.
        self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128),
                                    use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448)
        self.lm_head.weight.detach().zero_() # @Grad62304977
        # Add learnable skip connection weights for decoder layers
        assert num_layers % 2 == 0
        self.skip_weights = nn.Parameter(torch.ones(num_layers//2))

    def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
        BLOCK_SIZE = 128
        docs = (input_seq == 50256).cumsum(0)

        def document_causal(b, h, q_idx, kv_idx):
            causal_mask = q_idx >= kv_idx
            document_mask = docs[q_idx] == docs[kv_idx]
            return causal_mask & document_mask

        def dense_to_ordered(dense_blockmask: Tensor):
            num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
            indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
            return num_blocks[None, None].contiguous(), indices[None, None].contiguous()

        # manual block mask creation by @YouJiacheng
        assert len(input_seq) % BLOCK_SIZE == 0
        NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
        block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
        causal_blockmask_any = block_idx[:, None] >= block_idx
        causal_blockmask_all = block_idx[:, None] > block_idx
        docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
        docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
        document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
        document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
        blockmask_any = causal_blockmask_any & document_blockmask_any
        blockmask_all = causal_blockmask_all & document_blockmask_all
        partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
        full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
        def build_bm(window_size_blocks: Tensor) -> BlockMask:
            return BlockMask.from_kv_blocks(
                torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
                partial_kv_indices,
                torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
                full_kv_indices,
                BLOCK_SIZE=BLOCK_SIZE,
                mask_mod=document_causal,
            )
        # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
        return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)

    def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
        assert input_seq.ndim == 1

        ve = [value_embed(input_seq) for value_embed in self.value_embeds]
        # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
        ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
        assert len(ve) == len(self.blocks)

        long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
        block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
        assert len(block_masks) == len(self.blocks)

        x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977

        # U-net design by @brendanh0gan
        skip_connections = []
        n = len(self.skip_weights)
        for i in range(len(self.blocks)):
            if i >= n:
                x = x + self.skip_weights[i - n] * skip_connections.pop()
            x = self.blocks[i](x, ve[i], x0, block_masks[i])
            if i < n:
                skip_connections.append(x)

        x = norm(x)
        logits = self.lm_head(x).float()
        # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
        logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5))
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction='sum' if self.training else 'mean')
        return loss


class MQLRAFlexAttention(nn.Module):
    """
    Multi-Query Low-Rank Attention + FlexAttention
    that acts as a drop-in replacement for CausalSelfAttention.

    - We replicate the same structure:
      * a rotary embedding
      * a final CastedLinear (c_proj)
      * a lambdas parameter for combining 've' (value embeddings)
      * multi-head query, single K/V

    - We still do QK norm, then apply rotary, then call flex_attention.
    - The difference: we use (W1Q_h, W2Q_h) for Q per head, plus single (W1K,W2K)
      and (W1V,W2V) for K, V.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        max_seq_len: int,
        head_dim=128,
        rank=16,  # low-rank dimension
        scale=0.12,  # default scale from your code
    ):
        """
        Args:
          dim: model dimension (input and output).
          num_heads: number of heads.
          max_seq_len: for rotary embedding.
          head_dim: dimension per head (like in CausalSelfAttention).
          rank: low-rank dimension for Q, K, V factorization.
          scale: scale factor for attention logits, e.g. 0.12 in your code.
        """
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.rank = rank
        self.scale = scale

        # total hidden dimension after merging heads
        hdim = num_heads * head_dim

        # optional initialization approach
        std = 0.5 * (dim ** -0.5)
        bound = (3 ** 0.5) * std

        # We keep lambdas for mixing in 've'
        self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))

        # Rotary embedding
        self.rotary = Rotary(head_dim, max_seq_len)

        # final projection after attention
        self.c_proj = CastedLinear(hdim, dim)
        self.c_proj.weight.detach().zero_()

        # Q: separate (W1Q, W2Q) for each head
        # shape: W1Q => [dim, rank], W2Q => [rank, head_dim]
        # We'll store them in ParameterList to keep them distinct per head.
        self.W1Q_heads = nn.ParameterList([
            nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
            for _ in range(num_heads)
        ])
        self.W2Q_heads = nn.ParameterList([
            nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))
            for _ in range(num_heads)
        ])

        # K: single (W1K, W2K)
        self.W1K = nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
        self.W2K = nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))

        # V: single (W1V, W2V)
        self.W1V = nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
        self.W2V = nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))

    def forward(self, x: torch.Tensor, ve: torch.Tensor | None, block_mask: BlockMask):
        """
        x: [B, T, dim]
        ve: optional value-embedding, shape must broadcast with V
        block_mask: for flex_attention
        Returns:
          [B, T, dim] after c_proj
        """
        B, T, dim = x.shape

        # If your flex_attention kernel only works for B=1:
        assert B == 1, "Must use batch size = 1 for FlexAttention"

        # 1) Compute Q, K, V in multi-query style
        #    Q => [B, T, num_heads, head_dim]
        #    single K => broadcast to [B, T, num_heads, head_dim]
        #    single V => broadcast similarly

        # Q (separate per head)
        Q_list = []
        for h in range(self.num_heads):
            partial_q = x.matmul(self.W1Q_heads[h])   # [B, T, rank]
            q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, T, head_dim]
            Q_list.append(q_h)

        # stack => [B, T, num_heads, head_dim]
        q = torch.stack(Q_list, dim=2)

        # single K
        partial_k = x.matmul(self.W1K)   # [B, T, rank]
        K_ = partial_k.matmul(self.W2K)  # [B, T, head_dim]

        # expand to match num_heads => [B, T, num_heads, head_dim]
        k = K_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)

        # single V
        partial_v = x.matmul(self.W1V)   # [B, T, rank]
        v_ = partial_v.matmul(self.W2V)  # [B, T, head_dim]
        v = v_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)

        # 2) norm Q, K + rotary
        #    same approach as CausalSelfAttention: norm(q), norm(k), then rotary
        q = norm(q)
        k = norm(k)
        q = self.rotary(q)
        k = self.rotary(k)

        # 3) incorporate 've' into v if provided
        if ve is not None:
            # ve => shape something like [B, T, num_heads, head_dim] or broadcastable
            # in CausalSelfAttention, they do `v = lambdas[0]*v + lambdas[1]*ve.view_as(v)`
            # We do the same:
            v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v)
        else:
            v = self.lambdas[0] * v

        # 4) call flex_attention
        #    flex_attention expects shape [B, n_heads, T, head_dim]
        #    so we do transpose(1,2)
        q_t = q.transpose(1, 2)  # => [B, num_heads, T, head_dim]
        k_t = k.transpose(1, 2)
        v_t = v.transpose(1, 2)

        # we have scale = self.scale, block_mask
        # from the snippet: y = flex_attention(q_t, k_t, v_t, block_mask, scale=0.12)
        y_t = flex_attention(q_t, k_t, v_t, block_mask=block_mask, scale=self.scale)

        # y_t => [B, num_heads, T, head_dim]
        # transpose back => [B, T, num_heads, head_dim]
        y = y_t.transpose(1, 2).contiguous()

        # 5) reshape => [B, T, (num_heads * head_dim)]
        y = y.view(B, T, self.num_heads * self.head_dim)

        # final projection => [B, T, dim]
        y = self.c_proj(y)
        return y

# Manually add gradient checkpointing to Block
class MQLRAFlexBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int, rank: int):
        super().__init__()
        # skip attention of blocks.7 (the 8th layer) by @YouJiacheng
        self.attn = MQLRAFlexAttention(dim, num_heads, max_seq_len, rank=rank) if layer_idx != 7 else None
        self.mlp = MLP(dim)
        self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
        
    def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
        # Use checkpoint instead of direct execution
        return checkpoint_block(self._forward, x, ve, x0, block_mask)
        
    def _forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
        x = self.lambdas[0] * x + self.lambdas[1] * x0
        if self.attn is not None:
            x = x + self.attn(norm(x), ve, block_mask)
        x = x + self.mlp(norm(x))
        return x

class ModifiedGPTWithMQLRAFlex(nn.Module):
    """Modified GPT model that uses MqaLraFlashAttention with configurable rank"""
    
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, rank: int):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, model_dim)
        # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
        # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
        self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
        self.blocks = nn.ModuleList([MQLRAFlexBlock(model_dim, num_heads, max_seq_len, i, rank) for i in range(num_layers)])
        # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
        # suggested to me by @Grad62304977. this originates from Karpathy's experiments.
        self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128),
                                    use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448)
        self.lm_head.weight.detach().zero_() # @Grad62304977
        # Add learnable skip connection weights for decoder layers
        assert num_layers % 2 == 0
        self.skip_weights = nn.Parameter(torch.ones(num_layers//2))
        self._last_loss_value = None

    def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
        BLOCK_SIZE = 64
        docs = (input_seq == 50256).cumsum(0)

        def document_causal(b, h, q_idx, kv_idx):
            causal_mask = q_idx >= kv_idx
            document_mask = docs[q_idx] == docs[kv_idx]
            return causal_mask & document_mask

        def dense_to_ordered(dense_blockmask: Tensor):
            num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
            indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
            return num_blocks[None, None].contiguous(), indices[None, None].contiguous()

        # manual block mask creation by @YouJiacheng
        assert len(input_seq) % BLOCK_SIZE == 0
        NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
        block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
        causal_blockmask_any = block_idx[:, None] >= block_idx
        causal_blockmask_all = block_idx[:, None] > block_idx
        docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
        docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
        document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
        document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
        blockmask_any = causal_blockmask_any & document_blockmask_any
        blockmask_all = causal_blockmask_all & document_blockmask_all
        partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
        full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
        def build_bm(window_size_blocks: Tensor) -> BlockMask:
            return BlockMask.from_kv_blocks(
                torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
                partial_kv_indices,
                torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
                full_kv_indices,
                BLOCK_SIZE=BLOCK_SIZE,
                mask_mod=document_causal,
            )
        # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
        return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)

    def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
        assert input_seq.ndim == 1

        ve = [value_embed(input_seq) for value_embed in self.value_embeds]
        # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
        ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
        assert len(ve) == len(self.blocks)

        long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
        block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
        assert len(block_masks) == len(self.blocks)

        x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977

        # U-net design by @brendanh0gan
        skip_connections = []
        n = len(self.skip_weights)
        for i in range(len(self.blocks)):
            if i >= n:
                x = x + self.skip_weights[i - n] * skip_connections.pop()
            x = self.blocks[i](x, ve[i], x0, block_masks[i])
            if i < n:
                skip_connections.append(x)

        x = norm(x)
        logits = self.lm_head(x).float()
        # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
        logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5))
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction='sum' if self.training else 'mean')
        # If we're in training mode, track the loss components
        if self.training and dist.get_rank() == 0:
            # Store the loss components for analysis
            self._last_loss_value = float(loss.item())
        
        return loss

    def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40):
        """Generate text from the model"""
        self.eval()
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Crop idx to manageable context if needed
                idx_cond = idx if idx.size(1) <= 1024 else idx[:, -1024:]
                
                # Get logits for next token prediction
                logits, _ = self._forward_generation(idx_cond)
                logits = logits[:, -1, :] / temperature  # Focus on last position
                
                # Apply top-k filtering
                if top_k is not None:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('inf')
                
                # Sample from distribution
                probs = F.softmax(logits, dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
                
                # Append to sequence
                idx = torch.cat((idx, idx_next), dim=1)
        
        return idx

    def _forward_generation(self, input_seq):
        """Handle generation case where we only need logits"""
        # Make sure input is properly formatted
        if input_seq.ndim == 2:  # [B, T]
            batch_size = input_seq.size(0)
            if batch_size > 1:
                # Only use the first batch item for simplicity in generation
                input_seq = input_seq[0]
            else:
                input_seq = input_seq.view(-1)
        
        # Use a reasonable window size for generation
        sliding_window_num_blocks = torch.tensor(8, dtype=torch.int32, device=input_seq.device)
        
        # Ensure input is padded to BLOCK_SIZE multiple
        BLOCK_SIZE = 64  # Same as in create_blockmasks
        original_len = len(input_seq)
        remainder = original_len % BLOCK_SIZE
        if remainder != 0:
            pad_length = BLOCK_SIZE - remainder
            padding = torch.zeros(pad_length, dtype=input_seq.dtype, device=input_seq.device)
            input_seq = torch.cat([input_seq, padding])
        
        # Get token embeddings for the padded sequence
        x = x0 = norm(self.embed(input_seq)[None])  # use of norm here matches forward()
        
        # Create block masks for attention with padded sequence
        long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
        block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, 
                    short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
        
        # Get value embeddings for the padded sequence
        ve = [value_embed(input_seq) for value_embed in self.value_embeds]
        ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
        
        # Process through blocks like in forward method
        skip_connections = []
        n = len(self.skip_weights)
        for i in range(len(self.blocks)):
            if i >= n:
                x = x + self.skip_weights[i - n] * skip_connections.pop()
            
            # Critical fix: Only pass value embeddings to blocks that expect them
            current_ve = ve[i] if i < len(ve) and ve[i] is not None else None
            current_mask = block_masks[i] if i < len(block_masks) else None
            
            # Disable gradient checkpointing during generation
            if hasattr(self.blocks[i], 'use_checkpoint'):
                original_checkpoint = self.blocks[i].use_checkpoint
                self.blocks[i].use_checkpoint = False
                x = self.blocks[i](x, current_ve, x0, current_mask)
                self.blocks[i].use_checkpoint = original_checkpoint
            else:
                x = self.blocks[i](x, current_ve, x0, current_mask)
                
            if i < n:
                skip_connections.append(x)
        
        # Final norm and project to logits
        x = norm(x)
        logits = self.lm_head(x).float()
        
        # Return logits for generation and the hidden state
        return logits, x

flash/flex attention

i was already deep down the rabbit hole, so i figured why not attempt to drop in flash/flex attention on top of the attention matmuls and see if we can get an even larger speed boost

# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
# assume we have the flash attention triton definition:
# *Experimental* implementation of FlashAttention in Triton.
# Tested with triton==2.0.0.dev20221202.
#...
#...
# flash_attn_func = FlashAttnFunc.apply
# from flash_attn_triton import flash_attn_func
class MQLRAFlashAttention(nn.Module):
    """
    Multi-Query Attention + Low-Rank factorization + FlashAttention.

    - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
    - We have a *single* K factorization: (W1K, W2K)
    - We have a *single* V factorization: (W1V, W2V)
    - Then we call the standard Triton FlashAttention kernel on the resulting Q, K, V.

    Input shape:
      X: [B, L, D_in]

    We'll produce an attention output shape of:
      [B, L, D_out * n_heads]   (by concatenating the n_heads outputs).

    We store parameters for Q in `self.W1Q_heads[i], self.W2Q_heads[i]`,
    and single K, V in `self.W1K, W2K, W1V, W2V`.
    """

    def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
        super().__init__()
        self.D_in = D_in
        self.D_out = D_out
        self.n_heads = n_heads
        self.rank = rank
        self.causal = causal
        if scale is not None:
            self.scale = float(scale)
        else:
            self.scale = 1.0 / math.sqrt(D_out)

        # Q: separate for each head
        # Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
        self.W1Q_heads = nn.ParameterList([
            nn.Parameter(torch.randn(D_in, rank) * 0.02)
            for _ in range(n_heads)
        ])
        self.W2Q_heads = nn.ParameterList([
            nn.Parameter(torch.randn(rank, D_out) * 0.02)
            for _ in range(n_heads)
        ])

        # K: single
        self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)

        # V: single
        self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)

    def forward(self, X, attn_bias=None):
        """
        X => [B, L, D_in]
        attn_bias => optional, shape broadcastible to [B, n_heads, L, L] if needed
        We'll produce out => [B, L, n_heads*D_out].

        Implementation steps:
          1) Build Q (n_heads) => shape [B, L, n_heads, D_out]
          2) Build K => shape [B, L, D_out], replicate to [B, L, 1, D_out]
          3) Build V => shape [B, L, D_out], replicate to [B, L, 1, D_out]
          4) Call flash-attn forward => out => shape [B, L, n_heads, D_out]
          5) Reshape to [B, L, n_heads*D_out].
        """
        B, L, _ = X.shape
        # Step 1) Q => [B, L, n_heads, D_out]
        # We'll compute each head's Q in Python. Then stack along "n_heads".
        #   partial_Q_h = X @ W1Q_heads[h] => [B,L,rank]
        #   Q_h = partial_Q_h @ W2Q_heads[h] => [B,L,D_out]
        Q_list = []
        for h in range(self.n_heads):
            partial_q = X.matmul(self.W1Q_heads[h])      # [B, L, rank]
            q_h       = partial_q.matmul(self.W2Q_heads[h])  # [B, L, D_out]
            Q_list.append(q_h)
        # stack => [B, L, n_heads, D_out]
        Q = torch.stack(Q_list, dim=2)

        # Step 2) K => [B,L,D_out], then expand to [B,L,1,D_out]
        partial_k = X.matmul(self.W1K)   # [B,L,rank]
        K_ = partial_k.matmul(self.W2K)  # [B,L,D_out]
        # For multi-query: we have a single K => we can just unsqueeze dim=2 => n_heads=1
        # But if we want flash to see n_heads == self.n_heads, we replicate:
        #   shape => [B, L, n_heads, D_out]
        # but logically the data is the same. We'll do a .unsqueeze(2).expand(...)
        # or we can do a .repeat_interleave, but that costs memory. Alternatively we can do
        # a trick: we pass n_heads=1 to flash, but that won't match Q's shape. So we must replicate:
        # This is the simplest approach:
        K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)

        # Step 3) V => same approach
        partial_v = X.matmul(self.W1V)   # [B,L,rank]
        V_ = partial_v.matmul(self.W2V)  # [B,L,D_out]
        V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)

        # Step 4) flash-attn => out => [B, L, n_heads, D_out]
        # The flash_attn code expects Q => [B, L, n_heads, headdim]
        # so we pass Q,K,V in that shape. We can pass an optional bias if it matches shape [B,n_heads,L,L].

        Q = Q.contiguous()
        K = K.contiguous()
        V = V.contiguous()
        print("Q shape:", Q.shape, "stride:", Q.stride())
        print("K shape:", K.shape, "stride:", K.stride())
        print("V shape:", V.shape, "stride:", V.stride())

        out = flash_attn_func(Q, K, V,
                              attn_bias,
                              self.causal,
                              self.scale)
        # out => [B, L, nheads, D_out]

        # Step 5) reshape => [B,L,n_heads*D_out]
        out = out.reshape(B, L, self.n_heads*self.D_out)
        return out


class MQLRAFlexAttention(nn.Module):
    """
    Multi-Query Attention + Low-Rank factorization + FlashAttention.

    - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
    - We have a *single* K factorization: (W1K, W2K)
    - We have a *single* V factorization: (W1V, W2V)
    - Then we call the standard Triton FlashAttention kernel on the resulting Q, K, V.

    Input shape:
      X: [B, L, D_in]

    We'll produce an attention output shape of:
      [B, L, D_out * n_heads]   (by concatenating the n_heads outputs).

    We store parameters for Q in `self.W1Q_heads[i], self.W2Q_heads[i]`,
    and single K, V in `self.W1K, W2K, W1V, W2V`.
    """

    def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
        super().__init__()
        self.D_in = D_in
        self.D_out = D_out
        self.n_heads = n_heads
        self.rank = rank
        self.causal = causal
        if scale is not None:
            self.scale = float(scale)
        else:
            self.scale = 1.0 / math.sqrt(D_out)

        # Q: separate for each head
        # Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
        self.W1Q_heads = nn.ParameterList([
            nn.Parameter(torch.randn(D_in, rank) * 0.02)
            for _ in range(n_heads)
        ])
        self.W2Q_heads = nn.ParameterList([
            nn.Parameter(torch.randn(rank, D_out) * 0.02)
            for _ in range(n_heads)
        ])

        # K: single
        self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)

        # V: single
        self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)

    def forward(self, X, attn_bias=None):
        """
        X => [B, L, D_in]
        attn_bias => optional, shape broadcastible to [B, n_heads, L, L] if needed
        We'll produce out => [B, L, n_heads*D_out].

        Implementation steps:
          1) Build Q (n_heads) => shape [B, L, n_heads, D_out]
          2) Build K => shape [B, L, D_out], replicate to [B, L, 1, D_out]
          3) Build V => shape [B, L, D_out], replicate to [B, L, 1, D_out]
          4) Call flash-attn forward => out => shape [B, L, n_heads, D_out]
          5) Reshape to [B, L, n_heads*D_out].
        """
        B, L, _ = X.shape
        # Step 1) Q => [B, L, n_heads, D_out]
        # We'll compute each head's Q in Python. Then stack along "n_heads".
        #   partial_Q_h = X @ W1Q_heads[h] => [B,L,rank]
        #   Q_h = partial_Q_h @ W2Q_heads[h] => [B,L,D_out]
        Q_list = []
        for h in range(self.n_heads):
            partial_q = X.matmul(self.W1Q_heads[h])      # [B, L, rank]
            q_h       = partial_q.matmul(self.W2Q_heads[h])  # [B, L, D_out]
            Q_list.append(q_h)
        # stack => [B, L, n_heads, D_out]
        Q = torch.stack(Q_list, dim=2)

        # Step 2) K => [B,L,D_out], then expand to [B,L,1,D_out]
        partial_k = X.matmul(self.W1K)   # [B,L,rank]
        K_ = partial_k.matmul(self.W2K)  # [B,L,D_out]
        # For multi-query: we have a single K => we can just unsqueeze dim=2 => n_heads=1
        # But if we want flash to see n_heads == self.n_heads, we replicate:
        #   shape => [B, L, n_heads, D_out]
        # but logically the data is the same. We'll do a .unsqueeze(2).expand(...)
        # or we can do a .repeat_interleave, but that costs memory. Alternatively we can do
        # a trick: we pass n_heads=1 to flash, but that won't match Q's shape. So we must replicate:
        # This is the simplest approach:
        K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)

        # Step 3) V => same approach
        partial_v = X.matmul(self.W1V)   # [B,L,rank]
        V_ = partial_v.matmul(self.W2V)  # [B,L,D_out]
        V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)

        # Step 4) flash-attn => out => [B, L, n_heads, D_out]
        # The flash_attn code expects Q => [B, L, n_heads, headdim]
        # so we pass Q,K,V in that shape. We can pass an optional bias if it matches shape [B,n_heads,L,L].

        Q = Q.contiguous()
        K = K.contiguous()
        V = V.contiguous()
        print("Q shape:", Q.shape, "stride:", Q.stride())
        print("K shape:", K.shape, "stride:", K.stride())
        print("V shape:", V.shape, "stride:", V.stride())

        out = flash_attn_func(Q, K, V,
                              attn_bias,
                              self.causal,
                              self.scale)
        # out => [B, L, nheads, D_out]

        # Step 5) reshape => [B,L,n_heads*D_out]
        out = out.reshape(B, L, self.n_heads*self.D_out)
        return out


class MqaLraFlexAttention(nn.Module):
    """
    Multi-Query Attention + Low-Rank factorization + FlexAttention.

    - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h).
    - We have a *single* K factorization: (W1K, W2K)
    - We have a *single* V factorization: (W1V, W2V)
    - Then we call the Triton-based `flex_attention` kernel on the resulting Q, K, V,
      similar to how `CausalSelfAttention` uses flex_attention.

    Input shape:
      X: [B, L, D_in]

    Output shape:
      [B, L, D_out * n_heads]
    """

    def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None):
        super().__init__()
        self.D_in = D_in
        self.D_out = D_out
        self.n_heads = n_heads
        self.rank = rank
        self.causal = causal

        # If scale is not provided, use 1/sqrt(D_out) by default
        self.scale = float(scale) if scale is not None else (1.0 / math.sqrt(D_out))

        # Q: separate for each head
        # Each head: W1Q => [D_in, rank], W2Q => [rank, D_out]
        self.W1Q_heads = nn.ParameterList(
            nn.Parameter(torch.randn(D_in, rank) * 0.02)
            for _ in range(n_heads)
        )
        self.W2Q_heads = nn.ParameterList(
            nn.Parameter(torch.randn(rank, D_out) * 0.02)
            for _ in range(n_heads)
        )

        # K: single
        self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02)

        # V: single
        self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02)
        self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02)

    def forward(self, X, block_mask=None, attn_bias=None):
        """
        X: [B, L, D_in]
        block_mask: the `BlockMask` (if needed by flex_attention).
        attn_bias: optional, shape broadcastible to [B, n_heads, L, L] if needed.

        Returns: [B, L, n_heads * D_out]
        """

        B, L, _ = X.shape
        # If your flex_attention kernel only supports B=1, you may want:
        # assert B == 1, "We only support batch_size=1 for flex_attention."

        # 1) Compute Q per head => [B, L, n_heads, D_out]
        Q_list = []
        for h in range(self.n_heads):
            partial_q = X @ self.W1Q_heads[h]      # [B, L, rank]
            q_h       = partial_q @ self.W2Q_heads[h]  # [B, L, D_out]
            Q_list.append(q_h)
        Q = torch.stack(Q_list, dim=2)  # => [B, L, n_heads, D_out]

        # 2) Compute K => [B, L, D_out]
        partial_k = X @ self.W1K        # [B, L, rank]
        K_ = partial_k @ self.W2K       # [B, L, D_out]

        # Expand K => [B, L, n_heads, D_out] for multi-query
        K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)

        # 3) Compute V => [B, L, D_out]
        partial_v = X @ self.W1V        # [B, L, rank]
        V_ = partial_v @ self.W2V       # [B, L, D_out]

        V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out)

        # 4) Call flex_attention
        # flex_attention expects [B, n_heads, L, head_dim], so we transpose(1,2).
        # Then we get back [B, n_heads, L, head_dim], which we transpose back to [B, L, n_heads, head_dim].
        Q_t = Q.transpose(1, 2)  # => [B, n_heads, L, D_out]
        K_t = K.transpose(1, 2)  # => [B, n_heads, L, D_out]
        V_t = V.transpose(1, 2)  # => [B, n_heads, L, D_out]

        # flex attention here
        out_t = flex_attention(Q_t, K_t, V_t,
                               block_mask=block_mask,
                               scale=self.scale)
        # out_t => [B, n_heads, L, D_out]

        # 5) Transpose back => [B, L, n_heads, D_out], then flatten
        out = out_t.transpose(1, 2)  # => [B, L, n_heads, D_out]
        out = out.reshape(B, L, self.n_heads * self.D_out)

        return out

one more change:

i figued one simple change would be to fuse the k and v matrices into a single matrix. this is a bit more efficient, and it’s a bit faster on training (less calls to GPU)

while in some kind of forward pass we could have the for loop parallelized for the original implementation. either way, i wanted to see what would happen.

class FusedMQLRAFlexAttention(nn.Module):
    """
    Multi-Query Low-Rank Attention + FlexAttention with fused K and V matrices
    for improved computational efficiency.
    
    This version maintains the low-rank factorization of Q matrices for each head,
    but uses fused matrices for K and V to reduce computation during the forward pass.
    
    The Q matrices still use the (W1Q_h, W2Q_h) factorization per head, but
    K and V use direct projection matrices.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        max_seq_len: int,
        head_dim=128,
        rank=16,  # low-rank dimension used for initialization and Q factorization
        scale=0.12,  # default scale from your code
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.rank = rank
        self.scale = scale

        # total hidden dimension after merging heads
        hdim = num_heads * head_dim

        # initialization approach
        std = 0.5 * (dim ** -0.5)
        bound = (3 ** 0.5) * std

        # We keep lambdas for mixing in 've'
        self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))

        # Rotary embedding
        self.rotary = Rotary(head_dim, max_seq_len)

        # final projection after attention
        self.c_proj = CastedLinear(hdim, dim)
        self.c_proj.weight.detach().zero_()

        # Q: separate (W1Q, W2Q) for each head (keep low-rank factorization for Q)
        self.W1Q_heads = nn.ParameterList([
            nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound))
            for _ in range(num_heads)
        ])
        self.W2Q_heads = nn.ParameterList([
            nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound))
            for _ in range(num_heads)
        ])

        # Fused K: instead of W1K and W2K, use a single matrix initialized as their product
        W1K_init = torch.empty(dim, rank).uniform_(-bound, bound)
        W2K_init = torch.empty(rank, head_dim).uniform_(-bound, bound)
        self.K_matrix = nn.Parameter(W1K_init @ W2K_init)
        
        # Fused V: instead of W1V and W2V, use a single matrix initialized as their product
        W1V_init = torch.empty(dim, rank).uniform_(-bound, bound)
        W2V_init = torch.empty(rank, head_dim).uniform_(-bound, bound)
        self.V_matrix = nn.Parameter(W1V_init @ W2V_init)

    def forward(self, x: torch.Tensor, ve: torch.Tensor | None, block_mask: BlockMask):
        """
        x: [B, T, dim]
        ve: optional value-embedding, shape must broadcast with V
        block_mask: for flex_attention
        Returns:
          [B, T, dim] after c_proj
        """
        B, T, dim = x.shape

        # If your flex_attention kernel only works for B=1:
        assert B == 1, "Must use batch size = 1 for FlexAttention"

        # 1) Compute Q, K, V in multi-query style
        #    Q => [B, T, num_heads, head_dim]
        #    single K => broadcast to [B, T, num_heads, head_dim]
        #    single V => broadcast similarly

        # Q (separate per head with low-rank factorization)
        Q_list = []
        for h in range(self.num_heads):
            partial_q = x.matmul(self.W1Q_heads[h])   # [B, T, rank]
            q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, T, head_dim]
            Q_list.append(q_h)

        # stack => [B, T, num_heads, head_dim]
        q = torch.stack(Q_list, dim=2)

        # K using fused matrix - one matmul instead of two
        K_ = x.matmul(self.K_matrix)  # [B, T, head_dim]
        # expand to match num_heads => [B, T, num_heads, head_dim]
        k = K_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)

        # V using fused matrix - one matmul instead of two
        v_ = x.matmul(self.V_matrix)  # [B, T, head_dim]
        v = v_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim)

        # 2) norm Q, K + rotary
        q = norm(q)
        k = norm(k)
        q = self.rotary(q)
        k = self.rotary(k)

        # 3) incorporate 've' into v if provided
        if ve is not None:
            v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v)
        else:
            v = self.lambdas[0] * v

        # 4) call flex_attention
        q_t = q.transpose(1, 2)  # => [B, num_heads, T, head_dim]
        k_t = k.transpose(1, 2)
        v_t = v.transpose(1, 2)

        y_t = flex_attention(q_t, k_t, v_t, block_mask=block_mask, scale=self.scale)

        # y_t => [B, num_heads, T, head_dim]
        # transpose back => [B, T, num_heads, head_dim]
        y = y_t.transpose(1, 2).contiguous()

        # 5) reshape => [B, T, (num_heads * head_dim)]
        y = y.view(B, T, self.num_heads * self.head_dim)

        # final projection => [B, T, dim]
        y = self.c_proj(y)
        return y

class FusedMQLRAFlexBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int, rank: int):
        super().__init__()
        # skip attention of blocks.7 (the 8th layer) by @YouJiacheng
        self.attn = FusedMQLRAFlexAttention(dim, num_heads, max_seq_len, rank=rank) if layer_idx != 7 else None
        self.mlp = MLP(dim)
        self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
        
    def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
        # Use checkpoint instead of direct execution
        return checkpoint_block(self._forward, x, ve, x0, block_mask)
        
    def _forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
        x = self.lambdas[0] * x + self.lambdas[1] * x0
        if self.attn is not None:
            x = x + self.attn(norm(x), ve, block_mask)
        x = x + self.mlp(norm(x))
        return x

quick and dirty training attempt w/ fused MQLRFlexAttention.

basically, in order to roughly get the same performance as the original modded nanogpt, we need to train for many more steps since 1) i don’t have an H100 and 1a) we can train with float8, so basically we have to shrink the sequence length + increase the number of iterations (which sucks, but whatever, we’re just throwing boredom speghetti at the wall here)

you can sort of train with the fused kernel, then run inference with the original kernel, or not, or not not.

nothing wild/ground breaking, just a bit of stream of consciousness.