In recent years, the evolution of large language models has skyrocketed. BERT became one of the most popular and efficient models allowing to solve a wide range of NLP tasks with high accuracy. After BERT, a set of other models appeared later on the scene demonstrating outstanding results as well.
The obvious trend that became easy to observe is the fact that with time large language models (LLMs) tend to become more complex by exponentially augmenting the number of parameters and data they are trained on. Research in deep learning showed that such techniques usually lead to better results. Unfortunately, the machine learning world has already dealt with several problems regarding LLMs and scalability has become the main obstacle in effective training, storing and using them.
By taking into consideration this issue, special techniques have been elaborated for compressing LLMs. The objectives of compressing algorithms are either decreasing training time, reducing memory consumption or accelerating model inference. The three most common compression techniques used in practice are the following:
- Knowledge distillation involves training a smaller model trying to represent the behaviour of a larger model.
- Quantization is the process of reducing memory for storing numbers representing model’s weights.
- Pruning refers to discarding the least important model’s weights.
In this article, we will understand the distillation mechanism applied to BERT which led to a new model called DistillBERT. By the way, the discussed techniques below can be applied to other NLP models as well.
The goal of distillation is to create a smaller model which can imitate a larger model. In practice, it means that if a large model predicts something, then a smaller model is expected to make a similar prediction.
To achieve this, a larger model needs to be already pretrained (BERT in our case). Then an architecture of a smaller model needs to be chosen. To increase the possibility of successful imitation, it is usually recommended for the smaller model to have a similar architecture to the larger model with a reduced number of parameters. Finally, the smaller model learns from the predictions made by the larger model on a certain dataset. For this objective, it is vital to choose an appropriate loss function that will help the smaller model to learn better.
In distillation notation, the larger model is called a teacher and the smaller model is referred to as a student.
Generally, the distillation procedure is applied during the pretaining but can be applied during the fine-tuning as well.
DistilBERT learns from BERT and updates its weights by using the loss function which consists of three components:
- Masked language modeling (MLM) loss
- Distillation loss
- Similarity loss
Below, we are going to discuss these loss components and undestand the necessity of each of them. Nevertheless, before diving into depth it is necessary to understand an important concept called temperature in softmax activation function. The temperature concept is used in the DistilBERT loss function.
It is often to observe a softmax transformation as the last layer of a neural network. Softmax normalizes all model outputs, so they sum up to 1 and can be interpreted as probabilities.
There exists a softmax formula where all the outputs of the model are divided by a temperature parameter T:
The temperature T controls the smoothness of the output distribution:
- If T > 1, then the distribution becomes smoother.
- If T = 1, then the distribution is the same if the normal softmax was applied.
- If T < 1, then the distribution becomes more rough.
To make things clear, let us look at an example. Consider a classification task with 5 labels in which a neural network produced 5 values indicating the confidence of an input object belonging to a corresponding class. Applying softmax with different values of T results in different output distributions.
The greater the temperature is, the smoother the probability distribution becomes.
Masked language modeling loss
Similar to the teacher’s model (BERT), during pretraining, the student (DistilBERT) learns language by making predictions for the masked language modeling task. After producing a prediction for a certain token, the predicted probability distribution is compared to the one-hot encoded probability distribution of the teacher’s model.
The one-hot encoded distribution designates a probability distribution where the probability of the most likely token is set to 1 and the probabilities of all other tokens are set to 0.
As in most language models, the cross-entropy loss is calculated between predicted and true distribution and the weights of the student’s model are updated through backpropagation.
Distillation loss
Actually it is possible to use only the student loss to train the student model. However, in many cases, it might not be enough. The common problem with using only the student loss lies in its softmax transformation in which the temperature T is set to 1. In practice, the resulting distribution with T = 1 turns out to be in the form where one of the possible labels has a very high probability close to 1 and all other label probabilities become low being close to 0.
Such a situation does not align well with cases where two or more classification labels are valid for a particular input: the softmax layer with T = 1 will be very likely to exclude all valid labels but one and will make the probability distribution close to one-hot encoding distribution. This results in a loss of potentially useful information that could be learned by the student model which makes it less diverse.
That is why the authors of the paper introduce distillation loss in which softmax probabilities are calculated with a temperature T > 1 making it possible to smoothly align probabilities, thus taking into consideration several possible answers for the student.
In distillation loss, the same temperature T is applied both to the student and the teacher. One-hot encoding of the teacher’s distribution is removed.
Instead of the cross-entropy loss, it is possible to use KL divergence loss.
Similarity loss
The researchers also state that it is beneficial to add cosine similarity loss between hidden state embeddings.
This way, the student is likely not only to reproduce masked tokens correctly but also to construct embeddings that are similar to those of the teacher. It also opens the door for preserving the same relations between embeddings in both spaces of the models.
Triple loss
Finally, a sum of the linear combination of all three loss functions is calculated which defines the loss function in DistilBERT. Based on the loss value, the backpropagation is performed on the student model to update its weights.
As an interesting fact, among the three loss components, the masked language modeling loss has the least importance on the model’s performance. The distillation loss and similarity loss have a much higher impact.
The inference process in DistilBERT works exactly as during the training phase. The only subtlety is that softmax temperature T is set to 1. This is done to obtain probabilities close to those calculated by BERT.
In general, DistilBERT uses the same architecture as BERT except for these changes:
- DistilBERT has only half of BERT layers. Each layer in the model is initialized by taking one BERT layer out of two.
- Token-type embeddings are removed.
- The dense layer which is applied to the hidden state of the [CLS] token for a classification task is removed.
- For a more robust performance, authors use the best ideas proposed in RoBERTa:
– usage of dynamic masking
– removing the next sentence prediction objective
– training on larger batches
– gradient accumulation technique is applied for optimized gradient computations
The last hidden layer size (768) in DistilBERT is the same as in BERT. The authors reported that its reduction does not lead to considerable improvements in terms of computation efficiency. According to them, reducing the number of total layers has a much higher impact.
DistilBERT is trained on the same corpus of data as BERT which contains BooksCorpus (800M words) English Wikipedia (2500M words).
The key performance parameters of BERT and DistilBERT were compared on the several most popular benchmarks. Here are the facts important to retain:
- During inference, DistilBERT is 60% faster than BERT.
- DistilBERT has 44M fewer parameters and in total is 40% smaller than BERT.
- DistilBERT retains 97% of BERT performance.
DistilBERT made a huge step in BERT evolution by allowing it to significantly compress the model while achieving comparable performance on various NLP tasks. Apart from it, DistilBERT weighs only 207 MB making the integration on devices with restricted capacities easier. Knowledge distillation is not the only technique to apply: DistilBERT can be further compressed with quantization or pruning algorithms.
All images unless otherwise noted are by the author