tdmpc2-planning
Othertdmpc2rigorous codebase
Description
Planning Algorithm for Model-Based RL
Objective
Design and implement a custom trajectory optimization algorithm for online planning in model-based reinforcement learning. Your code goes in the custom_plan() function in custom_planner.py. This function is called at every environment step to select actions using the learned world model.
Background
TD-MPC2 uses Model Predictive Path Integral (MPPI) for planning. At each step, the agent:
- Samples
num_pi_trajs=24trajectories from the learned policy as warm-starts - Iterates
iterations=6rounds of:- Samples
num_samples=512action sequences from N(mean, std) - Rolls out each trajectory through the latent dynamics model for
horizon=3steps - Estimates trajectory value using predicted rewards + terminal Q-value
- Selects
num_elites=64best trajectories - Updates mean/std using softmax-weighted (temperature=0.5) elite statistics
- Samples
- Selects final action via Gumbel-softmax sampling from elites
Alternative planning approaches could improve sample efficiency, convergence speed, or final performance:
- Cross-Entropy Method (CEM): simpler elite selection without softmax weighting
- iCEM: improved CEM with temporally correlated noise and keep-elites
- Gradient-based planning: backpropagating through the world model
- Hybrid approaches: combining sampling with gradient refinement
- Adaptive methods: adjusting sampling parameters during optimization
What You Can Modify
The custom_plan() function (lines 15-120) in custom_planner.py. You have access to:
agent.model: WorldModel withencode,next,pi,Q,rewardmethodsagent._estimate_value(z, actions, task): evaluates trajectory returnsagent._prev_mean: warm-start buffer from previous planning stepagent.cfg: all configuration parameters (horizon, num_samples, etc.)common.math: utility functions (gumbel_softmax_sample, two_hot_inv, etc.)
Evaluation
- Metric: Episode reward (higher is better)
- Environments: DMControl walker-walk and cheetah-run
- Model: TD-MPC2 with 1M parameters, 200K training steps
- Note: The planning algorithm affects both data collection quality during training and action selection during evaluation
Key Constraints
- The function must return a single action tensor of shape
(action_dim,)clamped to[-1, 1] - The function runs under
@torch.no_grad()— no gradient computation - Must update
agent._prev_meanfor temporal consistency across steps - Planning budget: keep total computation comparable to the default (6 iterations x 512 samples)
Code
custom_planner.py
EditableRead-only
1"""Custom planning algorithm for TD-MPC2.23Replace the planning logic in custom_plan() with your trajectory4optimization method. The function is called at each environment step5to select actions using the learned world model.6"""78import torch9from common import math101112# =====================================================================13# EDITABLE: Custom planning algorithm14# =====================================================================15@torch.no_grad()
Additional context files (read-only):
tdmpc2/tdmpc2/tdmpc2.pytdmpc2/tdmpc2/common/world_model.pytdmpc2/tdmpc2/common/math.py
Results
| Model | Type | episode reward walker walk ↑ | episode reward cheetah run ↑ | episode reward cartpole swingup ↑ |
|---|---|---|---|---|
| cem | baseline | 976.963 | 833.477 | 867.257 |
| icem | baseline | 978.297 | 796.457 | 881.233 |
| mppi | baseline | 978.310 | 889.897 | 877.213 |