Agent Conversation: claude-opus-4.6

Model: claude-opus-4.6
Total steps: 6
Messages: 14
UserStep 0

Task: marl-centralized-critic

Cooperative MARL: Centralized Critic Architecture for MAPPO

Objective

Improve cooperative multi-agent reinforcement learning by designing a better centralized critic architecture for MAPPO (Multi-Agent PPO). You can modify the CustomCritic class (lines 13-69) and add custom imports (lines 7-8) in custom_critic.py.

Background

In cooperative MARL with partial observability, each agent only sees a local observation but the team shares a common reward. Centralized-Training-with-Decentralized-Execution (CTDE) methods train a centralized value function during training (which can see the global state and all agents' information) and use it to reduce variance when computing advantages for each agent's decentralized policy gradient update. The architecture of this centralized critic — what it conditions on and how it mixes per-agent features — directly determines how tight the bias-variance tradeoff is and therefore how well MAPPO scales to hard multi-agent cooperation tasks.

The training uses EPyMARL's ppo_learner with the MAPPO default hyperparameters on three SMAC maps via smaclite (a pure-Python reimplementation of the StarCraft Multi-Agent Challenge benchmark — no StarCraft II binary required):

  • mmm — 1 Medivac + 2 Marauders + 7 Marines (team of 10) vs mirror; heterogeneous cooperation requiring heal micro (≈5M env steps).
  • 2s3z — 2 Stalkers + 3 Zealots (team of 5) vs mirror; medium heterogeneous team (≈5M env steps).
  • 3s5z — 3 Stalkers + 5 Zealots (team of 8) vs mirror; larger team, harder (≈5M env steps).

The default critic is a simple (state ⊕ agent-one-hot) → 3-layer MLP → V that ignores per-agent observations. It is a working baseline but leaves room for smarter architectures that integrate per-agent features, attention, or state conditioning.

Interface

Your CustomCritic must:

  • Inherit from nn.Module.
  • Accept (scheme, args) in __init__, where:
    • scheme["state"]["vshape"] — global state dim
    • scheme["obs"]["vshape"] — per-agent observation dim
    • args.n_agents, args.n_actions, args.hidden_dim, args.obs_last_action, args.obs_individual_obs
  • Set self.output_type = "v" in __init__.
  • Implement forward(self, batch, t=None) where:
    • batch["state"] — shape (B, T, state_dim)
    • batch["obs"] — shape (B, T, n_agents, obs_dim)
    • batch.batch_size, batch.max_seq_length, batch.device
    • t=None means "whole sequence"; otherwise t is an integer
    • Returns q with shape (B, T, n_agents, 1) — the learner later does .squeeze(3), so the trailing singleton is mandatory.

Reference Implementations

  • IPPO critic (ippo_critic.edit.py): per-agent MLP over batch["obs"] ⊕ agent-one-hot; no centralization. Floor baseline from Yu et al. 2022's IPPO ablation. Also see epymarl/src/modules/critics/ac.py.
  • MAPPO critic (mappo_critic.edit.py): shared MLP over (batch["state"] ⊕ agent-one-hot). Standard MAPPO c
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "epymarl/src/modules/critics/custom_critic.py",
4  "start_line": 7,
5  "end_line": 8,
6  "content": "# ── Custom imports (editable) ────────────────────────────────────────────\n"
7}
Tool ResultStep 1
OK: Replaced lines 7..8 with 1 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–7, 12–68.

[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–7, 12–68 | total: 68 lines]
     7  # ── Custom imports (editable) ────────────────────────────────────────────
...
    12  class CustomCritic(nn.Module):
    13      """Centralized critic for MAPPO on SMAC (via smaclite).
    14  
       ... (51 more lines) ...
    66          x = F.relu(self.fc2(x))
    67          q = self.fc3(x)                                                  # (B, T, n, 1)
    68          return q
AssistantStep 2edit
Tool Input
1{
2  "op": "replace",
3  "filename": "epymarl/src/modules/critics/custom_critic.py",
4  "start_line": 12,
5  "end_line": 68,
6  "content": "class CustomCritic(nn.Module):\n    \"\"\"Observation-Gated Centralized Critic (OGC²) for MAPPO.\n\n    Algorithmic contribution: Feature-wise Linear Modulation (FiLM) fusion\n    of global state and per-agent observations.\n\n    In cooperative MARL, each agent's value depends on both the global team\n    state AND the agent's local situation. The same global state has different\n    value implications for different agents depending on their local context\n    (e.g., a low-health unit should value retreat more than a full-health one).\n\n    Standard MAPPO ignores per-agent obs in the critic (state + agent_id only).\n    Simple concatenation of state + obs treats them symmetrically. Instead, we\n    use FiLM conditioning: each agent's obs encoder produces per-feature scale\n    (gamma) and shift (beta) that multiplicatively modulate the state\n    representation. This creates agent-specific \"lenses\" through which to\n    interpret the global state, capturing the interaction between global\n    context and local situation.\n\n    Architecture:\n      state  → Linear → ReLU → h_s (hidden state features)\n      obs+id → Linear → ReLU → h_o (hidden obs features)\n      h_o → Linear → Sigmoid → γ  (per-feature gate)\n      h_o → Linear           → β  (per-feature shift)\n      modulated = γ ⊙ h_s + β + h_o  (FiLM + residual)\n      LayerNorm(modulated) → Linear → ReLU → Linear → V\n    \"\"\"\n\n    def __init__(self, scheme, args):\n        super(CustomCritic, self).__init__()\n        self.args = args\n        self.n_agents = args.n_agents\n        self.n_actions = args.n_actions\n        self.output_type = \"v\"\n\n        obs_dim = int(scheme[\"obs\"][\"vshape\"])\n        state_dim = int(scheme[\"state\"][\"vshape\"])\n        H = args.hidden_dim\n\n        # Dual-stream first-layer encoders\n        self.state_fc = nn.Linear(state_dim, H)\n        self.obs_fc = nn.Linear(obs_dim + self.n_agents, H)\n\n        # FiLM generators: obs features → per-feature scale & shift for state\n        self.film_gamma = nn.Linear(H, H)   # → sigmoid → multiplicative gate\n        self.film_beta = nn.Linear(H, H)    # → additive shift\n\n        # Layer norm after modulation for training stability\n        self.ln = nn.LayerNorm(H)\n\n        # Value head (shared across agents)\n        self.fc2 = nn.Linear(H, H)\n        self.fc3 = nn.Linear(H, 1)\n\n    def forward(self, batch, t=None):\n        bs = batch.batch_size\n        max_t = batch.max_seq_length if t is None else 1\n        ts = slice(None) if t is None else slice(t, t + 1)\n\n        # Extract inputs\n        obs = batch[\"obs\"][:, ts]                                        # (B, T, n, obs_dim)\n        state = batch[\"state\"][:, ts]                                    # (B, T, state_dim)\n        state = state.unsqueeze(2).expand(-1, -1, self.n_agents, -1)     # (B, T, n, state_dim)\n\n        agent_id = th.eye(self.n_agents, device=batch.device)\n        agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)\n\n        # Encode state stream\n        h_s = F.relu(self.state_fc(state))                               # (B, T, n, H)\n\n        # Encode obs stream (obs + agent one-hot)\n        obs_input = th.cat([obs, agent_id], dim=-1)                      # (B, T, n, obs+n)\n        h_o = F.relu(self.obs_fc(obs_input))                             # (B, T, n, H)\n\n        # FiLM: obs-conditioned gating of state features\n        gamma = th.sigmoid(self.film_gamma(h_o))                         # (B, T, n, H) ∈ (0,1)\n        beta = self.film_beta(h_o)                                       # (B, T, n, H)\n\n        # Modulate state features + residual from obs stream\n        modulated = gamma * h_s + beta + h_o                             # (B, T, n, H)\n        modulated = self.ln(modulated)\n\n        # Value head\n        x = F.relu(self.fc2(modulated))                                  # (B, T, n, H)\n        q = self.fc3(x)                                                  # (B, T, n, 1)\n        return q"
7}
Tool ResultStep 2
OK: Replaced lines 12..68 with 86 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–7, 12–97.

[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–7, 12–97 | total: 97 lines]
     7  # ── Custom imports (editable) ────────────────────────────────────────────
...
    12  class CustomCritic(nn.Module):
    13      """Observation-Gated Centralized Critic (OGC²) for MAPPO.
    14  
       ... (80 more lines) ...
    95          x = F.relu(self.fc2(modulated))                                  # (B, T, n, H)
    96          q = self.fc3(x)                                                  # (B, T, n, 1)
    97          return q
AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### mmm (scripts/train.sh)
Training metrics (last steps):
TRAIN_METRICS t_env=5004408 return_mean=23.0683 return_std=1.9680 battle_won_mean=0.9492
TRAIN_METRICS t_env=5015062 return_mean=23.0563 return_std=1.7606 battle_won_mean=0.9492
TRAIN_METRICS t_env=5025185 return_mean=23.2357 return_std=1.9818 battle_won_mean=0.9420
TRAIN_METRICS t_env=5036526 return_mean=22.9965 return_std=2.1114 battle_won_mean=0.9297
TRAIN_METRICS t_env=5047493 return_mean=22.9963 return_std=2.2469 battle_won_mean=0.9375
Final test for mmm (t_env=5032206):
  Win rate: 0.9062
  Mean return: 22.6018 +/- 2.9762

### 2s3z (scripts/train.sh)
Training metrics (last steps):
TRAIN_METRICS t_env=4990619 return_mean=10.2005 return_std=0.7899 battle_won_mean=0.0000
TRAIN_METRICS t_env=5005019 return_mean=10.0051 return_std=0.9426 battle_won_mean=0.0000
TRAIN_METRICS t_env=5019419 return_mean=10.0766 return_std=0.7336 battle_w
AssistantStep 4edit
Tool Input
1{
2  "content": "class CustomCritic(nn.Module):\n    \"\"\"Observation-Gated Centralized Critic (OGC²) for MAPPO.\n\n    Algorithmic contribution: Dual-path critic with observation-conditioned\n    gating of global state features, plus a stabilizing residual path.\n\n    In cooperative MARL, each agent's value depends on both the global team\n    state AND the agent's local situation. Standard MAPPO uses only\n    state + agent_id, ignoring per-agent observations in the critic.\n\n    Our approach: the agent's local observation generates a soft gate that\n    selectively amplifies or suppresses features of the global state encoding.\n    This lets each agent extract value-relevant information from the shared\n    state based on its own situation (e.g., a low-health agent values\n    retreat states differently than a full-health one).\n\n    Two paths are combined:\n      1. Gated path:  obs → gate → element-wise modulate state features\n      2. Direct path:  state + agent_id → standard encoding (stabilizer)\n    The combination provides both novelty (gating) and stability (direct path).\n    \"\"\"\n\n    def __init__(self, scheme, args):\n        super(CustomCritic, self).__init__()\n        self.args = args\n        self.n_agents = args.n_agents\n        self.n_actions = args.n_actions\n        self.output_type = \"v\"\n\n        obs_dim = int(scheme[\"obs\"][\"vshape\"])\n        state_dim = int(scheme[\"state\"][\"vshape\"])\n        H = args.hidden_dim\n\n        # ── Direct path: standard MAPPO-style (state + agent_id) ──\n        self.direct_fc1 = nn.Linear(state_dim + self.n_agents, H)\n\n        # ── Gated path: obs-conditioned modulation of state ──\n        # State encoder\n        self.state_fc = nn.Linear(state_dim, H)\n        # Obs encoder (generates gate)\n        self.obs_fc = nn.Linear(obs_dim + self.n_agents, H)\n        # Gate projection\n        self.gate_fc = nn.Linear(H, H)\n\n        # ── Fusion: combine both paths ──\n        self.fc2 = nn.Linear(H, H)\n        self.fc3 = nn.Linear(H, 1)\n\n    def forward(self, batch, t=None):\n        bs = batch.batch_size\n        max_t = batch.max_seq_length if t is None else 1\n        ts = slice(None) if t is None else slice(t, t + 1)\n\n        # Extract inputs\n        obs = batch[\"obs\"][:, ts]                                        # (B, T, n, obs_dim)\n        state = batch[\"state\"][:, ts]                                    # (B, T, state_dim)\n        state = state.unsqueeze(2).expand(-1, -1, self.n_agents, -1)     # (B, T, n, state_dim)\n\n        agent_id = th.eye(self.n_agents, device=batch.device)\n        agent_id = agent_id.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)\n\n        # ── Direct path (stabilizer): standard state + agent_id ──\n        direct_input = th.cat([state, agent_id], dim=-1)                 # (B, T, n, state+n)\n        h_direct = F.relu(self.direct_fc1(direct_input))                 # (B, T, n, H)\n\n        # ── Gated path: obs-conditioned state modulation ──\n        h_s = F.relu(self.state_fc(state))                               # (B, T, n, H)\n        obs_input = th.cat([obs, agent_id], dim=-1)                      # (B, T, n, obs+n)\n        h_o = F.relu(self.obs_fc(obs_input))                             # (B, T, n, H)\n        gate = th.sigmoid(self.gate_fc(h_o))                             # (B, T, n, H) ∈ (0,1)\n        h_gated = gate * h_s                                             # (B, T, n, H)\n\n        # ── Combine: gated features + direct path ──\n        combined = h_direct + h_gated                                    # (B, T, n, H)\n\n        # ── Value head ──\n        x = F.relu(self.fc2(combined))                                   # (B, T, n, H)\n        q = self.fc3(x)                                                  # (B, T, n, 1)\n        return q",
3  "end_line": 97,
4  "filename": "epymarl/src/modules/critics/custom_critic.py",
5  "op": "replace",
6  "start_line": 12
7}
Tool ResultStep 4
OK: Replaced lines 12..97 with 79 line(s) in epymarl/src/modules/critics/custom_critic.py. Editable range: 7–7, 12–90.

[Current file: epymarl/src/modules/critics/custom_critic.py | editable: 7–7, 12–90 | total: 90 lines]
     7  # ── Custom imports (editable) ────────────────────────────────────────────
...
    12  class CustomCritic(nn.Module):
    13      """Observation-Gated Centralized Critic (OGC²) for MAPPO.
    14  
       ... (73 more lines) ...
    88          x = F.relu(self.fc2(combined))                                   # (B, T, n, H)
    89          q = self.fc3(x)                                                  # (B, T, n, 1)
    90          return q
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)

### mmm (scripts/train.sh)
Training metrics (last steps):
TRAIN_METRICS t_env=4997963 return_mean=23.2905 return_std=1.0618 battle_won_mean=0.9826
TRAIN_METRICS t_env=5008942 return_mean=23.3561 return_std=1.0968 battle_won_mean=0.9931
TRAIN_METRICS t_env=5019099 return_mean=23.2060 return_std=1.5927 battle_won_mean=0.9688
TRAIN_METRICS t_env=5029968 return_mean=23.1149 return_std=1.5201 battle_won_mean=0.9792
TRAIN_METRICS t_env=5041122 return_mean=23.0837 return_std=1.6447 battle_won_mean=0.9653
Final test for mmm (t_env=5012717):
  Win rate: 0.9688
  Mean return: 23.4138 +/- 1.2980

### 2s3z (scripts/train.sh)
Training metrics (last steps):
TRAIN_METRICS t_env=5004410 return_mean=19.4788 return_std=1.9169 battle_won_mean=0.9062
TRAIN_METRICS t_env=5014946 return_mean=19.7606 return_std=1.5216 battle_won_mean=0.9531
TRAIN_METRICS t_env=5025401 return_mean=19.7544 return_std=1.4147 battle_wo
AssistantStep 6submit
Tool Input
1{
2  "n": 2
3}
Tool ResultStep 6
[submit] Submitting result from test #2 as final.

### mmm (scripts/train.sh)
Training metrics (last steps):
TRAIN_METRICS t_env=4997963 return_mean=23.2905 return_std=1.0618 battle_won_mean=0.9826
TRAIN_METRICS t_env=5008942 return_mean=23.3561 return_std=1.0968 battle_won_mean=0.9931
TRAIN_METRICS t_env=5019099 return_mean=23.2060 return_std=1.5927 battle_won_mean=0.9688
TRAIN_METRICS t_env=5029968 return_mean=23.1149 return_std=1.5201 battle_won_mean=0.9792
TRAIN_METRICS t_env=5041122 return_mean=23.0837 return_std=1.6447 battle_won_mean=0.9653
Final test for mmm (t_env=5012717):
  Win rate: 0.9688
  Mean return: 23.4138 +/- 1.2980

### 2s3z (scripts/train.sh)
Training metrics (last steps):
TRAIN_METRICS t_env=5004410 return_mean=19.4788 return_std=1.9169 battle_won_mean=0.9062
TRAIN_METRICS t_env=5014946 return_mean=19.7606 return_std=1.5216 battle_won_mean=0.9531
TRAIN_METRICS t_env=5025401 return_mean=19.7544 return_std=1.4147 battle_won_mean=0.9531
TRAIN_METRICS t_env=5036013 re