jepa-regularizer

Deep Learningeb_jeparigorous codebase

Description

JEPA Self-Supervised Learning: Anti-Collapse Regularization

Research Question

Design an improved anti-collapse regularization loss for Joint Embedding Predictive Architecture (JEPA) self-supervised image representation learning. Your regularizer should prevent representation collapse (where all inputs map to the same output) while encouraging the model to learn useful, discriminative features.

What You Can Modify

The CustomRegularizer class (lines 33-53) in custom_regularizer.py. This class receives two projected embedding tensors from different augmented views of the same images and must return a loss dictionary.

Interface:

  • Input: z1: [B, D] and z2: [B, D] -- projected embeddings from two augmented views
  • Output: dict with at least a "loss" key containing a scalar tensor

You may add any parameters to __init__, define helper methods, and use any PyTorch operations. The imports at the top of the file (torch, torch.nn, torch.nn.functional, etc.) are available.

Evaluation

  • Metric: val_acc -- linear probe classification accuracy on CIFAR-10 (higher is better)
  • Benchmarks: Three backbone architectures (ResNet-18, ResNet-34, ResNet-50) test regularizer generalization across model scales
  • Projector: features_dim -> 2048 -> 2048 MLP
  • Training: 100 epochs, batch size 256, LARS optimizer (lr=0.3), warmup cosine schedule
  • Dataset: CIFAR-10 (50k train / 10k val)

Code

custom_regularizer.py
EditableRead-only
1"""
2CIFAR-10 JEPA Self-Supervised Training Script (Self-Contained)
3
4Trains a ResNet-18 backbone with a projector using a two-view augmentation
5pipeline and an anti-collapse regularization loss. Evaluation is performed
6via an online linear probe on CIFAR-10 validation set.
7
8Usage:
9 python custom_regularizer.py
10"""
11
12import sys; sys.path = [p for p in sys.path if not __import__('os').path.isfile(__import__('os').path.join(p, 'logging.py'))]
13import os
14import math
15import time

Additional context files (read-only):

  • eb_jepa/losses.py
  • eb_jepa/jepa.py

Results

ModelTypeval acc resnet18 val acc resnet34 val acc resnet50
naiveagent34.13030.95034.170
sigregagent83.01083.42084.190
vicregagent80.54079.55081.640