Since their introduction in 2017, transformers have emerged as a prominent force in the field of Machine Learning, revolutionizing the capabilities of major translation and autocomplete services.
Recently, the popularity of transformers has soared even higher with the advent of large language models like OpenAI’s ChatGPT, GPT-4, and Meta’s LLama. These models, which have garnered immense attention and excitement, are all built on the foundation of the transformer architecture. By leveraging the power of transformers, these models have achieved remarkable breakthroughs in natural language understanding and generation; exposing these to the general public.
Despite a lot of good resources which break down how transformers work, I found myself in a position where I understood the how the mechanics worked mathematically but found it difficult to explain how a transformer works intuitively. After conducting many interviews, speaking to my colleagues, and giving a lightning talk on the subject, it seems that many people share this problem!
In this blog post, I shall aim to provide a high-level explanation of how transformers work without relying on code or mathematics. My goal is to avoid confusing technical jargon and comparisons with previous architectures. Whilst I’ll try to keep things as simple as possible, this won’t be easy as transformers are quite complex, but I hope it will provide a better intuition of what they do and how they do it.
A transformer is a type of neural network architecture which is well suited for tasks that involve processing sequences as inputs. Perhaps the most common example of a sequence in this context is a sentence, which we can think of as an ordered set of words.
The aim of these models is to create a numerical representation for each element within a sequence; encapsulating essential information about the element and its neighbouring context. The resulting numerical representations can then be passed on to downstream networks, which can leverage this information to perform various tasks, including generation and classification.
By creating such rich representations, these models enable downstream networks to better understand the underlying patterns and relationships within the input sequence, which enhances their ability to generate coherent and contextually relevant outputs.
The key advantage of transformers lies in their ability to handle long-range dependencies within sequences, as well as being highly efficient; capable of processing sequences in parallel. This is particularly useful for tasks such as machine translation, sentiment analysis, and text generation.
To feed an input into a transformer, we must first convert it into a sequence of tokens; a set of integers that represent our input.
As transformers were first applied in the NLP domain, let’s consider this scenario first. The simplest way to convert a sentence into a series of tokens is to define a vocabulary which acts as a lookup table, mapping words to integers; we can reserve a specific number to represent any word which is not contained in this vocabulary, so that we can always assign an integer value.
In practice, this is a naïve way of encoding text, as words such as cat and cats are treated as completely different tokens, despite them being singular and plural descriptions of the same animal! To overcome this, different tokenisation strategies — such as byte-pair encoding — have been devised which break words up into smaller chunks before indexing them. Additionally, it is often useful to add special tokens to represent characteristics such as the start and end of a sentence, to provide additional context to the model.
Let’s consider the following example, to better understand the tokenization process.
“Hello there, isn’t the weather nice today in Drosval?”
Drosval is a name generated by GPT-4 using the following prompt: “Can you create a fictional place name that sounds like it could belong to David Gemmell’s Drenai universe?”; chosen deliberately as it shouldn’t appear in the vocabulary of any trained model.
Using the bert-base-uncased
tokenizer from the transformers library, this is converted to the following sequence of tokens:
The integers that represent each word will change depending on the specific model training and tokenization strategy. Decoding this, we can see the word that each token represents:
Interestingly, we can see that this is not the same as our input. Special tokens have been added, our abbreviation has been split into multiple tokens, and our fictional place name is represented by different ‘chunks’. As we used the ‘uncased’ model, we have also lost all capitalization context.
However, whilst we used a sentence for our example, transformers are not limited to text inputs; this architecture has also demonstrated good results on vision tasks. To convert an image into a sequence, the authors of ViT sliced the image into non-overlapping 16×16 pixel patches and concatenated these into a long vector before passing it into the model. If we were using a transformer in a Recommender system, one approach could be to use the item ids of the last n items browsed by a user as an input to our network. If we can create a meaningful representation of input tokens for our domain, we can feed this into a transformer network.
Embedding our tokens
Once we have a sequence of integers which represents our input, we can convert them into embeddings. Embeddings are a way of representing information that can be easily processed by machine learning algorithms; they aim to capture the meaning of the token being encoded in a compressed format, by representing the information as a sequence of numbers. Initially, embeddings are initialised as sequences of random numbers, and meaningful representations are learned during training. However, these embeddings have an inherent limitation: they do not take into account the context in which the token appears. There are two aspects to this.
Depending on the task, when we embed our tokens, we may also wish to preserve the ordering of our tokens; this is especially important in domains such as NLP, or we essentially end up with a bag of words approach. To overcome this, we apply positional encoding to our embeddings. Whilst there are multiple ways of creating positional embeddings, the main idea is that we have another set of embeddings which represent the position of each token in the input sequence, which are combined with our token embeddings.
The other issue is that tokens can have different meanings depending on the tokens that surround it. Consider the following sentences:
It’s dark, who turned off the light?
Wow, this parcel is really light!
Here, the word light is used in two different contexts, where it has completely different meanings! However, it is likely that — depending on the tokenisation strategy — the embedding will be the same. In a transformer, this is handled by its attention mechanism.
Perhaps the most important mechanism used by the transformer architecture is known as attention, which enables the network to understand which parts of the input sequence are the most relevant for the given task. For each token in the sequence, the attention mechanism identifies which other tokens are important for understanding the current token in the given context. Before we explore how this is implemented within a transformer, let’s start simple and try to understand what the attention mechanism is trying to achieve conceptually, to build our intuition.
One way to understand attention is to think of it as a method which replaces each token embedding with an embedding that includes information about its neighbouring tokens; instead of using the same embedding for every token regardless of its context. If we knew which tokens were relevant to the current token, one way of capturing this context would be to create a weighted average — or, more generally, a linear combination — of these embeddings.
Let’s consider a simple example of how this could look for one of the sentences we saw earlier. Before attention is applied, the embeddings in the sequence have no context of their neighbours. Therefore, we can visualise the embedding for the word light as the following linear combination.
Here, we can see that our weights are just the identity matrix. After applying our attention mechanism, we would like to learn a weight matrix such that we could express our light embedding in a way similar to the following.
This time, larger weights are given to the embeddings that correspond to the most relevant parts of the sequence for our chosen token; which should ensure that the most important context is captured in the new embedding vector.
Embeddings which contain information about their current context are sometimes known as contextualised embeddings, and this is ultimately what we are trying to create.
Now that we have a high level understanding of what attention is trying to achieve, let’s explore how this is actually implemented in the following section.
There are multiple types of attention, and the main differences lie in the way that the weights used to perform the linear combination are calculated. Here, we shall consider scaled dot-product attention, as introduced in the original paper, as this is the most common approach. In this section, assume that all of our embeddings have been positionally encoded.
Recalling that our aim is to create contextualised embeddings using linear combinations of our original embeddings, let’s start simple and assume that we can encode all of the necessary information needed into our learned embedding vectors, and all we need to calculate are the weights.
To calculate the weights, we must first determine which tokens are relevant to each other. To achieve this, we need to establish a notion of similarity between two embeddings. One way to represent this similarity is by using the dot product, where we would like to learn embeddings such that higher scores indicate that two words are more similar.
As, for each token, we need to calculate its relevance with every other token in the sequence, we can generalise this to a matrix multiplication, which provides us with our weight matrix; which are often referred to as attention scores. To ensure that our weights sum to one, we also apply the SoftMax function. However, as matrix multiplications can produce arbitrarily large numbers, this could result in the SoftMax function returning very small gradients for large attention scores; which may lead to the vanishing gradient problem during training. To counteract this, the attention scores are multiplied by a scaling factor, before applying the SoftMax.
Now, to get our contextualised embedding matrix, we can multiply our attention scores with our original embedding matrix; which is the equivalent of taking linear combinations of our embeddings.