Hierarchical Reasoning Model

Juan Vera

August 2025

Abstract

Reading Hierarchical Reasoning Models by Wang et al.

Reasoning is still a critical challenge in AI. The Chain of Thought (CoT) of language models suffers from task decomposition,

This stems from autoregressive inference, as a language model samples as P(y1,y2,yT)=t=1TP(yty<t)P(y_1, y_2, y_T) = \prod_{t=1}^T P(y_t | y_{<t}), where yty_t is the ttth token in the sequence, and y<ty_{<t} is the set of tokens before tt. Meaning small hallucinations in early steps can lead to larger errors in downstream steps. See more by Yann LeCun here.

extensive data requirements (O(n2d)O(n^2 \cdot d) for self-attention), and high latency.

Directly inspired by the human brain, HRMs execute sequential reasoning in a single forward pass through two interdependent recurrent modules—one for slow and abstract planning, and a low-level module for rapid and detailed computations.

With 27M parameters and ~1000 training examples, HRM beats larger transformer-based models on ARC-AGI-(1 & 2), Sudoku-Extreme, and Maze-Hard.

Introduction

Language models, despite being overly parameterized, with many transformer blocks stacked on top of each other, are paradoxically shallow.

This paper proves that language models can be simulated by uniform constant-depth threshold circuits, meaning they are in the class TC0\text{TC}^0.

TC0\text{TC}^0 is the class of functions computed by uniform families of Boolean circuits of constant depth and polynomial size, with unbounded fan-in AND, OR, and NOT gates.

TC0P\text{TC}^0 \subseteq \text{P}, generally sitting lower in the complexity hierarchy, meaning they can't solve all problems in PP (polynomial time), placing fundamental limits on what they can compute directly. There are simply some functions or patterns transformers cannot represent (more here).

LLMs are therefore not Turing-complete, as Turing machines can solve all problems in P\text{P}, and therefore cannot compute a significant portion of problems efficiently, or even at all.

Chain of Thought reasoning can also easily break down, where a single misordering of steps can collapse the entire reasoning process.

In this paper, they explore latent reasoning, where the model reasons within its internal hidden state, aligning with the fact that the human mind reasons separately from language.

Recurrent Neural Networks use such a hidden state, but they are computationally expensive, not parallelizable, and suffer from the vanishing gradient problem due to backpropagation through time.

The human brain is a compelling blueprint, organizing computations hierarchically rather than sequentially, operating at different speeds, which can enable deep multi-stage reasoning in parallel.

The HRM is constructed similarly, with a high-level module designed for abstract deliberate reasoning while the low-level module is designed for rapid and detailed computations, which helps avoid the rapid convergence of recurrent models, which they coin as "hierarchical convergence".

The higher-level module only advances after the lower-level module has completed multiple steps and reached a stable state.

They also propose a one-step gradient approximation for training the HRM, which eliminates the need for backpropagation through time, decreasing the computational requirements of training the recurrent model.

Due to its architecture, the HRM offers excellent performance using only 1,000 training examples, without any pre-training or CoT SFT, even learning solutions to problems that are intractable to LLMs.

Architecture

The model is inspired by three principles:

  • Hierarchical Processing, where the brain processes information over hierarchical cortical regions—differentiating long-term slow reasoning from fast thinking.
  • Temporal Separation, where different regions of the brain operate at different Hz.
  • Recurrent Connectivity, where different connections in the brain are recursive, allowing for more context-sensitive reasoning.

The model consists of:

  • An input network, fI(;θI)f_I(\cdot ; \theta_I)
  • A low-level recurrent module, fL(;θL)f_L(\cdot ; \theta_L)
  • A high-level recurrent module, fH(;θH)f_H(\cdot ; \theta_H)
  • An output network, fO(;θO)f_O(\cdot ; \theta_O)

An inference pass over the HRM is done through NN high-level cycles of TT low-level timesteps, meaning N×TN \times T total timesteps.

fLf_L and fHf_H each keep a hidden state, zLiz_L^i for the low-level module and zHiz_H^i for the high-level module.

Given an input vector xx, the HRM maps it to an output prediction vector y^\hat{y} through the following process:

First, x^=f(x;θl)\hat{x} = f(x; \theta_l)

At each time step ii, the low-level module updates its hidden state conditioned on zHiz_H^i, zLiz_L^i, and x^\hat{x}, while fH(;θH)f_H(\cdot ; \theta_H) updates its state after a full cycle of fL(;θL)f_L(\cdot ; \theta_L) using the hidden state of fL(;θL)f_L(\cdot ; \theta_L), zLiz_L^i, and zHiz_H^i, without x^\hat{x}:

zLi+1=fL(zHi,zLi,x^;θL)zHi+1={fH(zHi,zLi;θH)if imodT=0zHiotherwisez_L^{i+1} = f_L(z_H^i, z_L^i, \hat{x}; \theta_L) \\[2mm] z_H^{i+1} = \begin{cases} f_H(z_H^i, z_L^i; \theta_H) & \text{if } i \bmod T = 0 \\ z_H^i & \text{otherwise} \end{cases}

Meaning unless ii is divisible by TT with no remainder, the fH(;θH)f_H(\cdot ; \theta_H) is not updated.

Until NN full cycles are completed, a prediction y^\hat{y} is extracted from the hidden state of the output module:

y^=fO(zH(N×T);θO)\hat{y} = f_O(z_H^{(N \times T)}; \theta_O)

The entire set of N×TN \times T timesteps within the HRM represents a single forward pass through the HRM.

The HRM is designed to counteract the premature convergence brought upon by RNNs by allowing the high-level module to only advance after the low-level module has completed a full cycle of TT timesteps—which allows the low-level module to reach a "local equilibrium" before updating the high-level module.

Approximate Gradient Methods

Recurrent models use BPTT to compute gradients, but BPTT requires storing the hidden states from the forward pass and then combining them with gradients during the backward pass, which is computationally expensive in terms of memory, as it has a memory complexity of O(T)O(T), growing linearly with the sequence length TT, thereby forcing small batch sizes.

Consider the situation when you have an RNN, R\mathcal{R}, that is trained on the same input, xx, over TT timesteps, with weight updates at every timestep.

Eventually, you will converge to a fixed point, or hidden state, zz^*, where R(z,x^;θL)=z\mathcal{R}(z^*, \hat{x}; \theta_L) = z^*.

Meaning after lengthy recursive training on the same inputs, the RNN will reach a fixed hidden state—or in the case of the HRM, a local equilibrium.

So if we consider the HRM where the high-level module zHk=fH(zHk1,zLi;θH)z_H^k = f_H(z_H^{k-1}, z_L^i; \theta_H) serves as conditioning to the low-level module fL(zHk,zLi,x^;θL)f_L(z_H^k, z_L^i, \hat{x}; \theta_L), we can observe that the ideal HRM would reach a fixed point, zHz_H^*, where fH(zH,zLi;θH)=zHf_H(z_H^*, z_L^i; \theta_H) = z_H^*, and only then the high-level module can update its hidden state as zH=fH(zHk1,zL;θH)z_H^* = f_H(z_H^{k-1}, z_L^*; \theta_H).

This is because during every low-level iteration, fL()f_L(\cdot) is conditioned on fH()f_H(\cdot) and some input vector x^\hat{x}, meaning the only variable that can be updated is zLiz_L^i, and given that it's defined as a recursive function, it will eventually converge to a stabilized and fixed point.

Let's define F\mathcal{F} as the transformation which contains the updates of the high-level module and the low-level module, zHk=F(zHk;x~,θ)z_H^k = \mathcal{F}(z_H^k; \tilde{x}, \theta), where θ=(θI,θL)\theta = (\theta_I, \theta_L).

The fixed point where we have the equilibrium can be written as zH=F(zH,x~,θ)z_H^* = \mathcal{F}(z_H^*, \tilde{x}, \theta)

JF=FzHJ_{\mathcal{F}} = \frac{\partial \mathcal{F}}{\partial z_H} is the Jacobian matrix of F\mathcal{F} with respect to zHz_H.

If the matrix IJFI - J_\mathcal{F} is invertible at zHz_H^* and F\mathcal{F} is continuous and differentiable, the implicit function theorem then allows for the computation of the exact gradient at that point without explicit backpropagation:

zHθ=(IJF)1FθzH\frac{\partial z_H^*}{\partial \theta} = (I - J_\mathcal{F})^{-1} \frac{\partial \mathcal{F}}{\partial \theta} \vert_{z_H^*}

Computing the inverse of IJFI - J_\mathcal{F} is computationally expensive, so we can take the Neumann series expansion of the inverse:

(IJF)1=I+k=0(JF)k(I - J_\mathcal{F})^{-1} = I + \sum_{k=0}^\infty (J_\mathcal{F})^k

and approximate the first term at k=1k = 1, which leads us to an approximation of the gradients at zHz_H^*.

Deep Supervision

Given a data sample (x,y)(x, y), you run multiple forward passes of the HRM models, where MM is the total number of forward passes, and mim_i is the iith forward pass.

For each segment mi{1,,M}m_i \in \{1, \ldots, M\}, you compute similar to gradient descent as:

zm,y^m=HRM(zm1,x;θ)LmLoss(y^m,y)θθηθLmz^m, \hat{y}^m = \text{HRM}(z^{m-1}, x; \theta) \\[2mm] L^m \leftarrow \text{Loss}(\hat{y}^m, y) \\[2mm] \theta \leftarrow \theta - \eta \nabla_\theta L^m

with the caveat that zmz^m is not involved in the computation of gradients, as we're approximating the gradient with a single timestep and would not need zmz^m to compute the gradient.

Adaptive Computational Time

They incorporate a halting strategy into the HRM via QQ-learning, enabling the HRM to dynamically select the number of segments MM based on the complexity of the task.

The QQ-head uses the final state of fH()f_H(\cdot) to predict the Q-values, Q^haltm,Q^continuem=Q^m\hat{Q}^m_{\text{halt}}, \hat{Q}^m_{\text{continue}} = \hat{Q}^m, as:

Q^m=σ(θQzHmNT)\hat{Q}^m = \sigma(\theta_{Q}^\top z_H^{mNT})

where σ\sigma is the sigmoid gate to derive the Q-values for halting and continuing.

Let:

  • MmaxM_{\text{max}} be the maximum number of segments
  • MminM_{\text{min}} be the minimum number of segments, where Mmin2M_{\text{min}} \geq 2, sampled probabilistically in a uniform manner with probability ϵ\epsilon from the set {2,,Mmax}\{2, \ldots, M_{\text{max}}\} and with probability 1ϵ1 - \epsilon that it is 11.

The criteria for halting is given by:

  • When the segment count surpasses the maximum threshold
  • When the halt value exceeds the continue value
  • The segment count has reached the minimum threshold MminM_{\text{min}}

The Q-head is trained through Q-learning, defined by a Markov decision process with a SS-state space, AA-action space, and a reward function R(s,a)R(s, a), where sm={z0,,zMNT}s^m = \{z^0, \ldots, z^{MNT}\} is the state, am{halt,continue}a^m \in \{\text{halt}, \text{continue}\} is the action, and R(sm,am)R(s^m, a^m) is the reward.

Once the Q-head halts, it returns a prediction and a corresponding binary reward for its prediction. Continuing returns a reward of 00.

For each possible action, G^m=(G^haltm,G^continuem)\hat{G}^m = (\hat{G}_{\text{halt}}^m, \hat{G}_{\text{continue}}^m):

G^continuem={Q^haltmif mNmaxmax(Q^haltm,Q^continuem)otherwise\hat{G}^m_\text{continue} = \begin{cases} \hat{Q}^m_\text{halt} & \text{if } m \geq N_{\text{max}} \\ \max(\hat{Q}^m_\text{halt}, \hat{Q}^m_\text{continue}) & \text{otherwise} \end{cases}

Meaning we halt if the steps we've taken exceed the maximum steps, or if the halt value exceeds the continue value.

The loss function is then LACTm=Loss(y^m,y)+BCE(Q^m,G^m)L^m_\text{ACT} = \text{Loss}(\hat{y}^m, y) + \text{BCE}(\hat{Q}^m, \hat{G}^m).

The stability of Q-learning is questionable, but under some conditions—such as Post-Normalization and weight decay—stability can be achieved.

Architecture

The HRM is a sequence-to-sequence architecture: the input and output are both sequences of tokens, which are then mapped into vectors.

The model includes an embedding layer fIf_I that converts tokens into vectors, and an output head that transforms the hidden state of the final timestep into the output probability vector y^\hat{y}.

The low-level and high-level modules are implemented using encoder-only Transformers with identical architectures and dimensions. They include enhancements in modern models, such as RoPE, GLUs, and Post-RMSNorm.

Results

  • 40.3% accuracy on ARC-AGI-1
  • 5.0% accuracy on ARC-AGI-2
  • 55.0% accuracy on Sudoku-Extreme
  • 74.5% accuracy on Maze-Hard

Beating DeepSeek R1, Claude 3.7, and o3-mini-high across all benchmarks.