I asked Codex to optimize some full-vocab distillation code. During a 60-hour /goal session of profiling and optimizing the training MFU, it found a small KL cache trick: cache the teacher contribution to the student hidden-state gradient, not the teacher hidden state.
This is an extension to the hidden-state caching idea in DeepSeek-V4: move teacher work out of the student update. DeepSeek-V4’s version caches teacher_hidden; here, we cache teacher_probs @ W_S, which has dimension d_S, where d_S is the student hidden dimension and W_S is the student LM head.
KL Cache Trick
Let W_T be the teacher LM head and W_S be the fixed student-side loss head.
The usual full-vocab distillation code for one token is:
teacher_logits = teacher_hidden @ W_T.T
teacher_probs = softmax(teacher_logits)
student_logits = student_hidden @ W_S.T
student_log_probs = log_softmax(student_logits)
loss = (teacher_probs * (log(teacher_probs) - student_log_probs)).sum()The Math
Split the teacher-only term out:
loss = (teacher_probs * log(teacher_probs)).sum() \
- (teacher_probs * student_log_probs).sum()The first term does not depend on the student, so for training the student we can drop it:
loss = -(teacher_probs * student_log_probs).sum()Now expand student_log_probs:
student_log_probs = student_logits - logsumexp(student_logits)So:
loss = -(teacher_probs * (student_logits - logsumexp(student_logits))).sum()Distribute:
loss = logsumexp(student_logits) - (teacher_probs * student_logits).sum()The last term is where the trick is. Since:
student_logits = student_hidden @ W_S.Tthen:
(teacher_probs * student_logits).sum()is the same as:
dot(student_hidden, teacher_probs @ W_S)Define the cached vector:
teacher_cache = teacher_probs @ W_Sor, substituting the teacher computation:
teacher_cache = softmax(teacher_hidden @ W_T.T) @ W_SNow the student-time loss is just:
student_logits = student_hidden @ W_S.T
loss = logsumexp(student_logits) - dot(student_hidden, teacher_cache)Same hidden-state update, no teacher LM-head projection or teacher softmax during student training.
Gradient check:
The gradient of the log-sum-exp term is:
grad_logsumexp = student_probs @ W_Sand:
teacher_cache = teacher_probs @ W_SSo:
grad_student_hidden = student_probs @ W_S - teacher_cache
grad_student_hidden = (student_probs - teacher_probs) @ W_SThe math allows us to precompute and cache:
teacher_cache = softmax(teacher_hidden @ W_T.T) @ W_SAnd the student-time loss becomes:
student_logits = student_hidden @ W_S.T
loss = logsumexp(student_logits) - dot(student_hidden, teacher_cache)Intuition
The teacher distribution is a soft target over vocabulary tokens. The student LM head turns each token into a direction in student hidden space. So teacher_probs @ W_S is the teacher-weighted average of those directions.
The usual KL loss says: make student_hidden score the teacher-preferred tokens highly. This rewrite says the same thing in hidden space: make student_hidden align with the teacher’s average output direction, while logsumexp(student_logits) keeps the scores normalized.
Implications
There are a few interesting things about this trick:
- This works even if teacher and student have different hidden sizes.
- This removes the need to load the teacher LM head during student training, so it can sometimes save you from OOM.
However, there are also many constraints:
- This only works when teacher and student have the same vocab.
- The offline cache only works if the student LM head is frozen during training.
- The cache can only be used for students with this particular LM head.
- This only works for forward KL, not reverse KL.
Why not reverse KL?
Forward KL works because the teacher distribution weights the student log-probabilities:
loss = -(teacher_probs * student_log_probs).sum()After expanding student_log_probs, the teacher side appears as:
teacher_probs @ W_SReverse KL has the opposite weighting:
loss = (student_probs * (student_log_probs - teacher_log_probs)).sum()The teacher-dependent term is:
-(student_probs * teacher_log_probs).sum()student_probs, which depend on the current student_hidden. So the teacher side cannot be precomputed into a fixed hidden-sized vector independent of the student.
This can be useful when the student LM head is frozen, and loading the teacher LM head during training reduces your available memory / maximum batch size and training throughput. You can use this trick to precompute and cache teacher_cache instead of teacher_hidden.
If the student LM head is not frozen, you can still cache teacher_hidden but compute fresh teacher_cache at each iteration. But if the gradient has to flow through W_S, teacher_probs has to be kept for backprop, making it memory-intensive and not practical.
One special use case I can think of where it could save computation is this:
In a modified MTP setup where several auxiliary prediction states share the same output head and distill from the same main-head future-token distribution, teacher_cache = teacher_probs @ W_S can be reused across those losses. This is related to the self-distillation setup in Self-Distillation for Multi-Token Prediction, where MTP heads share the output head with the main head and distill from detached main-head logits. Again, this only works if the student LM heads are shared and frozen.
Conclusion
This is not a very practical trick. But I still find it interesting that gpt-5.5-xhigh discovered this idea on its own while optimizing the training loop.
Acknowledgement
Thanks to Dmitrii Emelianenko for the discussion that connected this trick to MTP-style auxiliary prediction heads.
References
- DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence, DeepSeek-AI, 2026. Model page.
- Self-Distillation for Multi-Token Prediction, Zhao et al., 2026.
Cite this note
@misc{chang2026klcachetrick,
title = {KL Cache Trick for Full-Vocab Distillation},
author = {Jonathan Chang and {gpt-5.5-xhigh}},
year = {2026},
url = {https://jonathanc.net/blog/kl-cache-trick}
}Reference BibTeX
@misc{deepseekai2026deepseekv4,
title={DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence},
author={DeepSeek-AI},
year={2026},
url={https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro},
}
@misc{zhao2026selfdistillationmultitokenprediction,
title={Self-Distillation for Multi-Token Prediction},
author={Guoliang Zhao and Ruobing Xie and An Wang and Shuaipeng Li and Huaibing Xie and Xingwu Sun},
year={2026},
eprint={2603.23911},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2603.23911},
}