KL Cache Trick for Full-Vocab Distillation

Caching the teacher contribution to full-vocab KL when the student-side loss head is fixed.
Distillation
LLM
MTP
Authors

Jonathan Chang

gpt-5.5-xhigh

Published

May 12, 2026

I asked Codex to optimize some full-vocab distillation code. During a 60-hour /goal session of profiling and optimizing the training MFU, it found a small KL cache trick: cache the teacher contribution to the student hidden-state gradient, not the teacher hidden state.

This is an extension to the hidden-state caching idea in DeepSeek-V4: move teacher work out of the student update. DeepSeek-V4’s version caches teacher_hidden; here, we cache teacher_probs @ W_S, which has dimension d_S, where d_S is the student hidden dimension and W_S is the student LM head.

KL Cache Trick

Let W_T be the teacher LM head and W_S be the fixed student-side loss head.

The usual full-vocab distillation code for one token is:

teacher_logits = teacher_hidden @ W_T.T
teacher_probs = softmax(teacher_logits)

student_logits = student_hidden @ W_S.T
student_log_probs = log_softmax(student_logits)

loss = (teacher_probs * (log(teacher_probs) - student_log_probs)).sum()
The Math

Split the teacher-only term out:

loss = (teacher_probs * log(teacher_probs)).sum() \
    - (teacher_probs * student_log_probs).sum()

The first term does not depend on the student, so for training the student we can drop it:

loss = -(teacher_probs * student_log_probs).sum()

Now expand student_log_probs:

student_log_probs = student_logits - logsumexp(student_logits)

So:

loss = -(teacher_probs * (student_logits - logsumexp(student_logits))).sum()

Distribute:

loss = logsumexp(student_logits) - (teacher_probs * student_logits).sum()

The last term is where the trick is. Since:

student_logits = student_hidden @ W_S.T

then:

(teacher_probs * student_logits).sum()

is the same as:

dot(student_hidden, teacher_probs @ W_S)

Define the cached vector:

teacher_cache = teacher_probs @ W_S

or, substituting the teacher computation:

teacher_cache = softmax(teacher_hidden @ W_T.T) @ W_S

Now the student-time loss is just:

student_logits = student_hidden @ W_S.T
loss = logsumexp(student_logits) - dot(student_hidden, teacher_cache)

Same hidden-state update, no teacher LM-head projection or teacher softmax during student training.

Gradient check:

The gradient of the log-sum-exp term is:

grad_logsumexp = student_probs @ W_S

and:

teacher_cache = teacher_probs @ W_S

So:

grad_student_hidden = student_probs @ W_S - teacher_cache
grad_student_hidden = (student_probs - teacher_probs) @ W_S

The math allows us to precompute and cache:

teacher_cache = softmax(teacher_hidden @ W_T.T) @ W_S

And the student-time loss becomes:

student_logits = student_hidden @ W_S.T
loss = logsumexp(student_logits) - dot(student_hidden, teacher_cache)

Intuition

The teacher distribution is a soft target over vocabulary tokens. The student LM head turns each token into a direction in student hidden space. So teacher_probs @ W_S is the teacher-weighted average of those directions.

The usual KL loss says: make student_hidden score the teacher-preferred tokens highly. This rewrite says the same thing in hidden space: make student_hidden align with the teacher’s average output direction, while logsumexp(student_logits) keeps the scores normalized.

Implications

There are a few interesting things about this trick:

  • This works even if teacher and student have different hidden sizes.
  • This removes the need to load the teacher LM head during student training, so it can sometimes save you from OOM.

However, there are also many constraints:

  • This only works when teacher and student have the same vocab.
  • The offline cache only works if the student LM head is frozen during training.
  • The cache can only be used for students with this particular LM head.
  • This only works for forward KL, not reverse KL.
Why not reverse KL?

Forward KL works because the teacher distribution weights the student log-probabilities:

loss = -(teacher_probs * student_log_probs).sum()

After expanding student_log_probs, the teacher side appears as:

teacher_probs @ W_S

Reverse KL has the opposite weighting:

loss = (student_probs * (student_log_probs - teacher_log_probs)).sum()

The teacher-dependent term is:

-(student_probs * teacher_log_probs).sum()
Here the weights are student_probs, which depend on the current student_hidden. So the teacher side cannot be precomputed into a fixed hidden-sized vector independent of the student.

This can be useful when the student LM head is frozen, and loading the teacher LM head during training reduces your available memory / maximum batch size and training throughput. You can use this trick to precompute and cache teacher_cache instead of teacher_hidden.

If the student LM head is not frozen, you can still cache teacher_hidden but compute fresh teacher_cache at each iteration. But if the gradient has to flow through W_S, teacher_probs has to be kept for backprop, making it memory-intensive and not practical.

One special use case I can think of where it could save computation is this:

In a modified MTP setup where several auxiliary prediction states share the same output head and distill from the same main-head future-token distribution, teacher_cache = teacher_probs @ W_S can be reused across those losses. This is related to the self-distillation setup in Self-Distillation for Multi-Token Prediction, where MTP heads share the output head with the main head and distill from detached main-head logits. Again, this only works if the student LM heads are shared and frozen.

Conclusion

This is not a very practical trick. But I still find it interesting that gpt-5.5-xhigh discovered this idea on its own while optimizing the training loop.

Acknowledgement

Thanks to Dmitrii Emelianenko for the discussion that connected this trick to MTP-style auxiliary prediction heads.

References

Cite this note

@misc{chang2026klcachetrick,
  title = {KL Cache Trick for Full-Vocab Distillation},
  author = {Jonathan Chang and {gpt-5.5-xhigh}},
  year = {2026},
  url = {https://jonathanc.net/blog/kl-cache-trick}
}

Reference BibTeX

@misc{deepseekai2026deepseekv4,
      title={DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence},
      author={DeepSeek-AI},
      year={2026},
      url={https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro},
}

@misc{zhao2026selfdistillationmultitokenprediction,
      title={Self-Distillation for Multi-Token Prediction},
      author={Guoliang Zhao and Ruobing Xie and An Wang and Shuaipeng Li and Huaibing Xie and Xingwu Sun},
      year={2026},
      eprint={2603.23911},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2603.23911},
}