Trajectory Optimization for Model-Based Planning

An online planning algorithm selects actions through learned-world-model trajectory optimization to improve episode reward.

Roboticstdmpc2Claude Opus 4.6 beats every baseline
tdmpc2-planning

Description

Planning Algorithm for Model-Based RL

Objective

Design and implement a custom trajectory optimization algorithm for online planning in model-based reinforcement learning. Your code goes in the custom_plan() function in custom_planner.py. This function is called at every environment step to select actions using the learned world model.

Background

TD-MPC2 (Hansen, Su, Wang, ICLR 2024, arXiv:2310.16828) — the scalable successor of TD-MPC (Hansen, Wang, Su, ICML 2022, arXiv:2203.04955) — uses Model Predictive Path Integral (MPPI) (Williams et al., arXiv:1509.01149) for planning. At each step, the agent:

  1. Samples num_pi_trajs = 24 trajectories from the learned policy as warm-starts.
  2. Iterates iterations = 6 rounds of:
    • Sample num_samples = 512 action sequences from N(mean, std).
    • Roll out each trajectory through the latent dynamics model for horizon = 3 steps.
    • Estimate trajectory value using predicted rewards + terminal Q-value.
    • Select num_elites = 64 best trajectories.
    • Update mean / std using softmax-weighted (temperature = 0.5) elite statistics.
  3. Selects the final action via Gumbel-softmax sampling from elites.

Alternative planning approaches could improve sample efficiency, convergence speed, or final performance:

  • Cross-Entropy Method (CEM): simpler elite selection without softmax weighting.
  • iCEM (Pinneri et al., CoRL 2020, arXiv:2008.06389): improved CEM with temporally correlated (colored) noise and keep-elites.
  • Gradient-based planning: backpropagating through the world model.
  • Hybrid approaches: combining sampling with gradient refinement.
  • Adaptive methods: adjusting sampling parameters during optimization.

What You Can Modify

The custom_plan() function in custom_planner.py. You have access to:

  • agent.model: WorldModel with encode, next, pi, Q, reward methods
  • agent._estimate_value(z, actions, task): evaluates trajectory returns
  • agent._prev_mean: warm-start buffer from previous planning step
  • agent.cfg: all configuration parameters (horizon, num_samples, etc.)
  • common.math: utility functions (gumbel_softmax_sample, two_hot_inv, etc.)

Evaluation

  • Metric: episode reward (higher is better)
  • Environments: DMControl walker-walk and cheetah-run
  • Model: TD-MPC2 with 1M parameters, 200K training steps
  • Note: the planning algorithm affects both data collection quality during training and action selection during evaluation.

Key Constraints

  • The function must return a single action tensor of shape (action_dim,) clamped to [-1, 1].
  • The function runs under @torch.no_grad() — no gradient computation.
  • Must update agent._prev_mean for temporal consistency across steps.
  • Planning budget: keep total computation comparable to the default (6 iterations × 512 samples).

Code

custom_planner.py
EditableRead-only
1"""Custom planning algorithm for TD-MPC2.
2
3Replace the planning logic in custom_plan() with your trajectory
4optimization method. The function is called at each environment step
5to select actions using the learned world model.
6"""
7
8import torch
9from common import math
10
11
12# =====================================================================
13# EDITABLE: Custom planning algorithm
14# =====================================================================
15@torch.no_grad()
tdmpc2.py
EditableRead-only
1import torch
2import torch.nn.functional as F
3
4from common import math
5from common.scale import RunningScale
6from common.world_model import WorldModel
7from common.layers import api_model_conversion
8from tensordict import TensorDict
9
10
11class TDMPC2(torch.nn.Module):
12 """
13 TD-MPC2 agent. Implements training + inference.
14 Can be used for both single-task and multi-task experiments,
15 and supports both state and pixel observations.
world_model.py
EditableRead-only
1from copy import deepcopy
2
3import torch
4import torch.nn as nn
5
6from common import layers, math, init
7from tensordict import TensorDict
8from tensordict.nn import TensorDictParams
9
10
11class WorldModel(nn.Module):
12 """
13 TD-MPC2 implicit world model architecture.
14 Can be used for both single-task and multi-task experiments.
15 """
math.py
EditableRead-only
1import torch
2import torch.nn.functional as F
3from tensordict import TensorDict
4
5
6def soft_ce(pred, target, cfg):
7 """Computes the cross entropy loss between predictions and soft targets."""
8 pred = F.log_softmax(pred, dim=-1)
9 target = two_hot(target, cfg)
10 return -(target * pred).sum(-1, keepdim=True)
11
12
13def log_std(x, low, dif):
14 return low + 0.5 * dif * (torch.tanh(x) + 1)
15

Method Summary

Auto-summarized from each method's code by an LLM reviewer — not the model's original output. Browse via the picker below; the Code section is independent.Beats every baseline on every metric
Baselines
Agents
Claude Opus 4.6·Pseudocodehigh

Hybrid MPPI/iCEM with rank blend

iCEM-style elite retention + colored noise + adaptive temperature, with the MPPI softmax score blended with rank weights and a momentum mean update.

1. Init mean/std from warm-start; n_keep = ⌈0.15·num_elites⌉
2. for iter = 0..I-1 do
3. Titer=T(0.5+(iter/(I1))1.5)T_{\mathrm{iter}} = T \cdot (0.5 + (\mathrm{iter}/(I-1))\cdot 1.5) // 0.5T → 2T
4. Generate AR(1) colored noise: ηt=0.25ηt1+0.75ξt\eta_t = 0.25\eta_{t-1} + 0.75 \xi_t, ξtN\xi_t \sim \mathcal{N}
5. Sample atμ+σηta_t \leftarrow \mu + \sigma\,\eta_t; insert kept elites + π warm-starts
6. Score elites; ssoft=softmax(TiterVelite)s_{\mathrm{soft}} = \mathrm{softmax}(T_{\mathrm{iter}}\,V_{\mathrm{elite}}), srank1/(1+rank)s_{\mathrm{rank}}\propto 1/(1+\mathrm{rank})
7. s=0.8ssoft+0.2sranks = 0.8\,s_{\mathrm{soft}} + 0.2\,s_{\mathrm{rank}}
8. μnew,σnew\mu_{\mathrm{new}}, \sigma_{\mathrm{new}} \leftarrow weighted mean/std of elites by ss
9. μ0.15μ+0.85μnew\mu \leftarrow 0.15\,\mu + 0.85\,\mu_{\mathrm{new}} (no momentum on iter 0); kept ← top-n_keep elites
10. Final action: Gumbel sample over ss; add σ0\sigma_0 noise unless eval_mode
Δ vs. baselineCombines pieces from iCEM (elite retention, colored noise) with MPPI (softmax-weighted moment update + Gumbel selection), and adds (a) per-iteration ramp of MPPI temperature, (b) blend of softmax score with rank-based weights to reduce outlier sensitivity, and (c) momentum on the mean update.
keep_fraction=0.15momentum=0.15temp_start_mult=0.5temp_end_mult=2.0noise_beta=0.25rank_blend=0.2Recovers Standard MPPI when keep_fraction=0, momentum=0, rank_blend=0, noise_beta=0, temp mults=1

Results