meta-rl

Classical MLoysterrigorous codebase

Description

Meta-RL: Context Encoder for PEARL Task Inference

Objective

Design a context encoder for the PEARL meta-reinforcement learning algorithm that maps transition tuples (state, action, reward) to latent task representations. The encoder should enable effective task inference from limited interaction data, allowing fast adaptation to unseen tasks.

Background

PEARL (Probabilistic Embeddings for Actor-critic RL) is a meta-RL algorithm that learns a probabilistic latent task variable z from context transitions. During meta-testing, the agent collects a few transitions from a new task, encodes them into a posterior distribution q(z|c), and conditions its policy on the sampled z.

The context encoder processes individual transition tuples and outputs Gaussian parameters (mean and log-variance). The PEARLAgent aggregates per-transition outputs via product of Gaussians to form the task posterior. Your goal is to design a better encoder architecture.

You can modify the CustomContextEncoder class (lines 27-53) and add custom imports (lines 21-23) in custom_encoder.py.

Interface

Your CustomContextEncoder must:

  • Extend PyTorchModule and call self.save_init_params(locals()) in __init__
  • Accept hidden_sizes, input_size, output_size as constructor arguments
  • Set self.output_size attribute in __init__
  • Implement forward(self, input, return_preactivations=False) returning tensors of shape (*, output_size)
  • Implement reset(self, num_tasks=1) to reset any stateful components

Environments

The encoder is evaluated across three environments with different reward structures and task complexities:

  1. Half-Cheetah Velocity (cheetah-vel): 30 train / 10 test tasks. Target velocities in [0, 3] m/s. Obs dim 20, action dim 6. Dense reward based on velocity matching. Tests encoding quality on a continuous task distribution with high-dimensional observations.

  2. Sparse Point Robot (sparse-point-robot): 40 train / 10 test tasks. Goals on a half-circle, sparse reward (+1 within goal radius, 0 otherwise). Obs dim 2, action dim 2. Tests the encoder's ability to extract task information from sparse reward signals.

  3. Point Robot (point-robot): 40 train / 10 test tasks. Goals sampled uniformly from [-1, 1]^2. Dense reward (negative L2 distance to goal). Obs dim 2, action dim 2. Tests basic encoding quality on a simple but diverse continuous task distribution.

Evaluation

Performance is measured by meta_test_return on each environment: average return on held-out test tasks after 20 meta-training iterations.

Code

custom_encoder.py
EditableRead-only
1"""Custom context encoder for PEARL meta-RL.
2
3Encodes transition tuples (s, a, r [, s']) into latent representations
4for task inference. The PEARLAgent calls this encoder and aggregates
5per-transition outputs via product of Gaussians to form the task posterior.
6
7Interface requirements:
8 - __init__(hidden_sizes, input_size, output_size, **kwargs)
9 - forward(input) -> output of shape (*, output_size)
10 - reset(num_tasks) -> None (reset stateful components)
11 - Must set self.output_size attribute in __init__
12 - Must extend PyTorchModule (call self.save_init_params(locals()) first)
13"""
14
15import torch
launch_custom.py
EditableRead-only
1"""Custom PEARL experiment launcher for meta-rl task.
2
3This script is FIXED (not editable). It imports CustomContextEncoder
4from custom_encoder.py and runs meta-training on the specified environment.
5"""
6import os
7import pathlib
8import numpy as np
9import click
10import torch
11
12from rlkit.envs import ENVS
13from rlkit.envs.wrappers import NormalizedBoxEnv
14from rlkit.torch.sac.policies import TanhGaussianPolicy
15from rlkit.torch.networks import FlattenMlp

Additional context files (read-only):

  • oyster/rlkit/torch/networks.py
  • oyster/rlkit/torch/sac/agent.py
  • oyster/rlkit/torch/sac/sac.py
  • oyster/rlkit/core/rl_algorithm.py
  • oyster/configs/default.py

Results

ModelTypemeta test return cheetah vel meta test return sparse point robot meta test return point robot
attention_encoderbaseline-96.7541.991-10.837
mlp_encoderbaseline-141.8603.803-11.878
recurrent_encoderbaseline-170.8822.818-11.383
anthropic/claude-opus-4.6vanilla-122.3984.434-11.745
anthropic/claude-opus-4.6vanilla-98.9445.622-12.672
anthropic/claude-opus-4.6vanilla-141.6205.533-9.424
deepseek-reasonervanilla---18.483
google/gemini-3.1-pro-previewvanilla-193.0941.097-18.540
google/gemini-3.1-pro-previewvanilla-155.0010.000-14.726
openai/gpt-5.4-provanilla-90.6514.639-12.149
openai/gpt-5.4-provanilla-78.3696.309-11.602
openai/gpt-5.4-provanilla-82.4871.171-11.716
qwen3.6-plusvanilla-99.1335.624-15.996
anthropic/claude-opus-4.6agent-160.1036.687-13.449
anthropic/claude-opus-4.6agent-87.0142.454-11.716
anthropic/claude-opus-4.6agent-85.2752.582-13.124
deepseek-reasoneragent---15.067
deepseek-reasoneragent---18.483
deepseek-reasoneragent---15.883
google/gemini-3.1-pro-previewagent-92.8733.976-10.464
google/gemini-3.1-pro-previewagent-93.7503.249-11.493
google/gemini-3.1-pro-previewagent-128.1551.360-12.799
openai/gpt-5.4-proagent-89.6896.141-12.427
openai/gpt-5.4-proagent-150.4655.494-13.848
openai/gpt-5.4-proagent-83.1121.176-13.401
qwen3.6-plusagent-99.6904.637-14.073
qwen3.6-plusagent-110.4445.624-15.996
qwen3.6-plusagent-89.7822.559-14.200