tdmpc2-simnorm
Description
Latent Representation Normalization for Model-Based RL
Objective
Design and implement a custom normalization technique for latent state representations in model-based reinforcement learning. Your code goes in the CustomSimNorm class in custom_simnorm.py. This normalization is applied as the final activation in both the encoder and dynamics networks of the TD-MPC2 world model.
Background
TD-MPC2 learns an implicit world model in a latent space and uses it for planning. The latent representation geometry is critical for stable learning. The default approach uses SimNorm (Simplicial Normalization), which reshapes the latent vector into groups of 8 and applies softmax within each group, constraining each group to lie on a simplex.
Alternative normalization strategies could improve learning stability, representation quality, or computational efficiency:
- L2 normalization: projects onto a hypersphere
- RMSNorm: root-mean-square normalization without mean centering
- Spectral normalization: controls the Lipschitz constant
- Gumbel-softmax: adds stochasticity to the simplex projection
- Hybrid approaches: combining different normalization strategies
What You Can Modify
The CustomSimNorm class (lines 16-43) in custom_simnorm.py:
__init__(self, cfg): initialize parameters (cfg.simnorm_dim = 8)forward(self, x): normalize the latent vector (must preserve shape)
Evaluation
- Metric: Episode reward (higher is better)
- Environments: DMControl walker-walk and cheetah-run
- Model: TD-MPC2 with 1M parameters, 200K training steps
- Baselines: SimNorm (~800 reward on walker-walk), L2Norm, RMSNorm
Architecture Context
The normalization is used in:
- Encoder (
layers.py: enc()): maps observations to latent states - Dynamics (
world_model.py: __init__): predicts next latent state from current state + action
Both use SimNorm as the final activation in their MLP stacks. The latent dimension is 128 with simnorm_dim=8 (16 groups).
Code
1"""Custom latent normalization for TD-MPC2 world model.23Replace the body of CustomSimNorm with your normalization implementation.4The class is used as the final activation in the encoder and dynamics5networks, constraining the latent representation geometry.6"""78import torch9import torch.nn as nn10import torch.nn.functional as F111213# =====================================================================14# EDITABLE: Custom latent normalization15# =====================================================================
Additional context files (read-only):
tdmpc2/tdmpc2/common/layers.pytdmpc2/tdmpc2/common/world_model.py
Results
| Model | Type | episode reward walker walk ↑ | episode reward cheetah run ↑ | episode reward cartpole swingup ↑ |
|---|---|---|---|---|
| identity | baseline | 976.680 | 763.450 | 873.120 |
| l2norm | baseline | 976.330 | 813.487 | 878.993 |
| rmsnorm | baseline | 976.923 | 680.960 | 873.393 |
| simnorm | baseline | 977.207 | 888.937 | 881.517 |