Agent Conversation: gemini-3.1-pro-preview
Task: marl-centralized-critic
Cooperative MARL: Centralized Critic Architecture for MAPPO
Objective
Improve cooperative multi-agent reinforcement learning by designing a better centralized critic architecture for MAPPO (Multi-Agent PPO). You can modify the CustomCritic class (lines 13-69) and add custom imports (lines 7-8) in custom_critic.py.
Background
In cooperative MARL with partial observability, each agent only sees a local observation but the team shares a common reward. Centralized-Training-with-Decentralized-Execution (CTDE) methods train a centralized value function during training (which can see the global state and all agents' information) and use it to reduce variance when computing advantages for each agent's decentralized policy gradient update. The architecture of this centralized critic — what it conditions on and how it mixes per-agent features — directly determines how tight the bias-variance tradeoff is and therefore how well MAPPO scales to hard multi-agent cooperation tasks.
The training uses EPyMARL's ppo_learner with the MAPPO default hyperparameters on three SMAC maps via smaclite (a pure-Python reimplementation of the StarCraft Multi-Agent Challenge benchmark — no StarCraft II binary required):
- mmm — 1 Medivac + 2 Marauders + 7 Marines (team of 10) vs mirror; heterogeneous cooperation requiring heal micro (≈5M env steps).
- 2s3z — 2 Stalkers + 3 Zealots (team of 5) vs mirror; medium heterogeneous team (≈5M env steps).
- 3s5z — 3 Stalkers + 5 Zealots (team of 8) vs mirror; larger team, harder (≈5M env steps).
The default critic is a simple (state ⊕ agent-one-hot) → 3-layer MLP → V that ignores per-agent observations. It is a working baseline but leaves room for smarter architectures that integrate per-agent features, attention, or state conditioning.
Interface
Your CustomCritic must:
- Inherit from
nn.Module. - Accept
(scheme, args)in__init__, where:scheme["state"]["vshape"]— global state dimscheme["obs"]["vshape"]— per-agent observation dimargs.n_agents,args.n_actions,args.hidden_dim,args.obs_last_action,args.obs_individual_obs
- Set
self.output_type = "v"in__init__. - Implement
forward(self, batch, t=None)where:batch["state"]— shape(B, T, state_dim)batch["obs"]— shape(B, T, n_agents, obs_dim)batch.batch_size,batch.max_seq_length,batch.devicet=Nonemeans "whole sequence"; otherwisetis an integer- Returns
qwith shape(B, T, n_agents, 1)— the learner later does.squeeze(3), so the trailing singleton is mandatory.
Reference Implementations
- IPPO critic (
ippo_critic.edit.py): per-agent MLP overbatch["obs"]⊕ agent-one-hot; no centralization. Floor baseline from Yu et al. 2022's IPPO ablation. Also seeepymarl/src/modules/critics/ac.py. - MAPPO critic (
mappo_critic.edit.py): shared MLP over(batch["state"] ⊕ agent-one-hot). Standard MAPPO c
Tool Input
1{
2 "op": "replace",
3 "end_line": 69,
4 "filename": "epymarl/src/modules/critics/custom_critic.py",
5 "start_line": 13,
6 "content": "class CustomCritic(nn.Module):\n \"\"\"Attentive Context-Gated Critic (ACGC).\n \n Novelty: Combines global state context with a shared representation of \n all agents' local observations via Self-Attention and Gating.\n 1. Local observations are enriched via self-attention across agents.\n 2. The global state and enriched local observations are fused using \n a learned gating mechanism.\n \"\"\"\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n self.state_dim = int(scheme[\"state\"][\"vshape\"])\n self.obs_dim = int(scheme[\"obs\"][\"vshape\"])\n self.hidden_dim = args.hidden_dim\n \n # State path\n self.state_fc = nn.Linear(self.state_dim, self.hidden_dim)\n \n # Obs path\n self.obs_fc = nn.Linear(self.obs_dim + self.n_agents, self.hidden_dim)\n \n # Inter-agent attention\n self.num_heads = 4 if self.hidden_dim % 4 == 0 else 1\n self.attn = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=self.num_heads, batch_first=True)\n self.layer_norm = nn.LayerNorm(self.hidden_dim)\n \n # Gating\n self.gate_fc = nn.Linear(self.hidden_dim * 2, self.hidden_dim)\n \n # Value head\n self.v_fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)\n self.v_fc2 = nn.Linear(self.hidden_dim, 1)\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n \n # Global state features\n state_feat = F.relu(self.state_fc(state))\n state_feat = state_feat.unsqueeze(2).expand(-1, -1, self.n_agents, -1)\n \n # Local observation features\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)\n obs_id = th.cat([obs, agent_id], dim=-1)\n \n obs_feat = F.relu(self.obs_fc(obs_id)) # (B, T, n, H)\n \n # Inter-agent communication via self-attention\n b, tt, n, d = obs_feat.shape\n obs_seq = obs_feat.reshape(b * tt, n, d)\n attn_out, _ = self.attn(obs_seq, obs_seq, obs_seq)\n attn_out = attn_out.reshape(b, tt, n, d)\n \n # Residual connection and LayerNorm\n obs_feat = self.layer_norm(obs_feat + attn_out)\n \n # Gated fusion of global state and enriched local observations\n gate_input = th.cat([state_feat, obs_feat], dim=-1)\n gate = th.sigmoid(self.gate_fc(gate_input))\n \n fused = state_feat * gate + obs_feat * (1 - gate)\n \n # Value prediction\n x = F.relu(self.v_fc1(fused))\n q = self.v_fc2(x) # (B, T, n, 1)\n return q"
7}OK: Replaced lines 13..69 with 76 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–88.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–88 | total: 88 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8
...
13 class CustomCritic(nn.Module):
14 """Attentive Context-Gated Critic (ACGC).
15
... (70 more lines) ...
86 x = F.relu(self.v_fc1(fused))
87 q = self.v_fc2(x) # (B, T, n, 1)
88 return q[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=4998461 return_mean=22.6254 return_std=2.5535 battle_won_mean=0.8929 TRAIN_METRICS t_env=5009210 return_mean=23.1345 return_std=1.7403 battle_won_mean=0.9648 TRAIN_METRICS t_env=5020382 return_mean=22.9234 return_std=2.6265 battle_won_mean=0.9219 TRAIN_METRICS t_env=5031411 return_mean=22.8518 return_std=2.0662 battle_won_mean=0.9102 TRAIN_METRICS t_env=5042677 return_mean=22.8929 return_std=2.2725 battle_won_mean=0.9180 Final test for mmm (t_env=5024633): Win rate: 0.8750 Mean return: 22.5152 +/- 2.1721 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5006872 return_mean=18.7060 return_std=2.5474 battle_won_mean=0.7812 TRAIN_METRICS t_env=5017691 return_mean=18.6935 return_std=2.6697 battle_won_mean=0.7969 TRAIN_METRICS t_env=5028469 return_mean=18.9069 return_std=2.4953 battle_w
Tool Input
1{
2 "op": "replace",
3 "start_line": 13,
4 "content": "class CustomCritic(nn.Module):\n \"\"\"Residual FiLM Centralized Critic.\n \n Novelty: Combines the global state and local observations using a \n Residual Feature-wise Linear Modulation (FiLM) block. \n The global state is processed in the main stream, while the agent's \n local observation and ID generate FiLM parameters to modulate the \n state features. A residual connection ensures stable learning.\n \"\"\"\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n self.state_dim = int(scheme[\"state\"][\"vshape\"])\n self.obs_dim = int(scheme[\"obs\"][\"vshape\"])\n self.hidden_dim = args.hidden_dim\n \n # State encoder\n self.state_encoder = nn.Linear(self.state_dim, self.hidden_dim)\n self.state_ln = nn.LayerNorm(self.hidden_dim)\n \n # Observation encoder (condition generator)\n self.condition_encoder = nn.Sequential(\n nn.Linear(self.obs_dim + self.n_agents, self.hidden_dim),\n nn.ReLU(),\n nn.Linear(self.hidden_dim, 2 * self.hidden_dim)\n )\n \n # Post-modulation network\n self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)\n self.fc3 = nn.Linear(self.hidden_dim, 1)\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n state = batch[\"state\"][:, ts]\n obs = batch[\"obs\"][:, ts]\n \n # 1. Encode global state\n state_feat = self.state_encoder(state) # (B, T, H)\n state_feat = self.state_ln(state_feat) # (B, T, H)\n state_feat = F.relu(state_feat)\n state_feat = state_feat.unsqueeze(2).expand(-1, -1, self.n_agents, -1)\n \n # 2. Generate FiLM parameters from local obs + agent id\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)\n obs_id = th.cat([obs, agent_id], dim=-1)\n \n film_params = self.condition_encoder(obs_id) # (B, T, n, 2H)\n gamma, beta = film_params.chunk(2, dim=-1) # Each is (B, T, n, H)\n \n # 3. Apply FiLM modulation in a residual branch\n modulated = (1.0 + gamma) * state_feat + beta\n modulated = F.relu(modulated)\n \n x = self.fc2(modulated)\n \n # Residual connection\n out = F.relu(x + state_feat)\n \n # 4. Final value prediction\n q = self.fc3(out) # (B, T, n, 1)\n return q",
5 "filename": "epymarl/src/modules/critics/custom_critic.py",
6 "end_line": 88
7}OK: Replaced lines 13..88 with 69 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–81.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–81 | total: 81 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8
...
13 class CustomCritic(nn.Module):
14 """Residual FiLM Centralized Critic.
15
... (63 more lines) ...
79 # 4. Final value prediction
80 q = self.fc3(out) # (B, T, n, 1)
81 return q[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5005255 return_mean=22.5879 return_std=2.8705 battle_won_mean=0.8482 TRAIN_METRICS t_env=5015999 return_mean=22.5235 return_std=2.9371 battle_won_mean=0.8705 TRAIN_METRICS t_env=5026406 return_mean=23.0353 return_std=2.5413 battle_won_mean=0.9286 TRAIN_METRICS t_env=5037268 return_mean=22.4955 return_std=2.7873 battle_won_mean=0.8571 TRAIN_METRICS t_env=5047898 return_mean=22.5993 return_std=3.0267 battle_won_mean=0.8571 Final test for mmm (t_env=5041887): Win rate: 0.9375 Mean return: 23.2647 +/- 2.0570 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=4997891 return_mean=19.0736 return_std=2.1038 battle_won_mean=0.8313 TRAIN_METRICS t_env=5009017 return_mean=18.9850 return_std=2.3000 battle_won_mean=0.8250 TRAIN_METRICS t_env=5019319 return_mean=18.8241 return_std=2.4343 battle_wo
Tool Input
1{
2 "end_line": 81,
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 13,
5 "content": "class CustomCritic(nn.Module):\n \"\"\"Observation-Gated Dual-Stream Critic.\n \n Novelty: Processes the global state and local observation in parallel. \n A learned Sigmoid gating mechanism, conditioned on both the global state \n and the local observation, filters the global state to extract only the \n features relevant to the specific agent. The filtered state and local \n observation are then concatenated to predict the value, ensuring stable \n and bounded gradient flow.\n \"\"\"\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n self.state_dim = int(scheme[\"state\"][\"vshape\"])\n self.obs_dim = int(scheme[\"obs\"][\"vshape\"])\n self.hidden_dim = args.hidden_dim\n \n # State path\n self.state_fc = nn.Linear(self.state_dim, self.hidden_dim)\n \n # Obs path\n self.obs_fc = nn.Linear(self.obs_dim + self.n_agents, self.hidden_dim)\n \n # Gate generator\n self.gate_fc = nn.Linear(self.hidden_dim * 2, self.hidden_dim)\n \n # Value head\n self.v_fc1 = nn.Linear(self.hidden_dim * 2, self.hidden_dim)\n self.v_fc2 = nn.Linear(self.hidden_dim, 1)\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n \n # Global state features\n state_feat = F.relu(self.state_fc(state)) # (B, T, H)\n state_feat = state_feat.unsqueeze(2).expand(-1, -1, self.n_agents, -1)\n \n # Local observation features\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)\n obs_id = th.cat([obs, agent_id], dim=-1)\n \n obs_feat = F.relu(self.obs_fc(obs_id)) # (B, T, n, H)\n \n # Gating: Conditioned on both state and local obs\n gate_input = th.cat([state_feat, obs_feat], dim=-1) # (B, T, n, 2H)\n gate = th.sigmoid(self.gate_fc(gate_input)) # (B, T, n, H)\n \n gated_state = state_feat * gate # (B, T, n, H)\n \n # Combine\n combined = th.cat([gated_state, obs_feat], dim=-1) # (B, T, n, 2H)\n \n # Value prediction\n x = F.relu(self.v_fc1(combined))\n q = self.v_fc2(x) # (B, T, n, 1)\n return q",
6 "op": "replace"
7}OK: Replaced lines 13..81 with 66 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–78.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–78 | total: 78 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8
...
13 class CustomCritic(nn.Module):
14 """Observation-Gated Dual-Stream Critic.
15
... (60 more lines) ...
76 x = F.relu(self.v_fc1(combined))
77 q = self.v_fc2(x) # (B, T, n, 1)
78 return q[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=585882 return_mean=23.3215 return_std=2.0347 battle_won_mean=0.9420 TRAIN_METRICS t_env=595975 return_mean=23.5946 return_std=1.2139 battle_won_mean=0.9866 TRAIN_METRICS t_env=606674 return_mean=22.9022 return_std=2.4836 battle_won_mean=0.9152 TRAIN_METRICS t_env=617421 return_mean=23.2909 return_std=1.9091 battle_won_mean=0.9464 TRAIN_METRICS t_env=627457 return_mean=22.3624 return_std=3.3771 battle_won_mean=0.8333 Final test for mmm (t_env=609819): Win rate: 0.9688 Mean return: 23.0930 +/- 1.8796 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=737317 return_mean=18.0539 return_std=2.9534 battle_won_mean=0.6641 TRAIN_METRICS t_env=748375 return_mean=18.5244
Tool Input
1{
2 "n": 3
3}[submit] Submitting result from test #3 as final. ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=585882 return_mean=23.3215 return_std=2.0347 battle_won_mean=0.9420 TRAIN_METRICS t_env=595975 return_mean=23.5946 return_std=1.2139 battle_won_mean=0.9866 TRAIN_METRICS t_env=606674 return_mean=22.9022 return_std=2.4836 battle_won_mean=0.9152 TRAIN_METRICS t_env=617421 return_mean=23.2909 return_std=1.9091 battle_won_mean=0.9464 TRAIN_METRICS t_env=627457 return_mean=22.3624 return_std=3.3771 battle_won_mean=0.8333 Final test for mmm (t_env=609819): Win rate: 0.9688 Mean return: 23.0930 +/- 1.8796 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=737317 return_mean=18.0539 return_std=2.9534 battle_won_mean=0.6641 TRAIN_METRICS t_env=748375 return_mean=18.5244 return_std=2.7175 battle_won_mean=0.7500 TRAIN_METRICS t_env=759324 return_mean=18.6547 return_std=2.7177 battle_won_mean=0.7734 TRAIN_METRICS t_env=770312 return_mean=