meta-rl
Description
Meta-RL: Context Encoder for PEARL Task Inference
Objective
Design a context encoder for the PEARL meta-reinforcement learning algorithm that maps transition tuples (state, action, reward) to latent task representations. The encoder should enable effective task inference from limited interaction data, allowing fast adaptation to unseen tasks.
Background
PEARL (Probabilistic Embeddings for Actor-critic RL) is a meta-RL algorithm that learns a probabilistic latent task variable z from context transitions. During meta-testing, the agent collects a few transitions from a new task, encodes them into a posterior distribution q(z|c), and conditions its policy on the sampled z.
The context encoder processes individual transition tuples and outputs Gaussian parameters (mean and log-variance). The PEARLAgent aggregates per-transition outputs via product of Gaussians to form the task posterior. Your goal is to design a better encoder architecture.
You can modify the CustomContextEncoder class (lines 27-53) and add custom imports (lines 21-23) in custom_encoder.py.
Interface
Your CustomContextEncoder must:
- Extend
PyTorchModuleand callself.save_init_params(locals())in__init__ - Accept
hidden_sizes,input_size,output_sizeas constructor arguments - Set
self.output_sizeattribute in__init__ - Implement
forward(self, input, return_preactivations=False)returning tensors of shape(*, output_size) - Implement
reset(self, num_tasks=1)to reset any stateful components
Environments
The encoder is evaluated across three environments with different reward structures and task complexities:
-
Half-Cheetah Velocity (
cheetah-vel): 30 train / 10 test tasks. Target velocities in [0, 3] m/s. Obs dim 20, action dim 6. Dense reward based on velocity matching. Tests encoding quality on a continuous task distribution with high-dimensional observations. -
Sparse Point Robot (
sparse-point-robot): 40 train / 10 test tasks. Goals on a half-circle, sparse reward (+1 within goal radius, 0 otherwise). Obs dim 2, action dim 2. Tests the encoder's ability to extract task information from sparse reward signals. -
Point Robot (
point-robot): 40 train / 10 test tasks. Goals sampled uniformly from [-1, 1]^2. Dense reward (negative L2 distance to goal). Obs dim 2, action dim 2. Tests basic encoding quality on a simple but diverse continuous task distribution.
Evaluation
Performance is measured by meta_test_return on each environment: average return on held-out test tasks after 20 meta-training iterations.
Code
1"""Custom context encoder for PEARL meta-RL.23Encodes transition tuples (s, a, r [, s']) into latent representations4for task inference. The PEARLAgent calls this encoder and aggregates5per-transition outputs via product of Gaussians to form the task posterior.67Interface requirements:8- __init__(hidden_sizes, input_size, output_size, **kwargs)9- forward(input) -> output of shape (*, output_size)10- reset(num_tasks) -> None (reset stateful components)11- Must set self.output_size attribute in __init__12- Must extend PyTorchModule (call self.save_init_params(locals()) first)13"""1415import torch
1"""Custom PEARL experiment launcher for meta-rl task.23This script is FIXED (not editable). It imports CustomContextEncoder4from custom_encoder.py and runs meta-training on the specified environment.5"""6import os7import pathlib8import numpy as np9import click10import torch1112from rlkit.envs import ENVS13from rlkit.envs.wrappers import NormalizedBoxEnv14from rlkit.torch.sac.policies import TanhGaussianPolicy15from rlkit.torch.networks import FlattenMlp
Additional context files (read-only):
oyster/rlkit/torch/networks.pyoyster/rlkit/torch/sac/agent.pyoyster/rlkit/torch/sac/sac.pyoyster/rlkit/core/rl_algorithm.pyoyster/configs/default.py
Results
| Model | Type | meta test return cheetah vel ↑ | meta test return sparse point robot ↑ | meta test return point robot ↑ |
|---|---|---|---|---|
| attention_encoder | baseline | -96.754 | 1.991 | -10.837 |
| mlp_encoder | baseline | -141.860 | 3.803 | -11.878 |
| recurrent_encoder | baseline | -170.882 | 2.818 | -11.383 |
| anthropic/claude-opus-4.6 | vanilla | -122.398 | 4.434 | -11.745 |
| anthropic/claude-opus-4.6 | vanilla | -98.944 | 5.622 | -12.672 |
| anthropic/claude-opus-4.6 | vanilla | -141.620 | 5.533 | -9.424 |
| deepseek-reasoner | vanilla | - | - | -18.483 |
| google/gemini-3.1-pro-preview | vanilla | -193.094 | 1.097 | -18.540 |
| google/gemini-3.1-pro-preview | vanilla | -155.001 | 0.000 | -14.726 |
| openai/gpt-5.4-pro | vanilla | -90.651 | 4.639 | -12.149 |
| openai/gpt-5.4-pro | vanilla | -78.369 | 6.309 | -11.602 |
| openai/gpt-5.4-pro | vanilla | -82.487 | 1.171 | -11.716 |
| qwen3.6-plus | vanilla | -99.133 | 5.624 | -15.996 |
| anthropic/claude-opus-4.6 | agent | -160.103 | 6.687 | -13.449 |
| anthropic/claude-opus-4.6 | agent | -87.014 | 2.454 | -11.716 |
| anthropic/claude-opus-4.6 | agent | -85.275 | 2.582 | -13.124 |
| deepseek-reasoner | agent | - | - | -15.067 |
| deepseek-reasoner | agent | - | - | -18.483 |
| deepseek-reasoner | agent | - | - | -15.883 |
| google/gemini-3.1-pro-preview | agent | -92.873 | 3.976 | -10.464 |
| google/gemini-3.1-pro-preview | agent | -93.750 | 3.249 | -11.493 |
| google/gemini-3.1-pro-preview | agent | -128.155 | 1.360 | -12.799 |
| openai/gpt-5.4-pro | agent | -89.689 | 6.141 | -12.427 |
| openai/gpt-5.4-pro | agent | -150.465 | 5.494 | -13.848 |
| openai/gpt-5.4-pro | agent | -83.112 | 1.176 | -13.401 |
| qwen3.6-plus | agent | -99.690 | 4.637 | -14.073 |
| qwen3.6-plus | agent | -110.444 | 5.624 | -15.996 |
| qwen3.6-plus | agent | -89.782 | 2.559 | -14.200 |