/dev/posts/

Neural Network Distillation

Published:

Updated:

Overview of neural network distillation as done in “Distilling the Knowledge in a Neural Network” (Hinton et al, 2014).

Table of content

Overview

What? Transferring knowledge from one classifier neural network f^* (the teacher model) to a different model f_\theta (the student model).

Learner model f^*:

We don't need to have access to the architecture or parameters of the teacher model. We only be able to evaluate from it to get logits or probabilities (eg. through an API[1]).

Student model f_\theta:

Both models generate discrete probability distributions (classification tasks, stochastic agent actions):

Primary objective

We want to train the student model to match the probability distributions of the teacher model:

\mathrm{softmax}(f_\theta(x)) ≈ \mathrm{softmax}(f^*(x))

Distillation is obtained by fitting the student model to minimize the cross-entropy loss between the soft targets and the student predictions with temperature scaling (at high temperature, T > 1):

L_1(\theta) = \frac{1}{N} \sum_{i=1}^N H[ \overbrace{\mathrm{softmax}\left(\frac{f^*(x_i)}{T}\right)}^\text{soft target (teacher)} \| \overbrace{ \mathrm{softmax}\left( \frac{ f_\theta(x_i) }{T} \right) }^\text{scaled student prediction} ]

where T is the temperature (an hyperparameter).

The cross-entropy H(p \| q) is:

H(p \| q) = - \mathbb{E}_{y \sim p} \log q(y) = - \sum_y p(y) \log q(y)

Secondary objective

We can at the same time train the model to predict the ground-truth labels (y_i) (if available) by introducing a secondary objective function which is the cross-entropy loss between the ground truth labels (y_i) (hard targets) and the student predictions (without temperature scaling, T=1):

L_2(\theta) = \frac{1}{N} \sum_{i=1}^N H[ \overbrace{ \delta_{y_i} }^\text{hard targets (labels)} \| \overbrace{ \mathrm{softmax}\left( f_\theta(x_i) \right) }^\text{student prediction} ]

Which yields the objective function:

L(\theta) = {T^2} L_1(\theta) + \lambda L_2(\theta)

Explanations

Soft-target

The student model is trained to match the predictions of the teacher model (soft target): p_i^* = f^*(x_i) (opposed to the ground-truth labels y_i which are hard targets). The soft target gives a lot of information about the knowledge of the teacher model.

In particular, low probability values of the soft targets contain important valuable information about the learned teacher model f^* (relative probabilities of different classes). These are important to help the student model generalize.

Temperature scaling

Why using temperature scaling?

The low probability values of the soft targets contain important information. However, they tend to be be disregarded by the cross entropy loss (because of \mathbb{E}_{y \sim p}). Using high temperature ( T \gt 1) scaling smoothes the distributions and encourages the student model to care about the low probability values as well and not only the high probability ones.

Which temperature are we talking about?

The different experiments in the paper mentions using T=20, T > 8, T \in [2.5,4], T \in \{1,2,5,10\} .

Objective mixing

The paper mentions using \lambda = 0.5.

Other distillation methods

See the references for other distillation approaches.

The DeepSeek-R1 paper fine-tunes existing transformer-decoder language models (Qwen, Llama) on outputs of the DeepSeek-R1 model.

References

General:

For text:

Image classification:

In Diffusion models:


  1. You wouldn't steal a car, a handbag, a television, a baby or a helmet. You wouldn't steal logits, right? ↩︎