marl-centralized-critic
Description
Cooperative MARL: Centralized Critic Architecture for MAPPO
Objective
Improve cooperative multi-agent reinforcement learning by designing a better centralized critic architecture for MAPPO (Multi-Agent PPO). You can modify the CustomCritic class (lines 13-69) and add custom imports (lines 7-8) in custom_critic.py.
Background
In cooperative MARL with partial observability, each agent only sees a local observation but the team shares a common reward. Centralized-Training-with-Decentralized-Execution (CTDE) methods train a centralized value function during training (which can see the global state and all agents' information) and use it to reduce variance when computing advantages for each agent's decentralized policy gradient update. The architecture of this centralized critic — what it conditions on and how it mixes per-agent features — directly determines how tight the bias-variance tradeoff is and therefore how well MAPPO scales to hard multi-agent cooperation tasks.
The training uses EPyMARL's ppo_learner with the MAPPO default hyperparameters on three SMAC maps via smaclite (a pure-Python reimplementation of the StarCraft Multi-Agent Challenge benchmark — no StarCraft II binary required):
- mmm — 1 Medivac + 2 Marauders + 7 Marines (team of 10) vs mirror; heterogeneous cooperation requiring heal micro (≈5M env steps).
- 2s3z — 2 Stalkers + 3 Zealots (team of 5) vs mirror; medium heterogeneous team (≈5M env steps).
- 3s5z — 3 Stalkers + 5 Zealots (team of 8) vs mirror; larger team, harder (≈5M env steps).
The default critic is a simple (state ⊕ agent-one-hot) → 3-layer MLP → V that ignores per-agent observations. It is a working baseline but leaves room for smarter architectures that integrate per-agent features, attention, or state conditioning.
Interface
Your CustomCritic must:
- Inherit from
nn.Module. - Accept
(scheme, args)in__init__, where:scheme["state"]["vshape"]— global state dimscheme["obs"]["vshape"]— per-agent observation dimargs.n_agents,args.n_actions,args.hidden_dim,args.obs_last_action,args.obs_individual_obs
- Set
self.output_type = "v"in__init__. - Implement
forward(self, batch, t=None)where:batch["state"]— shape(B, T, state_dim)batch["obs"]— shape(B, T, n_agents, obs_dim)batch.batch_size,batch.max_seq_length,batch.devicet=Nonemeans "whole sequence"; otherwisetis an integer- Returns
qwith shape(B, T, n_agents, 1)— the learner later does.squeeze(3), so the trailing singleton is mandatory.
Reference Implementations
- IPPO critic (
ippo_critic.edit.py): per-agent MLP overbatch["obs"]⊕ agent-one-hot; no centralization. Floor baseline from Yu et al. 2022's IPPO ablation. Also seeepymarl/src/modules/critics/ac.py. - MAPPO critic (
mappo_critic.edit.py): shared MLP over(batch["state"] ⊕ agent-one-hot). Standard MAPPO central V from Yu et al. 2022. Also seeepymarl/src/modules/critics/centralV.py. - MAT-style attention critic (
mat_critic.edit.py): projects per-agent features (obs ⊕ state broadcast) into tokens, then a singleTransformerEncoderlayer with self-attention across the agent axis produces a per-agent value. Adapted from Wen et al. 2022 "Multi-Agent Transformer" (arXiv 2205.14953) — critic-only form; the MAPPO actor is kept unchanged.
Evaluation
Final performance is measured by test win rate (battle_won_mean) averaged over 32 test episodes with the greedy policy, evaluated separately on all three SMAC maps and recorded to the leaderboard under setup-specific metric keys:
- Primary:
test_battle_won_mean_<map> - Secondary:
test_return_mean_<map>
Higher is better. A strong centralized critic should generalize across maps of varying difficulty.
Code
1import numpy as np2import torch as th3import torch.nn as nn4import torch.nn.functional as F567# ── Custom imports (editable) ────────────────────────────────────────────8910# ======================================================================11# EDITABLE — Custom centralized critic for MAPPO12# ======================================================================13class CustomCritic(nn.Module):14"""Centralized critic for MAPPO on SMAC (via smaclite).15
Additional context files (read-only):
epymarl/src/modules/critics/centralV.pyepymarl/src/modules/critics/ac.pyepymarl/src/learners/ppo_learner.py
Results
| Model | Type | test return mean mmm ↑ | test return std mmm ↑ | test battle won mean mmm ↑ | test return mean 2s3z ↑ | test return std 2s3z ↑ | test battle won mean 2s3z ↑ | test return mean 3s5z ↑ | test return std 3s5z ↑ | test battle won mean 3s5z ↑ | test return mean 2s vs 1sc ↑ | test return std 2s vs 1sc ↑ | test battle won mean 2s vs 1sc ↑ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| ippo_critic | baseline | 21.725 | 4.144 | 0.635 | 15.749 | 2.885 | 0.448 | 18.188 | 2.603 | 0.635 | - | - | - |
| ippo_critic | baseline | - | - | - | 13.849 | 3.267 | 0.271 | 15.423 | 3.266 | 0.354 | 9.975 | 1.113 | 0.010 |
| ippo_critic | baseline | - | - | - | 15.749 | 2.885 | 0.448 | 18.188 | 2.603 | 0.635 | 9.975 | 1.113 | 0.010 |
| ippo_critic | baseline | - | - | - | 4.476 | 1.006 | 0.000 | 4.645 | 0.845 | 0.000 | - | - | - |
| ippo_critic | baseline | 21.725 | 4.144 | 0.635 | - | - | - | - | - | - | - | - | - |
| mappo_critic | baseline | 22.819 | 1.846 | 0.927 | 19.106 | 2.019 | 0.833 | 18.618 | 2.435 | 0.740 | - | - | - |
| mappo_critic | baseline | - | - | - | 19.068 | 2.212 | 0.833 | 17.242 | 2.721 | 0.469 | 10.033 | 0.036 | 0.000 |
| mappo_critic | baseline | - | - | - | 19.106 | 2.019 | 0.833 | 18.618 | 2.435 | 0.740 | 10.033 | 0.036 | 0.000 |
| mappo_critic | baseline | - | - | - | 4.461 | 1.116 | 0.000 | 4.595 | 0.883 | 0.000 | - | - | - |
| mappo_critic | baseline | 22.819 | 1.846 | 0.927 | - | - | - | - | - | - | - | - | - |
| mat_critic | baseline | 18.948 | 1.854 | 0.135 | 14.182 | 2.707 | 0.542 | 15.129 | 1.618 | 0.115 | - | - | - |
| mat_critic | baseline | - | - | - | 13.800 | 3.090 | 0.500 | 13.600 | 1.155 | 0.000 | 12.978 | 1.417 | 0.281 |
| mat_critic | baseline | - | - | - | 14.182 | 2.707 | 0.542 | 15.129 | 1.618 | 0.115 | 12.978 | 1.417 | 0.281 |
| mat_critic | baseline | - | - | - | 4.487 | 0.999 | 0.000 | 4.536 | 0.874 | 0.000 | - | - | - |
| mat_critic | baseline | 18.948 | 1.854 | 0.135 | - | - | - | - | - | - | - | - | - |
| claude-opus-4.6 | vanilla | 22.602 | 2.976 | 0.906 | 10.328 | 1.126 | 0.000 | 7.585 | 1.039 | 0.000 | - | - | - |
| deepseek-reasoner | vanilla | 22.736 | 1.910 | 0.938 | 18.024 | 3.028 | 0.688 | 18.918 | 2.207 | 0.781 | - | - | - |
| gemini-3.1-pro-preview | vanilla | 22.515 | 2.172 | 0.875 | 19.061 | 2.218 | 0.844 | 18.008 | 2.831 | 0.656 | - | - | - |
| qwen3.6-plus | vanilla | - | - | - | - | - | - | - | - | - | - | - | - |
| claude-opus-4.6 | agent | 23.414 | 1.298 | 0.969 | 19.178 | 2.895 | 0.906 | 17.099 | 3.255 | 0.500 | - | - | - |
| deepseek-reasoner | agent | 22.736 | 1.910 | 0.938 | 18.024 | 3.028 | 0.688 | 18.918 | 2.207 | 0.781 | - | - | - |
| gemini-3.1-pro-preview | agent | 23.093 | 1.880 | 0.969 | 19.518 | 1.658 | 0.906 | 16.782 | 2.813 | 0.375 | - | - | - |
| qwen3.6-plus | agent | 23.039 | 0.567 | 1.000 | 19.164 | 2.614 | 0.906 | 19.137 | 2.114 | 0.844 | - | - | - |