Researchers from Apple recently published “Distillation Scaling Laws,” which offers the surprising result that distillation does not improve upon standard fully supervised learning for LLMs given sufficient model size and compute budget on the student. I found this to be incredibly surprising as I have always understood distillation to be useful basically everywhere as an implicit regularization mechanism to avoid over-fitting to noise and to learn generalizable structures instead.
In this post, I go over several recent theoretical interpretations of distillation and reconcile my understanding of distillation with some of the results in “Distillation Scaling Laws.” My leading personal theories for why distillation is not as effective for large-scale LLMs are the following:
As the number of tokens increases, the mark of a strong LLM is its ability to memorize facts rather than learn structure.
The low-dimensional nature of language data compared to image data, on which distillation is traditionally studied theoretically, is less prone to overfitting to noise compared to image classification.
The unprecedented scale of problem difficulty in language modeling compared to image classification makes LLMs under-fit rather than over-fit.
The unprecedented model scale in LLMs makes distillation redundant.
The authors missed some key hyper-parameters that need to be tuned in order for distillation to be effective at large scales.
Preliminaries: What is Distillation?
The goal of a machine learning model is to generalize well to unseen data, but such a training objective is unobtainable; hence we typically supervise a model to perform well on training data. In knowledge distillation, we supervise a small student to predict the same outputs as a large and powerful teacher model (as opposed to the typical fully supervised training setting of using ground truth labels) in the hopes that the student model learns to generalize the same way as the teacher model, which compared to the traditional setting better aligns with our true goal of generalizing well to unseen data. The general intuition is that the teacher model contains hidden dark knowledge in its outputs (e.g. in image classification, the teacher provides dark knowledge of what the second- and third- most likely classes are) that can be transferred to the student model to help it understand the task and generalize better.
Since knowledge distillation was introduced in 2014, self-distillation (SD) has also been observed, in which using a teacher and student model with identical architectures and capacities improves generalization. Moreover, repeated self-distillation was shown to improve generalization as well. This challenged the intuition introduced in the knowledge distillation paper as the teacher ostensibly cannot possibly contain any dark knowledge unavailable to the student given that they have the same capacity. Hence, both knowledge distillation and self distillation have been subject to extensive theoretical investigation.
Knowledge distillation is typically theoretically studied under the setting of classification tasks, but empirically it has ended up being extremely useful in basically every machine learning setting.
Theoretical Interpretations
Over the years there have been various theoretical interpretations of why distillation is so effective. What implicit behavior does distillation introduce that improves generalization?
Self-Distillation as an Implicit Form of Occam’s Razor. In 2020, Mobahi et al derived a closed-form solution in the case of kernel regression, which thanks to the Neural Tangent Kernel is quite similar to the case of an infinite-width two-layer neural network. Each round of self-distillation introduces a sparsifying effect on the closed-form solution so that it effectively uses fewer kernel basis functions. This acts as an implicit form of Occam’s Razor, resulting in a simpler solution that generalizes better.
The authors note that excessive self-distillation will result in under-fitting as the solution will become excessively sparse. However, this was later shown to not be the case if the proper hyper-parameters are selected.
Distillation In the Context of the Bias-Variance Trade-off. In 2020, Menon et al approached distillation from a Bayesian perspective, noting that you prefer your supervision signal to be the true distribution of p(y|x) where y is the label and x is the input rather than a one-hot vector of the ground truth label, thereby moving the student into a more optimal regime in the bias-variance trade-off that will over-fit less at the risk of slightly under-fitting more. Hence, it is more important for the teacher to be well-calibrated, meaning its confidence accurately reflects how likely the prediction is to be correct, than for it to be accurate. Deep neural networks are notoriously poorly calibrated, but this can be ameliorated by applying temperature-scaling to the output of the teacher model.
Distillation as Implicit Ensembling. In 2023, Allen-Zhu and Li related ensembling with distillation. They proved that under certain data conditions that they call “multi-view data” when a label is associated with multiple features (e.g. a car label for an image classifier is associated with wheels and a hood), neural networks tend to immediately fit to one feature and correctly solve all the data examples for that feature. They then overfit to noise on the few examples that do not contain the fitted feature but instead only contain the other features. Since the model does not learn to recognize the second feature but overfits to noise instead, for a certain class of neural networks, this provably results in perfect training accuracy but sub-par test accuracy.
Ensembling is therefore useful in forcing the model to learn to use all the features present in “multi-view data” and provably achieves perfect training and testing loss on these types of datasets. Distillation and self-distillation both act as an implicit form of ensembling in which the student learns to use its own features and must also learn to incorporate the teacher’s features. Given enough rounds of distillation such that the teachers keep learning sufficiently diverse features from the student, distillation implicitly achieves the same effect as ensembling. The authors also theoretically prove that distilled models also can achieve perfect train and test loss on “multi-view data.”
Using the examples of a car having wheels and a hood or a horse having a recognizable head and body, the authors argue that real-life image data fits their simplified description of “multi-view.” Empirically, on deep neural networks on real data, the authors also observe that self-distillation brings a similar boost to performance as ensembling, and that ensembling distilled models does not increase performance. This is consistent with their theory that distillation and ensembling are performing similar functions.
Distillation Scaling Laws
The “Distillation Scaling Laws” paper’s core results are as follows:
Given the model size and number of tokens trained for both the student and the teacher, the final loss of the student can be predicted fairly accurately with a scaling law.
The authors find that distillation does not improve upon fully supervised learning for LLMs given sufficient model size and compute budget. Hence, distillation only makes sense if you already have a teacher model at hand (that was trained for some reason other than distillation) and you have a limited compute budget.
The first result is pretty unsurprising, but it speaks to the rigor and comprehensiveness of their experiments. While it is quite possible that the authors missed a hyper-parameter to be tuned somewhere in their experiments, for the most part, I don’t expect there to be obvious bugs or omissions in their experiments, which is important given how surprising I found the second result.
The second result contradicts basically all of my previous intuition around distillation. We have something that has time over time been demonstrably useful across a wide variety of machine learning settings, has some nice intuitive explanations with some theory to back them up, and then it comes out to be mostly useless for large language models except for some limited compute budgets.
The second result can be visualized in the figure below, which shows the effect of distilling a teacher model to various sized students in the y-axis and various amount of training tokens in the x-axis. The red regime is where fully supervised learning outperforms distillation; the blue regime indicates the opposite. We see that if we train with enough tokens, fully supervised learning always wins, and this effect occurs more rapidly for larger student models.
I haven’t trained any LLMs, so I can’t say with any confidence why distillation appears to no longer be useful in LLMs. Here are some possible explanations I could think of:
The nature of large language modeling. The premise behind distillation is that it acts as a regularization to help you learn structures that generalize well. This aligns with the LLM generative task when there are not many training tokens, in which case distillation is empirically useful. However, as the number of training tokens increases past a certain point, it may be the case that the mark of a stronger LLM is its ability to memorize pure facts like who is the first president of the United States. In the memorization regime, we wouldn’t expect the implicit regularization of distillation to be helpful and could reasonably expect it to be hurtful, so fully supervised learning would outperform a distilled model.
The nature of language data. The distillation and ensembling paper points out that whether distillation is effective is highly dependent on both the model and the type of data. While it is easy to imagine the concept of “multi-view” data being prevalent in image classification data, it is not so straight-forward for language data. The concept of over-fitting to noise is also much more prevalent in image data, as the input is much higher dimensional. One interesting test of this hypothesis that the data modality is the problem would be to perform the same distillation scaling experiments on visual language tasks.
The scale of large language modeling. Several of the theoretical works assume that the teacher reaches perfect or near-perfect training accuracy for the task of image classification. This is significantly harder for the language modeling case as the problem complexity is much larger than in image classification. If the models are always under-fitting, the implicit regularization effects of distillation (e.g. moving towards a simpler solution or towards a higher-bias, lower-variance solution) become moot. It may also be the case that if we scaled image classification to be as large as the experiments in this paper, we would also see diminishing returns for distillation.
Model scaling. One of the more surprising results in deep learning is that highly over-parameterized models have very strong implicit regularization that causes deep models to generalize well. Just how ensembling and distillation are redundant in image classification, it is possible that at large enough scale the implicit regularization induced by over-parameterization makes all other implicit regularizations redundant. However, this explanation does not really explain the relationship between distillation effectiveness and the number of training tokens used, and as far as I am aware, ensembling is still useful for LLMs.
Hyper-Parameter Tuning. There are two key hyper-parameters in distillation. One is how much you want to linearly interpolate your objective between the standard objective and the distillation objective. Another is the temperature of the teacher model. Distillation can be quite sensitive to both these hyper-parameters as shown by some of the aforementioned works on repeated self-distillation and the Bayesian perspective on distillation.
In this paper’s experiments, the temperature of the teacher model is set to one and the pure distillation objective is used. While the paper rigorously defends its choices, it is possible that under certain large-scale regimes, these two hyper-parameters need to be adjusted, perhaps even dynamically, during training.
Conclusion
For many deep learning phenomena such as network over-parameterization, the double descent phenomenon, adversarial examples, and layer normalization, the research cycle goes something like this: some researchers empirically find a surprising or unintuitive result. Then, theory people construct a two-layer network or a deep linear network on synthetic data that is similar enough to deep learning to provide a satisfying explanation that is provably correct. Someone comes in with more empirical results to poke flaws in how the theoretical explanations translate to the real world, and the theory people go back to the drawing board.
A good experimental paper is thus one that provides surprising and experimentally rigorous results that forces the theory people and the bloggers to rethink some of their assumptions. In this sense, “Distillation Scaling Laws” is a great experimental paper, and I’m looking forward to seeing what the theory people come up with to explain these results.