jepa-regularizer
Description
JEPA Self-Supervised Learning: Anti-Collapse Regularization
Research Question
Design an improved anti-collapse regularization loss for Joint Embedding Predictive Architecture (JEPA) self-supervised image representation learning. Your regularizer should prevent representation collapse (where all inputs map to the same output) while encouraging the model to learn useful, discriminative features.
What You Can Modify
The CustomRegularizer class (lines 33-53) in custom_regularizer.py. This class receives two projected embedding tensors from different augmented views of the same images and must return a loss dictionary.
Interface:
- Input:
z1: [B, D]andz2: [B, D]-- projected embeddings from two augmented views - Output:
dictwith at least a"loss"key containing a scalar tensor
You may add any parameters to __init__, define helper methods, and use any PyTorch operations. The imports at the top of the file (torch, torch.nn, torch.nn.functional, etc.) are available.
Evaluation
- Metric:
val_acc-- linear probe classification accuracy on CIFAR-10 (higher is better) - Benchmarks: Three backbone architectures (ResNet-18, ResNet-34, ResNet-50) test regularizer generalization across model scales
- Projector: features_dim -> 2048 -> 2048 MLP
- Training: 100 epochs, batch size 256, LARS optimizer (lr=0.3), warmup cosine schedule
- Dataset: CIFAR-10 (50k train / 10k val)
Code
1"""2CIFAR-10 JEPA Self-Supervised Training Script (Self-Contained)34Trains a ResNet-18 backbone with a projector using a two-view augmentation5pipeline and an anti-collapse regularization loss. Evaluation is performed6via an online linear probe on CIFAR-10 validation set.78Usage:9python custom_regularizer.py10"""1112import sys; sys.path = [p for p in sys.path if not __import__('os').path.isfile(__import__('os').path.join(p, 'logging.py'))]13import os14import math15import time
Additional context files (read-only):
eb_jepa/losses.pyeb_jepa/jepa.py
Results
| Model | Type | val acc resnet18 ↑ | val acc resnet34 ↑ | val acc resnet50 ↑ |
|---|---|---|---|---|
| naive | agent | 34.130 | 30.950 | 34.170 |
| sigreg | agent | 83.010 | 83.420 | 84.190 |
| vicreg | agent | 80.540 | 79.550 | 81.640 |