Grokking

Transformer Architecture

Model Simplifications

Simplifications in Attention

From the plot, it is clear that attention from positon 2 to itself is negligible. And, since attention map is output of softmax, attention from position 2 to 0 and attention from position 2 to 1 appaer complimentary. i.e,

\begin{equation} (A_{2\to0}^{\mathbf(h)}, A_{2\to1}^{\mathbf(h)}, A_{2\to2}^{\mathbf(h)}) = (A_{2\to0}^{\mathbf(h)}, 1-A_{2\to1}^{\mathbf(h)}, 0) \\ \end{equation}

where \(A_{x_1 \to x_2}^{\mathbf(h)}\) is attention from postion \(x_1\) to \(x_2\) at attention head \(h\)

Role of skip connection around MLP

Key Equations

Derivation

Effective Weights

Using the derivation above and model simplifications, we see that there are 3 effective weight matrices that fully determines model behaviour:

  • \begin{equation} W_{attn}=u^T W_Q^{hT} W_K^hW_E \end{equation}
  • \begin{equation} W_{neur}=W_{in}W_O^hW_V^hW_E \end{equation}
  • \begin{equation} W_{logit} = W_UW_{out} \end{equation}

And the ouputs of these effective weights are:

  • Attention pattern: \(A^h=\sigma(W_{attn}^h (t_0 -t_1)) \)
  • Neuron activations: \(N=ReLU(\sum^3_{h=0} (A^h_0 W_{neur}^h t_0 + A^h_1 W_{neur}^h t_1)) \)
  • Output Logits: \(L=W_{logit}N + bias \)

Where \(t_0\) and \(t_1\) are the one-hot encoding of the input tokens \(x\) and \(y\) respectively.

Circuit and Feature Analysis

Circuit and Feature Analysis: Understanding Embedding Matrix

It is evident that Embedding matrix in fourier basis is very sparse. Apart from few frequecies, the norm value for the rest are zero. These frequencies are {1, 6, 37}

Circuit and Feature Analysis: Understanding Neuron Activations

From the 2d fourier basis plots of first 2 neurons, it can be said that, we can find a good approximation to the neuron activations with linear combination of \( 1 \), \( \cos(wx) \), \( \sin(wx) \), \( \cos(wy) \), \( \sin(wy) \), \( \cos(wx)\cos(wy) \), \( \cos(wx)\sin(wy) \), \( \sin(wx)\cos(wy) \), \( \sin(wx)\sin(wy) \). This can proven further by average of norms over all neurons in 2d fourier basis

Neuron Clusters and key frequecies

Now, for each neuron, we will find the frequecy that is best explained by linear and quadratic terms of that frequency and plot the fraction of variance of neuron activation that is explained by that frequency.

If you the frequecies, these are the same frequecies that we obtained at Embedding Matrix. Apart from these frequencies, we also have