Agent Conversation: google/gemini-3.1-pro-preview
Task: ml-federated-aggregation
Federated Learning Aggregation Strategy Design
Research Question
Design a novel server-side aggregation strategy for federated learning that achieves better convergence and higher test accuracy under heterogeneous (non-IID) data distributions across clients.
Background
Federated Learning (FL) trains a shared global model across many clients without centralizing data. The canonical algorithm, FedAvg, simply averages client model parameters weighted by sample count. However, when client data distributions are heterogeneous (non-IID), FedAvg suffers from "client drift" where local updates diverge, leading to slow convergence or poor final accuracy. Research has produced several improvements: FedProx adds a proximal penalty to local objectives, SCAFFOLD uses control variates for variance reduction, and methods like FedNova normalize updates by local steps. The aggregation strategy — how the server combines client updates into the global model — is the core algorithmic component that determines convergence behavior.
Task
Modify the ServerAggregator class in custom_fl_aggregation.py. You must implement the aggregate() method that takes the current global model state, a list of client updates (model parameters + metadata), and returns the new global model state. You may also customize client selection via select_clients().
Interface
class ServerAggregator:
def __init__(self, global_model, args):
# Initialize aggregation state (momentum buffers, control variates, etc.)
def aggregate(self, global_state_dict, client_updates, round_num):
# global_state_dict: OrderedDict of current global model parameters
# client_updates: list of (state_dict, num_samples, avg_loss) tuples
# round_num: current communication round (0-indexed)
# Returns: OrderedDict of updated global model parameters
def select_clients(self, num_available, num_to_select, round_num):
# Returns: list of client indices to participate this round
Evaluation
The aggregation strategy is evaluated on three benchmarks with non-IID data:
- CIFAR-10 with Dirichlet split (alpha=0.1) — 100 clients, image classification
- FEMNIST (EMNIST ByClass) with Dirichlet split — 100 clients, character recognition
- Shakespeare (next character prediction) — naturally non-IID by speaker
Metric: test accuracy after 200 communication rounds (higher is better). Each round, 10 of 100 clients are selected, each trains for 5 local epochs with SGD (lr=0.01).
flower/custom_fl_aggregation.py [EDITABLE — lines 353–424 only]
1: # Custom federated learning aggregation strategy for MLS-Bench
2: #
3: # EDITABLE section: ServerAggregator class (aggregate method + helpers).
4: # FIXED sections: everything else (config, data partitioning, client training,
5: # FL simulation loop, evaluation).
6: import argparse
Tool Input
1{
2 "content": "class ServerAggregator:\n \"\"\"Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)\"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n self.momentum = None\n \n # Hyperparameters for FedNL-EMA\n # gamma: Loss-aware scale. Upweights clients with high loss (rare/hard strata).\n self.gamma = 1.0 \n # beta: Server EMA momentum. Filters non-IID jitter and aligns global direction.\n self.beta = 0.5 \n\n def aggregate(self, global_state_dict, client_updates, round_num):\n # Extract losses to compute relative client performance\n losses = [upd[2] for upd in client_updates]\n max_loss, min_loss = max(losses), min(losses)\n\n # 1. Compute Loss-Aware Weights and Local Step Counts (taus)\n raw_weights = []\n taus = []\n for state_dict, n_samples, avg_loss in client_updates:\n # Normalize loss to [0, 1] to bound the exponential scaling\n if max_loss > min_loss:\n norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n else:\n norm_loss = 0.5\n\n # Exponential loss-aware weighting: upweight worst-performing clients\n weight = n_samples * np.exp(self.gamma * norm_loss)\n raw_weights.append(weight)\n\n # Compute actual local steps (tau_i) taken by the client\n batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n tau = self.args.local_epochs * max(1, batches)\n taus.append(tau)\n\n # Normalize weights to sum to 1\n total_weight = sum(raw_weights)\n p = [w / total_weight for w in raw_weights]\n\n # Effective total tau (FedNova normalization factor)\n tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n # Initialize momentum buffers on the first round\n if self.momentum is None:\n self.momentum = OrderedDict()\n for key, val in global_state_dict.items():\n if val.is_floating_point():\n self.momentum[key] = torch.zeros_like(val).float()\n\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n # Integer layers (e.g., BN track_running_stats) cannot use momentum\n if not global_state_dict[key].is_floating_point():\n avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n new_state[key] = avg_val.to(global_state_dict[key].dtype)\n continue\n\n # 2. Compute normalized update (FedNova formulation)\n nova_update = torch.zeros_like(global_state_dict[key]).float()\n for i, (state_dict, _, _) in enumerate(client_updates):\n delta = (state_dict[key].float() - global_state_dict[key].float())\n nova_update += p[i] * (delta / taus[i])\n\n # Scale up by effective tau to recover update magnitude\n aggr_delta = tau_eff * nova_update\n\n # 3. Server-side Exponential Moving Average to filter client drift noise\n self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n \n # Unbiased EMA calculation (Adam-style bias correction)\n m_hat = self.momentum[key] / (1.0 - (self.beta ** (round_num + 1)))\n\n # Apply robust update\n new_state[key] = global_state_dict[key].float() + m_hat\n \n # Keep original dtype (in case it was half or double, though unlikely here)\n new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n return new_state\n\n def select_clients(self, num_available, num_to_select, round_num):\n return random.sample(range(num_available), min(num_to_select, num_available))",
3 "end_line": 424,
4 "op": "replace",
5 "start_line": 353,
6 "filename": "custom_fl_aggregation.py"
7}ERROR: Package 'custom_fl_aggregation.py' is not in allowed packages
Tool Input
1{
2 "filename": "flower/custom_fl_aggregation.py",
3 "op": "replace",
4 "end_line": 424,
5 "content": "class ServerAggregator:\n \"\"\"Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)\"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n self.momentum = None\n \n # Hyperparameters for FedNL-EMA\n self.gamma = 1.0 \n self.beta = 0.5 \n\n def aggregate(self, global_state_dict, client_updates, round_num):\n losses = [upd[2] for upd in client_updates]\n max_loss, min_loss = max(losses), min(losses)\n\n raw_weights = []\n taus = []\n for state_dict, n_samples, avg_loss in client_updates:\n if max_loss > min_loss:\n norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n else:\n norm_loss = 0.5\n\n weight = n_samples * np.exp(self.gamma * norm_loss)\n raw_weights.append(weight)\n\n batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n tau = self.args.local_epochs * max(1, batches)\n taus.append(tau)\n\n total_weight = sum(raw_weights)\n p = [w / total_weight for w in raw_weights]\n\n tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n if self.momentum is None:\n self.momentum = OrderedDict()\n for key, val in global_state_dict.items():\n if val.is_floating_point():\n self.momentum[key] = torch.zeros_like(val).float()\n\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n if not global_state_dict[key].is_floating_point():\n avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n new_state[key] = avg_val.to(global_state_dict[key].dtype)\n continue\n\n nova_update = torch.zeros_like(global_state_dict[key]).float()\n for i, (state_dict, _, _) in enumerate(client_updates):\n delta = (state_dict[key].float() - global_state_dict[key].float())\n nova_update += p[i] * (delta / taus[i])\n\n aggr_delta = tau_eff * nova_update\n\n self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n \n m_hat = self.momentum[key] / (1.0 - (self.beta ** (round_num + 1)))\n\n new_state[key] = global_state_dict[key].float() + m_hat\n new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n return new_state\n\n def select_clients(self, num_available, num_to_select, round_num):\n return random.sample(range(num_available), min(num_to_select, num_available))",
6 "start_line": 353
7}OK: Replaced lines 353..424 with 66 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–418.
[Current file: flower/custom_fl_aggregation.py | editable: 353–418 | total: 515 lines]
353 class ServerAggregator:
354 """Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)"""
355
... (60 more lines) ...
416
417 def select_clients(self, num_available, num_to_select, round_num):
418 return random.sample(range(num_available), min(num_to_select, num_available))[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### cifar10 (scripts/cifar10.sh) Training metrics (last rounds): TRAIN_METRICS round=180 avg_loss=0.0393 round_time=3.3s TRAIN_METRICS round=185 avg_loss=0.0408 round_time=4.0s TRAIN_METRICS round=190 avg_loss=0.0142 round_time=3.1s TRAIN_METRICS round=195 avg_loss=0.0200 round_time=3.6s TRAIN_METRICS round=200 avg_loss=0.0181 round_time=3.4s Evaluation progress: EVAL round=160 test_loss=1.3838 test_accuracy=0.6452 elapsed=569s EVAL round=170 test_loss=1.3273 test_accuracy=0.6520 elapsed=608s EVAL round=180 test_loss=1.3191 test_accuracy=0.6404 elapsed=638s EVAL round=190 test_loss=1.2580 test_accuracy=0.6659 elapsed=674s EVAL round=200 test_loss=1.3960 test_accuracy=0.6620 elapsed=711s Final metrics (cifar10): test_accuracy_cifar10: 0.6620 test_loss_cifar10: 1.3960 best_accuracy_cifar10: 0.6659 ### femnist (scripts/femnist.sh) Training metrics (last rounds): TRAIN_METRICS round=180 avg_los
Tool Input
1{
2 "end_line": 418,
3 "start_line": 353,
4 "filename": "flower/custom_fl_aggregation.py",
5 "content": "class ServerAggregator:\n \"\"\"Federated Wiener filtering with Adaptive Loss Tracking (FedWALT).\n \n 1. Loss-aware weighting: priorities clients poorly represented by the global model.\n 2. Cross-client Variance Filtering (Wiener Filter): Element-wise damping of parameters \n that have high disagreement among clients, suppressing variance from non-IID drift.\n 3. Server-side Momentum: Smooths updates across rounds.\n \"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n self.exp_avg_mean = None\n self.exp_avg_var = None\n\n self.beta = 0.9 # EMA rate (0.9 gives a ~10 round memory)\n self.gamma = 2.0 # Variance penalty weight. Higher = stricter suppression.\n self.alpha = 1.0 # Loss-aware scaling factor.\n self.server_lr = 1.0\n self.eps = 1e-6\n\n def aggregate(self, global_state_dict, client_updates, round_num):\n losses = [upd[2] for upd in client_updates]\n min_L, max_L = min(losses), max(losses)\n\n raw_weights = []\n for upd in client_updates:\n n_samples = upd[1]\n avg_loss = upd[2]\n norm_l = (avg_loss - min_L) / (max_L - min_L + 1e-9)\n w = n_samples * np.exp(self.alpha * norm_l)\n raw_weights.append(w)\n\n total_w = sum(raw_weights)\n p = [w / total_w for w in raw_weights]\n\n if self.exp_avg_mean is None:\n self.exp_avg_mean = OrderedDict()\n self.exp_avg_var = OrderedDict()\n for key, val in global_state_dict.items():\n if val.is_floating_point():\n self.exp_avg_mean[key] = torch.zeros_like(val).float()\n self.exp_avg_var[key] = torch.zeros_like(val).float()\n\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n if not global_state_dict[key].is_floating_point():\n avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n new_state[key] = avg_val.to(global_state_dict[key].dtype)\n continue\n\n mean_delta = torch.zeros_like(global_state_dict[key]).float()\n client_deltas = []\n for i, upd in enumerate(client_updates):\n delta = upd[0][key].float() - global_state_dict[key].float()\n client_deltas.append(delta)\n mean_delta += p[i] * delta\n\n spatial_var = torch.zeros_like(global_state_dict[key]).float()\n for i in range(len(client_updates)):\n spatial_var += p[i] * ((client_deltas[i] - mean_delta) ** 2)\n\n self.exp_avg_mean[key] = self.beta * self.exp_avg_mean[key] + (1.0 - self.beta) * mean_delta\n self.exp_avg_var[key] = self.beta * self.exp_avg_var[key] + (1.0 - self.beta) * spatial_var\n\n bias_corr = 1.0 - (self.beta ** (round_num + 1))\n m_hat = self.exp_avg_mean[key] / bias_corr\n v_hat = self.exp_avg_var[key] / bias_corr\n\n # Wiener filter scaling\n signal = m_hat ** 2\n scale = signal / (signal + self.gamma * v_hat + self.eps)\n\n update = self.server_lr * m_hat * scale\n\n new_state[key] = global_state_dict[key].float() + update\n new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n return new_state\n\n def select_clients(self, num_available, num_to_select, round_num):\n return random.sample(range(num_available), min(num_to_select, num_available))",
6 "op": "replace"
7}OK: Replaced lines 353..418 with 81 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–433.
[Current file: flower/custom_fl_aggregation.py | editable: 353–433 | total: 530 lines]
353 class ServerAggregator:
354 """Federated Wiener filtering with Adaptive Loss Tracking (FedWALT).
355
... (75 more lines) ...
431
432 def select_clients(self, num_available, num_to_select, round_num):
433 return random.sample(range(num_available), min(num_to_select, num_available))[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)
### cifar10 (scripts/cifar10.sh)
[COMMAND FAILED exit=1]
Device: cuda
Loading dataset: cifar10
Number of clients: 100
Traceback (most recent call last):
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 530, in <module>
run_fl_simulation(args)
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 476, in run_fl_simulation
updated_state, n_samples, avg_loss = client_local_train(
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 274, in client_local_train
model.to(device)
File "/srv/home/bohanlyu/miniconda3/envs/mlsbench-flower/lib/python3.10/site-p
Tool Input
1{
2 "start_line": 353,
3 "content": "class ServerAggregator:\n \"\"\"Federated Agreement-Aware Exponential Averaging (FedAA-EMA).\n \n 1. Loss-aware weighting: upweights clients with high loss (hard strata).\n 2. FedNova normalization: corrects local step heterogeneity.\n 3. Agreement-aware scaling: scales down updates where heterogeneous clients disagree.\n 4. Server-Side EMA: filters non-IID jitter and aligns global direction.\n \"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n self.momentum = None\n \n # Method Hyperparameters\n self.gamma = 1.0 # Loss-aware weight scale\n self.beta = 0.5 # Server EMA momentum\n self.kappa = 0.5 # Agreement strictness (0.5 = sqrt scaling)\n self.eps = 1e-6\n\n def aggregate(self, global_state_dict, client_updates, round_num):\n losses = [upd[2] for upd in client_updates]\n max_loss, min_loss = max(losses), min(losses)\n\n raw_weights = []\n taus = []\n for state_dict, n_samples, avg_loss in client_updates:\n # Loss-aware dynamic weighting\n if max_loss > min_loss:\n norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n else:\n norm_loss = 0.5\n weight = n_samples * np.exp(self.gamma * norm_loss)\n raw_weights.append(weight)\n\n # Local steps calculation (FedNova)\n batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n tau = self.args.local_epochs * max(1, batches)\n taus.append(tau)\n\n total_weight = sum(raw_weights)\n p = [w / total_weight for w in raw_weights]\n tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n # Initialize momentum (EMA buffer)\n if self.momentum is None:\n self.momentum = OrderedDict()\n for key, val in global_state_dict.items():\n if val.is_floating_point():\n self.momentum[key] = torch.zeros_like(val).cpu()\n\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n global_val = global_state_dict[key]\n \n # Non-floating point tensors (e.g., BN num_batches_tracked)\n if not global_val.is_floating_point():\n avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n new_state[key] = avg_val.to(global_val.dtype)\n continue\n\n # Core float parameters\n aggr_delta = torch.zeros_like(global_val).float().cpu()\n abs_delta_sum = torch.zeros_like(global_val).float().cpu()\n \n for i, (state_dict, _, _) in enumerate(client_updates):\n local_val = state_dict[key].float().cpu()\n g_val = global_val.float().cpu()\n \n # Per-step local direction\n delta_normed = (local_val - g_val) / taus[i]\n \n aggr_delta.add_(delta_normed, alpha=p[i])\n abs_delta_sum.add_(torch.abs(delta_normed), alpha=p[i])\n\n # Measure parameter-wise agreement among diverse client updates\n agreement = torch.abs(aggr_delta) / (abs_delta_sum + self.eps)\n \n # Suppress updates in highly conflicting dimensions + apply effective scale\n aggr_delta = tau_eff * aggr_delta * (agreement ** self.kappa)\n\n # Server momentum to smooth trajectory\n self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n bias_corr = 1.0 - (self.beta ** (round_num + 1))\n m_hat = self.momentum[key] / bias_corr\n\n # Step global model\n new_val = global_val.float().cpu() + m_hat\n new_state[key] = new_val.to(global_val.dtype)\n\n # Clear memory safely\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n\n return new_state\n\n def select_clients(self, num_available, num_to_select, round_num):\n return random.sample(range(num_available), min(num_to_select, num_available))",
4 "filename": "flower/custom_fl_aggregation.py",
5 "op": "replace",
6 "end_line": 433
7}OK: Replaced lines 353..433 with 97 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–449.
[Current file: flower/custom_fl_aggregation.py | editable: 353–449 | total: 546 lines]
353 class ServerAggregator:
354 """Federated Agreement-Aware Exponential Averaging (FedAA-EMA).
355
... (91 more lines) ...
447
448 def select_clients(self, num_available, num_to_select, round_num):
449 return random.sample(range(num_available), min(num_to_select, num_available))[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final)
[NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer.
### cifar10 (scripts/cifar10.sh)
[COMMAND FAILED exit=1]
Device: cuda
Loading dataset: cifar10
Number of clients: 100
Traceback (most recent call last):
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 546, in <module>
run_fl_simulation(args)
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 492, in run_fl_simulation
updated_state, n_samples, avg_loss = client_local_train(
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/google_gemini-3.1-pro-preview_20260414_110948/flower/custom_fl_aggregation.py", line 274,
Tool Input
1{
2 "filename": "flower/custom_fl_aggregation.py",
3 "end_line": 449,
4 "op": "replace",
5 "content": "class ServerAggregator:\n \"\"\"Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)\"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n self.momentum = None\n \n # Hyperparameters for FedNL-EMA\n self.gamma = 1.0 \n self.beta = 0.5 \n\n def aggregate(self, global_state_dict, client_updates, round_num):\n losses = [upd[2] for upd in client_updates]\n max_loss, min_loss = max(losses), min(losses)\n\n raw_weights = []\n taus = []\n for state_dict, n_samples, avg_loss in client_updates:\n if max_loss > min_loss:\n norm_loss = (avg_loss - min_loss) / (max_loss - min_loss)\n else:\n norm_loss = 0.5\n\n weight = n_samples * np.exp(self.gamma * norm_loss)\n raw_weights.append(weight)\n\n batches = (n_samples + self.args.local_batch_size - 1) // self.args.local_batch_size\n tau = self.args.local_epochs * max(1, batches)\n taus.append(tau)\n\n total_weight = sum(raw_weights)\n p = [w / total_weight for w in raw_weights]\n\n tau_eff = sum(p[i] * taus[i] for i in range(len(p)))\n\n if self.momentum is None:\n self.momentum = OrderedDict()\n for key, val in global_state_dict.items():\n if val.is_floating_point():\n self.momentum[key] = torch.zeros_like(val).float()\n\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n if not global_state_dict[key].is_floating_point():\n avg_val = sum(client_updates[i][0][key].float() * p[i] for i in range(len(client_updates)))\n new_state[key] = avg_val.to(global_state_dict[key].dtype)\n continue\n\n nova_update = torch.zeros_like(global_state_dict[key]).float()\n for i, (state_dict, _, _) in enumerate(client_updates):\n delta = (state_dict[key].float() - global_state_dict[key].float())\n nova_update += p[i] * (delta / taus[i])\n\n aggr_delta = tau_eff * nova_update\n\n self.momentum[key] = self.beta * self.momentum[key] + (1.0 - self.beta) * aggr_delta\n \n m_hat = self.momentum[key] / (1.0 - (self.beta ** (round_num + 1)))\n\n new_state[key] = global_state_dict[key].float() + m_hat\n new_state[key] = new_state[key].to(global_state_dict[key].dtype)\n\n return new_state\n\n def select_clients(self, num_available, num_to_select, round_num):\n return random.sample(range(num_available), min(num_to_select, num_available))",
6 "start_line": 353
7}OK: Replaced lines 353..449 with 66 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–418.
[Current file: flower/custom_fl_aggregation.py | editable: 353–418 | total: 515 lines]
353 class ServerAggregator:
354 """Federated Normalized Loss-Aware Exponential Averaging (FedNL-EMA)"""
355
... (60 more lines) ...
416
417 def select_clients(self, num_available, num_to_select, round_num):
418 return random.sample(range(num_available), min(num_to_select, num_available))Tool Input
1{
2 "n": 1
3}[submit] Submitting result from test #1 as final. ### cifar10 (scripts/cifar10.sh) Training metrics (last rounds): TRAIN_METRICS round=180 avg_loss=0.0393 round_time=3.3s TRAIN_METRICS round=185 avg_loss=0.0408 round_time=4.0s TRAIN_METRICS round=190 avg_loss=0.0142 round_time=3.1s TRAIN_METRICS round=195 avg_loss=0.0200 round_time=3.6s TRAIN_METRICS round=200 avg_loss=0.0181 round_time=3.4s Evaluation progress: EVAL round=160 test_loss=1.3838 test_accuracy=0.6452 elapsed=569s EVAL round=170 test_loss=1.3273 test_accuracy=0.6520 elapsed=608s EVAL round=180 test_loss=1.3191 test_accuracy=0.6404 elapsed=638s EVAL round=190 test_loss=1.2580 test_accuracy=0.6659 elapsed=674s EVAL round=200 test_loss=1.3960 test_accuracy=0.6620 elapsed=711s Final metrics (cifar10): test_accuracy_cifar10: 0.6620 test_loss_cifar10: 1.3960 best_accuracy_cifar10: 0.6659 ### femnist (scripts/femnist.sh) Training metrics (last rounds): TRAIN_METRICS round=180 avg_loss=0.0172 round_time=50.5s TRAIN_METRICS round