rl-offline-discrete
Reinforcement Learningd3rlpyrigorous codebase
Description
Offline RL: Discrete Action Control on Atari
Objective
Design and implement an offline RL algorithm for discrete action spaces with pixel observations. Your code goes in the QNetwork and OfflineAlgorithm classes in custom_atari.py. Three reference implementations (BC, CQL, BCQ) are provided as read-only.
Background
The offline datasets are "mixed" quality replay buffer data from a partially trained DQN agent. The agent must learn entirely from this fixed dataset without environment interaction during training.
Constraints
- The
NatureDQNEncoder(CNN feature extractor) is FIXED and must not be replaced or modified. YourQNetworkmust use it viaself.encoder = NatureDQNEncoder(...). The convolutional layers are verified at runtime. - Total model parameter count must not exceed 5,000,000. This is enforced at runtime; exceeding it will crash training.
- Focus on algorithmic innovation (loss functions, training procedures, action selection) rather than scaling up network capacity.
- Do NOT simply copy a reference implementation with minor changes
Evaluation
Trained and evaluated on Breakout (4 actions), Pong (6 actions), Qbert (6 actions) using d4rl-atari "mixed" datasets. Additional held-out environments (not shown during intermediate testing) are used to assess generalization. Metric: mean episode return over 10 evaluation episodes.
Code
custom_atari.py
EditableRead-only
1# Custom offline RL algorithm for MLS-Bench — Atari discrete control2#3# EDITABLE section: QNetwork class + OfflineAlgorithm class.4# FIXED sections: everything else (config, encoder, buffer, eval, training loop).5import argparse6import os7import random89import ale_py10import gymnasium11import numpy as np12import torch13import torch.nn as nn14import torch.nn.functional as F15
Results
| Model | Type | eval return breakout ↑ | eval return qbert ↑ | eval return pong ↑ |
|---|---|---|---|---|
| bc | baseline | 11.000 | 2541.667 | -21.000 |
| bcq | baseline | 12.100 | 1450.000 | 0.067 |
| cql | baseline | 8.600 | 1105.833 | -10.767 |
| dt | baseline | 54.000 | 2222.500 | -1.067 |
| dt | baseline | - | - | - |
| dt | baseline | - | - | - |
| dt | baseline | 36.433 | 4106.667 | 6.400 |
| dt | baseline | 51.100 | 3035.833 | -0.400 |
| deepseek-reasoner | vanilla | 0.000 | 0.000 | -21.000 |
| deepseek-reasoner | agent | - | - | - |
| deepseek-reasoner | agent | 8.900 | 5460.000 | 8.100 |
| deepseek-reasoner | agent | - | - | - |