ml-federated-aggregation

Classical MLflowerrigorous codebase

Description

Federated Learning Aggregation Strategy Design

Research Question

Design a novel server-side aggregation strategy for federated learning that achieves better convergence and higher test accuracy under heterogeneous (non-IID) data distributions across clients.

Background

Federated Learning (FL) trains a shared global model across many clients without centralizing data. The canonical algorithm, FedAvg, simply averages client model parameters weighted by sample count. However, when client data distributions are heterogeneous (non-IID), FedAvg suffers from "client drift" where local updates diverge, leading to slow convergence or poor final accuracy. Research has produced several improvements: FedProx adds a proximal penalty to local objectives, SCAFFOLD uses control variates for variance reduction, and methods like FedNova normalize updates by local steps. The aggregation strategy — how the server combines client updates into the global model — is the core algorithmic component that determines convergence behavior.

Task

Modify the ServerAggregator class in custom_fl_aggregation.py. You must implement the aggregate() method that takes the current global model state, a list of client updates (model parameters + metadata), and returns the new global model state. You may also customize client selection via select_clients().

Interface

class ServerAggregator:
    def __init__(self, global_model, args):
        # Initialize aggregation state (momentum buffers, control variates, etc.)

    def aggregate(self, global_state_dict, client_updates, round_num):
        # global_state_dict: OrderedDict of current global model parameters
        # client_updates: list of (state_dict, num_samples, avg_loss) tuples
        # round_num: current communication round (0-indexed)
        # Returns: OrderedDict of updated global model parameters

    def select_clients(self, num_available, num_to_select, round_num):
        # Returns: list of client indices to participate this round

Evaluation

The aggregation strategy is evaluated on three benchmarks with non-IID data:

  1. CIFAR-10 with Dirichlet split (alpha=0.1) — 100 clients, image classification
  2. FEMNIST (EMNIST ByClass) with Dirichlet split — 100 clients, character recognition
  3. Shakespeare (next character prediction) — naturally non-IID by speaker

Metric: test accuracy after 200 communication rounds (higher is better). Each round, 10 of 100 clients are selected, each trains for 5 local epochs with SGD (lr=0.01).

Code

custom_fl_aggregation.py
EditableRead-only
1# Custom federated learning aggregation strategy for MLS-Bench
2#
3# EDITABLE section: ServerAggregator class (aggregate method + helpers).
4# FIXED sections: everything else (config, data partitioning, client training,
5# FL simulation loop, evaluation).
6import argparse
7import copy
8import json
9import os
10import random
11import time
12from collections import OrderedDict
13from pathlib import Path
14
15import numpy as np

Results

ModelTypetest accuracy cifar10 test loss cifar10 best accuracy cifar10 test accuracy shakespeare test loss shakespeare best accuracy shakespeare test accuracy femnist test loss femnist best accuracy femnist
fedavgbaseline0.6341.3010.6450.4840.0360.4880.8070.7320.844
fedproxbaseline0.6301.3090.6490.4830.0360.4880.8110.6820.846
scaffoldbaseline0.6321.2900.6440.4850.0360.4900.8070.7180.845
anthropic/claude-opus-4.6vanilla0.6541.7010.669---0.8080.8840.838
deepseek-reasonervanilla---------
google/gemini-3.1-pro-previewvanilla0.6621.3960.666---0.8080.6950.842
anthropic/claude-opus-4.6agent0.6411.6660.6730.4930.0370.5030.8070.8640.839
deepseek-reasoneragent0.40411.4390.4050.4040.0520.4040.7892.8340.793
google/gemini-3.1-pro-previewagent0.6621.3960.666---0.8080.6950.842

Agent Conversations