rl-offline-pomdp

Reinforcement Learningkatakombarigorous codebase

Description

Offline RL: Partially Observable MDP on NetHack (POMDP)

Objective

Design and implement an offline RL algorithm that handles partial observability using recurrent memory. Your code goes in the Model and OfflineAlgorithm classes in custom_nethack_pomdp.py. Three reference implementations (BC, CQL, IQL with chaotic LSTM) are provided as read-only.

Background

NetHack is a procedurally generated roguelike with 121 discrete actions and terminal-based (80x24 character grid) observations. It is inherently a POMDP — the agent cannot see the full map, hidden enemies, or trap states from the current observation alone, requiring temporal memory.

Constraints

  • The ModelBackbone (observation encoders + LSTM) is FIXED and instantiated by the training loop. Your Model.__init__ receives it as the backbone argument and must store it as self.backbone. Do not replace or modify the backbone.
  • The LSTM dimensions (rnn_hidden_dim=2048, rnn_layers=2) are fixed by the TrainConfig and cannot be overridden.
  • Output head parameters (everything beyond the backbone) must not exceed 2,000,000. This is enforced at runtime.
  • Focus on algorithmic innovation (loss functions, training procedures, output head design) rather than scaling up network capacity.
  • Do NOT simply copy a reference implementation with minor changes

Evaluation

Trained and evaluated on Monk, Valkyrie, Ranger. Additional held-out NetHack characters (not shown during intermediate testing) are used to assess generalization. Metric: normalized score (your score / AutoAscend bot mean score).

Code

custom_nethack_pomdp.py
EditableRead-only
1# Custom offline RL algorithm for MLS-Bench — NetHack POMDP (LSTM backbone)
2#
3# FIXED: ModelBackbone (encoders + LSTM), config, utilities, data, eval, training loop.
4# EDITABLE: Model (output heads on top of backbone) + OfflineAlgorithm class.
5import os
6import sys
7import random
8import uuid
9import wandb
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
13import numpy as np
14
15import pyrallis

Results

ModelTypenormalized score mon hum neu normalized score val hum neu normalized score ran hum neu
bcbaseline0.0210.021-
bcbaseline-0.0160.081
cqlbaseline-0.0000.001
cqlbaseline0.0000.000-
iqlbaseline0.0240.0210.067
iqlbaseline-0.0120.076