Agent Conversation: deepseek-reasoner
Task: rl-reward-learning
Inverse RL: Reward Learning from Expert Demonstrations
Objective
Design and implement an inverse reinforcement learning (IRL) algorithm that learns a reward function from expert demonstrations. Your code goes in custom_irl.py, specifically the RewardNetwork and IRLAlgorithm classes. Three reference implementations (GAIL, AIRL, BC) from the imitation library are provided as read-only context.
Background
Inverse reinforcement learning recovers a reward function that explains observed expert behavior. The learned reward is then used to train a policy via standard RL (PPO in this benchmark). Key challenges include:
- Designing reward network architectures that capture the structure of expert behavior
- Balancing discriminator training with policy improvement
- Avoiding reward hacking where the policy exploits learned reward artifacts
- Ensuring the learned reward generalizes across different states visited during training
Different IRL approaches address these through adversarial training (GAIL), potential-based reward shaping (AIRL), or direct behavioral cloning. Your goal is to design a novel reward network architecture or IRL training algorithm that outperforms these baselines.
Evaluation
Trained and evaluated on three MuJoCo locomotion environments using pre-generated expert demonstrations: HalfCheetah-v4, Hopper-v4, Walker2d-v4. Metric: mean episodic return over 10 evaluation episodes (higher is better). The policy is trained using PPO with the learned reward signal.
imitation/custom_irl.py [EDITABLE — lines 231–357 only]
1: # Custom IRL / Reward Learning algorithm for MLS-Bench
2: #
3: # EDITABLE section: RewardNetwork and IRLAlgorithm classes.
4: # FIXED sections: everything else (config, env, demo loading, PPO training, evaluation).
5: import os
6: import random
7: import time
8: from dataclasses import dataclass
9:
10: import gymnasium as gym
11: import numpy as np
12: import torch
13: import torch.nn as nn
14: import torch.nn.functional as F
15: import torch.optim as optim
16:
17:
18: # =====================================================================
19: # FIXED: Configuration
20: # =====================================================================
21: @dataclass
22: class Args:
23: env_id: str = "HalfCheetah-v4"
24: seed: int = 42
25: torch_deterministic: bool = True
26: cuda: bool = True
27: # IRL training
28: irl_epochs: int = 200
29: irl_batch_size: int = 256
30: irl_lr: float = 3e-4
31: demo_path: str = "" # set from env or CLI
32: # Policy training (PPO via custom loop)
33: total_timesteps: int = 1000000
34: policy_lr: float = 3e-4
35: gamma: float = 0.99
36: gae_lambda: float = 0.95
37: n_steps: int = 2048
38: n_epochs: int = 10
39Tool Input
1{
2 "op": "replace",
3 "filename": "imitation/custom_irl.py",
4 "start_line": 231,
5 "end_line": 357,
6 "content": "class RewardNetwork(nn.Module):\n \"\"\"Reward network with uncertainty estimation.\n \n Outputs reward logit and log variance for each transition.\n \"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n input_dim = obs_dim + action_dim + obs_dim\n self.shared = nn.Sequential(\n nn.Linear(input_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n )\n self.reward_head = nn.Linear(256, 1)\n self.logvar_head = nn.Linear(256, 1)\n \n def forward(self, state, action, next_state):\n \"\"\"Compute reward logit and log variance for a batch of transitions.\n \n Returns:\n reward_logit: (batch,)\n log_var: (batch,)\n \"\"\"\n x = torch.cat([state, action, next_state], dim=-1)\n x = self.shared(x)\n reward_logit = self.reward_head(x).squeeze(-1)\n log_var = self.logvar_head(x).squeeze(-1)\n return reward_logit, log_var\n\n\nclass IRLAlgorithm:\n \"\"\"Uncertainty-Aware Adversarial IRL with Consistency Regularization.\n \n Features:\n 1. Uncertainty-weighted adversarial loss (inverse variance weighting).\n 2. Consistency regularization: reward should be smooth under small perturbations.\n 3. Entropy regularization on discriminator outputs to prevent overconfidence.\n \"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n self.total_updates = 0\n \n # Hyperparameters\n self.uncertainty_weight_coef = 1.0 # weight by inverse variance\n self.consistency_coef = 0.1 # consistency loss weight\n self.entropy_coef = 0.01 # entropy regularization weight\n self.label_smoothing = 0.1\n self.grad_penalty_coef = 10.0\n\n def compute_reward(self, obs, acts, next_obs):\n \"\"\"Compute learned reward: reward logit minus uncertainty penalty.\"\"\"\n with torch.no_grad():\n reward_logit, log_var = self.reward_net(obs, acts, next_obs)\n variance = torch.exp(log_var)\n # Reward = logit - beta * variance (encourage exploration in uncertain regions)\n beta = 0.1\n return reward_logit - beta * variance\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits, _ = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def _consistency_loss(self, obs, acts, next_obs, noise_std=0.01):\n \"\"\"Encourage reward consistency under small input perturbations.\"\"\"\n batch_size = obs.shape[0]\n # Add small Gaussian noise\n obs_noise = torch.randn_like(obs) * noise_std\n acts_noise = torch.randn_like(acts) * noise_std\n next_noise = torch.randn_like(next_obs) * noise_std\n \n logits_orig, _ = self.reward_net(obs, acts, next_obs)\n logits_noisy, _ = self.reward_net(obs + obs_noise, acts + acts_noise, next_obs + next_noise)\n \n return F.mse_loss(logits_orig, logits_noisy)\n\n def _entropy_regularization(self, logits, labels):\n \"\"\"Encourage discriminator to be uncertain (high entropy) about predictions.\"\"\"\n # Convert logits to probabilities\n probs = torch.sigmoid(logits)\n # Binary entropy: -p*log(p) - (1-p)*log(1-p)\n entropy = -probs * torch.log(probs + 1e-8) - (1 - probs) * torch.log(1 - probs + 1e-8)\n # We want high entropy (uncertainty) for generated data, low for expert?\n # Actually we want discriminator to be uncertain about both? Let's just maximize average entropy.\n return -entropy.mean() # negative because we want to maximize entropy\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n \"\"\"Update reward network with uncertainty weighting and consistency regularization.\"\"\"\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n # Sample expert data\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n # Sample policy data\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n # Get discriminator outputs with uncertainty\n expert_logits, expert_logvar = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_logits, gen_logvar = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n \n # Uncertainty weighting: weight = 1 / (variance + epsilon)\n expert_variance = torch.exp(expert_logvar)\n gen_variance = torch.exp(gen_logvar)\n expert_weights = 1.0 / (expert_variance + 1e-8)\n gen_weights = 1.0 / (gen_variance + 1e-8)\n \n # Normalize weights within each batch\n expert_weights = expert_weights / (expert_weights.mean() + 1e-8)\n gen_weights = gen_weights / (gen_weights.mean() + 1e-8)\n \n # Labels with smoothing\n eps = self.label_smoothing\n expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n gen_labels = torch.full((batch_size,), eps, device=self.device)\n \n # Weighted BCE loss\n expert_bce = F.binary_cross_entropy_with_logits(\n expert_logits, expert_labels, reduction='none'\n ) * expert_weights\n gen_bce = F.binary_cross_entropy_with_logits(\n gen_logits, gen_labels, reduction='none'\n ) * gen_weights\n bce_loss = (expert_bce.mean() + gen_bce.mean()) / 2.0\n \n # Consistency loss on both expert and policy data\n cons_loss_expert = self._consistency_loss(expert_obs, expert_acts, expert_next_obs)\n cons_loss_gen = self._consistency_loss(gen_obs, gen_acts, gen_next_obs)\n consistency_loss = (cons_loss_expert + cons_loss_gen) / 2.0\n \n # Entropy regularization on discriminator outputs\n all_logits = torch.cat([expert_logits, gen_logits], dim=0)\n all_labels = torch.cat([expert_labels, gen_labels], dim=0)\n entropy_loss = self._entropy_regularization(all_logits, all_labels)\n \n # Gradient penalty\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n \n # Total loss\n loss = (bce_loss + \n self.grad_penalty_coef * gp +\n self.consistency_coef * consistency_loss +\n self.entropy_coef * entropy_loss)\n \n self.optimizer.zero_grad()\n loss.backward()\n self.optimizer.step()\n \n with torch.no_grad():\n hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n torch.zeros(batch_size, device=self.device)], dim=0)\n acc = ((all_logits > 0).float() == hard_labels).float().mean().item()\n \n return {\n \"irl_loss\": loss.item(),\n \"disc_acc\": acc,\n \"grad_penalty\": gp.item(),\n \"consistency_loss\": consistency_loss.item(),\n \"entropy_loss\": entropy_loss.item(),\n \"expert_var\": expert_variance.mean().item(),\n \"gen_var\": gen_variance.mean().item(),\n }"
7}OK: Replaced lines 231..357 with 199 line(s) in imitation/custom_irl.py. Editable range: 231–429.
[Current file: imitation/custom_irl.py | editable: 231–429 | total: 671 lines]
231 class RewardNetwork(nn.Module):
232 """Reward network with uncertainty estimation.
233
... (193 more lines) ...
427 "expert_var": expert_variance.mean().item(),
428 "gen_var": gen_variance.mean().item(),
429 }[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=-0.0187 v_loss=0.0003 entropy=4.5918 irl_loss=0.3268 disc_acc=0.7520 grad_penalty=0.0006 consistency_loss=0.0001 entropy_loss=-0.4838 expert_var=12268.6602 gen_var=0.7329 TRAIN_METRICS step=995328 pg_loss=-0.0188 v_loss=0.0002 entropy=4.5676 irl_loss=0.3267 disc_acc=0.7461 grad_penalty=0.0006 consistency_loss=0.0001 entropy_loss=-0.4842 expert_var=9476.1035 gen_var=0.5576 TRAIN_METRICS step=997376 pg_loss=-0.0100 v_loss=0.0003 entropy=4.5703 irl_loss=0.3264 disc_acc=0.7910 grad_penalty=0.0005 consistency_loss=0.0001 entropy_loss=-0.4845 expert_var=17365.0918 gen_var=1.1725 TRAIN_METRICS step=999424 pg_loss=-0.0161 v_loss=0.0005 entropy=4.5985 irl_loss=0.3256 disc_acc=0.7441 grad_penalty=0.0004 consistency_loss=0.0001 entropy_loss=-0.4863 expert_var=9939.9092 gen_var=0.4175 TRAIN_METRI
Tool Input
1{
2 "op": "replace",
3 "filename": "imitation/custom_irl.py",
4 "start_line": 267,
5 "end_line": 429,
6 "content": "class IRLAlgorithm:\n \"\"\"Contrastive Adversarial IRL with Variance Regularization.\n \n Uses a discriminator with variance estimation to weight samples.\n Adds consistency regularization and variance control.\n \"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n self.total_updates = 0\n \n # Hyperparameters\n self.label_smoothing = 0.1\n self.grad_penalty_coef = 10.0\n self.consistency_coef = 0.05\n self.variance_reg_coef = 0.01\n self.variance_prior = 1.0 # target variance\n\n def compute_reward(self, obs, acts, next_obs):\n \"\"\"Compute learned reward: raw discriminator logit (like GAIL).\"\"\"\n with torch.no_grad():\n reward_logit, _ = self.reward_net(obs, acts, next_obs)\n return reward_logit\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits, _ = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def _consistency_loss(self, obs, acts, next_obs, noise_std=0.01):\n \"\"\"Encourage reward consistency under small input perturbations.\"\"\"\n batch_size = obs.shape[0]\n # Add small Gaussian noise\n obs_noise = torch.randn_like(obs) * noise_std\n acts_noise = torch.randn_like(acts) * noise_std\n next_noise = torch.randn_like(next_obs) * noise_std\n \n logits_orig, _ = self.reward_net(obs, acts, next_obs)\n logits_noisy, _ = self.reward_net(obs + obs_noise, acts + acts_noise, next_obs + next_noise)\n \n return F.mse_loss(logits_orig, logits_noisy)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n \"\"\"Update reward network with variance regularization and consistency.\"\"\"\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n # Sample expert data\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n # Sample policy data\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n # Get discriminator outputs\n expert_logits, expert_logvar = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_logits, gen_logvar = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n \n # Variance regularization: encourage expert variance low, policy variance close to prior\n expert_variance = torch.exp(expert_logvar)\n gen_variance = torch.exp(gen_logvar)\n \n var_loss_expert = expert_variance.mean() # minimize expert variance\n var_loss_gen = (gen_variance - self.variance_prior).pow(2).mean() # keep policy variance near prior\n variance_loss = var_loss_expert + var_loss_gen\n \n # Labels with smoothing\n eps = self.label_smoothing\n expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n gen_labels = torch.full((batch_size,), eps, device=self.device)\n \n # Standard BCE loss\n expert_bce = F.binary_cross_entropy_with_logits(expert_logits, expert_labels)\n gen_bce = F.binary_cross_entropy_with_logits(gen_logits, gen_labels)\n bce_loss = (expert_bce + gen_bce) / 2.0\n \n # Consistency loss on both expert and policy data\n cons_loss_expert = self._consistency_loss(expert_obs, expert_acts, expert_next_obs)\n cons_loss_gen = self._consistency_loss(gen_obs, gen_acts, gen_next_obs)\n consistency_loss = (cons_loss_expert + cons_loss_gen) / 2.0\n \n # Gradient penalty\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n \n # Total loss\n loss = (bce_loss + \n self.grad_penalty_coef * gp +\n self.consistency_coef * consistency_loss +\n self.variance_reg_coef * variance_loss)\n \n self.optimizer.zero_grad()\n loss.backward()\n self.optimizer.step()\n \n with torch.no_grad():\n hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n torch.zeros(batch_size, device=self.device)], dim=0)\n all_logits = torch.cat([expert_logits, gen_logits], dim=0)\n acc = ((all_logits > 0).float() == hard_labels).float().mean().item()\n \n return {\n \"irl_loss\": loss.item(),\n \"disc_acc\": acc,\n \"grad_penalty\": gp.item(),\n \"consistency_loss\": consistency_loss.item(),\n \"variance_loss\": variance_loss.item(),\n \"expert_var\": expert_variance.mean().item(),\n \"gen_var\": gen_variance.mean().item(),\n }"
7}OK: Replaced lines 267..429 with 142 line(s) in imitation/custom_irl.py. Editable range: 231–408.
[Current file: imitation/custom_irl.py | editable: 231–408 | total: 650 lines]
231 class RewardNetwork(nn.Module):
232 """Reward network with uncertainty estimation.
233
... (172 more lines) ...
406 "expert_var": expert_variance.mean().item(),
407 "gen_var": gen_variance.mean().item(),
408 }[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)
### halfcheetah-v4 (scripts/halfcheetah.sh)
[COMMAND FAILED exit=1]
/usr/bin/bash: warning: setlocale: LC_ALL: cannot change locale (zh_CN.UTF-8)
File "/workspace/imitation/custom_irl.py", line 413
"""Run PPO update on the rollout buffer. Returns metrics dict."""
^
SyntaxError: unterminated triple-quoted string literal (detected at line 650)
### hopper-v4 (scripts/hopper.sh)
[COMMAND FAILED exit=1]
/usr/bin/bash: warning: setlocale: LC_ALL: cannot change locale (zh_CN.UTF-8)
File "/workspace/imitation/custom_irl.py", line 413
"""Run PPO update on the rollout buffer. Returns metrics dict."""
^
SyntaxError: unterminated triple-quoted string literal (detected at line 650)
### walker2d-v4 (scripts/walker2d.sh)
[COMMAND FAILED exit=1]
/usr/bin/bash: wa
Tool Input
1{
2 "n": 1
3}Undo complete: Restored: imitation/custom_irl.py
Tool Input
1{
2 "op": "replace",
3 "filename": "imitation/custom_irl.py",
4 "start_line": 231,
5 "end_line": 429,
6 "content": "class RewardNetwork(nn.Module):\n \"\"\"Reward network R(s, a, s') -> scalar.\n \n Takes state, action, next_state as input and outputs a scalar reward.\n This is the discriminator/reward model used in IRL.\n \"\"\"\n\n def __init__(self, obs_dim, action_dim):\n super().__init__()\n input_dim = obs_dim + action_dim + obs_dim\n self.net = nn.Sequential(\n nn.Linear(input_dim, 256),\n nn.ReLU(),\n nn.Linear(256, 256),\n nn.ReLU(),\n nn.Linear(256, 1),\n )\n\n def forward(self, state, action, next_state):\n \"\"\"Compute reward for a batch of transitions.\n \n Args:\n state: (batch, obs_dim) current observations\n action: (batch, action_dim) actions taken\n next_state: (batch, obs_dim) next observations\n \n Returns:\n Reward tensor of shape (batch,)\n \"\"\"\n x = torch.cat([state, action, next_state], dim=-1)\n return self.net(x).squeeze(-1)\n\n\nclass IRLAlgorithm:\n \"\"\"Adversarial Behavioral Cloning (ABC): Combines GAIL discriminator with BC policy updates.\n \n Trains a discriminator to distinguish expert from policy data (GAIL),\n while also performing supervised BC updates on expert data.\n \"\"\"\n\n def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n self.reward_net = reward_net\n self.expert_demos = expert_demos\n self.device = device\n self.args = args\n self.obs_dim = obs_dim\n self.action_dim = action_dim\n\n # Discriminator optimizer\n disc_lr = min(args.irl_lr, 1e-4)\n self.disc_optimizer = optim.Adam(self.reward_net.parameters(), lr=disc_lr)\n self.total_updates = 0\n \n # Hyperparameters\n self.grad_penalty_coef = 10.0\n self.label_smoothing = 0.1\n self.bc_weight = 0.5 # weight of BC loss relative to discriminator loss\n self.bc_entropy_coef = 0.001\n \n # Policy reference (set later)\n self.policy = None\n self.policy_optimizer = None\n\n def set_policy(self, policy, optimizer):\n \"\"\"Set reference to policy for BC updates.\"\"\"\n self.policy = policy\n self.policy_optimizer = optimizer\n\n def compute_reward(self, obs, acts, next_obs):\n \"\"\"GAIL reward: use raw discriminator logit as reward.\"\"\"\n with torch.no_grad():\n logits = self.reward_net(obs, acts, next_obs)\n return logits\n\n def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs):\n \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n batch_size = expert_obs.shape[0]\n alpha = torch.rand(batch_size, 1, device=self.device)\n\n interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n logits = self.reward_net(interp_obs, interp_acts, interp_next)\n ones = torch.ones_like(logits)\n\n grads = torch.autograd.grad(\n outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n grad_outputs=ones, create_graph=True, retain_graph=True,\n )\n grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n return torch.mean((grad_norm - 1.0) ** 2)\n\n def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n \"\"\"Update discriminator with GAIL loss and policy with BC loss.\"\"\"\n self.total_updates += 1\n batch_size = self.args.irl_batch_size\n\n # Sample expert data for discriminator and BC\n n_expert = len(self.expert_demos[\"obs\"])\n expert_idx = torch.randint(0, n_expert, (batch_size,))\n expert_obs = self.expert_demos[\"obs\"][expert_idx]\n expert_acts = self.expert_demos[\"acts\"][expert_idx]\n expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n # Sample policy data for discriminator\n n_policy = len(policy_obs)\n policy_idx = torch.randint(0, n_policy, (batch_size,))\n gen_obs = policy_obs[policy_idx]\n gen_acts = policy_acts[policy_idx]\n gen_next_obs = policy_next_obs[policy_idx]\n\n # ----- Discriminator update (GAIL) -----\n expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n \n # Labels with smoothing\n eps = self.label_smoothing\n expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n gen_labels = torch.full((batch_size,), eps, device=self.device)\n \n logits = torch.cat([expert_logits, gen_logits], dim=0)\n labels = torch.cat([expert_labels, gen_labels], dim=0)\n \n bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n \n # Gradient penalty\n gp = self._gradient_penalty(\n expert_obs, expert_acts, expert_next_obs,\n gen_obs, gen_acts, gen_next_obs,\n )\n \n disc_loss = bce_loss + self.grad_penalty_coef * gp\n \n self.disc_optimizer.zero_grad()\n disc_loss.backward()\n self.disc_optimizer.step()\n \n with torch.no_grad():\n hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n torch.zeros(batch_size, device=self.device)], dim=0)\n disc_acc = ((logits > 0).float() == hard_labels).float().mean().item()\n \n # ----- Behavioral Cloning update -----\n bc_loss = 0.0\n if self.policy is not None:\n # Use a separate batch of expert data for BC\n bc_expert_idx = torch.randint(0, n_expert, (batch_size,))\n bc_expert_obs = self.expert_demos[\"obs\"][bc_expert_idx]\n bc_expert_acts = self.expert_demos[\"acts\"][bc_expert_idx]\n \n # Compute negative log-likelihood of expert actions under current policy\n _, log_prob, entropy, _ = self.policy.get_action_and_value(\n bc_expert_obs, bc_expert_acts,\n )\n neglogp = -log_prob.mean()\n ent_bonus = -self.bc_entropy_coef * entropy.mean()\n \n bc_loss = neglogp + ent_bonus\n \n # Update policy with BC loss\n self.policy_optimizer.zero_grad()\n bc_loss.backward()\n nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)\n self.policy_optimizer.step()\n \n return {\n \"irl_loss\": disc_loss.item(),\n \"disc_acc\": disc_acc,\n \"grad_penalty\": gp.item(),\n \"bc_loss\": bc_loss.item() if self.policy is not None else 0.0,\n }"
7}OK: Replaced lines 231..429 with 174 line(s) in imitation/custom_irl.py. Editable range: 231–404.
[Current file: imitation/custom_irl.py | editable: 231–404 | total: 646 lines]
231 class RewardNetwork(nn.Module):
232 """Reward network R(s, a, s') -> scalar.
233
... (168 more lines) ...
402 "grad_penalty": gp.item(),
403 "bc_loss": bc_loss.item() if self.policy is not None else 0.0,
404 }[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=0.2028 v_loss=89.0396 entropy=2.3968 irl_loss=0.4323 disc_acc=0.9648 grad_penalty=0.0030 bc_loss=0.9373 TRAIN_METRICS step=995328 pg_loss=0.2743 v_loss=51.2301 entropy=2.4133 irl_loss=0.4037 disc_acc=0.9902 grad_penalty=0.0039 bc_loss=1.4728 TRAIN_METRICS step=997376 pg_loss=0.2532 v_loss=53.9700 entropy=2.4409 irl_loss=0.3951 disc_acc=0.9883 grad_penalty=0.0035 bc_loss=1.9543 TRAIN_METRICS step=999424 pg_loss=0.1881 v_loss=33.4631 entropy=2.4599 irl_loss=0.3807 disc_acc=0.9922 grad_penalty=0.0031 bc_loss=1.7203 TRAIN_METRICS step=1000000 pg_loss=0.1883 v_loss=37.5047 entropy=2.4909 irl_loss=0.3964 disc_acc=1.0000 grad_penalty=0.0048 bc_loss=1.7508 Evaluation (hal
Tool Input
1{
2 "n": 3
3}[submit] Submitting result from test #3 as final. ### halfcheetah-v4 (scripts/halfcheetah.sh) Training metrics (last steps): TRAIN_METRICS step=993280 pg_loss=0.2028 v_loss=89.0396 entropy=2.3968 irl_loss=0.4323 disc_acc=0.9648 grad_penalty=0.0030 bc_loss=0.9373 TRAIN_METRICS step=995328 pg_loss=0.2743 v_loss=51.2301 entropy=2.4133 irl_loss=0.4037 disc_acc=0.9902 grad_penalty=0.0039 bc_loss=1.4728 TRAIN_METRICS step=997376 pg_loss=0.2532 v_loss=53.9700 entropy=2.4409 irl_loss=0.3951 disc_acc=0.9883 grad_penalty=0.0035 bc_loss=1.9543 TRAIN_METRICS step=999424 pg_loss=0.1881 v_loss=33.4631 entropy=2.4599 irl_loss=0.3807 disc_acc=0.9922 grad_penalty=0.0031 bc_loss=1.7203 TRAIN_METRICS step=1000000 pg_loss=0.1883 v_loss=37.5047 entropy=2.4909 irl_loss=0.3964 disc_acc=1.0000 grad_penalty=0.0048 bc_loss=1.7508 Evaluation (halfcheetah-v4): Eval episodic_return: 658.49 Eval episodic_return: 452.97 Eval episodic_return: 1066.74 Final eval return: 1066.74 ### hopper-v4 (scripts/hopper.sh) Trai