Agent Conversation: openai/gpt-5.4

Model: openai/gpt-5.4
Total steps: 5
Messages: 12
UserStep 0

Task: llm-dllm-demask-strategy

Masked Diffusion LM: Demasking Strategy

Research Question

Design a better demasking (decoding) strategy for masked diffusion language models. The strategy must generalize across different decoding regimes:

  • Block-based semi-autoregressive decoding for downstream task accuracy (LLaDA on MATH/HumanEval, following the KLASS protocol)
  • Fully-parallel decoding for open-ended text generation (Dream on prefix-conditioned C4 continuation, measured by perplexity / diversity)

Background

Masked diffusion LMs (LLaDA, Dream) generate by starting from a fully masked generation region and iteratively unmasking over steps denoising iterations. A demasking strategy decides at each step:

  1. Schedule: how many tokens to unmask
  2. Position selection: which masked positions to unmask
  3. Token assignment: what token id to place

Decoding can be semi-autoregressive (when block_length < gen_length, process one block at a time) or fully parallel (block_length == gen_length, all positions decoded together).

What You Can Modify

Edit the DemaskDecoder class in LLaDA/custom_demask_eval.py (lines 59-151).

Interface

class DemaskDecoder:
    def __init__(self, mask_id, temperature=0.0,
                 conf_threshold=0.9, kl_threshold=0.01, history_length=2):
        ...

    @torch.no_grad()
    def decode(self, model, input_ids, gen_length, steps, block_length):
        # Returns (x_output [1, prompt_len + gen_length], used_steps)

get_num_transfer_tokens(mask, steps) is available outside the editable region — returns the uniform schedule (mask.sum() // steps per step).

Constraints

  • gen_length % block_length == 0. When equal, decoding is fully parallel.
  • Process blocks sequentially (no early-decoding into later blocks).
  • Always return [1, prompt_len + gen_length].
  • used_steps counts model forward passes (lower = more efficient).

Evaluation

Benchmarks

LabelTaskModelgen_lenstepsblock_lenMetrics
llada-mathMATH-500LLaDA-8B-Instruct25625664accuracy + avg_steps
llada-humanevalHumanEval (164)LLaDA-8B-Instruct25625664accuracy + avg_steps
dream-textC4 prefix-continuation (256 samples, 32-tok prefix → 224-tok continuation)Dream-v0-Instruct-7B224256224gen_ppl + MAUVE + entropy + rep2 + avg_steps

Metrics

MetricDirectionWhereDescription
accuracymath/humanevalexact-match (MATH) or pass@1 (HumanEval)
gen_ppltextConditional perplexity via GPT-2-Large
mauvetextDistributional similarity to C4 reference text
entropytextBigram entropy (lexical diversity)
rep2textRepeated bigram ratio
avg_stepsallActual model forward passes used

AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "content": "class DemaskDecoder:\n    \"\"\"Consensus-Frontier Adaptive Demasking (CFAD).\n\n    Key ideas:\n      1) use token-level consensus across recent denoising steps rather than\n         full-vocab KL only;\n      2) always transfer at least the uniform quota, avoiding KLASS-style stalls\n         when only a few positions are marked ready;\n      3) add a soft left-to-right frontier bias in semi-autoregressive blocks,\n         while turning that bias off for fully-parallel decoding.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n\n        # Hardcoded algorithmic constants.\n        self.relaxed_conf_factor = 0.75\n        self.min_ready_margin = 0.05\n        self.ready_bonus = 1.0\n        self.frontier_strength = 0.15\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        prompt_len = input_ids.shape[1]\n        x = torch.full((1, prompt_len + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :prompt_len] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n        used = 0\n\n        semi_ar_strength = max(0.0, 1.0 - (block_length / max(gen_length, 1)))\n        denom = float(max(block_length - 1, 1))\n        pos_bias = (1.0 - torch.arange(block_length, device=x.device,\n                                       dtype=torch.float32) / denom).view(1, -1)\n        pos_bias = pos_bias * (semi_ar_strength * self.frontier_strength)\n        relaxed_conf = max(0.55, self.conf_threshold * self.relaxed_conf_factor)\n\n        for b in range(num_blocks):\n            bs = prompt_len + b * block_length\n            be = bs + block_length\n            num_xfer = get_num_transfer_tokens((x[:, bs:be] == mid), steps_per_block)\n\n            # Block-local history: much cheaper than storing full-vocab histories.\n            tok_hist = torch.full((1, block_length, self.history_length), -1,\n                                  dtype=torch.long, device=x.device)\n            conf_hist = torch.zeros((1, block_length, self.history_length),\n                                    dtype=torch.float32, device=x.device)\n\n            for step in range(steps_per_block):\n                mask_block = (x[:, bs:be] == mid)\n                if not mask_block.any():\n                    break\n\n                logits = model(x).logits[:, bs:be, :].to(torch.float32)\n                top2_logits, top2_idx = torch.topk(logits, k=2, dim=-1)\n                x0 = top2_idx[..., 0]\n                log_z = torch.logsumexp(logits, dim=-1)\n                conf = torch.exp(top2_logits[..., 0] - log_z)\n                p2 = torch.exp(top2_logits[..., 1] - log_z)\n                margin = conf - p2\n\n                valid_hist = conf_hist > 0\n                hist_count = valid_hist.float().sum(-1)\n                same_tok = (tok_hist == x0.unsqueeze(-1)) & valid_hist\n                same_count = same_tok.float().sum(-1)\n                agreement = same_count / hist_count.clamp_min(1.0)\n                same_conf = (conf_hist * same_tok.float()).sum(-1) / same_count.clamp_min(1.0)\n                enough_hist = hist_count >= max(1, self.history_length - 1)\n                consensus = enough_hist & (agreement > 0.999)\n\n                ready = mask_block & (\n                    ((conf > self.conf_threshold) & (margin > self.min_ready_margin)) |\n                    (consensus & (conf > relaxed_conf) & (same_conf > relaxed_conf * 0.9))\n                )\n\n                masked_frac = mask_block.float().sum(dim=1, keepdim=True) / float(block_length)\n                score = conf + 0.75 * margin + 0.35 * same_conf + 0.25 * agreement\n                score = score + pos_bias * masked_frac + ready.float() * self.ready_bonus\n\n                xfer = torch.zeros_like(mask_block)\n                for j in range(mask_block.shape[0]):\n                    remain = int(mask_block[j].sum().item())\n                    base_k = max(1, int(num_xfer[j, step].item()))\n                    if step == steps_per_block - 1 or remain <= base_k:\n                        target_k = remain\n                    else:\n                        ready_count = int(ready[j].sum().item())\n                        maturity = 1.0 - remain / float(block_length)\n                        extra = int(max(0, ready_count - base_k) * (0.25 + 0.5 * maturity))\n                        if ready_count >= 2 * base_k:\n                            extra = max(extra, min(ready_count - base_k, base_k))\n                        target_k = min(remain, base_k + extra)\n\n                    s = score[j].clone()\n                    s[~mask_block[j]] = -float(\"inf\")\n                    _, topk = torch.topk(s, target_k)\n                    xfer[j, topk] = True\n\n                chosen = x0\n                if self.temperature > 0 and xfer.any():\n                    chosen = x0.clone()\n                    temp = max(float(self.temperature), 1e-6)\n                    samp_probs = F.softmax(logits[xfer] / temp, dim=-1)\n                    chosen[xfer] = torch.multinomial(samp_probs, 1).squeeze(-1)\n\n                x[:, bs:be] = torch.where(xfer, chosen, x[:, bs:be])\n                used += 1\n\n                tok_hist = torch.roll(tok_hist, -1, dims=-1)\n                tok_hist[..., -1] = x0\n                conf_hist = torch.roll(conf_hist, -1, dims=-1)\n                conf_hist[..., -1] = conf\n\n        return x, used",
5  "after_line": 0,
6  "start_line": 59,
7  "end_line": 151
8}
Tool ResultStep 1
OK: Replaced lines 59..151 with 125 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–183.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–183 | total: 523 lines]
    59  class DemaskDecoder:
    60      """Consensus-Frontier Adaptive Demasking (CFAD).
    61  
       ... (119 more lines) ...
   181                  conf_hist[..., -1] = conf
   182  
   183          return x, used
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.309 avg_steps=82.6
TRAIN_METRICS: math 470/500 acc=0.311 avg_steps=82.4
TRAIN_METRICS: math 480/500 acc=0.310 avg_steps=82.6
TRAIN_METRICS: math 490/500 acc=0.310 avg_steps=82.6
TRAIN_METRICS: math 500/500 acc=0.308 avg_steps=83.0
accuracy_llada-math: 0.3080
avg_steps_llada-math: 83.0400
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.450 avg_steps=72.0
TRAIN_METRICS: humaneval 130/164 pass@1=0.438 avg_steps=72.6
TRAIN_METRICS: humaneval 140/164 pass@1=0.407 avg_steps=72.6
TRAIN_METRICS: humaneval 150/164 pass@1=0.387 avg_steps=72.2
TRAIN_METRICS: humaneval 160/164 pass@1=0.381 avg_steps=71.7
accuracy_llada-humaneval: 0.3780
avg_steps_llada-humaneval: 71.6800
n_samples_llada
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "content": "class DemaskDecoder:\n    \"\"\"Consensus-Frontier Adaptive Demasking (CFAD).\n\n    Key ideas:\n      1) token-consensus stability across recent denoising steps;\n      2) a minimum-transfer schedule so decoding never stalls;\n      3) regime-aware frontier control: semi-autoregressive blocks decode more\n         locally, while fully-parallel decoding remains aggressive.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n\n        # Hardcoded algorithmic constants.\n        self.relaxed_conf_factor = 0.75\n        self.min_ready_margin = 0.05\n        self.ready_bonus = 1.0\n        self.frontier_strength = 0.25\n        self.frontier_min_width = 8\n        self.frontier_growth = 0.35\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        prompt_len = input_ids.shape[1]\n        x = torch.full((1, prompt_len + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :prompt_len] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n        used = 0\n\n        semi_ar_strength = max(0.0, 1.0 - (block_length / max(gen_length, 1)))\n        denom = float(max(block_length - 1, 1))\n        pos_ids = torch.arange(block_length, device=x.device)\n        pos_bias = (1.0 - pos_ids.to(torch.float32) / denom).view(1, -1)\n        pos_bias = pos_bias * (semi_ar_strength * self.frontier_strength)\n        relaxed_conf = max(0.55, self.conf_threshold * self.relaxed_conf_factor)\n\n        for b in range(num_blocks):\n            bs = prompt_len + b * block_length\n            be = bs + block_length\n            num_xfer = get_num_transfer_tokens((x[:, bs:be] == mid), steps_per_block)\n\n            tok_hist = torch.full((1, block_length, self.history_length), -1,\n                                  dtype=torch.long, device=x.device)\n            conf_hist = torch.zeros((1, block_length, self.history_length),\n                                    dtype=torch.float32, device=x.device)\n\n            for step in range(steps_per_block):\n                mask_block = (x[:, bs:be] == mid)\n                if not mask_block.any():\n                    break\n\n                logits = model(x).logits[:, bs:be, :].to(torch.float32)\n                top2_logits, top2_idx = torch.topk(logits, k=2, dim=-1)\n                x0 = top2_idx[..., 0]\n                log_z = torch.logsumexp(logits, dim=-1)\n                conf = torch.exp(top2_logits[..., 0] - log_z)\n                p2 = torch.exp(top2_logits[..., 1] - log_z)\n                margin = conf - p2\n\n                valid_hist = conf_hist > 0\n                hist_count = valid_hist.float().sum(-1)\n                same_tok = (tok_hist == x0.unsqueeze(-1)) & valid_hist\n                same_count = same_tok.float().sum(-1)\n                agreement = same_count / hist_count.clamp_min(1.0)\n                same_conf = (conf_hist * same_tok.float()).sum(-1) / same_count.clamp_min(1.0)\n                enough_hist = hist_count >= max(1, self.history_length - 1)\n                consensus = enough_hist & (agreement > 0.999)\n\n                ready = mask_block & (\n                    ((conf > self.conf_threshold) & (margin > self.min_ready_margin)) |\n                    (consensus & (conf > relaxed_conf) & (same_conf > relaxed_conf * 0.9))\n                )\n\n                masked_frac = mask_block.float().sum(dim=1, keepdim=True) / float(block_length)\n                score = conf + 0.75 * margin + 0.35 * same_conf + 0.25 * agreement\n                score = score + pos_bias * masked_frac + ready.float() * self.ready_bonus\n\n                xfer = torch.zeros_like(mask_block)\n                for j in range(mask_block.shape[0]):\n                    remain = int(mask_block[j].sum().item())\n                    base_k = max(1, int(num_xfer[j, step].item()))\n                    maturity = 1.0 - remain / float(block_length)\n\n                    if step == steps_per_block - 1 or remain <= base_k:\n                        target_k = remain\n                    else:\n                        ready_count = int(ready[j].sum().item())\n                        if semi_ar_strength > 0:\n                            extra_scale = max(0.0, maturity - 0.25) / 0.75\n                            extra = int(max(0, ready_count - base_k) * extra_scale)\n                            if ready_count >= 4 * base_k and maturity > 0.6:\n                                extra = max(extra, min(ready_count - base_k, 2 * base_k))\n                        else:\n                            extra = int(max(0, ready_count - base_k) * (0.35 + 0.65 * maturity))\n                            if ready_count >= 2 * base_k:\n                                extra = max(extra, min(ready_count - base_k, base_k))\n                        target_k = min(remain, base_k + extra)\n\n                    s = score[j].clone()\n                    s[~mask_block[j]] = -float(\"inf\")\n\n                    if semi_ar_strength > 0 and target_k < remain:\n                        filled = block_length - remain\n                        width = max(self.frontier_min_width,\n                                    int(block_length * (0.15 + self.frontier_growth * maturity)))\n                        frontier = min(block_length, filled + width)\n                        eligible = pos_ids < frontier\n                        if int((mask_block[j] & eligible).sum().item()) >= target_k:\n                            s[~eligible] = -float(\"inf\")\n\n                    _, topk = torch.topk(s, target_k)\n                    xfer[j, topk] = True\n\n                chosen = x0\n                if self.temperature > 0 and xfer.any():\n                    chosen = x0.clone()\n                    temp = max(float(self.temperature), 1e-6)\n                    samp_probs = F.softmax(logits[xfer] / temp, dim=-1)\n                    chosen[xfer] = torch.multinomial(samp_probs, 1).squeeze(-1)\n\n                x[:, bs:be] = torch.where(xfer, chosen, x[:, bs:be])\n                used += 1\n\n                tok_hist = torch.roll(tok_hist, -1, dims=-1)\n                tok_hist[..., -1] = x0\n                conf_hist = torch.roll(conf_hist, -1, dims=-1)\n                conf_hist[..., -1] = conf\n\n        return x, used",
5  "after_line": 0,
6  "start_line": 59,
7  "end_line": 183
8}
Tool ResultStep 3
OK: Replaced lines 59..183 with 141 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–199.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–199 | total: 539 lines]
    59  class DemaskDecoder:
    60      """Consensus-Frontier Adaptive Demasking (CFAD).
    61  
       ... (135 more lines) ...
   197                  conf_hist[..., -1] = conf
   198  
   199          return x, used
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.335 avg_steps=149.0
TRAIN_METRICS: math 470/500 acc=0.336 avg_steps=148.9
TRAIN_METRICS: math 480/500 acc=0.338 avg_steps=149.1
TRAIN_METRICS: math 490/500 acc=0.339 avg_steps=149.2
TRAIN_METRICS: math 500/500 acc=0.336 avg_steps=149.5
accuracy_llada-math: 0.3360
avg_steps_llada-math: 149.5100
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.433 avg_steps=139.2
TRAIN_METRICS: humaneval 130/164 pass@1=0.423 avg_steps=138.8
TRAIN_METRICS: humaneval 140/164 pass@1=0.400 avg_steps=138.3
TRAIN_METRICS: humaneval 150/164 pass@1=0.387 avg_steps=137.9
TRAIN_METRICS: humaneval 160/164 pass@1=0.381 avg_steps=137.7
accuracy_llada-humaneval: 0.3780
avg_steps_llada-humaneval: 137.6300
n_sa
AssistantStep 5submit
Tool Input
1{
2  "n": 2
3}
Tool ResultStep 5
[submit] Submitting result from test #2 as final.

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.335 avg_steps=149.0
TRAIN_METRICS: math 470/500 acc=0.336 avg_steps=148.9
TRAIN_METRICS: math 480/500 acc=0.338 avg_steps=149.1
TRAIN_METRICS: math 490/500 acc=0.339 avg_steps=149.2
TRAIN_METRICS: math 500/500 acc=0.336 avg_steps=149.5
accuracy_llada-math: 0.3360
avg_steps_llada-math: 149.5100
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.433 avg_steps=139.2
TRAIN_METRICS: humaneval 130/164 pass@1=0.423 avg_steps=138.8
TRAIN_METRICS: humaneval 140/164 pass@1=0.400 avg_steps=138.3
TRAIN_METRICS: humaneval 150/164 pass@1=0.387 avg_steps=137.9
TRAIN_METRICS: humaneval 160/164 pass@1=0.381 avg_steps=137.7
accuracy_llada-humaneval: 0.3780
avg_steps_llada-humaneval: 137.6300
n_samples_llada-humaneval: 164.0000

### dream-t