Agent Conversation: claude-opus-4.6
Task: marl-centralized-critic
Cooperative MARL: Centralized Critic Architecture for MAPPO
Objective
Improve cooperative multi-agent reinforcement learning by designing a better centralized critic architecture for MAPPO (Multi-Agent PPO). You can modify the CustomCritic class (lines 13-69) and add custom imports (lines 7-8) in custom_critic.py.
Background
In cooperative MARL with partial observability, each agent only sees a local observation but the team shares a common reward. Centralized-Training-with-Decentralized-Execution (CTDE) methods train a centralized value function during training (which can see the global state and all agents' information) and use it to reduce variance when computing advantages for each agent's decentralized policy gradient update. The architecture of this centralized critic — what it conditions on and how it mixes per-agent features — directly determines how tight the bias-variance tradeoff is and therefore how well MAPPO scales to hard multi-agent cooperation tasks.
The training uses EPyMARL's ppo_learner with the MAPPO default hyperparameters on three SMAC maps via smaclite (a pure-Python reimplementation of the StarCraft Multi-Agent Challenge benchmark — no StarCraft II binary required):
- mmm — 1 Medivac + 2 Marauders + 7 Marines (team of 10) vs mirror; heterogeneous cooperation requiring heal micro (≈5M env steps).
- 2s3z — 2 Stalkers + 3 Zealots (team of 5) vs mirror; medium heterogeneous team (≈5M env steps).
- 3s5z — 3 Stalkers + 5 Zealots (team of 8) vs mirror; larger team, harder (≈5M env steps).
The default critic is a simple (state ⊕ agent-one-hot) → 3-layer MLP → V that ignores per-agent observations. It is a working baseline but leaves room for smarter architectures that integrate per-agent features, attention, or state conditioning.
Interface
Your CustomCritic must:
- Inherit from
nn.Module. - Accept
(scheme, args)in__init__, where:scheme["state"]["vshape"]— global state dimscheme["obs"]["vshape"]— per-agent observation dimargs.n_agents,args.n_actions,args.hidden_dim,args.obs_last_action,args.obs_individual_obs
- Set
self.output_type = "v"in__init__. - Implement
forward(self, batch, t=None)where:batch["state"]— shape(B, T, state_dim)batch["obs"]— shape(B, T, n_agents, obs_dim)batch.batch_size,batch.max_seq_length,batch.devicet=Nonemeans "whole sequence"; otherwisetis an integer- Returns
qwith shape(B, T, n_agents, 1)— the learner later does.squeeze(3), so the trailing singleton is mandatory.
Reference Implementations
- IPPO critic (
ippo_critic.edit.py): per-agent MLP overbatch["obs"]⊕ agent-one-hot; no centralization. Floor baseline from Yu et al. 2022's IPPO ablation. Also seeepymarl/src/modules/critics/ac.py. - MAPPO critic (
mappo_critic.edit.py): shared MLP over(batch["state"] ⊕ agent-one-hot). Standard MAPPO c
Tool Input
1{
2 "op": "replace",
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 7,
5 "end_line": 8,
6 "content": "# ── Custom imports (editable) ────────────────────────────────────────────\n"
7}OK: Replaced lines 7..8 with 1 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–7, 12–68.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–7, 12–68 | total: 68 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
...
12 class CustomCritic(nn.Module):
13 """Centralized critic for MAPPO on SMAC (via smaclite).
14
... (51 more lines) ...
66 x = F.relu(self.fc2(x))
67 q = self.fc3(x) # (B, T, n, 1)
68 return qTool Input
1{
2 "op": "replace",
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 12,
5 "end_line": 68,
6 "content": "class CustomCritic(nn.Module):\n \"\"\"Observation-Gated Centralized Critic (OGC²) for MAPPO.\n\n Algorithmic contribution: Feature-wise Linear Modulation (FiLM) fusion\n of global state and per-agent observations.\n\n In cooperative MARL, each agent's value depends on both the global team\n state AND the agent's local situation. The same global state has different\n value implications for different agents depending on their local context\n (e.g., a low-health unit should value retreat more than a full-health one).\n\n Standard MAPPO ignores per-agent obs in the critic (state + agent_id only).\n Simple concatenation of state + obs treats them symmetrically. Instead, we\n use FiLM conditioning: each agent's obs encoder produces per-feature scale\n (gamma) and shift (beta) that multiplicatively modulate the state\n representation. This creates agent-specific \"lenses\" through which to\n interpret the global state, capturing the interaction between global\n context and local situation.\n\n Architecture:\n state → Linear → ReLU → h_s (hidden state features)\n obs+id → Linear → ReLU → h_o (hidden obs features)\n h_o → Linear → Sigmoid → γ (per-feature gate)\n h_o → Linear → β (per-feature shift)\n modulated = γ ⊙ h_s + β + h_o (FiLM + residual)\n LayerNorm(modulated) → Linear → ReLU → Linear → V\n \"\"\"\n\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n obs_dim = int(scheme[\"obs\"][\"vshape\"])\n state_dim = int(scheme[\"state\"][\"vshape\"])\n H = args.hidden_dim\n\n # Dual-stream first-layer encoders\n self.state_fc = nn.Linear(state_dim, H)\n self.obs_fc = nn.Linear(obs_dim + self.n_agents, H)\n\n # FiLM generators: obs features → per-feature scale & shift for state\n self.film_gamma = nn.Linear(H, H) # → sigmoid → multiplicative gate\n self.film_beta = nn.Linear(H, H) # → additive shift\n\n # Layer norm after modulation for training stability\n self.ln = nn.LayerNorm(H)\n\n # Value head (shared across agents)\n self.fc2 = nn.Linear(H, H)\n self.fc3 = nn.Linear(H, 1)\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n # Extract inputs\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n state = state.unsqueeze(2).expand(-1, -1, self.n_agents, -1) # (B, T, n, state_dim)\n\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)\n\n # Encode state stream\n h_s = F.relu(self.state_fc(state)) # (B, T, n, H)\n\n # Encode obs stream (obs + agent one-hot)\n obs_input = th.cat([obs, agent_id], dim=-1) # (B, T, n, obs+n)\n h_o = F.relu(self.obs_fc(obs_input)) # (B, T, n, H)\n\n # FiLM: obs-conditioned gating of state features\n gamma = th.sigmoid(self.film_gamma(h_o)) # (B, T, n, H) ∈ (0,1)\n beta = self.film_beta(h_o) # (B, T, n, H)\n\n # Modulate state features + residual from obs stream\n modulated = gamma * h_s + beta + h_o # (B, T, n, H)\n modulated = self.ln(modulated)\n\n # Value head\n x = F.relu(self.fc2(modulated)) # (B, T, n, H)\n q = self.fc3(x) # (B, T, n, 1)\n return q"
7}OK: Replaced lines 12..68 with 86 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–7, 12–97.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–7, 12–97 | total: 97 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
...
12 class CustomCritic(nn.Module):
13 """Observation-Gated Centralized Critic (OGC²) for MAPPO.
14
... (80 more lines) ...
95 x = F.relu(self.fc2(modulated)) # (B, T, n, H)
96 q = self.fc3(x) # (B, T, n, 1)
97 return q[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5004408 return_mean=23.0683 return_std=1.9680 battle_won_mean=0.9492 TRAIN_METRICS t_env=5015062 return_mean=23.0563 return_std=1.7606 battle_won_mean=0.9492 TRAIN_METRICS t_env=5025185 return_mean=23.2357 return_std=1.9818 battle_won_mean=0.9420 TRAIN_METRICS t_env=5036526 return_mean=22.9965 return_std=2.1114 battle_won_mean=0.9297 TRAIN_METRICS t_env=5047493 return_mean=22.9963 return_std=2.2469 battle_won_mean=0.9375 Final test for mmm (t_env=5032206): Win rate: 0.9062 Mean return: 22.6018 +/- 2.9762 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=4990619 return_mean=10.2005 return_std=0.7899 battle_won_mean=0.0000 TRAIN_METRICS t_env=5005019 return_mean=10.0051 return_std=0.9426 battle_won_mean=0.0000 TRAIN_METRICS t_env=5019419 return_mean=10.0766 return_std=0.7336 battle_w
Tool Input
1{
2 "content": "class CustomCritic(nn.Module):\n \"\"\"Observation-Gated Centralized Critic (OGC²) for MAPPO.\n\n Algorithmic contribution: Dual-path critic with observation-conditioned\n gating of global state features, plus a stabilizing residual path.\n\n In cooperative MARL, each agent's value depends on both the global team\n state AND the agent's local situation. Standard MAPPO uses only\n state + agent_id, ignoring per-agent observations in the critic.\n\n Our approach: the agent's local observation generates a soft gate that\n selectively amplifies or suppresses features of the global state encoding.\n This lets each agent extract value-relevant information from the shared\n state based on its own situation (e.g., a low-health agent values\n retreat states differently than a full-health one).\n\n Two paths are combined:\n 1. Gated path: obs → gate → element-wise modulate state features\n 2. Direct path: state + agent_id → standard encoding (stabilizer)\n The combination provides both novelty (gating) and stability (direct path).\n \"\"\"\n\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n obs_dim = int(scheme[\"obs\"][\"vshape\"])\n state_dim = int(scheme[\"state\"][\"vshape\"])\n H = args.hidden_dim\n\n # ── Direct path: standard MAPPO-style (state + agent_id) ──\n self.direct_fc1 = nn.Linear(state_dim + self.n_agents, H)\n\n # ── Gated path: obs-conditioned modulation of state ──\n # State encoder\n self.state_fc = nn.Linear(state_dim, H)\n # Obs encoder (generates gate)\n self.obs_fc = nn.Linear(obs_dim + self.n_agents, H)\n # Gate projection\n self.gate_fc = nn.Linear(H, H)\n\n # ── Fusion: combine both paths ──\n self.fc2 = nn.Linear(H, H)\n self.fc3 = nn.Linear(H, 1)\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n # Extract inputs\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n state = state.unsqueeze(2).expand(-1, -1, self.n_agents, -1) # (B, T, n, state_dim)\n\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)\n\n # ── Direct path (stabilizer): standard state + agent_id ──\n direct_input = th.cat([state, agent_id], dim=-1) # (B, T, n, state+n)\n h_direct = F.relu(self.direct_fc1(direct_input)) # (B, T, n, H)\n\n # ── Gated path: obs-conditioned state modulation ──\n h_s = F.relu(self.state_fc(state)) # (B, T, n, H)\n obs_input = th.cat([obs, agent_id], dim=-1) # (B, T, n, obs+n)\n h_o = F.relu(self.obs_fc(obs_input)) # (B, T, n, H)\n gate = th.sigmoid(self.gate_fc(h_o)) # (B, T, n, H) ∈ (0,1)\n h_gated = gate * h_s # (B, T, n, H)\n\n # ── Combine: gated features + direct path ──\n combined = h_direct + h_gated # (B, T, n, H)\n\n # ── Value head ──\n x = F.relu(self.fc2(combined)) # (B, T, n, H)\n q = self.fc3(x) # (B, T, n, 1)\n return q",
3 "end_line": 97,
4 "filename": "epymarl/src/modules/critics/custom_critic.py",
5 "op": "replace",
6 "start_line": 12
7}OK: Replaced lines 12..97 with 79 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–7, 12–90.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–7, 12–90 | total: 90 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
...
12 class CustomCritic(nn.Module):
13 """Observation-Gated Centralized Critic (OGC²) for MAPPO.
14
... (73 more lines) ...
88 x = F.relu(self.fc2(combined)) # (B, T, n, H)
89 q = self.fc3(x) # (B, T, n, 1)
90 return q[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=4997963 return_mean=23.2905 return_std=1.0618 battle_won_mean=0.9826 TRAIN_METRICS t_env=5008942 return_mean=23.3561 return_std=1.0968 battle_won_mean=0.9931 TRAIN_METRICS t_env=5019099 return_mean=23.2060 return_std=1.5927 battle_won_mean=0.9688 TRAIN_METRICS t_env=5029968 return_mean=23.1149 return_std=1.5201 battle_won_mean=0.9792 TRAIN_METRICS t_env=5041122 return_mean=23.0837 return_std=1.6447 battle_won_mean=0.9653 Final test for mmm (t_env=5012717): Win rate: 0.9688 Mean return: 23.4138 +/- 1.2980 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5004410 return_mean=19.4788 return_std=1.9169 battle_won_mean=0.9062 TRAIN_METRICS t_env=5014946 return_mean=19.7606 return_std=1.5216 battle_won_mean=0.9531 TRAIN_METRICS t_env=5025401 return_mean=19.7544 return_std=1.4147 battle_wo
Tool Input
1{
2 "n": 2
3}[submit] Submitting result from test #2 as final. ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=4997963 return_mean=23.2905 return_std=1.0618 battle_won_mean=0.9826 TRAIN_METRICS t_env=5008942 return_mean=23.3561 return_std=1.0968 battle_won_mean=0.9931 TRAIN_METRICS t_env=5019099 return_mean=23.2060 return_std=1.5927 battle_won_mean=0.9688 TRAIN_METRICS t_env=5029968 return_mean=23.1149 return_std=1.5201 battle_won_mean=0.9792 TRAIN_METRICS t_env=5041122 return_mean=23.0837 return_std=1.6447 battle_won_mean=0.9653 Final test for mmm (t_env=5012717): Win rate: 0.9688 Mean return: 23.4138 +/- 1.2980 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5004410 return_mean=19.4788 return_std=1.9169 battle_won_mean=0.9062 TRAIN_METRICS t_env=5014946 return_mean=19.7606 return_std=1.5216 battle_won_mean=0.9531 TRAIN_METRICS t_env=5025401 return_mean=19.7544 return_std=1.4147 battle_won_mean=0.9531 TRAIN_METRICS t_env=5036013 re