Skip to content

Latest commit

 

History

History
5428 lines (4653 loc) · 223 KB

README.md

File metadata and controls

5428 lines (4653 loc) · 223 KB

Hacking on GPT

Note

This is an ongoing documentation of my exploration of the GPT (2) architecture. The primary goal is to answer the following questions:

  • Why does it work?
  • Can we do better?
  • Is the architecture theoretically motivated, and do we have reason to believe that it is a "best" learning algorithm?

The experiments are not intended to be scientific. Experiments are generally run at a tiny scale (n_embd = 768, n_layer = 12, T = 1024, varying batch size, 30M-100M parameters) on a single A100, H100, or GH200. Even at this scale, transformers are magic.

Tip

Here are some observations that I thought are interesting:

  • GPT is magic and feels almost perfect.
  • attn weights can generally be reused with negligible penalty (at this scale); tying mlp weights incurs a substantial penalty
  • The structure of GPT reminds me almost of an advanced combinator calculus. If I had to prove the expressivity of the architecture, I would start there. This view also fits in nicely with the interpretation of the FFN as a "persistent kvCache" (where a block is just x = x + attn(LN(x), kvCache) where kvCache is learnable and has length = hidden_dim(MLP).). Multihead Attention also has a nice interpretation from this point-of-view.
  • The mlp component in particular is quite flexible; it can be replaced by many fun variants, described below.
  • The output of attn may not need to be put into the residual stream. As long as it (and x) is fed as input into the mlp, the network still performs as well (at least at this scale), albeit slightly slower to train. (Relatedly, I suppose that a wide enough mlp can learn a copy/identity operation.)
  • On eliminating skip connections: we can learn to "gate" between passing the entire residual, or a combination of the residual and block output, or exclusively the block output; however, this doesn't improve perplexity and also increases training time. Likewise, we can eschew skip connections entirely if we add an "identity loss", pushing our blocks to compute the identity function; it converges to the same place, but takes 4x longer to train, so I don't see the point. TLDR; skip connections are very nice and also hard to interpret.
  • Throughout, I make an implicit assumption that the 'most natural' architecture is also the one best optimized by SGD (conditioned on some "good baseline initialization"). This is not obviously true, if even true.

Important

I started this project without much background knowledge of the literature (as a cryptographer by training). The documentation for many experiments is also very loose, since it was not indended to be shared publicly. Many early experiments are not documented at all. Most experiments are not rigorous and for intuition only. Many of the observations may seem easy or, alternatively, surprising. On later perusal of the literature, most have been discovered already.

Useful commands

Note

Retroactive Note. Throughout, I will be adding informative retroactive notes, trying to explain and contextualize old experiments (if I remember them).

Important

TODO: Break up / organize this README into something more readable.

torchrun --standalone --nproc_per_node=1 train_gpt2_ben.py

pip freeze -l > requirements.txt

source setup.sh
screen -dm bash -c 'torchrun --standalone --nproc_per_node=8 train_gpt2.py > log/screen.txt 2>&1'

screen -r

screen -dm bash -c 'torchrun --standalone --nproc_per_node=1 train_gpt2_ben.py > screen.txt 2>&1'

screen -dm bash -c 'torchrun --standalone --nproc_per_node=1 train_sequential_ben.py > screen.txt 2>&1'

screen -dm bash -c 'torchrun --standalone --nproc_per_node=1 train_code_ben.py > screen.txt 2>&1'

screen -dm bash -c 'torchrun --standalone --nproc_per_node=1 train_code2_ben.py > screen.txt 2>&1'

Setup:

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126

The original implementation / experimental setup is based on Karpathy's series of tutorials on building GPT from scratch.

Table of Contents

Important

Organization: Experiments are grouped by "topics" or "sections". These topics are ordered by recency, the most recent ones first.

My understanding of GPT has changed drastically over the course of these experiments. Please bear that in mind when visiting the older write-ups!

Notes

Using Attention to Compute the FFN

Hypothesis: The MLP/FFN component is just an instantiation of attention with "code tokens" (from the point of view of our combinator calculus intuition). The FFN basically checks which "code tokens" our "variable tokens" are positioned near to (via just the dot product, but perhaps we can generalize to multi-head asymmetric positioning much like in attention). A code token is now a high dimensional "concept" that can affect left-neighboring tokens by replacing them like a constant function (with itself, or some linear transformation of itself). (We can try allowing variable tokens to be applicators on code tokens as well, but that doesn't appear to do much, as prior "code mode" experiments seem to demonstrate.) In this section, let's try to demonstrate this and merge the two components.

Implifications if True: A unified architecture, better theoretical analyzability. Hopefully also insights on how to scale.

Note

In retrospect, FFN as Attention is an observation previously made in https://openreview.net/forum?id=HklJdaNYPH. I only discovered this when browsing the literature after running these experiments :)

Activation Functions

The Attention layer uses a softmax where as the FFN uses a GeLU in our implementation. Can we unify these?

What happens if we replace GeLU with a softmax?

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = MLPSoftmaxExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

What about GeLu with a x*softmax(x)? Indeed, it seems mildly important that the "value matrix" of MLPs is "selective" (on what?):

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 mlpalt
        inp = self.ln_2(x + attn) # (B, T, n)
        y = self.c_fc(inp) # (B, T, 4n)
        y = y*F.softmax(y, dim=-1) # (B, T, 4n)
        mlp = self.c_proj(y)
        return mlp + attn
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = MLPSoftmaxExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

In attention, what happens if we replace softmax with a ReLU or GeLU? What in the world is going on here? There should be no reason that the first token loss is improving, unless the k and v matrices themselves are learning MLP parameters.

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
        self.attn = GeLUAttention(config)
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

Now, refactoring the FFN to look sort of like at attention, but using softmax attention for the usual attention. Note that initially, I left out the bias and the component was essentially a no-op. This should be mathematically identical -- so why is there a slight difference, moreso than weight tying? (TODO figure out. I am surprised that GeLU attention didn't need a bias. TODO figure out why)

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 ffn as attn
        inp = self.ln_2(x + attn) # (B, T, n)
        # kv = self.kv_attn(self.ffn_vars) # (1, NUM_MLP_TOKENS, 2n)
        # k,v = kv.split(self.n_embd, dim=2) # (1, NUM_MLP_TOKENS, n)
        k, v, bias = self.c_fc.weight.unsqueeze(0), self.ffn_vals, self.c_fc.bias.unsqueeze(0).unsqueeze(0)
        mlp = self.ffn_attn(inp, kvCache=(k,v,bias))
        return mlp + attn
----------------
 machine_modules
        self.attn = CausalSelfAttention(config)
        self.compiler = BenCompilerNoOp(config)
        self.execute = FFNAsAttentionExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

What happens if I do a "multi-head attention" here for the FFN? The outcome is not great, so it's possible the whole premise (of FFN as attn) is flawed:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 ffn as attn
        inp = self.ln_2(x + attn) # (B, T, n)
        k = self.k_proj(self.c_fc.weight.unsqueeze(0))
        v = self.v_proj(self.ffn_vals)
        bias = self.c_fc.bias.unsqueeze(0)
        mlp = self.ffn_attn(inp, kvCache=(k,v,bias))
        return mlp + attn
----------------
 machine_modules
        self.attn = CausalSelfAttention(config)
        self.compiler = BenCompilerNoOp(config)
        self.execute = FFNAsAttentionExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

Note

It seems that there is some sort of substitution going on. If we imagine a lambda calculus expression #x.a(x b), x tells us where to substitute in a later expression in the context of the a (x b). The FFN network here seems to play a similar role; perhaps it adds positional information to an input y s.t. y is placed exactly where x is placed in the expression above. Perhaps x itself is that positional information, so dot(x, b) should be large to indicate closeness. So, if we imagine a row of our transformer as being the current state of some expression tree, each MLP row potentially also represents an "unbound variable" in this expression, that when applied to an expression (according to the MLP query), essentially substitutes that expression into the appropriate locations. But it's not exactly a lambda calculus, because the expression is specified by the tree, instead of a term in a single node; in some sense, it's more reminiscent of the S combinator in an SKI combinator.

So, maybe one way to think of this is, the attention component takes sibling concepts and adds them together. Then, if they are already positioned in the right place (i.e. position information should not have changed since it was just a sum of two prior blocks), i.e. adjacent to an MLP substitution combinator, then, a substitution is performed (but a copy is also kept in the original location, can we get rid of that?). Each ffn row, I suspect, performs a single substitution to a single location, but we usually want to copy to multiple locations at once, which is why having multiple independent rows is important.

Of course, the FFN can add more than just positional information --- it probably also adds conceptual information. But, is conceptual and positional information one and the same? Suppose that they are.

In this theory, when a concept is copied, it is still treated as a single "applicator", but now it is acting on multiple locations in parallel. Perhaps it is worth splitting them up into multiple applicators / columns in our code. (TODO) I.e., each attention head should "add" a new column of evaluation? And when there are too many columns, we prune it somehow...

Briefly, we rerun the experiment where it is essentially vanilla GPT, but we pass in the output of attention straight to the MLP without summing in the residual; it appears to destroy the performance of the FFN at first. (We have run a similar experiment before, see ``11-mlponly-mlponattn-value`.)

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
        self.attn = CausalSelfAttention(config)
        self.compiler = BenCompilerNoOp(config)
        self.execute = MLPOfAttnExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

Now, we do "multi-head" MLP, but not for the values, which we just sum up over heads (instead of concatenating smaller vectors and then projecting). It is better than before but still a penalty compared to vanilla MLP. Perhaps this multihead thing is only useful if a softmax is used?

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 ffn as attn
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
k, v, kbias = kvCache
cacheT = k.size(1)
k = k.view(1, cacheT, self.n_head, C // self.n_head).transpose(1, 2)
v = v.unsqueeze(0)
att = (q @ k.transpose(-2, -1)) + kbias.unsqueeze(0)
att = self.gelu(att)
y = torch.matmul(att, v)
y = y.sum(dim=1).squeeze(1)
----------------
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = FFNAsAttentionExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

Instead of GeLU, now let's use a softmax. This takes absurdly long per step, and I need to investigate why, but let's try it first:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 ffn as attn
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
k, v, kbias = kvCache
cacheT = k.size(1)
k = k.view(1, cacheT, self.n_head, C // self.n_head).transpose(1, 2)
v = v.unsqueeze(0)
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
y = y.sum(dim=1).squeeze(1)
----------------
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = FFNAsAttentionExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

Let's just merge it together, just to break some ground and write some code, and see how it works. Surprisingly, the MLP gets better, but the attention gets worse! Perhaps this is because the softmax is allowing the MLP to steal some of attention's thunder.

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + attn
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Let's double the number of layers. Perhaps, by allocating more resources, the FFN doesn't steal all of the thunder:

Transformer, max LR 0.0003 n_layer 24
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + attn
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Go back to 12 layers, and tie attention weights; I suspect this is a little worse! So there is probably a bottleneck in the number of heads:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + attn
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

Increasing the number of heads (and head size now goes down to 48, since we keep the embedding dimension the same size), and no longer tie weights. It doesn't appear to make much of a difference; surprisingly again, the MLP contribution improves, instead of the attention contribution:

Transformer, max LR 0.0003 n_layer 12 
n_head 16 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + attn
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

What if we double the learning rate (and set n_head back to 12)? It doesn't help much.

Transformer, max LR 0.0006 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + attn
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Is there any natural way to prevent the FFN from stealing Attention's thunder?

While we think about that, I am curious what happens if I do GeLU attention. It seems to hurt more than it helps:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + attn
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Adding the attention sink back:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
----------------
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + attn
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Adding a layer norm because I presume that perhaps the signal is too small; indeed it improves it slightly:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Switching back to gelu, and keeping the Layer Norm. I like gelu because I see no reason why the FFN and attention concepts need to compete with each other --- there should be multiple winners, so using a softmax seems like too strong of an inductive bias. Let's also add in an additional learnable bias term (one per head) before applying the GeLU. Unfortunately, the outcome is not great:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 gelucode
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att + self.lin_bias[:,:,:T,:T_with_cache]
att = self.gelu(att)
att = att.masked_fill(self.attn_bias[:,:,:T,:T_with_cache] == 0, 0.0)
y = att @ v
----------------
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

So let's go back to a softmax setting, and try to balance that out. Perhaps I should layer norm the learned k and v parameters:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (self.ln_3(k).to(torch.bfloat16), self.ln_4(v).to(torch.bfloat16), bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

In case the contribution from non-MLP parameter's is too small, let's in fact layer norm everything, in the code of the attention, so that everything has unit variance:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Before, we were normalizing the signal only after combining the heads. Now, let's normalize each head output individually; it doesn't seem to have much effect:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Remove the per-head LN, and now make sure to LN the q vector. It works quite well!

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Returning to the per head regime, where we normalize only after splitting the key/query/value vectors into their multiple heads (computing this is excruciatingly slow, at least in our naive implementation). It seems unlikely to make a difference (mathematically variance shift is uniform anyways):

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Let's try GeLU one more time, but with the appropriate LN on the k, v, q vectors. I guess attention must really choose only a few to attend to...

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 gelucode
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att + self.lin_bias[:,:,:T,:T_with_cache]
att = self.gelu(att)
att = att.masked_fill(self.attn_bias[:,:,:T,:T_with_cache] == 0, 0.0)
y = att @ v
----------------
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Using RMSNorm for the kqv vectors, there is a penalty; the network seems quite sensitive to the choice of norm.

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=False
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Switching back to LayerNorm, but still keeping ELEMENTWISEAFFINE=False:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=False
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Now, using the x*softmax(x) activation:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 activation fn
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att + self.lin_bias[:,:,:T,:T_with_cache]
att = att.masked_fill(self.attn_bias[:,:,:T,:T_with_cache] == 0, float('-inf'))
att = att*F.softmax(att, dim=-1)
y = att @ v
----------------
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

It is a little suspicious that the first token loss gets better in this architecture. It turns out that in vanilla GPT, if we remove the FFN entirely, but introduce an extra attention sink (k=0,v=0), we also recover the first token loss of vanilla GPT with the FFN present, which is very strange! I'm not sure why this is happening, it is really odd.

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=True
TIE_ATTN_WEIGHTS=False

caption

Note

TODO spend more time looking into the above, it is really weird... I really don't understand why. It is surprising/suspicious that the first token loss gets better in this architecture. TODO try training an MLP-only architecture and then see what it looks like.

Let's run a baseline, note that ATTN_LAYER_NORM=True (like in the previous experiment as well). It appears to match our merged performance at first, but later beats it a bit, never surpassing the true baseline:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=True
TIE_ATTN_WEIGHTS=False

caption

And now back a rerun of the original baseline; it turns out that setting ATTN_LAYER_NORM=True actually penalizes vanilla performance, a bit, but only in the first few steps:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

For good measure:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=True
NUM_MLP_TOKENS=3072

caption

Without tying attention weights: TODO

Top-K instead of softmax

Instead of softmax, take top 128:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 activation fn
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
attinf = att.masked_fill(self.attn_bias[:,:,:T,:T_with_cache] == 0, float('-inf'))
_, top_k_indices = torch.topk(attinf, k=128, dim=-1)
att = torch.zeros_like(att).scatter_(-1, top_k_indices, att.gather(-1, top_k_indices))
y = att @ v
----------------
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
NUM_MLP_TOKENS=3072

caption

What about top 1:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 activation fn
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
attinf = att.masked_fill(self.attn_bias[:,:,:T,:T_with_cache] == 0, float('-inf'))
_, top_k_indices = torch.topk(attinf, k=1, dim=-1)
att = torch.zeros_like(att).scatter_(-1, top_k_indices, att.gather(-1, top_k_indices))
y = att @ v
----------------
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
NUM_MLP_TOKENS=3072

caption

Vanilla GPT, but take as output the V vector instead of the whole concept:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=True

caption

Merged GPT, but take as output the V vector instead of the whole concept:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
bias = self.c_fc.bias.unsqueeze(0).unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, bias)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=True
NUM_MLP_TOKENS=3072

caption

Perhaps what we should really do (when evaluating the FFN) is to apply the top-n_heads FFN concepts (i.e. copy them to at most n_heads locations), because we can't handle more than n_heads worth of them anyways (can think of each head as a parallel evaluation). (Of course, if we copy twice, it blows up, so maybe should use fewer. But 1 seems too few.) Is there a more differentiable version of top-k, for k=n_heads?

In this experiment, instead of computing a softmax normally (there, one normalizes by \sum \exp x_i), instead we normalize by the sum of all the odds except that of the top k=n_heads/2. (This implementation is really slow, but certainly can be made fast):

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
k = self.c_fc.weight.unsqueeze(0)
v = self.ffn_vals
kvCache = (k, v, none)
y = self.ln_1(x)
attn, _ = self.attn(y, y, print_weights=print_weights, kvCache=kvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=True
TIE_ATTN_WEIGHTS=False
NUM_MLP_TOKENS=3072

caption

Also, mlp concepts should not attend to each other; what happens if we add an extra error term such that the self-attention (amongst mlp concepts) tends to zero, making sure to add an attention sink?

Using K as the query vector when querying "FFN states"

First, a baseline for an attention-only transformer:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

Using k instead of q to query FFN concepts (and applying the attention independently):

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
q = self.c_fc.weight.unsqueeze(0)
v = self.ffn_vals
qvCache = (q, v, None)
y = self.ln_1(x)
attn = self.attn(y, y, print_weights=print_weights, qvCache=qvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
NUM_MLP_TOKENS=3072

caption

Additionally normalizing the final contributions from the two attention computations:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
q = self.c_fc.weight.unsqueeze(0)
v = self.ffn_vals
qvCache = (q, v, None)
y = self.ln_1(x)
attn = self.attn(y, y, print_weights=print_weights, qvCache=qvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
NUM_MLP_TOKENS=3072

caption

GeLU instead of softmax:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
q = self.c_fc.weight.unsqueeze(0)
v = self.ffn_vals
qvCache = (q, v, None)
y = self.ln_1(x)
attn = self.attn(y, y, print_weights=print_weights, qvCache=qvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
NUM_MLP_TOKENS=3072

caption

Adding a bias:

Transformer, max LR 0.0003 n_layer 12 
n_head 12 n_embd 768
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 block_logic
q = self.c_fc.weight.unsqueeze(0)
v = self.ffn_vals
qvCache = (q, v, None)
y = self.ln_1(x)
attn = self.attn(y, y, print_weights=print_weights, qvCache=qvCache)
newx = x + self.ln_2(attn)
----------------
========
REUSE_WEIGHTS=False
ELEMENTWISEAFFINE=True
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
NUM_MLP_TOKENS=3072

caption

Signal Propagation

So, generally speaking, perhaps one way to interpret the power of softmax is that within a single "subspace", it is prudent not to take the average of too many concepts (because then, the signal is lost). It suffices to "combine" just a few concepts. Or rather, perhaps each token specifies some other location to copy the applicator to, as this appears to be the role of the value matrix. (See next paragraph for importance of the value matrix.) Perhaps this is the role of every concept that is "attended to" -- that is, the applicator is the concept that exposes the k vector, and every applicator can be viewed as "copying" the subject to a position specified by v. Then, if heads use a softmax, each head copies the subject to essentially a single new position. Then, each attention application copies each concept to n_head new locations. Note that in vanilla GPT, a normalized signal is fed as input into the attention layer; so only the largest signals (in the residual) matter. So for "new" signals, perhaps we expect that they have a larger magnitude to account for a "growing residual", so that they are useful in the next step. Somehow, this gets learned? See one of the below experiments. In any case, perhaps we should expect these "new locations" to replace the old ones by virtue of simply being larger. (Update: it turns out that my expectation is potentially wrong?)

It also seems that the value matrix is important for the performance of the first column in a computation. Here is Vanilla GPT, replacing the value matrix with the identity. (Note that in the reusing weights regime, this penalty does not seem to show up, interestingly.)

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

In Vanilla GPT, what happens if I normalize the output of attention, and the output of the FFN, before adding it back to the residual? Namely, x + self.ln_0(mlp) + self.ln_1(attn) instead of x + mlp + attn. I predict that this should drastically alter the behavior, since I suspect that the residual should grow as it is fed through more and more layers, and so the output of later layers must also grow to have an effect. (The prediction is wrong! The perplexity seems to behave the same; if anything, the perplexity of the first token gets better.)

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

So let's try to explain the previous result, where the contribution of each layer does not grow, and is similar in magnitude to each other layer. It seems important that the residual contains the whole tree, and that new components are not weighted more than prior components.

When we tie weights, the curves diverge a little bit in the beginning:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=True

caption

In fact, according to 22-ln-res-debug, the norms of the q vectors, k vectors, and v vectors (averaged across n_heads, B, and T) seem to be actually rather structured; they follow a pattern as we get deeper into the network, and the resulting attention matrix seems similar:

@ 2000 train 3.9872 , allloss: 4.0237,
qnorm: ['6.1855', '9.7890', '17.6464', '17.4868', '14.8330', '11.0258', '9.4597', '9.0222', '8.9031', '8.6727', '8.6815', '7.4363']
knormList: ['5.8492', '9.0334', '16.8719', '16.8863', '14.7853', '9.5540', '7.9108', '7.3659', '6.8504', '6.4701', '6.2542', '5.7161']
vnormList: ['3.4351', '4.9895', '3.5287', '3.7144', '4.2510', '4.8346', '4.8180', '4.7209', '4.5903', '4.5837', '4.5443', '4.3687'], tok/sec: 136663.90, flops:103.38, batch-reuse:1

Returning to the question of how we can expect a network that does not "prioritize" new concepts to work---let's try one more experiment where we LN MLP, Attn, the k, q, v vectors, and also tie attention weights:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=True
TIE_ATTN_WEIGHTS=True

caption

The prior experiment is surprisingly good, but not better than the vanilla, and maybe worse in the long term. And what if I now stop normalizing the inputs into the components (but continue to normalize the outputs, and qkv)? It trains easier at first, but later gets worse.

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=True
TIE_ATTN_WEIGHTS=True

caption

Now, setting ATTN_LAYER_NORM=False (so qkv is no longer normalized, but we still do x + LN(mlp(x + attn(x))) + LN(attn(x))), it again trains faster in the early stages, but later it loses in perplexity. In this case, note that the norms of the qkv vectors differ substantially between layers, in a way that I don't currently understand:

@ 7965 train 3.4692 , allloss: 3.4871, gap: 2.1040, dt: 549.24ms, gptloss: 3.4871, extraloss: 0.0000, qnorm: ['0.1751', '9.2063', '11.3693', '11.2771', '9.9027', '9.9097', '10.7547', '12.8102', '14.8945', '17.1459', '19.2298', '20.4718'], knormList: ['0.1695', '39.8808', '43.6074', '44.0956', '42.8097', '42.7953', '40.6297', '41.1652', '31.1243', '29.2986', '23.5657', '20.5410'], vnormList: ['0.1397', '6.5994', '6.6087', '7.9303', '8.0594', '8.4875', '9.1998', '11.1374', '12.4090', '14.0631', '15.7936', '17.3155'], tok/sec: 238642.54, flops:180.52, batch-reuse:1
Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=True

caption

Now, not tying weights (still x + LN(mlp(x + attn(x))) + LN(attn(x))), this actually is worse compared to the prior experiment:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

Recall that a pre-LN transformer should also outperform a post-LN transformer, the difference being that in a post-LN transformer, the residual signal is normalized, whereas in a pre-LN world, it is not. With a normalized residual, the gradient for deeper layers should be much larger than that of earlier layers. Here is post-LN, but I suspect that already my learning rate is set too high:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.mlp = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

What about x = LN(x + LN(mlp(x + attn(x))) + LN(attn(x)))? I.e. let's do a post LN, but also normalize the new contributions of each layer. So we expect the gradient to slowly shrink for earlier layers, but at a "respectable pace". Without this shrinking effect, I don't see how we can use the same attention head across different layers, without mixing up signals (is this why we lose something when tying attention weights?). Our learning rate is still too high here...

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.mlp = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

Let's keep track of both pre-LN and post-LN residuals; as input to attn and the FFN, we feed the post-LN residual, but as the output of the network, let's use the pre-LN:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
self.attn = CausalSelfAttention(config)
self.mlp = VanillaExecute(config)
----------------
# x_normed: recent tokens are prioritized
attn, _metadata = self.attn(x_normed, x_normed, print_weights=print_weights)
attn = self.ln_5(attn)
mlp = self.ln_6(self.mlp(x_normed))
newx = x + mlp + attn
x_normed = self.ln_1(x_normed + mlp + attn)
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTN_LAYER_NORM=False
TIE_ATTN_WEIGHTS=False

caption

The takeaway: I guess historic tree state is important, and signal can be carried across layers. But that still doesn't answer, how come, if we apply a LN to inputs to attn, and ffn, it still works? Perhaps the effect of the LayerNorm is small; as long as there two concepts are even slightly non-orthogonal, it doesn't matter what their magnitude is, perhaps? But this seems at odds with the (poor) performance of Post-LN. For Post-LN, though, a layer norm is applied multiple times, perhaps dividing by \sqrt(hidden_dim) n_layer times. For attn and ffn, it is only applied once to the input.

What if, instead of LN, I keep things at \sqrt(hidden_dim) std? Perhaps that is functionally the same.

Tip

The above is the most recent experiment.

Inference Inspired Explorations

Motivation: I am curious what happens if we make the training process more inference-like; that is, during the forward pass, we allow the model to process one input token at a time, and "think longer". After all, when a human being reads a sentence, they don't read the entire sentence at once in parallel (unlike a GPT); instead, we read one token at a time, and probably add it to some hidden state. (So this is certainly very reminiscent of State-Space Models, or even selective ones like Mamba.) Perhaps this "thinking process" should be a first-order citizen when training, and perhaps the forward pass should be more sequential and much longer --- which is also reminiscent of chain-of-thought. The downside of this approach is, of course, that it neuters what is probably biggest advantage of Transformers: that they can be easily parallelized during training.

Initial Design: Just to explore this notion, let's first build a model that takes input one token at a time. On seeing a new token, it then runs that single embedded token through n_layer transformer blocks, where the attention is taken over all prior outputs of transformer blocks up to a context size T. On finishing the run-through, a loss is computed against the expected target (for this token) and added to the total loss. Then, we repeat with the next token. So the attention matrix will look something like T x 1 size when the context size is hit. To make this efficient, we implement a kvCache much like in inference.

Initial Forays

Important

The loss here plotted in this section is somewhat different than the loss plotted normally. Now, we plot a variety of different losses:

  • train-lossbacked: this is the usual notion of loss, that we called loss.backward() on.
  • train-lasttoken: this is the loss of the last token in the context (averaged over batch).
  • train-gap: this is the difference between the loss of the last token and the loss of the first token. The rationale is to make it possible to compare the performance of models with different T. This doesn't make much of a difference if both T are very large, but if one T is very small, it seems to make a difference. (It turns out that, in retrospect, this "gap" has a name and is also known as an in-context learning score, i.e. see the Appendix of https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)

Plots with two subplots now also plot train-gap. Hopefully this is a good heuristic for quality of context-based learning.

Establishing a baseline: To start, I first re-run the vanilla GPT experiment, with a slightly lower learning rate, and extracting the loss of the last token, as well as the gap in loss between the loss of the last and the first tokens:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(x)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 loss_fn
            trueloss = F.cross_entropy(_logits.view(-1, _logits.size(-1)), targets.view(-1), ignore_index=-1, reduction='none') # (B*T)
            loss = trueloss.mean() + _xtraloss
            firstloss = trueloss.view(B, T)[:, 0].mean()
            trueloss = trueloss.view(B, T)[:, -1].mean()
            gap = firstloss - trueloss
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Fun Experiment 1: The initial experiment, where the loss is backpropagated through the entire network (see actual code). (Note that there are some fun pitfalls during implementation; in particular, is_causal should no longer be set to be true. Note that attention weights are also tied.)

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=16 AccumNum=6.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 loss_fn
                trueloss = F.cross_entropy(_logits.view(-1, _logits.size(-1)), targetj.view(-1), ignore_index=-1)
                loss += trueloss / T
                if j == 0:
                    firstLoss = trueloss.detach()
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
LOW_RANK_ATTN=False

caption

Compare with a baseline GPT with largely the same parameters (including T=16 and the same learning rate / printing only the loss of the last token); the curves are very similar. Oddly though, the "context gap" of funexperiment is much worse than the baseline gpt. This means that, somehow, funexperiment is making gains in the zero context setting, perhaps because we are training it more / invoking those transformer blocks many more times. (Really, we want the opposite: we want to pick up more in-context learning more efficiently.) Also note the strange periodicity in our data, we should probably shuffle it...

Transformer, max LR 0.0003 n_layer 12
Minibatch=1024 T=16 AccumNum=3.0
Total Batch Size=49152
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 loss_fn
            trueloss = F.cross_entropy(_logits.view(-1, _logits.size(-1)), targets.view(-1), ignore_index=-1, reduction='none') # (B*T)
            loss = trueloss.mean() + _xtraloss
            firstloss = trueloss.view(B, T)[:, 0].mean()
            trueloss = trueloss.view(B, T)[:, -1].mean()
            gap = firstloss - trueloss
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Let's increase the batch size for short context vanilla GPT, by 4x. The takeaway is that certainly increasing batch size can drastically change training dynamics.

Transformer, max LR 0.0003 n_layer 12
Minibatch=3072 T=16 AccumNum=4.0
Total Batch Size=196608
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 loss_fn
            trueloss = F.cross_entropy(_logits.view(-1, _logits.size(-1)), targets.view(-1), ignore_index=-1, reduction='none') # (B*T)
            loss = trueloss.mean() + _xtraloss
            firstloss = trueloss.view(B, T)[:, 0].mean()
            trueloss = trueloss.view(B, T)[:, -1].mean()
            gap = firstloss - trueloss
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Thoughts. Let's think about it for a bit. It's odd and also not odd at all that funexperiment behaves so similarly to vanilla GPT. I guess the extra thinking time doesn't help, and also that attention is able to pick out which embeddings are useful. Note that in funexperiment, the frontier is essentially one token or one application at a time, and moreover it is always the most recent token. I wonder if anything changes if we widen the frontier, or have historic tokens also be applicators once in a while.

Here, I detached the kvCache before feeding it through the network. The hypothesis was that "future influences" perhaps don't matter to old historic embeddings. It turns out that it is likely important to have "subsequent influences" on embeddings beyond the immediate loss for the next token, even when attention is shared.

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=16 AccumNum=6.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
LOW_RANK_ATTN=False

caption

A quick comparison with vanilla GPT with context_size = 16, and similarly sized batches (I think). Short context certainly does make a difference!

Transformer, max LR 0.0024 n_layer 6
Minibatch=3072 T=16 AccumNum=4.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(x)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Perhaps most of the damage is done because without large T, training is no longer particularly parallelised. What about T=32:

Transformer, max LR 0.0024 n_layer 6
Minibatch=1024 T=32 AccumNum=6.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(x)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Now, T=32 but we halve the total batch size, keeping everything else the same:

caption

Overall, it's unclear how fair these comparisons are. One explanation is that when T is large, more tokens on average have more context available, before comparing them with the target. So it might not speak to the actual quality of the model, but is rather a byproduct of the structure of our training setup.

Another version of the idea

Hypothesis: One issue, I suspect, is that the "computation frontier" is too small. Here, we only compute on one "applicator" token at a time. Moreover, new tokens have no ability to influence old tokens. I suspect we need to increase the size of the "computation frontier" so that it can act on new tokens and persist new knowledge in the "hidden state". In the following experiments, we continue to do the sequential evaluation, but after an embedding is evaluated against the target, it now is once again fed through further transformer blocks, and attention only includes the row of embeddings, and not historic blocks.

In the first experiment, attention weights are tied, and T=16.

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=16 AccumNum=6.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
LOW_RANK_ATTN=False

caption

Now, let's set is_causal=True again; I suspect that attending to future tokens, but no longer being penalized by a local loss, is perhaps hurting performance in this set of experiments. That's not true, in fact it appears to have no effect at all:

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=16 AccumNum=6.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
LOW_RANK_ATTN=False

caption

I am curious what happens if we run frontier-1, but this time, for post-evaluation embeddings, we compare every token in the same "evaluated" row against the same target (see code). The outcome is not great.

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=16 AccumNum=6.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
LOW_RANK_ATTN=False

caption

A smarter computational frontier

If our hypothesis is true that "more thinking time" should lead to a stronger model architecture, it certainly hasn't shown up in the funexperiment set of experiments yet. I would have expected, for instance, the T=16 context funexperiment to outperform the corresponding vanilla gpt with T=16 (and all other parameters held the same), yet it doesn't. To fully rule it out, maybe we need to evaluate it on larger T, but the problem is that the attention window grows exorbitantly large (since it is size T*num_layers) and training takes exorbitantly long and a lot of memory.

Instead of growing the attention window to size T*num_layers, we should really cap it at some size M, giving the machine essentially a finite memory and again giving it a state space model flavor. Thus the question essentially boils down to, on seeing a new token, which old embedding do we evict from the kvCache?

Hypothesis: In the following experiment, we try evicting the embedding that is "least attended to" by the newest token (summed across attention heads). There are some immediate potential problems (e.g. what if the next token cares a lot about it, but the current token doesn't, so we evict something we should have kept?) but experiments are cheap, so let's just try it out. Note that M=16 here; to fully simulate the old funexperiment without evicting, we should have set T=16 and M=16*n_layer, so already a lot of eviction is being done with respect to the n_layer factor. Note that the implementation is also really slow, and can probably be optimized later if it seems worth it to do so:

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=32 M=16 AccumNum=6.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
LOW_RANK_ATTN=False

caption

Increasing T=64 and M=32:

Transformer, max LR 0.0003 n_layer 6
Minibatch=384 T=64 M=32 AccumNum=8.0
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
LOW_RANK_ATTN=False

caption

So clearly, the "cross sectional component" seems important to the performance of GPT. I wonder if this would change in the "reusing-weights" regime. Perhaps what is going on is that, in vanilla GPT, individual layers are tailoring their outputs to subsequent layers (and that is somehow very important to performance), whereas here that information is lost because of our kvCache eviction policy, and because we attend to all prior layers at once. Now, of course, this structure can be learned, but at the expense of extra parameters. In the vanilla layer-based model, each layer knows the subsequent layer; if it were to be learned, the information would have to be put into (and later removed from) the residual stream. We can imagine a system where we do put the information into the residual stream and then use a MoE setup (for efficiency) to ensure that it is routed to the correct layer.

Selective Attention

Question: What happens if I only "let through" the "strongest" attention head? Measure strength by the amount that is put on the attention sink.

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 attn_weights [TIE_ATTN_WEIGHTS]
                if TIE_ATTN_WEIGHTS:
                    # Tie model weights together
                    firstBlock = self.transformer.h[0]
                    for block in self.transformer.h:
                        block.attn.c_attn.weight = firstBlock.attn.c_attn.weight

----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=True
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

TODO: check if there is routing collapse, since we haven't added auxiliary loss or balancing bias term.

TODO: compare with running single attention head; if similar, probably there is routing collapse?

TODO: is this O projection matrix really necessary in vanilla GPT world?

Else, barring training problems, it seems that the fact that the multiple attention heads are all "accumulating information in parallel" is very important for efficient converging (though maybe it eventually does converge?).

Factuality

Hypothesis: By penalizing wrong answers, can we help the model learn better? The cross entropy already does this to some extent, but perhaps we want to exaggerate this, and more strongly penalize very bad answers.

It turns out that the model doesn't learn better, and given that I currently don't have a setup for evaluating hallucinations (if that's even well defined at this model size... these models are barely speaking English) I suspect any impact is hard to measure. This experiment adds another component to the loss equal to the largest probability that is not the target, for each point:

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 attn_weights [TIE_ATTN_WEIGHTS]
                if TIE_ATTN_WEIGHTS:
                    # Tie model weights together
                    firstBlock = self.transformer.h[0]
                    for block in self.transformer.h:
                        block.attn.c_attn.weight = firstBlock.attn.c_attn.weight

----------------
 loss_fn
            _xtraloss = ((logsoftex.max(dim=-1)[0]).exp()).mean() # (1)
            trueloss = nllLoss(logsoft, targets.view(-1))
            loss = trueloss.mean() + _xtraloss
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

Hypothesis: Can we learn how bad answers are?

Self-Consistent Models

Try: (1) adding in previous output back into residual stream; (2) try the post confidence loss again.

In the sequential model, let's see what happens when I add in the RMS of the previous output tokens and the RMS of the new embedded tokens:

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=16 M=16 AccumNum=6.0
Total Batch Size=49152
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=True
LOW_RANK_ATTN=False

caption

If we concat the two and then run it through a down projection:

Transformer, max LR 0.0003 n_layer 6
Minibatch=512 T=16 M=16 AccumNum=6.0
Total Batch Size=49152
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=True
LOW_RANK_ATTN=False

caption

Hodpodge of Explorations

Learning Programs

I am curious if prepending extra "learnable tokens" (i.e. playing the function of code) might help the network scale. With 768 prepended tokens, initialized with std 1, it seems to be a noop:

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 code_logic
        # Now, concat our code (NOTE: shoudl we add positional embeddings)
        code_expanded = self.code.unsqueeze(0).expand(b, -1, -1)
        tok_emb = torch.cat((code_expanded, tok_emb), dim=1)
----------------
 network_logic
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

With 128 prepended tokens, initiated with std 5:

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 code_logic
        # Now, concat our code (NOTE: shoudl we add positional embeddings)
        code_expanded = self.code.unsqueeze(0).expand(b, -1, -1)
        tok_emb = torch.cat((code_expanded, tok_emb), dim=1)
----------------
 network_logic
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

Instead of prepending, what if we directly add it to the input tokens? I set std back to 0.5 because otherwise it has trouble converging... This is definitely not great.

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 code_logic
        # Now, concat our code (NOTE: shoudl we add positional embeddings)
        code_expanded = self.code.unsqueeze(0)  #.expand(b, -1, -1)
        tok_emb = tok_emb + code_expanded[:, -t:, :]
----------------
 network_logic
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

Just for fun, we run the original 128-code std=5 experiment in conjunction with MLPConcat. I suspect that for these code experiments, the gradient is pretty tiny, (should just print and check) because there just isn't really any immediate point in attending to them?

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 code_logic
        # Now, concat our code (NOTE: shoudl we add positional embeddings)
        code_expanded = self.code.unsqueeze(0).expand(b, -1, -1)
        tok_emb = torch.cat((code_expanded, tok_emb), dim=1)
----------------
 network_logic
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

I decided to run the signal gate design again. It doesn't perform as well, but it's not fatal; I suspect that learning in "superposition" is important (and somehow the network learns to "dwarf" earlier residuals?). TODO investigate

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
        self.throughput = nn.Parameter(torch.tensor(-2.0))
        torch.nn.init.normal_(self.throughput, mean=-2.0, std=0.02)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        tr = torch.sigmoid(self.throughput)
        newx = (1 - tr) * x + tr * machineOutput
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False
CODE_MODE=False

caption

Thoughts on "Mixture of Experts"

I finally read about MOE a bit more. A few questions:

What if we do routing as a function of just the pre-attention token x? And then we send in the attention into the expert. According to my principle, these lookup tables are probably best bucketed by applicator (though I imagine that having both also works).

Alternatively, if we route only the output of the attention layer, maybe we should send the output of each attention head to a different expert. This would be the opposite interpretation of the above paragraph; it would be odd if the look-ups went through together. Alternatively, we can code the previous paragraph, and then call that expert n_heads times, once for each of the attention heads. (If the attention heads compute orthogonal things, then this probably doesn't matter so much. I suspect it will usually compute orthogonal things, but we can't rule out "similar" attention pairs being placed together...)

For efficiency, we can also consider projecting down the keyspace into a smaller space. (This projection matrix is perhaps shared by all of the nodes?)

Honestly, it'd be nice to increase the learning rate for the MLP sections only, that would be nice.

Note

Retroactive Note 2/25 . I suspect we can reuse attention weights, but not reuse MLP weights. One question: If scaling transformers boils down to making the MLP as big as possible, assuming that the positional features of the language are limited in nature, then how come we don't train another model to "learn" the MLP truth table? Here, the inputs would be the applicator x, the attention output attn(x), and the output would be some replacement. Does the current recursive structure already do this sublearning problem?

I suspect that the answer is no, it does not already do this sublearning problem. The language of this inner learning problem is fundamentally different: different tokens (i.e. the bits corresponding to x | attn(x), instead of language tokens), and thus a different interpreter. It is computation in an entirely different language.

One fun experiment could be to train an inner LLM on this inner language. But I can see why an MLP is sufficient as an approximator.

It also begs the question: why not tokenize the original text into smaller tokens? It is not obvious to be that the resulting language would be simpler, since the "code" (i.e. the true input) is still the same, so the interpreter must have same overall complexity. In some sense, it may have increased complexity, if the input representation is less efficient (i.e. not a nice encoding, spending lots of bits on rare words). Imagine if the vocab size decreases, I imagine that this is a trade-off between the mlp size (smaller), and the attention complexity / depth of computation (because positionally, different tokens may end up in much more various associations than before, i.e. a character attending to every character of its word). We do need to make sure that there are enough attention parameters to fully apply all possible relationships, because the process is destructive and subtrees are overwritten (or are they? Recall that we have the residual and also no-op attention, i.e. attention sinks).

I wonder if these days, LLMs are good at copying text. I.e. consider a prompt that dumps in 3 paragraphs, and then says "Copy the above text, verbatim", or even something like "Copy the above text, verbatim, but add the letter e in between every character." If not, maybe consider an attention mechanism that allows pointing to a whole swathe of text (text specialized manipulation), and perhaps instead of summing it up, we maintain pointers somehow; then, when these pointers are fed into the MLP or MLP alternative, the MLP acts on the data that is pointed to, almost like an internal function call (except here it is just a dictionary lookup, for now). Would that be better? (TODO: implement)

On smaller tokenizations and MLP size, and learning context (Code Mode)

A potential question is whether there is a setting (of hyperparameters) such that we don't have to set extraordinarily large MLP sizes; that is, supposing that the complexity of the target function itself remains unchanged, can we build networks with small MLPs that can still learn the target? Well, the complexity of the computation would then need to be expressed in the language, i.e. input as code, and I don't see how something like tokenizing it to be smaller, or increasing the size of the attention layers, would help with that. If tokenized poorly, I suspect it will require more attention and mlp complexity (and depth) to reconstruct "true tokens" and then require at the at point the same attention/mlp resources as before.

This still begs the question: imagine a design where, in addition to learning parameters for the interpreter, our system also learns a "context" written in the same "language" as the input (perhaps encoding parts of the target function f). We can imagine the input as a "string", for instance, and the context as tokens for the python language; certainly learning a python interpreter is easier than learning an english interpreter. How come trying to learn such a context didn't work in our experiment above? (If it worked, it would be nice: the attention and mlp components could be smaller, while maintaining the same expressivity.)

Perhaps. the problem previously is that our "contextual language" is as hard to learn as the "target language", or, in other words, no one is actually teaching our system to speak in the easier "contextual language" (i.e. python code examples in this metaphor), and instead all of the training examples are in the target language (i.e. english completions). I am not sure this reason is perfect, thought: surely, it's possible to learn both a python interpreter, and python code for understanding english, when fed training examples of english? (Does the function itself truly already need to know how to speak the "contextual language", to learn context?)

I thought about it for a long time, and frankly, I think it should work. I see no reason (at the current point) why learning a context would not also be a way to scale, and I think that series of experiments is worth probing at.

Some more thoughts. First, the gradient right now at those "learnable" context tokens may be very small. (Should check it experimentally.) That may be one reason it doesn't help much in the current form. Second, perhaps the effect is more prevalent if I reduce the embedding dimension n_embd, or vocab size. That is, perhaps contextual learning is not yet good enough that adding context would help. Third, perhaps the resource trade-offs are just too bad. This technique would require much wider and deeper networks; it's probably easier just to scale MLPs instead.

In general these experiments have not panned out. Here is an example where we interleave "dummy columns" in between real tokens, where the starting "dummy tokens" are trainable parameters. At some point this becomes a hindrance to the in-context gap, though it doesn't appear to be reflected in the actual loss (strange!):

Transformer, max LR 0.0003 n_layer 12
Minibatch=8 T=1024 AccumNum=16.0
Total Batch Size=131072
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 init_logic [CODE_MODE]
        elif isinstance(module, GPT) and CODE_MODE:
            torch.nn.init.normal_(module.code, mean=0.0, std=5)
----------------
 code_logic [CODE_MODE]
        if CODE_MODE:
            code_expanded = self.code[-t:].unsqueeze(0).expand(b, -1, -1) # same code for every batch
            z = torch.stack((tok_emb, code_expanded), dim=2) # (b, t, 2, n_embd)
            tok_emb = z.reshape(b, 2*t, -1) # (b, 2t, n_embd)
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK_ROUTER=False
TIE_ATTN_WEIGHTS=True
TIE_MLP_WEIGHTS=False

caption

On attention

Hypothesis: If I fix the Key and Query matrices (ignoring the value matrix / using an identity value matrix, and skipping the projectio step, for simplify), can the remainder of the network compensate and still learn effectively? I suspect that the answer is yes, based on my intuition that attention is simply a module that puts ``nearby tokens'' together. Fixing K and Q, perhaps the rest of the network will learn to put tokens directly in the correct space (conditioned on K and Q), instead of needing an extra learning step.

A couple of experiments to run here:

  • Reuse attention weights, do not reuse MLP weights
  • Choose K and Q smartly (in particular taking into account the multiple heads, the component for each head should be different), and fix them. (Surely we should pick it in some structured way? With multihead attention it's not obvious to me.)

Somewhat as expected, reusing attention weights does not fundamentally impact ther performance, although there is a noticeable penalty --- so, given the current MLP size, perhaps it is worth increasing the size of the attention component in a subsequent experiment. The current experiment, tying together attention weights (there's no bias):

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False
CODE_MODE=False
TIE_ATTN_WEIGHTS=True

caption

Compare that to a world where we tie MLP weights (but leave attention untouched). This one is surpisingly good to me: perhaps the attention setup is able to shoulder some of the MLP load:

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 mlp_weights [TIE_MLP_WEIGHTS]
                if TIE_MLP_WEIGHTS:
                    # Only works with BenBlock, set module to be the same
                    firstBlock = self.transformer.h[0]
                    for block in self.transformer.h:
                        block.execute = firstBlock.execute
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=True

caption

Now, let's investigate what happens if we fix the attention matrices. Recall that I'm guessing that fixing them should not drastically impact model performance. So, let's set the K,Q,V matrices to be fixed and initialized according to a Gaussian with sqrt(1/n_embd) std. (I'm not sure if V should be fixed, but let's try this first because it requires less code modification.) Unfortunately, the outcome seems to be that we have neutered the network somewhat:

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 fixed_attn [NO_GRAD_ATTN]
        if NO_GRAD_ATTN:
            for block in self.transformer.h:
                block.attn.c_attn.weight.requires_grad = False
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Now, we let V be free, and only fix K and Q:

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 fixed_attn [NO_GRAD_ATTN]
        if NO_GRAD_ATTN:
            for block in self.transformer.h:
                block.attn.c_attn.weight.requires_grad = False
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Maybe K and Q have to be chosen together, in some structured way. Fixing only K:

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 fixed_attn [NO_GRAD_ATTN]
        if NO_GRAD_ATTN:
            for block in self.transformer.h:
                block.attn.c_attn.weight.requires_grad = False
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Compare with ommitting attn entirely.

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = NoAttnExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(x)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

Note, in the previous experiment I also fixed a bug with VanillaExecute where I was passing in the layer norm of the residual into the mlp input (i.e. ln(ln(x) + attn)) instead of ln(x + attn)). Hopefully it doesn't change too much (it should only affect the more recent experiments); let's check the vanilla experiment again. Surprisingly, it seems to have no affect at all.

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(x)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
TIE_ATTN_WEIGHTS=False
TIE_MLP_WEIGHTS=False

caption

TODO I forgot to run, fix attention, reusing weights

On the Residual

Attention Sinks

Q: why does magnitude of the residual matter so much? Layer norming it completely destroys it... Even halving it, completely destroys it. (I imagine that it gets smaller and smaller every step, converging to zero.)

New experiment! I'm trying to figure out if we can avoid passing the entirety of the residual along. Note that all layer loss is on, and reusing weights is true. We have added a "Zero Sink" as well (todo: figure out how to make this computation more efficient)

self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = y + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=True
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True

loss plot

If we turn reusing weights off,


self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = y + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=True
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True

loss plot

Let me see if I can extract resw efficiently by commandeering a direction in the embedding dimensionality. The hope is that by using resw, we can zero out the contribution from the residual (noop) completely if the attention is high enough. We keep this optimization for all future attention sink experiments: (This is not a fair comparison, because we also use xWeights now)

self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights*y + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=True
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True

loss plot

Now, use MLPConcat instead:

self.compiler = BenCompilerNoOp(config)
self.execute = BenExecute(config)
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights*y + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=True
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True

loss plot

What if I don't layer norm the attention before feeding it into the mlp? (Otherwise same as above)

loss plot

Should really find a way to start with the residual with weight 1 (instead of 0)...

Let's try the same sinkgate but with vanilla execution instead:

========
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights*y + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=True
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True

loss plot

JUst for fun... Let me turn REUSE_WEIGHTS on and all_layer_loss off, and feed in the entire residual again. It looks almost identical to 16-sinkgate-vanilla-2, which has reuse_weights off and all_layer_loss on and the weighted residual. (TODO: why?)

========
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True

loss plot

Now, let's turn on reuse weights but also do the xWeights again.

========
self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights*y + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=True
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True

loss plot

The Identity Mechanism

First, we reduce the learning rate, and run the vanilla experiment

self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = x + machineOutput
======== 
max_lr = 0.25*6e-4
min_lr = max_lr * 0.1
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=False

loss plot

Now, we run an experiment where we do the weighting thing again...

self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config) 
========
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights * x + (1-xWeights)*machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

loss plot

Surprisingly, if I get rid of the (1 - xWeights), things are substantially worse. I suspect that this is because the mlp component is additive? Or is the attn component itself just too large (because of the value matrix?). I suspect that it is entirely becaues of the MLP (outdominating the residual component)

self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config) 
========
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights * x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

loss plot

Now, we switch to BenExecute, which currently looks like

class BenExecute(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    
    def forward(self, program, attn):
        return self.mlp(program, attn) # self.ln_2(attn)

Note that we do not layernorm the attn signal, in hopes that if it is attenuated, the mlp will recognize that and attenuate its own output as well.

self.compiler = BenCompilerNoOp(config)
self.execute = BenExecute(config) 
========
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights * x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

loss plot

Next, we add the layer norm in BenExecute and the (1-xWeights) back in. The motivation is to hopefully increase the residual faster, and to make it easier for the attention component to zero itself out. Surprisingly, this absolutely fails.

self.compiler = BenCompilerNoOp(config)
self.execute = BenExecute(config) 
========
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights * x + (1-xWeights)*machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

loss

Why did the previous fail? Is it because we put the layer norm back? Or is it because we added the (1-xWeights)? Let's remove the layer norm again from BenExecute. It works now, but note the strange time when the standard deviation of output went out of control:

loss

That seems to be why. Now, it seems that mlpconcat is not very good, but in prior experiments it has usually converged to vanilla transformers, and is just slower to train. But it fits my mental model better, so let's stick with it for now, and try to up the training rate:

Let's see if this is still stable, if we return the learning rate to the original of 6e-4; also we switch from cosine similarity to a norm of the difference:

loss

The norm of the output in the previous experiment was very high. What if we layer norm it? It starts off bad but eventually converges. Also note the horizontal section in the beginning while we wait for fracRes to converge to some high number like 0.8:

class BenExecute(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    
    def forward(self, program, attn):
        return self.ln_2(self.mlp(program, attn))

loss

Hmmmmm. I suspect such a program is just too complicated to learn? What happens if I push the identity loss much higher? It is rather unstable... I'm not sure it ever converges at this training rate.

_xtraloss = _xtraloss + 2*torch.linalg.norm(_x - _in, dim=-1).mean()

loss

What happens if I switch it to the 1-norm? I also remove the layer norm wrapping the output.

_xtraloss = _xtraloss + torch.linalg.norm(_x - _in, dim=-1, ord=1).mean()

loss

Let's do something fun, setting n_layer=48, and using the infinity norm: (Note that when i use ord=2, it doesn't even make it past a loss of 10. This somehow feels like an important observation, but at this point I don't know what to make of it)

_xtraloss = _xtraloss + torch.linalg.norm(_x - _in, dim=-1, ord=float('inf')).mean()
self.compiler = BenCompilerNoOp(config)
self.execute = BenExecute(config) 
========
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights * x + (1-xWeights)*machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=8
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

loss

It's a little screwy! It seems to be getting caught when activating the effects of attention. Set n_layer=12 just to speed it up a little bit, and also newx = xWeights * y + (1 - xWeights) * machineOutput. This is awful, which is odd, because usually layer norming x (y = LN(x)) would make the standard deviation of x larger, not smaller.

loss

Let's try the same thing but pop the learning rate lower. Note that ord = 2, in this experiment (The outcome is horrible) (12 layers)

Experiment description: Transformer, max LR 0.00015
self.compiler = BenCompilerNoOp(config)
self.execute = BenExecute(config) 
========
y = self.ln_1(x)
attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = xWeights * y + (1-xWeights)*machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=2
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

loss

Let me try the same thing with ord = inf. It is infinitely better. (Why??) But it cannot seem to get the residual fraction to be high. (12 layers)

loss

So let me swap back the x for y and set ord=2. This is pushing up the residual fraction again. But, it diverges.

Transformer, max LR 0.00015
Setting:
========
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
========
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = xWeights * x + (1 - xWeights) * machineOutput
========
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
                _in = x.detach()
                _x, _ = block(_in,print_weights=False) # Do again... lol
                _xtraloss = _xtraloss + torch.linalg.norm(_x - _in, dim=-1, ord=2).mean()
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=2
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

caption

Same thing but ord=inf again. Still having trouble getting the residual fraction up:

Transformer, max LR 0.00015
Setting:
========
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
========
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = xWeights * x + (1 - xWeights) * machineOutput
========
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
                _in = x.detach()
                _x, _ = block(_in,print_weights=False) # Do again... lol
                _xtraloss = _xtraloss + torch.linalg.norm(_x - _in, dim=-1, ord=float('inf')).mean()
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=2
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

caption

Same thing but with cosine similarity. This one is much nicer. (Similar to the y + infinite ordinality)

Transformer, max LR 0.00015
Setting:
========
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
========
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = xWeights * x + (1 - xWeights) * machineOutput
========
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
                _in = x.detach()
                _x, _ = block(_in,print_weights=False) # Do again... lol
                _xtraloss = _xtraloss + (1 - F.cosine_similarity(_x, _in, dim=-1).mean())
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=2
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

caption

Same thing, but MLP_SCALE back to 4.

Transformer, max LR 0.00015 n_layer 12
Setting:
==machine code======
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
==machine modules======
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
==block logic======
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = xWeights * x + (1 - xWeights) * machineOutput
==loss computation======
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
                _in = x.detach()
                _x, _ = block(_in,print_weights=False) # Do again... lol
                _xtraloss = _xtraloss + (1 - F.cosine_similarity(_x, _in, dim=-1).mean())
========
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False
ATTENTION_SINK=True
IDENTITY_LOSS=True

caption

Reuse weights back to false. This will converge to 17-identity-test.

Transformer, max LR 0.00015 n_layer 12
Setting:
==machine code======
class BenExecute(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)

    
    def forward(self, program, attn):
        return self.mlp(program, attn)
==machine modules======
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
==block logic======
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = xWeights * x + (1 - xWeights) * machineOutput
==loss computation======
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
                _in = x.detach()
                _x, _ = block(_in,print_weights=False) # Do again... lol
                _xtraloss = _xtraloss + (1 - F.cosine_similarity(_x, _in, dim=-1).mean())
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
IDENTITY_LOSS=True

caption

Same thing, but with higher learning rate. And n_layers to 8.

Transformer, max LR 0.0006 n_layer 8
Setting:
==machine code======
class BenExecute(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)

    
    def forward(self, program, attn):
        return self.mlp(program, attn)
==machine modules======
        self.compiler = BenCompilerNoOp(config)
        self.execute = BenExecute(config)
==block logic======
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = xWeights * x + (1 - xWeights) * machineOutput
==loss computation======
                x, metadata = block(x,print_weights=print_weights,step=i)
                _x_total = x
                _in = x.detach()
                _x, _ = block(_in,print_weights=False) # Do again... lol
                _xtraloss = _xtraloss + (1 - F.cosine_similarity(_x, _in, dim=-1).mean())
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=True
IDENTITY_LOSS=True

caption

Next steps: increase learning rate. What happens if i turn x to y? Also, our hope was that removing residual would "clarify" signal (because that's how computation works). Do we see improvement in that regard (is perplexity the best way to measure that?).

An environmental point of view

First, we see whether, if we absolute-value the residual at every layer, does it change performance compared to vanilla GPT? (Yes, it does, kind of similar to reusing weights, I think. It is a consistent penalty.)

Transformer, max LR 0.0006 n_layer 8
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = (x + machineOutput).abs()
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

Compare to (somethign is off --- why is this worse than 13-baseline? Because we only use 8 layers instead of 12, lol.):

Note

Retroactive Note 2/25 . There's another bug where I fed in as input to mlp, mlp(ln(ln(x) + attn)) instead of mlp(ln(x + attn)). Hopefully this does not make too much difference; in the most generic experiment, indeed it does not (see 19-vanilla).

Transformer, max LR 0.0006 n_layer 8
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

We can potentially think of our network as defining, recursively, a series of "contexts" -- where a context consists of an input string (i.e. comprised of T token embeddings), and an output target (usually computed by "inverting" the network, which is implicitly done during back-propagation). Each context thus defines a "gap", which is the loss, or distance between the input and the target. In a theory of life, for instance, we can posit that "organisms" would like to fill each gap.

Now, let's try to motivate the attention mechanism from this point of view. The MLP is easy to motivate: it is essentially memorizing a truth table of inputs and outputs. The issue with the MLP (and the universal approximation theorem) is that, of course, it does not scale. The attention mechanism, on the other hand, allows us to scale, in a computational way. It turns out that, by only memorizing a few things and having this "tape" or "graph structure", much like in a Turing Machine, we can simulate any computation in the world; the input embeddings then, in some sense, become the code, and the base rules become the interpreter. The problem is that our interpreter must remain fairly complex: the language is very high level, and very expressive in an immediate way. Yet, at the same time, we don't want to memorize too much in order to build the interpreter. I posit that MLPs both map input tokens to more expressive "program" tokens and also memorize "base building block" functions, whereas the attention component "positions" related tokens together, like a "conditional application" of function tokens to data tokens. In other words, it is some expressive variant of a combinator calculus, which is somehow the perfect trade-off between memorization and expressiveness of the language that is being interpreted (if there is a trade-off at all...)

An organism then comprises an attention component ("where do I feed") and an MLP component ("what do I extract"). Conceivably, we can have an organism that feeds everywhere, but it should be out-competed by an organism that feeds selectively on topical tokens. (How do multiple attention heads build into this?) Similarly, if we have two organisms in sequence that are not attuned to the overall gap, there should be no stability and it is unlikely for any of them to fill the gap. Instead, one organism should try to fill the gap first the best it can, and the remaining gap (its output, and the target) can subsequently be filled by a second organism. (This may be the role of skip connections.)

We may even posit a much more fluid network dynamics, where organisms have not only a choice of which of T tokens to feed-on, but also where they are located in the network; moreover, some organisms may be copies of each other, and perhaps this would facilitate solutions to the Prisoner's Dilemma. (As far as I can imagine, gradient descent rewards selfish behavior only, by design. Cooperation seems only to arise by accident, when organisms are not yet fully attuned? Or does cooperation arise because behavior is never fully discrete, so with 1% chance both organisms cooperate, and this eventually dominates? We can think of it being computed in superposition. But the contribution of cooperation amongst multiple organisms is going to be miniscule, and hard to discover, I'm not sure. Perhaps once loss plateaus, it too will eventually have its time to shine, no matter how small... Indeed perhaps cooperation does arise by chance. But certainly it may not arise in unstable environments.)

If some organisms share code, perhaps cooperation arises faster and easier? This feels less fundamental. But in any case, let me implement some "block routing" mechanism for each layer, which sends the signal to one of n_layer organisms (and in the whole network there are only n_layer organisms) according to a softmax. To start, we skip the softmax, and just apply every block at every layer. Note that this is quite bad! (Note the bug with _xtota, should see 18-router-3 instead, likely won't be better than that, probably similar.)

Transformer, max LR 0.0006 n_layer 8
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 network_logic
                _finx, metadata = block(x,print_weights=print_weights,step=i)
                for j in range(self.config.n_layer):
                    if i == j:
                        continue
                    b = self.transformer.h[j]
                    _x, _metadata = b(x,print_weights=False,step=i)
                    _finx = _finx + _x
                x = _finx
                _xtotal = x
                # x, metadata = block(x,print_weights=print_weights,step=i)
                # _x_total = x
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

Now, try to learn a weighted block, i.e., depending on a layer, a block may decide not to contribute. Still no good! (We didn't backprop properly...) (Note the bug, should see 18-router-3 instead, likely won't be better than that, probably similar.)

Transformer, max LR 0.0006 n_layer 8
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
        newx = torch.sigmoid(self.routes[step]) * newx
----------------
 network_logic
                _finx, metadata = block(x,print_weights=print_weights,step=i)
                for j in range(self.config.n_layer):
                    if i == j:
                        continue
                    b = self.transformer.h[j]
                    _x, _metadata = b(x,print_weights=False,step=i)
                    _finx = _finx + _x
                x = _finx
                _xtotal = x
                # x, metadata = block(x,print_weights=print_weights,step=i)
                # _x_total = x
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

What if we build a proper softmax router:

Transformer, max LR 0.0006 n_layer 8
Setting:
==details======
 machine_code
class BenExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLPConcat(config)
        # self.ln_2 = nn.LayerNorm(config.n_embd, elementwise_affine=ELEMENTWISEAFFINE)
    def forward(self, program, attn):
        return self.mlp(program, attn)
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = VanillaExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
----------------
 network_logic
                routes = F.softmax(self.router[i], dim=-1)
                routes.requires_grad_(True)
                routes.retain_grad()
                x.requires_grad_(True)
                x.retain_grad()
                # print(f"routes grad {routes.grad}")
                # print(f"x grad {x.grad}")
                _finx, metadata = block(x,print_weights=print_weights,step=i)
                _finx = routes[i] * _finx
                for j in range(self.config.n_layer):
                    if i == j:
                        continue
                    b = self.transformer.h[j]
                    _x, _metadata = b(x,print_weights=False,step=i)
                    _finx = _finx + routes[j] * _x
                x = _finx
                _x_total = x
                metadata[f"routes_{i}"] = routes
                # x, metadata = block(x,print_weights=print_weights,step=i)
                # _x_total = x
----------------
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False

caption

(Deprecated) Attempts at building a more perfect network

Note

Retroactive Note: 3/25: I forget what the design was exactly, but it can be found in the git history. Certainly my understanding of GPT has evolved a lot since this set of experiments.

This is how I would currently design our network. The assumption is that, if a token is not an applicator, i.e. it is not applying itself to anything, it should just be shoveled up another level / compute a no-op, because there is no application to be done. The only issue may be that of the layer norm... maybe resx should add x instead and not y. Generally we should ensure bias is initialized to zero, and M starts my computing the identity. Should M compute a basic matrix rotation, or should it compute a full blown MLP? I think it should maybe compute a full blown MLP --- currently, it doesn't, which is maybe why the residual is useful, as now we can apply the nonlinearity in the next layer. (But how do we initialize an MLP to idenitty? Is this why we need the residual? Qn: is the attention residual important? I posit that the answer is No, based on prior experiments where I fed attn directly into MLP as input and never aded it to the residual stream. E.g. img/14-nodiagonal-noattn.jpg). (Resx is broken -- need to initialize, so just use x for now) Hmmm... it's not great, it's not horrible.

y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
hiddenBias, fParams, bParams = self.compiler(y)
machineOutput = self.execute(attn, fParams, bParams, hiddenBias)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=True
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=True

loss plot

Does more parameters help?

y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
hiddenBias, fParams, bParams = self.compiler(y)
machineOutput = self.execute(attn, fParams, bParams, hiddenBias)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=True
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=128
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=True

loss plot

Note that, without zero-ing the diagonal, performance is actually kind of horrible:

y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
hiddenBias, fParams, bParams = self.compiler(y)
machineOutput = self.execute(attn, fParams, bParams, hiddenBias)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=True
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False

loss plot

Zero'ing the diagonal... wait, it is incredible fragile. Sometimes it works sometimes it doesnt, I really have no idea what is going on:

y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
hiddenBias, fParams, bParams = self.compiler(y)
machineOutput = self.execute(attn, fParams, bParams, hiddenBias)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=True
EXTRACT_SELF_CONTRIBUTION=False

loss plot

Why might a machine that only executes matrix multiplications (+ bias) be as expressive as one that executes an entire MLP? I guess the matrix multiplication (+ bias) is fed into the next layer anyways, which has a mlp/nonlinearity?

Traditionally, with residual, say the desired mapping is H(x), with the residual, we can imagine it wanting to learn H(x) - x, which should not be any harder, but the idenitty H(x) = x is easier to learn.

Also, all things considered, it seems quite reasonable to learn some function f(x, attn) instead, maybe by feeding it directly into mlp(x || attn). If attn is small, perhaps we want this to return the identity, and then it is easier to directly learn x = x + mlp(x || attn). But it seems like a chicken and egg; to replace x entirely then mlp(x || attn) needs to learn -x.

An alternative is to initialize the matrix M in M@x+b to the identity matrix, and b to the 0 vector; but what about the non-linearity?

Let me try just concatenating x + mlp(LN(x) || attn(LN(x))). This doesn't perform that great, or better, but it doesn't perform worse, because it does converge..

(compiler just outputs y, execute outputs mlp(y || attn))
y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False

loss plot

Doing due diligence, same experiment as before, but MLP_SCALE=8. It doesn't appear to improve the base model substantially.

(compiler just outputs y, execute outputs mlp(y || attn))
y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=8
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False

loss plot

Let's try large mlp scale with vanilla transformer. Note that tit doesn't have the same impact as 14-mlpmatrix-moreinner, but note that the latter has 700M parameters versus our 520M parameters. This one still doesn't appear to improve the base model.

self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=32
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False

loss plot

Am curious how large we can go, can we get to 700M parameters? Let's set MLP_SCALE=64, we get 973,701,120 parameters. TLDR; MLP is enough, we don't need our fancy compiler stuff. (Would it help with reusing parameters?)

self.compiler = BenCompilerNoOp(config)
self.execute = VanillaExecute(config)
y = self.ln_1(x)
attn, resx, scores = self.attn(y, y)
program = self.compiler(y)
machineOutput = self.execute(program, attn)
x = x + machineOutput
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=64
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=4096
MLPMAT_INNER_SIZE=64
DELETE_SELF_CONTRIBUTION=False
EXTRACT_SELF_CONTRIBUTION=False

loss plot

Takeaways so far

Attention as positional applicator, should be allowed to no-op. It essentially computes neighbors(x)

MLP(x) * y applies x to y = neighbors(x) in a way specified by x. The result is added to the computation graph (x).

I.e. imagine left application (mirroring traditional visualizations of combinator calculus, so below, x is called on y)

        ?
    y       x
        prevy   prevx
            ppy     ppx
                pppy    pppx

What about multiple arguments? I.e. traditionally

        ?
     ?      z
  ?     y  
S   x

reduces to

      ?
  ?       ?
x  y    x   z

Here, we have something like

    ?
z       ?
    y      ?
        x     S

and what maybe happens is that x is just added to the residual, and so is y, while they wait for the final argument, and then finally MLP(S+x+y)*z + (S+x+y) computes the application (But in this case why do we need to add S+x+y back to the residual? Maybe to defer it, and try again, to make optimization easier? The more times the better I guess? But then z should be added back in in the next layer. And it is, through the residual from the previous column.) Now, why might we want to do MLP(x)*(attn(x) + x), instead of just MLP(x)*attn(x). Well, maybe the effect of MLP in prior layers was incomplete, so we want to make sure that MLP can keep cleaning it up?

I really do think we should use zero'd attention, but it is really quite too slow, and the non-zero'd version is a good approximation.

Tinkering with the MLP, and More

One goal is to figure out how to re-use weights with minimal penalty, since that would be very nice.

Note

Retroactive 3/25: Again, my interpretation has changed much since then; I am no longer particularly attached to needing to tie MLP weights together.

First, let me try a new version of all-train, where we compute a loss at every layer, but only propagate the loss one layer deep each time. This works really poorly (somewhat expected, since we're not reusing weights, so how can they hope to learn anything quickly?). This shows that with poor oprimizatbility, it's hard to distinguish that from poor expressivity.

x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))

loss plot

Evven reusing weights, it is horribe:

x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))

loss plot

I guess it is important for backprop to be somewhat deep. Let me remove the detach(), and stop reusing weights (and interestingly, it's not any slower than with the detach(), TODO why is that?). Generally, computing loss at every layer is simply worse than computing only at one layer.

x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))

loss plot

Let's talk about mlp matrices. Here, we map y to a (CxC)-matrix M (using the appropriate decomposition), and then directly apply M to x. Here, we use an inner matrix of size 48x48.

y = self.ln_1(x)
attn = self.attn(y)
mlp, bias = self.fatmlp(y)
M = self.matrixfromparams(mlp)
x = M @ attn + bias + x
========
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Let's get rid of the bias and increase the size of the inner matrix to MLPMAT_INNER_SIZE = 128. This isn't very good at all! The bias seems to be quite important.

y = self.ln_1(x)
attn = self.attn(y)
mlp, bias = self.fatmlp(y)
M = self.matrixfromparams(mlp)
x = M @ attn + x
========
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Let's add the bias back but keep MLPMAT_INNER_SIZE = 128. This does well. Is this because of more parameters, or due to the more expressive MLP? Why is bias so important? Maybe it allows for replacing the applicator.

y = self.ln_1(x)
attn = self.attn(y)
mlp, bias = self.fatmlp(y)
M = self.matrixfromparams(mlp)
x = M @ attn + bias + x
========
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

One question is what happens if we keep the same size for fatmlp (i.e. are the parameters more useful for somewhere else)

One question (TODO) is how to encode multiple application? In other words, a copy operation.

Multiple applications (x + y)? Apply to same attendees? Hmm. Should really be one applicator per column. And in my opinion, a column should comprise either an applicator or an applicatee; a token shoud not able to be BOTH (unless in some parallel superposition?) (Perhaps dot product attn(x) with x to see if it is a no-op, first starting in the no-value regime).

Let's try to get rid of the residual... we revert to mlp*attn just to make this experiment fast. Note that sometimes the dot product is negative --- why? Note that the value matrix is set to false. Which means that other words that we are attending to have opposite embeddings of the applicator. It turns out this is completely broken (and surely the same thing is achieved just by forwarding along attn...)

y = self.ln_1(x)
attn = self.attn(y, y) 
app = torch.linalg.vecdot(attn, y,dim=-1).unsqueeze(-1) 
mlp = self.mlp(y)
# NOTE sometimes negative dot product (why?)
app = (torch.sigmoid(app) - 0.5) * 2  # app is -1 or 1
x = mlp * attn + app * x
========
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

On that note, I am curious, if attn alone is enough of a residual (answer: it is not.) The plot looks similar to if we had used y instead of attn; maybe attn is destructive because of the layer norm, not because it is an attention component. (perhaps mlp is wiping it out)

y = self.ln_1(x)
attn = self.attn(y, y)
mlp = self.mlp(y)
x = mlp * attn + attn
========
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

We should probably normalize app... else everything probably has large dot product (and I suspect layer norm doesn't normalize it enough). (Um, what did I do) (Also, instead of sigmoid, should just use tanh)

y = self.ln_1(x)
attn = self.attn(y, y) 
siz = torch.linalg.vecdot(y, y,dim=-1).unsqueeze(-1)
app = torch.linalg.vecdot(attn, y,dim=-1).unsqueeze(-1) / siz
# print(app[-1,-1,-1].item())
mlp = self.mlp(y)
# NOTE sometimes negative dot product (why?)
app = (torch.sigmoid(app) - 0.5) * 2  # app is -1 or 1
x = mlp * attn + app * x
========
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Trying app selection with our weird matrix conglomerate (this is exceedingly slow). It is simultatenousy promising and unpromising, and I'm not patient enough to run it all the way (should really optimize app computation if we are going to run more experiments.)

y = self.ln_1(x)
attn = self.attn(y,y)
siz = torch.linalg.vecdot(y, y,dim=-1).unsqueeze(-1) # (B, T, 1)
app = torch.linalg.vecdot(attn, y,dim=-1).unsqueeze(-1) / siz
app = (torch.sigmoid(app) - 0.5) * 2  # [-1, 1]
m, bias = self.fatmlp(y)
M = self.applymat(m, attn) #(B, T, 3*C), (B, T, C) -> (B, T, C)
x = M + app * bias + app * x
======
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

Honestly, we should also flip the app on the bias. Ideally app is 0 to 1. (Also how do we support this computation with the value matrix?) Well, this turns out to be strictly worse than mlpmatrix-moreinner...

y = self.ln_1(x)
attn = self.attn(y,y)
siz = torch.linalg.vecdot(y, y,dim=-1).unsqueeze(-1) # (B, T, 1)
app = torch.linalg.vecdot(attn, y,dim=-1).unsqueeze(-1) / siz # may be greater than 1
app = (torch.sigmoid(torch.abs(app)) - 0.5) * 2  # [0, 1]
m, bias = self.fatmlp(y)
M = self.applymat(m, attn) #(B, T, 3*C), (B, T, C) -> (B, T, C)
x = (1-app) * (M + bias) + app * x
======
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

TODO: pull app out directly from KQ matrix, I think i have a bug somewhere.

Let's exeriment with mlpmatrix-moreinner some more. Here, we want to know why axm works but here we need the bias. What if we bias x attn.

y = self.ln_1(x)
attn = self.attn(y,y)
siz = torch.linalg.vecdot(y, y,dim=-1).unsqueeze(-1) # (B, T, 1)
app = torch.linalg.vecdot(attn, y,dim=-1).unsqueeze(-1) / siz # may be greater than 1
app = (torch.sigmoid(torch.abs(app)) - 0.5) * 2  # [0, 1]
m, bias = self.fatmlp(y)
M = self.applymat(m, attn) #(B, T, 3*C), (B, T, C) -> (B, T, C)
x = M + bias*attn + x
======
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

Let's exeriment with mlpmatrix-moreinner some more...

y = self.ln_1(x)
attn = self.attn(y,y)
siz = torch.linalg.vecdot(y, y,dim=-1).unsqueeze(-1) # (B, T, 1)
app = torch.linalg.vecdot(attn, y,dim=-1).unsqueeze(-1) / siz # may be greater than 1
app = (torch.sigmoid(torch.abs(app)) - 0.5) * 2  # [0, 1]
m, bias = self.fatmlp(y)
M = self.applymat(m, attn) #(B, T, 3*C), (B, T, C) -> (B, T, C)
x = M + bias + x
======
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

I am curious what happens if we omit attn entirely from the residual stream in the vanilla architecture. It eventually does converge! It is just slightly harder to optimize. (So clearly MLP needs to map the output of attention only.) (Note that in the log, perc(-) starts off near 1 and then gets smaller and smaller, closer to 0.6; this corresponds to the number of (B,T) s.t. the sum of self attention scores (over the 12 heads) is less than 1. This number starts off huge, but eventually it seems that it is getting smaller; so, the overall sum is getting larger, so it is slowly learning to attend more and more to the self? Confusing.)

Experiment Name: 15-baseline
y = self.ln_1(x)
attn = self.attn(y)
x = x + self.mlp(self.ln_2(x + attn))
======
DELETE_SELF_CONTRIBUTION=False
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

Again, let me omit the diagonal from attn, but here for vanilla transformer. It works well (but unclear if the gains are substantial, or increasing.)

y = self.ln_1(x)
attn = self.attn(y)
x = x + attn
x = x + self.mlp(self.ln_2(x))
======
DELETE_SELF_CONTRIBUTION=True
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

What if we do both, i.e. omit the diagonal, and also exclude attn from the residual? This converges, and eventually does even slightly better than the original...

y = self.ln_1(x)
attn, _ = self.attn(y, y)
x = x + self.mlp(self.ln_2(x + attn))
======
DELETE_SELF_CONTRIBUTION=True
MEASURE_SELF_CONTRIBUTION=True
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

Note that at this point I also silently normalized the VALUE_MATRIX=True code to divide the sum of values by the number of heads.

And again, what happens if we omit the residual entirely from the mlp. We have tried this before, and it takes longer to converge (but it should converge?) But this also breaks our intuition of needing to run self.mlp(on the applicator). So maybe this one shouldn't converge at all. (After experiment) the results are:

y = self.ln_1(x)
attn, scores = self.attn(y, y)
x = x + self.mlp(attn)
======
DELETE_SELF_CONTRIBUTION=True
MEASURE_SELF_CONTRIBUTION=True
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

(TODO I did not run this experiment) And now, instead of feeding attn as input into mlp, we apply our mlp thing instead. This experiment we've run before, except now delete self contribution is true.

y = self.ln_1(x)
attn, scores = self.attn(y, y)
m, bias = self.fatmlp(y)
M = self.applymat(m, attn)
x = x + M + bias
======
DELETE_SELF_CONTRIBUTION=True
MEASURE_SELF_CONTRIBUTION=True
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MLPMAT_INNER_SIZE=128

loss plot

I am curious if the last layer of the MLP needs a bias in vanilla transformers:

========
y = self.ln_1(x)
attn, score = self.attn(y, y)
x = x + attn
x = x + self.mlp(self.ln_2(x))
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
MEASURE_SELF_CONTRIBUTION=False
NEW_ALL_LAYER_LOSS=False
MATRIX_NUM_PARAMS=16384
MLPMAT_INNER_SIZE=128
DELETE_SELF_CONTRIBUTION=False

loss plot

Messing with Signal Propagation

One question is, how come we don't just sum together the entire context window? Well, the thing is that some tokens attend to other tokens selectively, and not others. So it could be that [A B] does something, but [C B] does not, and we should really only sum tokens together when we know it (potentially) does something. (Or, does attention compute instead whether A and B are positioned together / near each other?) (In that case, this whole key query interpretation is kind of ridiculous, and can maybe be replaced by inverses of the position embedding.)

So, continuing this line of thought, if attention computes whether two tokens are positioned adjacent to each other (whcih would be a function of their positions only and not necessarily their semantic content, unless the two are conflated, which they may well be in natural language), then mlp should compute its replacement, namely, if C = [A B], then MLP should learn C. This has the property that it describes both the inverse computation and the forward computation; i.e. if we are trying to compress the output of some computation, I think that this compression map can be learned exactly by learning C (i.e. learning some language for the forward computation). [Not going to lie, this is somewhat trippy.]

Importantly, it could also be that [A B] is a no-op, that is, it is already in its most simplified form, and so [A B] = [A B]. It could also be that [A B C D E] = [A B' C D E] is mostly a no-op, but there is some small operation. But, it seems that [A B] should fully replace the separate A and B from the previous layer (i.e. applying positional properties should not result in a no-op). The no-op could be implemented using a gate much like in an LSTM. Alternatively, a simple addition may be enough (not sure how to do a full replacement). So, I think, from these vague "principles", I expect the following architecture to be ideal (in terms of expressiveness, and interpretability, but not sure about optimizability):

x = self.attn(x, x) + x # (is this +x necessary? Or +res? Or nothing?)
midx = x
mlp = self.mlp(x)
gate = sigmoid(self.mlp2(x))
x = x*gate + mlp
newres = x
x = RMSNorm(x, ELEMENTWISEAFFINE={ELEMENTWISEAFFINE}), 

The RMSNorm may not even be necessary. Is this additional mlp necessary, or can it just be learned via the embedding matrix? Is the residual previously necessary, only to allow for the implementation of no-op? (Previously, the norm of attn(x) is 10x that of mlp(x), but 1/4 of that of res.)

Messing with signal propagation

First, we try, without all layer loss. Note that it is exceptionally important to normalize the input to the MLP, else the gradient quickly approaches NaN. During optimization, it is also important that the MLP get a piece of the real residual signal, without it being fully destroyed by attention (why?), or alternatively we use all layer loss.

attn = self.attn(x, x) + x
y = self.ln(attn)
mlp = self.mlp(y)
midx = mlp
y = torch.sigmoid(self.mlp2(mlp))
x = mlp + res*(1 - y)
newres = x
x = self.ln(x)

Frankly, it's not clear how well the previous worked... but the curve is interesting, it certainly converges.

loss plot

Let's test vanilla architectural changes. Consider transformers as conditional summation. The hypothesis is that this should perform just as well as our usual mlp*attn achitecture:

class Gate(DualModule):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        # smoother ReLU, also maybe helps dead neurons empirically?
        # (should just delete dead neurons)
        self.gelu = nn.GELU(approximate='tanh')  # should just use 'none' if not trying to copy GPT2
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = x.sum(dim=-1, keepdim=True) # sum over the last dimension
        x = torch.sigmoid(x)
        return x
attn = self.attn(x, x)
mlp = self.gate(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

But unfortunately, it doesn't! In fact, the curve looks less nice than the prior experiment, and I'm not even sure that it converges. So we really lose a lot of expressitivity; the hypothesis is just wrong. To verify that it isn't some issue with some other part of the code, let's just run the original mlp*attn architecture again. It seems fine, no issues:

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

I wonder how much of the gate experiment is just loss in parameters, i.e. if we add in the same number of parameters and then do a summation, does it help? It turns out ot make no difference at all, compared to the 12-gate-axg experiment.

class Gate(DualModule):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        # smoother ReLU, also maybe helps dead neurons empirically?
        # (should just delete dead neurons)
        self.gelu = nn.GELU(approximate='tanh')  # should just use 'none' if not trying to copy GPT2
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = x.sum(dim=-1, keepdim=True) # sum over the last dimension
        x = torch.sigmoid(x)
        return x
attn = self.attn(x, x)
mlp = self.gate(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

This points to the conclusion that the mlp does not just serve as a gate; it is in fact memorizing useful information. This begs the question: is attention itself serving as a gate? No, it seems that the attention component is not just functioning as a gate. (The outcome reminds me of the outcome of removing attention entirely.)

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
gate = attn.sum(dim=-1, keepdim=True)
y = mlp*gate
x = y + res
newres = x
x = self.ln(x)

loss plot

Let me also feed in the gate through a sigmoid to make sure we are coming to the correct conclusion. The conclusion remains unchanged:

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
gate = torch.sigmoid(attn.sum(dim=-1, keepdim=True))
y = mlp*gate
x = y + res
newres = x
x = self.ln(x)

loss plot

Let's keep trying random things. First, let's try adding attn back into the residual, in hopes that it makes optimization faster (it does not make optimization faster):

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + attn + res
newres = x
x = self.ln(x)

loss plot

Is this really the best version of this architecture? Let's check if giving mlp access to attn helps with anything... No, it in fact penalizes things. (wtf) Probably the more pure the residual signal, the better at first...

attn = self.attn(x, x)
mlp = self.mlp(x + attn)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

Let's check if giving mlp only access to attn hurts anything... (It hurts things a little bit)

attn = self.attn(x, x)
mlp = self.mlp(attn)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

TODO: original transformer basline:

attn = self.attn(x, x)
mlp = self.mlp(self.ln(attn + res))
midx = mlp
y = mlp
x = y + attn + res
newres = x
x = self.ln(x)

loss plot

Let's train a larger model just for fun (value matrix is set to True)

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

TODO train a baseline GPT-medium (copy Kaparthy code)

Let's "rotate" the mlp before applying it (this is probably redundant, because the last layer of the MLP probably already does this). It doesn't seem to work super well, but it does converge (does it get better?) (value matrix is set to True)

attn = self.attn(x, x)
mlp = self.rotator(self.mlp(x))
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

Let's try the rotator again but with more warmup time (from 100 from 10). It still doesn't seem to bring any improvement. (value matrix is set to True)

attn = self.attn(x, x)
mlp = self.rotator(self.mlp(x))
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

Let's set n_embd = 1296 (and also increase warmupsteps to 100 from 10). More embed is not substantially better! (TODO try normal transformers, is it substantially better there?) (value matrix is set to True)

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

Let's go back to n_embd = 768 (and also increase warmupsteps to 100 from 10). Optimization is a little slower but it ends up in the same place. (value matrix is set to True)

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

Let me try to reproduce axm again... wait, why is it different? (value matrix is set to True) (Warmup steps 10).

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

Maybe I have to turn off the value matrix... Indeed it was because of the value matrix. (Warmup steps 10). Yes this is closer (discrepancy may be due to RMSNorm, I forget now.)

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + res
newres = x
x = self.ln(x)

loss plot

Let's send warmup steps back to 100 and turn the value matrix back on just for faster training. What happens if we use x instead of res? It is simply a lost cause, optimization-wise, even with learning rate set to 1/4 of the usual (1.5e-4 instaed of 6e-4.) (It may eventually converge to the same place, I think, eventually? But I don't have enough patience to find out.) TODO: why is this harder to optimize?

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*attn
x = y + x
newres = x
x = self.ln(x)

loss plot

I suspect that the mlp is deleting "semantic dimensions" of note (but this only works if dimensions align...). Does the size of mlp output scale with the size of the residual that we are trying to attenuate? What if we try to help it along, directly have it attenuate res (hopefully), and add attn separately? There appears to be no tangible benefit.

attn = self.attn(x, x)
mlp = self.mlp(x)
midx = mlp
y = mlp*res
x = y + res + attn
newres = x
x = self.ln(x)

loss plot

Let's revisit vanilla attention, with indpenednet blocks, copying a bit from the original Kaprthy code. Wait, it does much better! Why?

x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))

loss plot

Let's make sure that our multiplication schematic is still competitive when not reusing weights. Vanilla is still true. Unlike previously, this seems to be a complete no-op. (Maybe this only helps in the re-using weight regime?)

y = self.ln_1(x)
x = x + self.attn(y)*self.mlp(y)

loss plot

Let's turn on the original, reusing weights.

x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4

loss plot

Clearly reusing weight is what is wrong. So it seems like we need to probably increase the size of the mlp matrix correspondingly (from 4 to something larger, i.e. 4*NUMLAYERS to match the original number of weights) if we do decide to reuse weights.

Let's turn on multiplication re-using wieghts again, just to see if munltiplication helps in the re-using weights regime. (Yes, it does improve optimization a little bit) Note that removing the value matrix in fact also improves things substantially during optimization (but is slower):

y = self.ln_1(x)
x = x + self.attn(y)*self.mlp(y)
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=4

loss plot loss plot

Now, setting MLP_SCALE=4*12=48, to match the weights lost due to removal. Our hypothesis is wrong, and increasing the size of the MLP doesn't really help!

y = self.ln_1(x)
x = x + self.attn(y)*self.mlp(y)
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=48

loss plot

To make sure, let's try MLP_SCALE=48 for addition too. The performance impact of not reusing weights persists:

x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
======== 
VALUEMATRIX=True
REUSE_WEIGHTS=True
MLP_SCALE=48

loss plot

So somehow, it seems that we want to learn a different MLP for each layer. TODO

Returning to the non-reusing-weights regime, what happens if we specifically have mlp negate embeddings that contribute to attention. This performs decidely worse:

y = self.ln_1(x)
mlp=self.mlp(y)
x = x + self.attn(y,mlp*y)
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

I don't understand though; isn't that mathematically equivalent to (note that this experiment, the no-value no-sharing-weights, is the best yet):

y = self.ln_1(x)
mlp=self.mlp(y)
x = x + self.attn(y,y)*mlp
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

No, they are not equivalent; the difference is that mlp also has dimension (B, T, C). And when we do element-wise multiplication, in the first experiment attn(y, mlp*y), each contribution is multiplied by the source token MLP weight, whereas when we do attn(y, y)*mlp, the final output is weighted by the destination token MLP.

Now let's try: (TODO try putting second LN bak in and running attn on output of mlp). ONe hypothesis is that MLP exists exclusively to ``negate'' certain priveleged dimensions of the residual, and maybe htis architecture would facilitate that. But clearly, it does not facilitate it, it seems to be a long-term penalty:

y = self.ln_1(x)
mlp=self.mlp(y)
x = x + mlp*x + self.attn(y,y)
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Orthogonally, I would like to get of the residual. Ideally, attn(x) can learn to replace it, but the problem is that it runs itself through a softmax, so the weight of the identity will always decrease. We need to allow the weight of the original to stay the same (i.e. in computation, we need to allow union in addition to replacement, and replacement is handled by the MLP). Can we avoid running the attn through softmax? (This doesn't work)

y = self.ln_1(x)
mlp=self.mlp(y)
x = 2*self.attn(y,y)*mlp
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Let's try running attention through a sigmoid instead (It deosn't work)

y = self.ln_1(x)
mlp=self.mlp(y)
x = self.sigmoidattn(y,y)*mlp
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Adding x back: it seems anyways extremely important, that self.attn always picks "the best ones to combine" (maybe to prevent too much information pollution). Otherwise, there seems not to be any incentive not to send every weight to be equal to 1: (TODO what if we set every weight to be 1)

y = self.ln_1(x)
mlp=self.mlp(y)
x = self.sigmoidattn(y,y)*mlp+x
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Just for fun, just force the attention diagonal to always be 1 (keeping sigmoid):

y = self.ln_1(x)
mlp=self.mlp(y)
x = self.sigmoiddiagattn(y,y)*mlp+x
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Removing the +x:

y = self.ln_1(x)
mlp=self.mlp(y)
x = self.sigmoiddiagattn(y,y)*mlp
======== 
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Alright now, switching gears entirely back to something more complicated. This is my best blind guess at an architecture currently. The outcome is that it converges to the same place but is decidely slower to optimize in the very beginning:

y = self.ln_1(x)
newemb = self.attn(y) + x
mlp=self.mlp(self.ln_2(newemb))
x = mlp*newemb + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Does delaying the input of mlp by a layer change anything at all? It optimizes faster, but then seems to fail and get worse later, even decongealing after step 4000. So input delay to MLP doesn't matter and maybe it in fact hurts, but maybe ensuring there is a large residual component to multiply into mlp helps with optimization early on but maybe not later? (So mlp maybe isn't attenuating x? I'm thoroughly confused.)

y = self.ln_1(x)
newemb = self.attn(y) + x
mlp=self.mlp(y)
x = mlp*newemb + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Another experiment. Note that here, I forgot to pass the residual into mlp. It doesn't realy converge to the same spot...

y = self.ln_1(x)
newemb = self.attn(y)
mlp=self.mlp(self.ln_2(newemb))
x = mlp*newemb + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Let's feed the residual back in. Compared to the first "complicated" experiment, here mlp only multiplies self.attn(y), and not x; it is worse than the first complicated experiment. Compared to the axm experiments, here mlp additionally gets an un-attenuated attn component as input, which is clearly a penalty. So for some reason, it is is important to explicitly attenuate only the attention component, and not the residual.

y = self.ln_1(x)
newemb = self.attn(y)
mlp=self.mlp(self.ln_2(newemb + x))
x = mlp*newemb + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Experiment 5:

y = self.ln_1(x)
newemb = self.attn(y)
mlp=self.mlp(self.ln_2(newemb))
x = mlp*(newemb + x) + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Experiment 6:

y = self.ln_1(x)
newemb = self.attn(y)
mlp=self.mlp(self.ln_2(newemb))
x = mlp*(x) + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

The takeaways: (Experiment 2) mlp(x) * (self.attn(x) + x) seems to optimize faster than (Experiment 1) mlp(self.attn(x)+x) * (self.attn(x) + x) at first, but later it really fails. (Experiment 3) mlp(self.attn(x)) * (self.attn(x)) only is atrocious, worse than both. (Experiment 4) mlp(self.attn(x)+x) * (self.attn(x)) is overall just slightly worse than Experiment 1, thouguh has a tiny optimization bump in the very beginning... (Experiment 5) mlp(self.attn(x)) * (self.attn(x) + x) optimizes slowly at the beginning, and doesn't look like it will converge to a good place (TODO run for longer)... (Experiment 6) mlp(self.attn(x)) * (x) is more stable than 5 but is even worse worse than Experiment 3. All are worse than (13-baseline-axm) mlp(x) * self.attn(x).

Takeaways: MLP needs to see the residual, not just attention. It doesn't care about the current attention component at all. It also shouldn't multiply the residual itself (and it should multiply attention only) -- it converges faster but really odd things happen later down the road...

More messing with signal propagation

Let's remove the 1s from the diagonal of attention. The motivation here is, we are already copying x, there is no need to copy it again (but wait, what if there is no application to do, i.e. it is not near anything. I suspect this will not do as well. Instead, we should do the next experiment.) (I accidentally deleted this run...)

self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size),diagonal=-1).bool().view(1, 1, config.block_size, config.block_size))
========
y = self.ln_1(x)
attn = self.attn(y)
mlp=self.mlp(y)
x = mlp*attn + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Now, we use the usual attention causality (lower triangular mask = True), but when summing things up, we omit the value contribution from the identity token (so the mlp(x) does not apply/multiply on it). This also requires turning the value matrix off... The result is equivalent to the one before! (What does this mean?)

y = y*self.nodiagonal[:,:,T,T] # delete the self contribution
========
y = self.ln_1(x)
attn = self.attn(y)
mlp=self.mlp(y)
x = mlp*attn + x
========
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

In any case, the takeaway is that in this architecture, it is important to run mlp*x as well. Maybe, because we always add the residual back to itself, mlp*x is performing leftover cleanup of the previous residual, over and over and over again.

What happens if I add a constant contribution from the identity token? I.e. mlp(x)*(attn+x). That was the 13-complicated-2 experiment from above. It does amazingly in the beginning and then odd things happen. Experiment 1 mlp(self.attn(x)+x) * (self.attn(x) + x) does slightly less well in the beginning, but eventually seems to converge? Let me try to run this experiment again.

Note in 13-complicated-2 it is attn + x but here it is attn + y (layer normed). As we found previously, it really does hurt long term performance, a little bit!

y = self.ln_1(x)
attn = self.attn(y)
mlp = self.mlp(y)
x = mlp*(attn+y) + x
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

In an ideal world, we zero out the self contribution in attention, and add in a whole unit of it afterwards:

y = self.ln_1(x)
attn = self.zerodattn(y)
mlp = self.mlp(y)
x = mlp*(attn+y) + x
========
VALUEMATRIX=False
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

Note that, in general, multiplication seems slightly worse than addition.

Unrelatedly, let me try to do the "DELETE_SELF_CONTRIBUTION" faster in the value matrix regime. (Our implementation is not faster.) But note that DELETE_SELF_CONTRIBUTION has no effect at all:

y = self.ln_1(x)
attn = self.attn(y)
mlp = self.mlp(y)
x = mlp*(attn) + x
========
DELETE_SELF_CONTRIBUTION=True
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4

loss plot

TODO when doing addition attention, what happens if I run mlp(x) instead of mlp(LN(attn + res))?

TODO is mlp allowing us to attenuate res? I.e. deletion instead of replacement?

Early fun:

Exploring Unconventional Loss Fns

Note

Notes in retrospect: This series of experiments probably makes more sense to do on a token-by-token basis, not a layer-by-layer basis. Here, I am conflating the language, and the computation. (Yes, this series of experiments fails. It is not very motivated.)

Hypothesis: One of my early motivations was to somehow make the model "self-consistent". That is, putting on an RL hat, perhaps we can interpret subsequent layers as giving "feedback" to earlier layers (in retrospect, this turns out to be a bad way of reasoning about it). In any case, imagine a world where a block wants to output a signal such that the subsequent block minimizes its "surprise" (i.e. -log(1 - Pr[max])).

We start off by computing an additional loss at every layer, and then later appending it to the final loss. Setting block_loss = -log(1 - pr[max]) (the lower the confidence, the closer the block_loss is to 0; the higher the confidence, the more negative the block_loss (so when added to the loss, the loss gets lower, which is "good").). Special care is taken to compute the "evaluator" block under torch.no_grad.

It doesn't work very well at all! (Notes in retrospect: Graph lost to time.) One observed problem is that the confidence is being pushed very high on early layers, with no bearing on the final layer that outputs. Then, we could maybe investigate some early termination technique that samples from the embeddings as soon as the confidence is high enough.

1-noise

Hypothesis: I suspect one issue was that the was that the block_loss was too big relative to the real signal; let us normalize it (i.e. for each layer, -1 * crossentropy / n_layers), but it turns out that this loss signal is still very noisy, so learning is not very good.

2-test

Hypothesis: In this one, we set losses += _block_loss / self.config.n_layer, where _block_loss = F.cross_entropy(_logits.view(-1, _logits.size(-1)), _targets.view(-1)) and then _block_loss = torch.log(1 - torch.exp(-1*_block_loss)), namely it is positive feedback only. As we can see, it doesn't ever get to good training error, but it is better than noise.

Note

Retroactive Note: I forgot what was going on here.

3-test

Hypothesis: Continuing on this "self-reward using confidence as heuristic" idea, we want to incentivize high confidence (else loss accumulates forever?) whilst penalizing wrong answers. When there is no environmental feedback (i.e. when targets is None), the loss should just be the self-confidence. Whenever confidence is high, there is some probability of terminating the line of thought (which is good), yet also a chance of accumulating loss in prior steps. Then the model should learn to be confident early.

loss_ = (xe * _confidence * _mask_BT).mean()

Note

Retroactive Note: I again don't remember what is going on, but we did try this early termination thing, and the conclusion is that it is relatively pointless. In retrospect, it reminds me of Universal Transformers (excluding the per-layer loss idea).

4-test

Hypothesis: Previously, we evaluated the output of each layer against the target. Now, we don't evaluate the true loss against the target until the network is actually ready to output:

xe_factor = ((xe - 1) * _just_triggered + 1)
loss_ = (xe_factor * _confidence * _mask_BT).mean()

This has the effect of punishing confidence early on. Unfortunately, the result is still subpar.

5-test

Note

Retroactive Note: Another unmotivated experiment in retrospect, but it has the property of having the first graph that I saved. I remain curious how these ideas would play out at a token generation level (if not already used in these chain-of-thought settings or during post-training).

Unfortunately, the graphs are (initially) not very well documented; orange usually reflects the current experiment, and blue is a relevant baseline, but for these older experiments, I didn't write down what it was (since this was not intended to be a reference).

Hypothesis: Let's rethink using confidence of subsequent layers to penalize / reward earlier computation:

  • Certainly, one should punish confidently wrong answers. But what if there is no target? Ask the next layer if wrong or not wrong.

  • On confidence: should we reward confidence? (By my simulateability theory, predictable actions are not interesting to me. So if the robot's action was predictable to me, that action is not interesting. But this seems to be different. There is a distinction between predictability and distinguishability. Also, "did I expect this" from a verifier's point of view, is different from "would I have done the same thing", because note that in self talk, the answer to the latter is always "yes". Maybe heuristically, if I am highly confident, then I did expect it -- I know how to act in return to maximize the true reward; if I am not at all highly confident, then I did not expect the answer at all (it doesn't look like giberish either), and have no clue and no confidence in how to act. Thus, we should reward low confidence, or punish high confidence.)

loss_ = (xe_factor_prev * _confidence * _mask_BT_prev).mean()

Result:

loss plot

(Strangely linear, but also outputs "the" a hundred times. On further debugging, penalizing confidence appears to cause this behavior. Why?)

6-test

A few experiments:

  • 6-test-1: Rerun the "confidence of target" experiment with GPT learning rate.

  • 0-noearly: the same as vanilla GPT (but reusing weights) without our early termination mechanism.

  • 0-original does not re-use any weights, and needs a smaller learning rate to converge properly.

Does confidence reinforcement even make sense at a high level? Recall what each layer does:

  • the attention module: each embedding is the weighted sum of multiple embeddings from the previous layer in its context window, computed according to some "attention matrix".

  • the MLP module is essentially a fact retrieval system; the 4A x A weight matrix + ReLU (or other non-linearity) can be (perhaps) thought of as evaluating a giant |4A|-number set of if statements on A-dimensional embeddings; the A x 4A projection matrix perhaps then "adds" facts to the original embedding (via the residual) depending on which "if statements" passed.

Our general hypothesis in this section was that good training data is only one part of learning; acts of "self-reflection" or "self-consistency" are also very important to learning. (Somehow, the model should make two predictions and check whether they are consistent, or be able to evaluate its own quality/consistency independenty of generating predictions.)

Note that the subsequent logit it generates is indeed such an assessment. The problem with using confidence as part of the loss is that...

A stream of   text that is
  of     text that is   next
  of     text that is   next

Well, we are in some sense in the wrong namespace; why are we using steps of the computation, to reward overall behavior? The computation hasn't even had time to finish.

Note that the residual connection and c_proj are extremey important, and I do not know why. The value matrix does not seem so important. (perhaps we can get rid of c_proj?)

What if we reward tokens that are equal to the next token in the previous layer? Does it still have the same namespace issue?

Note

Retroactive Note. I see no reason why intermediate computations should reflect the structure of "next token prediction."

Experimenting with Signal Propagation

8-experiments

Note

Retroactive Note. A few experiments are lost, and at some point I started experimenting with multiple copies of the attn and mlp layers, as well as element-wise multiplication instead of addition. Throughout, `all layer loss' refers to evaluating the output of each layer against the target in addition to the usual loss; later, we will see that this is both slower and slightly counterproductive (but the results should still generalize to the usual notion of loss). Initially, I thought all layer loss would help with signal propagation and lessen the need for residuals (indeed it does).

Note

Retroactive Note. At some point here, I realized by accident that doing x = x + attn(LN(x)) * mlp(LN(x)) converges faster than the standard architecture, though the perplexity is a bit worse. For good measure, I retroactively ran the following comparison since I cannot find the original comparison (the purple here is 8 layers instead of 12):

Transformer, max LR 0.0006 n_layer 12
Setting:
==details======
 machine_code
class MultExecute(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLP(config)
    def forward(self, program, attn):
        return self.mlp(program) * attn
----------------
 machine_modules
        self.compiler = BenCompilerNoOp(config)
        self.execute = MultExecute(config)
----------------
 block_logic
        y = self.ln_1(x)
        attn, xWeights, scores = self.attn(y, y, print_weights=print_weights)
        program = self.compiler(y)
        machineOutput = self.execute(program, attn)
        newx = x + machineOutput
========
VALUEMATRIX=True
REUSE_WEIGHTS=False
MLP_SCALE=4
ATTENTION_SINK=False
ATTENTION_MASK =False
IDENTITY_LOSS=False
CODE_MODE=False

caption

Also, note that the performance of attn()*mlp() is even more comparible to GPT if every block reuses weights, which is a regime that I was initially fascinated with (and forgot to turn off in some early experiments).

Let me emphasize, even in retrospect, how odd it is that x = x + mlp(ln(x))*attn(ln(x)) works so well (at least at this scale, and this early in the training process). It really messes with my intuition for interpreting these networks, and may be a clue into how these networks work.

Hypothesis: On seeing the surprising efficacy of multiplying the attention signal into the mlp signal, let's try other combinations. [Experiments lost.]

The problem with x = x*attn2(LN(x)) + attn(LN(x)), x = x + mlp(LN(x)) seems to be that the residuals are blowing up the more we train it. Same with x = x*attn2(LN(x)) + attn(LN(x)), x = x * mlp2(LN(x)) + mlp(LN(x)). (All Layer Loss). We see residuals like

SIZE COMPARISON prev 2.8568525314331055 next 1.1331876516342163
SIZE COMPARISON prev 2.7836289405822754 next 1.121557354927063
SIZE COMPARISON prev 3.347931385040283 next 1.1270164251327515
SIZE COMPARISON prev 5.639773368835449 next 1.1290345191955566
SIZE COMPARISON prev 15.21493911743164 next 1.1317602396011353
SIZE COMPARISON prev 60.20911407470703 next 1.1316640377044678
SIZE COMPARISON prev 284.3788757324219 next 1.1313221454620361
SIZE COMPARISON prev 1445.6865234375 next 1.1314541101455688
SIZE COMPARISON prev 7582.7626953125 next 1.1323127746582031
SIZE COMPARISON prev 40300.05078125 next 1.133501410484314
SIZE COMPARISON prev 216047.5625 next 1.1349073648452759

Note

Retroactive Note. I believe next refers to output residual fed through a layer norm, and prev directly refers to the size of the input residual. Each row corresponds with a layer.

For something like x = LN(x) + self.attn(LN(x)) x = x + self.mlp(LN(x)) (all layer loss) (note, using LN(x) instead of the residual) we see residuals magnitudes like

SIZE COMPARISON prev 18.42753028869629 next 0.8354735374450684
SIZE COMPARISON prev 12.65165901184082 next 0.8403578996658325
SIZE COMPARISON prev 12.434549331665039 next 0.8322869539260864
SIZE COMPARISON prev 13.066632270812988 next 0.8329403400421143
SIZE COMPARISON prev 12.981801986694336 next 0.8330138921737671
SIZE COMPARISON prev 12.918157577514648 next 0.8329639434814453
SIZE COMPARISON prev 12.923635482788086 next 0.8329555988311768
SIZE COMPARISON prev 12.929137229919434 next 0.8329416513442993
SIZE COMPARISON prev 12.931181907653809 next 0.832944393157959
SIZE COMPARISON prev 12.930807113647461 next 0.8329448699951172
SIZE COMPARISON prev 12.930171966552734 next 0.8329458236694336
SIZE COMPARISON prev 12.92995834350586 next 0.832942008972168

Note

Retroactive Note 2/25. Note the unhealthy obsession with shared weights between blocks, and "all layer loss". Eventually I will come to the conclusion that weights should not be shared, if only because it mimics a world with a much fatter MLP, which I currently think is the best way to scale.

I thought shared weights were useful because they better fit this "combinator calculus" mode of thinking that I had (and still have). There, the combinators don't change, only the tokens and how they are arranged. I suspect that, however, in reality, for a language as expressive as English, different combinators will often be applied at every level; so while, in an ideal world, yes everyone should have the same mlp matrix, in a practical world, it makes sense to optimize the size of the mlp matrix, and distribute knowledge across combinators tailored to specific contexts (i.e. layers).

For comparison, using vanilla GPT (but with shared weights), and all layer loss, we see by step 1100 (it is also slowly shrinking over more steps)

SIZE COMPARISON prev 0.8780975341796875 next 1.1142359972000122
SIZE COMPARISON prev 1.2143694162368774 next 1.110939383506775
SIZE COMPARISON prev 1.652571439743042 next 1.1084067821502686
SIZE COMPARISON prev 2.119270086288452 next 1.1072710752487183
SIZE COMPARISON prev 2.6001691818237305 next 1.1067460775375366
SIZE COMPARISON prev 3.0836727619171143 next 1.1065006256103516
SIZE COMPARISON prev 3.564718246459961 next 1.1063616275787354
SIZE COMPARISON prev 4.042917251586914 next 1.1062694787979126
SIZE COMPARISON prev 4.51833438873291 next 1.1062010526657104
SIZE COMPARISON prev 4.991661548614502 next 1.1061511039733887
SIZE COMPARISON prev 5.463232040405273 next 1.1061171293258667
SIZE COMPARISON prev 5.933814525604248 next 1.1060969829559326

and by step 4250:

SIZE COMPARISON prev 2.485067367553711 next 1.3869950771331787
SIZE COMPARISON prev 3.4810280799865723 next 1.3849396705627441
SIZE COMPARISON prev 4.4364542961120605 next 1.3832675218582153
SIZE COMPARISON prev 5.394400596618652 next 1.3827459812164307
SIZE COMPARISON prev 6.328531265258789 next 1.3822684288024902
SIZE COMPARISON prev 7.25205135345459 next 1.3819595575332642
SIZE COMPARISON prev 8.162321090698242 next 1.381672739982605
SIZE COMPARISON prev 9.063191413879395 next 1.3814265727996826
SIZE COMPARISON prev 9.955363273620605 next 1.3812041282653809
SIZE COMPARISON prev 10.840224266052246 next 1.3810186386108398
SIZE COMPARISON prev 11.719019889831543 next 1.3808491230010986

Most of the contribution seems to come from the attention step. In particular e.g. by step 799 for the same experiment as above:

SIZE COMPARISON prev 0.9725479483604431 mid 0.05692768841981888 next 1.1076934337615967
SIZE COMPARISON prev 1.208902359008789 mid 1.1301352977752686 next 1.109763503074646
SIZE COMPARISON prev 1.4090938568115234 mid 1.3024053573608398 next 1.111130714416504
SIZE COMPARISON prev 1.5330252647399902 mid 1.396193027496338 next 1.1153910160064697
SIZE COMPARISON prev 1.7066140174865723 mid 1.5469791889190674 next 1.11979341506958
SIZE COMPARISON prev 2.1395297050476074 mid 1.9626308679580688 next 1.1215941905975342
SIZE COMPARISON prev 3.0498666763305664 mid 2.85837459564209 next 1.1217771768569946
SIZE COMPARISON prev 4.754446029663086 mid 4.55536413192749 next 1.1214094161987305
SIZE COMPARISON prev 7.8321757316589355 mid 7.630302429199219 next 1.121267318725586
SIZE COMPARISON prev 13.233830451965332 mid 13.03246784210205 next 1.1213200092315674
SIZE COMPARISON prev 22.627168655395508 mid 22.428388595581055 next 1.1217448711395264
SIZE COMPARISON prev 38.77415466308594 mid 38.57970428466797 next 1.1223392486572266

Note

Retroactive Note 2/25. I forget what mid stands for, it's probably the magnitude of x + attn(LN(x)) or something similar.

Detour: We briefly check if LayerNorm's learnable scale/shift parameters are actualy necessary.

The answer is, probably not! For x = x + attn(LN(x)), x = x + MLP(LN(x)), without learnable parameters, by step 499: [Note, graph is lost and needs to be regenerated]

SIZE COMPARISON prev 1.3563252687454224 mid 0.4865388572216034 next 1.0006482601165771
SIZE COMPARISON prev 1.7924420833587646 mid 1.4573729038238525 next 1.0006499290466309
SIZE COMPARISON prev 2.284381628036499 mid 1.9278755187988281 next 1.0006506443023682
SIZE COMPARISON prev 2.821475028991699 mid 2.4290308952331543 next 1.0006510019302368
SIZE COMPARISON prev 3.3831615447998047 mid 2.9755988121032715 next 1.000651240348816
SIZE COMPARISON prev 3.9586188793182373 mid 3.545346736907959 next 1.0006513595581055
SIZE COMPARISON prev 4.539757251739502 mid 4.125799179077148 next 1.000651478767395
SIZE COMPARISON prev 5.122259140014648 mid 4.709800720214844 next 1.0006515979766846
SIZE COMPARISON prev 5.703899383544922 mid 5.293674945831299 next 1.0006515979766846
SIZE COMPARISON prev 6.283486843109131 mid 5.875919342041016 next 1.0006515979766846
SIZE COMPARISON prev 6.860944747924805 mid 6.455946922302246 next 1.0006515979766846
SIZE COMPARISON prev 7.435906887054443 mid 7.033243179321289 next 1.0006515979766846

Hypothesis: Maybe attn(LN(x)) need not be fed additively into the input of the MLP; perhaps element-wise multiplicative-ness is enough.

If we use x = x*attn(LN(x)), x = x + MLP(LN(x)), as previously seen, the residuals blow up. For LN without learnable parameters (everylayer loss, shared weights), by step 3800:

SIZE COMPARISON prev 0.9520694017410278 mid 0.05660898983478546 next 1.0006455183029175
SIZE COMPARISON prev 1.1140660047531128 mid 1.0007506608963013 next 1.0006468296051025
SIZE COMPARISON prev 1.1631948947906494 mid 0.9358635544776917 next 1.000647783279419
SIZE COMPARISON prev 2.937286853790283 mid 2.5033302307128906 next 1.0006508827209473
SIZE COMPARISON prev 23.77895736694336 mid 23.278114318847656 next 1.0006515979766846
SIZE COMPARISON prev 211.52383422851562 mid 211.03501892089844 next 1.0006515979766846
SIZE COMPARISON prev 1932.359130859375 mid 1931.881103515625 next 1.0006515979766846
SIZE COMPARISON prev 17739.107421875 mid 17738.642578125 next 1.0006515979766846
SIZE COMPARISON prev 162568.78125 mid 162568.328125 next 1.0006515979766846
SIZE COMPARISON prev 1477890.75 mid 1477890.25 next 1.0006515979766846
SIZE COMPARISON prev 13287474.0 mid 13287474.0 next 1.0006515979766846
SIZE COMPARISON prev 117777120.0 mid 117777120.0 next 1.0006515979766846

Perhaps this means that attn(x) outputs very small numbers, and they are trying to compensate; generally the outcome is very confusing.

Note

Retroactive Note 2/25. Throughout, I was rather perturbed by the fact that we could not feed the skip connections through a LayerNorm... what difference could it possibly make? (It clearly makes a big difference.) In retrospect, it certainly attenuates the effect of any loss computed relative to early layers via skip connections.

Hypothesis: I wonder if x = x + attn(ln(x)) is at heart performing a "substitution" into the residual / outbound signal. And x + mlp(ln(x)) evaluates an if statement and essentially does a dictionary lookup. Maybe a better algorithm should be

y = x + attn(ln(x))
x = ln(x) + mlp(ln(y))

noticing here that we don't add y back in, i.e. it does not contribute to the residual signal?

Well, it turns out this performs poorly at first, similarly to x = ln(x) + attn(ln(x)):

loss plot

What about

y = x + attn(ln(x))
x = x + mlp(ln(y))

The performance seems truly worse:

loss plot

so it is really quite important that the logits directly get the output of the attention layer (instead of solely feeding x + attn(ln(x)) into the mlp and never feeding attn(ln(x)) to the output).

Note

Retroactive Note 2/25. But note that doing x = x + attn(ln(x))*mlp(ln(x)) or x = x + mlp(ln(x) || attn(ln(x))) seems to do just fine. Also, I run a similar experiment later where x = x + mlp(ln(x + attn(ln(x)))) --- without reusing weights --- and it seems to converge just fine (see 15-baseline). So it's not clear what the story is here.

Note that for double attention x = x*attn2(ln(x)) + attn(ln(x)), x = x + mlp(ln(x)), the residuals still blow up:

@ 949 train 5.2751 , allloss: 64.1261, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9999e-04, norm:7.9968, dt: 1769.77ms, tok/sec: 74061.71, flops:30.47, batch-reuse:1
SIZE COMPARISON prev 1.0659868717193604 next 1.0947023630142212
SIZE COMPARISON prev 1.798586368560791 next 1.1034561395645142
SIZE COMPARISON prev 2.1790177822113037 next 1.107276439666748
SIZE COMPARISON prev 2.6335883140563965 next 1.105468511581421
SIZE COMPARISON prev 4.98344612121582 next 1.1039302349090576
SIZE COMPARISON prev 14.011809349060059 next 1.1064351797103882
SIZE COMPARISON prev 56.432098388671875 next 1.106866478919983
SIZE COMPARISON prev 249.68592834472656 next 1.1070586442947388
SIZE COMPARISON prev 1164.308837890625 next 1.1072487831115723
SIZE COMPARISON prev 5611.69482421875 next 1.1074464321136475
SIZE COMPARISON prev 27626.546875 next 1.1076771020889282
SIZE COMPARISON prev 138020.09375 next 1.1078872680664062

For x = x + attn(ln(x)), x = x * mlp(ln(x)), it still blows up, but a little slower:

@ 949 train 5.3970 , allloss: 65.0632, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9999e-04, norm:8.4142, dt: 1606.70ms, tok/sec: 81578.15, flops:35.93, batch-reuse:1
SIZE COMPARISON prev 0.6984971761703491 mid 0.6713976860046387 next 1.0006380081176758
SIZE COMPARISON prev 0.9332523345947266 mid 0.7942314147949219 next 1.0006457567214966
SIZE COMPARISON prev 1.5732736587524414 mid 1.0734591484069824 next 1.0006494522094727
SIZE COMPARISON prev 2.531172752380371 mid 1.7262670993804932 next 1.0006508827209473
SIZE COMPARISON prev 4.456131935119629 mid 2.7139768600463867 next 1.0006513595581055
SIZE COMPARISON prev 7.704921245574951 mid 4.647169589996338 next 1.0006515979766846
SIZE COMPARISON prev 14.081094741821289 mid 7.897726058959961 next 1.0006515979766846
SIZE COMPARISON prev 25.32900619506836 mid 14.263557434082031 next 1.0006515979766846
SIZE COMPARISON prev 47.37249755859375 mid 25.50238800048828 next 1.0006515979766846
SIZE COMPARISON prev 87.77811431884766 mid 47.53529357910156 next 1.0006515979766846
SIZE COMPARISON prev 166.31536865234375 mid 87.93147277832031 next 1.0006515979766846
SIZE COMPARISON prev 313.864501953125 mid 166.46051025390625 next 1.0006515979766846
...
@ 3499 train 4.1087 , allloss: 50.1864, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9982e-04, norm:7.7492, dt: 1606.40ms, tok/sec: 81593.42, flops:35.93, batch-reuse:1
SIZE COMPARISON prev 0.9271938800811768 mid 0.6703000068664551 next 1.000638723373413
SIZE COMPARISON prev 1.5912479162216187 mid 1.1378358602523804 next 1.000647783279419
SIZE COMPARISON prev 2.6288633346557617 mid 1.8008546829223633 next 1.0006506443023682
SIZE COMPARISON prev 7.2939934730529785 mid 2.786419630050659 next 1.0006515979766846
SIZE COMPARISON prev 53.24920654296875 mid 7.418212890625 next 1.0006517171859741
SIZE COMPARISON prev 499.09014892578125 mid 53.380401611328125 next 1.0006515979766846
SIZE COMPARISON prev 4886.146484375 mid 499.2305908203125 next 1.0006515979766846
SIZE COMPARISON prev 48644.1953125 mid 4886.2861328125 next 1.0006517171859741
SIZE COMPARISON prev 489250.40625 mid 48644.3359375 next 1.0006515979766846
SIZE COMPARISON prev 4965097.5 mid 489250.53125 next 1.0006515979766846
SIZE COMPARISON prev 50753656.0 mid 4965097.5 next 1.0006515979766846
SIZE COMPARISON prev 521954624.0 mid 50753656.0 next 1.0006515979766846

Feeding in residual directly into mlp without attn first

Experiment: What about

y = x + attn(ln(x))
x = y + mlp(ln(x))

i.e. how important is it that the output of attention gets fed into the MLP? It turns out, this works surprisingly well; i.e. attention and MLP both seem to be additive, independent components:

loss plot

@ 4099 train 4.1100 , allloss: 50.7263, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9975e-04, norm:4.9580, dt: 1626.09ms, tok/sec: 80605.68, flops:35.50, batch-reuse:1
SIZE COMPARISON prev 1.7276793718338013 mid 0.828337550163269 next 1.0006475448608398
SIZE COMPARISON prev 2.3453421592712402 mid 1.9509013891220093 next 1.0006494522094727
SIZE COMPARISON prev 3.0417964458465576 mid 2.5768418312072754 next 1.0006505250930786
SIZE COMPARISON prev 3.7537474632263184 mid 3.24953031539917 next 1.0006508827209473
SIZE COMPARISON prev 4.469127655029297 mid 3.953681230545044 next 1.0006511211395264
SIZE COMPARISON prev 5.175540447235107 mid 4.6598124504089355 next 1.0006513595581055
SIZE COMPARISON prev 5.872206687927246 mid 5.358777046203613 next 1.0006513595581055
SIZE COMPARISON prev 6.5586748123168945 mid 6.048742294311523 next 1.000651478767395
SIZE COMPARISON prev 7.23720645904541 mid 6.730024337768555 next 1.0006515979766846
SIZE COMPARISON prev 7.907989501953125 mid 7.40459680557251 next 1.0006515979766846
SIZE COMPARISON prev 8.572383880615234 mid 8.07198715209961 next 1.0006515979766846
SIZE COMPARISON prev 9.231319427490234 mid 8.7335205078125 next 1.0006515979766846

Question: What happens if we remove the "all layer loss", and compute loss as per the usual method?

loss plot

It is in fact, better; the all layer loss is utterly useless.

Note

Retroactive Note 2/25. Finally, hopefully these old experiments stop using this "all_layer_loss", which had the additional property of 3x-ing the training time.

But note that the standard deviation of the signal does not blow up as much, weirdly:

@ 4399 train 3.7913 , allloss: 3.7913, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9971e-04, norm:0.6440, dt: 411.78ms, tok/sec: 318308.01, flops:140.18, batch-reuse:1
SIZE COMPARISON prev 2.829397439956665 mid 2.238248825073242 next 1.0006511211395264
SIZE COMPARISON prev 2.9154844284057617 mid 2.7668395042419434 next 1.0006510019302368
SIZE COMPARISON prev 2.937965154647827 mid 2.9097206592559814 next 1.0006510019302368
SIZE COMPARISON prev 2.847860336303711 mid 2.835615634918213 next 1.0006510019302368
SIZE COMPARISON prev 2.772952079772949 mid 2.7625179290771484 next 1.0006508827209473
SIZE COMPARISON prev 2.7469887733459473 mid 2.7138304710388184 next 1.0006508827209473
SIZE COMPARISON prev 2.775296211242676 mid 2.707706928253174 next 1.0006508827209473
SIZE COMPARISON prev 2.8555359840393066 mid 2.748220682144165 next 1.0006508827209473
SIZE COMPARISON prev 2.987206220626831 mid 2.838392496109009 next 1.0006510019302368
SIZE COMPARISON prev 3.1633992195129395 mid 2.9731945991516113 next 1.0006511211395264
SIZE COMPARISON prev 3.425197124481201 mid 3.1695663928985596 next 1.000651240348816
SIZE COMPARISON prev 3.9492619037628174 mid 3.533745765686035 next 1.0006513595581055

Weighting the attn(x) component more

What happens if we 2x the attention component, i.e. x = x + 2*attn(ln(x)) + mlp(ln(x))?

loss plot

@ 2449 train 4.2471 , allloss: 4.2471, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9991e-04, norm:0.7454, dt: 411.36ms, tok/sec: 318631.23, flops:140.32, batch-reuse:1
SIZE COMPARISON prev 3.474045753479004 mid 2.998915672302246 next 1.0006513595581055
SIZE COMPARISON prev 3.860865592956543 mid 3.8047633171081543 next 1.0006513595581055
SIZE COMPARISON prev 3.176105499267578 mid 3.17417573928833 next 1.0006511211395264
SIZE COMPARISON prev 2.7696032524108887 mid 2.7842602729797363 next 1.0006510019302368
SIZE COMPARISON prev 2.601201057434082 mid 2.6080410480499268 next 1.0006508827209473
SIZE COMPARISON prev 2.55460524559021 mid 2.5488529205322266 next 1.0006508827209473
SIZE COMPARISON prev 2.535594940185547 mid 2.500566244125366 next 1.0006508827209473
SIZE COMPARISON prev 2.5360474586486816 mid 2.458188056945801 next 1.0006508827209473
SIZE COMPARISON prev 2.598814010620117 mid 2.465818405151367 next 1.0006508827209473
SIZE COMPARISON prev 2.779794216156006 mid 2.578892707824707 next 1.0006510019302368
SIZE COMPARISON prev 3.112057685852051 mid 2.8377723693847656 next 1.0006511211395264
SIZE COMPARISON prev 3.612070083618164 mid 3.259913444519043 next 1.0006513595581055

It is worse... why? But it eventually converges, so it doesn't really matter. What if we 0.5x it?

loss plot

@ 1999 train 4.1020 , allloss: 4.1020, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9994e-04, norm:0.6649, dt: 422.70ms, tok/sec: 310084.43, flops:136.56, batch-reuse:1
val: loaded 100000000 tokens (first shard)
val: 1 epoch (1 shard) = 12207 mini-batches
validation loss: 4.2778
SIZE COMPARISON prev 1.7215176820755005 mid 1.1063441038131714 next 1.0006499290466309
SIZE COMPARISON prev 2.114750385284424 mid 1.966301441192627 next 1.0006506443023682
SIZE COMPARISON prev 1.981388807296753 mid 2.1172070503234863 next 1.0006502866744995
SIZE COMPARISON prev 1.8269994258880615 mid 1.9413310289382935 next 1.00065016746521
SIZE COMPARISON prev 1.6911734342575073 mid 1.7758653163909912 next 1.0006499290466309
SIZE COMPARISON prev 1.6034241914749146 mid 1.6564953327178955 next 1.0006498098373413
SIZE COMPARISON prev 1.5692648887634277 mid 1.589362621307373 next 1.0006496906280518
SIZE COMPARISON prev 1.5858174562454224 mid 1.5691030025482178 next 1.0006496906280518
SIZE COMPARISON prev 1.6530719995498657 mid 1.5916287899017334 next 1.0006498098373413
SIZE COMPARISON prev 1.7865718603134155 mid 1.6776154041290283 next 1.0006500482559204
SIZE COMPARISON prev 2.007704973220825 mid 1.845852017402649 next 1.000650405883789
SIZE COMPARISON prev 2.3220303058624268 mid 2.0987911224365234 next 1.0006507635116577

It is identical... So it seems that attention is not particularly useful for early gains; it's really the MLP that matters. If we skip out on attention completely, i.e. x = x + mlp(ln(x)):

loss plot

@ 1699 train 5.8143 , allloss: 5.8143, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9996e-04, norm:0.2052, dt: 272.28ms, tok/sec: 481386.85, flops:212.00, batch-reuse:1
SIZE COMPARISON prev 3.032069206237793 mid 0.034484002739191055 next 1.0006475448608398
SIZE COMPARISON prev 4.604878902435303 mid 3.032069206237793 next 1.0006513595581055
SIZE COMPARISON prev 4.4857869148254395 mid 4.604878902435303 next 1.0006513595581055
SIZE COMPARISON prev 4.478532791137695 mid 4.4857869148254395 next 1.0006513595581055
SIZE COMPARISON prev 4.600770473480225 mid 4.478532791137695 next 1.000651478767395
SIZE COMPARISON prev 4.848290920257568 mid 4.600770473480225 next 1.0006513595581055
SIZE COMPARISON prev 5.205094337463379 mid 4.848290920257568 next 1.0006513595581055
SIZE COMPARISON prev 5.654345989227295 mid 5.205094337463379 next 1.0006515979766846
SIZE COMPARISON prev 6.184695720672607 mid 5.654345989227295 next 1.0006513595581055
SIZE COMPARISON prev 6.788667678833008 mid 6.184695720672607 next 1.000651478767395
SIZE COMPARISON prev 7.458611965179443 mid 6.788667678833008 next 1.0006515979766846
SIZE COMPARISON prev 8.18480110168457 mid 7.458611965179443 next 1.0006517171859741

It is worse! Whew. Our efforts are validated. But note that the residual still grows... Here, prev is the size of newres (i.e. the output), x is the size of ln(prev) (i.e. layernorm of the output), and mid is the size of prevres (as opposed to prevres + attn) (i.e. mid = x, the input). So MLP is definitely adding a component in terms of magnitude growth (perhaps over and over again).

Note

Retroactive Note 2/25. Apologies for the strange names, they are artifacts of earlier experiments, when I decided to compute a final LayerNorm within the block, for some reason.

Back to multiplying attn(x)*mlp(x)

Question: What happens (for fun) if we do x = x + self.attn(ln(x))*self.mlp(ln(x)):

loss plot

@ 6349 train 3.8610 , allloss: 3.8610, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9941e-04, norm:0.7167, dt: 409.91ms, tok/sec: 319757.28, flops:140.82, batch-reuse:1
SIZE COMPARISON prev 4.092894554138184 mid 4.092894554138184 next 1.0006513595581055
SIZE COMPARISON prev 4.809565544128418 mid 4.809565544128418 next 1.0006513595581055
SIZE COMPARISON prev 4.6843671798706055 mid 4.6843671798706055 next 1.0006513595581055
SIZE COMPARISON prev 4.60699462890625 mid 4.60699462890625 next 1.0006513595581055
SIZE COMPARISON prev 4.567028999328613 mid 4.567028999328613 next 1.0006513595581055
SIZE COMPARISON prev 4.567410945892334 mid 4.567410945892334 next 1.0006513595581055
SIZE COMPARISON prev 4.6050896644592285 mid 4.6050896644592285 next 1.0006513595581055
SIZE COMPARISON prev 4.691596508026123 mid 4.691596508026123 next 1.000651478767395
SIZE COMPARISON prev 4.89960241317749 mid 4.89960241317749 next 1.0006513595581055
SIZE COMPARISON prev 5.414493083953857 mid 5.414493083953857 next 1.0006515979766846
SIZE COMPARISON prev 6.540674686431885 mid 6.540674686431885 next 1.0006515979766846
SIZE COMPARISON prev 8.01706314086914 mid 8.01706314086914 next 1.0006517171859741
rank 0 sample 0: A Poem for you! Roses are red, Potatoes are 

It seems to be the best yet.

Note

Retroactive Note 2/25. Important caveat: these experiments are still in the "reusing weights" regime, i.e., the entire block is reused in every layer. This architecture is somewhat worse when we don't reuse weights. For why I was reusing weights here, see the earlier comment on combinator calculuses. At some point we do stop reusing weights, thankfully, because it's unclear how much results in this regime generalize.

Other designs

Question: What happens if we only feed in attn(x) + mlp(x) into the attn and mlp components of the next layer, and not the res? Namely, y = attn(ln(x)) + mlp(ln(x)), x = x + attn(prev_y) + mlp(prev_y). (This one crashes).

On the Value Matrix

Hypothesis: I suspect the value matrix isn't necessary, and is perhaps just an optimization (it doesn't fit into my interpretability intuition). What happens if we remove the value matrix, and instead just use the identity matrix, and them sum together the output of all of the heads?

It is almost as good, but definitely not as good (I wonder if because we got rid of that one projection matrix):

loss plot

@ 4699 train 3.9265 , allloss: 3.9265, confloss: 0.0000, targetloss: 0.0000, earlystop: 0.000, earlystopdict: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], lr:5.9967e-04, norm:0.6336, dt: 1324.04ms, tok/sec: 98993.98, flops:42.88, batch-reuse:1
INFO nextres 33.397308349609375 attn*mlp 33.40809631347656 layernormed 1.0006517171859741
	 attn_hist -29.4375<tensor([ 45., 134., 166., 263., 119.,  41.])>29.25 mlp_hist -10.4375<tensor([ 24., 123., 402., 207.,  10.,   2.], dtype=torch.bfloat16)>15.0625
INFO nextres 193.25694274902344 attn*mlp 195.94659423828125 layernormed 1.0006517171859741
	 attn_hist -75.375<tensor([  1.,  11., 467., 249.,  38.,   2.])>72.75 mlp_hist -58.5<tensor([ 74., 122., 255., 280.,  33.,   5.], dtype=torch.bfloat16)>80.0
INFO nextres 120.01077270507812 attn*mlp 82.8893814086914 layernormed 1.0006517171859741
	 attn_hist -52.5<tensor([ 14., 545., 197.,  10.,   1.,   1.])>116.25 mlp_hist -18.125<tensor([  7.,   3., 272., 278., 190.,  17.], dtype=torch.bfloat16)>6.28125
INFO nextres 71.63121032714844 attn*mlp 55.632476806640625 layernormed 1.0006517171859741
	 attn_hist -45.75<tensor([ 14., 526., 212.,  12.,   2.,   2.])>102.75 mlp_hist -13.0625<tensor([  7.,  24., 362., 226., 137.,  12.], dtype=torch.bfloat16)>5.21875
INFO nextres 44.969993591308594 attn*mlp 30.2081298828125 layernormed 1.0006517171859741
	 attn_hist -49.875<tensor([  8., 436., 308.,  12.,   2.,   2.])>106.5 mlp_hist -7.6875<tensor([ 11., 276., 183., 214.,  74.,  10.], dtype=torch.bfloat16)>3.8125
INFO nextres 29.09838104248047 attn*mlp 18.429052352905273 layernormed 1.0006517171859741
	 attn_hist -45.375<tensor([ 27., 598., 130.,  10.,   1.,   2.])>109.5 mlp_hist -4.8125<tensor([155., 192., 158., 181.,  71.,  11.], dtype=torch.bfloat16)>2.78125
INFO nextres 19.71997833251953 attn*mlp 12.138340950012207 layernormed 1.0006517171859741
	 attn_hist -35.8125<tensor([132., 586.,  42.,   6.,   1.,   1.])>112.5 mlp_hist -3.65625<tensor([210., 152., 167., 177.,  56.,   6.], dtype=torch.bfloat16)>2.609375
INFO nextres 15.027619361877441 attn*mlp 7.450224876403809 layernormed 1.0006517171859741
	 attn_hist -34.5<tensor([121., 542.,  94.,   9.,   1.,   1.])>99.0 mlp_hist -3.25<tensor([166., 180., 143., 212.,  57.,  10.], dtype=torch.bfloat16)>2.09375
INFO nextres 13.641393661499023 attn*mlp 5.161359786987305 layernormed 1.0006517171859741
	 attn_hist -36.75<tensor([ 47., 202., 477.,  38.,   2.,   2.])>73.875 mlp_hist -3.1875<tensor([ 13., 253., 178., 238.,  70.,  16.], dtype=torch.bfloat16)>1.609375
INFO nextres 15.402214050292969 attn*mlp 6.81701135635376 layernormed 1.0006517171859741
	 attn_hist -36.5625<tensor([ 28., 136., 418., 123.,  58.,   5.])>44.25 mlp_hist -4.125<tensor([  4.,  26.,  81., 354., 270.,  33.], dtype=torch.bfloat16)>1.5703125
INFO nextres 19.019399642944336 attn*mlp 8.641536712646484 layernormed 1.0006517171859741
	 attn_hist -30.0<tensor([ 65., 130., 418.,  95.,  55.,   5.])>43.5 mlp_hist -5.03125<tensor([ 24., 156., 308.,  90., 113.,  78.], dtype=torch.bfloat16)>5.65625
INFO nextres 20.741458892822266 attn*mlp 7.100666046142578 layernormed 1.0006517171859741
	 attn_hist -34.3125<tensor([ 21., 186., 411., 108.,  38.,   4.])>50.625 mlp_hist -5.34375<tensor([ 23., 328., 159.,  73.,  66., 120.], dtype=torch.bfloat16)>8.1875

Note

Retroactive Note 2/25. This experiment seems worth re-running. I recall concluding that the value matrix is unnecessary, though it does speed up the forward pass at no cost to the perplexity.

Some sketchy thoughts on the combinator calculus idea

Recall that in combinator calculi, we need to be able to (1) copy arguments (inverting this operation is kind of the point of compression / learning a computation) and (2) apply arguments to each other (programmability). When inverting, I suspect it is just pattern match replacement (is this the MLP?). Also, we observe that magnitude is somehow important for the encoding. Is it true that Attention combines tokens into single embeddings? Is this some manifestation of the programmability notion? There is this other thought I had, where I thought learning should be about learning to invert a computation graph, but here, structurally, everything seems arranged along the forward direction, and not about learning inverses. (For a long time, I thought that perhaps the MLP is learning "inverse mappings" of outputs to their inputs, under some ground truth function f. Perhaps this is still true, or perhaps computation backwards can still be described as computation forwards, or, gradient descent does everything for us.) Attention takes a weighted sum of prior tokens.

Somehow, addition feels like an application / one-step evaluation / perhaps it refers to the depth of the tree, and each embedding dimension is like a possible subtree at each level. But more likely, perhaps attention gives the tree structure.

Note

Retroactive Note 2/25. I'm not really sure what I was saying above, but I do think that attention specifies how tokens are arranged relative to each other, much like expressions arranged in a tree. And the MLP specifies the outcome of applying a token x to adjacent siblings attn(x).

Somehow, a sequence of embeddings represents code. The attention component learns how individual tokens (read subtrees) programmatically act on other tokens (other subtrees). (But how do we interpret the value matrix?) Note that addition does not distinguish between left and right subtrees, or have an order, so how come it works?

In any case, here's a general hand-wavy framework:

  • Backprop performs memorization of substitution rules.
  • regularization through LN limits "how much" we can memorize (limiting standard deviation).
  • forward pass performs the actual computation.

Let's try again:

  • the embedding itself encodes things like its position in the tree... and also its subtree... which is itself a sum of embeddings... (how is this possible? We only give it a 1D positional embedding.)
  • The MLP maps nodes (i.e. subtrees) to compressed inverses.
  • Attention then joins more nodes together as appropriate. (Why not join all nodes together? Well, maybe MLP cannot distinguish such a big sum? It shouldn't actually matter..., MLP should be able to distinguish. So why?) (If we sum everything together at every T, then the MLP triggers for every location T.) So somehow, we want the MLP to trigger selectively and in the right order.

Lots of Questions: How come the MLP can't pick out specific patterns by itself? Why does it need the attention... what does element-wise multiplication mean, between two embeddings? (i.e. why does attn(x)*mlp(x) work.) A masked embedding? Attention is the mask (does it generate only boolean outputs)? MLP is the mask? why do we need to mask the output of the mlp? What does magnitude of embedding mean?

Some Potential Answers/Ideas: An embedding does not see its sibling trees. Attention computes (given the embeddings) which other nodes each on acts on / is connected to, outputting a parent node (the sum of the child nodes). The MLP layer takes as input an embedding (subtree) and inverts it (ideally to something more compressed, or closer to the output). But the result is added on. Why is it added on? It should instead replace the embedding (but maybe not the whole thing, if only a subtree was inverted?)

Other Ideas: Try to do termination when it no longer updates / when it converges.

Getting rid of All Layer Loss

Note that tacking on the "all layer loss" generally does seem slightly worse, and it also takes forever to train. So there isn't any reason to do it if we have skip connections. The magnitude of the MLP output, however, is smaller under an all layer loss. A comparison of (true) loss:

testname = "10-resmlp-single-axm-novalue-copy-alllayer"
basename = "10-resmlp-single-axm-novalue"

loss plot

RMSNorm

Now, going bak to single layer loss, using the axm architecture, what does it look like with RMSNorm? Answer, it is pretty similar. (Graph is lost.)

On the inputs to attention and MLP

Note

Retrospective Note 2/25. This series of experiments starts digging into modifications to how we use the MLP and attention components.

Hypothesis: Maybe the MLP performs a copy of the embedding in a memorized location, i.e. it rotates it, or adds a positional embedding (and maybe this is why the residual is important? To make it easy to copy?) Let's try x = x + mlp(x + ln(attn(ln(x)))).

(The outcome is lost).

Question: What about a multiplicative position embedding?

Position embeddings are definitely additive. If we switch from x = tok_emb + pos_emb to x = tok_emb * pos_emb, the capability of the model to compute is completely destroyed, probably because the input signal is also being randomized:

loss plot

Question: What we feed attn*x into the mlp?

If we do y = self.attn(ln(x)), x=x+self.mlp(ln(x)*y)*y, it is worse, but not terrible --- it probably also converges:

loss plot

Adjusting the input again, y = self.attn(ln(x)), x=x+self.mlp(ln(x)+y)*y:

loss plot

In conclusion, we probably want to feed in just x into the mlp, without destroying that signal at all. (Why?)

Note

Retrospective Note 2/25. The line of thought here is a little piecemeal, in part because this journal is partial, and I didn't write down the outcome of many experiments. Also, I wish I ran some of these for longer. Hard to tell after only a few hundred steps.

Hypothesis: Perhaps (this is far-fetched) the residual is only important because of the initial positional embedding.

Let's see what happens if we don't feed in the residual to future layers, but we do add the positional embedding back in at every layer. It's not good! The residual is clearly important for reasons other than the initial positional embedding:

loss plot

Hypothesis: Perhaps the network still works if we omit the residual when giving input to the attention layer. The goal is to figure out where the residual is important. Namely:

What if we feed in the previous y to attn:

attn = self.attn(y, y)
mlp = self.mlp(ln(x))
y = attn*mlp
x = x + y

This crashes because values quickly go to infinity or zero.

Hypothesis: Perhaps the residual is not important as an output of the attention layer; instead of using the value matrix (or an identity matrix), what if we give attention the output of the mlp instead?

Namely, let attn = KQ(x) @ mlp(x), . Well, it turns out worse. So somehow the fact that we are doing a linear combination of x and not mlp(x) is important. But note, mlp is not introduced anywhere else! (We took it away from the attn*mlp term.) And it still does seem to be able to learn, at least to a degree. I.e.

mlp = self.mlp(ln(x))
attn = self.attn(ln(x), mlp)
y = attn
x = x + y

loss plot

Note

Retrospective Note 2/25. In fact, I recall that, to some degree, ommitting the MLP entirely and only having attention still yields a network that can "learn", but it learns much worse.

At this point, I'm trying almost arbitrary inputs to each component, trying to see what sticks. I believe all of these are still in the regime of reusing weights, unfortunately.

Question: What if we do y = KQ(mlp(x)) @ x. Then, as before, it performs worse, but it is not fatal:

mlp = self.mlp(ln(x))
attn = self.attn(mlp,ln(x))
y = attn
x = x + y

loss plot

Hypothesis: The above experiments can probably be explained by getting rid of mlp entirely; that is, attn doesn't care at all about getting the output of mlp as input.

Indeed, both of the above experiments look similar to this one! So certainly the non-linearity of mlp is essential, but it does not show up immediately. It's not conclusive, but I'll take it for now.

attn = self.attn(ln(x), ln(x))
y = attn
x = x + y

loss plot

Note

Retrospective Note 2/25. Note that the syntax self.attn(x, y) refers to using KQ(x) @ y in place of using the value matrix. It is motivated by observing that the perplexity of KQ(x) @ x is similar to that when doing KQ @ V. (Of course, this takes longer to train). Regarding multi-head attention, there is no longer an output matrix used in this case if the second argument of self.attn is specified; we just sum over all n_head contributions.

Question: For good measure, what if run attn on mlp only? This is much worse than no-mlp, and also much worse than self.attn(mlp,x). Of all of these experiments (without including the mlp component elsewhere), self.attn(mlp, x) is for some reason the best (and self.attn(mlp, mlp) is the worst).

mlp = self.mlp(RMSNorm(x))
attn = self.attn(mlp, mlp)
y = attn
x = x + y

loss plot

Question: Now, adding the mlp back into its usual position, perhaps it doesn't matter what I feed into attention; what if I give it mlp for computing k and q (and x itself for the values)? Well, it seems to start off pretty reasonably, but still worse than our baseline, but it's certainly not fatal:

mlp = self.mlp(rmsN(x))
attn = self.attn(mlp, rmsN(x))
y = attn*mlp
x = x + y

loss plot

Note

Retrospective Note 2/25. The baseline here (and for most of these early experiments) is probably (a) reusing weights, and (b) x = x + attn(ln(x))*mlp(ln(x)). We only stop using (b) as part of the baseline when we stop re-using weights, because in the reusing weights regime (at this size), it is really quite good.

Directly comparing the previous experiment (now blue) to the vanilla experiment (orange, described below)

mlp = self.mlp(RMS(x))
attn = self.attn(RMS(x), RMS(x))
y = attn*mlp
x = x + y

loss plot

So, yeah. I'm stumped. The best still seems to be x = x + attn(ln(x))*mlp(ln(x)). Nothing seems to be able to improve on that.

Question: What happens again when mlp only gets the output of attn as input? I recall that mlp really wants to see x as input, so this should behave pooly, if I remember correctly.

Oh hey. it's surprisingly not so bad, but it is definitely worse:

attn = self.attn(RMS(x), RMS(x))
mlp = self.mlp(attn)
y = attn*mlp
x = x + y

loss plot

Hypothesis: Perhaps the fact that we are multiplying attn*mlp is helping obfuscate this whole relationship; maybe when we do attn+mlp, mlp will more desperately want x as input.

Indeed, when doing x=x+attn(RMS(x))+mlp(attn), the outcome is worse:

attn = self.attn(RMS(x), RMS(x))
mlp = self.mlp(attn)
y = attn+mlp
x = x + y

loss plot

For Completeness: For good measure, let's run the vanilla GPT with RMSNorm and no value matrix:

attn = self.attn(RMS(x), RMS(x))
mlp = self.mlp(RMS(x))
y = attn+mlp
x = res + y

loss plot

Note

Retrospective Note 2/25. How to interpret this series of experiments (previous and subsequent) still eludes me. It seems very... out there.

In general, it seems important to feed the whole x into the mlp and attention layers. In some sense, x represents an entire embedding, whereas attn and mlp perhaps generate only part of it.

Hypothesis: Now, one last time, why is the residual so important? If I omit the residual entirely, it should be bad.

Indeed it is:

attn = self.attn(RMS(x), RMS(x))
mlp = self.mlp(RMS(x))
y = attn*mlp
x = y

loss plot

Out of curiosity: It doesn't work particularly well, but what happens if I turn "All Layer Loss" back on, still leaving out the residual?

Surprisingly, removing the residual is fine if we do attn*mlp, as long as we compute the "all layer loss". It does seem to converge much slower: (Also, if we do attn + mlp, what happens? Todo.)

attn = self.attn(RMS(x), RMS(x))
mlp = self.mlp(RMS(x))
y = attn*mlp
x = y

loss plot

Attempt at a "how it should be" architecture from current intuition

Note

Retrospective Note 2/25. My interpretation of "how GPT should be" has certainly evolved a lot since these experiments. But it is fun to see what I thought when I first started this project.

I do think that from current principles, it should be as follows (with the residual, because the mlp is allowed to be a `no-op'.). But, concretely, this doesn't work as well for some reason (maybe it converges over time?):

attn = self.attn(x, x)
mlp = self.mlp(attn)
y = mlp
x = y + res
newres = x
x = RMSNorm(x, ELEMENTWISEAFFINE={ELEMENTWISEAFFINE}), 

loss plot

Or, maybe even y*res? Nope, the loss here is stuck and doesn't change, so there is some numerical problem.

attn = self.attn(x, x)
mlp = self.mlp(attn)
y = mlp
x = y * res
newres = x
x = RMSNorm(x, ELEMENTWISEAFFINE={ELEMENTWISEAFFINE}), 

loss plot

Let's try this one again but train it longer. IT does eventually converge to the same place, I think!

attn = self.attn(x, x)
mlp = self.mlp(attn)
y = mlp
x = y + res
newres = x
x = RMSNorm(x, ELEMENTWISEAFFINE={ELEMENTWISEAFFINE}), 

loss plot

Generally, we need a way for mlp to express a no-op. Also, one question is why taking the attention out actually makes it train faster.

Early Experiments

Lost to time. These included a lot of basic explorations of removing the residual, mlp(x)*attn(x), etc.

Appendix

Brainstorming

Data

It would be nice to somehow emphasize "High Importance" datapoints; i.e. predicting the next "and" or "the" or "or" is far less important/impactful than predicting the next "Yes" or mathematical formula. Despite the difference in impact, the loss penalizes both of them the same way. Many errors in the former category come from fundamental entropy of the source text, whereas errors in the latter category are true errors.

I am also curious what happens in an architecture that generates, say, a chunk (paragraph) of text all at the same time; i.e. say we tack on 128 extra columns with dummy inputs, and moreover say that these rows do not have the future zero'd out, and now their outputs are the output.

On scaling

What happens if I feed prior embeddings, as input into subsequent calls to the language model (instead of this autoregressive structure?)

On Model Evals

Sometimes I use Deepseek R1 to answer questions about math and can't help but think that it is better than e.g. OpenAI's current offerings, or 3.7 thinking mode. Usually when I ask questions I ask a few models at the same time. Yet, the benchmarks appear to tell a different story. How in the world should we measure this stuff?

One broad idea is to train "teachable models" that 1) learn by querying the dataset, and 2) that are intentionally limited in the knowledge that they know. For example, it'd be nice to have a dataset comprising the knowledge known by a third grader, and the train the model on that. Now, the evaluate a more complex model, we can see how quickly or how well it can "teach" the smaller model to obtain good performance. The approach overall seems non-trivial, but at least from first principles, even part 1) seems fundamental to learning in the most natural sense of the word; learning should be an active, adaptive, and not a passive process, unlike pre-training today.