tdmpc2-planning

Othertdmpc2rigorous codebase

Description

Planning Algorithm for Model-Based RL

Objective

Design and implement a custom trajectory optimization algorithm for online planning in model-based reinforcement learning. Your code goes in the custom_plan() function in custom_planner.py. This function is called at every environment step to select actions using the learned world model.

Background

TD-MPC2 uses Model Predictive Path Integral (MPPI) for planning. At each step, the agent:

  1. Samples num_pi_trajs=24 trajectories from the learned policy as warm-starts
  2. Iterates iterations=6 rounds of:
    • Samples num_samples=512 action sequences from N(mean, std)
    • Rolls out each trajectory through the latent dynamics model for horizon=3 steps
    • Estimates trajectory value using predicted rewards + terminal Q-value
    • Selects num_elites=64 best trajectories
    • Updates mean/std using softmax-weighted (temperature=0.5) elite statistics
  3. Selects final action via Gumbel-softmax sampling from elites

Alternative planning approaches could improve sample efficiency, convergence speed, or final performance:

  • Cross-Entropy Method (CEM): simpler elite selection without softmax weighting
  • iCEM: improved CEM with temporally correlated noise and keep-elites
  • Gradient-based planning: backpropagating through the world model
  • Hybrid approaches: combining sampling with gradient refinement
  • Adaptive methods: adjusting sampling parameters during optimization

What You Can Modify

The custom_plan() function (lines 15-120) in custom_planner.py. You have access to:

  • agent.model: WorldModel with encode, next, pi, Q, reward methods
  • agent._estimate_value(z, actions, task): evaluates trajectory returns
  • agent._prev_mean: warm-start buffer from previous planning step
  • agent.cfg: all configuration parameters (horizon, num_samples, etc.)
  • common.math: utility functions (gumbel_softmax_sample, two_hot_inv, etc.)

Evaluation

  • Metric: Episode reward (higher is better)
  • Environments: DMControl walker-walk and cheetah-run
  • Model: TD-MPC2 with 1M parameters, 200K training steps
  • Note: The planning algorithm affects both data collection quality during training and action selection during evaluation

Key Constraints

  • The function must return a single action tensor of shape (action_dim,) clamped to [-1, 1]
  • The function runs under @torch.no_grad() — no gradient computation
  • Must update agent._prev_mean for temporal consistency across steps
  • Planning budget: keep total computation comparable to the default (6 iterations x 512 samples)

Code

custom_planner.py
EditableRead-only
1"""Custom planning algorithm for TD-MPC2.
2
3Replace the planning logic in custom_plan() with your trajectory
4optimization method. The function is called at each environment step
5to select actions using the learned world model.
6"""
7
8import torch
9from common import math
10
11
12# =====================================================================
13# EDITABLE: Custom planning algorithm
14# =====================================================================
15@torch.no_grad()

Additional context files (read-only):

  • tdmpc2/tdmpc2/tdmpc2.py
  • tdmpc2/tdmpc2/common/world_model.py
  • tdmpc2/tdmpc2/common/math.py

Results

ModelTypeepisode reward walker walk episode reward cheetah run episode reward cartpole swingup
cembaseline976.963833.477867.257
icembaseline978.297796.457881.233
mppibaseline978.310889.897877.213