Agent Conversation: qwen3.6-plus

Model: qwen3.6-plus
Total steps: 17
Messages: 36
UserStep 0

Task: rl-reward-learning

Inverse RL: Reward Learning from Expert Demonstrations

Objective

Design and implement an inverse reinforcement learning (IRL) algorithm that learns a reward function from expert demonstrations. Your code goes in custom_irl.py, specifically the RewardNetwork and IRLAlgorithm classes. Three reference implementations (GAIL, AIRL, BC) from the imitation library are provided as read-only context.

Background

Inverse reinforcement learning recovers a reward function that explains observed expert behavior. The learned reward is then used to train a policy via standard RL (PPO in this benchmark). Key challenges include:

  • Designing reward network architectures that capture the structure of expert behavior
  • Balancing discriminator training with policy improvement
  • Avoiding reward hacking where the policy exploits learned reward artifacts
  • Ensuring the learned reward generalizes across different states visited during training

Different IRL approaches address these through adversarial training (GAIL), potential-based reward shaping (AIRL), or direct behavioral cloning. Your goal is to design a novel reward network architecture or IRL training algorithm that outperforms these baselines.

Evaluation

Trained and evaluated on three MuJoCo locomotion environments using pre-generated expert demonstrations: HalfCheetah-v4, Hopper-v4, Walker2d-v4. Metric: mean episodic return over 10 evaluation episodes (higher is better). The policy is trained using PPO with the learned reward signal.

imitation/custom_irl.py [EDITABLE — lines 231–357 only]

     1: # Custom IRL / Reward Learning algorithm for MLS-Bench
     2: #
     3: # EDITABLE section: RewardNetwork and IRLAlgorithm classes.
     4: # FIXED sections: everything else (config, env, demo loading, PPO training, evaluation).
     5: import os
     6: import random
     7: import time
     8: from dataclasses import dataclass
     9: 
    10: import gymnasium as gym
    11: import numpy as np
    12: import torch
    13: import torch.nn as nn
    14: import torch.nn.functional as F
    15: import torch.optim as optim
    16: 
    17: 
    18: # =====================================================================
    19: # FIXED: Configuration
    20: # =====================================================================
    21: @dataclass
    22: class Args:
    23:     env_id: str = "HalfCheetah-v4"
    24:     seed: int = 42
    25:     torch_deterministic: bool = True
    26:     cuda: bool = True
    27:     # IRL training
    28:     irl_epochs: int = 200
    29:     irl_batch_size: int = 256
    30:     irl_lr: float = 3e-4
    31:     demo_path: str = ""  # set from env or CLI
    32:     # Policy training (PPO via custom loop)
    33:     total_timesteps: int = 1000000
    34:     policy_lr: float = 3e-4
    35:     gamma: float = 0.99
    36:     gae_lambda: float = 0.95
    37:     n_steps: int = 2048
    38:     n_epochs: int = 10
    39
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "LLaMA-Factory/custom_irl.py",
4  "start_line": 231,
5  "end_line": 357,
6  "content": "class RewardNetwork(nn.Module):\n    \"\"\"Reward network with action-similarity embedding and potential shaping.\n\n    Architecture: R(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s')\n\n    where:\n    - g(s,a) is the base reward function\n    - h(s) is potential-based shaping for stability\n    - phi(s,a,s') is an action-similarity bonus that rewards transitions\n      where the action aligns with expert-like behavior patterns\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): base reward function\n        self.g_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        # h(s): potential-based shaping function\n        self.h_net = nn.Sequential(\n            nn.Linear(obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 128),\n            nn.ReLU(),\n            nn.Linear(128, 1),\n        )\n        # phi(s,a,s'): action-similarity embedding\n        # Uses a smaller network to model transition quality\n        self.phi_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim + obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 64),\n            nn.ReLU(),\n            nn.Linear(64, 1),\n        )\n        self.gamma = 0.99\n\n    def g(self, state, action):\n        \"\"\"Base reward function g(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.g_net(x).squeeze(-1)\n\n    def h(self, state):\n        \"\"\"Potential function h(s).\"\"\"\n        return self.h_net(state).squeeze(-1)\n\n    def phi(self, state, action, next_state):\n        \"\"\"Action-similarity bonus phi(s,a,s').\"\"\"\n        x = torch.cat([state, action, next_state], dim=-1)\n        return self.phi_net(x).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s').\"\"\"\n        return (self.g(state, action)\n                + self.gamma * self.h(next_state)\n                - self.h(state)\n                + self.phi(state, action, next_state))\n\n\nclass IRLAlgorithm:\n    \"\"\"Advantage-Weighted IRL with Transition Consistency (AW-IRL).\n\n    Combines three objectives:\n    1. GAIL-style discriminator to separate expert from policy data\n    2. Transition consistency: reward transitions that look \"expert-like\"\n    3. Advantage-weighted BC: directly train policy on expert data, weighted\n       by how expert-like the discriminator thinks each transition is\n\n    The key insight: the discriminator provides a notion of \"expert-ness\"\n    for each transition. We use this to weight BC updates, so the policy\n    focuses on mimicking expert behavior in regions where the discriminator\n    is most confident about what's expert-like.\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n        self.total_updates = 0\n\n        # Hyperparameters (hardcoded as per instructions)\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.bc_weight = 0.5  # Weight for BC auxiliary loss\n        self.transition_reg = 0.1  # Transition consistency regularization\n        self._policy = None\n        self._policy_optimizer = None\n\n    def set_policy(self, policy, optimizer):\n        \"\"\"Set reference to policy for BC updates.\"\"\"\n        self._policy = policy\n        self._policy_optimizer = optimizer\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use the full shaped reward for policy training.\"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"AW-IRL update: discriminator + transition consistency + BC.\n\n        1. Train discriminator to separate expert from policy data\n        2. Add transition consistency regularization\n        3. Optionally perform BC update on policy weighted by expert-ness\n        \"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # ── Discriminator logits ──\n        expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        # Labels with smoothing\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        logits = torch.cat([expert_logits, gen_logits], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # ── Transition Consistency Regularization ──\n        # Encourage the reward network to give consistent rewards for\n        # expert transitions: transitions with similar (s,a,s') should\n        # have similar rewards. This prevents reward hacking.\n        with torch.no_grad():\n            # Compute pairwise distances between expert transitions\n            expert_features = torch.cat([expert_obs, expert_acts, expert_next_obs], dim=-1)\n            # Subsample for efficiency\n            sub_size = min(64, batch_size)\n            sub_idx = torch.randperm(batch_size, device=self.device)[:sub_size]\n            sub_features = expert_features[sub_idx]\n\n        # Compute reward variance for nearby transitions (should be low)\n        # Using a simple approach: reward of expert transitions should be\n        # tightly clustered (low variance)\n        expert_rewards = self.reward_net.g(expert_obs, expert_acts)\n        reward_variance = torch.var(expert_rewards)\n\n        # Transition consistency: nearby states should have smooth rewards\n        # Compute reward differences for randomly paired expert transitions\n        perm_idx = torch.randperm(batch_size, device=self.device)\n        diff_obs = expert_obs - expert_obs[perm_idx]\n        diff_acts = expert_acts - expert_acts[perm_idx]\n        dist = torch.norm(diff_obs, dim=-1) + torch.norm(diff_acts, dim=-1)\n\n        reward_diff = torch.abs(\n            self.reward_net.g(expert_obs, expert_acts) -\n            self.reward_net.g(expert_obs[perm_idx], expert_acts[perm_idx])\n        )\n        # Penalize large reward differences for similar transitions\n        # Weight by inverse distance (closer transitions should have more similar rewards)\n        smoothness_weight = torch.exp(-dist / (dist.mean() + 1e-8))\n        transition_consistency = (smoothness_weight * reward_diff).mean()\n\n        # ── Total IRL loss ──\n        loss = bce_loss + self.grad_penalty_coef * gp + self.transition_reg * transition_consistency\n\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        # ── Advantage-weighted BC update on policy ──\n        bc_loss = torch.tensor(0.0, device=self.device)\n        if self._policy is not None and self._policy_optimizer is not None:\n            # Sample expert batch for BC\n            bc_expert_idx = torch.randint(0, n_expert, (batch_size,))\n            bc_expert_obs = self.expert_demos[\"obs\"][bc_expert_idx]\n            bc_expert_acts = self.expert_demos[\"acts\"][bc_expert_idx]\n\n            # Compute discriminator \"expert-ness\" as weights\n            with torch.no_grad():\n                expert_logits_for_weight = self.reward_net(\n                    bc_expert_obs, bc_expert_acts,\n                    self.expert_demos[\"next_obs\"][bc_expert_idx]\n                )\n                # Sigmoid gives probability of being expert\n                expert_weights = torch.sigmoid(expert_logits_for_weight)\n                # Use as weights for BC (higher weight for more expert-like)\n                # Clamp to avoid extreme weights\n                expert_weights = expert_weights.clamp(0.1, 1.0)\n\n            # BC loss weighted by exp\n... (truncated)"
7}
Tool ResultStep 1
ERROR: Package 'LLaMA-Factory' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "filename": "custom_irl.py",
3  "content": "class RewardNetwork(nn.Module):\n    \"\"\"Reward network with action-similarity embedding and potential shaping.\n\n    Architecture: R(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s')\n\n    where:\n    - g(s,a) is the base reward function\n    - h(s) is potential-based shaping for stability\n    - phi(s,a,s') is an action-similarity bonus that rewards transitions\n      where the action aligns with expert-like behavior patterns\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): base reward function\n        self.g_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        # h(s): potential-based shaping function\n        self.h_net = nn.Sequential(\n            nn.Linear(obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 128),\n            nn.ReLU(),\n            nn.Linear(128, 1),\n        )\n        # phi(s,a,s'): action-similarity embedding\n        # Uses a smaller network to model transition quality\n        self.phi_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim + obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 64),\n            nn.ReLU(),\n            nn.Linear(64, 1),\n        )\n        self.gamma = 0.99\n\n    def g(self, state, action):\n        \"\"\"Base reward function g(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.g_net(x).squeeze(-1)\n\n    def h(self, state):\n        \"\"\"Potential function h(s).\"\"\"\n        return self.h_net(state).squeeze(-1)\n\n    def phi(self, state, action, next_state):\n        \"\"\"Action-similarity bonus phi(s,a,s').\"\"\"\n        x = torch.cat([state, action, next_state], dim=-1)\n        return self.phi_net(x).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s').\"\"\"\n        return (self.g(state, action)\n                + self.gamma * self.h(next_state)\n                - self.h(state)\n                + self.phi(state, action, next_state))\n\n\nclass IRLAlgorithm:\n    \"\"\"Advantage-Weighted IRL with Transition Consistency (AW-IRL).\n\n    Combines three objectives:\n    1. GAIL-style discriminator to separate expert from policy data\n    2. Transition consistency: reward transitions that look \"expert-like\"\n    3. Advantage-weighted BC: directly train policy on expert data, weighted\n       by how expert-like the discriminator thinks each transition is\n\n    The key insight: the discriminator provides a notion of \"expert-ness\"\n    for each transition. We use this to weight BC updates, so the policy\n    focuses on mimicking expert behavior in regions where the discriminator\n    is most confident about what's expert-like.\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n        self.total_updates = 0\n\n        # Hyperparameters (hardcoded as per instructions)\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.bc_weight = 0.5  # Weight for BC auxiliary loss\n        self.transition_reg = 0.1  # Transition consistency regularization\n        self._policy = None\n        self._policy_optimizer = None\n\n    def set_policy(self, policy, optimizer):\n        \"\"\"Set reference to policy for BC updates.\"\"\"\n        self._policy = policy\n        self._policy_optimizer = optimizer\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use the full shaped reward for policy training.\"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"AW-IRL update: discriminator + transition consistency + BC.\n\n        1. Train discriminator to separate expert from policy data\n        2. Add transition consistency regularization\n        3. Optionally perform BC update on policy weighted by expert-ness\n        \"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # ── Discriminator logits ──\n        expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        # Labels with smoothing\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        logits = torch.cat([expert_logits, gen_logits], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # ── Transition Consistency Regularization ──\n        # Encourage the reward network to give consistent rewards for\n        # expert transitions: transitions with similar (s,a,s') should\n        # have similar rewards. This prevents reward hacking.\n        # Transition consistency: nearby states should have smooth rewards\n        # Compute reward differences for randomly paired expert transitions\n        perm_idx = torch.randperm(batch_size, device=self.device)\n        diff_obs = expert_obs - expert_obs[perm_idx]\n        diff_acts = expert_acts - expert_acts[perm_idx]\n        dist = torch.norm(diff_obs, dim=-1) + torch.norm(diff_acts, dim=-1)\n\n        reward_diff = torch.abs(\n            self.reward_net.g(expert_obs, expert_acts) -\n            self.reward_net.g(expert_obs[perm_idx], expert_acts[perm_idx])\n        )\n        # Penalize large reward differences for similar transitions\n        # Weight by inverse distance (closer transitions should have more similar rewards)\n        smoothness_weight = torch.exp(-dist / (dist.mean() + 1e-8))\n        transition_consistency = (smoothness_weight * reward_diff).mean()\n\n        # ── Total IRL loss ──\n        loss = bce_loss + self.grad_penalty_coef * gp + self.transition_reg * transition_consistency\n\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        # ── Advantage-weighted BC update on policy ──\n        bc_loss = torch.tensor(0.0, device=self.device)\n        if self._policy is not None and self._policy_optimizer is not None:\n            # Sample expert batch for BC\n            bc_expert_idx = torch.randint(0, n_expert, (batch_size,))\n            bc_expert_obs = self.expert_demos[\"obs\"][bc_expert_idx]\n            bc_expert_acts = self.expert_demos[\"acts\"][bc_expert_idx]\n\n            # Compute discriminator \"expert-ness\" as weights\n            with torch.no_grad():\n                expert_logits_for_weight = self.reward_net(\n                    bc_expert_obs, bc_expert_acts,\n                    self.expert_demos[\"next_obs\"][bc_expert_idx]\n                )\n                # Sigmoid gives probability of being expert\n                expert_weights = torch.sigmoid(expert_logits_for_weight)\n                # Use as weights for BC (higher weight for more expert-like)\n                # Clamp to avoid extreme weights\n                expert_weights = expert_weights.clamp(0.1, 1.0)\n\n            # BC loss weighted by expert-ness\n            _, log_prob, entropy, _ = self._policy.get_action_and_value(\n                bc_expert_obs, bc_expert_acts,\n            )\n\n            # Advantage-weighted: weight by expert confidence\n            weighted_nll = -(expert_weights * log_prob).mean()\n            ent_bonus = -0.001 * entropy.mean()\n\n            bc_loss = self.bc_weight * (weighted_nll + ent_bonus)\n\n            self._policy_optimizer.zero_grad()\n            bc_loss.backward()\n            nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n            self._policy_optimizer.step()\n\n        with torch.no_grad():\n            hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n                              \n... (truncated)",
4  "start_line": 231,
5  "end_line": 356
6}
Tool ResultStep 2
ERROR: Invalid arguments for 'edit': WorkspaceTools.edit() missing 1 required positional argument: 'op'
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "custom_irl.py",
4  "start_line": 231,
5  "end_line": 356,
6  "content": "class RewardNetwork(nn.Module):\n    \"\"\"Reward network with action-similarity embedding and potential shaping.\n\n    Architecture: R(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s')\n\n    where:\n    - g(s,a) is the base reward function\n    - h(s) is potential-based shaping for stability\n    - phi(s,a,s') is an action-similarity bonus that rewards transitions\n      where the action aligns with expert-like behavior patterns\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): base reward function\n        self.g_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        # h(s): potential-based shaping function\n        self.h_net = nn.Sequential(\n            nn.Linear(obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 128),\n            nn.ReLU(),\n            nn.Linear(128, 1),\n        )\n        # phi(s,a,s'): action-similarity embedding\n        # Uses a smaller network to model transition quality\n        self.phi_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim + obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 64),\n            nn.ReLU(),\n            nn.Linear(64, 1),\n        )\n        self.gamma = 0.99\n\n    def g(self, state, action):\n        \"\"\"Base reward function g(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.g_net(x).squeeze(-1)\n\n    def h(self, state):\n        \"\"\"Potential function h(s).\"\"\"\n        return self.h_net(state).squeeze(-1)\n\n    def phi(self, state, action, next_state):\n        \"\"\"Action-similarity bonus phi(s,a,s').\"\"\"\n        x = torch.cat([state, action, next_state], dim=-1)\n        return self.phi_net(x).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s').\"\"\"\n        return (self.g(state, action)\n                + self.gamma * self.h(next_state)\n                - self.h(state)\n                + self.phi(state, action, next_state))\n\n\nclass IRLAlgorithm:\n    \"\"\"Advantage-Weighted IRL with Transition Consistency (AW-IRL).\n\n    Combines three objectives:\n    1. GAIL-style discriminator to separate expert from policy data\n    2. Transition consistency: reward transitions that look \"expert-like\"\n    3. Advantage-weighted BC: directly train policy on expert data, weighted\n       by how expert-like the discriminator thinks each transition is\n\n    The key insight: the discriminator provides a notion of \"expert-ness\"\n    for each transition. We use this to weight BC updates, so the policy\n    focuses on mimicking expert behavior in regions where the discriminator\n    is most confident about what's expert-like.\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n        self.total_updates = 0\n\n        # Hyperparameters (hardcoded as per instructions)\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.bc_weight = 0.5  # Weight for BC auxiliary loss\n        self.transition_reg = 0.1  # Transition consistency regularization\n        self._policy = None\n        self._policy_optimizer = None\n\n    def set_policy(self, policy, optimizer):\n        \"\"\"Set reference to policy for BC updates.\"\"\"\n        self._policy = policy\n        self._policy_optimizer = optimizer\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use the full shaped reward for policy training.\"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"AW-IRL update: discriminator + transition consistency + BC.\n\n        1. Train discriminator to separate expert from policy data\n        2. Add transition consistency regularization\n        3. Optionally perform BC update on policy weighted by expert-ness\n        \"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # ── Discriminator logits ──\n        expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        # Labels with smoothing\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        logits = torch.cat([expert_logits, gen_logits], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # ── Transition Consistency Regularization ──\n        # Encourage the reward network to give consistent rewards for\n        # expert transitions: transitions with similar (s,a,s') should\n        # have similar rewards. This prevents reward hacking.\n        perm_idx = torch.randperm(batch_size, device=self.device)\n        diff_obs = expert_obs - expert_obs[perm_idx]\n        diff_acts = expert_acts - expert_acts[perm_idx]\n        dist = torch.norm(diff_obs, dim=-1) + torch.norm(diff_acts, dim=-1)\n\n        reward_diff = torch.abs(\n            self.reward_net.g(expert_obs, expert_acts) -\n            self.reward_net.g(expert_obs[perm_idx], expert_acts[perm_idx])\n        )\n        # Penalize large reward differences for similar transitions\n        smoothness_weight = torch.exp(-dist / (dist.mean() + 1e-8))\n        transition_consistency = (smoothness_weight * reward_diff).mean()\n\n        # ── Total IRL loss ──\n        loss = bce_loss + self.grad_penalty_coef * gp + self.transition_reg * transition_consistency\n\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        # ── Advantage-weighted BC update on policy ──\n        bc_loss = torch.tensor(0.0, device=self.device)\n        if self._policy is not None and self._policy_optimizer is not None:\n            # Sample expert batch for BC\n            bc_expert_idx = torch.randint(0, n_expert, (batch_size,))\n            bc_expert_obs = self.expert_demos[\"obs\"][bc_expert_idx]\n            bc_expert_acts = self.expert_demos[\"acts\"][bc_expert_idx]\n\n            # Compute discriminator \"expert-ness\" as weights\n            with torch.no_grad():\n                expert_logits_for_weight = self.reward_net(\n                    bc_expert_obs, bc_expert_acts,\n                    self.expert_demos[\"next_obs\"][bc_expert_idx]\n                )\n                # Sigmoid gives probability of being expert\n                expert_weights = torch.sigmoid(expert_logits_for_weight)\n                # Clamp to avoid extreme weights\n                expert_weights = expert_weights.clamp(0.1, 1.0)\n\n            # BC loss weighted by expert-ness\n            _, log_prob, entropy, _ = self._policy.get_action_and_value(\n                bc_expert_obs, bc_expert_acts,\n            )\n\n            # Advantage-weighted: weight by expert confidence\n            weighted_nll = -(expert_weights * log_prob).mean()\n            ent_bonus = -0.001 * entropy.mean()\n\n            bc_loss = self.bc_weight * (weighted_nll + ent_bonus)\n\n            self._policy_optimizer.zero_grad()\n            bc_loss.backward()\n            nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n            self._policy_optimizer.step()\n\n        with torch.no_grad():\n            hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n                                     torch.zeros(batch_size, device=self.device)], dim=0)\n            acc = ((logits > 0).float() == hard_labels).float().mean().item()\n\n        return {\n            \"irl_loss\": loss.item(),\n            \"disc_acc\": acc,\n            \"grad_penalty\": gp.item(),\n            \"transition_reg\": transition_consistency.item(\n... (truncated)"
7}
Tool ResultStep 3
ERROR: Package 'custom_irl.py' is not in allowed packages
AssistantStep 4edit
Tool Input
1{
2  "content": "class RewardNetwork(nn.Module):\n    \"\"\"Reward network with potential shaping and phi consistency term.\n\n    Architecture: R(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s')\n\n    where:\n    - g(s,a) is the base reward function\n    - h(s) is potential-based shaping for stability\n    - phi(s,a,s') is a transition consistency bonus\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): base reward function\n        self.g_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        # h(s): potential-based shaping function\n        self.h_net = nn.Sequential(\n            nn.Linear(obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 128),\n            nn.ReLU(),\n            nn.Linear(128, 1),\n        )\n        # phi(s,a,s'): transition consistency embedding\n        self.phi_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim + obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 64),\n            nn.ReLU(),\n            nn.Linear(64, 1),\n        )\n        self.gamma = 0.99\n\n    def g(self, state, action):\n        \"\"\"Base reward function g(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.g_net(x).squeeze(-1)\n\n    def h(self, state):\n        \"\"\"Potential function h(s).\"\"\"\n        return self.h_net(state).squeeze(-1)\n\n    def phi(self, state, action, next_state):\n        \"\"\"Transition consistency bonus phi(s,a,s').\"\"\"\n        x = torch.cat([state, action, next_state], dim=-1)\n        return self.phi_net(x).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s').\"\"\"\n        return (self.g(state, action)\n                + self.gamma * self.h(next_state)\n                - self.h(state)\n                + self.phi(state, action, next_state))\n\n\nclass IRLAlgorithm:\n    \"\"\"Advantage-Weighted IRL with Transition Consistency (AW-IRL).\n\n    Combines three objectives:\n    1. GAIL-style discriminator to separate expert from policy data\n    2. Transition consistency regularization on the reward network\n    3. Advantage-weighted BC: directly train policy on expert data, weighted\n       by how expert-like the discriminator thinks each transition is\n\n    The key insight: the discriminator provides a notion of \"expert-ness\"\n    for each transition. We use this to weight BC updates, so the policy\n    focuses on mimicking expert behavior in regions where the discriminator\n    is most confident about what's expert-like.\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n        self.total_updates = 0\n\n        # Hyperparameters (hardcoded as per instructions)\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.bc_weight = 0.5\n        self.transition_reg = 0.1\n        self._policy = None\n        self._policy_optimizer = None\n\n    def set_policy(self, policy, optimizer):\n        \"\"\"Set reference to policy for BC updates.\"\"\"\n        self._policy = policy\n        self._policy_optimizer = optimizer\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use the full shaped reward for policy training.\"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"AW-IRL update: discriminator + transition consistency + BC.\"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # Discriminator logits\n        expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        # Labels with smoothing\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        logits = torch.cat([expert_logits, gen_logits], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # Transition Consistency Regularization\n        perm_idx = torch.randperm(batch_size, device=self.device)\n        diff_obs = expert_obs - expert_obs[perm_idx]\n        diff_acts = expert_acts - expert_acts[perm_idx]\n        dist = torch.norm(diff_obs, dim=-1) + torch.norm(diff_acts, dim=-1)\n\n        reward_diff = torch.abs(\n            self.reward_net.g(expert_obs, expert_acts) -\n            self.reward_net.g(expert_obs[perm_idx], expert_acts[perm_idx])\n        )\n        smoothness_weight = torch.exp(-dist / (dist.mean() + 1e-8))\n        transition_consistency = (smoothness_weight * reward_diff).mean()\n\n        # Total IRL loss\n        loss = bce_loss + self.grad_penalty_coef * gp + self.transition_reg * transition_consistency\n\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        # Advantage-weighted BC update on policy\n        bc_loss = torch.tensor(0.0, device=self.device)\n        if self._policy is not None and self._policy_optimizer is not None:\n            bc_expert_idx = torch.randint(0, n_expert, (batch_size,))\n            bc_expert_obs = self.expert_demos[\"obs\"][bc_expert_idx]\n            bc_expert_acts = self.expert_demos[\"acts\"][bc_expert_idx]\n\n            with torch.no_grad():\n                expert_logits_for_weight = self.reward_net(\n                    bc_expert_obs, bc_expert_acts,\n                    self.expert_demos[\"next_obs\"][bc_expert_idx]\n                )\n                expert_weights = torch.sigmoid(expert_logits_for_weight)\n                expert_weights = expert_weights.clamp(0.1, 1.0)\n\n            _, log_prob, entropy, _ = self._policy.get_action_and_value(\n                bc_expert_obs, bc_expert_acts,\n            )\n\n            weighted_nll = -(expert_weights * log_prob).mean()\n            ent_bonus = -0.001 * entropy.mean()\n\n            bc_loss = self.bc_weight * (weighted_nll + ent_bonus)\n\n            self._policy_optimizer.zero_grad()\n            bc_loss.backward()\n            nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n            self._policy_optimizer.step()\n\n        with torch.no_grad():\n            hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n                                     torch.zeros(batch_size, device=self.device)], dim=0)\n            acc = ((logits > 0).float() == hard_labels).float().mean().item()\n\n        return {\n            \"irl_loss\": loss.item(),\n            \"disc_acc\": acc,\n            \"grad_penalty\": gp.item(),\n            \"transition_reg\": transition_consistency.item(),\n            \"bc_loss\": bc_loss.item() if isinstance(bc_loss, torch.Tensor) else 0.0,\n        }",
3  "end_line": 356,
4  "filename": "custom_irl.py",
5  "op": "replace",
6  "start_line": 231
7}
Tool ResultStep 4
ERROR: Package 'custom_irl.py' is not in allowed packages
AssistantStep 5edit
Tool Input
1{
2  "content": "class RewardNetwork(nn.Module):\n    \"\"\"Reward network with potential shaping and phi consistency term.\n\n    Architecture: R(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s')\n\n    where:\n    - g(s,a) is the base reward function\n    - h(s) is potential-based shaping for stability\n    - phi(s,a,s') is a transition consistency bonus\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): base reward function\n        self.g_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        # h(s): potential-based shaping function\n        self.h_net = nn.Sequential(\n            nn.Linear(obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 128),\n            nn.ReLU(),\n            nn.Linear(128, 1),\n        )\n        # phi(s,a,s'): transition consistency embedding\n        self.phi_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim + obs_dim, 128),\n            nn.ReLU(),\n            nn.Linear(128, 64),\n            nn.ReLU(),\n            nn.Linear(64, 1),\n        )\n        self.gamma = 0.99\n\n    def g(self, state, action):\n        \"\"\"Base reward function g(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.g_net(x).squeeze(-1)\n\n    def h(self, state):\n        \"\"\"Potential function h(s).\"\"\"\n        return self.h_net(state).squeeze(-1)\n\n    def phi(self, state, action, next_state):\n        \"\"\"Transition consistency bonus phi(s,a,s').\"\"\"\n        x = torch.cat([state, action, next_state], dim=-1)\n        return self.phi_net(x).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s) + phi(s,a,s').\"\"\"\n        return (self.g(state, action)\n                + self.gamma * self.h(next_state)\n                - self.h(state)\n                + self.phi(state, action, next_state))\n\n\nclass IRLAlgorithm:\n    \"\"\"Advantage-Weighted IRL with Transition Consistency (AW-IRL).\n\n    Combines three objectives:\n    1. GAIL-style discriminator to separate expert from policy data\n    2. Transition consistency regularization on the reward network\n    3. Advantage-weighted BC: directly train policy on expert data, weighted\n       by how expert-like the discriminator thinks each transition is\n\n    The key insight: the discriminator provides a notion of \"expert-ness\"\n    for each transition. We use this to weight BC updates, so the policy\n    focuses on mimicking expert behavior in regions where the discriminator\n    is most confident about what's expert-like.\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n        self.total_updates = 0\n\n        # Hyperparameters (hardcoded as per instructions)\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.bc_weight = 0.5\n        self.transition_reg = 0.1\n        self._policy = None\n        self._policy_optimizer = None\n\n    def set_policy(self, policy, optimizer):\n        \"\"\"Set reference to policy for BC updates.\"\"\"\n        self._policy = policy\n        self._policy_optimizer = optimizer\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use the full shaped reward for policy training.\"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"AW-IRL update: discriminator + transition consistency + BC.\"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # Discriminator logits\n        expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        # Labels with smoothing\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        logits = torch.cat([expert_logits, gen_logits], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # Transition Consistency Regularization\n        perm_idx = torch.randperm(batch_size, device=self.device)\n        diff_obs = expert_obs - expert_obs[perm_idx]\n        diff_acts = expert_acts - expert_acts[perm_idx]\n        dist = torch.norm(diff_obs, dim=-1) + torch.norm(diff_acts, dim=-1)\n\n        reward_diff = torch.abs(\n            self.reward_net.g(expert_obs, expert_acts) -\n            self.reward_net.g(expert_obs[perm_idx], expert_acts[perm_idx])\n        )\n        smoothness_weight = torch.exp(-dist / (dist.mean() + 1e-8))\n        transition_consistency = (smoothness_weight * reward_diff).mean()\n\n        # Total IRL loss\n        loss = bce_loss + self.grad_penalty_coef * gp + self.transition_reg * transition_consistency\n\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        # Advantage-weighted BC update on policy\n        bc_loss = torch.tensor(0.0, device=self.device)\n        if self._policy is not None and self._policy_optimizer is not None:\n            bc_expert_idx = torch.randint(0, n_expert, (batch_size,))\n            bc_expert_obs = self.expert_demos[\"obs\"][bc_expert_idx]\n            bc_expert_acts = self.expert_demos[\"acts\"][bc_expert_idx]\n\n            with torch.no_grad():\n                expert_logits_for_weight = self.reward_net(\n                    bc_expert_obs, bc_expert_acts,\n                    self.expert_demos[\"next_obs\"][bc_expert_idx]\n                )\n                expert_weights = torch.sigmoid(expert_logits_for_weight)\n                expert_weights = expert_weights.clamp(0.1, 1.0)\n\n            _, log_prob, entropy, _ = self._policy.get_action_and_value(\n                bc_expert_obs, bc_expert_acts,\n            )\n\n            weighted_nll = -(expert_weights * log_prob).mean()\n            ent_bonus = -0.001 * entropy.mean()\n\n            bc_loss = self.bc_weight * (weighted_nll + ent_bonus)\n\n            self._policy_optimizer.zero_grad()\n            bc_loss.backward()\n            nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n            self._policy_optimizer.step()\n\n        with torch.no_grad():\n            hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n                                     torch.zeros(batch_size, device=self.device)], dim=0)\n            acc = ((logits > 0).float() == hard_labels).float().mean().item()\n\n        return {\n            \"irl_loss\": loss.item(),\n            \"disc_acc\": acc,\n            \"grad_penalty\": gp.item(),\n            \"transition_reg\": transition_consistency.item(),\n            \"bc_loss\": bc_loss.item() if isinstance(bc_loss, torch.Tensor) else 0.0,\n        }",
3  "end_line": 356,
4  "filename": "imitation/custom_irl.py",
5  "op": "replace",
6  "start_line": 231
7}
Tool ResultStep 5
OK: Replaced lines 231..356 with 226 line(s) in imitation/custom_irl.py. Editable range: 231–457.

[Current file: imitation/custom_irl.py | editable: 231–457 | total: 699 lines]
   231  class RewardNetwork(nn.Module):
   232      """Reward network with potential shaping and phi consistency term.
   233  
       ... (221 more lines) ...
   455              "bc_loss": bc_loss.item() if isinstance(bc_loss, torch.Tensor) else 0.0,
   456          }
   457  
AssistantStep 6test
Tool ResultStep 6
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### halfcheetah-v4 (scripts/halfcheetah.sh)
Training metrics (last steps):
TRAIN_METRICS step=993280 pg_loss=0.1877 v_loss=30.7484 entropy=1.6392 irl_loss=0.3709 disc_acc=0.9883 grad_penalty=0.0012 transition_reg=0.0485 bc_loss=-0.0613
TRAIN_METRICS step=995328 pg_loss=0.0897 v_loss=21.7861 entropy=1.6648 irl_loss=0.3616 disc_acc=0.9902 grad_penalty=0.0012 transition_reg=0.0452 bc_loss=0.7740
TRAIN_METRICS step=997376 pg_loss=0.2416 v_loss=50.4157 entropy=1.6372 irl_loss=0.3980 disc_acc=0.9453 grad_penalty=0.0010 transition_reg=0.0448 bc_loss=0.0218
TRAIN_METRICS step=999424 pg_loss=0.0643 v_loss=19.2602 entropy=1.6516 irl_loss=0.3685 disc_acc=0.9883 grad_penalty=0.0011 transition_reg=0.0525 bc_loss=0.1913
TRAIN_METRICS step=1000000 pg_loss=0.0915 v_loss=15.2053 entropy=1.6682 irl_loss=0.3542 disc_acc=0.9922 grad_penalty=0.0010 transition_reg=0.0523 bc_loss=0.3104
Evaluation (halfcheetah-v4)
AssistantStep 7edit
Tool Input
1{
2  "content": "class RewardNetwork(nn.Module):\n    \"\"\"Progressive reward network with AIRL shaping and spectral normalization.\n\n    Architecture: R(s,a,s') = g(s,a) + gamma*h(s') - h(s)\n\n    Uses spectral normalization on g_net to prevent the discriminator from\n    becoming too confident, which maintains meaningful gradient signals\n    for policy improvement.\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): reward function with spectral normalization\n        self.g_fc1 = nn.Linear(obs_dim + action_dim, 256)\n        self.g_fc2 = nn.Linear(256, 256)\n        self.g_fc3 = nn.Linear(256, 1)\n        # h(s): potential-based shaping\n        self.h_fc1 = nn.Linear(obs_dim, 256)\n        self.h_fc2 = nn.Linear(256, 256)\n        self.h_fc3 = nn.Linear(256, 1)\n        self.gamma = 0.99\n\n        # Apply spectral normalization\n        self.g_fc1 = nn.utils.parametrizations.spectral_norm(self.g_fc1)\n        self.g_fc2 = nn.utils.parametrizations.spectral_norm(self.g_fc2)\n        self.g_fc3 = nn.utils.parametrizations.spectral_norm(self.g_fc3)\n        self.h_fc1 = nn.utils.parametrizations.spectral_norm(self.h_fc1)\n        self.h_fc2 = nn.utils.parametrizations.spectral_norm(self.h_fc2)\n        self.h_fc3 = nn.utils.parametrizations.spectral_norm(self.h_fc3)\n\n    def g(self, state, action):\n        \"\"\"Reward function g(s, a) with spectral norm bounded.\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        x = F.relu(self.g_fc1(x))\n        x = F.relu(self.g_fc2(x))\n        return self.g_fc3(x).squeeze(-1)\n\n    def h(self, state):\n        \"\"\"Potential function h(s).\"\"\"\n        x = F.relu(self.h_fc1(state))\n        x = F.relu(self.h_fc2(x))\n        return self.h_fc3(x).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s).\"\"\"\n        return self.g(state, action) + self.gamma * self.h(next_state) - self.h(state)\n\n\nclass IRLAlgorithm:\n    \"\"\"Progressive AIRL with Action Matching (PA-IRL).\n\n    Key innovations:\n    1. Spectral normalization on reward network prevents overconfidence\n    2. Progressive discriminator: temperature increases over training,\n       starting with soft gradients that sharpen as policy improves\n    3. Action matching regularization: rewards transitions where policy\n       actions are close to expert actions in similar states\n    4. Conservative margin loss: ensures expert rewards exceed policy\n       rewards by a margin, but doesn't push them apart unnecessarily\n\n    The progressive temperature schedule is critical: early in training,\n    the policy is poor and the discriminator could easily separate them.\n    With high temperature (soft sigmoid), the discriminator still provides\n    meaningful gradients. As training progresses, temperature decreases\n    (sharper sigmoid) for stronger discrimination.\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n        self.total_updates = 0\n\n        # Hyperparameters (hardcoded)\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.action_match_weight = 0.5\n        self.margin = 1.0  # Minimum margin between expert and policy rewards\n\n        # Temperature schedule: starts high (soft), decreases over time\n        self.initial_temp = 5.0\n        self.min_temp = 1.0\n        self.total_updates_for_schedule = 500  # Updates to reach min temp\n\n        # Store expert actions in a simple k-d tree-like structure for action matching\n        self._expert_obs_buffer = expert_demos[\"obs\"]\n        self._expert_acts_buffer = expert_demos[\"acts\"]\n\n    def _get_temperature(self):\n        \"\"\"Get current discriminator temperature (progressive schedule).\"\"\"\n        progress = min(self.total_updates / self.total_updates_for_schedule, 1.0)\n        return self.initial_temp - (self.initial_temp - self.min_temp) * progress\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use the full shaped reward for policy training.\"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def _action_matching_reward(self, policy_obs, policy_acts):\n        \"\"\"Compute action matching bonus: reward actions similar to expert.\n\n        For each policy transition, find the nearest expert state and\n        reward the policy for taking similar actions. This provides a\n        direct imitation signal that complements the adversarial signal.\n        \"\"\"\n        batch_size = policy_obs.shape[0]\n        sub_size = min(128, batch_size)\n        sub_idx = torch.randperm(batch_size, device=self.device)[:sub_size]\n        policy_obs_sub = policy_obs[sub_idx]\n        policy_acts_sub = policy_acts[sub_idx]\n\n        # Find nearest expert states (efficient approximate nearest neighbor)\n        n_expert = len(self._expert_obs_buffer)\n        # Sample a subset of expert states for efficiency\n        expert_sub_size = min(512, n_expert)\n        expert_sub_idx = torch.randint(0, n_expert, (expert_sub_size,), device=self.device)\n        expert_obs_sub = self._expert_obs_buffer[expert_sub_idx]\n        expert_acts_sub = self._expert_acts_buffer[expert_sub_idx]\n\n        # Compute distances: (sub_size, expert_sub_size)\n        diff = policy_obs_sub.unsqueeze(1) - expert_obs_sub.unsqueeze(0)\n        dists = torch.sum(diff ** 2, dim=-1)  # (sub_size, expert_sub_size)\n\n        # Find nearest expert for each policy state\n        nearest_idx = torch.argmin(dists, dim=1)  # (sub_size,)\n        nearest_expert_acts = expert_acts_sub[nearest_idx]  # (sub_size, action_dim)\n\n        # Action matching bonus: negative squared distance (higher = better)\n        action_dist = torch.sum((policy_acts_sub - nearest_expert_acts) ** 2, dim=-1)\n        action_bonus = torch.exp(-action_dist / (action_dist.mean() + 1e-8))\n\n        return action_bonus\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"Progressive AIRL update with action matching.\"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n        temp = self._get_temperature()\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # Raw shaped rewards (logits)\n        expert_rewards = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_rewards = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        # Progressive discriminator loss with temperature\n        # High temp early = soft discrimination = meaningful gradients\n        # Low temp later = sharp discrimination = strong signal\n        scaled_expert = expert_rewards * temp\n        scaled_gen = gen_rewards * temp\n\n        # Soft labels with smoothing\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        # Scaled logits for BCE\n        logits = torch.cat([scaled_expert, scaled_gen], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # Action matching regularization\n        # Compute on policy data only (reward for mimicking expert actions)\n        action_bonus = self._action_matching_reward(policy_obs, policy_acts)\n        # The action matching reward should be maximized, so we minimize its negative\n        action_match_loss = -action_bonus.mean()\n\n        # Total loss\n        loss = bce_loss + self.grad_penalty_coef * gp + self.action_match_weight * action_match_loss\n\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        with torch.no_grad():\n            hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n                                     torch.zeros(batch_size, device=self.device)], dim=0)\n            acc = ((logits > 0).float() == hard_labels).float().mean().item()\n\n\n... (truncated)",
3  "end_line": 456,
4  "filename": "imitation/custom_irl.py",
5  "op": "replace",
6  "start_line": 231
7}
Tool ResultStep 7
OK: Replaced lines 231..456 with 230 line(s) in imitation/custom_irl.py. Editable range: 231–461.

[Current file: imitation/custom_irl.py | editable: 231–461 | total: 703 lines]
   231  class RewardNetwork(nn.Module):
   232      """Progressive reward network with AIRL shaping and spectral normalization.
   233  
       ... (225 more lines) ...
   459              "temperature": temp,
   460          }
   461  
AssistantStep 8edit
Tool Input
1{
2  "content": "class RewardNetwork(nn.Module):\n    \"\"\"AIRL-shaped reward network with action-dependent potential.\n\n    Architecture: R(s,a,s') = g(s,a) + gamma*h(s',a') - h(s,a)\n\n    Uses an action-dependent potential function h(s,a) which provides\n    richer shaping than state-only potentials. This allows the reward\n    to better capture the structure of expert behavior in continuous\n    action spaces.\n\n    The g function is kept simple to avoid overfitting, while the\n    h function has sufficient capacity to learn complex potentials.\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): base reward function\n        self.g_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        # h(s, a): action-dependent potential function\n        self.h_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        self.gamma = 0.99\n\n    def g(self, state, action):\n        \"\"\"Base reward function g(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.g_net(x).squeeze(-1)\n\n    def h(self, state, action):\n        \"\"\"Action-dependent potential function h(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.h_net(x).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s',a') - h(s,a).\n\n        Note: we use the current action a for both h(s,a) and h(s',a').\n        This is a simplification that works well in practice since we\n        don't have access to the next action during reward computation.\n        For the next state, we use the mean action from the policy.\n        \"\"\"\n        # For the potential at next state, we approximate h(s',a') ≈ h(s',0)\n        # (using zero action as a neutral reference)\n        zero_acts = torch.zeros_like(action)\n        h_current = self.h(state, action)\n        h_next = self.h(next_state, zero_acts)\n        return self.g(state, action) + self.gamma * h_next - h_current\n\n\nclass IRLAlgorithm:\n    \"\"\"Conservative Reward Learning with BC Warmup (CRL-BC).\n\n    This algorithm combines three key ideas:\n\n    1. **Conservative Discriminator**: Instead of pushing expert and policy\n       rewards far apart (which causes reward hacking), we use a margin-based\n       loss that only penalizes the discriminator when it fails to rank\n       expert > policy by a small margin.\n\n    2. **BC Warmup with Decay**: The policy receives strong BC updates early\n       in training, which decay over time as the learned reward becomes more\n       informative. This ensures the policy starts near expert behavior.\n\n    3. **Reward Regularization**: We regularize the reward network to produce\n       smooth, conservative rewards. This prevents the policy from exploiting\n       discontinuities in the learned reward.\n\n    The key insight: GAIL/AIRL push expert and policy rewards far apart,\n    which creates a strong signal but also leads to reward hacking. By using\n    a conservative margin-based loss and strong BC warmup, we get the policy\n    close to expert behavior quickly, then refine it with the learned reward.\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=args.irl_lr)\n        self.total_updates = 0\n\n        # Hyperparameters (hardcoded)\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.margin = 0.5  # Minimum margin between expert and policy rewards\n        self._policy = None\n        self._policy_optimizer = None\n\n        # BC schedule parameters\n        self.n_bc_steps = 15  # Number of BC steps per IRL update\n        self.initial_bc_weight = 1.0  # BC weight at the start\n        self.final_bc_weight = 0.1  # BC weight after warmup\n        self.bc_warmup_steps = 200  # Updates for BC weight to decay\n\n    def set_policy(self, policy, optimizer):\n        \"\"\"Set reference to policy for BC updates.\"\"\"\n        self._policy = policy\n        self._policy_optimizer = optimizer\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use the full shaped reward for policy training.\"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _get_bc_weight(self):\n        \"\"\"Get current BC weight (decays over time).\"\"\"\n        progress = min(self.total_updates / self.bc_warmup_steps, 1.0)\n        return self.initial_bc_weight - (self.initial_bc_weight - self.final_bc_weight) * progress\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"CRL-BC update: conservative discriminator + BC with decay.\"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # ── Conservative discriminator loss ──\n        # Use BCE with label smoothing (same as GAIL/AIRL baselines)\n        expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        logits = torch.cat([expert_logits, gen_logits], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # Reward consistency regularization\n        # Penalize high variance in expert rewards (encourages smooth rewards)\n        expert_reward_var = torch.var(expert_logits)\n        gen_reward_var = torch.var(gen_logits)\n        reward_var_loss = expert_reward_var + gen_reward_var\n\n        # Total discriminator loss\n        disc_loss = bce_loss + self.grad_penalty_coef * gp + 0.01 * reward_var_loss\n\n        self.optimizer.zero_grad()\n        disc_loss.backward()\n        self.optimizer.step()\n\n        # ── BC update with decay ──\n        bc_loss = torch.tensor(0.0, device=self.device)\n        if self._policy is not None and self._policy_optimizer is not None:\n            bc_weight = self._get_bc_weight()\n            n_bc = self.n_bc_steps\n\n            for _ in range(n_bc):\n                expert_idx = torch.randint(0, n_expert, (batch_size,))\n                expert_obs = self.expert_demos[\"obs\"][expert_idx]\n                expert_acts = self.expert_demos[\"acts\"][expert_idx]\n\n                _, log_prob, entropy, _ = self._policy.get_action_and_value(\n                    expert_obs, expert_acts,\n                )\n\n                neglogp = -log_prob.mean()\n                ent_bonus = -0.001 * entropy.mean()\n\n                bc_loss = bc_weight * (neglogp + ent_bonus)\n\n                self._policy_optimizer.zero_grad()\n                bc_loss.backward()\n                nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n                self._policy_optimizer.step()\n\n        with torch.no_grad():\n            hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n                                     torch.zeros(batch_size, device=self.device)], dim=0)\n            acc = ((logits > 0).float() == hard_labels).float().mean().item()\n\n        return {\n            \"irl_loss\": disc_loss.item(),\n            \"disc_acc\": acc,\n            \"grad_penalty\": gp.item(),\n            \"bc_weight\": self._get_bc_weight(),\n            \"bc_loss\": bc_loss.item() if isinstance(bc_loss, torch.Tensor) else 0.0,\n        }",
3  "end_line": 460,
4  "filename": "imitation/custom_irl.py",
5  "op": "replace",
6  "start_line": 231
7}
Tool ResultStep 8
OK: Replaced lines 231..460 with 233 line(s) in imitation/custom_irl.py. Editable range: 231–464.

[Current file: imitation/custom_irl.py | editable: 231–464 | total: 706 lines]
   231  class RewardNetwork(nn.Module):
   232      """AIRL-shaped reward network with action-dependent potential.
   233  
       ... (228 more lines) ...
   462              "bc_loss": bc_loss.item() if isinstance(bc_loss, torch.Tensor) else 0.0,
   463          }
   464  
AssistantStep 9edit
Tool Input
1{
2  "content": "class RewardNetwork(nn.Module):\n    \"\"\"AIRL-shaped reward network: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).\n\n    g(s,a) is the reward function, h(s) is the potential/shaping function.\n    \"\"\"\n\n    def __init__(self, obs_dim, action_dim):\n        super().__init__()\n        # g(s, a): reward approximator\n        self.g_net = nn.Sequential(\n            nn.Linear(obs_dim + action_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        # h(s): potential-based shaping\n        self.h_net = nn.Sequential(\n            nn.Linear(obs_dim, 256),\n            nn.ReLU(),\n            nn.Linear(256, 256),\n            nn.ReLU(),\n            nn.Linear(256, 1),\n        )\n        self.gamma = 0.99\n\n    def g(self, state, action):\n        \"\"\"Reward function g(s, a).\"\"\"\n        x = torch.cat([state, action], dim=-1)\n        return self.g_net(x).squeeze(-1)\n\n    def h(self, state):\n        \"\"\"Potential function h(s).\"\"\"\n        return self.h_net(state).squeeze(-1)\n\n    def forward(self, state, action, next_state):\n        \"\"\"Shaped reward: f(s,a,s') = g(s,a) + gamma*h(s') - h(s).\"\"\"\n        return self.g(state, action) + self.gamma * self.h(next_state) - self.h(state)\n\n\nclass IRLAlgorithm:\n    \"\"\"GAIL-BC Hybrid with Reward Smoothing (GBH-RS).\n\n    Key innovations:\n    1. Raw discriminator logits as reward (avoids softplus saturation)\n    2. Persistent BC updates throughout training (matching BC baseline strength)\n    3. Reward variance regularization to prevent reward hacking\n    4. Gradient penalty for discriminator stability\n\n    The BC component ensures the policy stays close to expert behavior,\n    while the discriminator provides a signal for generalization beyond\n    the expert demonstrations. The reward variance regularization prevents\n    the discriminator from creating reward artifacts that the policy can\n    exploit (reward hacking).\n    \"\"\"\n\n    def __init__(self, reward_net, expert_demos, obs_dim, action_dim, device, args):\n        self.reward_net = reward_net\n        self.expert_demos = expert_demos\n        self.device = device\n        self.args = args\n        self.obs_dim = obs_dim\n        self.action_dim = action_dim\n\n        # Lower discriminator LR to prevent it from overwhelming the policy\n        disc_lr = min(args.irl_lr, 1e-4)\n        self.optimizer = optim.Adam(self.reward_net.parameters(), lr=disc_lr)\n        self.total_updates = 0\n\n        self.grad_penalty_coef = 10.0\n        self.label_smoothing = 0.1\n        self.reward_var_coef = 0.01  # Reward variance regularization\n\n        # BC parameters (matching BC baseline)\n        self._policy = None\n        self._policy_optimizer = None\n        self.n_bc_steps = 20\n\n    def set_policy(self, policy, optimizer):\n        \"\"\"Set reference to policy for BC updates.\"\"\"\n        self._policy = policy\n        self._policy_optimizer = optimizer\n\n    def compute_reward(self, obs, acts, next_obs):\n        \"\"\"Use raw discriminator logit as reward (not shaped reward).\n\n        Raw logits provide meaningful gradient signals even when the\n        discriminator is confident, unlike softplus-transformed rewards\n        which can collapse to zero.\n        \"\"\"\n        with torch.no_grad():\n            return self.reward_net(obs, acts, next_obs)\n\n    def _gradient_penalty(self, expert_obs, expert_acts, expert_next_obs,\n                          gen_obs, gen_acts, gen_next_obs):\n        \"\"\"Compute gradient penalty on interpolated data (WGAN-GP style).\"\"\"\n        batch_size = expert_obs.shape[0]\n        alpha = torch.rand(batch_size, 1, device=self.device)\n\n        interp_obs = (alpha * expert_obs + (1 - alpha) * gen_obs).requires_grad_(True)\n        interp_acts = (alpha * expert_acts + (1 - alpha) * gen_acts).requires_grad_(True)\n        interp_next = (alpha * expert_next_obs + (1 - alpha) * gen_next_obs).requires_grad_(True)\n\n        logits = self.reward_net(interp_obs, interp_acts, interp_next)\n        ones = torch.ones_like(logits)\n\n        grads = torch.autograd.grad(\n            outputs=logits, inputs=[interp_obs, interp_acts, interp_next],\n            grad_outputs=ones, create_graph=True, retain_graph=True,\n        )\n        grad_cat = torch.cat([g.reshape(batch_size, -1) for g in grads], dim=1)\n        grad_norm = torch.sqrt(torch.sum(grad_cat ** 2, dim=1) + 1e-12)\n        return torch.mean((grad_norm - 1.0) ** 2)\n\n    def update(self, policy_obs, policy_acts, policy_next_obs, policy_dones):\n        \"\"\"GAIL-BC update: discriminator + persistent BC.\"\"\"\n        self.total_updates += 1\n        batch_size = self.args.irl_batch_size\n\n        # Sample expert data\n        n_expert = len(self.expert_demos[\"obs\"])\n        expert_idx = torch.randint(0, n_expert, (batch_size,))\n        expert_obs = self.expert_demos[\"obs\"][expert_idx]\n        expert_acts = self.expert_demos[\"acts\"][expert_idx]\n        expert_next_obs = self.expert_demos[\"next_obs\"][expert_idx]\n\n        # Sample policy data\n        n_policy = len(policy_obs)\n        policy_idx = torch.randint(0, n_policy, (batch_size,))\n        gen_obs = policy_obs[policy_idx]\n        gen_acts = policy_acts[policy_idx]\n        gen_next_obs = policy_next_obs[policy_idx]\n\n        # Discriminator logits\n        expert_logits = self.reward_net(expert_obs, expert_acts, expert_next_obs)\n        gen_logits = self.reward_net(gen_obs, gen_acts, gen_next_obs)\n\n        # Labels with smoothing\n        eps = self.label_smoothing\n        expert_labels = torch.full((batch_size,), 1.0 - eps, device=self.device)\n        gen_labels = torch.full((batch_size,), eps, device=self.device)\n\n        logits = torch.cat([expert_logits, gen_logits], dim=0)\n        labels = torch.cat([expert_labels, gen_labels], dim=0)\n\n        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)\n\n        # Gradient penalty\n        gp = self._gradient_penalty(\n            expert_obs, expert_acts, expert_next_obs,\n            gen_obs, gen_acts, gen_next_obs,\n        )\n\n        # Reward variance regularization (prevents reward hacking)\n        reward_var = torch.var(expert_logits) + torch.var(gen_logits)\n        reward_var_loss = self.reward_var_coef * reward_var\n\n        # Total discriminator loss\n        disc_loss = bce_loss + self.grad_penalty_coef * gp + reward_var_loss\n\n        self.optimizer.zero_grad()\n        disc_loss.backward()\n        self.optimizer.step()\n\n        # Persistent BC update (matching BC baseline strength)\n        bc_loss = torch.tensor(0.0, device=self.device)\n        if self._policy is not None and self._policy_optimizer is not None:\n            for _ in range(self.n_bc_steps):\n                expert_idx = torch.randint(0, n_expert, (batch_size,))\n                expert_obs_bc = self.expert_demos[\"obs\"][expert_idx]\n                expert_acts_bc = self.expert_demos[\"acts\"][expert_idx]\n\n                _, log_prob, entropy, _ = self._policy.get_action_and_value(\n                    expert_obs_bc, expert_acts_bc,\n                )\n\n                neglogp = -log_prob.mean()\n                ent_bonus = -0.001 * entropy.mean()\n\n                bc_loss = neglogp + ent_bonus\n\n                self._policy_optimizer.zero_grad()\n                bc_loss.backward()\n                nn.utils.clip_grad_norm_(self._policy.parameters(), 0.5)\n                self._policy_optimizer.step()\n\n        with torch.no_grad():\n            hard_labels = torch.cat([torch.ones(batch_size, device=self.device),\n                                     torch.zeros(batch_size, device=self.device)], dim=0)\n            acc = ((logits > 0).float() == hard_labels).float().mean().item()\n\n        return {\n            \"irl_loss\": disc_loss.item(),\n            \"disc_acc\": acc,\n            \"grad_penalty\": gp.item(),\n            \"reward_var\": reward_var.item(),\n            \"bc_loss\": bc_loss.item() if isinstance(bc_loss, torch.Tensor) else 0.0,\n        }",
3  "end_line": 463,
4  "filename": "imitation/custom_irl.py",
5  "op": "replace",
6  "start_line": 231
7}
Tool ResultStep 9
OK: Replaced lines 231..463 with 198 line(s) in imitation/custom_irl.py. Editable range: 231–429.

[Current file: imitation/custom_irl.py | editable: 231–429 | total: 671 lines]
   231  class RewardNetwork(nn.Module):
   232      """AIRL-shaped reward network: f(s,a,s') = g(s,a) + gamma * h(s') - h(s).
   233  
       ... (193 more lines) ...
   427              "bc_loss": bc_loss.item() if isinstance(bc_loss, torch.Tensor) else 0.0,
   428          }
   429