jepa-planning
Description
JEPA World Model Planning: Algorithm Design
Objective
Design a novel planning algorithm that leverages a learned JEPA (Joint Embedding Predictive Architecture) world model for goal-conditioned navigation. The task evaluates planning performance on the Two Rooms environment, where an agent must navigate around walls and through doors to reach a target location.
Research Question
Can you design a planning algorithm that outperforms standard derivative-free methods (CEM, MPPI) by better exploiting the structure of a learned JEPA world model?
What You Can Modify
You must implement the CustomPlanner class in custom_planner.py within the editable region on lines 303-349. The class extends the Planner abstract base class and must implement the plan() method. The JEPA world model checkpoint is fixed and provided by the evaluation environment; the task is to improve planning, not to retrain the model.
Interface
CustomPlanner Constructor
def __init__(self, unroll, action_dim=2, plan_length=15,
num_samples=200, n_iters=20, **kwargs):
unroll: Function to forward-simulate through the world modelaction_dim: Dimensionality of action space (2 for x/y movement)plan_length: Maximum planning horizonnum_samples: Number of action samples (adjustable)n_iters: Number of optimization iterations (adjustable)
plan() Method
def plan(self, obs_init, steps_left=None, eval_mode=True,
t0=False, plan_vis_path=None) -> PlanningResult:
obs_init: Initial observation encoding[1, C, 1, H, W]steps_left: Remaining steps in the episode- Returns:
PlanningResult(actions=Tensor[T, A], ...)
Available Methods (Inherited)
self.unroll(obs_init, actions): Forward-simulate actions through the world model.obs_init:[1, C, 1, H, W]initial observation encodingactions:[B, A, T]batch of action sequences- Returns:
[B, D, T+1, H, W]predicted state encodings
self.objective(encodings): Compute cost for predicted state encodings.encodings:[B, D, T, H, W]- Returns:
[B]cost per sample (lower is better)
self.cost_function(actions, obs_init): Convenience method that callsunrollthenobjective.- Returns:
[B]cost per sample
- Returns:
Evaluation
- Environment: Two Rooms (65x65 grid with wall and door)
- Episodes: 20 with random start and goal positions (fixed seed for reproducibility)
- Max steps per episode: 200
- Success threshold: Euclidean distance < 4.5 from goal
- Benchmarks: Three planning horizons (30, 60, 90 steps) test the algorithm across short, medium, and long-range planning
- Metric:
success_rate(fraction of successful episodes) per horizon
Code
1"""2Self-contained script for JEPA planning evaluation.34Part 1: Load a released AC Video JEPA checkpoint (or retrain on demand).5Part 2: Define CustomPlanner (EDITABLE REGION).6Part 3: Run planning evaluation and report metrics.7"""89import sys10from pathlib import Path as _Path1112# Prevent eb_jepa/logging.py from shadowing stdlib logging13_script_dir = str(_Path(__file__).resolve().parent)14sys.path = [p for p in sys.path if p != _script_dir]15
Additional context files (read-only):
eb_jepa/planning.pyeb_jepa/jepa.py
Results
No results yet.