Agent Conversation: gpt-5.4
Task: safe-rl
Safe RL: Constraint-Handling Mechanism Design
Objective
Design a constraint-handling mechanism for safe reinforcement learning. Your code goes in custom_lag.py, a subclass of PPO registered as CustomLag. Reference implementations (PPOLag using Lagrange multiplier, CPPOPID using PID controller) are provided as read-only.
Background
Safe RL aims to maximize reward while satisfying safety constraints (keeping episode cost below a limit). The key challenge is how to adaptively balance reward and cost: the Lagrangian approach converts the constrained problem to an unconstrained dual problem via a multiplier lambda, while PID methods use control theory for more responsive constraint satisfaction. You must design: (1) a multiplier update rule in _update(), and (2) an advantage combination formula in _compute_adv_surrogate().
Evaluation
Evaluated on 3 Safety-Gymnasium environments to test generalization:
- SafetyPointGoal1-v0: point robot navigating to goals while avoiding hazards
- SafetyCarGoal1-v0: car robot (non-holonomic) navigating to goals while avoiding hazards
- SafetyPointButton1-v0: point robot pressing goal buttons while avoiding hazards
Metrics: episode reward (higher is better) and episode cost (lower is better, target <= 25.0). Each environment trains for 2M steps.
Baselines
- naive: no constraint handling (pure PPO, ignores cost)
- ppo_lag: Lagrangian multiplier updated via Adam optimizer
- pid_lag: PID controller for multiplier update
omnisafe/omnisafe/algorithms/on_policy/naive_lagrange/custom_lag.py [EDITABLE — lines 20–20, lines 48–70 only]
1: """Custom Lagrangian-based safe PPO for MLS-Bench.
2:
3: EDITABLE section: imports + constraint handling methods.
4: FIXED sections: algorithm registration, learn() with metrics reporting.
5: """
6:
7: from __future__ import annotations
8:
9: import time
10:
11: import numpy as np
12: import torch
13:
14: from omnisafe.algorithms import registry
15: from omnisafe.algorithms.on_policy.base.ppo import PPO
16:
17: # ===================================================================
18: # EDITABLE: Custom imports
19: # ===================================================================
20:
21:
22: # ===================================================================
23: # FIXED: Algorithm class definition
24: # ===================================================================
25: @registry.register
26: class CustomLag(PPO):
27: """Custom Lagrangian-based safe RL algorithm.
28:
29: Extends PPO with constraint handling for safe reinforcement learning.
30: The agent must design:
31: 1. _init: Initialize constraint handler state (call super()._init() first)
32: 2. _init_log: Register logging keys (call super()._init_log() first)
33: 3. _updaTool Input
1{
2 "op": "replace",
3 "filename": "omnisafe/omnisafe/algorithms/on_policy/naive_lagrange/custom_lag.py",
4 "start_line": 48,
5 "end_line": 70,
6 "content": " def _init(self) -> None:\n super()._init()\n self._cost_limit: float = self._cfgs.lagrange_cfgs.cost_limit\n self._lambda_lr: float = self._cfgs.lagrange_cfgs.lambda_lr\n # Novel mechanism: optimistic-conservative mirror multiplier with violation momentum\n self._lagrangian_multiplier: float = 0.0\n self._violation_ema: float = 0.0\n self._feasible_ema: float = 0.0\n self._cost_var_ema: float = 0.0\n self._prev_delta: float = 0.0\n self._ema_beta: float = 0.9\n self._var_beta: float = 0.95\n self._damp_beta: float = 0.5\n self._safety_margin: float = 0.10 * self._cost_limit\n self._lambda_cap: float = 10.0\n\n def _init_log(self) -> None:\n super()._init_log()\n self._logger.register_key('Metrics/LagrangeMultiplier', min_and_max=True)\n\n def _update(self) -> None:\n Jc = self._logger.get_stats('Metrics/EpCost')[0]\n assert not np.isnan(Jc), 'cost is nan'\n delta = float(Jc - self._cost_limit)\n # Track both violation pressure and feasibility confidence.\n pos_delta = max(delta, 0.0)\n neg_delta = max(-delta, 0.0)\n self._violation_ema = self._ema_beta * self._violation_ema + (1 - self._ema_beta) * pos_delta\n self._feasible_ema = self._ema_beta * self._feasible_ema + (1 - self._ema_beta) * neg_delta\n centered = delta - self._prev_delta\n self._cost_var_ema = self._var_beta * self._cost_var_ema + (1 - self._var_beta) * (delta * delta)\n risk_scale = float(np.sqrt(self._cost_var_ema + 1e-8))\n # Update target combines current violation, trend, and a variance-adaptive safety margin.\n trend = centered\n target = (\n self._violation_ema\n + 0.5 * pos_delta\n + 0.25 * max(trend, 0.0)\n - 0.15 * self._feasible_ema\n - 0.05 * max(-trend, 0.0)\n + 0.1 * max(risk_scale - self._safety_margin, 0.0)\n )\n target = max(0.0, target / (self._cost_limit + 1e-8))\n # Mirror-style interpolation: rise fast when unsafe, decay slowly when safe.\n if delta > 0.0:\n step = self._lambda_lr * (1.0 + target)\n else:\n step = self._lambda_lr * 0.25\n self._lagrangian_multiplier = (1 - step) * self._lagrangian_multiplier + step * target\n self._lagrangian_multiplier = float(np.clip(self._lagrangian_multiplier, 0.0, self._lambda_cap))\n self._prev_delta = delta\n super()._update()\n self._logger.store({'Metrics/LagrangeMultiplier': self._lagrangian_multiplier})\n\n def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor:\n \"\"\"Variance-aware gated reward-cost advantage coupling.\"\"\"\n penalty = float(self._lagrangian_multiplier)\n adv_c_center = adv_c - adv_c.mean()\n cost_scale = adv_c_center.abs().mean().detach() + 1e-6\n reward_scale = adv_r.abs().mean().detach() + 1e-6\n normalized_cost = adv_c_center / cost_scale\n # Gate sharpens only when multiplier is active, preserving reward learning in feasible regions.\n gate = torch.sigmoid(normalized_cost * penalty)\n adaptive_penalty = penalty * (0.5 + gate)\n return (adv_r / reward_scale - adaptive_penalty * normalized_cost) / (1.0 + penalty)"
7}OK: Replaced lines 48..70 with 65 line(s) in omnisafe/omnisafe/algorithms/on_policy/naive_lagrange/custom_lag.py. Editable range: 20–20, 48–112.
[Current file: omnisafe/omnisafe/algorithms/on_policy/naive_lagrange/custom_lag.py | editable: 20–20, 48–112 | total: 191 lines]
20
...
48 def _init(self) -> None:
49 super()._init()
50 self._cost_limit: float = self._cfgs.lagrange_cfgs.cost_limit
... (59 more lines) ...
110 gate = torch.sigmoid(normalized_cost * penalty)
111 adaptive_penalty = penalty * (0.5 + gate)
112 return (adv_r / reward_scale - adaptive_penalty * normalized_cost) / (1.0 + penalty)[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ## Seed 42 ### SafetyPointGoal1-v0 (scripts/point_goal.sh) Training metrics (last epochs): TRAIN_METRICS epoch=91 ep_ret=18.5372 ep_cost=49.6400 ep_len=1000.0 TRAIN_METRICS epoch=92 ep_ret=18.7751 ep_cost=52.1400 ep_len=1000.0 TRAIN_METRICS epoch=93 ep_ret=18.9717 ep_cost=55.6400 ep_len=1000.0 TRAIN_METRICS epoch=94 ep_ret=19.1327 ep_cost=53.4600 ep_len=1000.0 TRAIN_METRICS epoch=95 ep_ret=19.4145 ep_cost=51.7200 ep_len=1000.0 ### SafetyCarGoal1-v0 (scripts/car_goal.sh) Training metrics (last epochs): TRAIN_METRICS epoch=94 ep_ret=14.5266 ep_cost=51.6900 ep_len=1000.0 TRAIN_METRICS epoch=95 ep_ret=14.6518 ep_cost=53.5500 ep_len=1000.0 TRAIN_METRICS epoch=96 ep_ret=14.6068 ep_cost=56.2000 ep_len=1000.0 TRAIN_METRICS epoch=97 ep_ret=14.6224 ep_cost=54.3700 ep_len=1000.0 TRAIN_METRICS epoch=98 ep_ret=15.2221 ep_cost=56.5200 ep_len=1000.0 ### SafetyPointButton1-v0 (scripts/point_button.sh) T
Tool Input
1{
2 "op": "replace",
3 "filename": "omnisafe/omnisafe/algorithms/on_policy/naive_lagrange/custom_lag.py",
4 "start_line": 48,
5 "end_line": 112,
6 "content": " def _init(self) -> None:\n super()._init()\n self._cost_limit: float = self._cfgs.lagrange_cfgs.cost_limit\n self._lambda_lr: float = self._cfgs.lagrange_cfgs.lambda_lr\n # Novel mechanism: predictive integral barrier with safety-budget normalization.\n self._lagrangian_multiplier: float = 0.0\n self._ema_delta: float = 0.0\n self._ema_abs_delta: float = 0.0\n self._integral_violation: float = 0.0\n self._prev_cost: float = self._cost_limit\n self._ema_beta: float = 0.9\n self._lambda_cap: float = 20.0\n\n def _init_log(self) -> None:\n super()._init_log()\n self._logger.register_key('Metrics/LagrangeMultiplier', min_and_max=True)\n\n def _update(self) -> None:\n Jc = self._logger.get_stats('Metrics/EpCost')[0]\n assert not np.isnan(Jc), 'cost is nan'\n\n delta = float(Jc - self._cost_limit)\n norm_delta = delta / (self._cost_limit + 1e-8)\n cost_velocity = float(Jc - self._prev_cost) / (self._cost_limit + 1e-8)\n\n self._ema_delta = self._ema_beta * self._ema_delta + (1.0 - self._ema_beta) * norm_delta\n self._ema_abs_delta = self._ema_beta * self._ema_abs_delta + (1.0 - self._ema_beta) * abs(norm_delta)\n self._integral_violation = max(0.0, self._integral_violation + self._ema_delta)\n\n # Predictive barrier target: current violation + upward trend + persistent excess cost.\n pred_violation = self._ema_delta + 0.5 * max(cost_velocity, 0.0)\n raw_target = max(0.0, pred_violation) + 0.5 * self._integral_violation\n # If persistently feasible, actively release the multiplier.\n if norm_delta < 0.0:\n raw_target -= 0.25 * abs(norm_delta)\n target = float(np.clip(raw_target, 0.0, self._lambda_cap))\n\n # Aggressive rise when unsafe, conservative decay when safe.\n rise_rate = min(1.0, self._lambda_lr * (2.0 + 4.0 * max(norm_delta, 0.0)))\n decay_rate = min(1.0, self._lambda_lr * 0.5)\n mix = rise_rate if target > self._lagrangian_multiplier else decay_rate\n self._lagrangian_multiplier = (1.0 - mix) * self._lagrangian_multiplier + mix * target\n self._lagrangian_multiplier = float(np.clip(self._lagrangian_multiplier, 0.0, self._lambda_cap))\n self._prev_cost = float(Jc)\n\n super()._update()\n self._logger.store({'Metrics/LagrangeMultiplier': self._lagrangian_multiplier})\n\n def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor:\n \"\"\"Budget-normalized adaptive penalization with one-sided cost emphasis.\"\"\"\n penalty = float(self._lagrangian_multiplier)\n adv_c_center = adv_c - adv_c.mean()\n adv_c_scale = adv_c_center.abs().mean().detach() + 1e-6\n normalized_cost = adv_c_center / adv_c_scale\n # Penalize above-average cost actions more strongly than below-average ones.\n risky_cost = torch.relu(normalized_cost) + 0.25 * torch.minimum(\n normalized_cost,\n torch.zeros_like(normalized_cost),\n )\n return (adv_r - penalty * risky_cost) / (1.0 + penalty)"
7}OK: Replaced lines 48..112 with 60 line(s) in omnisafe/omnisafe/algorithms/on_policy/naive_lagrange/custom_lag.py. Editable range: 20–20, 48–107.
[Current file: omnisafe/omnisafe/algorithms/on_policy/naive_lagrange/custom_lag.py | editable: 20–20, 48–107 | total: 186 lines]
20
...
48 def _init(self) -> None:
49 super()._init()
50 self._cost_limit: float = self._cfgs.lagrange_cfgs.cost_limit
... (54 more lines) ...
105 torch.zeros_like(normalized_cost),
106 )
107 return (adv_r - penalty * risky_cost) / (1.0 + penalty)[Test #2] (0 tests remaining; call submit(n=N) to choose which test result to submit as final) [NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer. ## Seed 42 ### SafetyPointGoal1-v0 (scripts/point_goal.sh) Training metrics (last epochs): TRAIN_METRICS epoch=90 ep_ret=-0.0481 ep_cost=5.1400 ep_len=1000.0 TRAIN_METRICS epoch=91 ep_ret=-0.0740 ep_cost=4.7300 ep_len=1000.0 TRAIN_METRICS epoch=92 ep_ret=-0.0022 ep_cost=7.6500 ep_len=1000.0 TRAIN_METRICS epoch=93 ep_ret=0.0948 ep_cost=5.0200 ep_len=1000.0 TRAIN_METRICS epoch=94 ep_ret=0.3729 ep_cost=6.8000 ep_len=1000.0 ### SafetyCarGoal1-v0 (scripts/car_goal.sh) Training metrics (last epochs): TRAIN_METRICS epoch=95 ep_ret=0.4582 ep_cost=8.5000 ep_len=1000.0 TRAIN_METRICS epoch=96 ep_ret=0.4625 ep_cost=9.8800 ep_len=1000.0 TRAIN_METRICS epoch=97 ep_ret=0.4391 ep_cost=7.6400 ep_len=1000.0 TRAIN_METRICS epoch=98 ep_ret=0.2665 ep_cost=8.5100 ep_len=1000.0 TRAIN_METRICS
Tool Input
1{
2 "n": 2
3}[submit] Submitting result from test #2 as final. ## Seed 42 ### SafetyPointGoal1-v0 (scripts/point_goal.sh) Training metrics (last epochs): TRAIN_METRICS epoch=90 ep_ret=-0.0481 ep_cost=5.1400 ep_len=1000.0 TRAIN_METRICS epoch=91 ep_ret=-0.0740 ep_cost=4.7300 ep_len=1000.0 TRAIN_METRICS epoch=92 ep_ret=-0.0022 ep_cost=7.6500 ep_len=1000.0 TRAIN_METRICS epoch=93 ep_ret=0.0948 ep_cost=5.0200 ep_len=1000.0 TRAIN_METRICS epoch=94 ep_ret=0.3729 ep_cost=6.8000 ep_len=1000.0 ### SafetyCarGoal1-v0 (scripts/car_goal.sh) Training metrics (last epochs): TRAIN_METRICS epoch=95 ep_ret=0.4582 ep_cost=8.5000 ep_len=1000.0 TRAIN_METRICS epoch=96 ep_ret=0.4625 ep_cost=9.8800 ep_len=1000.0 TRAIN_METRICS epoch=97 ep_ret=0.4391 ep_cost=7.6400 ep_len=1000.0 TRAIN_METRICS epoch=98 ep_ret=0.2665 ep_cost=8.5100 ep_len=1000.0 TRAIN_METRICS epoch=99 ep_ret=0.2395 ep_cost=8.1500 ep_len=1000.0 Final evaluation: TEST_METRICS ep_ret=0.2395 ep_cost=8.1500 ep_len=1000.0 ### SafetyPointButton1-v0 (scripts/poin