2026-04-02

HCLSM: Hierarchical Causal Latent State Machines for Object-Centric World Modeling

Jaber Jaber, Osama Jaber

world-models object-centric causal-reasoning state-space-models

Problem

World models that predict future states from video suffer from flat latent representations that entangle all objects, timescales, and causal relationships into a single unstructured vector. This prevents agents from performing planning and counterfactual reasoning on individual objects.

The paper identifies three dimensions where current methods fall short:

  1. Objects: Physical scenes contain discrete entities whose states evolve semi-independently. Flat latent spaces cannot represent, e.g., a mug’s position separately from a table’s surface.
  2. Time: Physical dynamics operate at multiple scales simultaneously (continuous trajectories in milliseconds, discrete collision events, abstract strategic plans over minutes). No existing system handles all three within a single architecture.
  3. Causality: Without explicit causal structure, a model cannot answer counterfactual queries like “what if the gripper had pushed harder?”

Prior art and limitations:

  • V-JEPA (Bardes, 2024) and V-JEPA 2 (Assran, 2025): Predict video in latent space using JEPA with masking, scaled to billions of parameters, but operate on flat unstructured latent vectors with no object decomposition, no hierarchical dynamics, and no explicit causal structure.
  • DreamerV3 (Hafner, 2023): Learns a world model for model-based RL using recurrent state-space models, but uses flat latent states with no object decomposition or causal reasoning.
  • GAIA-1 (Hu, 2023): Builds a driving world model via autoregressive video generation, but again uses flat representations.
  • Slot Attention (Locatello, 2020): Introduces iterative soft-attention grouping of image patches into object slots, but lacks temporal dynamics entirely.
  • SAVi (Kipf, 2022): Extends slot attention to video with temporal slot propagation, but has no hierarchical dynamics or causal structure.
  • SAVi++ (Elsayed, 2022): End-to-end object-centric learning from real-world videos, but no temporal hierarchy or causal reasoning.
  • DINOSAUR (Seitzer, 2023): Replaces pixel reconstruction with self-supervised ViT features enabling real-world decomposition, but lacks temporal dynamics and causal structure.
  • SlotFormer (Wu, 2023): Adds autoregressive prediction on top of frozen slot representations using a single-scale transformer — no hierarchical temporal dynamics, no causal reasoning, and no training on real data.
  • Slot SSM (Jiang, 2024): Applies SSMs to object-centric representations, but lacks hierarchical dynamics, causal reasoning, and real-data training.
  • Adaptive Slot Attention (Fan, 2024): Learns the number of slots dynamically, but no dynamics or causality.

No existing system unifies object decomposition, hierarchical temporal dynamics, and causal reasoning in a single differentiable architecture.

Architecture

HCLSM processes video through five interconnected layers: perception, object decomposition, hierarchical dynamics, causal reasoning, and continual memory. The model has 68M parameters in the Small configuration (Base: 262M, Large: 3B, though larger configs have unresolved stability issues).

Layer 1: Perception

A Vision Transformer encoder processes video frames of shape $(B, T, C, H, W)$ into patch embeddings $(B, T, M, d_{\text{model}})$ where $M = (H/p)^2$ is the number of patches with patch size $p = 16$. Temporal position embeddings are added per-frame. A linear projection maps from $d_{\text{model}}$ to $d_{\text{world}}$, the unified representation dimension.

Layer 2: Object Decomposition

Slot Attention with Dynamic Birth/Death: $N_{\max}$ slot proposals (32 by default) are initialized from a learned Gaussian $\mathcal{N}(\mu, \sigma^2)$. For $K$ iterations, slots compete for patch tokens through softmax over the slot dimension:

\[A\_{nk} = \frac{\exp(q\_n \cdot k\_k / \sqrt{d})}{\sum\_{n'} \exp(q\_{n'} \cdot k\_k / \sqrt{d})}\]

Each iteration refines slots through weighted value aggregation and a GRU cell. An existence head predicts $p_{\text{alive}} \in [0, 1]$ per slot. When residual attention energy exceeds a threshold, a dormant slot is “born” by projection from the highest-residual token.

Spatial Broadcast Decoder (SBD): Each slot is independently broadcast to a $14 \times 14$ spatial grid, concatenated with $(x, y)$ positional coordinates, and decoded by a 4-layer CNN into feature predictions plus an alpha mask. Alpha masks are softmax-normalized over alive slots:

\[\alpha\_{n,p} = \frac{\exp(\hat{\alpha}\_{n,p})}{\sum\_{n':\text{alive}} \exp(\hat{\alpha}\_{n',p})}\]

The reconstruction target is frozen ViT patch features from an EMA target encoder (following DINOSAUR), giving slots a semantic signal rather than low-level texture matching:

\[\mathcal{L}\_{\text{SBD}} = \sum\_{n} \sum\_{p} \alpha\_{n,p} \| \hat{f}\_{n,p} - f\_p^* \|^2\]

Relation Graph (GNN): A GNN processes all-pairs edge features $[o_i ; o_j ; o_i - o_j ; o_i \odot o_j]$ through an edge MLP, producing weighted messages aggregated per-node. For $N > 32$ slots, chunked computation processes edges in blocks of 16 to prevent memory overflow (at $N = 64$, a full pair tensor would require 4GB).

Layer 3: Hierarchical Dynamics

Level 0 — Selective SSM (Continuous Physics): Each object gets its own SSM track with shared parameters:

\[h\_t = A\_t \odot h\_{t-1} + B\_t \odot x\_t, \quad y\_t = C\_t^\top h\_t\]

where $A_t = \exp(\Delta_t A)$ and $\Delta_t, B_t, C_t$ are input-dependent (following Mamba). A global SSM processes mean-pooled object states and conditions per-object tracks via additive context. $A_{\log}$ is initialized in $[-0.5, 0]$ (not the standard $\log(1 \ldots d_{\text{state}})$) and $\Delta_t A$ is clamped to $[-20, 0]$ for numerical stability at bf16.

Level 1 — Sparse Event Transformer: An event detector monitors Level 0 states using multi-scale temporal features (frame differences at scales 1, 2, 4) processed through causal dilated convolutions. When the event score exceeds a learned threshold, the corresponding timestep is gathered into a dense event tensor. A standard transformer with SwiGLU feed-forward processes only $K \ll T$ event timesteps, with cost $\mathcal{O}(K \cdot N^2)$ instead of $\mathcal{O}(T \cdot N^2)$. The detector identifies 2–3 events per 16-frame sequence in PushT.

Level 2 — Goal Compression Transformer: Learned summary query tokens cross-attend to the event sequence, compressing it into $n_{\text{summary}}$ abstract state tokens processed by a goal-level transformer, optionally conditioned on language/goal embeddings.

Hierarchy Manager: The three levels are combined via vectorized gather/scatter operations (no Python loops) with learned per-level gating weights.

Layer 4: Causal Structure

A causal adjacency matrix $W \in \mathbb{R}^{N \times N}$ is learned with Gumbel-softmax binary sampling, $\ell_1$ sparsity regularization, and a NOTEARS (Zheng, 2018) DAG constraint $h(A) = \text{tr}(e^{A \odot A}) - N = 0$ enforced via augmented Lagrangian optimization. In the current release, the GNN edge weights serve as the primary causal structure signal, with explicit DAG learning as a regularization pathway.

Layer 5: Continual Memory

Uses Hopfield networks with Elastic Weight Consolidation (EWC) for continual learning across tasks.

Loss Functions

The total loss depends on the training stage:

\[\text{Stage 1: } \mathcal{L} = 5.0 \cdot \mathcal{L}\_{\text{SBD}} + 0.1 \cdot \mathcal{L}\_{\text{diversity}}\] \[\text{Stage 2: } \mathcal{L} = \mathcal{L}\_{\text{JEPA}} + \mathcal{L}\_{\text{SBD}} + \lambda\_{\text{obj}} \mathcal{L}\_{\text{obj}} + \lambda\_{\text{causal}} \mathcal{L}\_{\text{causal}}\]

Key loss components:

  • $\mathcal{L}_{\text{SBD}}$: Spatial broadcast decoder reconstruction against frozen ViT features
  • $\mathcal{L}_{\text{JEPA}}$: JEPA-style latent next-state prediction loss
  • $\mathcal{L}_{\text{diversity}}$: Slot diversity regularizer
  • $\mathcal{L}_{\text{obj}}$: Object-level prediction loss
  • $\mathcal{L}_{\text{causal}}$: Causal graph regularizer (NOTEARS DAG constraint + $\ell_1$ sparsity)

Training

Hardware: NVIDIA H100 80GB GPU (single GPU; multi-GPU FSDP failed due to NCCL version incompatibility on their cloud provider).

Dataset: PushT task from LeRobot / Open X-Embodiment ecosystem — 206 episodes, 25,650 frames total. A robot pushes a T-shaped block toward a target with 2D end-effector displacement actions. 16-frame clips at $224 \times 224$ resolution.

Optimizer: AdamW with learning rate $1.5 \times 10^{-4}$, cosine schedule with 2,000-step warmup.

Batch size: 4.

Precision: bfloat16 mixed precision (GradScaler disabled, unnecessary on H100).

Training duration: 50,000 steps total (~6 hours per run).

  • Stage 1 (first 40%, 20K steps): Only $\mathcal{L}_{\text{SBD}}$ and $\mathcal{L}_{\text{diversity}}$ produce gradients. Prediction loss is computed for monitoring only. Forces slots to specialize spatially.
  • Stage 2 (remaining 60%, 30K steps): Full loss activated. SBD weight reduced from 5.0 to 1.0.

EMA target encoder: Updated with exponential moving average: $\theta^- \leftarrow \tau \theta^- + (1 - \tau) \theta$.

Numerical stability tricks (critical for bf16):

  1. Replaced x**2 with x*x in all loss functions (PowBackward0 produces NaN when inputs exceed bf16 range).
  2. SSM $A_{\log}$ initialized in $[-0.5, 0]$ instead of $\log(1 \ldots d_{\text{state}})$.
  3. All intermediate activations clamped to $[-50, 50]$ after ViT and slot attention GRU.
  4. GradScaler disabled (unnecessary for bf16 on H100).
  5. $\Delta_t A$ clamped to $[-20, 0]$.

Seed sensitivity: ~40–60% of training runs diverge to NaN within the first 1,000 steps due to seed-dependent gradient overflow in the slot attention GRU at bf16. Surviving runs converge reliably. The paper reports results across 2 successful runs out of 4 launched.

Engineering highlights:

  • Custom Triton SSM kernel: Parallelizes across $(B, d_{\text{inner}} / \text{block})$ dimensions, launching up to 4,096 parallel programs on a single H100. Achieves 38–39× speedup over sequential PyTorch.
  • GPU-native slot tracking: Replaced CPU Hungarian matching (scipy) with differentiable Sinkhorn-Knopp algorithm running entirely on GPU.
  • Memory-efficient GNN: Chunked computation for $N > 32$ slots processes source nodes in blocks of 16, reducing peak memory by $N/16$.

Codebase: 8,478 lines of Python across 51 modules with 171 unit tests.

Evaluation

Main Results (PushT, 68M params, 50K steps, H100)

Configuration Pred. $\downarrow$ Track. $\downarrow$ Diversity $\downarrow$ SBD $\downarrow$ Total $\downarrow$ Speed
HCLSM (no SBD) 0.002 0.001 0.154 0.100 2.3 sps
HCLSM (two-stage) 0.008 0.016 0.132 0.008 0.262 2.9 sps

Key observations:

  • Without SBD, prediction loss is lower (0.002) because all 32 slots encode the scene distributively — an easier but unstructured prediction target.
  • With two-stage training, prediction loss is higher (0.008) but SBD reconstruction reaches 0.008, indicating individual slots have learned to reconstruct specific spatial regions. Diversity loss is lower (0.132 vs 0.154), confirming slots are more differentiated.
  • Training speed is slightly faster with SBD (2.9 sps vs 2.3 sps).

Spatial Decomposition

After two-stage training, the SBD alpha masks reveal that different slots claim different spatial regions of the PushT scene. The decomposition is not yet clean — 32 slots is excessive for a 3-object scene and each object is split across ~10 slots — but spatial structure emerges. Prior runs without SBD showed uniform attention across all slots with no spatial structure.

Event Detection

The learned event detector fires 2–3 times per 16-frame sequence, corresponding to moments of significant state change (e.g., contact between robot end-effector and T-block). Trained via contrastive signal rewarding alignment between event probability and actual state-change magnitude.

Latent Dynamics

PCA projection of slot state trajectories (57% variance explained in best run, 33.5% in two-stage run) shows structured temporal evolution: slots follow smooth paths with direction changes at event boundaries, and different slots track different aspects of scene dynamics.

Triton SSM Kernel Performance (NVIDIA T4)

Config B$\times$N T Sequential Triton Speedup
Tiny 128 16 6.22 ms 0.16 ms 39.3×
Base 512 16 69.64 ms 1.83 ms 38.0×

Where It Wins

  • First architecture to unify object-centric slots, hierarchical temporal dynamics, and causal structure learning in a single differentiable model.
  • 38–39× SSM kernel speedup over PyTorch baseline.
  • Emerging spatial decomposition on real robot data (a hard unsupervised problem).
  • Functional event detection learning without explicit supervision.

Where It Falls Short

  • Slot count: All 32 slots remain alive; no dynamic slot killing. Reducing to 8 slots caused NaN gradients. Each object is split across ~10 slots instead of 1.
  • Causal discovery: Explicit causal adjacency matrix learns no edges (all collapse to zero under sparsity regularization). GNN edge weights provide implicit interaction structure but unverified against ground truth.
  • Scale: Only Small config (68M) successfully trained. Base (262M) and Large (3B) have NaN gradient issues at batch size $\geq$ 4.
  • Seed sensitivity: 40–60% of runs diverge to NaN in the first 1,000 steps.
  • No comparison against standard baselines on downstream RL/planning tasks — only ablation of its own two-stage protocol.

Reproduction Guide

Prerequisites

  • NVIDIA GPU with $\geq$ 80GB VRAM (H100 recommended; T4 works for smaller configs)
  • Python 3.10+
  • PyTorch 2.x with Triton support

Install and Run

# Clone the repository
git clone https://github.com/rightnow-ai/hclsm.git
cd hclsm

# Install dependencies
pip install -e .

# Download the PushT dataset (from LeRobot / Open X-Embodiment)
# PushT consists of 206 episodes, ~25,650 frames
# The LeRobot library handles downloading:
pip install lerobot
python -c "from lerobot.common.datasets.lerobot_dataset import LeRobotDataset; ..."

Training

# Train the Small (68M) configuration on PushT
python train.py \
  --config config/small.yaml \
  --dataset pusht \
  --data_path /path/to/pusht \
  --batch_size 4 \
  --lr 1.5e-4 \
  --warmup_steps 2000 \
  --total_steps 50000 \
  --stage_ratio 0.4 \
  --precision bf16

# Note: Multiple seeds may be needed due to ~40-60% NaN divergence rate.
# Launch with different seeds until a run survives past step 1000.

Verify

After training, check these indicators of success:

  1. SBD loss should decrease below ~0.01 during Stage 1 (first 20K steps), indicating slot specialization.
  2. Slot alpha heatmaps should show non-uniform spatial assignment (visualize with the evaluation suite).
  3. Prediction loss in Stage 2 should stabilize around 0.008–0.01.
  4. Diversity loss should be lower than ~0.15.
  5. Event detector should fire 2–3 times per 16-frame clip.
# Run evaluation suite
python eval.py \
  --checkpoint checkpoints/small/pusht/best.pt \
  --config config/small.yaml

# Run unit tests (171 tests)
python -m pytest tests/

Known Issues

  • If you hit NaN in the first 1,000 steps, try a different random seed.
  • Multi-GPU training via FSDP may fail with NCCL version issues.
  • The Base (262M) and Large (3B) configs are not currently stable at batch size $\geq$ 4.
  • The custom Triton SSM kernel requires compatible GPU architecture; falls back to sequential PyTorch if Triton is unavailable.

Notes

Key insight — structure must precede prediction: The central thesis is that when all losses are active from step zero, the dynamics objective dominates because distributed slot codes are easier to predict than object-specific ones. The model uses 32 slots as a collective distributed representation rather than assigning one slot per object. The two-stage protocol (reconstruction first, dynamics second) resolves this, inspired by the biological observation that object recognition precedes motion prediction in visual development (Schölkopf, 2021).

Connection to JEPA framework: HCLSM adopts a JEPA-style prediction loss for Stage 2, predicting future latent states rather than pixels. This aligns with the broader JEPA agenda (V-JEPA, V-JEPA 2) but adds object-centric structure on top.

Connection to Mamba/SSMs: The selective SSM at Level 0 follows the Mamba architecture closely (input-dependent $\Delta_t, B_t, C_t$) but applies it per-object rather than to a flat sequence. The Slot SSM paper (Jiang, 2024) also applied SSMs to object-centric representations, but HCLSM additionally layers sparse and goal transformers on top.

Causal reasoning is aspirational: Despite the architectural inclusion of NOTEARS-style DAG learning and Gumbel-softmax edge sampling, the explicit causal adjacency matrix learns no edges in practice. The paper is transparent about this — causal discovery is listed as a limitation, not a claimed result. The GNN edge weights provide an implicit interaction signal but are not validated against ground-truth causal structure.

The paper is candid about limitations: This is refreshing — the authors explicitly state that (1) slots don’t properly specialize, (2) causal discovery doesn’t work yet, (3) larger configs are numerically unstable, and (4) 40–60% of runs crash. The paper is positioned as a “foundation” or proof-of-concept rather than a finished system.

Engineering contribution: The custom Triton SSM kernel (38–39× speedup) and GPU-native Sinkhorn-Knopp slot matching are practical engineering contributions that could be reused independently of the full HCLSM architecture.

Missing comparisons: The paper does not benchmark against DreamerV3, SlotFormer, or other world models on any shared evaluation protocol. Results are only shown as an internal ablation (with/without SBD) on PushT. Downstream task performance (e.g., model-based planning or control) is not evaluated despite CEM/MPPI planners being implemented.