Agent Conversation: deepseek-reasoner
Task: rl-value-atari
Online RL: Value-Based Methods for Visual Control (Atari)
Objective
Design and implement a value-based RL algorithm for visual/Atari environments using CNN feature extraction. Your code goes in custom_value_atari.py. Three reference implementations (DQN, DoubleDQN, C51) are provided as read-only.
Background
Atari games require learning from raw pixel observations (84x84 grayscale, 4 stacked frames). Value-based methods must learn effective visual representations alongside Q-value estimation. Key challenges include high-dimensional observations, sparse rewards, and memory-efficient experience replay. Different approaches address these through distributional value functions, frame stacking, or architecture innovations.
Constraints
- Network architecture dimensions are FIXED and cannot be modified
- Total parameter count is enforced at runtime
- Focus on algorithmic innovation: new loss functions, update rules, exploration strategies, etc.
- Do NOT simply copy a reference implementation with minor changes
Evaluation
Trained and evaluated on Breakout, Pong, BeamRider. Additional held-out environments (not shown during intermediate testing) are used to assess generalization. Metric: mean episodic return over 10 evaluation episodes (higher is better).
cleanrl/cleanrl/custom_value_atari.py [EDITABLE — lines 186–249 only]
1: # Custom value-based RL algorithm for Atari -- MLS-Bench
2: #
3: # EDITABLE section: QNetwork head and ValueAlgorithm classes.
4: # FIXED sections: everything else (config, env, buffer, encoder, eval, training loop).
5: import os
6: import random
7: import time
8: from dataclasses import dataclass
9:
10: import gymnasium as gym
11: import numpy as np
12: import torch
13: import torch.nn as nn
14: import torch.nn.functional as F
15: import torch.optim as optim
16: import tyro
17:
18: from cleanrl_utils.atari_wrappers import (
19: ClipRewardEnv,
20: EpisodicLifeEnv,
21: FireResetEnv,
22: MaxAndSkipEnv,
23: NoopResetEnv,
24: )
25: from cleanrl_utils.buffers import ReplayBuffer
26:
27:
28: # =====================================================================
29: # FIXED: Configuration
30: # =====================================================================
31: @dataclass
32: class Args:
33: exp_name: str = os.path.basename(__file__)[: -len(".py")]
34: """the name of this experiment"""
35: seed: int = 1
36: """seed of the experiment"""
37: torch_deterministic: bool = True
38: """if toggled, `torch.backends.cudnn.deterministic=False`"""
39: cuda: bool = True
40: """if toggled, cuda will be enabled by default"""
41:
42: # Algorithm specific arguments
43: env_id: str = "BreakoutNoFrameskip-v4"
44: """the id of the environment""Tool Input
1{
2 "op": "replace",
3 "filename": "cleanrl/cleanrl/custom_value_atari.py",
4 "start_line": 186,
5 "end_line": 249,
6 "content": "class QNetwork(nn.Module):\n \"\"\"Implicit Quantile Network (IQN) for distributional RL.\n \n Encodes quantile samples via cosine embedding and combines with state features.\n \"\"\"\n def __init__(self, envs, quantile_embed_dim=64):\n super().__init__()\n n_actions = envs.single_action_space.n\n self.n_actions = n_actions\n self.quantile_embed_dim = quantile_embed_dim\n self.encoder = NatureDQNEncoder()\n \n # Cosine embedding for quantile samples\n self.cos_embed = nn.Linear(quantile_embed_dim, 512)\n # Main head: after elementwise product of features and quantile embedding\n self.head = nn.Linear(512, n_actions)\n \n def forward(self, x, tau_samples=None):\n \"\"\"Return Q-values (mean over quantiles) or quantile values if tau_samples provided.\n \n Args:\n x: state tensor [batch, 4, 84, 84]\n tau_samples: optional [batch, n_samples] quantile samples in (0,1)\n \n Returns:\n If tau_samples is None: Q-values [batch, n_actions]\n Else: quantile values [batch, n_samples, n_actions]\n \"\"\"\n features = self.encoder(x) # [batch, 512]\n \n if tau_samples is None:\n # For action selection: sample a single tau per state? Use expectation over tau.\n # We'll sample 32 tau per state and average.\n batch_size = x.shape[0]\n tau = torch.rand(batch_size, 32, device=x.device) # [batch, 32]\n quantile_vals = self._quantile_values(features, tau) # [batch, 32, n_actions]\n q_values = quantile_vals.mean(dim=1) # [batch, n_actions]\n return q_values\n else:\n # tau_samples: [batch, n_samples]\n quantile_vals = self._quantile_values(features, tau_samples) # [batch, n_samples, n_actions]\n return quantile_vals\n \n def _quantile_values(self, features, tau):\n \"\"\"Compute quantile values given features and tau samples.\n \n Args:\n features: [batch, 512]\n tau: [batch, n_samples]\n \n Returns:\n quantile values: [batch, n_samples, n_actions]\n \"\"\"\n batch_size, n_samples = tau.shape\n # Compute cosine embedding of tau\n # tau in (0,1), expand to [batch, n_samples, quantile_embed_dim]\n pi = torch.tensor(torch.pi, device=tau.device)\n # Create cosine basis: cos(pi * i * tau) for i=1..quantile_embed_dim\n i_vals = torch.arange(1, self.quantile_embed_dim + 1, device=tau.device).float()\n # i_vals: [quantile_embed_dim]\n # tau: [batch, n_samples, 1]\n tau = tau.unsqueeze(-1) # [batch, n_samples, 1]\n # cos(pi * i * tau): [batch, n_samples, quantile_embed_dim]\n cos_vals = torch.cos(pi * i_vals * tau)\n # Linear transform\n cos_embed = self.cos_embed(cos_vals) # [batch, n_samples, 512]\n # Expand features to match\n features_expanded = features.unsqueeze(1).expand(-1, n_samples, -1) # [batch, n_samples, 512]\n # Elementwise product\n combined = features_expanded * cos_embed # [batch, n_samples, 512]\n # ReLU activation (as in IQN paper)\n combined = F.relu(combined)\n # Linear head\n quantile_vals = self.head(combined) # [batch, n_samples, n_actions]\n return quantile_vals\n\n\nclass ValueAlgorithm:\n \"\"\"Implicit Quantile Networks (IQN) with latent temporal consistency regularization.\"\"\"\n \n def __init__(self, envs, device, args):\n self.device = device\n self.gamma = args.gamma\n self.tau = args.tau\n self.target_network_frequency = args.target_network_frequency\n self.n_quantile_samples = 8 # number of tau samples for loss\n self.n_target_samples = 8\n self.kappa = 1.0 # Huber loss threshold\n self.lambda_latent = 0.1 # weight for latent consistency loss\n \n self.q_network = QNetwork(envs).to(device)\n self.target_network = QNetwork(envs).to(device)\n self.target_network.load_state_dict(self.q_network.state_dict())\n self.optimizer = optim.Adam(self.q_network.parameters(), lr=args.learning_rate)\n \n # Small MLP for predicting next latent features\n self.latent_predictor = nn.Sequential(\n nn.Linear(ENCODER_FEATURE_DIM, 256),\n nn.ReLU(),\n nn.Linear(256, ENCODER_FEATURE_DIM)\n ).to(device)\n self.latent_optimizer = optim.Adam(self.latent_predictor.parameters(), lr=args.learning_rate)\n \n def select_action(self, obs, epsilon):\n \"\"\"Epsilon-greedy action selection using Q-values.\"\"\"\n if random.random() < epsilon:\n return np.array([self.q_network.n_actions])\n q_values = self.q_network(torch.Tensor(obs).to(self.device))\n return torch.argmax(q_values, dim=1).cpu().numpy()\n \n def update(self, batch, global_step):\n \"\"\"IQN update with latent temporal consistency regularization.\"\"\"\n batch_size = batch.observations.shape[0]\n \n # Sample quantiles for current and target networks\n tau = torch.rand(batch_size, self.n_quantile_samples, device=self.device) # [batch, N]\n tau_prime = torch.rand(batch_size, self.n_target_samples, device=self.device) # [batch, N']\n \n with torch.no_grad():\n # Compute target quantile values\n next_quantiles = self.target_network._quantile_values(\n self.target_network.encoder(batch.next_observations),\n tau_prime\n ) # [batch, N', n_actions]\n next_q = next_quantiles.mean(dim=1) # [batch, n_actions]\n next_actions = next_q.argmax(dim=1) # [batch]\n # Select quantiles for best actions\n next_quantiles_best = next_quantiles[torch.arange(batch_size), :, next_actions] # [batch, N']\n # Compute target quantile values\n target_quantiles = batch.rewards + self.gamma * next_quantiles_best * (1 - batch.dones)\n target_quantiles = target_quantiles.unsqueeze(1) # [batch, 1, N'] for broadcast\n \n # Current quantile values for taken actions\n current_quantiles_all = self.q_network._quantile_values(\n self.q_network.encoder(batch.observations),\n tau\n ) # [batch, N, n_actions]\n current_quantiles = current_quantiles_all[torch.arange(batch_size), :, batch.actions.flatten()] # [batch, N]\n \n # Quantile Huber loss\n # Pairwise TD errors: [batch, N, N']\n td_errors = target_quantiles - current_quantiles.unsqueeze(2)\n huber_loss = F.huber_loss(td_errors, torch.zeros_like(td_errors), reduction='none', delta=self.kappa)\n quantile_loss = (huber_loss * (tau.unsqueeze(2) - (td_errors < 0).float()).abs()).mean()\n \n # Latent temporal consistency loss\n features_current = self.q_network.encoder(batch.observations)\n features_next = self.q_network.encoder(batch.next_observations)\n predicted_next = self.latent_predictor(features_current)\n latent_loss = F.mse_loss(predicted_next, features_next.detach())\n \n # Total loss\n total_loss = quantile_loss + self.lambda_latent * latent_loss\n \n self.optimizer.zero_grad()\n self.latent_optimizer.zero_grad()\n total_loss.backward()\n self.optimizer.step()\n self.latent_optimizer.step()\n \n # Soft target network update\n if global_step % self.target_network_frequency == 0:\n for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):\n target_param.data.copy_(\n self.tau * q_param.data + (1.0 - self.tau) * target_param.data\n )\n \n return {\n \"td_loss\": quantile_loss.item(),\n \"latent_loss\": latent_loss.item(),\n \"total_loss\": total_loss.item(),\n \"q_values\": current_quantiles.mean().item()\n }"
7}OK: Replaced lines 186..249 with 173 line(s) in cleanrl/cleanrl/custom_value_atari.py. Editable range: 186–358. [Current file: cleanrl/cleanrl/custom_value_atari.py | editable: 186–358 | total: 468 lines] 183| # ===================================================================== 184| # EDITABLE: QNetwork head and ValueAlgorithm 185| # ===================================================================== 186 class QNetwork(nn.Module): 187 """Implicit Quantile Network (IQN) for distributional RL. 188 189 Encodes quantile samples via cosine embedding and combines with state features. 190 """ 191 def __init__(self, envs, quantile_embed_dim=64): 192 super().__init__() 193 n_actions = envs.single_action_space.n 194 self.n_actions = n_actions 195 self.quantile_embed_dim = quantile_embed_dim 196 self.encoder = NatureDQNEncoder() 197 198 # Cosine embeddin
Tool Input
1{
2 "op": "replace",
3 "filename": "cleanrl/cleanrl/custom_value_atari.py",
4 "start_line": 289,
5 "end_line": 294,
6 "content": " def select_action(self, obs, epsilon):\n \"\"\"Epsilon-greedy action selection using Q-values.\"\"\"\n if random.random() < epsilon:\n # Return random action (single environment)\n return np.array([random.randrange(self.q_network.n_actions)])\n q_values = self.q_network(torch.Tensor(obs).to(self.device))\n return torch.argmax(q_values, dim=1).cpu().numpy()"
7}OK: Replaced lines 289..294 with 7 line(s) in cleanrl/cleanrl/custom_value_atari.py. Editable range: 186–359. [Current file: cleanrl/cleanrl/custom_value_atari.py | editable: 186–359 | total: 469 lines] 183| # ===================================================================== 184| # EDITABLE: QNetwork head and ValueAlgorithm 185| # ===================================================================== 186 class QNetwork(nn.Module): 187 """Implicit Quantile Network (IQN) for distributional RL. 188 189 Encodes quantile samples via cosine embedding and combines with state features. 190 """ 191 def __init__(self, envs, quantile_embed_dim=64): 192 super().__init__() 193 n_actions = envs.single_action_space.n 194 self.n_actions = n_actions 195 self.quantile_embed_dim = quantile_embed_dim 196 self.encoder = NatureDQNEncoder() 197 198 # Cosine embedding
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ## Seed 42 ### breakout-v4 (scripts/breakout.sh) /usr/bin/bash: warning: setlocale: LC_ALL: cannot change locale (zh_CN.UTF-8) A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7) [Powered by Stella] QNetwork parameters: 1,719,460 / 2,199,254 (1.05x largest baseline, informational only) Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality. Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade. See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information. /usr/local/lib/python3.8/dist-packages/pygame/pkgdata.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html from pkg_resources import resource_stream, resource_exists global_step=114, episodic_ret
Tool Input
1{
2 "op": "replace",
3 "filename": "cleanrl/cleanrl/custom_value_atari.py",
4 "start_line": 263,
5 "end_line": 359,
6 "content": "class ValueAlgorithm:\n \"\"\"Implicit Quantile Networks (IQN) with temporal feature smoothing regularization.\"\"\"\n \n def __init__(self, envs, device, args):\n self.device = device\n self.gamma = args.gamma\n self.tau = args.tau\n self.target_network_frequency = args.target_network_frequency\n self.n_quantile_samples = 8 # number of tau samples for loss\n self.n_target_samples = 8\n self.kappa = 1.0 # Huber loss threshold\n self.lambda_latent = 0.01 # weight for temporal feature smoothing\n \n self.q_network = QNetwork(envs).to(device)\n self.target_network = QNetwork(envs).to(device)\n self.target_network.load_state_dict(self.q_network.state_dict())\n self.optimizer = optim.Adam(self.q_network.parameters(), lr=args.learning_rate)\n \n def select_action(self, obs, epsilon):\n \"\"\"Epsilon-greedy action selection using Q-values.\"\"\"\n if random.random() < epsilon:\n # Return random action (single environment)\n return np.array([random.randrange(self.q_network.n_actions)])\n q_values = self.q_network(torch.Tensor(obs).to(self.device))\n return torch.argmax(q_values, dim=1).cpu().numpy()\n \n def update(self, batch, global_step):\n \"\"\"IQN update with temporal feature smoothing regularization.\"\"\"\n batch_size = batch.observations.shape[0]\n \n # Sample quantiles for current and target networks\n tau = torch.rand(batch_size, self.n_quantile_samples, device=self.device) # [batch, N]\n tau_prime = torch.rand(batch_size, self.n_target_samples, device=self.device) # [batch, N']\n \n with torch.no_grad():\n # Compute target quantile values\n next_quantiles = self.target_network._quantile_values(\n self.target_network.encoder(batch.next_observations),\n tau_prime\n ) # [batch, N', n_actions]\n next_q = next_quantiles.mean(dim=1) # [batch, n_actions]\n next_actions = next_q.argmax(dim=1) # [batch]\n # Select quantiles for best actions\n next_quantiles_best = next_quantiles[torch.arange(batch_size), :, next_actions] # [batch, N']\n # Compute target quantile values\n target_quantiles = batch.rewards + self.gamma * next_quantiles_best * (1 - batch.dones)\n target_quantiles = target_quantiles.unsqueeze(1) # [batch, 1, N'] for broadcast\n \n # Current quantile values for taken actions\n current_quantiles_all = self.q_network._quantile_values(\n self.q_network.encoder(batch.observations),\n tau\n ) # [batch, N, n_actions]\n current_quantiles = current_quantiles_all[torch.arange(batch_size), :, batch.actions.flatten()] # [batch, N]\n \n # Quantile Huber loss\n # Pairwise TD errors: [batch, N, N']\n td_errors = target_quantiles - current_quantiles.unsqueeze(2)\n huber_loss = F.huber_loss(td_errors, torch.zeros_like(td_errors), reduction='none', delta=self.kappa)\n quantile_loss = (huber_loss * (tau.unsqueeze(2) - (td_errors < 0).float()).abs()).mean()\n \n # Temporal feature smoothing: encourage consecutive features to be similar\n features_current = self.q_network.encoder(batch.observations)\n features_next = self.q_network.encoder(batch.next_observations)\n latent_loss = F.mse_loss(features_current, features_next.detach())\n \n # Total loss\n total_loss = quantile_loss + self.lambda_latent * latent_loss\n \n self.optimizer.zero_grad()\n total_loss.backward()\n self.optimizer.step()\n \n # Soft target network update\n if global_step % self.target_network_frequency == 0:\n for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):\n target_param.data.copy_(\n self.tau * q_param.data + (1.0 - self.tau) * target_param.data\n )\n \n return {\n \"td_loss\": quantile_loss.item(),\n \"latent_loss\": latent_loss.item(),\n \"total_loss\": total_loss.item(),\n \"q_values\": current_quantiles.mean().item()\n }"
7}OK: Replaced lines 263..359 with 86 line(s) in cleanrl/cleanrl/custom_value_atari.py. Editable range: 186–348. [Current file: cleanrl/cleanrl/custom_value_atari.py | editable: 186–348 | total: 458 lines] 183| # ===================================================================== 184| # EDITABLE: QNetwork head and ValueAlgorithm 185| # ===================================================================== 186 class QNetwork(nn.Module): 187 """Implicit Quantile Network (IQN) for distributional RL. 188 189 Encodes quantile samples via cosine embedding and combines with state features. 190 """ 191 def __init__(self, envs, quantile_embed_dim=64): 192 super().__init__() 193 n_actions = envs.single_action_space.n 194 self.n_actions = n_actions 195 self.quantile_embed_dim = quantile_embed_dim 196 self.encoder = NatureDQNEncoder() 197 198 # Cosine embedding
Tool Input
1{
2 "op": "replace",
3 "filename": "cleanrl/cleanrl/custom_value_atari.py",
4 "start_line": 186,
5 "end_line": 260,
6 "content": "class QNetwork(nn.Module):\n \"\"\"Implicit Quantile Network (IQN) for distributional RL.\n \n Encodes quantile samples via cosine embedding and combines with state features.\n \"\"\"\n def __init__(self, envs, quantile_embed_dim=32):\n super().__init__()\n n_actions = envs.single_action_space.n\n self.n_actions = n_actions\n self.quantile_embed_dim = quantile_embed_dim\n self.encoder = NatureDQNEncoder()\n \n # Cosine embedding for quantile samples\n self.cos_embed = nn.Linear(quantile_embed_dim, 512)\n # Main head: after elementwise product of features and quantile embedding\n self.head = nn.Linear(512, n_actions)\n \n def forward(self, x, tau_samples=None):\n \"\"\"Return Q-values (mean over quantiles) or quantile values if tau_samples provided.\n \n Args:\n x: state tensor [batch, 4, 84, 84]\n tau_samples: optional [batch, n_samples] quantile samples in (0,1)\n \n Returns:\n If tau_samples is None: Q-values [batch, n_actions]\n Else: quantile values [batch, n_samples, n_actions]\n \"\"\"\n features = self.encoder(x) # [batch, 512]\n \n if tau_samples is None:\n # For action selection: sample 8 tau per state and average.\n batch_size = x.shape[0]\n tau = torch.rand(batch_size, 8, device=x.device) # [batch, 8]\n quantile_vals = self._quantile_values(features, tau) # [batch, 8, n_actions]\n q_values = quantile_vals.mean(dim=1) # [batch, n_actions]\n return q_values\n else:\n # tau_samples: [batch, n_samples]\n quantile_vals = self._quantile_values(features, tau_samples) # [batch, n_samples, n_actions]\n return quantile_vals\n \n def _quantile_values(self, features, tau):\n \"\"\"Compute quantile values given features and tau samples.\n \n Args:\n features: [batch, 512]\n tau: [batch, n_samples]\n \n Returns:\n quantile values: [batch, n_samples, n_actions]\n \"\"\"\n batch_size, n_samples = tau.shape\n # Compute cosine embedding of tau\n # tau in (0,1), expand to [batch, n_samples, quantile_embed_dim]\n pi = torch.tensor(torch.pi, device=tau.device)\n # Create cosine basis: cos(pi * i * tau) for i=1..quantile_embed_dim\n i_vals = torch.arange(1, self.quantile_embed_dim + 1, device=tau.device).float()\n # i_vals: [quantile_embed_dim]\n # tau: [batch, n_samples, 1]\n tau = tau.unsqueeze(-1) # [batch, n_samples, 1]\n # cos(pi * i * tau): [batch, n_samples, quantile_embed_dim]\n cos_vals = torch.cos(pi * i_vals * tau)\n # Linear transform\n cos_embed = self.cos_embed(cos_vals) # [batch, n_samples, 512]\n # Expand features to match\n features_expanded = features.unsqueeze(1).expand(-1, n_samples, -1) # [batch, n_samples, 512]\n # Elementwise product\n combined = features_expanded * cos_embed # [batch, n_samples, 512]\n # ReLU activation (as in IQN paper)\n combined = F.relu(combined)\n # Linear head\n quantile_vals = self.head(combined) # [batch, n_samples, n_actions]\n return quantile_vals"
7}OK: Replaced lines 186..260 with 74 line(s) in cleanrl/cleanrl/custom_value_atari.py. Editable range: 186–347. [Current file: cleanrl/cleanrl/custom_value_atari.py | editable: 186–347 | total: 457 lines] 183| # ===================================================================== 184| # EDITABLE: QNetwork head and ValueAlgorithm 185| # ===================================================================== 186 class QNetwork(nn.Module): 187 """Implicit Quantile Network (IQN) for distributional RL. 188 189 Encodes quantile samples via cosine embedding and combines with state features. 190 """ 191 def __init__(self, envs, quantile_embed_dim=32): 192 super().__init__() 193 n_actions = envs.single_action_space.n 194 self.n_actions = n_actions 195 self.quantile_embed_dim = quantile_embed_dim 196 self.encoder = NatureDQNEncoder() 197 198 # Cosine embedding
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### breakout-v4 (scripts/breakout.sh) Training metrics (last steps): TRAIN_METRICS step=2995000 td_loss=0.0300 latent_loss=0.0082 total_loss=0.0301 q_values=4.7895 TRAIN_METRICS step=2996000 td_loss=0.0113 latent_loss=0.0062 total_loss=0.0113 q_values=4.5404 TRAIN_METRICS step=2997000 td_loss=0.0304 latent_loss=0.0069 total_loss=0.0304 q_values=5.4329 TRAIN_METRICS step=2998000 td_loss=0.0471 latent_loss=0.0091 total_loss=0.0471 q_values=4.9636 TRAIN_METRICS step=2999000 td_loss=0.0203 latent_loss=0.0078 total_loss=0.0204 q_values=5.3423 Evaluation (breakout-v4): Eval episodic_return: 324.20 Eval episodic_return: 253.30 Eval episodic_return: 294.60 Final eval return: 294.60 ### seaquest-v4 (scripts/seaquest.sh) Training metrics (last steps): TRAIN_METRICS step=4939000 td_loss=0.0116 latent_loss=0.0037 total_loss=0.0117 q_values=9.1046 TRAIN_METRICS step=4940000 td_loss=0.0145 latent_loss=0.0035 total_loss=0.0145 q_v
Tool Input
1{
2 "op": "replace",
3 "filename": "cleanrl/cleanrl/custom_value_atari.py",
4 "start_line": 262,
5 "end_line": 347,
6 "content": "class ValueAlgorithm:\n \"\"\"Implicit Quantile Networks (IQN) with TD-weighted temporal feature smoothing.\"\"\"\n \n def __init__(self, envs, device, args):\n self.device = device\n self.gamma = args.gamma\n self.tau = args.tau\n self.target_network_frequency = args.target_network_frequency\n self.n_quantile_samples = 16 # number of tau samples for loss\n self.n_target_samples = 16\n self.kappa = 1.0 # Huber loss threshold\n self.lambda_latent = 0.001 # weight for TD-weighted feature smoothing\n \n self.q_network = QNetwork(envs).to(device)\n self.target_network = QNetwork(envs).to(device)\n self.target_network.load_state_dict(self.q_network.state_dict())\n self.optimizer = optim.Adam(self.q_network.parameters(), lr=args.learning_rate)\n \n def select_action(self, obs, epsilon):\n \"\"\"Epsilon-greedy action selection using Q-values.\"\"\"\n if random.random() < epsilon:\n # Return random action (single environment)\n return np.array([random.randrange(self.q_network.n_actions)])\n q_values = self.q_network(torch.Tensor(obs).to(self.device))\n return torch.argmax(q_values, dim=1).cpu().numpy()\n \n def update(self, batch, global_step):\n \"\"\"IQN update with TD-weighted temporal feature smoothing.\"\"\"\n batch_size = batch.observations.shape[0]\n \n # Sample quantiles for current and target networks\n tau = torch.rand(batch_size, self.n_quantile_samples, device=self.device) # [batch, N]\n tau_prime = torch.rand(batch_size, self.n_target_samples, device=self.device) # [batch, N']\n \n with torch.no_grad():\n # Compute target quantile values\n next_quantiles = self.target_network._quantile_values(\n self.target_network.encoder(batch.next_observations),\n tau_prime\n ) # [batch, N', n_actions]\n next_q = next_quantiles.mean(dim=1) # [batch, n_actions]\n next_actions = next_q.argmax(dim=1) # [batch]\n # Select quantiles for best actions\n next_quantiles_best = next_quantiles[torch.arange(batch_size), :, next_actions] # [batch, N']\n # Compute target quantile values\n target_quantiles = batch.rewards + self.gamma * next_quantiles_best * (1 - batch.dones)\n target_quantiles = target_quantiles.unsqueeze(1) # [batch, 1, N'] for broadcast\n \n # Current quantile values for taken actions\n current_quantiles_all = self.q_network._quantile_values(\n self.q_network.encoder(batch.observations),\n tau\n ) # [batch, N, n_actions]\n current_quantiles = current_quantiles_all[torch.arange(batch_size), :, batch.actions.flatten()] # [batch, N]\n \n # Quantile Huber loss\n # Pairwise TD errors: [batch, N, N']\n td_errors = target_quantiles - current_quantiles.unsqueeze(2)\n huber_loss = F.huber_loss(td_errors, torch.zeros_like(td_errors), reduction='none', delta=self.kappa)\n quantile_loss = (huber_loss * (tau.unsqueeze(2) - (td_errors < 0).float()).abs()).mean()\n \n # TD-weighted temporal feature smoothing\n features_current = self.q_network.encoder(batch.observations)\n features_next = self.q_network.encoder(batch.next_observations)\n # Compute MSE per sample\n mse_per_sample = F.mse_loss(features_current, features_next.detach(), reduction='none').mean(dim=1) # [batch]\n # Compute TD error magnitude per sample\n td_mag = td_errors.abs().mean(dim=(1,2)) # [batch]\n # Weight: higher weight where TD error is small (confidence)\n weight = torch.exp(-td_mag.detach()) # [batch]\n latent_loss = (weight * mse_per_sample).mean()\n \n # Total loss\n total_loss = quantile_loss + self.lambda_latent * latent_loss\n \n self.optimizer.zero_grad()\n total_loss.backward()\n self.optimizer.step()\n \n # Soft target network update\n if global_step % self.target_network_frequency == 0:\n for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):\n target_param.data.copy_(\n self.tau * q_param.data + (1.0 - self.tau) * target_param.data\n )\n \n return {\n \"td_loss\": quantile_loss.item(),\n \"latent_loss\": latent_loss.item(),\n \"total_loss\": total_loss.item(),\n \"q_values\": current_quantiles.mean().item()\n }"
7}OK: Replaced lines 262..347 with 92 line(s) in cleanrl/cleanrl/custom_value_atari.py. Editable range: 186–353. [Current file: cleanrl/cleanrl/custom_value_atari.py | editable: 186–353 | total: 463 lines] 183| # ===================================================================== 184| # EDITABLE: QNetwork head and ValueAlgorithm 185| # ===================================================================== 186 class QNetwork(nn.Module): 187 """Implicit Quantile Network (IQN) for distributional RL. 188 189 Encodes quantile samples via cosine embedding and combines with state features. 190 """ 191 def __init__(self, envs, quantile_embed_dim=32): 192 super().__init__() 193 n_actions = envs.single_action_space.n 194 self.n_actions = n_actions 195 self.quantile_embed_dim = quantile_embed_dim 196 self.encoder = NatureDQNEncoder() 197 198 # Cosine embedding
Tool Input
1{
2 "op": "replace",
3 "filename": "cleanrl/cleanrl/custom_value_atari.py",
4 "start_line": 216,
5 "end_line": 222,
6 "content": " if tau_samples is None:\n # For action selection: sample 16 tau per state and average.\n batch_size = x.shape[0]\n tau = torch.rand(batch_size, 16, device=x.device) # [batch, 16]\n quantile_vals = self._quantile_values(features, tau) # [batch, 16, n_actions]\n q_values = quantile_vals.mean(dim=1) # [batch, n_actions]\n return q_values"
7}OK: Replaced lines 216..222 with 7 line(s) in cleanrl/cleanrl/custom_value_atari.py. Editable range: 186–353. [Current file: cleanrl/cleanrl/custom_value_atari.py | editable: 186–353 | total: 463 lines] 183| # ===================================================================== 184| # EDITABLE: QNetwork head and ValueAlgorithm 185| # ===================================================================== 186 class QNetwork(nn.Module): 187 """Implicit Quantile Network (IQN) for distributional RL. 188 189 Encodes quantile samples via cosine embedding and combines with state features. 190 """ 191 def __init__(self, envs, quantile_embed_dim=32): 192 super().__init__() 193 n_actions = envs.single_action_space.n 194 self.n_actions = n_actions 195 self.quantile_embed_dim = quantile_embed_dim 196 self.encoder = NatureDQNEncoder() 197 198 # Cosine embedding
[Test #3] (0 tests remaining) ## Seed 42 ### breakout-v4 (scripts/breakout.sh) Training metrics (last steps): TRAIN_METRICS step=1807000 td_loss=0.0366 latent_loss=0.0078 total_loss=0.0366 q_values=3.1151 TRAIN_METRICS step=1808000 td_loss=0.0097 latent_loss=0.0133 total_loss=0.0097 q_values=3.3553 TRAIN_METRICS step=1809000 td_loss=0.0131 latent_loss=0.0112 total_loss=0.0132 q_values=3.2436 TRAIN_METRICS step=1810000 td_loss=0.0069 latent_loss=0.0064 total_loss=0.0069 q_values=3.4697 TRAIN_METRICS step=1811000 td_loss=0.0105 latent_loss=0.0079 total_loss=0.0105 q_values=3.2141 Evaluation (breakout-v4): Eval episodic_return: 67.60 Eval episodic_return: 38.40 Eval episodic_return: 149.40 Final eval return: 149.40 ### seaquest-v4 (scripts/seaquest.sh) Training metrics (last steps): TRAIN_METRICS step=3039000 td_loss=0.0052 latent_loss=0.0012 total_loss=0.0052 q_values=5.1346 TRAIN_METRICS step=3040000 td_loss=0.0106 latent_loss=0.0177 total_loss=0.0106 q_values=5.1634 TRAIN_METRICS s