meta-rl-algorithm

Classical MLoysterrigorous codebase

Description

Meta-RL Algorithm Design

Objective

Design a complete meta-reinforcement learning algorithm for fast adaptation to new tasks from limited interaction data. You must implement both the agent (how to encode context and condition the policy) and the training algorithm (how to meta-train the agent across tasks).

Background

Meta-RL algorithms learn to learn: they train across a distribution of tasks so that at test time, the agent can quickly adapt to a new, unseen task from just a few interactions. The key challenge is designing:

  1. Task inference: How to encode past experience (context) into a compact task representation
  2. Policy conditioning: How to condition the policy on this task representation
  3. Meta-training: How to optimize the agent across tasks so it generalizes to new ones

Different approaches exist: PEARL uses a probabilistic encoder with product-of-Gaussians aggregation; FOCAL uses contrastive learning for task embeddings; VariBAD uses a recurrent encoder with reward prediction.

Your Task

Modify the CustomMetaRLAgent and CustomMetaRLAlgorithm classes in custom_meta_rl.py. The template provides fixed infrastructure (environment setup, evaluation, replay buffers, network building blocks) — you design the algorithm.

Agent Interface (CustomMetaRLAgent)

Your agent must implement:

  • get_action(obs, deterministic=False) -> (action_np, agent_info) — sample action conditioned on task belief
  • update_context(transition_tuple) -> None — accumulate online experience (called during rollout)
  • adapt() -> None — perform task inference from collected context (called after exploration)
  • clear_context(num_tasks=1) -> None — reset context and task belief
  • infer_posterior(context_tensor) -> None — encode context from replay buffer (for training)
  • context property — return collected context
  • z attribute — latent task variable tensor
  • networks property — list of nn.Module for GPU transfer

Algorithm Interface (CustomMetaRLAlgorithm)

Your algorithm must implement:

  • collect_initial_data() — gather initial exploration data for all training tasks
  • train_iteration(iteration_idx) -> dict — one meta-training iteration (data collection + gradient updates)
  • networks property — all networks for GPU transfer

Available Utilities

The template provides these fixed utilities you can use:

  • build_mlp(input_dim, output_dim, hidden_dim, n_layers) — simple MLP
  • build_policy(obs_dim, action_dim, latent_dim, net_size) — TanhGaussianPolicy
  • build_qf(obs_dim, action_dim, latent_dim, net_size) — Q-function
  • build_vf(obs_dim, latent_dim, net_size) — V-function
  • create_replay_buffers(env, tasks) — replay buffer pair
  • sample_context_from_buffer(enc_replay_buffer, indices, batch_size, ...) — sample context
  • sample_sac_batch(replay_buffer, indices, batch_size) — sample RL batch
  • collect_data(agent, env, sampler, replay_buffer, enc_replay_buffer, ...) — collect trajectories
  • InPlacePathSampler from rlkit — trajectory sampler

Environments

Three MuJoCo environments with different challenges:

  1. Half-Cheetah Velocity (cheetah-vel): 30 train / 10 test tasks. Target velocities in [0, 3] m/s. Obs dim 20, action dim 6. Dense reward (velocity matching). High-dimensional observations require strong encoding.

  2. Sparse Point Robot (sparse-point-robot): 40 train / 10 test tasks. Goals on a half-circle, sparse reward (+1 near goal, 0 otherwise). Obs dim 2, action dim 2. Sparse reward makes task inference especially challenging.

  3. Point Robot (point-robot): 40 train / 10 test tasks. Goals in [-1, 1]^2. Dense reward (neg L2 distance). Obs dim 2, action dim 2. Tests basic meta-learning quality.

Evaluation

Performance is measured by meta_test_return on each environment: average return on held-out test tasks after meta-training. The evaluation protocol collects exploration trajectories, calls agent.adapt(), then evaluates with a deterministic policy.

Key Design Dimensions

  • Context encoding: Permutation-invariant (MLP + aggregation) vs. sequential (RNN/GRU) vs. attention
  • Task variable: Probabilistic (information bottleneck) vs. deterministic
  • Encoder loss: KL divergence, contrastive, reward prediction, or reconstruction
  • RL algorithm: SAC variants, policy gradient, or other

Code

custom_meta_rl.py
EditableRead-only
1"""Custom meta-RL algorithm template for meta-rl-algorithm task.
2
3FIXED infrastructure (not editable): environment setup, network building blocks,
4replay buffers, sampler, evaluation protocol, and outer training loop.
5EDITABLE region: CustomMetaRLAgent and CustomMetaRLAlgorithm classes.
6"""
7import os
8import sys
9import copy
10import argparse
11import numpy as np
12
13import torch
14import torch.nn as nn
15import torch.nn.functional as F

Additional context files (read-only):

  • oyster/rlkit/torch/networks.py
  • oyster/rlkit/torch/sac/policies.py
  • oyster/configs/default.py

Results

ModelTypemeta test return point robot meta test return cheetah vel meta test return sparse point robot
focalbaseline-12.862-91.9230.233
pearlbaseline-15.468-64.6345.491
varibadbaseline-12.494-69.4310.000
anthropic/claude-opus-4.6vanilla-11.160-74.8510.256
anthropic/claude-opus-4.6vanilla-14.165-64.1444.850
anthropic/claude-opus-4.6vanilla-15.411-67.1874.876
deepseek-reasonervanilla-22.428-277.3170.000
openai/gpt-5.4-provanilla-15.459-52.6370.000
openai/gpt-5.4-provanilla-11.774-74.7120.000
openai/gpt-5.4-provanilla-9.803-87.6060.000
anthropic/claude-opus-4.6agent-14.056-56.0881.098
anthropic/claude-opus-4.6agent-15.326-57.0700.000
anthropic/claude-opus-4.6agent-13.009-55.9734.140
deepseek-reasoneragent-22.430-276.7660.000
deepseek-reasoneragent-22.428-277.4180.000
deepseek-reasoneragent-22.429-277.2360.000
openai/gpt-5.4-proagent-11.934-80.0941.045