Trajectory Optimization for Model-Based Planning
An online planning algorithm selects actions through learned-world-model trajectory optimization to improve episode reward.

Description
Planning Algorithm for Model-Based RL
Objective
Design and implement a custom trajectory optimization algorithm for online planning in model-based reinforcement learning. Your code goes in the custom_plan() function in custom_planner.py. This function is called at every environment step to select actions using the learned world model.
Background
TD-MPC2 (Hansen, Su, Wang, ICLR 2024, arXiv:2310.16828) — the scalable successor of TD-MPC (Hansen, Wang, Su, ICML 2022, arXiv:2203.04955) — uses Model Predictive Path Integral (MPPI) (Williams et al., arXiv:1509.01149) for planning. At each step, the agent:
- Samples
num_pi_trajs = 24trajectories from the learned policy as warm-starts. - Iterates
iterations = 6rounds of:- Sample
num_samples = 512action sequences from N(mean, std). - Roll out each trajectory through the latent dynamics model for
horizon = 3steps. - Estimate trajectory value using predicted rewards + terminal Q-value.
- Select
num_elites = 64best trajectories. - Update mean / std using softmax-weighted (temperature = 0.5) elite statistics.
- Sample
- Selects the final action via Gumbel-softmax sampling from elites.
Alternative planning approaches could improve sample efficiency, convergence speed, or final performance:
- Cross-Entropy Method (CEM): simpler elite selection without softmax weighting.
- iCEM (Pinneri et al., CoRL 2020, arXiv:2008.06389): improved CEM with temporally correlated (colored) noise and keep-elites.
- Gradient-based planning: backpropagating through the world model.
- Hybrid approaches: combining sampling with gradient refinement.
- Adaptive methods: adjusting sampling parameters during optimization.
What You Can Modify
The custom_plan() function in custom_planner.py. You have access to:
agent.model: WorldModel withencode,next,pi,Q,rewardmethodsagent._estimate_value(z, actions, task): evaluates trajectory returnsagent._prev_mean: warm-start buffer from previous planning stepagent.cfg: all configuration parameters (horizon, num_samples, etc.)common.math: utility functions (gumbel_softmax_sample,two_hot_inv, etc.)
Evaluation
- Metric: episode reward (higher is better)
- Environments: DMControl walker-walk and cheetah-run
- Model: TD-MPC2 with 1M parameters, 200K training steps
- Note: the planning algorithm affects both data collection quality during training and action selection during evaluation.
Key Constraints
- The function must return a single action tensor of shape
(action_dim,)clamped to[-1, 1]. - The function runs under
@torch.no_grad()— no gradient computation. - Must update
agent._prev_meanfor temporal consistency across steps. - Planning budget: keep total computation comparable to the default (6 iterations × 512 samples).
Code
1"""Custom planning algorithm for TD-MPC2.23Replace the planning logic in custom_plan() with your trajectory4optimization method. The function is called at each environment step5to select actions using the learned world model.6"""78import torch9from common import math101112# =====================================================================13# EDITABLE: Custom planning algorithm14# =====================================================================15@torch.no_grad()
1import torch2import torch.nn.functional as F34from common import math5from common.scale import RunningScale6from common.world_model import WorldModel7from common.layers import api_model_conversion8from tensordict import TensorDict91011class TDMPC2(torch.nn.Module):12"""13TD-MPC2 agent. Implements training + inference.14Can be used for both single-task and multi-task experiments,15and supports both state and pixel observations.
1from copy import deepcopy23import torch4import torch.nn as nn56from common import layers, math, init7from tensordict import TensorDict8from tensordict.nn import TensorDictParams91011class WorldModel(nn.Module):12"""13TD-MPC2 implicit world model architecture.14Can be used for both single-task and multi-task experiments.15"""
1import torch2import torch.nn.functional as F3from tensordict import TensorDict456def soft_ce(pred, target, cfg):7"""Computes the cross entropy loss between predictions and soft targets."""8pred = F.log_softmax(pred, dim=-1)9target = two_hot(target, cfg)10return -(target * pred).sum(-1, keepdim=True)111213def log_std(x, low, dif):14return low + 0.5 * dif * (torch.tanh(x) + 1)15
Method Summary
Hybrid MPPI/iCEM with rank blend
iCEM-style elite retention + colored noise + adaptive temperature, with the MPPI softmax score blended with rank weights and a momentum mean update.
1. Init mean/std from warm-start; n_keep = ⌈0.15·num_elites⌉2. for iter = 0..I-1 do3. // 0.5T → 2T4. Generate AR(1) colored noise: ,5. Sample ; insert kept elites + π warm-starts6. Score elites; ,7.8. weighted mean/std of elites by9. (no momentum on iter 0); kept ← top-n_keep elites10. Final action: Gumbel sample over ; add noise unless eval_mode