Transformers have become the state-of-the-art neural network architecture
across numerous domains of machine learning. This is partly due to their
celebrated ability to transfer and to learn in-context based on few examples.
Nevertheless, the mechanisms by which Transformers become in-context learners
are not well understood and remain mostly an intuition. Here, we argue that
training Transformers on auto-regressive tasks can be closely related to
well-known gradient-based meta-learning formulations. We start by providing a
simple weight construction that shows the equivalence of data transformations
induced by 1) a single linear self-attention layer and by 2) gradient-descent
(GD) on a regression loss. Motivated by that construction, we show empirically
that when training self-attention-only Transformers on simple regression tasks
either the models learned by GD and Transformers show great similarity or,
remarkably, the weights found by optimization match the construction. Thus we
show how trained Transformers implement gradient descent in their forward pass.
This allows us, at least in the domain of regression problems, to
mechanistically understand the inner workings of optimized Transformers that
learn in-context. Furthermore, we identify how Transformers surpass plain
gradient descent by an iterative curvature correction and learn linear models
on deep data representations to solve non-linear regression tasks. Finally, we
discuss intriguing parallels to a mechanism identified to be crucial for
in-context learning termed induction-head (Olsson et al., 2022) and show how it
could be understood as a specific case of in-context learning by gradient
descent learning within Transformers.

Authors

Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, Max Vladymyrov