2026-04-05
Thinking While Listening: Fast-Slow Recurrence for Long-Horizon Sequential Modeling
Shota Takashiro, Masanori Koyama, Takeru Miyato
Problem
Long-horizon sequential modeling requires maintaining an internal state that summarizes past information while processing new observations and selectively retaining task-relevant signals. This is a known challenge for RNNs (Bengio et al., 1994; Pascanu et al., 2013) and Transformers (Tay et al., 2021; Liu et al., 2023; Sinha et al., 2025).
Recent latent recurrent models — including Geiping et al. (2025), Hao et al. (2024b), Gladstone et al. (2025), Miyato et al. (2025), Darlow et al. (2025) — iterate a recurrent block in latent space for test-time compute scaling. However, when applied to sequential inputs, these methods either: (1) reinitialize the latent state at each turn and append decoded summaries to a growing context window (causing context length to grow with sequence), or (2) in the case of Coconut (Hao et al., 2024b), keep growing conditional context even in latent space. This makes long-horizon integration unreliable.
Perceiver (Jaegle et al., 2021) does RNN-style latent updates but lacks a fast reasoning loop faster than observations. Adaptive Computation Time (Graves, 2017) separates fast/slow dynamics but doesn’t carry the latent state across loops — it produces the next initial state via ensemble of trajectories. HRM/TRM (Wang et al., 2025; Jolicoeur-Martineau, 2025) have coupled fast/slow dynamics but operate on stationary inputs only. CTM (Darlow et al., 2025) handles sequential input but uses explicit pairwise synchronization strengths integrated over time rather than direct latent readout.
Architecture
FSRM (Fast-Slow Recurrent Model) reformulates latent recurrence as a coupled dynamical system with two timescales sharing a single time axis $t$:
- Fast time $t$: latent reasoning process $X(t) \in \mathbb{R}^{K \times D}$
- Slow time $s = \lfloor t / T \rfloor$: observation process $C(t) = \text{Encoder}(O_s)$
- $T$ = number of latent reasoning iterations per observation (inner-loop depth)
- Latent state $X(0) \sim \mathcal{N}(0, I)$ is never reset during an episode
The core update rule:
\[C(t) = \text{Encoder}(O\_s) \quad \text{for } s = \lfloor t/T \rfloor \in \{0, \ldots, S-1\}\] \[X(t+1) = \Pi\bigl(X(t) + \gamma \, F\bigl(X(t), C(t); \theta\bigr)\bigr) \quad \text{for } t \in \{0, \ldots, ST-1\}\]where $\Pi(x) = x / |x|$ normalizes tokens to the unit sphere, $\gamma > 0$ is a learnable step size, and $F$ is a shared recurrent module. Predictions use $\text{Decoder}(X(sT))$.
Fast module F (AKOrN-based)
Building on Miyato et al. (2025) with attention + MLP extension:
\[F(X, C)\_i := \Omega\_i x\_i + \text{Proj}\_{x\_i}\bigl(J(X, C)\_i\bigr)\]where:
- $x_i$ is the $i$-th latent token
- $\text{Proj}_x(v) = v - (v^\top x)\, x$ projects onto the tangent of the unit sphere
- $\Omega_i$ is an anti-symmetric matrix
- $J$ is attention-based (GTA positional embeddings from Miyato et al., 2023) followed by a ReLU MLP
class FastLayer:
def forward(state, cond, T):
for t in range(T):
state += gamma * J(state, cond)
state /= norm(state)
return state
class FastSlow:
def forward(tokens, T):
c = embed(tokens) # (L, K, C)
x = randn(K, C) # latent state, never reset
for s in range(len(tokens)):
x = fast1(x, c[s], T)
logits[s] = classifier(x)
return logits
Multi-layer variant (Dyck task)
For hierarchical structure, the two-stage architecture cascades fast processes:
\[X^{(0)}(t) = C(t), \qquad X^{(\ell)}(t+1) = \Pi\Bigl(X^{(\ell)}(t) + \gamma \, F\bigl(X^{(\ell)}(t), X^{(\ell-1)}(t)\bigr)\Bigr)\]for $\ell = 1, 2$. Uses a history queue of size $H = 4$ feeding into the second layer.
Energy-like scalar
\[E(X) = -\frac{1}{2} \sum\_i x\_i^\top J(X, C)\_i\]This tracks the self-organization level of the latent population (related to Kuramoto, 1984). The model learns to align interpretable structural properties with this energy.
Maze-specific variant
For the maze task, $J$ is wrapped with a GRU for continuity: $J_{\text{GRU}}(X, C) = \text{GRU}(J(X, C); Z)_i$, where $Z$ is updated at every fast step.
Training
Supervised tasks (Maze, Dyck)
- Optimizer: AdamW
- Batch size: 256
- Scheduler: cosine
- Gradient clipping: 0.1 (Maze), 1.0 (Dyck)
- Weight decay: 0.1 (Maze), 0.01 (Dyck)
| Setting | Maze | Dyck |
|---|---|---|
| Epochs | 300 | 30 |
| Learning rate | $1 \times 10^{-3}$ | $5 \times 10^{-3}$ |
| Dataset size | 45K train / 5K val | 10K train / 1K val |
| Internal steps $T$ | 5 | 5 |
| Hidden dim (channels) | 64 | 256 |
| Attention heads | 4 | 4 |
| Oscillator dim | 4 | 4 |
| Initial $\gamma$ | 0.1 | 0.1 |
| Parameters | 1.16M | 1.41M |
RL tasks (MiniGrid / PPO)
- Optimizer: Adam, lr $2.5 \times 10^{-4}$
- Total timesteps: $1 \times 10^6$
- Environments: 32 parallel, 96 steps/env, 4 update epochs
- Discount factor: $\gamma = 0.995$, GAE $\lambda = 0.95$, clip coefficient $= 0.1$
- Entropy coefficient: $1 \times 10^{-2}$
- Internal steps $T$: 10 (gradient truncated to last 5 iterations)
- Parameters: 1.19M (channels=512, heads=4, oscillator dim=2)
Evaluation
1. Egocentric Maze
Train on $19 \times 19$ mazes, evaluate OOD on $39 \times 39$ mazes. The model receives a stream of egocentric $7 \times 7$ observations from a navigator following the right-hand rule, and must output the shortest path.
| Model | ID Acc ($19 \times 19$) | OOD Acc ($39 \times 39$) |
|---|---|---|
| LSTM (2.86M) | ~75% | ~20% |
| Mamba-2 (1.25M) | ~90% | ~25% |
| Transformer (1.73M) | ~95% | ~30% |
| Looped TF | ~95% | ~30% |
| S5 | ~85% | ~20% |
| CTM | ~95% | ~25% |
| FSRM (1.16M) | ~97% | ~60% |
Effect of $T$: OOD accuracy improves monotonically from 0.306 ($T$=1) → 0.595 ($T$=4) → 0.673 ($T$=8), saturating beyond $T$=8.
Weight sharing ablation (Table 6): recurrent $T$=5 achieves 0.612 OOD vs non-recurrent 5-layer stack at 0.509 and 10-layer stack at 0.445. Weight sharing is critical.
2. Dyck-(30, 5)
Train on random Dyck strings of length $\leq 40$ tokens ($k=30$ bracket types, max depth $m=5$). OOD evaluation on 1-regular runs up to length 2,560.
- All baselines (LSTM, Mamba-2, Transformer) decay to chance level (0.5) on long OOD sequences
- FSRM maintains ~100% accuracy over indefinitely long sequences on OOD
- Against frontier LLMs (GPT-5.1 Thinking, Claude Opus 4.5, Gemini 3.0 Pro) given the explicit stack algorithm in a prompt, FSRM maintains >90% accuracy at lengths 2 orders of magnitude beyond training range (up to $10^5$), while all LLMs degrade rapidly
Emergent structure: layer-1 energy spikes with new bracket types and stabilizes per-type; layer-2 energy organizes by stack depth regardless of bracket identity. PCA projections form crisp, spring-shaped manifolds indexed by depth.
3. MiniGrid RL
Zero-shot transfer from simple ID environments to larger OOD environments. Sparse reward (only at episode end).
DoorKey: train $5 \times 5$ → OOD $6 \times 6$, $8 \times 8$, $16 \times 16$. FSRM outperforms LSTM, Mamba-2, Transformer-XL across all sizes.
MultiRoom: train N2-S4 → OOD N4-S5, N6. FSRM matches or exceeds baselines.
LavaCrossing: train S9N1 → OOD S9N2, S9N3, S11N5. FSRM achieves best average success rate.
Results averaged over 5 seeds. Latent states self-organize around key events (key pickup, door opening, goal discovery) in PCA space.
Ablation: Fast module F
On the maze task, replacing AKOrN with a Transformer block + RMSNorm works competitively but slightly worse. LSTM as fast module performs noticeably worse, suggesting the self-organizing mechanism is important.
Parameter efficiency
FSRM is ~1 order of magnitude smaller than baselines. Scaling channels from 64 → 128 → 256 improves OOD accuracy (0.692 → 0.670 → 0.727), showing gains are not from reduced overfitting alone.
Inference cost
~3× slower than Mamba/Transformer at $T$=5 on a single NVIDIA GH200 (unoptimized PyTorch, no torch.compile). Cost scales linearly with $T$.
Reproduction Guide
No public code release as of the paper date (April 2026). Key implementation details for reimplementation:
# Core dependencies
pip install torch numpy gym-minigrid
# The architecture is relatively simple — main components:
# 1. Self-attention with GTA positional embeddings (Miyato et al., 2023)
# 2. Anti-symmetric rotation Omega
# 3. Sphere normalization Pi(x) = x / ||x||
# 4. Coupled fast-slow loop with learnable gamma
Key pseudocode from the paper for the fast module J:
class J:
def __init__(self, dim, oscillator_dim):
self.attn = SelfAttention(dim) # with GTA positional embeddings
self.mlp = MLP(dim) # ReLU activation
self.omega = Omega(dim, oscillator_dim) # anti-symmetric matrix
def forward(self, x, c):
y = self.attn(x + c) # attention over latent tokens, obs added
y = self.mlp(x + c + y) # MLP on augmented features
y = proj_tangent(y, x) # y - (y^T x)x
y = y + self.omega(x) # anti-symmetric rotation
return y
Two-stage Dyck architecture (Figure 14):
class FastSlow:
def __init__(self, init_gamma):
self.embed = nn.Embedding()
self.fast1 = FastLayer(init_gamma)
self.fast2 = FastLayer(init_gamma)
def forward(self, tokens, H=4, T=5):
c = self.embed(tokens) # (L, K, C)
x = randn(K, C)
z = randn(H, K, C)
h = queue(H)
z_out = zeros(H, K, C)
for s in range(len(tokens)):
x = self.fast1(x, c[s], T)
x_out = readout(x)
h.enqueue(x_out)
z = self.fast2(z, h + z_out, T)
z_out = readout(z)
logits[s] = classifier(z_out)
return logits
Recommended hyperparameter starting points: AdamW, lr $1 \times 10^{-3}$, batch 256, cosine schedule, gradient clip 0.1, $T$=5, channels 64, 4 heads, oscillator dim 4, initial $\gamma$=0.1.
Notes
- The model’s key insight is aligning the time axes of the observation process and latent reasoning so the latent state persists continuously, avoiding context growth. This is architecturally simple but has strong empirical effects.
- The energy-like scalar $E(X)$ provides a useful interpretability lens across all three task domains (maze frustration events, Dyck bracket structure, RL key events).
- Weight sharing across the fast loop is crucial — simply stacking independent layers degrades OOD generalization. This aligns with the theoretical motivation that recurrent dynamics produce self-organizing clustering (Geshkovski et al., 2024).
- Parameter counts are 1.16M / 1.41M / 1.19M vs 1.25–37.1M for baselines. The recurrent core is highly parameter-efficient.
- Main limitations: (1) only tested on synthetic/grid domains, not real-world scale, (2) ~3× inference overhead vs optimized baselines, (3) sensitivity to random initialization (gradient instability from repeated layer reuse).
- Test-time scaling of $T$ at inference (training $T$=10) does not significantly harm performance, but the paper notes careful fast/slow balancing is needed for proper scaling.