Agent Conversation: openai/gpt-5.4-pro
Task: optimization-gradient-compression
Gradient Compression for Communication-Efficient Distributed Training
Research Question
Design a gradient compression operator that reduces communication cost in distributed training while maintaining convergence quality (test accuracy).
Background
In distributed data-parallel training, gradient communication is often the bottleneck. Workers compute local gradients, which must be aggregated (e.g., via all-reduce) before the optimizer step. Gradient compression reduces the volume of data communicated by applying lossy compression to gradients before transmission.
Three main families of compression exist:
- Sparsification: Keep only a subset of gradient elements (e.g., TopK selects the largest magnitudes)
- Quantization: Reduce the precision of gradient values (e.g., QSGD uses stochastic rounding to discrete levels)
- Low-rank approximation: Approximate gradient matrices with low-rank factors (e.g., PowerSGD)
A key challenge is that naive compression introduces bias or variance that degrades convergence. Error feedback (accumulating compression residuals for the next iteration) is a widely-used technique to correct this.
Task
Modify the Compressor class in custom_compressor.py. Your compressor must implement:
__init__(self, compress_ratio): Initialize with a target compression ratio (0.01 = 100x compression)compress(self, tensor, name): Compress a gradient tensor, returning(compressed_tensors, ctx)decompress(self, compressed_tensors, ctx): Reconstruct the gradient
The compressor may maintain internal state (e.g., error feedback residuals) across calls. The name parameter identifies parameters for per-parameter state tracking.
Interface
class Compressor:
def __init__(self, compress_ratio=0.01): ...
def compress(self, tensor, name) -> (list[Tensor], ctx): ...
def decompress(self, compressed_tensors, ctx) -> Tensor: ...
compress_ratio: Fraction of gradient elements/information to retain (0.01 = keep 1%)compressed_tensors: List of tensors that would be communicated over the networkctx: Local context (not communicated) needed for decompression- The decompressed tensor must have the same shape as the original input
Evaluation
Trained and evaluated on three settings with 100x compression (compress_ratio=0.01):
- ResNet-20 / CIFAR-10 (0.27M params): Small model, standard benchmark
- VGG-11-BN / CIFAR-100 (9.8M params): Larger model, harder 100-class problem
- ResNet-56 / CIFAR-10 (0.85M params): Deeper model, tests scalability
Metric: best test accuracy (higher is better). All settings use SGD with momentum, cosine LR schedule, and 200 training epochs.
Hints
- Error feedback is essential for biased compressors (like TopK). Without it, convergence degrades significantly.
- Combining sparsification with quantization can yield better compression/accuracy tradeoffs.
- Adaptive compression (differen
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_compressor.py",
4 "content": "class Compressor:\n \"\"\"Residual-adaptive stable-support hybrid compressor.\n\n Uses error feedback, dense transmission for critical 1D statistics,\n EMA-stabilized TopK support, and a low-rank fallback for matrix/conv\n gradients when rank-1 is cheaper and more accurate than sparse updates.\n \"\"\"\n\n def __init__(self, compress_ratio=0.01):\n self.compress_ratio = compress_ratio\n self.residuals = {}\n self.scores = {}\n self.subspaces = {}\n self.beta = 0.9\n self.eps = 1e-8\n self.small_tensor = max(128, int(1.0 / max(compress_ratio, 1e-8)))\n\n def compress(self, tensor, name):\n residual = self.residuals.get(name)\n tensor = tensor if residual is None else tensor + residual\n shape, flat = tensor.shape, tensor.flatten()\n numel = flat.numel()\n\n if tensor.ndim == 1 or numel <= self.small_tensor:\n self.residuals[name] = torch.zeros_like(tensor)\n return [flat.clone()], ('dense', shape)\n\n score = self.scores.get(name)\n if score is None or score.numel() != numel or score.device != flat.device:\n score = flat.abs()\n else:\n score = self.beta * score + (1.0 - self.beta) * flat.abs()\n self.scores[name] = score.detach()\n\n boost = 0.0 if residual is None else min(\n 3.0, (residual.norm() / (tensor.norm() + self.eps)).item())\n k = max(1, min(numel, int(numel * self.compress_ratio * (1.0 + boost))))\n _, indices = torch.topk(score, k, sorted=False)\n values = flat[indices]\n sparse = torch.zeros_like(flat)\n sparse.scatter_(0, indices, values)\n sparse_rec = sparse.view(shape)\n sparse_err = (tensor - sparse_rec).norm()\n\n if tensor.ndim in (2, 4):\n mat = tensor.reshape(tensor.shape[0], -1)\n m, n = mat.shape\n if min(m, n) > 1 and (m + n + 1) <= 2 * k:\n q = self.subspaces.get(name)\n if q is None or q.numel() != n or q.device != mat.device:\n q = torch.randn(n, device=mat.device, dtype=mat.dtype)\n q = q / (q.norm() + self.eps)\n p = mat.mv(q)\n p = p / (p.norm() + self.eps)\n q = mat.t().mv(p)\n q = q / (q.norm() + self.eps)\n self.subspaces[name] = q.detach()\n sigma = p.dot(mat.mv(q))\n low_rank = (sigma * torch.outer(p, q)).view(shape)\n if (tensor - low_rank).norm() <= sparse_err:\n self.residuals[name] = (tensor - low_rank).detach()\n return [p, q, sigma.view(1)], ('lowrank', shape)\n\n self.residuals[name] = (tensor - sparse_rec).detach()\n return [values, indices], ('topk', numel, shape)\n\n def decompress(self, compressed_tensors, ctx):\n if ctx[0] == 'dense':\n return compressed_tensors[0].view(ctx[1])\n if ctx[0] == 'topk':\n _, numel, shape = ctx\n values, indices = compressed_tensors\n out = torch.zeros(numel, dtype=values.dtype, device=values.device)\n out.scatter_(0, indices, values)\n return out.view(shape)\n p, q, sigma = compressed_tensors\n return (sigma * torch.outer(p, q)).view(ctx[1])",
5 "after_line": 232,
6 "start_line": 183,
7 "end_line": 232
8}OK: Replaced lines 183..232 with 77 line(s) in pytorch-vision/custom_compressor.py. Editable range: 182–259.
[Current file: pytorch-vision/custom_compressor.py | editable: 182–259 | total: 400 lines]
179| # ============================================================================
180| # EDITABLE SECTION — Gradient Compressor (lines 182-232)
181| # ============================================================================
182
183 class Compressor:
184 """Residual-adaptive stable-support hybrid compressor.
185
186 Uses error feedback, dense transmission for critical 1D statistics,
187 EMA-stabilized TopK support, and a low-rank fallback for matrix/conv
188 gradients when rank-1 is cheaper and more accurate than sparse updates.
189 """
190
191 def __init__(self, compress_ratio=0.01):
192 self.compress_ratio = compress_ratio
193 self.residuals = {}
194 self.scores = {}
1
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_compressor.py",
4 "content": "class Compressor:\n \"\"\"Structure-adaptive error-feedback compressor.\n\n Core idea: compare a TopK sparse update against a same-budget hybrid\n low-rank+TopK update, while leaving tiny / 1D tensors uncompressed.\n This preserves sensitive BN/bias gradients, exploits matrix structure in\n large conv/linear layers, and still uses error feedback to correct the\n remaining compression error.\n \"\"\"\n\n def __init__(self, compress_ratio=0.01):\n self.compress_ratio = compress_ratio\n self.residuals = {}\n self.proj = {}\n self.small_dense = 2048\n\n def compress(self, tensor, name):\n tensor = tensor + self.residuals.get(name, 0)\n shape = tensor.shape\n flat = tensor.reshape(-1)\n numel = flat.numel()\n\n if tensor.ndim == 1 or numel <= self.small_dense:\n self.residuals[name] = torch.zeros_like(tensor)\n return [tensor.clone()], (\"dense\", shape)\n\n k = max(1, int(numel * self.compress_ratio))\n _, idx = torch.topk(flat.abs(), k, sorted=False)\n vals = flat[idx]\n topk_rec = torch.zeros_like(flat)\n topk_rec.scatter_(0, idx, vals)\n best_pack, best_ctx, best_rec = [vals, idx], (\"topk\", numel, shape), topk_rec\n best_err = (flat - topk_rec).pow(2).sum()\n\n if tensor.ndim >= 2:\n mat = tensor.reshape(tensor.shape[0], -1)\n m, n = mat.shape\n base_cost = 2 * k # TopK communicates values + indices.\n lr_cost = m + n # rank-1 communicates two dense factors.\n if min(m, n) > 1 and lr_cost <= base_cost:\n q = self.proj.get(name)\n if q is None or q.numel() != n or q.device != flat.device or q.dtype != flat.dtype:\n q = torch.randn(n, device=flat.device, dtype=flat.dtype)\n q = q / (q.norm() + 1e-12)\n u = mat @ q\n if u.norm() > 0:\n u = u / (u.norm() + 1e-12)\n v = mat.t() @ u\n approx = torch.outer(u, v).reshape(-1)\n self.proj[name] = (v / (v.norm() + 1e-12)).detach()\n\n tail = flat - approx\n tail_k = max(0, (base_cost - lr_cost) // 2)\n if tail_k > 0:\n _, ridx = torch.topk(tail.abs(), min(tail_k, numel), sorted=False)\n rvals = tail[ridx]\n tail_rec = torch.zeros_like(flat)\n tail_rec.scatter_(0, ridx, rvals)\n else:\n rvals = torch.empty(0, device=flat.device, dtype=flat.dtype)\n ridx = torch.empty(0, device=flat.device, dtype=torch.long)\n tail_rec = torch.zeros_like(flat)\n\n hyb_rec = approx + tail_rec\n hyb_err = (flat - hyb_rec).pow(2).sum()\n if hyb_err < best_err:\n best_pack = [u, v, rvals, ridx]\n best_ctx = (\"hybrid\", numel, shape)\n best_rec = hyb_rec\n best_err = hyb_err\n else:\n self.proj.pop(name, None)\n else:\n self.proj.pop(name, None)\n\n self.residuals[name] = (flat - best_rec).view(shape).detach()\n return best_pack, best_ctx\n\n def decompress(self, compressed_tensors, ctx):\n mode = ctx[0]\n if mode == \"dense\":\n return compressed_tensors[0].view(ctx[1])\n if mode == \"topk\":\n _, numel, shape = ctx\n vals, idx = compressed_tensors\n out = torch.zeros(numel, dtype=vals.dtype, device=vals.device)\n out.scatter_(0, idx, vals)\n return out.view(shape)\n _, numel, shape = ctx\n u, v, rvals, ridx = compressed_tensors\n out = torch.outer(u, v).reshape(-1)\n if rvals.numel() > 0:\n tail = torch.zeros(numel, dtype=out.dtype, device=out.device)\n tail.scatter_(0, ridx, rvals)\n out = out + tail\n return out.view(shape)\n",
5 "after_line": 182,
6 "start_line": 183,
7 "end_line": 232
8}OK: Replaced lines 183..232 with 96 line(s) in pytorch-vision/custom_compressor.py. Editable range: 182–278. [Current file: pytorch-vision/custom_compressor.py | editable: 182–278 | total: 419 lines] 179| # ============================================================================ 180| # EDITABLE SECTION — Gradient Compressor (lines 182-232) 181| # ============================================================================ 182 183 class Compressor: 184 """Structure-adaptive error-feedback compressor. 185 186 Core idea: compare a TopK sparse update against a same-budget hybrid 187 low-rank+TopK update, while leaving tiny / 1D tensors uncompressed. 188 This preserves sensitive BN/bias gradients, exploits matrix structure in 189 large conv/linear layers, and still uses error feedback to correct the 190 remaining compression error. 191 """ 192 193 def __init__(self, compress_ratio=0.01): 194
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ## Seed 42 ### resnet20-cifar10 (scripts/resnet20_cifar10.sh) Training metrics (last steps): TRAIN_METRICS epoch=196 lr=0.001161 train_loss=0.0209 train_acc=99.56 TRAIN_METRICS epoch=197 lr=0.001103 train_loss=0.0205 train_acc=99.59 TRAIN_METRICS epoch=198 lr=0.001058 train_loss=0.0200 train_acc=99.60 TRAIN_METRICS epoch=199 lr=0.001026 train_loss=0.0188 train_acc=99.65 TRAIN_METRICS epoch=200 lr=0.001006 train_loss=0.0187 train_acc=99.64 test_acc=92.59 test_loss=0.2759 Final metrics (resnet20-cifar10): test_acc_resnet20-cifar10=92.5900, best_acc_resnet20-cifar10=92.5900, test_loss_resnet20-cifar10=0.2759 ### vgg11-cifar100 (scripts/vgg11_cifar100.sh) Training metrics (last steps): TRAIN_METRICS epoch=196 lr=0.000580 train_loss=0.0082 train_acc=99.88 TRAIN_METRICS epoch=197 lr=0.000551 train_loss=0.0078 train_acc=99.91 TRAIN_METRICS epoch=198 lr=0.000529 train_loss=0.0081 train_acc=99.89 TRAIN_METRICS epoch=199 l
Tool Input
1{
2 "op": "replace",
3 "filename": "pytorch-vision/custom_compressor.py",
4 "content": "class Compressor:\n \"\"\"Structure-adaptive error-feedback compressor.\n\n It compares a sparse TopK update against a same-budget hybrid made of a\n low-rank matrix approximation plus a sparse residual tail. Small tensors,\n normalization/bias vectors, and narrow classifier heads are sent dense to\n avoid hurting statistically sensitive parameters.\n \"\"\"\n\n def __init__(self, compress_ratio=0.01):\n self.compress_ratio = compress_ratio\n self.residuals = {}\n self.proj = {}\n self.small_dense = 2048\n self.head_dense = 65536\n self.max_rank = 4\n\n def compress(self, tensor, name):\n tensor = tensor + self.residuals.get(name, 0)\n shape = tensor.shape\n flat = tensor.reshape(-1)\n numel = flat.numel()\n dense_head = tensor.ndim == 2 and tensor.shape[0] <= 128 and numel <= self.head_dense\n\n if tensor.ndim == 1 or numel <= self.small_dense or dense_head:\n self.residuals[name] = torch.zeros_like(tensor)\n self.proj.pop(name, None)\n return [tensor.clone()], (\"dense\", shape)\n\n k = max(1, int(numel * self.compress_ratio))\n _, idx = torch.topk(flat.abs(), k, sorted=False)\n vals = flat[idx]\n topk_rec = torch.zeros_like(flat)\n topk_rec.scatter_(0, idx, vals)\n best_pack, best_ctx, best_rec = [vals, idx], (\"topk\", numel, shape), topk_rec\n best_err = (flat - topk_rec).pow(2).sum()\n\n if tensor.ndim >= 2:\n mat = tensor.reshape(tensor.shape[0], -1)\n m, n = mat.shape\n base_cost = 2 * k # values + indices for TopK.\n rank = min(self.max_rank, min(m, n), max(1, (base_cost // 2) // max(1, m + n)))\n lr_cost = rank * (m + n)\n if rank > 0 and lr_cost <= base_cost:\n Q = self.proj.get(name)\n if Q is None or Q.shape != (n, rank) or Q.device != flat.device or Q.dtype != flat.dtype:\n Q = torch.randn(n, rank, device=flat.device, dtype=flat.dtype)\n Q = torch.linalg.qr(Q, mode='reduced')[0]\n P = mat @ Q\n if P.norm() > 0:\n P = torch.linalg.qr(P, mode='reduced')[0]\n Q_send = mat.t() @ P\n approx = (P @ Q_send.t()).reshape(-1)\n if Q_send.norm() > 0:\n self.proj[name] = torch.linalg.qr(Q_send, mode='reduced')[0].detach()\n else:\n self.proj.pop(name, None)\n\n tail = flat - approx\n tail_k = max(0, (base_cost - lr_cost) // 2)\n if tail_k > 0:\n _, ridx = torch.topk(tail.abs(), min(tail_k, numel), sorted=False)\n rvals = tail[ridx]\n tail_rec = torch.zeros_like(flat)\n tail_rec.scatter_(0, ridx, rvals)\n else:\n rvals = torch.empty(0, device=flat.device, dtype=flat.dtype)\n ridx = torch.empty(0, device=flat.device, dtype=torch.long)\n tail_rec = torch.zeros_like(flat)\n\n hyb_rec = approx + tail_rec\n hyb_err = (flat - hyb_rec).pow(2).sum()\n if hyb_err < best_err:\n best_pack = [P, Q_send, rvals, ridx]\n best_ctx = (\"hybrid\", numel, shape)\n best_rec = hyb_rec\n best_err = hyb_err\n else:\n self.proj.pop(name, None)\n else:\n self.proj.pop(name, None)\n\n self.residuals[name] = (flat - best_rec).view(shape).detach()\n return best_pack, best_ctx\n\n def decompress(self, compressed_tensors, ctx):\n mode = ctx[0]\n if mode == \"dense\":\n return compressed_tensors[0].view(ctx[1])\n if mode == \"topk\":\n _, numel, shape = ctx\n vals, idx = compressed_tensors\n out = torch.zeros(numel, dtype=vals.dtype, device=vals.device)\n out.scatter_(0, idx, vals)\n return out.view(shape)\n _, numel, shape = ctx\n P, Q_send, rvals, ridx = compressed_tensors\n out = (P @ Q_send.t()).reshape(-1)\n if rvals.numel() > 0:\n tail = torch.zeros(numel, dtype=out.dtype, device=out.device)\n tail.scatter_(0, ridx, rvals)\n out = out + tail\n return out.view(shape)\n",
5 "after_line": 182,
6 "start_line": 183,
7 "end_line": 278
8}OK: Replaced lines 183..278 with 103 line(s) in pytorch-vision/custom_compressor.py. Editable range: 182–285. [Current file: pytorch-vision/custom_compressor.py | editable: 182–285 | total: 426 lines] 179| # ============================================================================ 180| # EDITABLE SECTION — Gradient Compressor (lines 182-232) 181| # ============================================================================ 182 183 class Compressor: 184 """Structure-adaptive error-feedback compressor. 185 186 It compares a sparse TopK update against a same-budget hybrid made of a 187 low-rank matrix approximation plus a sparse residual tail. Small tensors, 188 normalization/bias vectors, and narrow classifier heads are sent dense to 189 avoid hurting statistically sensitive parameters. 190 """ 191 192 def __init__(self, compress_ratio=0.01): 193 self.compress_ratio = compress_ratio 19
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### resnet20-cifar10 (scripts/resnet20_cifar10.sh) Training metrics (last steps): TRAIN_METRICS epoch=196 lr=0.001161 train_loss=0.0205 train_acc=99.58 TRAIN_METRICS epoch=197 lr=0.001103 train_loss=0.0205 train_acc=99.56 TRAIN_METRICS epoch=198 lr=0.001058 train_loss=0.0195 train_acc=99.65 TRAIN_METRICS epoch=199 lr=0.001026 train_loss=0.0189 train_acc=99.64 TRAIN_METRICS epoch=200 lr=0.001006 train_loss=0.0177 train_acc=99.69 test_acc=92.72 test_loss=0.2665 Final metrics (resnet20-cifar10): test_acc_resnet20-cifar10=92.7200, best_acc_resnet20-cifar10=92.7200, test_loss_resnet20-cifar10=0.2665 ### vgg11-cifar100 (scripts/vgg11_cifar100.sh) Training metrics (last steps): TRAIN_METRICS epoch=196 lr=0.000580 train_loss=0.0085 train_acc=99.87 TRAIN_METRICS epoch=197 lr=0.000551 train_loss=0.0080 train_acc=99.91 TRAIN_METRICS epoch=198 lr=0.000529 train_loss=0.0081 train_acc=99.89 TRAIN_METRICS epoch=199 lr=0.000513 tra
[Test #3] (0 tests remaining) ## Seed 42 ### resnet20-cifar10 (scripts/resnet20_cifar10.sh) Training metrics (last steps): TRAIN_METRICS epoch=196 lr=0.001161 train_loss=0.0209 train_acc=99.59 TRAIN_METRICS epoch=197 lr=0.001103 train_loss=0.0207 train_acc=99.61 TRAIN_METRICS epoch=198 lr=0.001058 train_loss=0.0211 train_acc=99.56 TRAIN_METRICS epoch=199 lr=0.001026 train_loss=0.0200 train_acc=99.61 TRAIN_METRICS epoch=200 lr=0.001006 train_loss=0.0188 train_acc=99.66 test_acc=92.85 test_loss=0.2709 Final metrics (resnet20-cifar10): test_acc_resnet20-cifar10=92.8500, best_acc_resnet20-cifar10=92.8500, test_loss_resnet20-cifar10=0.2709 ### vgg11-cifar100 (scripts/vgg11_cifar100.sh) Training metrics (last steps): TRAIN_METRICS epoch=196 lr=0.000580 train_loss=0.0082 train_acc=99.90 TRAIN_METRICS epoch=197 lr=0.000551 train_loss=0.0078 train_acc=99.90 TRAIN_METRICS epoch=198 lr=0.000529 train_loss=0.0080 train_acc=99.89 TRAIN_METRICS epoch=199 lr=0.000513 train_loss=0.0076 train_acc=9