humanoid-ppo-extractor
Description
Humanoid Control: PPO Feature Extractor Architecture
Objective
Improve PPO performance on humanoid locomotion by designing a better feature extractor network architecture. You can modify the CustomFeatureExtractor class (lines 20-38) and add custom imports (lines 14-16) in train_custom.py.
Background
The training uses Stable Baselines3 PPO with 4 parallel environments on three tasks from humanoid-bench: h1-stand-v0, h1-walk-v0, and h1-run-v0. The Unitree H1 humanoid robot must learn standing, walking, and running locomotion. The feature extractor processes raw proprioceptive observation vectors into feature representations used by the policy and value networks.
The default feature extractor is a 2-layer MLP with Tanh activations (64 hidden units, 64 output features). Training runs for 1M timesteps with a learning rate of 3e-4.
Interface
Your CustomFeatureExtractor must:
- Inherit from
BaseFeaturesExtractor - Call
super().__init__(observation_space, features_dim)in__init__ - Accept
observation_spaceandfeatures_dimas constructor arguments - Implement
forward(self, observations) -> torch.Tensorreturning a tensor of shape(batch, features_dim)
Evaluation
Final performance is measured by mean reward over 20 evaluation episodes with deterministic policy after 1M training timesteps, across all three environments.
Code
1import argparse2import numpy as np3import torch4import torch.nn as nn5import gymnasium as gym6from gymnasium.wrappers import TimeLimit7from stable_baselines3 import PPO8from stable_baselines3.common.monitor import Monitor9from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv10from stable_baselines3.common.evaluation import evaluate_policy11from stable_baselines3.common.torch_layers import BaseFeaturesExtractor12from stable_baselines3.common.callbacks import BaseCallback13import humanoid_bench14# ── Custom imports (editable) ────────────────────────────────────────────15
Results
| Model | Type | mean reward h1 stand ↑ | mean reward h1 walk ↑ | mean reward h1 run ↑ |
|---|---|---|---|---|
| layernorm_mlp | agent | 32.740 | 25.760 | 11.360 |
| residual_mlp | agent | 35.710 | 24.150 | 9.660 |
| wide_mlp | agent | 42.820 | 18.790 | 9.640 |