The Learned Kernel: How Transformers Actually Do In-Context Learning
Summary
Transformers achieve superior in-context learning (ICL) by learning the similarity function itself, rather than relying on hand-tuned parameters or fixed distance metrics. This post, the third in a series, demonstrates how attention mechanisms in transformers function as a form of kernel regression with a learned kernel. It progresses from Ridge Regression (linear kernel) to Kernel Ridge Regression with a hand-picked RBF kernel, and finally to a custom `PretrainedAttentionICL` model. This model, trained on thousands of diverse random functions, learns optimal `W_Q`, `W_K`, and `W_V` matrices that define an RBF-like kernel. The experiment shows the pre-trained attention model matching the performance of an optimally hand-tuned RBF kernel on an unseen `sin(x)` function, achieving an MSE of 0.004, significantly outperforming linear ridge regression (0.171 MSE) and untrained attention (0.081 MSE).
Key takeaway
For AI Engineers developing or deploying large language models, understanding that transformer attention is a learned kernel regression clarifies the "magic" of in-context learning. You should recognize that pretraining's diversity is crucial for learning robust similarity functions, enabling models to adapt to new tasks without gradient updates. This insight can guide your architectural choices and pretraining strategies, emphasizing broad task exposure over specific task memorization to achieve strong generalization.
Key insights
Transformers excel at in-context learning by dynamically learning optimal similarity functions through their attention mechanisms.
Principles
- Attention is kernel regression with a learned kernel.
- Pretraining enables meta-learning of similarity functions.
- Dual form of Ridge Regression reveals kernel structure.
Method
Train an attention model on diverse tasks to learn Q, K, V matrices, which implicitly define a kernel. This learned kernel then performs in-context learning on new, unseen tasks without further weight updates.
In practice
- Implement attention as a learned Nadaraya-Watson estimator.
- Use `W_Q(x) @ W_K(x')^T` for learned similarity.
- Consider multi-head attention for ensemble kernels.
Topics
- In-Context Learning
- Transformers
- Kernel Regression
- Attention Mechanism
- Meta-learning
Code references
Best for: AI Engineer, Machine Learning Engineer, AI Researcher
Related on AIssential
Editorial summary, takeaway, and curation by AIssential. Original article published by Agus’s Substack.