Agent Conversation: deepseek-reasoner
Task: marl-centralized-critic
Cooperative MARL: Centralized Critic Architecture for MAPPO
Objective
Improve cooperative multi-agent reinforcement learning by designing a better centralized critic architecture for MAPPO (Multi-Agent PPO). You can modify the CustomCritic class (lines 13-69) and add custom imports (lines 7-8) in custom_critic.py.
Background
In cooperative MARL with partial observability, each agent only sees a local observation but the team shares a common reward. Centralized-Training-with-Decentralized-Execution (CTDE) methods train a centralized value function during training (which can see the global state and all agents' information) and use it to reduce variance when computing advantages for each agent's decentralized policy gradient update. The architecture of this centralized critic — what it conditions on and how it mixes per-agent features — directly determines how tight the bias-variance tradeoff is and therefore how well MAPPO scales to hard multi-agent cooperation tasks.
The training uses EPyMARL's ppo_learner with the MAPPO default hyperparameters on three SMAC maps via smaclite (a pure-Python reimplementation of the StarCraft Multi-Agent Challenge benchmark — no StarCraft II binary required):
- mmm — 1 Medivac + 2 Marauders + 7 Marines (team of 10) vs mirror; heterogeneous cooperation requiring heal micro (≈5M env steps).
- 2s3z — 2 Stalkers + 3 Zealots (team of 5) vs mirror; medium heterogeneous team (≈5M env steps).
- 3s5z — 3 Stalkers + 5 Zealots (team of 8) vs mirror; larger team, harder (≈5M env steps).
The default critic is a simple (state ⊕ agent-one-hot) → 3-layer MLP → V that ignores per-agent observations. It is a working baseline but leaves room for smarter architectures that integrate per-agent features, attention, or state conditioning.
Interface
Your CustomCritic must:
- Inherit from
nn.Module. - Accept
(scheme, args)in__init__, where:scheme["state"]["vshape"]— global state dimscheme["obs"]["vshape"]— per-agent observation dimargs.n_agents,args.n_actions,args.hidden_dim,args.obs_last_action,args.obs_individual_obs
- Set
self.output_type = "v"in__init__. - Implement
forward(self, batch, t=None)where:batch["state"]— shape(B, T, state_dim)batch["obs"]— shape(B, T, n_agents, obs_dim)batch.batch_size,batch.max_seq_length,batch.devicet=Nonemeans "whole sequence"; otherwisetis an integer- Returns
qwith shape(B, T, n_agents, 1)— the learner later does.squeeze(3), so the trailing singleton is mandatory.
Reference Implementations
- IPPO critic (
ippo_critic.edit.py): per-agent MLP overbatch["obs"]⊕ agent-one-hot; no centralization. Floor baseline from Yu et al. 2022's IPPO ablation. Also seeepymarl/src/modules/critics/ac.py. - MAPPO critic (
mappo_critic.edit.py): shared MLP over(batch["state"] ⊕ agent-one-hot). Standard MAPPO c
Tool Input
1{
2 "op": "replace",
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 13,
5 "end_line": 69,
6 "content": "class CustomCritic(nn.Module):\n \"\"\"Gated Feature Fusion centralized critic for MAPPO.\n\n Combines per-agent local observations with global state via a learned gating\n mechanism. Each agent's value is computed as:\n local_feat = MLP(obs_i ⊕ agent_id ⊕ [last_action_i])\n global_feat = MLP(state) broadcast\n gate = sigmoid(MLP(local_feat ⊕ global_feat))\n fused = gate * local_feat + (1 - gate) * global_feat\n value = MLP(fused)\n\n This allows adaptive weighting between local and global information.\n \"\"\"\n\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n obs_dim = int(scheme[\"obs\"][\"vshape\"])\n state_dim = int(scheme[\"state\"][\"vshape\"])\n self.obs_dim = obs_dim\n self.state_dim = state_dim\n\n # Determine input dimensions\n local_input_dim = obs_dim + self.n_agents # obs + agent one-hot\n if args.obs_last_action:\n local_input_dim += scheme[\"actions_onehot\"][\"vshape\"][0]\n self.local_input_dim = local_input_dim\n\n # Local encoder\n self.local_encoder = nn.Sequential(\n nn.Linear(local_input_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n )\n # Global encoder (processes state)\n self.global_encoder = nn.Sequential(\n nn.Linear(state_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n )\n # Gating network\n self.gate_net = nn.Sequential(\n nn.Linear(2 * args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, 1),\n nn.Sigmoid(),\n )\n # Fusion MLP\n self.fusion_mlp = nn.Sequential(\n nn.Linear(args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, 1),\n )\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n state = state.unsqueeze(2).expand(-1, -1, self.n_agents, -1) # (B, T, n, state_dim)\n\n # Build local inputs\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1) # (B, T, n, n_agents)\n local_inputs = [obs, agent_id]\n\n if self.args.obs_last_action:\n # Incorporate last actions\n if t is None:\n # whole sequence: last_actions = previous step actions\n last_actions = th.cat(\n [th.zeros_like(batch[\"actions_onehot\"][:, 0:1]), batch[\"actions_onehot\"][:, :-1]],\n dim=1\n ) # (B, T, n, n_actions)\n last_actions = last_actions.view(bs, max_t, self.n_agents, -1)\n elif t == 0:\n last_actions = th.zeros_like(batch[\"actions_onehot\"][:, 0:1]).view(bs, 1, self.n_agents, -1)\n else:\n last_actions = batch[\"actions_onehot\"][:, t-1:t].view(bs, 1, self.n_agents, -1)\n local_inputs.append(last_actions)\n\n local_input = th.cat(local_inputs, dim=-1) # (B, T, n, local_input_dim)\n\n # Encode\n local_feat = self.local_encoder(local_input) # (B, T, n, hidden_dim)\n global_feat = self.global_encoder(state) # (B, T, n, hidden_dim)\n\n # Gating\n gate_input = th.cat([local_feat, global_feat], dim=-1)\n gate = self.gate_net(gate_input) # (B, T, n, 1)\n\n # Fuse\n fused = gate * local_feat + (1 - gate) * global_feat # (B, T, n, hidden_dim)\n\n # Output value\n q = self.fusion_mlp(fused) # (B, T, n, 1)\n return q"
7}OK: Replaced lines 13..69 with 105 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–117.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–117 | total: 117 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8
...
13 class CustomCritic(nn.Module):
14 """Gated Feature Fusion centralized critic for MAPPO.
15
... (99 more lines) ...
115 # Output value
116 q = self.fusion_mlp(fused) # (B, T, n, 1)
117 return qTool Input
1{
2 "op": "replace",
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 13,
5 "end_line": 117,
6 "content": "class CustomCritic(nn.Module):\n \"\"\"Gated Feature Fusion centralized critic for MAPPO.\n\n Combines per-agent local observations with global state via a learned gating\n mechanism. Each agent's value is computed as:\n local_feat = MLP(obs_i ⊕ agent_id ⊕ [last_action_i])\n global_feat = MLP(state ⊕ [flattened_observations])\n gate = sigmoid(MLP(local_feat ⊕ global_feat))\n fused = gate * local_feat + (1 - gate) * global_feat\n value = MLP(fused)\n\n This allows adaptive weighting between local and global information.\n \"\"\"\n\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n obs_dim = int(scheme[\"obs\"][\"vshape\"])\n state_dim = int(scheme[\"state\"][\"vshape\"])\n self.obs_dim = obs_dim\n self.state_dim = state_dim\n\n # Determine input dimensions\n local_input_dim = obs_dim + self.n_agents # obs + agent one-hot\n if args.obs_last_action:\n local_input_dim += scheme[\"actions_onehot\"][\"vshape\"][0]\n self.local_input_dim = local_input_dim\n\n global_input_dim = state_dim\n if args.obs_individual_obs:\n global_input_dim += obs_dim * self.n_agents\n self.global_input_dim = global_input_dim\n\n # Local encoder\n self.local_encoder = nn.Sequential(\n nn.Linear(local_input_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n )\n # Global encoder (processes state + possibly flattened observations)\n self.global_encoder = nn.Sequential(\n nn.Linear(global_input_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n )\n # Gating network\n self.gate_net = nn.Sequential(\n nn.Linear(2 * args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, 1),\n nn.Sigmoid(),\n )\n # Fusion MLP\n self.fusion_mlp = nn.Sequential(\n nn.Linear(args.hidden_dim, args.hidden_dim),\n nn.ReLU(),\n nn.Linear(args.hidden_dim, 1),\n )\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n\n # Build local inputs\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1) # (B, T, n, n_agents)\n local_inputs = [obs, agent_id]\n\n if self.args.obs_last_action:\n # Incorporate last actions\n if t is None:\n # whole sequence: last_actions = previous step actions\n last_actions = th.cat(\n [th.zeros_like(batch[\"actions_onehot\"][:, 0:1]), batch[\"actions_onehot\"][:, :-1]],\n dim=1\n ) # (B, T, n, n_actions)\n last_actions = last_actions.view(bs, max_t, self.n_agents, -1)\n elif t == 0:\n last_actions = th.zeros_like(batch[\"actions_onehot\"][:, 0:1]).view(bs, 1, self.n_agents, -1)\n else:\n last_actions = batch[\"actions_onehot\"][:, t-1:t].view(bs, 1, self.n_agents, -1)\n local_inputs.append(last_actions)\n\n local_input = th.cat(local_inputs, dim=-1) # (B, T, n, local_input_dim)\n\n # Build global inputs\n global_inputs = [state.unsqueeze(2).expand(-1, -1, self.n_agents, -1)] # (B, T, n, state_dim)\n if self.args.obs_individual_obs:\n # Flatten observations across agents\n flat_obs = obs.view(bs, max_t, -1) # (B, T, n*obs_dim)\n flat_obs = flat_obs.unsqueeze(2).expand(-1, -1, self.n_agents, -1) # (B, T, n, n*obs_dim)\n global_inputs.append(flat_obs)\n global_input = th.cat(global_inputs, dim=-1) # (B, T, n, global_input_dim)\n\n # Encode\n local_feat = self.local_encoder(local_input) # (B, T, n, hidden_dim)\n global_feat = self.global_encoder(global_input) # (B, T, n, hidden_dim)\n\n # Gating\n gate_input = th.cat([local_feat, global_feat], dim=-1)\n gate = self.gate_net(gate_input) # (B, T, n, 1)\n\n # Fuse\n fused = gate * local_feat + (1 - gate) * global_feat # (B, T, n, hidden_dim)\n\n # Output value\n q = self.fusion_mlp(fused) # (B, T, n, 1)\n return q"
7}OK: Replaced lines 13..117 with 118 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–130.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–130 | total: 130 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8
...
13 class CustomCritic(nn.Module):
14 """Gated Feature Fusion centralized critic for MAPPO.
15
... (112 more lines) ...
128 # Output value
129 q = self.fusion_mlp(fused) # (B, T, n, 1)
130 return q[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5002921 return_mean=23.0704 return_std=1.0577 battle_won_mean=0.9688 TRAIN_METRICS t_env=5012926 return_mean=23.1879 return_std=1.0573 battle_won_mean=0.9727 TRAIN_METRICS t_env=5023896 return_mean=23.1903 return_std=0.8815 battle_won_mean=0.9861 TRAIN_METRICS t_env=5034820 return_mean=23.0402 return_std=1.4911 battle_won_mean=0.9757 TRAIN_METRICS t_env=5044923 return_mean=23.0474 return_std=1.3343 battle_won_mean=0.9570 Final test for mmm (t_env=5021493): Win rate: 0.9375 Mean return: 22.7363 +/- 1.9098 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5007628 return_mean=18.9021 return_std=2.3894 battle_won_mean=0.8156 TRAIN_METRICS t_env=5018409 return_mean=18.7273 return_std=2.4549 battle_won_mean=0.7781 TRAIN_METRICS t_env=5028884 return_mean=19.0800 return_std=2.2289 battle_w
Tool Input
1{
2 "op": "replace",
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 7,
5 "end_line": 8,
6 "content": "# ── Custom imports (editable) ────────────────────────────────────────────\nfrom torch.nn import TransformerEncoderLayer, TransformerEncoder"
7}OK: Replaced lines 7..8 with 2 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–130.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–130 | total: 130 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8 from torch.nn import TransformerEncoderLayer, TransformerEncoder
...
13 class CustomCritic(nn.Module):
14 """Gated Feature Fusion centralized critic for MAPPO.
15
... (112 more lines) ...
128 # Output value
129 q = self.fusion_mlp(fused) # (B, T, n, 1)
130 return qTool Input
1{
2 "op": "replace",
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 13,
5 "end_line": 130,
6 "content": "class CustomCritic(nn.Module):\n \"\"\"Multi-Agent Transformer with State Gating (MAT-SG) centralized critic.\n\n Per-agent tokens (obs_i + agent_id + [last_action_i]) are processed by a\n transformer encoder layer to enable inter-agent communication. A separate\n global state encoder produces a context feature. A gating mechanism\n adaptively combines the transformer output with the global feature before\n predicting per-agent values.\n \"\"\"\n\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n obs_dim = int(scheme[\"obs\"][\"vshape\"])\n state_dim = int(scheme[\"state\"][\"vshape\"])\n self.obs_dim = obs_dim\n self.state_dim = state_dim\n\n # Determine token dimension\n token_dim = obs_dim + self.n_agents # obs + agent one-hot\n if args.obs_last_action:\n token_dim += scheme[\"actions_onehot\"][\"vshape\"][0]\n self.token_dim = token_dim\n\n global_input_dim = state_dim\n if args.obs_individual_obs:\n global_input_dim += obs_dim * self.n_agents\n self.global_input_dim = global_input_dim\n\n self.d_model = args.hidden_dim\n\n # Per-agent token projection\n self.token_proj = nn.Sequential(\n nn.Linear(token_dim, self.d_model),\n nn.ReLU(),\n nn.Linear(self.d_model, self.d_model),\n nn.LayerNorm(self.d_model),\n )\n\n # Transformer encoder layer (self-attention across agents)\n self.transformer = TransformerEncoder(\n TransformerEncoderLayer(\n d_model=self.d_model,\n nhead=4,\n dim_feedforward=4 * self.d_model,\n dropout=0.0,\n batch_first=True,\n activation=\"gelu\",\n ),\n num_layers=1,\n )\n\n # Global state encoder\n self.global_encoder = nn.Sequential(\n nn.Linear(global_input_dim, self.d_model),\n nn.ReLU(),\n nn.Linear(self.d_model, self.d_model),\n nn.LayerNorm(self.d_model),\n )\n\n # Gating network\n self.gate_net = nn.Sequential(\n nn.Linear(2 * self.d_model, self.d_model),\n nn.ReLU(),\n nn.Linear(self.d_model, 1),\n nn.Sigmoid(),\n )\n\n # Value head\n self.v_head = nn.Sequential(\n nn.Linear(self.d_model, self.d_model),\n nn.ReLU(),\n nn.Linear(self.d_model, 1),\n )\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n\n # Build token inputs\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1) # (B, T, n, n_agents)\n token_inputs = [obs, agent_id]\n\n if self.args.obs_last_action:\n # Incorporate last actions\n if t is None:\n # whole sequence: last_actions = previous step actions\n last_actions = th.cat(\n [th.zeros_like(batch[\"actions_onehot\"][:, 0:1]), batch[\"actions_onehot\"][:, :-1]],\n dim=1\n ) # (B, T, n, n_actions)\n last_actions = last_actions.view(bs, max_t, self.n_agents, -1)\n elif t == 0:\n last_actions = th.zeros_like(batch[\"actions_onehot\"][:, 0:1]).view(bs, 1, self.n_agents, -1)\n else:\n last_actions = batch[\"actions_onehot\"][:, t-1:t].view(bs, 1, self.n_agents, -1)\n token_inputs.append(last_actions)\n\n token_input = th.cat(token_inputs, dim=-1) # (B, T, n, token_dim)\n\n # Build global inputs\n global_inputs = [state.unsqueeze(2).expand(-1, -1, self.n_agents, -1)] # (B, T, n, state_dim)\n if self.args.obs_individual_obs:\n # Flatten observations across agents\n flat_obs = obs.view(bs, max_t, -1) # (B, T, n*obs_dim)\n flat_obs = flat_obs.unsqueeze(2).expand(-1, -1, self.n_agents, -1) # (B, T, n, n*obs_dim)\n global_inputs.append(flat_obs)\n global_input = th.cat(global_inputs, dim=-1) # (B, T, n, global_input_dim)\n\n # Project tokens\n tokens = self.token_proj(token_input) # (B, T, n, d_model)\n\n # Transformer expects (batch, seq_len, d_model); we treat agents as sequence.\n # Flatten (B, T) dimension, apply transformer, then restore.\n b, tt, n, d = tokens.shape\n tokens = tokens.reshape(b * tt, n, d)\n trans_out = self.transformer(tokens) # (B*T, n, d_model)\n trans_out = trans_out.reshape(b, tt, n, d) # (B, T, n, d_model)\n\n # Global feature\n global_feat = self.global_encoder(global_input) # (B, T, n, d_model)\n\n # Gating\n gate_input = th.cat([trans_out, global_feat], dim=-1)\n gate = self.gate_net(gate_input) # (B, T, n, 1)\n\n # Fuse\n fused = gate * trans_out + (1 - gate) * global_feat # (B, T, n, d_model)\n\n # Value head\n q = self.v_head(fused) # (B, T, n, 1)\n return q"
7}OK: Replaced lines 13..130 with 141 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–153.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–153 | total: 153 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8 from torch.nn import TransformerEncoderLayer, TransformerEncoder
...
13 class CustomCritic(nn.Module):
14 """Multi-Agent Transformer with State Gating (MAT-SG) centralized critic.
15
... (135 more lines) ...
151 # Value head
152 q = self.v_head(fused) # (B, T, n, 1)
153 return q[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5007398 return_mean=23.1212 return_std=2.2774 battle_won_mean=0.9286 TRAIN_METRICS t_env=5018242 return_mean=23.1149 return_std=2.2768 battle_won_mean=0.9062 TRAIN_METRICS t_env=5028756 return_mean=22.6901 return_std=2.9503 battle_won_mean=0.8795 TRAIN_METRICS t_env=5039581 return_mean=22.4543 return_std=2.9675 battle_won_mean=0.8393 TRAIN_METRICS t_env=5050043 return_mean=22.6369 return_std=2.9312 battle_won_mean=0.8929 Final test for mmm (t_env=5033364): Win rate: 0.8750 Mean return: 22.3823 +/- 3.1649 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5009388 return_mean=18.3443 return_std=3.1750 battle_won_mean=0.7656 TRAIN_METRICS t_env=5019697 return_mean=18.3365 return_std=3.2921 battle_won_mean=0.7773 TRAIN_METRICS t_env=5030074 return_mean=18.2129 return_std=3.2439 battle_wo
Tool Input
1{
2 "op": "replace",
3 "filename": "epymarl/src/modules/critics/custom_critic.py",
4 "start_line": 13,
5 "end_line": 153,
6 "content": "class CustomCritic(nn.Module):\n \"\"\"FiLM-conditioned centralized critic for MAPPO.\n\n Global state (and optionally flattened observations) produces per-agent\n scale and shift parameters via a FiLM generator. These modulate the\n per-agent local features (obs + agent_id + last_action). The modulated\n features are then combined with the global feature via residual addition\n before predicting per-agent values.\n \"\"\"\n\n def __init__(self, scheme, args):\n super(CustomCritic, self).__init__()\n self.args = args\n self.n_agents = args.n_agents\n self.n_actions = args.n_actions\n self.output_type = \"v\"\n\n obs_dim = int(scheme[\"obs\"][\"vshape\"])\n state_dim = int(scheme[\"state\"][\"vshape\"])\n self.obs_dim = obs_dim\n self.state_dim = state_dim\n\n # Determine input dimensions\n local_input_dim = obs_dim + self.n_agents # obs + agent one-hot\n if args.obs_last_action:\n local_input_dim += scheme[\"actions_onehot\"][\"vshape\"][0]\n self.local_input_dim = local_input_dim\n\n global_input_dim = state_dim\n if args.obs_individual_obs:\n global_input_dim += obs_dim * self.n_agents\n self.global_input_dim = global_input_dim\n\n self.hidden_dim = args.hidden_dim\n\n # Local encoder\n self.local_encoder = nn.Sequential(\n nn.Linear(local_input_dim, self.hidden_dim),\n nn.ReLU(),\n nn.Linear(self.hidden_dim, self.hidden_dim),\n nn.LayerNorm(self.hidden_dim),\n )\n\n # Global encoder (produces base global feature)\n self.global_encoder = nn.Sequential(\n nn.Linear(global_input_dim, self.hidden_dim),\n nn.ReLU(),\n nn.Linear(self.hidden_dim, self.hidden_dim),\n nn.LayerNorm(self.hidden_dim),\n )\n\n # FiLM generator: from global feature to scale and shift\n self.film_generator = nn.Sequential(\n nn.Linear(self.hidden_dim, 2 * self.hidden_dim),\n nn.ReLU(),\n nn.Linear(2 * self.hidden_dim, 2 * self.hidden_dim),\n )\n\n # Fusion MLP (after FiLM modulation)\n self.fusion_mlp = nn.Sequential(\n nn.Linear(self.hidden_dim, self.hidden_dim),\n nn.ReLU(),\n nn.Linear(self.hidden_dim, 1),\n )\n\n def forward(self, batch, t=None):\n bs = batch.batch_size\n max_t = batch.max_seq_length if t is None else 1\n ts = slice(None) if t is None else slice(t, t + 1)\n\n obs = batch[\"obs\"][:, ts] # (B, T, n, obs_dim)\n state = batch[\"state\"][:, ts] # (B, T, state_dim)\n\n # Build local inputs\n agent_id = th.eye(self.n_agents, device=batch.device)\n agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1) # (B, T, n, n_agents)\n local_inputs = [obs, agent_id]\n\n if self.args.obs_last_action:\n # Incorporate last actions\n if t is None:\n # whole sequence: last_actions = previous step actions\n last_actions = th.cat(\n [th.zeros_like(batch[\"actions_onehot\"][:, 0:1]), batch[\"actions_onehot\"][:, :-1]],\n dim=1\n ) # (B, T, n, n_actions)\n last_actions = last_actions.view(bs, max_t, self.n_agents, -1)\n elif t == 0:\n last_actions = th.zeros_like(batch[\"actions_onehot\"][:, 0:1]).view(bs, 1, self.n_agents, -1)\n else:\n last_actions = batch[\"actions_onehot\"][:, t-1:t].view(bs, 1, self.n_agents, -1)\n local_inputs.append(last_actions)\n\n local_input = th.cat(local_inputs, dim=-1) # (B, T, n, local_input_dim)\n\n # Build global inputs\n global_inputs = [state.unsqueeze(2).expand(-1, -1, self.n_agents, -1)] # (B, T, n, state_dim)\n if self.args.obs_individual_obs:\n # Flatten observations across agents\n flat_obs = obs.view(bs, max_t, -1) # (B, T, n*obs_dim)\n flat_obs = flat_obs.unsqueeze(2).expand(-1, -1, self.n_agents, -1) # (B, T, n, n*obs_dim)\n global_inputs.append(flat_obs)\n global_input = th.cat(global_inputs, dim=-1) # (B, T, n, global_input_dim)\n\n # Encode\n local_feat = self.local_encoder(local_input) # (B, T, n, hidden_dim)\n global_feat = self.global_encoder(global_input) # (B, T, n, hidden_dim)\n\n # FiLM modulation\n film_params = self.film_generator(global_feat) # (B, T, n, 2*hidden_dim)\n scale, shift = th.chunk(film_params, 2, dim=-1) # each (B, T, n, hidden_dim)\n modulated_local = scale * local_feat + shift\n\n # Residual combination\n fused = modulated_local + global_feat # (B, T, n, hidden_dim)\n\n # Output value\n q = self.fusion_mlp(fused) # (B, T, n, 1)\n return q"
7}OK: Replaced lines 13..153 with 119 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–8, 13–131.
[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–8, 13–131 | total: 131 lines]
7 # ── Custom imports (editable) ────────────────────────────────────────────
8 from torch.nn import TransformerEncoderLayer, TransformerEncoder
...
13 class CustomCritic(nn.Module):
14 """FiLM-conditioned centralized critic for MAPPO.
15
... (113 more lines) ...
129 # Output value
130 q = self.fusion_mlp(fused) # (B, T, n, 1)
131 return q[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=1126033 return_mean=23.1863 return_std=3.2733 battle_won_mean=0.8562 TRAIN_METRICS t_env=1136532 return_mean=23.4843 return_std=2.8687 battle_won_mean=0.9000 TRAIN_METRICS t_env=1146890 return_mean=23.6116 return_std=2.6975 battle_won_mean=0.9000 TRAIN_METRICS t_env=1157385 return_mean=23.5266 return_std=2.8912 battle_won_mean=0.9062 TRAIN_METRICS t_env=1168072 return_mean=23.7615 return_std=2.7397 battle_won_mean=0.9250 Final test for mmm (t_env=1123884): Win rate: 0.8438 Mean return: 22.9774 +/- 3.6075 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=1453959 return_mean=19.0537 return_std=2.2091 battle_won_mean=0.8359 TRAIN_METRICS t_env=1464075 return_mean=
Tool Input
1{
2 "n": 1
3}[submit] Submitting result from test #1 as final. ### mmm (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5002921 return_mean=23.0704 return_std=1.0577 battle_won_mean=0.9688 TRAIN_METRICS t_env=5012926 return_mean=23.1879 return_std=1.0573 battle_won_mean=0.9727 TRAIN_METRICS t_env=5023896 return_mean=23.1903 return_std=0.8815 battle_won_mean=0.9861 TRAIN_METRICS t_env=5034820 return_mean=23.0402 return_std=1.4911 battle_won_mean=0.9757 TRAIN_METRICS t_env=5044923 return_mean=23.0474 return_std=1.3343 battle_won_mean=0.9570 Final test for mmm (t_env=5021493): Win rate: 0.9375 Mean return: 22.7363 +/- 1.9098 ### 2s3z (scripts/train.sh) Training metrics (last steps): TRAIN_METRICS t_env=5007628 return_mean=18.9021 return_std=2.3894 battle_won_mean=0.8156 TRAIN_METRICS t_env=5018409 return_mean=18.7273 return_std=2.4549 battle_won_mean=0.7781 TRAIN_METRICS t_env=5028884 return_mean=19.0800 return_std=2.2289 battle_won_mean=0.8438 TRAIN_METRICS t_env=5039565 re