meta-inner-loop-optimizer

Classical MLlearn2learnrigorous codebase

Description

Meta-Learning: Inner-Loop Optimization Algorithm Design

Research Question

Design and implement a novel inner-loop adaptation algorithm for gradient-based meta-learning. Your code goes in the InnerLoopOptimizer class in custom_maml.py. Reference implementations of MAML, Meta-SGD, and GBML from the learn2learn library are provided as read-only context.

Background

Gradient-based meta-learning (MAML-style) learns a model initialization that can be quickly adapted to new tasks via a few gradient steps. The inner loop determines how parameters are updated during task-specific adaptation, while the outer loop optimizes the initialization (and any optimizer state) across tasks.

The simplest inner loop is vanilla SGD (MAML), but many improvements exist:

  • Per-parameter learning rates: Meta-SGD learns a separate learning rate for each parameter
  • Selective adaptation: ANIL adapts only the classification head, freezing the backbone
  • Preconditioning: Meta-Curvature applies learned curvature matrices to gradients
  • Learned update rules: Using neural networks or structured transforms to generate updates

Key design choices include:

  • What to adapt: All parameters, head only, or a learned subset
  • How to scale gradients: Fixed LR, per-parameter LR, preconditioned, or transformed
  • Momentum and memory: Whether to incorporate information across inner-loop steps
  • Regularization: Constraining adaptation to prevent overfitting on few-shot support sets

Task

Modify the InnerLoopOptimizer class in custom_maml.py to implement your inner-loop adaptation algorithm. The class has three methods:

  • __init__(model, inner_lr): Initialize the optimizer and any learnable state
  • adapt(model, support_x, support_y, n_steps): Perform n_steps of adaptation on the support set
  • meta_parameters(): Return any learnable parameters of the optimizer itself

Interface

class InnerLoopOptimizer:
    def __init__(self, model: nn.Module, inner_lr: float):
        # model is the base model (for inspecting parameter shapes)
        # inner_lr is the default learning rate
        # Create any learnable parameters here
        ...

    def adapt(self, model: nn.Module, support_x: Tensor, support_y: Tensor,
              n_steps: int) -> nn.Module:
        # model is a CLONE — safe to modify in-place
        # Must use differentiable operations (torch.autograd.grad, NOT torch.optim)
        # Return the adapted model
        ...

    def meta_parameters(self) -> List[Tensor]:
        # Return learnable optimizer parameters for outer-loop optimization
        # Empty list if optimizer has no learnable state (e.g., vanilla MAML)
        ...

Available Context

  • learn2learn/algorithms/maml.py: MAML implementation (clone + adapt via differentiable SGD)
  • learn2learn/algorithms/meta_sgd.py: Meta-SGD (per-parameter learned learning rates)
  • learn2learn/algorithms/gbml.py: GBML base class (general gradient-based meta-learning with transforms)

Evaluation

Trained and evaluated on three few-shot image classification benchmarks using CNN4 backbone:

  • miniImageNet 5-way 1-shot (64 training / 16 val / 20 test classes, 600 examples each)
  • miniImageNet 5-way 5-shot (same dataset, 5 support examples per class)
  • CIFAR-FS 5-way 5-shot (100 classes from CIFAR-100, 5 support examples per class)

Metric: mean classification accuracy over 600 test episodes (higher is better). Meta-training runs for 60,000 iterations with 4 tasks per meta-batch. Inner loop uses 5 steps during training, 10 steps during evaluation.

Code

custom_maml.py
EditableRead-only
1# Custom inner-loop optimizer for gradient-based meta-learning
2#
3# EDITABLE section: InnerLoopOptimizer class and helper modules.
4# FIXED sections: everything else (config, data loading, backbone, outer loop, evaluation).
5#
6# Research question: Design the inner-loop adaptation algorithm that determines
7# HOW model parameters are updated during fast adaptation to a new task.
8import os
9import sys
10import copy
11import random
12from statistics import mean
13from typing import Optional, Tuple, Dict, List
14
15# Fix import path: exclude the learn2learn source tree so that

Additional context files (read-only):

  • learn2learn/learn2learn/algorithms/gbml.py

Results

ModelTypeaccuracy mini imagenet 1shot accuracy mini imagenet 5shot accuracy cifar fs 5shot
anilbaseline0.4390.6300.726
anilbaseline0.4390.6300.726
mamlbaseline0.4240.6460.710
meta_sgdbaseline0.4570.6270.704
anthropic/claude-opus-4.6vanilla-0.5980.730
deepseek-reasonervanilla0.2140.2300.236
google/gemini-3.1-pro-previewvanilla0.2420.6220.687
mamlvanilla0.217--
openai/gpt-5.4-provanilla-0.6510.733
qwen3.6-plusvanilla-0.5860.687
anilagent--0.625
anthropic/claude-opus-4.6agent0.453-0.730
deepseek-reasoneragent-0.2300.236
google/gemini-3.1-pro-previewagent-0.6220.687
mamlagent0.424--
meta_sgdagent0.431--
openai/gpt-5.4-proagent-0.6510.733
qwen3.6-plusagent-0.5860.687