Agent Conversation: claude-opus-4.6
Task: rl-reward-learning
Inverse RL: Reward Learning from Expert Demonstrations
Objective
Design and implement an inverse reinforcement learning (IRL) algorithm that learns a reward function from expert demonstrations. Your code goes in custom_irl.py, specifically the RewardNetwork and IRLAlgorithm classes. Three reference implementations (GAIL, AIRL, BC) from the imitation library are provided as read-only context.
Background
Inverse reinforcement learning recovers a reward function that explains observed expert behavior. The learned reward is then used to train a policy via standard RL (PPO in this benchmark). Key challenges include:
- Designing reward network architectures that capture the structure of expert behavior
- Balancing discriminator training with policy improvement
- Avoiding reward hacking where the policy exploits learned reward artifacts
- Ensuring the learned reward generalizes across different states visited during training
Different IRL approaches address these through adversarial training (GAIL), potential-based reward shaping (AIRL), or direct behavioral cloning. Your goal is to design a novel reward network architecture or IRL training algorithm that outperforms these baselines.
Evaluation
Trained and evaluated on three MuJoCo locomotion environments using pre-generated expert demonstrations: HalfCheetah-v4, Hopper-v4, Walker2d-v4. Metric: mean episodic return over 10 evaluation episodes (higher is better). The policy is trained using PPO with the learned reward signal.
imitation/custom_irl.py [EDITABLE — lines 231–357 only]
1: # Custom IRL / Reward Learning algorithm for MLS-Bench
2: #
3: # EDITABLE section: RewardNetwork and IRLAlgorithm classes.
4: # FIXED sections: everything else (config, env, demo loading, PPO training, evaluation).
5: import os
6: import random
7: import time
8: from dataclasses import dataclass
9:
10: import gymnasium as gym
11: import numpy as np
12: import torch
13: import torch.nn as nn
14: import torch.nn.functional as F
15: import torch.optim as optim
16:
17:
18: # =====================================================================
19: # FIXED: Configuration
20: # =====================================================================
21: @dataclass
22: class Args:
23: env_id: str = "HalfCheetah-v4"
24: seed: int = 42
25: torch_deterministic: bool = True
26: cuda: bool = True
27: # IRL training
28: irl_epochs: int = 200
29: irl_batch_size: int = 256
30: irl_lr: float = 3e-4
31: demo_path: str = "" # set from env or CLI
32: # Policy training (PPO via custom loop)
33: total_timesteps: int = 1000000
34: policy_lr: float = 3e-4
35: gamma: float = 0.99
36: gae_lambda: float = 0.95
37: n_steps: int = 2048
38: n_epochs: int = 10
39Tool Input
1{
2 "op": "replace",
3 "filename": "imitation/custom_irl.py",
4 "start_line": 231,
5 "end_line": 357,
6 "content": "class RewardNetwork(nn.Module):\n \"\"\"AIRL-style shaped reward: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).\n\n Decomposes the reward into a state-action reward g(s,a) and a potential-based\n shaping function h(s). This decomposition is theoretically motivated: under\n certain conditions, g(s,a) recovers the true reward up to a constant, and h(s)\n captures the value function shaping that accelerates learning.\n \"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n # g(s, a): reward approximator\n self.g_net = nn.Sequential(\n nn.Linear(obs_dim + action_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n # h(s): potential-based shaping\n self.h_net = nn.Sequential(\n nn.Linear(obs_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.gamma = 0.99\n\n def g(self, state, action):\n \"\"\"Reward function g(s, a).\"\"\"\n x = torch.cat([state, action], dim=-1)\n return self.g_net(x).squeeze(-1)\n\n def h(self, state):\n \"\"\"Potential function h(s).\"\"\"\n return self.h_net(state).squeeze(-1)\n\n def forward(self, state, action, next_state):\n \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s).\"\"\"\n return self.g(state, action) + self.gamma * self.h(next_state) - self.h(state)\n\n\nclass IRLAlgorithm:\n \"\"\"BC-Annealed Wasserstein AIRL.\n\n Novel algorithm that combines three key ideas:\n\n 1. **Wasserstein objective** for discriminator training instead of BCE.\n The Wasserstein distance provides non-vanishing gradients even when the\n discriminator is highly accurate, solving the mode collapse / training\n instability issues that plague standard GAIL. Combined with gradient\n penalty for the Lipschitz constraint.\n\n 2. **AIRL potential-based reward shaping** f(s,a,s') = g(s,a) + γh(s') - h(s).\n This decomposition helps recover a transferable reward function and provides\n a natural value-function shaping that accelerates PPO learning.\n\n 3. **Behavioral cloning warm-start with exponential annealing**. Early in training,\n when the learned reward is unreliable, BC provides strong supervised signal\n directly to the policy. As training progresses, BC weight decays exponentially,\n letting the IRL reward signal take over. This gives BC's fast convergence\n benefits while avoiding its distribution shift limitation.\n\n The combination is novel: BC handles the cold-start problem (initial random policy\n gets no useful IRL signal), Wasserstein AIRL provides stable long-term reward\n learning, and the annealing schedule smoothly transitions between regimes.\n \"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n # Discriminator optimizer with lower beta1 for Wasserstein training stability\n disc_lr = min(args.irl_lr, 1e-4)\n self.optimizer = optim.Adam(\n self.reward_net.parameters(), lr=disc_lr, betas=(0.5, 0.9)\n )\n self.total_updates = 0\n\n # Gradient penalty coefficient for Lipschitz constraint\n self.grad_penalty_coef = 10.0\n\n # BC annealing parameters\n self._policy = None\n self._policy_optimizer = None\n self.bc_weight = 1.0\n self.bc_decay = 0.97 # per IRL round decay (called n_irl_updates_per_round times)\n self.bc_min_weight = 0.005 # floor to maintain slight BC regularization\n self.n_bc_steps = 10 # BC gradient steps per update call\n\n def set_policy(self, policy, optimizer):\n \"\"\"Set reference to policy for BC warm-start updates.\"\"\"\n self._policy = policy\n self._policy_optimizer = optimizer\n\n def compute_reward(self, obs, acts, next_obs):\n \"\"\"Use the full shaped reward f(s,a,s') for policy training.\"\"\"\n with torch.no_grad():\n return self.reward_net(obs, acts, next_obs)\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n \"\"\"Gradient penalty on interpolated data for Lipschitz constraint (WGAN-GP).\"\"\"\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def _bc_update(self):\n \"\"\"Perform BC gradient steps on policy, weighted by annealing schedule.\"\"\"\n if self._policy is None or self.bc_weight < self.bc_min_weight:\n return 0.0\n\n batch_size = self.args.irl_batch_size\n n_expert = len(self.expert_demos[\"obs\"])\n total_bc_loss = 0.0\n\n # Scale number of BC steps with weight (fewer steps as we anneal)\n effective_steps = max(1, int(self.n_bc_steps * min(1.0, self.bc_weight * 2)))\n\n for _ in range(effective_steps):\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n\n # Negative log-likelihood of expert actions under current policy\n _, log_prob, entropy, _ = self._policy.get_action_and_value(\n expert_obs, expert_acts,\n )\n neglogp = -log_prob.mean()\n ent_bonus = -0.001 * entropy.mean()\n bc_loss = self.bc_weight * (neglogp + ent_bonus)\n\n self._policy_optimizer.zero_grad()\n bc_loss.backward()\n nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n self._policy_optimizer.step()\n total_bc_loss += bc_loss.item()\n\n return total_bc_loss / max(effective_steps, 1)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n \"\"\"Combined update: Wasserstein AIRL discriminator + BC warm-start.\n\n Phase 1: Train discriminator with Wasserstein loss + gradient penalty\n Phase 2: BC update on policy (annealed)\n \"\"\"\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n # === Phase 1: Wasserstein AIRL discriminator update ===\n\n # Sample expert data\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n # Sample policy data\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n # Wasserstein loss: critic should output high for expert, low for policy\n # Minimize: -(E[f(expert)] - E[f(policy)])\n expert_scores = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_scores = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n wasserstein_loss = -(expert_scores.mean() - gen_scores.mean())\n\n # Gradient penalty for Lipschitz constraint\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n\n disc_loss = wasserstein_loss + self.grad_penalty_coef * gp\n\n self.optimizer.zero_grad()\n disc_loss.backward()\n self.optimizer.step()\n\n # === Phase 2: BC warm-start update (annealed) ===\n bc_loss_val = self._bc_update()\n\n # Decay BC weight\n self.bc_weight = max(self.bc_min_weight, self.bc_weight * self.bc_decay)\n\n return {\n \"irl_loss\": disc_loss.item(),\n \"w_dist\": -wasserstein_loss.item(),\n \"grad_penalty\": gp.item(),\n \"bc_loss\": bc_loss_val,\n \"bc_weight\": self.bc_weight,\n }\n\n"
7}OK: Replaced lines 231..357 with 216 line(s) in imitation/custom_irl.py. Editable range: 231–446.
[Current file: imitation/custom_irl.py | editable: 231–446 | total: 688 lines]
231 class RewardNetwork(nn.Module):
232 """AIRL-style shaped reward: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).
233
... (210 more lines) ...
444 "bc_weight": self.bc_weight,
445 }
446 [Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=0.1916 v_loss=106.0780 entropy=1.9125 irl_loss=-13.1565 w_dist=14.3133 grad_penalty=0.1157 bc_loss=0.0184 bc_weight=0.0050 TRAIN_METRICS step=995328 pg_loss=0.1949 v_loss=100.4625 entropy=1.9732 irl_loss=-6.3979 w_dist=7.6495 grad_penalty=0.1252 bc_loss=0.0208 bc_weight=0.0050 TRAIN_METRICS step=997376 pg_loss=0.1942 v_loss=140.1659 entropy=2.0257 irl_loss=-9.5510 w_dist=10.8182 grad_penalty=0.1267 bc_loss=0.0215 bc_weight=0.0050 TRAIN_METRICS step=999424 pg_loss=0.1736 v_loss=99.4692 entropy=2.0701 irl_loss=-8.6748 w_dist=9.8201 grad_penalty=0.1145 bc_loss=0.0171 bc_weight=0.0050 TRAIN_METRICS step=1000000 pg_loss=0.1748 v_loss=107.0256 entropy=2.0863 irl_loss=-8.3352 w_dist=9.5395 grad_penalty=0.1204 bc_loss=0.0108 bc_weight=0.0050 Evaluation (halfcheetah-v4): Eval episodic_return:
Tool Input
1{
2 "content": "class RewardNetwork(nn.Module):\n \"\"\"AIRL-style shaped reward: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).\n\n Decomposes the reward into a state-action reward g(s,a) and a potential-based\n shaping function h(s). This decomposition helps recover a transferable reward\n and provides natural value-function shaping that accelerates PPO learning.\n \"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n # g(s, a): reward approximator\n self.g_net = nn.Sequential(\n nn.Linear(obs_dim + action_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n # h(s): potential-based shaping\n self.h_net = nn.Sequential(\n nn.Linear(obs_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.gamma = 0.99\n\n def g(self, state, action):\n \"\"\"Reward function g(s, a).\"\"\"\n x = torch.cat([state, action], dim=-1)\n return self.g_net(x).squeeze(-1)\n\n def h(self, state):\n \"\"\"Potential function h(s).\"\"\"\n return self.h_net(state).squeeze(-1)\n\n def forward(self, state, action, next_state):\n \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s).\"\"\"\n return self.g(state, action) + self.gamma * self.h(next_state) - self.h(state)\n\n\nclass IRLAlgorithm:\n \"\"\"BC-Annealed AIRL with Dual-Phase Training.\n\n Novel algorithm combining three complementary ideas:\n\n 1. **AIRL potential-based reward shaping** f(s,a,s') = g(s,a) + γh(s') - h(s).\n This decomposition helps recover a transferable reward function and provides\n natural value-function shaping that accelerates PPO learning.\n\n 2. **Behavioral cloning warm-start with cosine annealing**. Early in training,\n when the learned reward is unreliable and the policy is far from expert\n behavior, BC provides strong supervised signal directly to the policy.\n The BC weight follows a cosine schedule, smoothly transitioning from\n imitation to reward-driven learning. This resolves the cold-start problem\n where random policies get no useful IRL gradient signal.\n\n 3. **Stabilized discriminator training** with gradient penalty, label smoothing,\n and discriminator output clipping. The discriminator logits are clamped to\n prevent extreme reward values that destabilize PPO. Combined with the\n running reward normalization in the main loop, this ensures smooth,\n informative reward signals throughout training.\n\n The combination is novel: BC handles cold-start (giving the discriminator\n meaningful policy data to learn from), AIRL provides long-term reward learning,\n and the cosine annealing schedule smoothly transitions between regimes without\n the instability of abrupt switching.\n \"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n # Discriminator optimizer - lower LR to prevent overwhelming the policy\n disc_lr = min(args.irl_lr, 1e-4)\n self.optimizer = optim.Adam(self.reward_net.parameters(), lr=disc_lr)\n self.total_updates = 0\n\n # Gradient penalty coefficient\n self.grad_penalty_coef = 10.0\n # Label smoothing to prevent overconfident discriminator\n self.label_smoothing = 0.1\n # Logit clipping to prevent extreme reward values\n self.logit_clip = 10.0\n\n # BC annealing parameters\n self._policy = None\n self._policy_optimizer = None\n # Total number of IRL update calls expected:\n # ~(total_timesteps / n_steps) * n_irl_updates_per_round = (1M/2048)*5 ≈ 2440\n self.total_expected_updates = (args.total_timesteps // args.n_steps) * args.n_irl_updates_per_round\n self.n_bc_steps = 15 # BC gradient steps per update call (when active)\n\n def set_policy(self, policy, optimizer):\n \"\"\"Set reference to policy for BC warm-start updates.\"\"\"\n self._policy = policy\n self._policy_optimizer = optimizer\n\n def _get_bc_weight(self):\n \"\"\"Cosine annealing schedule for BC weight.\n\n Starts at 1.0, smoothly decays to near 0 over the first 60% of training,\n then maintains a small floor for ongoing regularization.\n \"\"\"\n progress = min(1.0, self.total_updates / max(1, self.total_expected_updates * 0.5))\n # Cosine decay from 1.0 to 0.0\n cos_weight = 0.5 * (1.0 + np.cos(np.pi * progress))\n # Floor at small value for mild ongoing regularization\n return max(0.005, cos_weight)\n\n def compute_reward(self, obs, acts, next_obs):\n \"\"\"Use clamped shaped reward f(s,a,s') for stable policy training.\"\"\"\n with torch.no_grad():\n logits = self.reward_net(obs, acts, next_obs)\n # Clamp to prevent extreme reward values\n return torch.clamp(logits, -self.logit_clip, self.logit_clip)\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n \"\"\"Gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def _bc_update(self):\n \"\"\"Perform BC gradient steps on policy with cosine-annealed weight.\"\"\"\n bc_weight = self._get_bc_weight()\n if self._policy is None or bc_weight < 0.006:\n return 0.0\n\n batch_size = self.args.irl_batch_size\n n_expert = len(self.expert_demos[\"obs\"])\n total_bc_loss = 0.0\n\n # Scale number of steps with weight for efficiency\n effective_steps = max(1, int(self.n_bc_steps * min(1.0, bc_weight * 3)))\n\n for _ in range(effective_steps):\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n\n # Negative log-likelihood of expert actions under current policy\n _, log_prob, entropy, _ = self._policy.get_action_and_value(\n expert_obs, expert_acts,\n )\n neglogp = -log_prob.mean()\n ent_bonus = -0.001 * entropy.mean()\n bc_loss = bc_weight * (neglogp + ent_bonus)\n\n self._policy_optimizer.zero_grad()\n bc_loss.backward()\n nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n self._policy_optimizer.step()\n total_bc_loss += bc_loss.item()\n\n return total_bc_loss / max(effective_steps, 1)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n \"\"\"Combined update: AIRL discriminator + BC warm-start (annealed).\n\n Phase 1: Train discriminator with BCE + gradient penalty + label smoothing\n Phase 2: BC update on policy (cosine-annealed)\n \"\"\"\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n # === Phase 1: AIRL discriminator update ===\n\n # Sample expert data\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n # Sample policy data\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n # Discriminator logits (f-values from AIRL)\n expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n # Labels with smoothing: expert=1-eps, generator=eps\n eps = self.label_smoothing\n expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n logits = torch.cat([expert_logits, gen_logits], dim=0)\n labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n # Gradient penalty\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n\n disc_loss = bce_loss + self.grad_penalty_coef * gp\n\n self.optimizer.zero_grad()\n disc_loss.backward()\n self.optimizer.step()\n\n # === Phase 2: BC warm-start update (cosine-annealed) ===\n bc_loss_val = self._bc_update()\n\n # Compute discriminator accuracy for logging\n with torch.no_grad():\n hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n torch.zeros(batch_size, device=self.devic\n... (truncated)",
3 "end_line": 446,
4 "filename": "imitation/custom_irl.py",
5 "op": "replace",
6 "start_line": 231
7}OK: Replaced lines 231..446 with 245 line(s) in imitation/custom_irl.py. Editable range: 231–475.
[Current file: imitation/custom_irl.py | editable: 231–475 | total: 717 lines]
231 class RewardNetwork(nn.Module):
232 """AIRL-style shaped reward: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).
233
... (239 more lines) ...
473 "bc_weight": self._get_bc_weight(),
474 }
475 [Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=0.2752 v_loss=24.1465 entropy=-3.2746 irl_loss=0.3416 disc_acc=1.0000 grad_penalty=0.0008 bc_loss=0.0000 bc_weight=0.0050 TRAIN_METRICS step=995328 pg_loss=0.2581 v_loss=30.1042 entropy=-3.2525 irl_loss=0.3425 disc_acc=0.9980 grad_penalty=0.0009 bc_loss=0.0000 bc_weight=0.0050 TRAIN_METRICS step=997376 pg_loss=0.2473 v_loss=35.7972 entropy=-3.2164 irl_loss=0.3429 disc_acc=1.0000 grad_penalty=0.0011 bc_loss=0.0000 bc_weight=0.0050 TRAIN_METRICS step=999424 pg_loss=0.2495 v_loss=37.1688 entropy=-3.1883 irl_loss=0.3420 disc_acc=1.0000 grad_penalty=0.0010 bc_loss=0.0000 bc_weight=0.0050 TRAIN_METRICS step=1000000 pg_loss=0.2219 v_loss=49.2893 entropy=-3.1782 irl_loss=0.3392 disc_acc=1.0000 grad_penalty=0.0009 bc_loss=0.0000 bc_weight=0.0050 Evaluation (halfcheetah-v4): Eval episodic_return
Tool Input
1{
2 "content": "class RewardNetwork(nn.Module):\n \"\"\"AIRL-style shaped reward: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).\n\n Decomposes the reward into a state-action reward g(s,a) and a potential-based\n shaping function h(s). This decomposition helps recover a transferable reward\n and provides natural value-function shaping that accelerates PPO learning.\n \"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n # g(s, a): reward approximator\n self.g_net = nn.Sequential(\n nn.Linear(obs_dim + action_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n # h(s): potential-based shaping\n self.h_net = nn.Sequential(\n nn.Linear(obs_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.gamma = 0.99\n\n def g(self, state, action):\n \"\"\"Reward function g(s, a).\"\"\"\n x = torch.cat([state, action], dim=-1)\n return self.g_net(x).squeeze(-1)\n\n def h(self, state):\n \"\"\"Potential function h(s).\"\"\"\n return self.h_net(state).squeeze(-1)\n\n def forward(self, state, action, next_state):\n \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s).\"\"\"\n return self.g(state, action) + self.gamma * self.h(next_state) - self.h(state)\n\n\nclass IRLAlgorithm:\n \"\"\"BC-Annealed Wasserstein AIRL with Dual-Phase Training.\n\n Novel algorithm combining three complementary ideas:\n\n 1. **Wasserstein objective** for discriminator training instead of BCE.\n The Wasserstein distance provides non-vanishing gradients even when the\n discriminator becomes accurate. Standard BCE-based discriminators (GAIL/AIRL)\n suffer from mode collapse when disc_acc → 1.0, providing zero gradient to\n the policy. Wasserstein loss with gradient penalty maintains informative\n gradients throughout training.\n\n 2. **AIRL potential-based reward shaping** f(s,a,s') = g(s,a) + γh(s') - h(s).\n This decomposition helps recover a transferable reward function and provides\n natural value-function shaping that accelerates PPO learning.\n\n 3. **Behavioral cloning warm-start with cosine annealing** over 75% of training.\n Early in training, BC provides strong supervised signal directly to the policy,\n solving the cold-start problem where random policies produce uninformative\n transitions for the discriminator. The cosine schedule ensures smooth\n transition from imitation to reward-driven learning, maintaining a small\n floor for ongoing regularization that prevents late-training instability.\n\n The key insight is that BC and Wasserstein AIRL are synergistic:\n - BC quickly gets the policy into a reasonable state, giving the discriminator\n meaningful policy data to learn from (not random noise)\n - The discriminator learns a reward that captures nuances beyond simple imitation\n - As BC decays, the policy is free to explore reward-improving behaviors\n - Wasserstein loss prevents the discriminator from \"winning\" too hard\n \"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n # Discriminator optimizer with lower beta1 for Wasserstein stability\n disc_lr = min(args.irl_lr, 1e-4)\n self.optimizer = optim.Adam(\n self.reward_net.parameters(), lr=disc_lr, betas=(0.5, 0.9)\n )\n self.total_updates = 0\n\n # Gradient penalty coefficient for Lipschitz constraint\n self.grad_penalty_coef = 10.0\n # Reward output clamping for PPO stability\n self.reward_clip = 10.0\n\n # BC annealing parameters\n self._policy = None\n self._policy_optimizer = None\n # Total expected IRL update calls: (total_timesteps / n_steps) * n_irl_updates_per_round\n self.total_expected_updates = max(1, (args.total_timesteps // args.n_steps) * args.n_irl_updates_per_round)\n # Anneal BC over 75% of training\n self.bc_anneal_fraction = 0.75\n self.n_bc_steps_max = 15 # max BC gradient steps per update call\n\n def set_policy(self, policy, optimizer):\n \"\"\"Set reference to policy for BC warm-start updates.\"\"\"\n self._policy = policy\n self._policy_optimizer = optimizer\n\n def _get_bc_weight(self):\n \"\"\"Cosine annealing schedule for BC weight.\n\n Decays from 1.0 to 0.0 over bc_anneal_fraction of total training,\n maintaining a small floor for mild ongoing regularization.\n \"\"\"\n anneal_end = self.total_expected_updates * self.bc_anneal_fraction\n progress = min(1.0, self.total_updates / max(1, anneal_end))\n # Cosine decay from 1.0 to 0.0\n cos_weight = 0.5 * (1.0 + np.cos(np.pi * progress))\n # Small floor to maintain mild regularization throughout\n return max(0.01, cos_weight)\n\n def compute_reward(self, obs, acts, next_obs):\n \"\"\"Use clamped shaped reward f(s,a,s') for stable policy training.\n\n Clamping prevents extreme reward values from destabilizing PPO,\n while still allowing informative gradients through the Wasserstein critic.\n \"\"\"\n with torch.no_grad():\n logits = self.reward_net(obs, acts, next_obs)\n return torch.clamp(logits, -self.reward_clip, self.reward_clip)\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n \"\"\"Gradient penalty on interpolated data for Lipschitz constraint (WGAN-GP).\"\"\"\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def _bc_update(self):\n \"\"\"Perform BC gradient steps on policy with cosine-annealed weight.\"\"\"\n bc_weight = self._get_bc_weight()\n if self._policy is None or bc_weight < 0.011:\n return 0.0\n\n batch_size = self.args.irl_batch_size\n n_expert = len(self.expert_demos[\"obs\"])\n total_bc_loss = 0.0\n\n # Scale number of steps: more when BC is important, fewer as it decays\n effective_steps = max(1, int(self.n_bc_steps_max * bc_weight))\n\n for _ in range(effective_steps):\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n\n # Negative log-likelihood of expert actions under current policy\n _, log_prob, entropy, _ = self._policy.get_action_and_value(\n expert_obs, expert_acts,\n )\n neglogp = -log_prob.mean()\n ent_bonus = -0.001 * entropy.mean()\n bc_loss = bc_weight * (neglogp + ent_bonus)\n\n self._policy_optimizer.zero_grad()\n bc_loss.backward()\n nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n self._policy_optimizer.step()\n total_bc_loss += bc_loss.item()\n\n return total_bc_loss / max(effective_steps, 1)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n \"\"\"Combined update: Wasserstein AIRL discriminator + BC warm-start.\n\n Phase 1: Train discriminator with Wasserstein loss + gradient penalty\n Phase 2: BC update on policy (cosine-annealed, active for first 75% of training)\n \"\"\"\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n # === Phase 1: Wasserstein AIRL discriminator update ===\n\n # Sample expert data\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n # Sample policy data\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n # Wasserstein loss: critic should output high for expert, low for policy\n # Minimize: -(E[f(expert)] - E[f(policy)])\n expert_scores = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_scores = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n wasserstein_loss = -(expert_scores.mean() - gen_scores.mean())\n\n # Gradient penalty for Lipschitz constraint\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n\n disc_loss = wasserstein_loss + self.grad_penalty_coef * gp\n\n self.optimizer.zero_grad()\n disc_loss.backward()\n self.optimizer.step()\n\n # === Phase 2: BC warm-start update (cosine-annealed) ===\n bc_loss_val = self._bc_update()\n\n \n... (truncated)",
3 "end_line": 475,
4 "filename": "imitation/custom_irl.py",
5 "op": "replace",
6 "start_line": 231
7}OK: Replaced lines 231..475 with 239 line(s) in imitation/custom_irl.py. Editable range: 231–469.
[Current file: imitation/custom_irl.py | editable: 231–469 | total: 711 lines]
231 class RewardNetwork(nn.Module):
232 """AIRL-style shaped reward: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).
233
... (233 more lines) ...
467 "bc_weight": self._get_bc_weight(),
468 }
469 [Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=-0.0063 v_loss=5.7138 entropy=-3.1275 irl_loss=-6.9273 w_dist=8.0012 grad_penalty=0.1074 bc_loss=0.0000 bc_weight=0.0100 TRAIN_METRICS step=995328 pg_loss=-0.0310 v_loss=12.1199 entropy=-3.1326 irl_loss=-8.4442 w_dist=9.5646 grad_penalty=0.1120 bc_loss=0.0000 bc_weight=0.0100 TRAIN_METRICS step=997376 pg_loss=-0.0316 v_loss=10.6274 entropy=-3.1536 irl_loss=-6.4403 w_dist=7.5679 grad_penalty=0.1128 bc_loss=0.0000 bc_weight=0.0100 TRAIN_METRICS step=999424 pg_loss=-0.0290 v_loss=18.3251 entropy=-3.1681 irl_loss=-7.3475 w_dist=8.3260 grad_penalty=0.0979 bc_loss=0.0000 bc_weight=0.0100 TRAIN_METRICS step=1000000 pg_loss=-0.1006 v_loss=12.8860 entropy=-3.1754 irl_loss=
Tool Input
1{
2 "n": 1
3}[submit] Submitting result from test #1 as final. ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=0.1916 v_loss=106.0780 entropy=1.9125 irl_loss=-13.1565 w_dist=14.3133 grad_penalty=0.1157 bc_loss=0.0184 bc_weight=0.0050 TRAIN_METRICS step=995328 pg_loss=0.1949 v_loss=100.4625 entropy=1.9732 irl_loss=-6.3979 w_dist=7.6495 grad_penalty=0.1252 bc_loss=0.0208 bc_weight=0.0050 TRAIN_METRICS step=997376 pg_loss=0.1942 v_loss=140.1659 entropy=2.0257 irl_loss=-9.5510 w_dist=10.8182 grad_penalty=0.1267 bc_loss=0.0215 bc_weight=0.0050 TRAIN_METRICS step=999424 pg_loss=0.1736 v_loss=99.4692 entropy=2.0701 irl_loss=-8.6748 w_dist=9.8201 grad_penalty=0.1145 bc_loss=0.0171 bc_weight=0.0050 TRAIN_METRICS step=1000000 pg_loss=0.1748 v_loss=107.0256 entropy=2.0863 irl_loss=-8.3352 w_dist=9.5395 grad_penalty=0.1204 bc_loss=0.0108 bc_weight=0.0050 Evaluation (halfcheetah-v4): Eval episodic_return: 2946.00 Eval episodic_return: 2232.17 Eval ep