2026-03-29
LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero et al.
problem
joint embedding predictive architectures (JEPAs) are compelling for learning world models in latent spaces, but existing methods are fragile. they rely on complex multi-term losses, exponential moving averages, pre-trained encoders, or auxiliary supervision to avoid representation collapse. no one has built a JEPA that trains stably end-to-end from raw pixels with a simple loss.
the core tension: JEPAs need to learn a latent space where you can predict the next state, but without pixel-level reconstruction (which wastes capacity on irrelevant details). the “collapse” problem means the encoder can cheat by mapping everything to the same representation, making prediction trivially easy but useless.
prior art and their collapse prevention:
- I-JEPA (assran et al.): uses a stop-gradient on the target encoder $+$ EMA $+$ learned predictor. works but requires careful tuning and pretrained patches.
- V-JEPA (bardes et al.): similar approach for video, uses pretrained vision transformer backbone.
- Joint-Embedding Predictive Architecture (chen et al.): multi-term loss with both intra- and inter-view terms.
architecture
flowchart LR
x_t[x_t pixels] --> E[encoder E]
x_t1[x_t+1 pixels] --> E2[encoder E]
E --> z_t[z_t latent]
E2 --> z_t1[z_t+1 latent]
z_t --> P[predictor P]
z_t1 -.->|stop-gradient| L2[l2 prediction loss]
P --> L2
z_t --> KL[gaussian prior]
N[N 0 I] --> KL
KL --> D_KL[D_KL regularizer]
D_KL --> Loss[total loss L]
L2 --> Loss
encoder $E$: maps raw pixels $x_t$ to latent $z_t$. standard vision encoder (conv layers).
predictor $P$: takes latent $z_t$ and predicts the next latent $z_{t+1}$.
loss: two terms total:
- prediction loss: $\ell_2$ distance between $P(z_t)$ and $E(x_{t+1})$ with stop-gradient on $E(x_{t+1})$
- gaussian prior regularizer: penalizes deviation of the latent distribution from a standard Gaussian. this is the key contribution - it replaces all the complex collapse prevention machinery.
the single hyperparameter is $\lambda$ (weight of the prior regularizer). compared to prior work which needs $6+$ tunable loss hyperparameters.
why this works: the Gaussian prior forces the latent space to spread out and use its full capacity. collapse means all inputs map to the same point, which would violate the Gaussian prior (it would become a delta, not a Gaussian). so the prior regularizer alone is sufficient to prevent collapse.
training
- $\sim$15M parameters total
- single GPU (not specified which, but training takes “a few hours”)
- trains end-to-end from raw pixels (no pretrained encoder, no pretrained patches, no auxiliary tasks)
- standard optimizer setup (details in paper)
evaluation
planning benchmarks
tested on diverse 2D and 3D control tasks:
- plans up to $48\times$ faster than foundation-model-based world models (like UniSim, TWM)
- remains competitive on task completion metrics despite the speedup
latent space quality
- probing experiments: linear probes on the latent space can predict physical quantities (position, velocity), confirming the latent space encodes meaningful physical structure rather than pixel statistics
- surprise detection: the model reliably assigns high surprise to physically implausible events (objects teleporting, impossible physics), confirming it learns real physics, not just frame correlations
comparison to prior jepa approaches
- simpler loss ($2$ terms vs $6+$)
- fewer hyperparameters ($1$ vs $6+$)
- end-to-end from pixels (vs needing pretrained encoders)
- comparable or better task performance
reproduction guide
- clone https://github.com/lucas-maes/le-wm
- install dependencies (pytorch, standard ML stack)
- single GPU training, a few hours
- the key experiment: train on a simple 2D control task first to verify the Gaussian prior prevents collapse. if you see the loss going to near-zero but prediction quality is garbage, collapse is happening and $\lambda$ needs adjustment.
- compare with/without the prior regularizer to see collapse firsthand
notes
the key insight is that representation collapse can be avoided with just a Gaussian prior regularizer instead of the complex machinery other JEPAs use. this is a strong contribution because simplicity in the loss function directly translates to reproducibility and ease of use.
the $48\times$ speedup over foundation models while staying competitive is significant for robotics where iteration speed matters. a model you can retrain in hours vs days changes the development loop entirely.
open questions for a research sprint:
- does this generalize to manipulation tasks with diverse visual backgrounds? their eval is on relatively clean simulated environments
- how sensitive is $\lambda$? is there a narrow band where it works, or is it robust?
- can you combine this with a diffusion decoder for pixel-level prediction when needed (hybrid approach)?
- what happens at scale? does the simplicity hold with larger encoders and longer prediction horizons?