Agent Conversation: google/gemini-3.1-pro-preview

Model: google/gemini-3.1-pro-preview
Total steps: 9
Messages: 20
UserStep 0

Task: ml-federated-aggregation

Federated Learning Aggregation Strategy Design

Research Question

Design a novel server-side aggregation strategy for federated learning that achieves better convergence and higher test accuracy under heterogeneous (non-IID) data distributions across clients.

Background

Federated Learning (FL) trains a shared global model across many clients without centralizing data. The canonical algorithm, FedAvg, simply averages client model parameters weighted by sample count. However, when client data distributions are heterogeneous (non-IID), FedAvg suffers from "client drift" where local updates diverge, leading to slow convergence or poor final accuracy. Research has produced several improvements: FedProx adds a proximal penalty to local objectives, SCAFFOLD uses control variates for variance reduction, and methods like FedNova normalize updates by local steps. The aggregation strategy — how the server combines client updates into the global model — is the core algorithmic component that determines convergence behavior.

Task

Modify the ServerAggregator class in custom_fl_aggregation.py. You must implement the aggregate() method that takes the current global model state, a list of client updates (model parameters + metadata), and returns the new global model state. You may also customize client selection via select_clients().

Interface

class ServerAggregator:
    def __init__(self, global_model, args):
        # Initialize aggregation state (momentum buffers, control variates, etc.)

    def aggregate(self, global_state_dict, client_updates, round_num):
        # global_state_dict: OrderedDict of current global model parameters
        # client_updates: list of (state_dict, num_samples, avg_loss) tuples
        # round_num: current communication round (0-indexed)
        # Returns: OrderedDict of updated global model parameters

    def select_clients(self, num_available, num_to_select, round_num):
        # Returns: list of client indices to participate this round

Evaluation

The aggregation strategy is evaluated on three benchmarks with non-IID data:

  1. CIFAR-10 with Dirichlet split (alpha=0.1) — 100 clients, image classification
  2. FEMNIST (EMNIST ByClass) with Dirichlet split — 100 clients, character recognition
  3. Shakespeare (next character prediction) — naturally non-IID by speaker

Metric: test accuracy after 200 communication rounds (higher is better). Each round, 10 of 100 clients are selected, each trains for 5 local epochs with SGD (lr=0.01).

flower/custom_fl_aggregation.py [EDITABLE — lines 353–424 only]

     1: # Custom federated learning aggregation strategy for MLS-Bench
     2: #
     3: # EDITABLE section: ServerAggregator class (aggregate method + helpers).
     4: # FIXED sections: everything else (config, data partitioning, client training,
     5: #                 FL simulation loop, evaluation).
     6: import argparse
     
AssistantStep 1edit
Tool Input
1{
2  "content": "class ServerAggregator:\n    \"\"\"Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)\"\"\"\n\n    def __init__(self, global_model, args):\n        self.args = args\n        self.momentum = None\n        \n        # Hyperparameters for FedNL-EMA\n        # gamma: Loss-aware scale. Upweights clients with high loss (rare/hard strata).\n        self.gamma = 1.0   \n        # beta: Server EMA momentum. Filters non-IID jitter and aligns global direction.\n        self.beta = 0.5    \n\n    def aggregate(self, global_state_dict, client_updates, round_num):\n        # Extract losses to compute relative client performance\n        losses = [upd[2] for upd in client_updates]\n        max_loss, min_loss = max(losses), min(losses)\n\n        # 1. Compute Loss-Aware Weights and Local Step Counts (taus)\n        raw_weights = []\n        taus = []\n        for state_dict, n_samples, avg_loss in client_updates:\n            # Normalize loss to [0, 1] to bound the exponential scaling\n            if max_loss > min_loss:\n                norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n            else:\n                norm_loss = 0.5\n\n            # Exponential loss-aware weighting: upweight worst-performing clients\n            weight = n_samples * np.exp(self.gamma * norm_loss)\n            raw_weights.append(weight)\n\n            # Compute actual local steps (tau_i) taken by the client\n            batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n            tau = self.args.local_epochs * max(1, batches)\n            taus.append(tau)\n\n        # Normalize weights to sum to 1\n        total_weight = sum(raw_weights)\n        p = [w / total_weight for w in raw_weights]\n\n        # Effective total tau (FedNova normalization factor)\n        tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n        # Initialize momentum buffers on the first round\n        if self.momentum is None:\n            self.momentum = OrderedDict()\n            for key, val in global_state_dict.items():\n                if val.is_floating_point():\n                    self.momentum[key] = torch.zeros_like(val).float()\n\n        new_state = OrderedDict()\n        for key in global_state_dict.keys():\n            # Integer layers (e.g., BN track_running_stats) cannot use momentum\n            if not global_state_dict[key].is_floating_point():\n                avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n                new_state[key] = avg_val.to(global_state_dict[key].dtype)\n                continue\n\n            # 2. Compute normalized update (FedNova formulation)\n            nova_update = torch.zeros_like(global_state_dict[key]).float()\n            for i, (state_dict, _, _) in enumerate(client_updates):\n                delta = (state_dict[key].float() - global_state_dict[key].float())\n                nova_update += p[i] * (delta / taus[i])\n\n            # Scale up by effective tau to recover update magnitude\n            aggr_delta = tau_eff * nova_update\n\n            # 3. Server-side Exponential Moving Average to filter client drift noise\n            self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n            \n            # Unbiased EMA calculation (Adam-style bias correction)\n            m_hat = self.momentum[key] / (1.0 - (self.beta ** (round_num + 1)))\n\n            # Apply robust update\n            new_state[key] = global_state_dict[key].float() + m_hat\n            \n            # Keep original dtype (in case it was half or double, though unlikely here)\n            new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n        return new_state\n\n    def select_clients(self, num_available, num_to_select, round_num):\n        return random.sample(range(num_available), min(num_to_select, num_available))",
3  "end_line": 424,
4  "op": "replace",
5  "start_line": 353,
6  "filename": "custom_fl_aggregation.py"
7}
Tool ResultStep 1
ERROR: Package 'custom_fl_aggregation.py' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "filename": "flower/custom_fl_aggregation.py",
3  "op": "replace",
4  "end_line": 424,
5  "content": "class ServerAggregator:\n    \"\"\"Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)\"\"\"\n\n    def __init__(self, global_model, args):\n        self.args = args\n        self.momentum = None\n        \n        # Hyperparameters for FedNL-EMA\n        self.gamma = 1.0   \n        self.beta = 0.5    \n\n    def aggregate(self, global_state_dict, client_updates, round_num):\n        losses = [upd[2] for upd in client_updates]\n        max_loss, min_loss = max(losses), min(losses)\n\n        raw_weights = []\n        taus = []\n        for state_dict, n_samples, avg_loss in client_updates:\n            if max_loss > min_loss:\n                norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n            else:\n                norm_loss = 0.5\n\n            weight = n_samples * np.exp(self.gamma * norm_loss)\n            raw_weights.append(weight)\n\n            batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n            tau = self.args.local_epochs * max(1, batches)\n            taus.append(tau)\n\n        total_weight = sum(raw_weights)\n        p = [w / total_weight for w in raw_weights]\n\n        tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n        if self.momentum is None:\n            self.momentum = OrderedDict()\n            for key, val in global_state_dict.items():\n                if val.is_floating_point():\n                    self.momentum[key] = torch.zeros_like(val).float()\n\n        new_state = OrderedDict()\n        for key in global_state_dict.keys():\n            if not global_state_dict[key].is_floating_point():\n                avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n                new_state[key] = avg_val.to(global_state_dict[key].dtype)\n                continue\n\n            nova_update = torch.zeros_like(global_state_dict[key]).float()\n            for i, (state_dict, _, _) in enumerate(client_updates):\n                delta = (state_dict[key].float() - global_state_dict[key].float())\n                nova_update += p[i] * (delta / taus[i])\n\n            aggr_delta = tau_eff * nova_update\n\n            self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n            \n            m_hat = self.momentum[key] / (1.0 - (self.beta ** (round_num + 1)))\n\n            new_state[key] = global_state_dict[key].float() + m_hat\n            new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n        return new_state\n\n    def select_clients(self, num_available, num_to_select, round_num):\n        return random.sample(range(num_available), min(num_to_select, num_available))",
6  "start_line": 353
7}
Tool ResultStep 2
OK: Replaced lines 353..424 with 66 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–418.

[Current file: flower/custom_fl_aggregation.py | editable: 353–418 | total: 515 lines]
   353  class ServerAggregator:
   354      """Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)"""
   355  
       ... (60 more lines) ...
   416  
   417      def select_clients(self, num_available, num_to_select, round_num):
   418          return random.sample(range(num_available), min(num_to_select, num_available))
AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### cifar10 (scripts/cifar10.sh)
Training metrics (last rounds):
TRAIN_METRICS round=180 avg_loss=0.0393 round_time=3.3s
TRAIN_METRICS round=185 avg_loss=0.0408 round_time=4.0s
TRAIN_METRICS round=190 avg_loss=0.0142 round_time=3.1s
TRAIN_METRICS round=195 avg_loss=0.0200 round_time=3.6s
TRAIN_METRICS round=200 avg_loss=0.0181 round_time=3.4s
Evaluation progress:
EVAL round=160 test_loss=1.3838 test_accuracy=0.6452 elapsed=569s
EVAL round=170 test_loss=1.3273 test_accuracy=0.6520 elapsed=608s
EVAL round=180 test_loss=1.3191 test_accuracy=0.6404 elapsed=638s
EVAL round=190 test_loss=1.2580 test_accuracy=0.6659 elapsed=674s
EVAL round=200 test_loss=1.3960 test_accuracy=0.6620 elapsed=711s
Final metrics (cifar10):
test_accuracy_cifar10: 0.6620
test_loss_cifar10: 1.3960
best_accuracy_cifar10: 0.6659

### femnist (scripts/femnist.sh)
Training metrics (last rounds):
TRAIN_METRICS round=180 avg_los
AssistantStep 4edit
Tool Input
1{
2  "end_line": 418,
3  "start_line": 353,
4  "filename": "flower/custom_fl_aggregation.py",
5  "content": "class ServerAggregator:\n    \"\"\"Federated Wiener filtering with Adaptive Loss Tracking (FedWALT).\n    \n    1. Loss-aware weighting: priorities clients poorly represented by the global model.\n    2. Cross-client Variance Filtering (Wiener Filter): Element-wise damping of parameters \n       that have high disagreement among clients, suppressing variance from non-IID drift.\n    3. Server-side Momentum: Smooths updates across rounds.\n    \"\"\"\n\n    def __init__(self, global_model, args):\n        self.args = args\n        self.exp_avg_mean = None\n        self.exp_avg_var = None\n\n        self.beta = 0.9    # EMA rate (0.9 gives a ~10 round memory)\n        self.gamma = 2.0   # Variance penalty weight. Higher = stricter suppression.\n        self.alpha = 1.0   # Loss-aware scaling factor.\n        self.server_lr = 1.0\n        self.eps = 1e-6\n\n    def aggregate(self, global_state_dict, client_updates, round_num):\n        losses = [upd[2] for upd in client_updates]\n        min_L, max_L = min(losses), max(losses)\n\n        raw_weights = []\n        for upd in client_updates:\n            n_samples = upd[1]\n            avg_loss = upd[2]\n            norm_l = (avg_loss - min_L) / (max_L - min_L + 1e-9)\n            w = n_samples * np.exp(self.alpha * norm_l)\n            raw_weights.append(w)\n\n        total_w = sum(raw_weights)\n        p = [w / total_w for w in raw_weights]\n\n        if self.exp_avg_mean is None:\n            self.exp_avg_mean = OrderedDict()\n            self.exp_avg_var = OrderedDict()\n            for key, val in global_state_dict.items():\n                if val.is_floating_point():\n                    self.exp_avg_mean[key] = torch.zeros_like(val).float()\n                    self.exp_avg_var[key] = torch.zeros_like(val).float()\n\n        new_state = OrderedDict()\n        for key in global_state_dict.keys():\n            if not global_state_dict[key].is_floating_point():\n                avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n                new_state[key] = avg_val.to(global_state_dict[key].dtype)\n                continue\n\n            mean_delta = torch.zeros_like(global_state_dict[key]).float()\n            client_deltas = []\n            for i, upd in enumerate(client_updates):\n                delta = upd[0][key].float() - global_state_dict[key].float()\n                client_deltas.append(delta)\n                mean_delta += p[i] * delta\n\n            spatial_var = torch.zeros_like(global_state_dict[key]).float()\n            for i in range(len(client_updates)):\n                spatial_var += p[i] * ((client_deltas[i] - mean_delta) ** 2)\n\n            self.exp_avg_mean[key] = self.beta * self.exp_avg_mean[key] + (1.0 - self.beta) * mean_delta\n            self.exp_avg_var[key] = self.beta * self.exp_avg_var[key] + (1.0 - self.beta) * spatial_var\n\n            bias_corr = 1.0 - (self.beta ** (round_num + 1))\n            m_hat = self.exp_avg_mean[key] / bias_corr\n            v_hat = self.exp_avg_var[key] / bias_corr\n\n            # Wiener filter scaling\n            signal = m_hat ** 2\n            scale = signal / (signal + self.gamma * v_hat + self.eps)\n\n            update = self.server_lr * m_hat * scale\n\n            new_state[key] = global_state_dict[key].float() + update\n            new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n        return new_state\n\n    def select_clients(self, num_available, num_to_select, round_num):\n        return random.sample(range(num_available), min(num_to_select, num_available))",
6  "op": "replace"
7}
Tool ResultStep 4
OK: Replaced lines 353..418 with 81 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–433.

[Current file: flower/custom_fl_aggregation.py | editable: 353–433 | total: 530 lines]
   353  class ServerAggregator:
   354      """Federated Wiener filtering with Adaptive Loss Tracking (FedWALT).
   355      
       ... (75 more lines) ...
   431  
   432      def select_clients(self, num_available, num_to_select, round_num):
   433          return random.sample(range(num_available), min(num_to_select, num_available))
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)

### cifar10 (scripts/cifar10.sh)
[COMMAND FAILED exit=1]
Device: cuda
Loading dataset: cifar10
Number of clients: 100
Traceback (most recent call last):
  File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 530, in <module>
    run_fl_simulation(args)
  File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 476, in run_fl_simulation
    updated_state, n_samples, avg_loss = client_local_train(
  File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 274, in client_local_train
    model.to(device)
  File "/srv/home/bohanlyu/miniconda3/envs/mlsbench-flower/lib/python3.10/site-p
AssistantStep 6edit
Tool Input
1{
2  "start_line": 353,
3  "content": "class ServerAggregator:\n    \"\"\"Federated Agreement-Aware Exponential Averaging (FedAA-EMA).\n    \n    1. Loss-aware weighting: upweights clients with high loss (hard strata).\n    2. FedNova normalization: corrects local step heterogeneity.\n    3. Agreement-aware scaling: scales down updates where heterogeneous clients disagree.\n    4. Server-Side EMA: filters non-IID jitter and aligns global direction.\n    \"\"\"\n\n    def __init__(self, global_model, args):\n        self.args = args\n        self.momentum = None\n        \n        # Method Hyperparameters\n        self.gamma = 1.0       # Loss-aware weight scale\n        self.beta = 0.5        # Server EMA momentum\n        self.kappa = 0.5       # Agreement strictness (0.5 = sqrt scaling)\n        self.eps = 1e-6\n\n    def aggregate(self, global_state_dict, client_updates, round_num):\n        losses = [upd[2] for upd in client_updates]\n        max_loss, min_loss = max(losses), min(losses)\n\n        raw_weights = []\n        taus = []\n        for state_dict, n_samples, avg_loss in client_updates:\n            # Loss-aware dynamic weighting\n            if max_loss > min_loss:\n                norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n            else:\n                norm_loss = 0.5\n            weight = n_samples * np.exp(self.gamma * norm_loss)\n            raw_weights.append(weight)\n\n            # Local steps calculation (FedNova)\n            batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n            tau = self.args.local_epochs * max(1, batches)\n            taus.append(tau)\n\n        total_weight = sum(raw_weights)\n        p = [w / total_weight for w in raw_weights]\n        tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n        # Initialize momentum (EMA buffer)\n        if self.momentum is None:\n            self.momentum = OrderedDict()\n            for key, val in global_state_dict.items():\n                if val.is_floating_point():\n                    self.momentum[key] = torch.zeros_like(val).cpu()\n\n        new_state = OrderedDict()\n        for key in global_state_dict.keys():\n            global_val = global_state_dict[key]\n            \n            # Non-floating point tensors (e.g., BN num_batches_tracked)\n            if not global_val.is_floating_point():\n                avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n                new_state[key] = avg_val.to(global_val.dtype)\n                continue\n\n            # Core float parameters\n            aggr_delta = torch.zeros_like(global_val).float().cpu()\n            abs_delta_sum = torch.zeros_like(global_val).float().cpu()\n            \n            for i, (state_dict, _, _) in enumerate(client_updates):\n                local_val = state_dict[key].float().cpu()\n                g_val = global_val.float().cpu()\n                \n                # Per-step local direction\n                delta_normed = (local_val - g_val) / taus[i]\n                \n                aggr_delta.add_(delta_normed, alpha=p[i])\n                abs_delta_sum.add_(torch.abs(delta_normed), alpha=p[i])\n\n            # Measure parameter-wise agreement among diverse client updates\n            agreement = torch.abs(aggr_delta) / (abs_delta_sum + self.eps)\n            \n            # Suppress updates in highly conflicting dimensions + apply effective scale\n            aggr_delta = tau_eff * aggr_delta * (agreement ** self.kappa)\n\n            # Server momentum to smooth trajectory\n            self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n            bias_corr = 1.0 - (self.beta ** (round_num + 1))\n            m_hat = self.momentum[key] / bias_corr\n\n            # Step global model\n            new_val = global_val.float().cpu() + m_hat\n            new_state[key] = new_val.to(global_val.dtype)\n\n        # Clear memory safely\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n        return new_state\n\n    def select_clients(self, num_available, num_to_select, round_num):\n        return random.sample(range(num_available), min(num_to_select, num_available))",
4  "filename": "flower/custom_fl_aggregation.py",
5  "op": "replace",
6  "end_line": 433
7}
Tool ResultStep 6
OK: Replaced lines 353..433 with 97 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–449.

[Current file: flower/custom_fl_aggregation.py | editable: 353–449 | total: 546 lines]
   353  class ServerAggregator:
   354      """Federated Agreement-Aware Exponential Averaging (FedAA-EMA).
   355      
       ... (91 more lines) ...
   447  
   448      def select_clients(self, num_available, num_to_select, round_num):
   449          return random.sample(range(num_available), min(num_to_select, num_available))
AssistantStep 7test
Tool ResultStep 7
[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final)

[NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer.

### cifar10 (scripts/cifar10.sh)
[COMMAND FAILED exit=1]
Device: cuda
Loading dataset: cifar10
Number of clients: 100
Traceback (most recent call last):
  File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 546, in <module>
    run_fl_simulation(args)
  File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 492, in run_fl_simulation
    updated_state, n_samples, avg_loss = client_local_train(
  File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 274,
AssistantStep 8edit
Tool Input
1{
2  "filename": "flower/custom_fl_aggregation.py",
3  "end_line": 449,
4  "op": "replace",
5  "content": "class ServerAggregator:\n    \"\"\"Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)\"\"\"\n\n    def __init__(self, global_model, args):\n        self.args = args\n        self.momentum = None\n        \n        # Hyperparameters for FedNL-EMA\n        self.gamma = 1.0   \n        self.beta = 0.5    \n\n    def aggregate(self, global_state_dict, client_updates, round_num):\n        losses = [upd[2] for upd in client_updates]\n        max_loss, min_loss = max(losses), min(losses)\n\n        raw_weights = []\n        taus = []\n        for state_dict, n_samples, avg_loss in client_updates:\n            if max_loss > min_loss:\n                norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n            else:\n                norm_loss = 0.5\n\n            weight = n_samples * np.exp(self.gamma * norm_loss)\n            raw_weights.append(weight)\n\n            batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n            tau = self.args.local_epochs * max(1, batches)\n            taus.append(tau)\n\n        total_weight = sum(raw_weights)\n        p = [w / total_weight for w in raw_weights]\n\n        tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n        if self.momentum is None:\n            self.momentum = OrderedDict()\n            for key, val in global_state_dict.items():\n                if val.is_floating_point():\n                    self.momentum[key] = torch.zeros_like(val).float()\n\n        new_state = OrderedDict()\n        for key in global_state_dict.keys():\n            if not global_state_dict[key].is_floating_point():\n                avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n                new_state[key] = avg_val.to(global_state_dict[key].dtype)\n                continue\n\n            nova_update = torch.zeros_like(global_state_dict[key]).float()\n            for i, (state_dict, _, _) in enumerate(client_updates):\n                delta = (state_dict[key].float() - global_state_dict[key].float())\n                nova_update += p[i] * (delta / taus[i])\n\n            aggr_delta = tau_eff * nova_update\n\n            self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n            \n            m_hat = self.momentum[key] / (1.0 - (self.beta ** (round_num + 1)))\n\n            new_state[key] = global_state_dict[key].float() + m_hat\n            new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n        return new_state\n\n    def select_clients(self, num_available, num_to_select, round_num):\n        return random.sample(range(num_available), min(num_to_select, num_available))",
6  "start_line": 353
7}
Tool ResultStep 8
OK: Replaced lines 353..449 with 66 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–418.

[Current file: flower/custom_fl_aggregation.py | editable: 353–418 | total: 515 lines]
   353  class ServerAggregator:
   354      """Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)"""
   355  
       ... (60 more lines) ...
   416  
   417      def select_clients(self, num_available, num_to_select, round_num):
   418          return random.sample(range(num_available), min(num_to_select, num_available))
AssistantStep 9submit
Tool Input
1{
2  "n": 1
3}
Tool ResultStep 9
[submit] Submitting result from test #1 as final.

### cifar10 (scripts/cifar10.sh)
Training metrics (last rounds):
TRAIN_METRICS round=180 avg_loss=0.0393 round_time=3.3s
TRAIN_METRICS round=185 avg_loss=0.0408 round_time=4.0s
TRAIN_METRICS round=190 avg_loss=0.0142 round_time=3.1s
TRAIN_METRICS round=195 avg_loss=0.0200 round_time=3.6s
TRAIN_METRICS round=200 avg_loss=0.0181 round_time=3.4s
Evaluation progress:
EVAL round=160 test_loss=1.3838 test_accuracy=0.6452 elapsed=569s
EVAL round=170 test_loss=1.3273 test_accuracy=0.6520 elapsed=608s
EVAL round=180 test_loss=1.3191 test_accuracy=0.6404 elapsed=638s
EVAL round=190 test_loss=1.2580 test_accuracy=0.6659 elapsed=674s
EVAL round=200 test_loss=1.3960 test_accuracy=0.6620 elapsed=711s
Final metrics (cifar10):
test_accuracy_cifar10: 0.6620
test_loss_cifar10: 1.3960
best_accuracy_cifar10: 0.6659

### femnist (scripts/femnist.sh)
Training metrics (last rounds):
TRAIN_METRICS round=180 avg_loss=0.0172 round_time=50.5s
TRAIN_METRICS round