Agent Conversation: deepseek-reasoner
Task: ml-federated-aggregation
Federated Learning Aggregation Strategy Design
Research Question
Design a novel server-side aggregation strategy for federated learning that achieves better convergence and higher test accuracy under heterogeneous (non-IID) data distributions across clients.
Background
Federated Learning (FL) trains a shared global model across many clients without centralizing data. The canonical algorithm, FedAvg, simply averages client model parameters weighted by sample count. However, when client data distributions are heterogeneous (non-IID), FedAvg suffers from "client drift" where local updates diverge, leading to slow convergence or poor final accuracy. Research has produced several improvements: FedProx adds a proximal penalty to local objectives, SCAFFOLD uses control variates for variance reduction, and methods like FedNova normalize updates by local steps. The aggregation strategy — how the server combines client updates into the global model — is the core algorithmic component that determines convergence behavior.
Task
Modify the ServerAggregator class in custom_fl_aggregation.py. You must implement the aggregate() method that takes the current global model state, a list of client updates (model parameters + metadata), and returns the new global model state. You may also customize client selection via select_clients().
Interface
class ServerAggregator:
def __init__(self, global_model, args):
# Initialize aggregation state (momentum buffers, control variates, etc.)
def aggregate(self, global_state_dict, client_updates, round_num):
# global_state_dict: OrderedDict of current global model parameters
# client_updates: list of (state_dict, num_samples, avg_loss) tuples
# round_num: current communication round (0-indexed)
# Returns: OrderedDict of updated global model parameters
def select_clients(self, num_available, num_to_select, round_num):
# Returns: list of client indices to participate this round
Evaluation
The aggregation strategy is evaluated on three benchmarks with non-IID data:
- CIFAR-10 with Dirichlet split (alpha=0.1) — 100 clients, image classification
- FEMNIST (EMNIST ByClass) with Dirichlet split — 100 clients, character recognition
- Shakespeare (next character prediction) — naturally non-IID by speaker
Metric: test accuracy after 200 communication rounds (higher is better). Each round, 10 of 100 clients are selected, each trains for 5 local epochs with SGD (lr=0.01).
flower/custom_fl_aggregation.py [EDITABLE — lines 353–424 only]
1: # Custom federated learning aggregation strategy for MLS-Bench
2: #
3: # EDITABLE section: ServerAggregator class (aggregate method + helpers).
4: # FIXED sections: everything else (config, data partitioning, client training,
5: # FL simulation loop, evaluation).
6: import argparse
Tool Input
1{
2 "op": "replace",
3 "filename": "flower/custom_fl_aggregation.py",
4 "start_line": 353,
5 "end_line": 424,
6 "content": "class ServerAggregator:\n \"\"\"FedAM: Federated Adaptive Momentum Aggregation.\n\n Novel aggregation strategy combining:\n 1. Server-side momentum (exponential moving average) of updates to smooth trajectory.\n 2. Per-parameter adaptive scaling based on variance (similar to Adam) for non-uniform learning rates.\n 3. Adaptive client weighting using loss reduction and consistency with momentum.\n 4. Outlier detection via Median Absolute Deviation (MAD) to reject harmful updates.\n\n Designed to handle non-IID data by reducing client drift and adapting to heterogeneous updates.\n \"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n # Momentum buffer for each parameter\n self.momentum = OrderedDict()\n # Variance buffer for each parameter (for adaptive scaling)\n self.variance = OrderedDict()\n # Initialize buffers with zeros\n for name, param in global_model.state_dict().items():\n self.momentum[name] = torch.zeros_like(param)\n self.variance[name] = torch.zeros_like(param)\n # Hyperparameters\n self.beta1 = 0.9 # momentum decay\n self.beta2 = 0.999 # variance decay\n self.epsilon = 1e-8\n self.server_lr = 0.01 # server learning rate\n self.outlier_threshold = 3.0 # MAD threshold for outlier rejection\n # Track per-client loss history for adaptive weighting\n self.client_loss_history = {}\n # Track global loss history for momentum weighting\n self.global_loss_history = []\n self.round_num = 0\n\n def aggregate(self, global_state_dict, client_updates, round_num):\n \"\"\"Aggregate client updates with adaptive momentum.\"\"\"\n self.round_num = round_num\n total_samples = sum(n for _, n, _ in client_updates)\n \n # Compute per-client deltas (update directions)\n client_deltas = []\n client_weights = []\n for state_dict, n_samples, avg_loss in client_updates:\n delta = OrderedDict()\n for key in global_state_dict.keys():\n delta[key] = state_dict[key] - global_state_dict[key]\n client_deltas.append(delta)\n # Base weight proportional to sample size\n base_weight = n_samples / total_samples\n client_weights.append(base_weight)\n\n # Detect and remove outlier updates using MAD\n filtered_deltas, filtered_weights = self._filter_outliers(\n client_deltas, client_weights, global_state_dict\n )\n \n # Compute adaptive weights based on loss reduction and momentum consistency\n adaptive_weights = self._compute_adaptive_weights(\n filtered_deltas, filtered_weights, client_updates\n )\n \n # Compute weighted average delta\n avg_delta = OrderedDict()\n for key in global_state_dict.keys():\n weighted_sum = torch.zeros_like(global_state_dict[key])\n for i, delta in enumerate(filtered_deltas):\n weighted_sum += delta[key] * adaptive_weights[i]\n avg_delta[key] = weighted_sum\n\n # Update momentum and variance buffers (Adam-like)\n for key in global_state_dict.keys():\n # Update momentum\n self.momentum[key] = self.beta1 * self.momentum[key] + (1 - self.beta1) * avg_delta[key]\n # Update variance\n self.variance[key] = self.beta2 * self.variance[key] + (1 - self.beta2) * avg_delta[key].pow(2)\n\n # Bias correction\n momentum_corrected = OrderedDict()\n variance_corrected = OrderedDict()\n for key in global_state_dict.keys():\n momentum_corrected[key] = self.momentum[key] / (1 - self.beta1 ** (round_num + 1))\n variance_corrected[key] = self.variance[key] / (1 - self.beta2 ** (round_num + 1))\n\n # Compute update with adaptive scaling\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n adaptive_update = momentum_corrected[key] / (variance_corrected[key].sqrt() + self.epsilon)\n new_state[key] = global_state_dict[key] + self.server_lr * adaptive_update\n\n return new_state\n\n def _filter_outliers(self, client_deltas, client_weights, global_state_dict):\n \"\"\"Filter outlier updates using Median Absolute Deviation (MAD).\"\"\"\n if len(client_deltas) <= 1:\n return client_deltas, client_weights\n \n # Compute L2 norm of each delta\n norms = []\n for delta in client_deltas:\n norm_sq = 0.0\n for key in delta.keys():\n norm_sq += delta[key].float().pow(2).sum().item()\n norms.append(np.sqrt(norm_sq))\n \n # Compute median and MAD\n median_norm = np.median(norms)\n mad = np.median(np.abs(norms - median_norm))\n if mad == 0:\n mad = 1e-6\n \n # Identify outliers (norm far from median)\n filtered_deltas = []\n filtered_weights = []\n for i, norm in enumerate(norms):\n z_score = abs(norm - median_norm) / mad\n if z_score < self.outlier_threshold:\n filtered_deltas.append(client_deltas[i])\n filtered_weights.append(client_weights[i])\n \n # If all filtered out, keep all\n if len(filtered_deltas) == 0:\n return client_deltas, client_weights\n \n # Normalize weights\n total = sum(filtered_weights)\n filtered_weights = [w / total for w in filtered_weights]\n return filtered_deltas, filtered_weights\n\n def _compute_adaptive_weights(self, deltas, base_weights, client_updates):\n \"\"\"Compute adaptive weights based on loss reduction and momentum consistency.\"\"\"\n # If first round, use base weights\n if self.round_num == 0:\n return base_weights\n \n # Compute consistency with momentum direction\n consistency_scores = []\n for delta in deltas:\n cos_sim = 0.0\n total_params = 0\n for key in delta.keys():\n if self.momentum[key].numel() > 0:\n # Flatten tensors for cosine similarity\n d_flat = delta[key].flatten().float()\n m_flat = self.momentum[key].flatten().float()\n if torch.norm(d_flat) > 1e-12 and torch.norm(m_flat) > 1e-12:\n cos_sim += torch.dot(d_flat, m_flat) / (torch.norm(d_flat) * torch.norm(m_flat))\n total_params += 1\n if total_params > 0:\n consistency_scores.append(cos_sim.item() / total_params)\n else:\n consistency_scores.append(0.0)\n \n # Normalize consistency scores to [0, 1]\n min_cons = min(consistency_scores)\n max_cons = max(consistency_scores)\n if max_cons - min_cons > 1e-6:\n consistency_norm = [(c - min_cons) / (max_cons - min_cons) for c in consistency_scores]\n else:\n consistency_norm = [0.5 for _ in consistency_scores]\n \n # Combine base weight, consistency, and loss reduction\n adaptive_weights = []\n for i, (_, n_samples, avg_loss) in enumerate(client_updates):\n # Loss reduction factor (lower loss -> higher weight)\n loss_factor = 1.0 / (1.0 + avg_loss) # simple inverse loss\n # Combine factors\n adaptive = base_weights[i] * (0.5 + 0.5 * consistency_norm[i]) * loss_factor\n adaptive_weights.append(adaptive)\n \n # Normalize to sum to 1\n total = sum(adaptive_weights)\n if total > 0:\n adaptive_weights = [w / total for w in adaptive_weights]\n else:\n adaptive_weights = base_weights\n \n return adaptive_weights\n\n def select_clients(self, num_available, num_to_select, round_num):\n \"\"\"Select clients with higher loss reduction (active learning).\"\"\"\n # For first few rounds, random selection\n if round_num < 5 or not self.client_loss_history:\n return random.sample(range(num_available), min(num_to_select, num_available))\n \n # Compute average loss reduction per client (if history available)\n client_scores = []\n for client_idx in range(num_available):\n if client_idx in self.client_loss_history:\n losses = self.client_loss_history[client_idx]\n if len(losses) >= 2:\n # Score = average loss reduction (negative if loss increased)\n reduction = losses[-2] - losses[-1] # positive if loss decreased\n client_scores.append((client_idx, reduction))\n else:\n client_scores.append((client_idx, 0.0))\n else:\n client_scores.append((client_idx, 0.0))\n \n # Sort by loss reduction (descending)\n client_scores.sort(key=lambda x: x[1], reverse=True)\n selected = [idx for idx, _ in client_scores[:num_to_select]]\n \n # Ensure we have exactly num_to_select (pad with random if needed)\n if len(selected) < num_to_select:\n remaining = [i for i in range(num_available) if i not in selected]\n selected.extend(random.sample(remaining, num_to_select - len(selected)))\n \n return selected"
7}OK: Replaced lines 353..424 with 208 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–560.
[Current file: flower/custom_fl_aggregation.py | editable: 353–560 | total: 657 lines]
353 class ServerAggregator:
354 """FedAM: Federated Adaptive Momentum Aggregation.
355
... (202 more lines) ...
558 selected.extend(random.sample(remaining, num_to_select - len(selected)))
559
560 return selectedTool Input
1{
2 "op": "replace",
3 "filename": "flower/custom_fl_aggregation.py",
4 "start_line": 353,
5 "end_line": 560,
6 "content": "class ServerAggregator:\n \"\"\"FedAM: Federated Adaptive Momentum Aggregation.\n\n Novel aggregation strategy combining:\n 1. Server-side momentum (exponential moving average) of updates to smooth trajectory.\n 2. Per-parameter adaptive scaling based on variance (similar to Adam) for non-uniform learning rates.\n 3. Adaptive client weighting using loss reduction and consistency with momentum.\n 4. Outlier detection via Median Absolute Deviation (MAD) to reject harmful updates.\n\n Designed to handle non-IID data by reducing client drift and adapting to heterogeneous updates.\n \"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n # Momentum buffer for each parameter\n self.momentum = OrderedDict()\n # Variance buffer for each parameter (for adaptive scaling)\n self.variance = OrderedDict()\n # Initialize buffers with zeros\n for name, param in global_model.state_dict().items():\n self.momentum[name] = torch.zeros_like(param)\n self.variance[name] = torch.zeros_like(param)\n # Hyperparameters\n self.beta1 = 0.9 # momentum decay\n self.beta2 = 0.999 # variance decay\n self.epsilon = 1e-8\n self.server_lr = 0.01 # server learning rate\n self.outlier_threshold = 3.0 # MAD threshold for outlier rejection\n # Track per-client loss history for adaptive weighting\n self.client_loss_history = {}\n # Track selected clients for the current round\n self.selected_clients = []\n self.round_num = 0\n\n def aggregate(self, global_state_dict, client_updates, round_num):\n \"\"\"Aggregate client updates with adaptive momentum.\"\"\"\n self.round_num = round_num\n \n # Update loss history for selected clients\n for idx, (_, _, avg_loss) in zip(self.selected_clients, client_updates):\n if idx not in self.client_loss_history:\n self.client_loss_history[idx] = []\n self.client_loss_history[idx].append(avg_loss)\n # Keep only recent history (last 5 rounds)\n if len(self.client_loss_history[idx]) > 5:\n self.client_loss_history[idx].pop(0)\n \n total_samples = sum(n for _, n, _ in client_updates)\n \n # Compute per-client deltas (update directions)\n client_deltas = []\n client_weights = []\n for state_dict, n_samples, avg_loss in client_updates:\n delta = OrderedDict()\n for key in global_state_dict.keys():\n delta[key] = state_dict[key] - global_state_dict[key]\n client_deltas.append(delta)\n # Base weight proportional to sample size\n base_weight = n_samples / total_samples\n client_weights.append(base_weight)\n\n # Detect and remove outlier updates using MAD\n filtered_deltas, filtered_weights = self._filter_outliers(\n client_deltas, client_weights, global_state_dict\n )\n \n # Compute adaptive weights based on loss reduction and momentum consistency\n adaptive_weights = self._compute_adaptive_weights(\n filtered_deltas, filtered_weights, client_updates\n )\n \n # Compute weighted average delta\n avg_delta = OrderedDict()\n for key in global_state_dict.keys():\n weighted_sum = torch.zeros_like(global_state_dict[key])\n for i, delta in enumerate(filtered_deltas):\n weighted_sum += delta[key] * adaptive_weights[i]\n avg_delta[key] = weighted_sum\n\n # Update momentum and variance buffers (Adam-like)\n for key in global_state_dict.keys():\n # Update momentum\n self.momentum[key] = self.beta1 * self.momentum[key] + (1 - self.beta1) * avg_delta[key]\n # Update variance\n self.variance[key] = self.beta2 * self.variance[key] + (1 - self.beta2) * avg_delta[key].pow(2)\n\n # Bias correction\n momentum_corrected = OrderedDict()\n variance_corrected = OrderedDict()\n for key in global_state_dict.keys():\n momentum_corrected[key] = self.momentum[key] / (1 - self.beta1 ** (round_num + 1))\n variance_corrected[key] = self.variance[key] / (1 - self.beta2 ** (round_num + 1))\n\n # Compute update with adaptive scaling\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n adaptive_update = momentum_corrected[key] / (variance_corrected[key].sqrt() + self.epsilon)\n new_state[key] = global_state_dict[key] + self.server_lr * adaptive_update\n\n return new_state\n\n def _filter_outliers(self, client_deltas, client_weights, global_state_dict):\n \"\"\"Filter outlier updates using Median Absolute Deviation (MAD).\"\"\"\n if len(client_deltas) <= 1:\n return client_deltas, client_weights\n \n # Compute L2 norm of each delta\n norms = []\n for delta in client_deltas:\n norm_sq = 0.0\n for key in delta.keys():\n norm_sq += delta[key].float().pow(2).sum().item()\n norms.append(np.sqrt(norm_sq))\n \n # Compute median and MAD\n median_norm = np.median(norms)\n mad = np.median(np.abs(norms - median_norm))\n if mad == 0:\n mad = 1e-6\n \n # Identify outliers (norm far from median)\n filtered_deltas = []\n filtered_weights = []\n for i, norm in enumerate(norms):\n z_score = abs(norm - median_norm) / mad\n if z_score < self.outlier_threshold:\n filtered_deltas.append(client_deltas[i])\n filtered_weights.append(client_weights[i])\n \n # If all filtered out, keep all\n if len(filtered_deltas) == 0:\n return client_deltas, client_weights\n \n # Normalize weights\n total = sum(filtered_weights)\n filtered_weights = [w / total for w in filtered_weights]\n return filtered_deltas, filtered_weights\n\n def _compute_adaptive_weights(self, deltas, base_weights, client_updates):\n \"\"\"Compute adaptive weights based on loss reduction and momentum consistency.\"\"\"\n # If first round, use base weights\n if self.round_num == 0:\n return base_weights\n \n # Compute consistency with momentum direction\n consistency_scores = []\n for delta in deltas:\n cos_sim = 0.0\n total_params = 0\n for key in delta.keys():\n if self.momentum[key].numel() > 0:\n # Flatten tensors for cosine similarity\n d_flat = delta[key].flatten().float()\n m_flat = self.momentum[key].flatten().float()\n if torch.norm(d_flat) > 1e-12 and torch.norm(m_flat) > 1e-12:\n cos_sim += torch.dot(d_flat, m_flat) / (torch.norm(d_flat) * torch.norm(m_flat))\n total_params += 1\n if total_params > 0:\n consistency_scores.append(cos_sim.item() / total_params)\n else:\n consistency_scores.append(0.0)\n \n # Normalize consistency scores to [0, 1]\n min_cons = min(consistency_scores)\n max_cons = max(consistency_scores)\n if max_cons - min_cons > 1e-6:\n consistency_norm = [(c - min_cons) / (max_cons - min_cons) for c in consistency_scores]\n else:\n consistency_norm = [0.5 for _ in consistency_scores]\n \n # Combine base weight, consistency, and loss reduction\n adaptive_weights = []\n for i, (_, _, avg_loss) in enumerate(client_updates):\n # Loss reduction factor (lower loss -> higher weight)\n loss_factor = 1.0 / (1.0 + avg_loss) # simple inverse loss\n # Combine factors\n adaptive = base_weights[i] * (0.5 + 0.5 * consistency_norm[i]) * loss_factor\n adaptive_weights.append(adaptive)\n \n # Normalize to sum to 1\n total = sum(adaptive_weights)\n if total > 0:\n adaptive_weights = [w / total for w in adaptive_weights]\n else:\n adaptive_weights = base_weights\n \n return adaptive_weights\n\n def select_clients(self, num_available, num_to_select, round_num):\n \"\"\"Select clients with higher loss reduction (active learning).\"\"\"\n # For first few rounds, random selection\n if round_num < 5 or not self.client_loss_history:\n self.selected_clients = random.sample(range(num_available), min(num_to_select, num_available))\n return self.selected_clients\n \n # Compute average loss reduction per client (if history available)\n client_scores = []\n for client_idx in range(num_available):\n if client_idx in self.client_loss_history:\n losses = self.client_loss_history[client_idx]\n if len(losses) >= 2:\n # Score = average loss reduction (negative if loss increased)\n reduction = losses[-2] - losses[-1] # positive if loss decreased\n client_scores.append((client_idx, reduction))\n else:\n client_scores.append((client_idx, 0.0))\n else:\n client_scores.append((client_idx, 0.0))\n \n # Sort by loss reduction (descending)\n client_scores.sort(key=lambda x: x[1], reverse=True)\n selected = [idx for idx, _ in client_scores[:num_to_select]]\n \n # Ensure we have exactly num_to_select (pad with random if needed)\n if len(selected) < num_to_select:\n remaining = [i for i in range(num_available) if i not in selected]\n selected.extend(random.sample(remaining, num_to_select - len(selected)))\n \n self.selected_clients = selected\n return selected"
7}OK: Replaced lines 353..560 with 220 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–572.
[Current file: flower/custom_fl_aggregation.py | editable: 353–572 | total: 669 lines]
353 class ServerAggregator:
354 """FedAM: Federated Adaptive Momentum Aggregation.
355
... (214 more lines) ...
570
571 self.selected_clients = selected
572 return selected[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### cifar10 (scripts/cifar10.sh) Training metrics (last rounds): TRAIN_METRICS round=1 avg_loss=1.0773 round_time=5.4s ### femnist (scripts/femnist.sh) Training metrics (last rounds): TRAIN_METRICS round=1 avg_loss=0.4806 round_time=38.1s
Tool Input
1{
2 "op": "replace",
3 "filename": "flower/custom_fl_aggregation.py",
4 "start_line": 353,
5 "end_line": 572,
6 "content": "class ServerAggregator:\n \"\"\"FedAM: Federated Adaptive Momentum Aggregation.\n\n Novel aggregation strategy combining:\n 1. Server-side momentum (exponential moving average) of updates to smooth trajectory.\n 2. Per-parameter adaptive scaling based on variance (similar to Adam) for non-uniform learning rates.\n 3. Adaptive client weighting using loss reduction and consistency with momentum.\n 4. Outlier detection via Median Absolute Deviation (MAD) to reject harmful updates.\n\n Designed to handle non-IID data by reducing client drift and adapting to heterogeneous updates.\n \"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n # Momentum buffer for each parameter\n self.momentum = OrderedDict()\n # Variance buffer for each parameter (for adaptive scaling)\n self.variance = OrderedDict()\n # Initialize buffers with zeros\n for name, param in global_model.state_dict().items():\n self.momentum[name] = torch.zeros_like(param)\n self.variance[name] = torch.zeros_like(param)\n # Hyperparameters\n self.beta1 = 0.9 # momentum decay\n self.beta2 = 0.999 # variance decay\n self.epsilon = 1e-8\n self.server_lr = 0.01 # server learning rate\n self.outlier_threshold = 3.0 # MAD threshold for outlier rejection\n # Track per-client loss history for adaptive weighting\n self.client_loss_history = {}\n # Track selected clients for the current round\n self.selected_clients = []\n self.round_num = 0\n\n def aggregate(self, global_state_dict, client_updates, round_num):\n \"\"\"Aggregate client updates with adaptive momentum.\"\"\"\n self.round_num = round_num\n \n # Update loss history for selected clients\n for idx, (_, _, avg_loss) in zip(self.selected_clients, client_updates):\n if idx not in self.client_loss_history:\n self.client_loss_history[idx] = []\n self.client_loss_history[idx].append(avg_loss)\n # Keep only recent history (last 5 rounds)\n if len(self.client_loss_history[idx]) > 5:\n self.client_loss_history[idx].pop(0)\n \n total_samples = sum(n for _, n, _ in client_updates)\n \n # Compute per-client deltas (update directions)\n client_deltas = []\n client_weights = []\n for state_dict, n_samples, avg_loss in client_updates:\n delta = OrderedDict()\n for key in global_state_dict.keys():\n delta[key] = state_dict[key] - global_state_dict[key]\n client_deltas.append(delta)\n # Base weight proportional to sample size\n base_weight = n_samples / total_samples\n client_weights.append(base_weight)\n\n # Detect and remove outlier updates using MAD\n kept_indices = self._detect_outliers(client_deltas)\n filtered_deltas = [client_deltas[i] for i in kept_indices]\n filtered_updates = [client_updates[i] for i in kept_indices]\n filtered_weights = [client_weights[i] for i in kept_indices]\n # Normalize weights after filtering\n total_filtered_weight = sum(filtered_weights)\n if total_filtered_weight > 0:\n filtered_weights = [w / total_filtered_weight for w in filtered_weights]\n \n # Compute adaptive weights based on loss reduction and momentum consistency\n adaptive_weights = self._compute_adaptive_weights(\n filtered_deltas, filtered_weights, filtered_updates\n )\n \n # Compute weighted average delta\n avg_delta = OrderedDict()\n for key in global_state_dict.keys():\n weighted_sum = torch.zeros_like(global_state_dict[key])\n for i, delta in enumerate(filtered_deltas):\n weighted_sum += delta[key] * adaptive_weights[i]\n avg_delta[key] = weighted_sum\n\n # Update momentum and variance buffers (Adam-like)\n for key in global_state_dict.keys():\n # Update momentum\n self.momentum[key] = self.beta1 * self.momentum[key] + (1 - self.beta1) * avg_delta[key]\n # Update variance\n self.variance[key] = self.beta2 * self.variance[key] + (1 - self.beta2) * avg_delta[key].pow(2)\n\n # Bias correction\n momentum_corrected = OrderedDict()\n variance_corrected = OrderedDict()\n for key in global_state_dict.keys():\n momentum_corrected[key] = self.momentum[key] / (1 - self.beta1 ** (round_num + 1))\n variance_corrected[key] = self.variance[key] / (1 - self.beta2 ** (round_num + 1))\n\n # Compute update with adaptive scaling\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n adaptive_update = momentum_corrected[key] / (variance_corrected[key].sqrt() + self.epsilon)\n new_state[key] = global_state_dict[key] + self.server_lr * adaptive_update\n\n return new_state\n\n def _detect_outliers(self, client_deltas):\n \"\"\"Detect outlier updates using Median Absolute Deviation (MAD). Return indices of kept updates.\"\"\"\n if len(client_deltas) <= 1:\n return list(range(len(client_deltas)))\n \n # Compute L2 norm of each delta\n norms = []\n for delta in client_deltas:\n norm_sq = 0.0\n for key in delta.keys():\n norm_sq += delta[key].float().pow(2).sum().item()\n norms.append(np.sqrt(norm_sq))\n \n # Compute median and MAD\n median_norm = np.median(norms)\n mad = np.median(np.abs(norms - median_norm))\n if mad == 0:\n mad = 1e-6\n \n # Identify outliers (norm far from median)\n kept_indices = []\n for i, norm in enumerate(norms):\n z_score = abs(norm - median_norm) / mad\n if z_score < self.outlier_threshold:\n kept_indices.append(i)\n \n # If all filtered out, keep all\n if len(kept_indices) == 0:\n kept_indices = list(range(len(client_deltas)))\n return kept_indices\n\n def _compute_adaptive_weights(self, deltas, base_weights, client_updates):\n \"\"\"Compute adaptive weights based on loss reduction and momentum consistency.\"\"\"\n # If first round, use base weights\n if self.round_num == 0:\n return base_weights\n \n # Compute consistency with momentum direction\n consistency_scores = []\n for delta in deltas:\n cos_sim = 0.0\n total_params = 0\n for key in delta.keys():\n if self.momentum[key].numel() > 0:\n # Flatten tensors for cosine similarity\n d_flat = delta[key].flatten().float()\n m_flat = self.momentum[key].flatten().float()\n if torch.norm(d_flat) > 1e-12 and torch.norm(m_flat) > 1e-12:\n cos_sim += torch.dot(d_flat, m_flat) / (torch.norm(d_flat) * torch.norm(m_flat))\n total_params += 1\n if total_params > 0:\n consistency_scores.append(cos_sim.item() / total_params)\n else:\n consistency_scores.append(0.0)\n \n # Normalize consistency scores to [0, 1]\n min_cons = min(consistency_scores)\n max_cons = max(consistency_scores)\n if max_cons - min_cons > 1e-6:\n consistency_norm = [(c - min_cons) / (max_cons - min_cons) for c in consistency_scores]\n else:\n consistency_norm = [0.5 for _ in consistency_scores]\n \n # Combine base weight, consistency, and loss reduction\n adaptive_weights = []\n for i, (_, _, avg_loss) in enumerate(client_updates):\n # Loss reduction factor (lower loss -> higher weight)\n loss_factor = 1.0 / (1.0 + avg_loss) # simple inverse loss\n # Combine factors\n adaptive = base_weights[i] * (0.5 + 0.5 * consistency_norm[i]) * loss_factor\n adaptive_weights.append(adaptive)\n \n # Normalize to sum to 1\n total = sum(adaptive_weights)\n if total > 0:\n adaptive_weights = [w / total for w in adaptive_weights]\n else:\n adaptive_weights = base_weights\n \n return adaptive_weights\n\n def select_clients(self, num_available, num_to_select, round_num):\n \"\"\"Select clients with higher loss reduction (active learning).\"\"\"\n # For first few rounds, random selection\n if round_num < 5 or not self.client_loss_history:\n self.selected_clients = random.sample(range(num_available), min(num_to_select, num_available))\n return self.selected_clients\n \n # Compute average loss reduction per client (if history available)\n client_scores = []\n for client_idx in range(num_available):\n if client_idx in self.client_loss_history:\n losses = self.client_loss_history[client_idx]\n if len(losses) >= 2:\n # Score = average loss reduction (negative if loss increased)\n reduction = losses[-2] - losses[-1] # positive if loss decreased\n client_scores.append((client_idx, reduction))\n else:\n client_scores.append((client_idx, 0.0))\n else:\n client_scores.append((client_idx, 0.0))\n \n # Sort by loss reduction (descending)\n client_scores.sort(key=lambda x: x[1], reverse=True)\n selected = [idx for idx, _ in client_scores[:num_to_select]]\n \n # Ensure we have exactly num_to_select (pad with random if needed)\n if len(selected) < num_to_select:\n remaining = [i for i in range(num_available) if i not in selected]\n selected.extend(random.sample(remaining, num_to_select - len(selected)))\n \n self.selected_clients = selected\n return \n... (truncated)"
7}OK: Replaced lines 353..572 with 219 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–571.
[Current file: flower/custom_fl_aggregation.py | editable: 353–571 | total: 668 lines]
353 class ServerAggregator:
354 """FedAM: Federated Adaptive Momentum Aggregation.
355
... (213 more lines) ...
569
570 self.selected_clients = selected
571 return selected[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### cifar10 (scripts/cifar10.sh) Training metrics (last rounds): TRAIN_METRICS round=180 avg_loss=0.0870 round_time=5.0s TRAIN_METRICS round=185 avg_loss=0.0943 round_time=5.0s TRAIN_METRICS round=190 avg_loss=0.0866 round_time=5.0s TRAIN_METRICS round=195 avg_loss=0.0860 round_time=4.9s TRAIN_METRICS round=200 avg_loss=0.0882 round_time=5.0s Evaluation progress: EVAL round=160 test_loss=9.5152 test_accuracy=0.3915 elapsed=815s EVAL round=170 test_loss=10.4654 test_accuracy=0.3956 elapsed=869s EVAL round=180 test_loss=10.7825 test_accuracy=0.4039 elapsed=921s EVAL round=190 test_loss=10.9166 test_accuracy=0.4054 elapsed=973s EVAL round=200 test_loss=11.4390 test_accuracy=0.4035 elapsed=1025s Final metrics (cifar10): test_accuracy_cifar10: 0.4035 test_loss_cifar10: 11.4390 best_accuracy_cifar10: 0.4054 ### femnist (scripts/femnist.sh) Training metrics (last rounds): TRAIN_METRICS round=180 av
Tool Input
1{
2 "op": "replace",
3 "filename": "flower/custom_fl_aggregation.py",
4 "start_line": 353,
5 "end_line": 571,
6 "content": "class ServerAggregator:\n \"\"\"FedAdaptive: Federated Adaptive Aggregation with Momentum.\n\n Combines:\n 1. Server-side momentum (FedAvgM) to smooth updates and accelerate convergence.\n 2. Adaptive client weighting based on loss, consistency with momentum, and sample size.\n 3. Client selection favoring clients with higher loss (to improve underperforming clients).\n \n Simpler and more robust than full Adam-like scaling, focusing on non-IID robustness.\n \"\"\"\n\n def __init__(self, global_model, args):\n self.args = args\n # Momentum buffer for each parameter\n self.momentum = OrderedDict()\n for name, param in global_model.state_dict().items():\n self.momentum[name] = torch.zeros_like(param)\n # Hyperparameters\n self.beta = 0.9 # momentum decay\n self.server_lr = 1.0 # server learning rate (like FedAvg)\n # Track per-client loss history for selection\n self.client_loss_history = {}\n # Track selected clients for the current round\n self.selected_clients = []\n self.round_num = 0\n\n def aggregate(self, global_state_dict, client_updates, round_num):\n \"\"\"Aggregate client updates with adaptive weighting and momentum.\"\"\"\n self.round_num = round_num\n \n # Update loss history for selected clients\n for idx, (_, _, avg_loss) in zip(self.selected_clients, client_updates):\n if idx not in self.client_loss_history:\n self.client_loss_history[idx] = []\n self.client_loss_history[idx].append(avg_loss)\n # Keep only recent history (last 5 rounds)\n if len(self.client_loss_history[idx]) > 5:\n self.client_loss_history[idx].pop(0)\n \n total_samples = sum(n for _, n, _ in client_updates)\n \n # Compute per-client deltas (update directions)\n client_deltas = []\n for state_dict, n_samples, avg_loss in client_updates:\n delta = OrderedDict()\n for key in global_state_dict.keys():\n delta[key] = state_dict[key] - global_state_dict[key]\n client_deltas.append(delta)\n\n # Compute adaptive weights\n weights = self._compute_weights(client_deltas, client_updates, total_samples)\n \n # Compute weighted average delta\n avg_delta = OrderedDict()\n for key in global_state_dict.keys():\n weighted_sum = torch.zeros_like(global_state_dict[key])\n for i, delta in enumerate(client_deltas):\n weighted_sum += delta[key] * weights[i]\n avg_delta[key] = weighted_sum\n\n # Update momentum\n for key in global_state_dict.keys():\n self.momentum[key] = self.beta * self.momentum[key] + (1 - self.beta) * avg_delta[key]\n\n # Apply momentum update\n new_state = OrderedDict()\n for key in global_state_dict.keys():\n new_state[key] = global_state_dict[key] + self.server_lr * self.momentum[key]\n\n return new_state\n\n def _compute_weights(self, deltas, client_updates, total_samples):\n \"\"\"Compute adaptive weights for each client.\"\"\"\n n = len(deltas)\n base_weights = [n_samples / total_samples for _, n_samples, _ in client_updates]\n \n # If first round, use only base weights\n if self.round_num == 0:\n return base_weights\n \n # Compute consistency with momentum direction (cosine similarity per parameter, averaged)\n consistency_scores = []\n for delta in deltas:\n cos_sim = 0.0\n total_params = 0\n for key in delta.keys():\n if self.momentum[key].numel() > 0:\n d_flat = delta[key].flatten().float()\n m_flat = self.momentum[key].flatten().float()\n norm_d = torch.norm(d_flat)\n norm_m = torch.norm(m_flat)\n if norm_d > 1e-12 and norm_m > 1e-12:\n cos_sim += torch.dot(d_flat, m_flat) / (norm_d * norm_m)\n total_params += 1\n if total_params > 0:\n consistency_scores.append(cos_sim.item() / total_params)\n else:\n consistency_scores.append(0.0)\n \n # Normalize consistency to [0, 1]\n min_cons = min(consistency_scores)\n max_cons = max(consistency_scores)\n if max_cons - min_cons > 1e-6:\n consistency_norm = [(c - min_cons) / (max_cons - min_cons) for c in consistency_scores]\n else:\n consistency_norm = [0.5 for _ in consistency_scores]\n \n # Compute loss factor (lower loss -> higher weight)\n loss_factors = [1.0 / (1.0 + avg_loss) for _, _, avg_loss in client_updates]\n \n # Combine: weight = base_weight * (0.5 + 0.5 * consistency) * loss_factor\n combined_weights = []\n for i in range(n):\n w = base_weights[i] * (0.5 + 0.5 * consistency_norm[i]) * loss_factors[i]\n combined_weights.append(w)\n \n # Normalize to sum to 1\n total = sum(combined_weights)\n if total > 0:\n combined_weights = [w / total for w in combined_weights]\n else:\n combined_weights = base_weights\n \n return combined_weights\n\n def select_clients(self, num_available, num_to_select, round_num):\n \"\"\"Select clients with higher recent loss (to improve underperforming clients).\"\"\"\n # For first few rounds, random selection\n if round_num < 5 or not self.client_loss_history:\n self.selected_clients = random.sample(range(num_available), min(num_to_select, num_available))\n return self.selected_clients\n \n # Compute average loss per client over recent rounds\n client_scores = []\n for client_idx in range(num_available):\n if client_idx in self.client_loss_history:\n losses = self.client_loss_history[client_idx]\n avg_loss = sum(losses) / len(losses)\n client_scores.append((client_idx, avg_loss))\n else:\n client_scores.append((client_idx, float('inf'))) # unseen clients prioritized\n \n # Sort by loss descending (higher loss -> selected first)\n client_scores.sort(key=lambda x: x[1], reverse=True)\n selected = [idx for idx, _ in client_scores[:num_to_select]]\n \n # Ensure we have exactly num_to_select (pad with random if needed)\n if len(selected) < num_to_select:\n remaining = [i for i in range(num_available) if i not in selected]\n selected.extend(random.sample(remaining, num_to_select - len(selected)))\n \n self.selected_clients = selected\n return selected"
7}OK: Replaced lines 353..571 with 153 line(s) in flower/custom_fl_aggregation.py. Editable range: 353–505.
[Current file: flower/custom_fl_aggregation.py | editable: 353–505 | total: 602 lines]
353 class ServerAggregator:
354 """FedAdaptive: Federated Adaptive Aggregation with Momentum.
355
... (147 more lines) ...
503
504 self.selected_clients = selected
505 return selected[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final)
[NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer.
### cifar10 (scripts/cifar10.sh)
[COMMAND FAILED exit=1]
Device: cuda
Loading dataset: cifar10
Number of clients: 100
Traceback (most recent call last):
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/deepseek-reasoner_20260414_110948/flower/custom_fl_aggregation.py", line 602, in <module>
run_fl_simulation(args)
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/deepseek-reasoner_20260414_110948/flower/custom_fl_aggregation.py", line 548, in run_fl_simulation
updated_state, n_samples, avg_loss = client_local_train(
File "/srv/home/bohanlyu/MLS-Bench/vendor/workspace/ml-federated-aggregation/deepseek-reasoner_20260414_110948/flower/custom_fl_aggregation.py", line 274, in client_local_train
model.to(
Tool Input
1{
2 "n": 2
3}[submit] Submitting result from test #2 as final. ### cifar10 (scripts/cifar10.sh) Training metrics (last rounds): TRAIN_METRICS round=180 avg_loss=0.0870 round_time=5.0s TRAIN_METRICS round=185 avg_loss=0.0943 round_time=5.0s TRAIN_METRICS round=190 avg_loss=0.0866 round_time=5.0s TRAIN_METRICS round=195 avg_loss=0.0860 round_time=4.9s TRAIN_METRICS round=200 avg_loss=0.0882 round_time=5.0s Evaluation progress: EVAL round=160 test_loss=9.5152 test_accuracy=0.3915 elapsed=815s EVAL round=170 test_loss=10.4654 test_accuracy=0.3956 elapsed=869s EVAL round=180 test_loss=10.7825 test_accuracy=0.4039 elapsed=921s EVAL round=190 test_loss=10.9166 test_accuracy=0.4054 elapsed=973s EVAL round=200 test_loss=11.4390 test_accuracy=0.4035 elapsed=1025s Final metrics (cifar10): test_accuracy_cifar10: 0.4035 test_loss_cifar10: 11.4390 best_accuracy_cifar10: 0.4054 ### femnist (scripts/femnist.sh) Training metrics (last rounds): TRAIN_METRICS round=180 avg_loss=0.0016 round_time=58.9s TRAIN_METRICS