rl-offline-discrete

Reinforcement Learningd3rlpyrigorous codebase

Description

Offline RL: Discrete Action Control on Atari

Objective

Design and implement an offline RL algorithm for discrete action spaces with pixel observations. Your code goes in the QNetwork and OfflineAlgorithm classes in custom_atari.py. Three reference implementations (BC, CQL, BCQ) are provided as read-only.

Background

The offline datasets are "mixed" quality replay buffer data from a partially trained DQN agent. The agent must learn entirely from this fixed dataset without environment interaction during training.

Constraints

  • The NatureDQNEncoder (CNN feature extractor) is FIXED and must not be replaced or modified. Your QNetwork must use it via self.encoder = NatureDQNEncoder(...). The convolutional layers are verified at runtime.
  • Total model parameter count must not exceed 5,000,000. This is enforced at runtime; exceeding it will crash training.
  • Focus on algorithmic innovation (loss functions, training procedures, action selection) rather than scaling up network capacity.
  • Do NOT simply copy a reference implementation with minor changes

Evaluation

Trained and evaluated on Breakout (4 actions), Pong (6 actions), Qbert (6 actions) using d4rl-atari "mixed" datasets. Additional held-out environments (not shown during intermediate testing) are used to assess generalization. Metric: mean episode return over 10 evaluation episodes.

Code

custom_atari.py
EditableRead-only
1# Custom offline RL algorithm for MLS-Bench — Atari discrete control
2#
3# EDITABLE section: QNetwork class + OfflineAlgorithm class.
4# FIXED sections: everything else (config, encoder, buffer, eval, training loop).
5import argparse
6import os
7import random
8
9import ale_py
10import gymnasium
11import numpy as np
12import torch
13import torch.nn as nn
14import torch.nn.functional as F
15

Results

ModelTypeeval return breakout eval return qbert eval return pong
bcbaseline11.0002541.667-21.000
bcqbaseline12.1001450.0000.067
cqlbaseline8.6001105.833-10.767
dtbaseline54.0002222.500-1.067
dtbaseline---
dtbaseline---
dtbaseline36.4334106.6676.400
dtbaseline51.1003035.833-0.400
deepseek-reasonervanilla0.0000.000-21.000
deepseek-reasoneragent---
deepseek-reasoneragent8.9005460.0008.100
deepseek-reasoneragent---

Agent Conversations