rl-offline-policy
Reinforcement Learningfqlrigorous codebase
Description
Offline RL Policy Training
Objective
Design a better policy training approach for offline reinforcement learning that effectively leverages offline data while maximizing Q-values. Your approach is evaluated on OGBench environments by success rate.
Background
Offline RL trains policies from fixed datasets without environment interaction. The key challenge is balancing:
- Staying close to the data (behavioral cloning) to avoid out-of-distribution actions
- Maximizing Q-values to discover better actions than those in the dataset
Different approaches tackle this differently:
- IQL: Uses implicit Q-learning with advantage-weighted regression (AWR) for the actor
- TD3+BC / ReBRAC: Deterministic policy gradient (DDPG) regularized with a BC loss
- FQL: Trains a flow-matching BC policy, then distills it into a one-step policy guided by Q-values
What to implement
Implement the PolicyTrainer class with three core methods:
create_policy_networks()-- Define what neural network(s) your policy usescompute_actor_loss()-- Define the training loss for the policysample_actions()-- Define how to produce actions at test time
You may also override get_target_update_modules() and needs_next_actions() if needed.
Fixed components
- Critic: Twin Q-networks (Value with num_ensembles=2) trained with standard TD loss
- Target network: Polyak-averaged target critic for stable Q-value targets
- Training loop: 1M offline gradient steps, periodic evaluation
- Dataset: OGBench offline datasets with observations, actions, rewards, masks
Available primitives
GaussianActor: Gaussian policy network (supports const_std, state_dependent_std, tanh_squash)VectorFieldActor: Vector field network for flow matching (takes observations, actions, optional times)Value: Value/critic network with optional ensembleMLP: Multi-layer perceptron with optional layer normdistrax: JAX probability distributions libraryoptax: JAX optimizer library
Evaluation
Success rate on OGBench goal-reaching tasks (antmaze navigation, cube manipulation). Higher is better. The metric ranges from 0.0 (never reaches goal) to 1.0 (always reaches goal).
Hyperparameters available in config
alpha: BC/regularization coefficient (environment-specific, passed via script)discount: Discount factor (default 0.99)tau: Target network EMA rate (default 0.005)lr: Learning rate (default 3e-4)q_agg: Q aggregation for target ('min' or 'mean', passed via script)normalize_q_loss: Whether to normalize Q loss by |Q| meanactor_hidden_dims,value_hidden_dims: Network sizes (default (512,512,512,512))
Code
custom_train.py
EditableRead-only
1"""Offline RL policy training on OGBench.23This script trains an offline RL agent on OGBench environments.4The critic (Value network with TD loss) and the training loop are FIXED.5Only the PolicyTrainer class (policy network creation, actor loss, action sampling)6is editable -- this defines the policy training approach.7"""89import copy10import os11import time12from functools import partial13from typing import Any, Sequence1415import flax
Results
No results available yet.