==================

Preface

I’m going to start blogging a bit more on day-to-day random issues that aren’t NDA’d topics, sort of the slog/grind of the day to day in the life of a scientist in the AI space.

Float16 explosions in inference and not training?

stream of consciousness: wut? why? this inference code works in float32 for 100% of the inputs. however, in float16 no problems for like 99% of inputs. but for 1% of inputs, it explodes/returns nulls?

very weird.

let’s rewrite our forward forward function of this embedder to find out what in the world is happening.

Rewriting our forward function to find out what in the friday night lights is happening

we’ll print/log each step and save to a np/pickle file the weights before and after given layers/activations.


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        output_all_encoded_layers: Optional[bool] = True,
        subset_mask: Optional[torch.Tensor] = None,
        position_encodings: Optional[torch.Tensor] = None,
    ) -> List[torch.Tensor]:
        from einops import rearrange
        from infinity_emb.transformer.monarch.mm.hyena_utils import fftconv_ref
        import math

        import torch
        import torch.nn as nn
        import torch.nn.functional as F

        from einops import rearrange
        import opt_einsum as oe
        contract = oe.contract

        def check_nan(tensor, location):
            if torch.isnan(tensor).any():
                nan_count = torch.isnan(tensor).sum().item()
                print(f"NaN detected in {location}: {nan_count} NaN values")
                return True
            return False

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        attention_mask_bool = attention_mask.bool()
        batch, seqlen = hidden_states.shape[:2]

        # Unpad inputs and mask. It will remove tokens that are padded.
        # Assume ntokens is total number of tokens (padded and non-padded)
        # and ntokens_unpad is total number of non-padded tokens.
        # Then unpadding performs the following compression of the inputs:
        # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
        if not self.monarch_mixer_sequence_mixing:
            hidden_states, indices, cu_seqlens, _ = bert_padding_module.unpad_input(
                hidden_states, attention_mask_bool)
        else:
            cu_seqlens = None
            indices = None

        # Add alibi matrix to extended_attention_mask
        if not self.monarch_mixer_sequence_mixing:
            if self._current_alibi_size < seqlen:
                # Rebuild the alibi tensor when needed
                warnings.warn(
                    f'Increasing alibi size from {self._current_alibi_size} to {seqlen}'
                )
                self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
            elif self.alibi.device != hidden_states.device:
                # Device catch-up
                self.alibi = self.alibi.to(hidden_states.device)
            alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
            attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
            alibi_attn_mask = attn_bias + alibi_bias
        else:
            alibi_attn_mask = None

        all_encoder_layers = []

        if self.monarch_mixer_sequence_mixing:
            for layer_idx, layer_module in enumerate(self.layer):

                # u is B L H
                print(f"starting: Layer {layer_idx} {layer_module}")
                check_nan(hidden_states, f"Layer {layer_idx} input")
                u = hidden_states
                if layer_module.attention.hyena_training_additions:
                    u = layer_module.attention.layernorm(u)
                    check_nan(u, f"Layer {layer_idx} u = layer_module.attention.layernorm(u)")
                L = u.size(-2)
                
                u_orig = u
                u = layer_module.attention.in_linear(u)
                check_nan(u, f"Layer {layer_idx}  u = self.in_linear(u)")
                u = rearrange(u, "b l d -> b d l")
                check_nan(u, f"Layer {layer_idx} u = rearrange(u, 'b l d -> b d l')")


                # short filter
                uc = layer_module.attention.short_filter(u)[..., :L]
                check_nan(uc, f"Layer {layer_idx} layer_module.attention.short_filter(u)[..., :L]")

                x1, x2, v = uc.split(layer_module.attention.d_model, dim=1)
                check_nan(x1, f"Layer {layer_idx} x1")
                check_nan(x2, f"Layer {layer_idx} x2")
                check_nan(v, f"Layer {layer_idx} v")

                v = v * x1
                check_nan(v, f"Layer {layer_idx} v = v * x1")
                if layer_module.attention.hyena_training_additions:
                    v = layer_module.attention.drop(v)
                    check_nan(v, f"Layer {layer_idx} v = layer_module.attention.drop(v)")

                k = layer_module.attention.filter_fn.filter(L, device=u.device)
                check_nan(k, f"Layer {layer_idx} k = layer_module.attention.filter_fn.filter(L, device=u.device)")
                k = rearrange(k, "c l d -> c d l")[0] # `c` is always 1 by default
                check_nan(k, f"Layer {layer_idx} k = rearrange(k, 'c l d -> c d l')[0] # `c` is always 1 by default")

                if layer_module.attention.bidirectional:
                    k_rev = layer_module.attention.filter_fn.filter_rev(L, device=u.device)
                    check_nan(k_rev, f"Layer {layer_idx} k_rev = layer_module.attention.filter_fn.filter_rev(L, device=u.device)")
                    k_rev = rearrange(k_rev, "c l d -> c d l")[0] # `c` is always 1 by default
                    check_nan(k_rev, f"Layer {layer_idx} k = rearrange(k, 'c l d -> c d l')[0] # `c` is always 1 by default")
                else:
                    k_rev = None

                y = layer_module.attention.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= layer_module.attention.filter_fn.bias[None, :, None])
                check_nan(y, f"Layer {layer_idx} y = layer_module.attention.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= layer_module.attention.filter_fn.bias[None, :, None])")

                if layer_module.attention.residual_long_conv:
                    k2 = layer_module.attention.filter_fn2.filter(L, device=u.device)
                    check_nan(k2, f"Layer {layer_idx} k2 = layer_module.attention.filter_fn2.filter(L, device=u.device)")
                    k2 = rearrange(k2, "c l d -> c d l")[0]
                    check_nan(k2, f"Layer {layer_idx} rearrange(k2, 'c l d -> c d l')[0]")

                    if layer_module.attention.bidirectional:
                        k2_rev = layer_module.attention.filter_fn2.filter_rev(L, device=u.device)
                        check_nan(k2_rev, f"Layer {layer_idx} k2_rev = layer_module.attention.filter_fn2.filter_rev(L, device=u.device)")
                        k2_rev = rearrange(k2_rev, "c l d -> c d l")[0] # `c` is always 1 by default
                        check_nan(k2_rev, f"Layer {layer_idx} rearrange(k2_rev, 'c l d -> c d l')[0]")

                    else:
                        k2_rev = None                

                    yu = layer_module.attention.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= layer_module.attention.filter_fn2.bias[None, :, None])
                    check_nan(yu, f"Layer {layer_idx} yu = layer_module.attention.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= layer_module.attention.filter_fn2.bias[None, :, None])")

                # post gating
                y = y * x2
                check_nan(y, f"Layer {layer_idx} y = y * x2")

                if layer_module.attention.residual_long_conv:
                    y = y + yu
                    check_nan(y, f"Layer {layer_idx} y = y + yu")

                y = y.transpose(-1, -2)
                check_nan(y, f"Layer {layer_idx} y = y.transpose(-1, -2)")
                if layer_module.attention.hyena_training_additions:
                    y = layer_module.attention.drop(layer_module.attention.act(y))
                    check_nan(y, f"Layer {layer_idx} y = layer_module.attention.drop(layer_module.attention.act(y))")
                # for debugging
                # print("y.shape:", y.shape)
                # print("layer_module.attention.out_linear.weight.shape:", layer_module.attention.out_linear.weight.shape)
                # print("layer_module.attention.out_linear.bias.shape:", layer_module.attention.out_linear.bias.shape)
                y = layer_module.attention.out_linear(y)
                check_nan(y, f"Layer {layer_idx} y = layer_module.out_linear(y)")
                
                #hidden_states = layer_module.mlp(y)
                #check_nan(y, f"Layer {layer_idx} hidden_states = layer_module.mlp(y)")
                hidden_states = y
                residual_connection = hidden_states
                # compute the activation
                hidden_states = layer_module.mlp.gated_layers(hidden_states)
                check_nan(hidden_states, f"Layer {layer_idx} hidden_states = layer_module.mlp.gated_layers(hidden_states)")
                print(f"Layer {layer_idx} after gated_layers stats:")
                print(f"Mean: {hidden_states.mean().item()}")
                print(f"Max abs: {hidden_states.abs().max().item()}")
                print(f"% > 5: {(hidden_states.abs() > 5).float().mean().item() * 100}%")

                if layer_module.mlp.is_padded:
                    gated = hidden_states[:, :, :layer_module.mlp.config.intermediate_size]
                    check_nan(gated, f"Layer {layer_idx} gated = hidden_states[:, :, :layer_module.mlp.config.intermediate_size]")
                    non_gated = hidden_states[:, :, layer_module.mlp.config.intermediate_size:]
                    check_nan(non_gated, f"Layer {layer_idx} non_gated = hidden_states[:, :, layer_module.mlp.config.intermediate_size:]")
                else:
                    gated = hidden_states[:, :layer_module.mlp.config.intermediate_size]
                    check_nan(gated, f"Layer {layer_idx} gated = hidden_states[:, :layer_module.mlp.config.intermediate_size]")
                    non_gated = hidden_states[:, layer_module.mlp.config.intermediate_size:]
                    check_nan(non_gated, f"Layer {layer_idx} non_gated = hidden_states[:, layer_module.mlp.config.intermediate_size:]")

                hidden_states = layer_module.mlp.act(gated) * non_gated
                hidden_states = torch.clamp(hidden_states, min=-10000.0, max=10000.0)
                check_nan(hidden_states, f"Layer {layer_idx} hidden_states = layer_module.mlp.act(gated) * non_gated")
                print(f"Layer {layer_idx} after activation and gating stats:")
                print(f"Mean: {hidden_states.mean().item()}")
                print(f"Max abs: {hidden_states.abs().max().item()}")
                print(f"% > 5: {(hidden_states.abs() > 5).float().mean().item() * 100}%")

                hidden_states = layer_module.mlp.dropout(hidden_states)
                check_nan(hidden_states, f"Layer {layer_idx} hidden_states = layer_module.mlp.dropout(hidden_states)")
                # multiply by the second matrix
                hidden_states = layer_module.mlp.wo(hidden_states)
                check_nan(hidden_states, f"Layer {layer_idx} hidden_states = layer_module.mlp.wo(hidden_states)")
                print(f"Layer {layer_idx} after wo stats:")
                print(f"Mean: {hidden_states.mean().item()}")
                print(f"Max abs: {hidden_states.abs().max().item()}")
                print(f"% > 5: {(hidden_states.abs() > 5).float().mean().item() * 100}%")

                # add the residual connection and post-LN
                combine = hidden_states + residual_connection

                print(f"Layer {layer_idx} combine stats:")
                print(f"Mean: {combine.mean().item()}")
                print(f"Std: {combine.std().item()}")
                print(f"Max: {combine.abs().max().item()}")
                print(f"% of values > 5: {(combine.abs() > 5).float().mean().item() * 100}%")
                
                
                check_nan(combine, f"Layer {layer_idx} hidden_states + residual_connection")
                hidden_states = layer_module.mlp.layernorm(combine)
                check_nan(hidden_states, f"Layer {layer_idx} hidden_states = layer_module.mlp.layernorm(combine)")

                if position_encodings is not None:
                    hidden_states = hidden_states + position_encodings
                    check_nan(hidden_states, f"Layer {layer_idx} after position encoding")
                
                if output_all_encoded_layers:
                    all_encoder_layers.append(hidden_states)
                    
            if subset_mask is not None:
                hidden_states = hidden_states[subset_mask]
                check_nan(hidden_states, f"After subset mask")

        else:
            if subset_mask is None:
                for layer_module in self.layer:
                    hidden_states = layer_module(hidden_states,
                        cu_seqlens,
                        seqlen,
                        None,
                        indices,
                        attn_mask=attention_mask,
                        bias=alibi_attn_mask
                    )
                    if output_all_encoded_layers:
                        all_encoder_layers.append(hidden_states)

                # Pad inputs and mask. It will insert back zero-padded tokens.
                # Assume ntokens is total number of tokens (padded and non-padded)
                # and ntokens_unpad is total number of non-padded tokens.
                # Then padding performs the following de-compression:
                #     hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
                hidden_states = bert_padding_module.pad_input(
                    hidden_states, indices, batch, seqlen
                )
            else:
                for i in range(len(self.layer) - 1):
                    layer_module = self.layer[i]
                    hidden_states = layer_module(hidden_states,
                                                    cu_seqlens,
                                                    seqlen,
                                                    None,
                                                    indices,
                                                    attn_mask=attention_mask,
                                                    bias=alibi_attn_mask)
                    if output_all_encoded_layers:
                        all_encoder_layers.append(hidden_states)
                subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
                                            as_tuple=False).flatten()
                    
                hidden_states = self.layer[-1](hidden_states,
                                                cu_seqlens,
                                                seqlen,
                                                subset_idx=subset_idx,
                                                indices=indices,
                                                attn_mask=attention_mask,
                                                bias=alibi_attn_mask)

        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers

And because one specifc operation is abstracted, we’ll add a debug/pickler for it as well.

# Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers

import numpy as np
import torch
from einops import rearrange


def blockdiag_weight_to_dense_weight(weight):
    """
    Argumments:
        weight: (nblocks, out / nblocks, in / blocks)
    Return:
        dense_weight: (out / in)
    """
    return torch.block_diag(*torch.unbind(weight, dim=0))


def blockdiag_multiply_reference(x, weight):
    """
    This implementation is slow but more likely to be correct.
    Arguments:
        x: (..., n)
        weight: (nblocks, q, n / nblocks)
    Outputs:
        out: (..., nblocks * q)
    """
    n = x.shape[-1]
    nblocks, q, p = weight.shape
    assert nblocks * p == n

    x_reshaped = rearrange(x, "... (nblocks p) -> ... nblocks p", nblocks=nblocks)
    return rearrange(
        torch.einsum("...kp, kqp -> ...kq", x_reshaped, weight),
        "... nblocks q -> ... (nblocks q)",
    )


class BlockdiagMultiply(torch.autograd.Function):

    """This is a faster implementation, with careful memory copies for the fastest
    bmm performance.
    The backward pass is also written manually with careful memory copies.
    Arguments:
        x: (..., n)
        weight: (nblocks, q, n / nblocks)
    Outputs:
        out: (..., nblocks * q)
    """

    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16)
    def forward(ctx, x, weight):
        ctx.save_for_backward(x, weight)
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = np.prod(batch_shape)
        nblocks, q, p = weight.shape
        assert nblocks * p == n
        x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
        out = torch.empty(
            batch_dim, nblocks, q, device=x.device, dtype=x.dtype
        ).transpose(0, 1)
        out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose(0, 1)

        # Create directory if it doesn't exist
        viz_dir = 'block_multiple_viz'
        os.makedirs(viz_dir, exist_ok=True)

        # Get next file number
        existing_files = os.listdir(viz_dir)
        numbers = [int(f.split('_')[-1].split('.')[0]) for f in existing_files if f.endswith('.pkl')]
        next_num = max(numbers + [-1]) + 1

        # Save tensors
        x_path = os.path.join(viz_dir, f'x_reshaped_{next_num:04d}.pkl')
        out_path = os.path.join(viz_dir, f'out_{next_num:04d}.pkl')
        
        with open(x_path, 'wb') as f:
            pickle.dump(x_reshaped.detach().cpu().numpy(), f)
        with open(out_path, 'wb') as f:
            pickle.dump(out.detach().cpu().numpy(), f)

        # clamp to avoid overflow + see explosions in later layers
        out = torch.clamp(out, min=-10000.0, max=10000.0)
        return out.reshape(*batch_shape, nblocks * q)

    # not used now..
    @staticmethod
    @torch.cuda.amp.custom_bwd
    def backward(ctx, dout):
        x, weight = ctx.saved_tensors
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = np.prod(batch_shape)
        nblocks, q, p = weight.shape
        assert nblocks * p == n
        dx, dweight = None, None
        dout_reshaped = dout.reshape(batch_dim, nblocks, q).transpose(0, 1)
        if ctx.needs_input_grad[0]:
            dx = torch.empty(batch_dim, nblocks, p, device=x.device, dtype=x.dtype)
            dx = (
                torch.bmm(dout_reshaped, weight.conj(), out=dx.transpose(0, 1))
                .transpose(0, 1)
                .reshape(*batch_shape, n)
            )
            dx = torch.clamp(dx, min=-10000.0, max=10000.0)
        if ctx.needs_input_grad[1]:
            x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
            dweight = torch.bmm(dout_reshaped.transpose(-1, -2), x_reshaped.conj())
            dweight = torch.clamp(dweight, min=-10000.0, max=10000.0)
        return dx, dweight


blockdiag_multiply = BlockdiagMultiply.apply

Cool, that gave us the results we’re looking for.

specifically, a single layer was causing explosions, shooting values up to > 65k while mean/std were 0/10 or something like that.

what in the world.

let’s try to see what’s going on in tsne space.

load outputs + slap a tsne on it + viz


x_reshaped = pickle.load(open('x_reshaped_0009.pkl','rb'))
out = pickle.load(open('out_0009.pkl','rb'))

x_2d = x_reshaped.transpose(1, 0, 2).reshape(-1, x_reshaped.shape[-1])  # (batch_dim * nblocks, p)
out_2d = out.transpose(1, 0, 2).reshape(-1, out.shape[-1])  # (batch_dim * nblocks, q)

# TSNE reduction
tsne = TSNE(n_components=3, random_state=42)
x_tsne = tsne.fit_transform(x_2d)
out_tsne = tsne.fit_transform(out_2d)

# Create mask for clamped values
clamped_mask = np.abs(out_2d).max(axis=1) >= 10000

viz

# Create subplot figure
fig = make_subplots(
    rows=1, cols=2,
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
    subplot_titles=('Input Values (TSNE)', 'Output Values (TSNE)')
)

# Add traces for input
fig.add_trace(
    go.Scatter3d(
        x=x_tsne[~clamped_mask, 0],
        y=x_tsne[~clamped_mask, 1],
        z=x_tsne[~clamped_mask, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=np.abs(x_2d).mean(axis=1),
            colorscale='Viridis',
            showscale=True
        ),
        name='Input'
    ),
    row=1, col=1
)

# Add traces for input
fig.add_trace(
    go.Scatter3d(
        x=x_tsne[clamped_mask, 0],
        y=x_tsne[clamped_mask, 1],
        z=x_tsne[clamped_mask, 2],
        mode='markers',
        marker=dict(
            size=8,
            color='red',
        ),
        name='Input'
    ),
    row=1, col=1
)

# Add normal output values
fig.add_trace(
    go.Scatter3d(
        x=out_tsne[~clamped_mask, 0],
        y=out_tsne[~clamped_mask, 1],
        z=out_tsne[~clamped_mask, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=np.abs(out_2d[~clamped_mask]).mean(axis=1),
            colorscale='Viridis',
            showscale=True
        ),
        name='Normal Output'
    ),
    row=1, col=2
)

# Add clamped output values in red
fig.add_trace(
    go.Scatter3d(
        x=out_tsne[clamped_mask, 0],
        y=out_tsne[clamped_mask, 1],
        z=out_tsne[clamped_mask, 2],
        mode='markers',
        marker=dict(
            size=8,
            color='red',
        ),
        name='Clamped Output'
    ),
    row=1, col=2
)

fig.update_layout(
    height=800,
    width=1600,
    title_text="TSNE Visualization with Highlighted Clamped Values",
    showlegend=True
)

fig.show()

what tokens cause this nasty bastard?

hacky pseudo code

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokens = tokenizer.encode(text, add_special_tokens=True)
problematic_tokens = []

for i, token in enumerate(tokens):
    test_text = tokenizer.decode([token])
    data = {
        "model": "super-duper-custom-model",
        "input": test_text
    }
    response = requests.post(url, headers=headers, data=json.dumps(data))
    
    if not response.ok:
        problematic_tokens.append({
            'token_id': token,
            'token_text': test_text,
            'position': i
        })
        print(f"Found problematic token at position {i}: {test_text}")

return problematic_tokens

results = test_individual_tokens(text2)
print(f"\nTotal problematic tokens found: {len(results)}")

well that doesn’t make too much sense.

“owner”? the word “`owner” causes the problem but not, specifically, “owners”? weird. either way, we can fix this problem by clamping after the bmm operator and checking to make sure the clamp doesn’t effect similarity too hard. (test not pictured, it fixes it, casues no issues, all is right in the world)

fin. that’s a wrap. welcome to the world of the odd. debugging precision issues in 2024.

join me next week for another deep dive in to the world of the odd.