mlsys-moe-load-balance

ML Systemseplbrigorous codebase

Description

MoE Expert Parallelism Load Balancing

Research Question

Design an efficient expert placement algorithm for Mixture-of-Experts (MoE) inference that assigns expert replicas to GPUs to minimize load imbalance while keeping the rebalancing algorithm runtime low.

Background

In MoE models (e.g., DeepSeek-V2/V3), different experts receive different amounts of traffic depending on the input distribution. During inference, experts are distributed across GPUs, and load imbalance causes some GPUs to become bottlenecks. The Expert Parallelism Load Balancer (EPLB) runs periodically to rebalance expert placement as workload patterns change.

The standard three-stage hierarchical algorithm is:

  1. Group-to-node packing: Distribute expert groups across server nodes to balance inter-node load
  2. Expert replication: Create additional replicas of popular (hot) experts within each node
  3. Replica-to-GPU packing: Assign physical expert replicas to GPUs within each node

The baseline greedy bin-packing approach uses Python for-loops to find optimal assignments, which is correct but slow (~540ms for medium configs). Vectorized tensor operations can achieve equivalent balance quality with orders of magnitude faster runtime.

Task

Modify the editable section of custom_eplb.py to implement an expert placement algorithm. You must implement:

  • balanced_packing(weight, num_packs) — pack weighted items into balanced packs
  • replicate_experts(weight, num_phy) — decide expert replication counts and assign physical IDs
  • rebalance_experts(weight, num_replicas, num_groups, num_nodes, num_gpus) — main entry point combining all three stages

Interface

def rebalance_experts(weight, num_replicas, num_groups, num_nodes, num_gpus):
    """
    Args:
        weight: [L, E] tensor — token load per expert per layer
        num_replicas: total physical expert slots (multiple of num_gpus)
        num_groups: number of expert groups (divisor of E)
        num_nodes: number of server nodes
        num_gpus: total GPUs (multiple of num_nodes)

    Returns:
        phy2log: [L, num_replicas] — logical expert ID for each physical slot
        log2phy: [L, E, max_rep] — physical IDs per expert (-1 = unused)
        logcnt: [L, E] — number of physical replicas per logical expert
    """

Constraints:

  • E % num_groups == 0, num_groups % num_nodes == 0
  • num_gpus % num_nodes == 0, num_replicas % num_gpus == 0
  • Each GPU must receive exactly num_replicas // num_gpus physical experts
  • Every logical expert must have at least one replica
  • logcnt.sum(-1) must equal num_replicas for every layer

Evaluation

Three MoE configurations at different scales:

  • moe-small: 64 experts, 8 GPUs, 1 node, 128 replicas (2x)
  • moe-medium: 128 experts, 16 GPUs, 2 nodes, 256 replicas (2x)
  • moe-large: 256 experts, 32 GPUs, 4 nodes, 512 replicas (2x)

Metrics per configuration:

  • balance: avg_tokens_per_gpu / max_tokens_per_gpu (higher is better, 1.0 = perfect)
  • runtime_ms: median time to execute the placement algorithm (lower is better)

Both metrics matter: a fast algorithm with poor balance is not useful; a perfectly balanced algorithm that takes too long to recompute delays inference.

Code

custom_eplb.py
EditableRead-only
1"""
2MoE Expert Parallelism Load Balancing (EPLB) Benchmark
3======================================================
4
5Design an efficient expert placement algorithm for Mixture-of-Experts (MoE)
6inference that assigns expert replicas to GPUs to minimize load imbalance
7while keeping the rebalancing algorithm runtime low.
8
9Metrics:
10 - balance: avg_tokens_per_gpu / max_tokens_per_gpu (higher is better, 1.0 = perfect)
11 - runtime_ms: time to run the placement algorithm (lower is better)
12
13Available libraries: torch, numpy
14"""
15

Results

ModelTypebalance deepseek-v3 runtime ms deepseek-v3 balance qwen3-moe runtime ms qwen3-moe balance deepseek-v2 runtime ms deepseek-v2 balance gpt-oss runtime ms gpt-oss
flat_zigzagbaseline0.9805.8510.9752.6640.9832.5520.9951.846
greedybaseline0.680245.1640.939103.3230.927182.2430.987132.558
zigzagbaseline0.6601.5080.9171.2560.9071.2260.9762.195