To understand the changes made here, we first need to discuss the Key-Value Cache. Inside of the transformer we have 3 vectors that are critical for attention to work — key, value, and query. From a high level, attention is how we pass along critical information about the previous tokens to the current token so that it can predict the next token. In the example of self-attention with one head, we multiply the query vector on the current token with the key vectors from the previous tokens and then normalize the resulting matrix (the resulting matrix we call the attention pattern). We now multiply the value vectors with the attention pattern to get the updates to each token. This data is then added to the current tokens embedding so that it now has the context to determine what comes next.
We create the attention pattern for every single new token we create, so while the queries tend to change, the keys and the values are constant. Consequently, the current architectures try to reduce compute time by caching the key and value vectors as they are generated by each successive round of attention. This cache is called the Key-Value Cache.
While architectures like encoder-only and encoder-decoder transformer models have had success, the authors posit that the autoregression shown above, and the speed it allows its models, is the reason why decoder-only models are the most commonly used today.
To understand the YOCO architecture, we have to start out by understanding how it sets out its layers.
For one half of the model, we use one type of attention to generate the vectors needed to fill the KV Cache. Once it crosses into the second half, it will use the KV Cache exclusively for the key and value vectors respectively, now generating the output token embeddings.
This new architecture requires two types of attention — efficient self-attention and cross-attention. We’ll go into each below.
Efficient Self-Attention (ESA) is designed to achieve a constant inference memory. Put differently we want the cache complexity to rely not on the input length but on the number of layers in our block. In the below equation, the authors abstracted ESA, but the remainder of the self-decoder is consistent as shown below.
Let’s go through the equation step by step. X^l is our token embedding and Y^l is an intermediary variable used to generate the next token embedding X^l+1. In the equation, ESA is Efficient Self-Attention, LN is the layer normalization function — which here was always Root Mean Square Norm (RMSNorm
), and finally SwiGLU
. SwiGLU
is defined by the below:
Here swish = x*sigmoid (Wg * x)
, where Wg is a trainable parameter. We then find the element-wise product (Hadamard Product) between that result and X*W1 before then multiplying that whole product by W2. The goal with SwiGLU
is to get an activation function that will conditionally pass through different amounts of information through the layer to the next token.
Now that we see how the self-decoder works, let’s go into the two ways the authors considered implementing ESA.
First, they considered what is called Gated Retention. Retention and self-attention are admittedly very similar, with the authors of the “Retentive Network: A Successor to Transformer for Large Language Models” paper saying that the key difference lies in the activation function — retention removes softmax allowing for a recurrent formulation. They use this recurrent formulation along with the parallelizability to drive memory efficiencies.
To dive into the mathematical details:
We have our typical matrices of Q, K, and V — each of which are multiplied by the learnable weights associated with each matrix. We then find the Hadamard product between the weighted matrices and the scalar Θ. The goal in using Θ is to create exponential decay, while we then use the D matrix to help with casual masking (stopping future tokens from interacting with current tokens) and activation.
Gated Retention is distinct from retention via the γ value. Here the matrix Wγ is used to allow our ESA to be data-driven.
Sliding Window ESA introduces the idea of limiting how many tokens the attention window should pay attention to. While in regular self-attention all previous tokens are attended to in some way (even if their value is 0), in sliding window ESA, we choose some constant value C that limits the size of these matrices. This means that during inference time the KV cache can be reduced to a constant complexity.
To again dive into the math:
We have our matrices being scaled by their corresponding weights. Next, we compute the head similar to how multi-head attention is computed, where B acts both as a causal map and also to make sure only the tokens C back are attended to.