Agent Conversation: gemini-3.1-pro-preview
Task: rl-reward-learning
Inverse RL: Reward Learning from Expert Demonstrations
Objective
Design and implement an inverse reinforcement learning (IRL) algorithm that learns a reward function from expert demonstrations. Your code goes in custom_irl.py, specifically the RewardNetwork and IRLAlgorithm classes. Three reference implementations (GAIL, AIRL, BC) from the imitation library are provided as read-only context.
Background
Inverse reinforcement learning recovers a reward function that explains observed expert behavior. The learned reward is then used to train a policy via standard RL (PPO in this benchmark). Key challenges include:
- Designing reward network architectures that capture the structure of expert behavior
- Balancing discriminator training with policy improvement
- Avoiding reward hacking where the policy exploits learned reward artifacts
- Ensuring the learned reward generalizes across different states visited during training
Different IRL approaches address these through adversarial training (GAIL), potential-based reward shaping (AIRL), or direct behavioral cloning. Your goal is to design a novel reward network architecture or IRL training algorithm that outperforms these baselines.
Evaluation
Trained and evaluated on three MuJoCo locomotion environments using pre-generated expert demonstrations: HalfCheetah-v4, Hopper-v4, Walker2d-v4. Metric: mean episodic return over 10 evaluation episodes (higher is better). The policy is trained using PPO with the learned reward signal.
imitation/custom_irl.py [EDITABLE — lines 231–357 only]
1: # Custom IRL / Reward Learning algorithm for MLS-Bench
2: #
3: # EDITABLE section: RewardNetwork and IRLAlgorithm classes.
4: # FIXED sections: everything else (config, env, demo loading, PPO training, evaluation).
5: import os
6: import random
7: import time
8: from dataclasses import dataclass
9:
10: import gymnasium as gym
11: import numpy as np
12: import torch
13: import torch.nn as nn
14: import torch.nn.functional as F
15: import torch.optim as optim
16:
17:
18: # =====================================================================
19: # FIXED: Configuration
20: # =====================================================================
21: @dataclass
22: class Args:
23: env_id: str = "HalfCheetah-v4"
24: seed: int = 42
25: torch_deterministic: bool = True
26: cuda: bool = True
27: # IRL training
28: irl_epochs: int = 200
29: irl_batch_size: int = 256
30: irl_lr: float = 3e-4
31: demo_path: str = "" # set from env or CLI
32: # Policy training (PPO via custom loop)
33: total_timesteps: int = 1000000
34: policy_lr: float = 3e-4
35: gamma: float = 0.99
36: gae_lambda: float = 0.95
37: n_steps: int = 2048
38: n_epochs: int = 10
39Tool Input
1{
2 "end_line": 357,
3 "start_line": 231,
4 "filename": "custom_irl.py",
5 "content": "class RewardNetwork(nn.Module):\n \"\"\"Multi-Task Reward Network with shared state features.\"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n self.state_net = nn.Sequential(\n nn.Linear(obs_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n )\n self.disc_head = nn.Sequential(\n nn.Linear(256 + action_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.bc_head = nn.Linear(256, action_dim)\n self.bc_coef = 2.0\n\n def forward(self, state, action, next_state):\n h = self.state_net(state)\n h_a = torch.cat([h, action], dim=-1)\n d_val = self.disc_head(h_a).squeeze(-1)\n \n bc_act = self.bc_head(h)\n penalty = torch.mean((action - bc_act)**2, dim=-1)\n \n return d_val - self.bc_coef * penalty\n\n def compute_components(self, state, action):\n h = self.state_net(state)\n h_a = torch.cat([h, action], dim=-1)\n d_val = self.disc_head(h_a).squeeze(-1)\n bc_act = self.bc_head(h)\n return d_val, bc_act\n\n\nclass IRLAlgorithm:\n \"\"\"LSGAN + BC Regularization Multi-Task IRL.\"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n self.total_updates = 0\n self.grad_penalty_coef = 10.0\n \n self._policy = None\n self._policy_optimizer = None\n\n def set_policy(self, policy, optimizer):\n self._policy = policy\n self._policy_optimizer = optimizer\n\n def compute_reward(self, obs, acts, next_obs):\n with torch.no_grad():\n return self.reward_net(obs, acts, next_obs)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n\n expert_d, expert_bc = self.reward_net.compute_components(expert_obs, expert_acts)\n gen_d, gen_bc = self.reward_net.compute_components(gen_obs, gen_acts)\n\n disc_loss = 0.5 * torch.mean((expert_d - 1.0)**2) + 0.5 * torch.mean((gen_d - (-1.0))**2)\n bc_loss = F.mse_loss(expert_bc, expert_acts)\n\n alpha = torch.rand(batch_size, 1, device=self.device)\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n \n h_interp = self.reward_net.state_net(interp_obs)\n d_interp = self.reward_net.disc_head(torch.cat([h_interp, interp_acts], dim=-1)).squeeze(-1)\n \n ones = torch.ones_like(d_interp)\n grads = torch.autograd.grad(\n outputs=d_interp, inputs=[interp_obs, interp_acts],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n gp = torch.mean((grad_norm - 1.0) ** 2)\n\n loss = disc_loss + self.grad_penalty_coef * gp + 10.0 * bc_loss\n\n self.optimizer.zero_grad()\n loss.backward()\n self.optimizer.step()\n\n p_bc_loss_val = 0.0\n if self._policy is not None:\n n_bc_steps = 10\n for _ in range(n_bc_steps):\n idx = torch.randint(0, n_expert, (batch_size,))\n exp_o = self.expert_demos[\"obs\"][idx]\n exp_a = self.expert_demos[\"acts\"][idx]\n \n _, log_prob, entropy, _ = self._policy.get_action_and_value(exp_o, exp_a)\n p_bc_loss = -log_prob.mean() - 0.001 * entropy.mean()\n \n self._policy_optimizer.zero_grad()\n p_bc_loss.backward()\n nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n self._policy_optimizer.step()\n p_bc_loss_val += p_bc_loss.item()\n p_bc_loss_val /= n_bc_steps\n\n return {\n \"irl_loss\": loss.item(),\n \"disc_loss\": disc_loss.item(),\n \"bc_loss\": bc_loss.item(),\n \"p_bc_loss\": p_bc_loss_val\n }",
6 "op": "replace"
7}ERROR: Package 'custom_irl.py' is not in allowed packages
Tool Input
1{
2 "start_line": 231,
3 "op": "replace",
4 "end_line": 357,
5 "filename": "imitation/custom_irl.py",
6 "content": "class RewardNetwork(nn.Module):\n \"\"\"Multi-Task Reward Network with shared state features.\"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n self.state_net = nn.Sequential(\n nn.Linear(obs_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n )\n self.disc_head = nn.Sequential(\n nn.Linear(256 + action_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.bc_head = nn.Linear(256, action_dim)\n self.bc_coef = 2.0\n\n def forward(self, state, action, next_state):\n h = self.state_net(state)\n h_a = torch.cat([h, action], dim=-1)\n d_val = self.disc_head(h_a).squeeze(-1)\n \n bc_act = self.bc_head(h)\n penalty = torch.mean((action - bc_act)**2, dim=-1)\n \n return d_val - self.bc_coef * penalty\n\n def compute_components(self, state, action):\n h = self.state_net(state)\n h_a = torch.cat([h, action], dim=-1)\n d_val = self.disc_head(h_a).squeeze(-1)\n bc_act = self.bc_head(h)\n return d_val, bc_act\n\n\nclass IRLAlgorithm:\n \"\"\"LSGAN + BC Regularization Multi-Task IRL.\"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n self.total_updates = 0\n self.grad_penalty_coef = 10.0\n \n self._policy = None\n self._policy_optimizer = None\n\n def set_policy(self, policy, optimizer):\n self._policy = policy\n self._policy_optimizer = optimizer\n\n def compute_reward(self, obs, acts, next_obs):\n with torch.no_grad():\n return self.reward_net(obs, acts, next_obs)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n\n expert_d, expert_bc = self.reward_net.compute_components(expert_obs, expert_acts)\n gen_d, gen_bc = self.reward_net.compute_components(gen_obs, gen_acts)\n\n disc_loss = 0.5 * torch.mean((expert_d - 1.0)**2) + 0.5 * torch.mean((gen_d - (-1.0))**2)\n bc_loss = F.mse_loss(expert_bc, expert_acts)\n\n alpha = torch.rand(batch_size, 1, device=self.device)\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n \n h_interp = self.reward_net.state_net(interp_obs)\n d_interp = self.reward_net.disc_head(torch.cat([h_interp, interp_acts], dim=-1)).squeeze(-1)\n \n ones = torch.ones_like(d_interp)\n grads = torch.autograd.grad(\n outputs=d_interp, inputs=[interp_obs, interp_acts],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n gp = torch.mean((grad_norm - 1.0) ** 2)\n\n loss = disc_loss + self.grad_penalty_coef * gp + 10.0 * bc_loss\n\n self.optimizer.zero_grad()\n loss.backward()\n self.optimizer.step()\n\n p_bc_loss_val = 0.0\n if self._policy is not None:\n n_bc_steps = 10\n for _ in range(n_bc_steps):\n idx = torch.randint(0, n_expert, (batch_size,))\n exp_o = self.expert_demos[\"obs\"][idx]\n exp_a = self.expert_demos[\"acts\"][idx]\n \n _, log_prob, entropy, _ = self._policy.get_action_and_value(exp_o, exp_a)\n p_bc_loss = -log_prob.mean() - 0.001 * entropy.mean()\n \n self._policy_optimizer.zero_grad()\n p_bc_loss.backward()\n nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n self._policy_optimizer.step()\n p_bc_loss_val += p_bc_loss.item()\n p_bc_loss_val /= n_bc_steps\n\n return {\n \"irl_loss\": loss.item(),\n \"disc_loss\": disc_loss.item(),\n \"bc_loss\": bc_loss.item(),\n \"p_bc_loss\": p_bc_loss_val\n }"
7}OK: Replaced lines 231..357 with 129 line(s) in imitation/custom_irl.py. Editable range: 231–359.
[Current file: imitation/custom_irl.py | editable: 231–359 | total: 601 lines]
231 class RewardNetwork(nn.Module):
232 """Multi-Task Reward Network with shared state features."""
233
... (123 more lines) ...
357 "bc_loss": bc_loss.item(),
358 "p_bc_loss": p_bc_loss_val
359 }[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=0.3003 v_loss=15.8558 entropy=-10.8984 irl_loss=1.2236 disc_loss=0.8169 bc_loss=0.0115 p_bc_loss=-11.7359 TRAIN_METRICS step=995328 pg_loss=0.2950 v_loss=17.1726 entropy=-10.8995 irl_loss=1.3185 disc_loss=0.8744 bc_loss=0.0125 p_bc_loss=-11.7414 TRAIN_METRICS step=997376 pg_loss=0.3046 v_loss=20.0570 entropy=-10.9221 irl_loss=1.3204 disc_loss=0.8921 bc_loss=0.0132 p_bc_loss=-11.8087 TRAIN_METRICS step=999424 pg_loss=0.3025 v_loss=24.9686 entropy=-10.8659 irl_loss=1.0573 disc_loss=0.6147 bc_loss=0.0122 p_bc_loss=-11.7211 TRAIN_METRICS step=1000000 pg_loss=0.3048 v_loss=53.1911 entropy=-10.9230 irl_loss=1.5116 disc_loss=1.0351 bc_loss=0.0128 p_bc_loss=-11.7419 Evaluation (halfcheetah-v4): Eval episodic_return: 3198.03 Eval episodic_return: 3532.08 Eval episodic_return: 3419.66 Final eva
Tool Input
1{
2 "content": "class RewardNetwork(nn.Module):\n \"\"\"AIRL shaped reward network: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).\"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n self.g_net = nn.Sequential(\n nn.Linear(obs_dim + action_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.h_net = nn.Sequential(\n nn.Linear(obs_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.gamma = 0.99\n\n def g(self, state, action):\n x = torch.cat([state, action], dim=-1)\n return self.g_net(x).squeeze(-1)\n\n def h(self, state):\n return self.h_net(state).squeeze(-1)\n\n def forward(self, state, action, next_state):\n return self.g(state, action) + self.gamma * self.h(next_state) - self.h(state)\n\n\nclass IRLAlgorithm:\n \"\"\"Demo-Regularized LS-AIRL (Least-Squares AIRL + BC).\"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n self.total_updates = 0\n self.grad_penalty_coef = 10.0\n\n self._policy = None\n self._policy_optimizer = None\n\n def set_policy(self, policy, optimizer):\n self._policy = policy\n self._policy_optimizer = optimizer\n\n def compute_reward(self, obs, acts, next_obs):\n with torch.no_grad():\n return self.reward_net(obs, acts, next_obs)\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n # LSGAN Loss\n disc_loss = 0.5 * torch.mean((expert_logits - 1.0)**2) + 0.5 * torch.mean((gen_logits - (-1.0))**2)\n\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n\n loss = disc_loss + self.grad_penalty_coef * gp\n\n self.optimizer.zero_grad()\n loss.backward()\n self.optimizer.step()\n\n # Demo-Guided Policy Regularization (BC)\n p_bc_loss_val = 0.0\n if self._policy is not None:\n n_bc_steps = 20\n for _ in range(n_bc_steps):\n idx = torch.randint(0, n_expert, (batch_size,))\n exp_o = self.expert_demos[\"obs\"][idx]\n exp_a = self.expert_demos[\"acts\"][idx]\n\n _, log_prob, entropy, _ = self._policy.get_action_and_value(exp_o, exp_a)\n p_bc_loss = -log_prob.mean() - 0.001 * entropy.mean()\n\n self._policy_optimizer.zero_grad()\n p_bc_loss.backward()\n nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n self._policy_optimizer.step()\n p_bc_loss_val += p_bc_loss.item()\n p_bc_loss_val /= n_bc_steps\n\n return {\n \"irl_loss\": loss.item(),\n \"grad_penalty\": gp.item(),\n \"p_bc_loss\": p_bc_loss_val\n }",
3 "start_line": 231,
4 "end_line": 359,
5 "op": "replace",
6 "filename": "imitation/custom_irl.py"
7}OK: Replaced lines 231..359 with 135 line(s) in imitation/custom_irl.py. Editable range: 231–365.
[Current file: imitation/custom_irl.py | editable: 231–365 | total: 607 lines]
231 class RewardNetwork(nn.Module):
232 """AIRL shaped reward network: f(s,a,s') = g(s,a) + gamma * h(s') - h(s)."""
233
... (129 more lines) ...
363 "grad_penalty": gp.item(),
364 "p_bc_loss": p_bc_loss_val
365 }[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=0.3114 v_loss=41.3663 entropy=-13.3465 irl_loss=0.8699 grad_penalty=0.0012 p_bc_loss=-13.7833 TRAIN_METRICS step=995328 pg_loss=0.3053 v_loss=22.9781 entropy=-13.3245 irl_loss=1.0233 grad_penalty=0.0013 p_bc_loss=-13.7528 TRAIN_METRICS step=997376 pg_loss=0.3080 v_loss=23.4172 entropy=-13.3865 irl_loss=1.0218 grad_penalty=0.0012 p_bc_loss=-13.8431 TRAIN_METRICS step=999424 pg_loss=0.3012 v_loss=17.5125 entropy=-13.4157 irl_loss=0.9965 grad_penalty=0.0011 p_bc_loss=-13.7837 TRAIN_METRICS step=1000000 pg_loss=0.3019 v_loss=19.5665 entropy=-13.4646 irl_loss=1.0194 grad_penalty=0.0014 p_bc_loss=-13.9039 Evaluation (halfcheetah-v4): Eval episodic_return: 3338.48 Eval episodic_return: 3272.54 Eval episodic_return: 2969.47 Final eval return: 2969.47 ### hopper-v4 (scripts/hopper.sh) Training
Tool Input
1{
2 "op": "replace",
3 "end_line": 365,
4 "start_line": 231,
5 "content": "class RewardNetwork(nn.Module):\n \"\"\"LS-AIRL shaped reward network: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).\"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n self.g_net = nn.Sequential(\n nn.Linear(obs_dim + action_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.h_net = nn.Sequential(\n nn.Linear(obs_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n self.gamma = 0.99\n\n def g(self, state, action):\n x = torch.cat([state, action], dim=-1)\n return self.g_net(x).squeeze(-1)\n\n def h(self, state):\n return self.h_net(state).squeeze(-1)\n\n def forward(self, state, action, next_state):\n return self.g(state, action) + self.gamma * self.h(next_state) - self.h(state)\n\n\nclass IRLAlgorithm:\n \"\"\"Demo-Regularized LS-AIRL (Least-Squares AIRL + BC on Mean).\"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n self.total_updates = 0\n self.grad_penalty_coef = 10.0\n\n self._policy = None\n self._policy_optimizer = None\n\n def set_policy(self, policy, optimizer):\n self._policy = policy\n self._policy_optimizer = optimizer\n\n def compute_reward(self, obs, acts, next_obs):\n with torch.no_grad():\n return self.reward_net(obs, acts, next_obs)\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n # LSGAN Loss\n disc_loss = 0.5 * torch.mean((expert_logits - 1.0)**2) + 0.5 * torch.mean((gen_logits - (-1.0))**2)\n\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n\n loss = disc_loss + self.grad_penalty_coef * gp\n\n self.optimizer.zero_grad()\n loss.backward()\n self.optimizer.step()\n\n # Demo-Guided Policy Regularization (BC on Mean ONLY to prevent variance collapse)\n p_bc_loss_val = 0.0\n if self._policy is not None:\n n_bc_steps = 20\n for _ in range(n_bc_steps):\n idx = torch.randint(0, n_expert, (batch_size,))\n exp_o = self.expert_demos[\"obs\"][idx]\n exp_a = self.expert_demos[\"acts\"][idx]\n\n # Update only the mean to match expert actions, keeping logstd untouched\n action_mean = self._policy.actor_mean(exp_o)\n p_bc_loss = F.mse_loss(action_mean, exp_a)\n\n self._policy_optimizer.zero_grad()\n p_bc_loss.backward()\n nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n self._policy_optimizer.step()\n p_bc_loss_val += p_bc_loss.item()\n p_bc_loss_val /= n_bc_steps\n\n return {\n \"irl_loss\": loss.item(),\n \"grad_penalty\": gp.item(),\n \"p_bc_loss\": p_bc_loss_val\n }",
6 "filename": "imitation/custom_irl.py"
7}OK: Replaced lines 231..365 with 136 line(s) in imitation/custom_irl.py. Editable range: 231–366.
[Current file: imitation/custom_irl.py | editable: 231–366 | total: 608 lines]
231 class RewardNetwork(nn.Module):
232 """LS-AIRL shaped reward network: f(s,a,s') = g(s,a) + gamma * h(s') - h(s)."""
233
... (130 more lines) ...
364 "grad_penalty": gp.item(),
365 "p_bc_loss": p_bc_loss_val
366 }[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=-0.0010 v_loss=49.9868 entropy=-0.6299 irl_loss=0.6296 grad_penalty=0.0022 p_bc_loss=0.0009 TRAIN_METRICS step=995328 pg_loss=-0.0016 v_loss=83.0296 entropy=-0.6479 irl_loss=0.5901 grad_penalty=0.0020 p_bc_loss=0.0009 TRAIN_METRICS step=997376 pg_loss=-0.0122 v_loss=37.8522 entropy=-0.6711 irl_loss=0.6081 grad_penalty=0.0020 p_bc_loss=0.0008 TRAIN_METRICS step=999424 pg_loss=-0.0052 v_loss=58.0866 entropy=-0.6993 irl_loss=0.6023 grad_penalty=0.0018 p_bc_loss=0.0009 TRAIN_METRICS step=1000000 pg_loss=0.0383 v_loss=62.9409 entropy=-0.7076 irl_loss=0.6384 grad_penalty=0.0019 p_bc_loss=0.0008 Evaluation (halfcheetah-v4): Eval episodic_return: 3008.39 Eval episodic_ret
Tool Input
1{
2 "n": 3
3}[submit] Submitting result from test #3 as final. ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=-0.0010 v_loss=49.9868 entropy=-0.6299 irl_loss=0.6296 grad_penalty=0.0022 p_bc_loss=0.0009 TRAIN_METRICS step=995328 pg_loss=-0.0016 v_loss=83.0296 entropy=-0.6479 irl_loss=0.5901 grad_penalty=0.0020 p_bc_loss=0.0009 TRAIN_METRICS step=997376 pg_loss=-0.0122 v_loss=37.8522 entropy=-0.6711 irl_loss=0.6081 grad_penalty=0.0020 p_bc_loss=0.0008 TRAIN_METRICS step=999424 pg_loss=-0.0052 v_loss=58.0866 entropy=-0.6993 irl_loss=0.6023 grad_penalty=0.0018 p_bc_loss=0.0009 TRAIN_METRICS step=1000000 pg_loss=0.0383 v_loss=62.9409 entropy=-0.7076 irl_loss=0.6384 grad_penalty=0.0019 p_bc_loss=0.0008 Evaluation (halfcheetah-v4): Eval episodic_return: 3008.39 Eval episodic_return: 2991.91 Eval episodic_return: 3014.00 Final eval return: 3014.00 ### hopper-v4 (scripts/hopper.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_los