ml-federated-aggregation
Description
Federated Learning Aggregation Strategy Design
Research Question
Design a novel server-side aggregation strategy for federated learning that achieves better convergence and higher test accuracy under heterogeneous (non-IID) data distributions across clients.
Background
Federated Learning (FL) trains a shared global model across many clients without centralizing data. The canonical algorithm, FedAvg, simply averages client model parameters weighted by sample count. However, when client data distributions are heterogeneous (non-IID), FedAvg suffers from "client drift" where local updates diverge, leading to slow convergence or poor final accuracy. Research has produced several improvements: FedProx adds a proximal penalty to local objectives, SCAFFOLD uses control variates for variance reduction, and methods like FedNova normalize updates by local steps. The aggregation strategy — how the server combines client updates into the global model — is the core algorithmic component that determines convergence behavior.
Task
Modify the ServerAggregator class in custom_fl_aggregation.py. You must implement the aggregate() method that takes the current global model state, a list of client updates (model parameters + metadata), and returns the new global model state. You may also customize client selection via select_clients().
Interface
class ServerAggregator:
def __init__(self, global_model, args):
# Initialize aggregation state (momentum buffers, control variates, etc.)
def aggregate(self, global_state_dict, client_updates, round_num):
# global_state_dict: OrderedDict of current global model parameters
# client_updates: list of (state_dict, num_samples, avg_loss) tuples
# round_num: current communication round (0-indexed)
# Returns: OrderedDict of updated global model parameters
def select_clients(self, num_available, num_to_select, round_num):
# Returns: list of client indices to participate this round
Evaluation
The aggregation strategy is evaluated on three benchmarks with non-IID data:
- CIFAR-10 with Dirichlet split (alpha=0.1) — 100 clients, image classification
- FEMNIST (EMNIST ByClass) with Dirichlet split — 100 clients, character recognition
- Shakespeare (next character prediction) — naturally non-IID by speaker
Metric: test accuracy after 200 communication rounds (higher is better). Each round, 10 of 100 clients are selected, each trains for 5 local epochs with SGD (lr=0.01).
Code
1# Custom federated learning aggregation strategy for MLS-Bench2#3# EDITABLE section: ServerAggregator class (aggregate method + helpers).4# FIXED sections: everything else (config, data partitioning, client training,5# FL simulation loop, evaluation).6import argparse7import copy8import json9import os10import random11import time12from collections import OrderedDict13from pathlib import Path1415import numpy as np
Results
| Model | Type | test accuracy cifar10 ↑ | test loss cifar10 ↓ | best accuracy cifar10 ↑ | test accuracy shakespeare ↑ | test loss shakespeare ↓ | best accuracy shakespeare ↑ | test accuracy femnist ↑ | test loss femnist ↓ | best accuracy femnist ↑ |
|---|---|---|---|---|---|---|---|---|---|---|
| fedavg | baseline | 0.634 | 1.301 | 0.645 | 0.484 | 0.036 | 0.488 | 0.807 | 0.732 | 0.844 |
| fedprox | baseline | 0.630 | 1.309 | 0.649 | 0.483 | 0.036 | 0.488 | 0.811 | 0.682 | 0.846 |
| scaffold | baseline | 0.632 | 1.290 | 0.644 | 0.485 | 0.036 | 0.490 | 0.807 | 0.718 | 0.845 |
| anthropic/claude-opus-4.6 | vanilla | 0.654 | 1.701 | 0.669 | - | - | - | 0.808 | 0.884 | 0.838 |
| deepseek-reasoner | vanilla | - | - | - | - | - | - | - | - | - |
| google/gemini-3.1-pro-preview | vanilla | 0.662 | 1.396 | 0.666 | - | - | - | 0.808 | 0.695 | 0.842 |
| anthropic/claude-opus-4.6 | agent | 0.641 | 1.666 | 0.673 | 0.493 | 0.037 | 0.503 | 0.807 | 0.864 | 0.839 |
| deepseek-reasoner | agent | 0.404 | 11.439 | 0.405 | 0.404 | 0.052 | 0.404 | 0.789 | 2.834 | 0.793 |
| google/gemini-3.1-pro-preview | agent | 0.662 | 1.396 | 0.666 | - | - | - | 0.808 | 0.695 | 0.842 |