End-to-end Algorithm Synthesis with Recurrent Networks: Logical Extrapolation Without Overthinking
Machine learning systems perform well on pattern matching tasks, but their
ability to perform algorithmic or logical reasoning is not well understood. One
important reasoning capability is logical extrapolation, in which models
trained only on small/simple reasoning problems can synthesize complex
algorithms that scale up to large/complex problems at test time. Logical
extrapolation can be achieved through recurrent systems, which can be iterated
many times to solve difficult reasoning problems. We observe that this approach
fails to scale to highly complex problems because behavior degenerates when
many iterations are applied -- an issue we refer to as "overthinking." We
propose a recall architecture that keeps an explicit copy of the problem
instance in memory so that it cannot be forgotten. We also employ a progressive
training routine that prevents the model from learning behaviors that are
specific to iteration number and instead pushes it to learn behaviors that can
be repeated indefinitely. These innovations prevent the overthinking problem,
and enable recurrent systems to solve extremely hard logical extrapolation
tasks, some requiring over 100K convolutional layers, without overthinking.
Authors
Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang, Micah Goldblum, Tom Goldstein