ml-active-learning

Classical MLbadgerigorous codebase

Description

Active Learning: Query Strategy Design

Research Question

Design a novel pool-based active learning query strategy that outperforms existing methods (uncertainty sampling, entropy sampling, BADGE, BAIT, BALD) across diverse tabular classification datasets.

Background

Active learning aims to minimize labeling cost by intelligently selecting which unlabeled samples to query for labels. In pool-based active learning, a query strategy selects batches of samples from an unlabeled pool to be labeled by an oracle, then the model is retrained. The goal is to achieve the highest possible accuracy with the fewest labeled samples.

Classic approaches include:

  • Uncertainty Sampling: Select samples where the model is least confident (lowest max predicted probability)
  • Entropy Sampling: Select samples with highest predictive entropy
  • Query By Committee: Select samples with maximal disagreement among an ensemble

Modern approaches incorporate diversity and information-theoretic principles:

  • BADGE (Ash et al., ICLR 2020): Uses gradient embeddings with k-means++ for diverse, uncertain batch selection
  • BAIT (Ash et al., NeurIPS 2021): Optimizes Fisher information to select maximally informative batches
  • BALD (Houlsby et al., 2011): Uses MC Dropout to estimate mutual information between predictions and parameters

Task

Modify the CustomSampling class in badge/query_strategies/custom_sampling.py to implement a novel query strategy. The strategy must implement the query(n) method that returns n indices from the unlabeled pool.

Interface

class CustomSampling(Strategy):
    def __init__(self, X, Y, idxs_lb, net, handler, args):
        super().__init__(X, Y, idxs_lb, net, handler, args)

    def query(self, n) -> np.ndarray:
        # Must return n indices into self.X of unlabeled samples to label
        ...

Available from the Strategy base class:

  • self.X: pool features (numpy array, shape [n_pool, n_features])
  • self.Y: pool labels (torch LongTensor, shape [n_pool])
  • self.idxs_lb: boolean mask of labeled samples
  • self.n_pool: total pool size
  • self.predict_prob(X, Y): softmax probabilities [len(X), n_classes]
  • self.predict_prob_dropout_split(X, Y, n_drop): MC dropout probs [n_drop, len(X), n_classes]
  • self.get_embedding(X, Y): penultimate-layer embeddings [len(X), emb_dim]
  • self.get_grad_embedding(X, Y): gradient embeddings [len(X), emb_dim * n_classes]
  • self.get_exp_grad_embedding(X, Y): expected Fisher embeddings [len(X), n_classes, emb_dim]

Evaluation

  • Datasets: 3 OpenML tabular classification datasets (letter recognition, spambase, splice)
  • Protocol: 20 rounds of batch active learning, evaluated after each round
  • Metrics:
    • accuracy: Test accuracy at the end of 20 AL rounds (fixed label budget)
    • auc: Area under the learning curve (accuracy vs. number of labeled samples), measuring sample efficiency across all rounds
  • Higher is better for both metrics.

Code

custom_sampling.py
EditableRead-only
1"""Custom active learning query strategy.
2
3This module defines a CustomSampling strategy that inherits from the badge
4framework's Strategy base class. The agent must implement the query() method
5to select the most informative samples from the unlabeled pool.
6
7Interface contract:
8 - self.X: numpy array of all pool features, shape (n_pool, n_features)
9 - self.Y: torch LongTensor of all pool labels, shape (n_pool,)
10 - self.idxs_lb: boolean array, True for labeled samples
11 - self.n_pool: total number of pool samples
12 - self.clf: the trained neural network model
13 - self.predict_prob(X, Y): returns softmax probabilities, shape (len(X), n_classes)
14 - self.predict_prob_dropout_split(X, Y, n_drop): returns MC dropout probs, shape (n_drop, len(X), n_classes)
15 - self.get_embedding(X, Y): returns penultimate-layer embeddings, shape (len(X), emb_dim)

Additional context files (read-only):

  • badge/query_strategies/strategy.py

Results

ModelTypeaccuracy letter auc letter accuracy spambase auc spambase accuracy splice auc splice
badgebaseline0.8320.7250.9280.8960.8090.736
baitbaseline0.7910.6710.9310.9110.8040.744
baldbaseline0.8040.6870.8980.8870.7890.729
least_confidencebaseline0.7520.6260.9240.8990.8050.729
randombaseline0.8070.7020.8930.8870.7900.730
deepseek-reasonervanilla------
google/gemini-3.1-pro-previewvanilla0.7990.7020.9090.8830.7740.732
openai/gpt-5.4vanilla------
qwen/qwen3.6-plusvanilla------
deepseek-reasoneragent------
google/gemini-3.1-pro-previewagent0.8220.6850.9370.9080.8620.763
openai/gpt-5.4agent------
qwen/qwen3.6-plusagent------

Agent Conversations