2026-04-02
Enhancing Policy Learning with World-Action Model
Yuci Han, Alper Yilmaz
problem
Conventional world models (DreamerV2 [Hafner et al., 2022], DreamerV3 [Hafner et al., 2025], PlaNet [Hafner et al., 2019]) are trained solely to predict future observations conditioned on past observations and actions. They never explicitly model the actions that drive state transitions. This creates an asymmetry: the latent state $z_t$ serves as the direct input to downstream policies (e.g., in DiWA [Chandra et al., 2025]), yet $z_t$ is optimized only for pixel reconstruction and KL regularization with no explicit pressure to encode action-relevant structure.
Prior art and limitations:
- DreamerV2 / DreamerV3 (Hafner et al.): RSSM-based world models trained purely via observation prediction. Latent features may discard fine-grained information about how the environment responds to agent behavior.
- DiWA [Chandra et al., 2025, CoRL]: Uses a DreamerV2 world model as a learned simulator for diffusion policy fine-tuning via PPO. Relies on observation-only features from the world model.
- DITTO [DeMoss et al., 2025, ICLR]: Imitation learning by optimizing a distance metric in a frozen world model’s latent space.
- LUMOS [Nematollahi et al., 2025, ICRA]: Extends world-model-based imitation to language-conditioned multi-task learning.
- DayDreamer [Wu et al., 2022, CoRL]: Demonstrates latent-imagination policy training on physical robots.
- WMPO [Zhu et al., 2025]: Pixel-based world models for on-policy GRPO optimization of VLA models.
- World4RL [Jiang et al., 2025]: Diffusion-based world models for end-to-end policy refinement.
- WorldVLA [Cen et al., 2025]: Unified action-world model generating actions and images jointly within a single autoregressive architecture. Requires large foundation models and fundamentally redesigns the architecture.
Key gap: All prior world-model-based policy learning methods use observation-only world models. None explicitly regularize latent representations toward action-relevant structure during world model training. WorldVLA addresses joint action-observation generation but does so by replacing the architecture entirely rather than improving existing world model representations.
architecture
WAM is a lightweight extension of DreamerV2 that augments the training objective with an inverse dynamics head, producing representations simultaneously predictive of visual dynamics and aware of the actions driving state transitions.
World Model Backbone (DreamerV2 RSSM)
A dual-stream CNN encoder processes static and gripper camera images ($64 \times 64$ RGB), fusing them with proprioceptive state to produce encoder embeddings $e_t \in \mathbb{R}^{1554}$. The RSSM models latent dynamics through four equations:
\[h\_t = f\_\phi(h\_{t-1}, z\_{t-1}, a\_{t-1})\] \[z\_t \sim q\_\phi(z\_t \mid h\_t, e\_t)\] \[\hat{z}\_t \sim p\_\phi(z\_t \mid h\_t)\] \[\hat{o}\_t \sim p\_\phi(o\_t \mid f\_t)\]where:
- $h_t$ is the deterministic recurrent state (GRU-based)
- $z_t$ is a stochastic categorical variable of size $32 \times 32$ (discrete, as in DreamerV2)
- $f_t = [h_t; z_t] \in \mathbb{R}^{2048}$ is the combined latent feature used for decoding and policy learning
- The decoder reconstructs observations $\hat{o}_t$ from $f_t$
Inverse Dynamics Head (WAM’s Key Addition)
A three-layer MLP $\psi$ predicts actions from consecutive encoder embeddings:
\[\hat{a}\_t = \psi([e\_t; e\_{t+1}])\]where $[\cdot; \cdot]$ denotes concatenation. The input is $2 \times 1554 = 3108$-dimensional.
Critical design choice: The inverse dynamics head operates on encoder embeddings $e_t$ rather than RSSM features $f_t$. This is because $f_t$ receives $a_{t-1}$ through the GRU (Eq. 1), which would make action prediction trivially solvable by simply reading $a_{t-1}$ from the recurrent state.
Cascading Effect
The action-aware structure cascades through the full model in a chain:
- Encoder: The inverse dynamics head regularizes encoder embeddings $e_t$ to retain action-relevant information
- Posterior: $z_t \sim q_\phi(z_t \mid h_t, e_t)$ is shaped by action-aware $e_t$
- Prior: The KL loss $L_{KL} = \text{KL}(q_\phi(z_t \mid h_t, e_t) \parallel p_\phi(z_t \mid h_t))$ propagates action-aware structure from posterior to prior $\hat{z}_t$
- Imagined rollouts: During policy fine-tuning, the prior $\hat{z}_t$ generates latent states for imagined trajectories, ensuring the diffusion policy benefits from action-relevant representations
Training Objective
WAM is trained end-to-end by minimizing:
\[\mathcal{L}\_{WAM} = \lambda\_{KL} \mathcal{L}\_{KL} + \lambda\_{img} \mathcal{L}\_{recon} + \lambda\_{act} \mathcal{L}\_{action}\]where:
- $\mathcal{L}_{KL} = \text{KL}(q_\phi(z_t \mid h_t, e_t) \parallel p_\phi(z_t \mid h_t))$ — posterior-prior KL divergence
- $\mathcal{L}_{recon} = |o_t - \hat{o}_t|_2^2$ — observation reconstruction loss (L2)
- $\mathcal{L}_{action} = |\hat{a}_t - a_t|_1$ — action prediction loss (L1 regression)
Loss weights: $\lambda_{KL} = 3.0$, $\lambda_{img} = 1.0$, $\lambda_{act} = 1000.0$.
Downstream Policy: DiffusionMLP
A diffusion-based policy $\pi_\theta(a_t \mid f_t)$ generates actions by iteratively denoising Gaussian noise through $K$ steps:
\[a\_t^{k-1} = \mu\_\theta(f\_t, a\_t^k, k) + \sigma\_k \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]where $a_t^K \sim \mathcal{N}(0, I)$ and the policy minimizes:
\[\mathcal{L}\_{BC} = \mathbb{E}\_{k, \epsilon, (f\_t, a\_t)} \|\mu\_\theta(f\_t, a\_t^k, k) - a\_t^{k-1}\|_2^2\]After BC pretraining, the policy is refined via DPPO [Ren et al., 2024] entirely within the frozen world model’s latent space:
\[\theta^* = \arg\max\_\theta \mathbb{E}\_{\tau \sim \pi\_\theta, P\_\phi} \left[ \sum\_{t=0}^{T} \gamma^t R\_\psi(z\_t, a\_t) \right]\]where $R_\psi$ is a binary reward classifier trained on world model features.
Parameter Count
No explicit total parameter count is provided. The inverse dynamics head is a lightweight three-layer MLP operating on 3108-dimensional input, adding negligible parameters relative to the DreamerV2 backbone (encoder, RSSM, decoder).
training
World Model Training
- Dataset: CALVIN benchmark environment D — 6 hours of teleoperated play data (~500K transitions, ~512K play frames) from a 7-DoF Franka Emika Panda robot
- Observations: $64 \times 64$ RGB images from static and gripper cameras + proprioceptive state
- Optimizer: AdamW with weight decay 0.05
- Learning rate: $3 \times 10^{-4}$
- Batch size: 500
- Sequence length: $T = 50$
- KL balance coefficient: $\alpha = 0.8$
- Loss weights: $\lambda_{KL} = 3.0$, $\lambda_{img} = 1.0$, $\lambda_{act} = 1000.0$
- Gradient steps: 230K (roughly 8.7x fewer than the 2M steps used by the DreamerV2/DreamingV2 baseline)
- Hardware: Not explicitly specified in the paper
Behavioral Cloning Training
- Expert data: 50 episodes per task (8 tasks)
- Policy: DiffusionMLP with $K = 20$ denoising steps, action horizon $T_a = 4$
- Epochs: 5,000
- Batch size: 256
- Learning rate: $10^{-4}$ with cosine decay to $10^{-5}$
- Weight decay: $10^{-6}$
- EMA decay: 0.995
- Features: $f_t \in \mathbb{R}^{2048}$ extracted from frozen WAM encoder + RSSM
PPO Fine-tuning
- Optimizer: Clipped PPO (DPPO)
- Rollouts per iteration: 50 parallel imagined trajectories within the frozen world model
- PPO batch size: 7,500
- Update epochs per iteration: 10
- Actor learning rate: $10^{-5}$
- Critic learning rate: $10^{-3}$
- Discount factor: $\gamma = 0.999$
- GAE lambda: $\lambda = 0.95$
- Denoising steps during rollouts: 10
- BC regularization: $\alpha_{BC} = 0.025$ (prevents catastrophic forgetting during PPO)
- Total iterations: 800 (evaluation every 25 iterations)
- Reward classifiers: Binary contrastive classifiers retrained on WAM features; all 8 classifiers achieve $\geq 0.97$ precision and 1.00 recall
Special Tricks
- Encoder-level inverse dynamics: Operating on $e_t$ instead of $f_t$ avoids trivial action prediction via the GRU’s memory of past actions
- Cascading regularization: KL divergence propagates action-aware structure from posterior through to the prior, benefiting imagined rollouts
- BC regularization during PPO: $\alpha_{BC} = 0.025$ anchors the policy to pretrained behavior during online fine-tuning
- Reduced denoising during rollouts: 10 steps (vs 20 during BC) for computational efficiency during imagined rollouts
- Reward classifier retraining: Classifiers are retrained on WAM features (not reused from DreamerV2) since the latent space differs
evaluation
Benchmark: CALVIN (Environment D)
8 tabletop manipulation tasks with a 7-DoF Franka Emika Panda robot. Evaluation uses 29 held-out initial configurations per task. An episode succeeds if CALVIN’s built-in task checker confirms completion (up to 18 decision points per episode, 72 max steps).
Generation Quality (Table I)
50-step open-loop imagination rollouts conditioned on the first ground-truth observation with ground-truth actions. 100 random sequences from CALVIN validation set.
| Metric | WAM (Ours) | DreamerV2 Baseline |
|---|---|---|
| PSNR $\uparrow$ | 22.10 $\pm$ 2.22 | 21.66 $\pm$ 2.20 |
| SSIM $\uparrow$ | 0.814 $\pm$ 0.061 | 0.807 $\pm$ 0.067 |
| LPIPS $\downarrow$ | 0.144 $\pm$ 0.072 | 0.149 $\pm$ 0.073 |
| FVD $\downarrow$ | 10.82 | 12.13 |
WAM outperforms DreamerV2 across all four metrics, with fewer training steps (230K vs 2M). Qualitative results show WAM preserves object shapes, fine details, and colors better, while DreamerV2 exhibits color drift and distorted object shapes.
Behavioral Cloning Results (Table III)
Both methods use identical DiffusionMLP policy architecture and training procedure; the only difference is which world model provides features.
| Task | DiWA | WAM (Ours) | $\Delta$ |
|---|---|---|---|
| close drawer | 58.6 $\pm$ 4.2 | 89.7 $\pm$ 3.1 | +31.1 |
| open drawer | 53.3 $\pm$ 5.1 | 73.3 $\pm$ 4.8 | +20.0 |
| move slider left | 50.0 $\pm$ 3.7 | 68.8 $\pm$ 5.2 | +18.8 |
| move slider right | 51.7 $\pm$ 4.5 | 82.8 $\pm$ 3.9 | +31.1 |
| turn on lightbulb | 42.4 $\pm$ 3.3 | 51.5 $\pm$ 4.6 | +9.1 |
| turn off lightbulb | 3.4 $\pm$ 1.8 | 17.2 $\pm$ 3.4 | +13.8 |
| turn on led | 44.8 $\pm$ 3.9 | 41.4 $\pm$ 4.1 | -3.4 |
| turn off led | 62.5 $\pm$ 5.3 | 68.8 $\pm$ 4.7 | +6.3 |
| Average | 45.8 | 61.7 | +15.9 |
WAM wins on 7 of 8 tasks. The largest gains appear on articulated object tasks (close drawer +31.1pp, move slider right +31.1pp, open drawer +20.0pp) where precise position control is critical. The only task where DiWA marginally wins is turn on led (44.8% vs 41.4%), attributed to evaluation variance with small test episode count.
PPO Fine-tuning Results (Table IV)
After 800 iterations of model-based PPO fine-tuning inside the frozen world model:
| Task | DiWA | WAM (Ours) | $\Delta$ |
|---|---|---|---|
| open drawer | 70.0 | 96.7 $\pm$ 2.4 | +26.7 |
| close drawer | — | 96.6 $\pm$ 1.8 | — |
| move slider left | — | 87.5 $\pm$ 3.7 | — |
| move slider right | — | 89.7 $\pm$ 3.2 | — |
| turn on lightbulb | — | 100.0 $\pm$ 0.0 | — |
| turn off lightbulb | — | 75.9 $\pm$ 4.3 | — |
| turn on led | 86.2 | 96.6 $\pm$ 2.1 | +10.4 |
| turn off led | — | 100.0 $\pm$ 0.0 | — |
| Average | 79.8 | 92.8 | +13.0 |
WAM outperforms DiWA on every single task. Two tasks reach 100% success (turn on lightbulb, turn off led). The largest gains are on open drawer (+26.7pp) and turn on led (+10.4pp). WAM achieves this with 8.7x fewer world model training steps (230K vs 2M).
Sample Efficiency of PPO Fine-tuning
WAM requires substantially fewer environment-equivalent steps to match DiWA’s final performance:
| Metric | WAM (Ours) | DiWA |
|---|---|---|
| Total physical interactions | ~2.5M | ~8M |
| World model training steps | 230K | 2M |
WAM is roughly 3.2x more sample-efficient in PPO fine-tuning and 8.7x more efficient in world model training.
Where WAM Wins
- Generation quality: all four metrics (PSNR, SSIM, LPIPS, FVD)
- Behavioral cloning: 7/8 tasks, +15.9pp average improvement
- PPO fine-tuning: all 8 tasks, +13.0pp average improvement
- Sample efficiency: 8.7x fewer world model steps, 3.2x fewer PPO environment steps
- Articulated object manipulation: largest gains on drawer and slider tasks
Where WAM Underperforms
- turn on led BC: marginally loses to DiWA (41.4% vs 44.8%), likely due to evaluation variance (only 29 test episodes)
reproduction guide
The paper does not provide an official code repository. The following steps outline how to reproduce based on the described pipeline:
Step 1: Environment Setup
# Install dependencies
pip install torch torchvision
pip install tensorflow tensorflow-probability # DreamerV2 dependencies
pip install diffusers # For diffusion policy
pip install calvin-env # CALVIN benchmark
pip install gym
Step 2: Prepare CALVIN Dataset
# Download CALVIN Environment D dataset
# ~6 hours of teleoperated play data (~500K transitions)
# 64x64 RGB images from static and gripper cameras
# Follow: https://github.com/mees/calvin
Step 3: Train World Model (WAM)
Modify DreamerV2 to add inverse dynamics head:
- Add a 3-layer MLP $\psi$ that takes concatenated encoder embeddings $[e_t; e_{t+1}]$ (dim 3108) and predicts action $\hat{a}_t$
- Add action prediction loss $\mathcal{L}_{action} = |\hat{a}_t - a_t|_1$ with weight $\lambda_{act} = 1000.0$
- Train with AdamW (lr $3 \times 10^{-4}$, weight decay 0.05), batch size 500, sequence length 50
- KL balance $\alpha = 0.8$, $\lambda_{KL} = 3.0$, $\lambda_{img} = 1.0$
- Train for 230K gradient steps on ~512K play frames
Step 4: Train Diffusion Policy via BC
- Freeze WAM encoder and RSSM
- Extract $f_t \in \mathbb{R}^{2048}$ features from 50 expert demonstrations per task
- Train DiffusionMLP with $K=20$ denoising steps, action horizon $T_a = 4$
- Train for 5,000 epochs, batch size 256, lr $10^{-4}$ (cosine decay to $10^{-5}$)
- Weight decay $10^{-6}$, EMA decay 0.995
Step 5: Retrain Reward Classifiers
- Featurize the entire CALVIN play dataset with the WAM encoder
- Extract matched features and raw states from 50 expert episodes per task (seed 42)
- Generate ground-truth reward labels by replaying expert actions in CALVIN simulator
- Augment training data with imagined rollouts inside WAM
- Balance class distribution
- Train contrastive reward classifiers (target: $\geq 0.97$ precision, 1.00 recall)
Step 6: PPO Fine-tuning
- Generate 50 parallel imagined rollouts per iteration in frozen WAM
- Use 10 denoising steps during rollouts with BC regularization $\alpha_{BC} = 0.025$
- PPO batch size 7,500, 10 update epochs per iteration
- Actor lr $10^{-5}$, critic lr $10^{-3}$, $\gamma = 0.999$, GAE $\lambda = 0.95$
- Train for 800 iterations, evaluate every 25 iterations
Step 7: Verify Results
Expected benchmarks on 29 held-out episodes per task:
- BC average: ~61.7% (range 17.2% to 89.7%)
- PPO average: ~92.8% (range 75.9% to 100.0%)
notes
Key Takeaways
-
Simple but effective: WAM’s core contribution is remarkably simple — a 3-layer MLP inverse dynamics head added to DreamerV2. Yet it yields substantial gains (+15.9pp BC, +13.0pp PPO) without any policy architecture changes.
-
Representation quality matters for policy learning: The cascading effect from encoder to prior to imagined rollouts is the central insight. Action-aware structure in the encoder propagates through the entire world model pipeline, improving both generation quality and downstream control.
-
Training efficiency: WAM achieves state-of-the-art results with 8.7x fewer world model training steps (230K vs 2M), suggesting the action prediction objective provides a stronger learning signal.
-
Complementary to architectural innovations: Unlike WorldVLA which replaces the architecture with a large foundation model, WAM improves existing world model representations through a simple training objective change. These approaches are complementary.
Connections to Other Work
-
Inverse dynamics in self-supervised learning: Draws on Pathak et al. [2017, ICML] (curiosity-driven exploration via inverse dynamics) and Baker et al. [2022, NeurIPS] (VPT — learning to act by watching videos). The key principle: predicting actions between consecutive states yields representations focused on controllable aspects of the environment while filtering out task-irrelevant distractors.
-
DiWA [Chandra et al., 2025]: WAM’s most direct baseline. DiWA uses DreamerV2 + DiffusionMLP + PPO fine-tuning. WAM replaces only the world model (DreamerV2 → WAM) while keeping the identical policy and training pipeline, making the comparison clean.
-
DITTO [DeMoss et al., 2025]: Optimizes a distance metric in a frozen world model’s latent space for offline IL. WAM’s improved representations could directly benefit DITTO.
-
DreamingV2 [Okada & Taniguchi, 2022]: The generation quality baseline (DreamerV2 trained for 2M steps without reconstruction). WAM outperforms it on all metrics with 8.7x fewer steps.
-
Action-conditional vs action-predictive: The distinction between treating actions as conditioning inputs (standard world models) vs. jointly predicting them (WAM) parallels the broader trend in representation learning from label-conditional to label-predictive objectives.
Limitations and Open Questions
- Single benchmark: Evaluation is limited to CALVIN (8 tabletop tasks). Generalization to other manipulation benchmarks (LIBERO, RLBench), locomotion, or real robots is untested.
- Hardware not specified: No GPU type or wall-clock training time is reported, making exact computational cost comparisons difficult.
- No code release: The paper lacks an official implementation.
- Abstract vs. body inconsistency: The abstract reports BC averages of 59.4% → 71.2%, while the body/conclusion reports 45.8% → 61.7%. The table data (45.8%, 61.7%) matches the conclusion text.
- Action loss weight sensitivity: $\lambda_{act} = 1000.0$ is very large relative to other weights. The paper states these were “carefully tuned” but does not provide an ablation over this value.
- No ablation study on head design: The paper does not compare the 3-layer MLP against simpler (linear) or more complex (deeper, attention-based) inverse dynamics heads.