Agent Conversation: openai/gpt-5.4
Model: openai/gpt-5.4
Total steps: 5
Messages: 12
UserStep 0
Task: llm-dllm-demask-strategy
Masked Diffusion LM: Demasking Strategy
Research Question
Design a better demasking (decoding) strategy for masked diffusion language models. The strategy must generalize across different decoding regimes:
- Block-based semi-autoregressive decoding for downstream task accuracy (LLaDA on MATH/HumanEval, following the KLASS protocol)
- Fully-parallel decoding for open-ended text generation (Dream on prefix-conditioned C4 continuation, measured by perplexity / diversity)
Background
Masked diffusion LMs (LLaDA, Dream) generate by starting from a fully masked
generation region and iteratively unmasking over steps denoising iterations.
A demasking strategy decides at each step:
- Schedule: how many tokens to unmask
- Position selection: which masked positions to unmask
- Token assignment: what token id to place
Decoding can be semi-autoregressive (when block_length < gen_length,
process one block at a time) or fully parallel (block_length == gen_length, all positions decoded together).
What You Can Modify
Edit the DemaskDecoder class in LLaDA/custom_demask_eval.py
(lines 59-151).
Interface
class DemaskDecoder:
def __init__(self, mask_id, temperature=0.0,
conf_threshold=0.9, kl_threshold=0.01, history_length=2):
...
@torch.no_grad()
def decode(self, model, input_ids, gen_length, steps, block_length):
# Returns (x_output [1, prompt_len + gen_length], used_steps)
get_num_transfer_tokens(mask, steps) is available outside the editable
region — returns the uniform schedule (mask.sum() // steps per step).
Constraints
gen_length % block_length == 0. When equal, decoding is fully parallel.- Process blocks sequentially (no early-decoding into later blocks).
- Always return
[1, prompt_len + gen_length]. used_stepscounts model forward passes (lower = more efficient).
Evaluation
Benchmarks
| Label | Task | Model | gen_len | steps | block_len | Metrics |
|---|---|---|---|---|---|---|
llada-math | MATH-500 | LLaDA-8B-Instruct | 256 | 256 | 64 | accuracy + avg_steps |
llada-humaneval | HumanEval (164) | LLaDA-8B-Instruct | 256 | 256 | 64 | accuracy + avg_steps |
dream-text | C4 prefix-continuation (256 samples, 32-tok prefix → 224-tok continuation) | Dream-v0-Instruct-7B | 224 | 256 | 224 | gen_ppl + MAUVE + entropy + rep2 + avg_steps |
Metrics
| Metric | Direction | Where | Description |
|---|---|---|---|
accuracy | ↑ | math/humaneval | exact-match (MATH) or pass@1 (HumanEval) |
gen_ppl | ↓ | text | Conditional perplexity via GPT-2-Large |
mauve | ↑ | text | Distributional similarity to C4 reference text |
entropy | ↑ | text | Bigram entropy (lexical diversity) |
rep2 | ↓ | text | Repeated bigram ratio |
avg_steps | ↓ | all | Actual model forward passes used |
AssistantStep 1edit
Tool Input
1{
2 "op": "replace",
3 "filename": "LLaDA/custom_demask_eval.py",
4 "content": "class DemaskDecoder:\n \"\"\"Consensus-Frontier Adaptive Demasking (CFAD).\n\n Key ideas:\n 1) use token-level consensus across recent denoising steps rather than\n full-vocab KL only;\n 2) always transfer at least the uniform quota, avoiding KLASS-style stalls\n when only a few positions are marked ready;\n 3) add a soft left-to-right frontier bias in semi-autoregressive blocks,\n while turning that bias off for fully-parallel decoding.\n \"\"\"\n\n def __init__(self, mask_id: int, temperature: float = 0.0,\n conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n history_length: int = 2):\n self.mask_id = mask_id\n self.temperature = temperature\n self.conf_threshold = conf_threshold\n self.kl_threshold = kl_threshold\n self.history_length = history_length\n\n # Hardcoded algorithmic constants.\n self.relaxed_conf_factor = 0.75\n self.min_ready_margin = 0.05\n self.ready_bonus = 1.0\n self.frontier_strength = 0.15\n\n @torch.no_grad()\n def decode(self, model, input_ids, gen_length: int, steps: int,\n block_length: int):\n mid = self.mask_id\n prompt_len = input_ids.shape[1]\n x = torch.full((1, prompt_len + gen_length), mid,\n dtype=torch.long, device=model.device)\n x[:, :prompt_len] = input_ids.clone()\n assert gen_length % block_length == 0\n num_blocks = gen_length // block_length\n assert steps % num_blocks == 0\n steps_per_block = steps // num_blocks\n used = 0\n\n semi_ar_strength = max(0.0, 1.0 - (block_length / max(gen_length, 1)))\n denom = float(max(block_length - 1, 1))\n pos_bias = (1.0 - torch.arange(block_length, device=x.device,\n dtype=torch.float32) / denom).view(1, -1)\n pos_bias = pos_bias * (semi_ar_strength * self.frontier_strength)\n relaxed_conf = max(0.55, self.conf_threshold * self.relaxed_conf_factor)\n\n for b in range(num_blocks):\n bs = prompt_len + b * block_length\n be = bs + block_length\n num_xfer = get_num_transfer_tokens((x[:, bs:be] == mid), steps_per_block)\n\n # Block-local history: much cheaper than storing full-vocab histories.\n tok_hist = torch.full((1, block_length, self.history_length), -1,\n dtype=torch.long, device=x.device)\n conf_hist = torch.zeros((1, block_length, self.history_length),\n dtype=torch.float32, device=x.device)\n\n for step in range(steps_per_block):\n mask_block = (x[:, bs:be] == mid)\n if not mask_block.any():\n break\n\n logits = model(x).logits[:, bs:be, :].to(torch.float32)\n top2_logits, top2_idx = torch.topk(logits, k=2, dim=-1)\n x0 = top2_idx[..., 0]\n log_z = torch.logsumexp(logits, dim=-1)\n conf = torch.exp(top2_logits[..., 0] - log_z)\n p2 = torch.exp(top2_logits[..., 1] - log_z)\n margin = conf - p2\n\n valid_hist = conf_hist > 0\n hist_count = valid_hist.float().sum(-1)\n same_tok = (tok_hist == x0.unsqueeze(-1)) & valid_hist\n same_count = same_tok.float().sum(-1)\n agreement = same_count / hist_count.clamp_min(1.0)\n same_conf = (conf_hist * same_tok.float()).sum(-1) / same_count.clamp_min(1.0)\n enough_hist = hist_count >= max(1, self.history_length - 1)\n consensus = enough_hist & (agreement > 0.999)\n\n ready = mask_block & (\n ((conf > self.conf_threshold) & (margin > self.min_ready_margin)) |\n (consensus & (conf > relaxed_conf) & (same_conf > relaxed_conf * 0.9))\n )\n\n masked_frac = mask_block.float().sum(dim=1, keepdim=True) / float(block_length)\n score = conf + 0.75 * margin + 0.35 * same_conf + 0.25 * agreement\n score = score + pos_bias * masked_frac + ready.float() * self.ready_bonus\n\n xfer = torch.zeros_like(mask_block)\n for j in range(mask_block.shape[0]):\n remain = int(mask_block[j].sum().item())\n base_k = max(1, int(num_xfer[j, step].item()))\n if step == steps_per_block - 1 or remain <= base_k:\n target_k = remain\n else:\n ready_count = int(ready[j].sum().item())\n maturity = 1.0 - remain / float(block_length)\n extra = int(max(0, ready_count - base_k) * (0.25 + 0.5 * maturity))\n if ready_count >= 2 * base_k:\n extra = max(extra, min(ready_count - base_k, base_k))\n target_k = min(remain, base_k + extra)\n\n s = score[j].clone()\n s[~mask_block[j]] = -float(\"inf\")\n _, topk = torch.topk(s, target_k)\n xfer[j, topk] = True\n\n chosen = x0\n if self.temperature > 0 and xfer.any():\n chosen = x0.clone()\n temp = max(float(self.temperature), 1e-6)\n samp_probs = F.softmax(logits[xfer] / temp, dim=-1)\n chosen[xfer] = torch.multinomial(samp_probs, 1).squeeze(-1)\n\n x[:, bs:be] = torch.where(xfer, chosen, x[:, bs:be])\n used += 1\n\n tok_hist = torch.roll(tok_hist, -1, dims=-1)\n tok_hist[..., -1] = x0\n conf_hist = torch.roll(conf_hist, -1, dims=-1)\n conf_hist[..., -1] = conf\n\n return x, used",
5 "after_line": 0,
6 "start_line": 59,
7 "end_line": 151
8}Tool ResultStep 1
OK: Replaced lines 59..151 with 125 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–183.
[Current file: LLaDA/custom_demask_eval.py | editable: 59–183 | total: 523 lines]
59 class DemaskDecoder:
60 """Consensus-Frontier Adaptive Demasking (CFAD).
61
... (119 more lines) ...
181 conf_hist[..., -1] = conf
182
183 return x, usedAssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final) ### llada-math (scripts/llada_math.sh) Generation progress (llada-math): TRAIN_METRICS: math 460/500 acc=0.309 avg_steps=82.6 TRAIN_METRICS: math 470/500 acc=0.311 avg_steps=82.4 TRAIN_METRICS: math 480/500 acc=0.310 avg_steps=82.6 TRAIN_METRICS: math 490/500 acc=0.310 avg_steps=82.6 TRAIN_METRICS: math 500/500 acc=0.308 avg_steps=83.0 accuracy_llada-math: 0.3080 avg_steps_llada-math: 83.0400 n_samples_llada-math: 500.0000 ### llada-humaneval (scripts/llada_humaneval.sh) Generation progress (llada-humaneval): TRAIN_METRICS: humaneval 120/164 pass@1=0.450 avg_steps=72.0 TRAIN_METRICS: humaneval 130/164 pass@1=0.438 avg_steps=72.6 TRAIN_METRICS: humaneval 140/164 pass@1=0.407 avg_steps=72.6 TRAIN_METRICS: humaneval 150/164 pass@1=0.387 avg_steps=72.2 TRAIN_METRICS: humaneval 160/164 pass@1=0.381 avg_steps=71.7 accuracy_llada-humaneval: 0.3780 avg_steps_llada-humaneval: 71.6800 n_samples_llada
AssistantStep 3edit
Tool Input
1{
2 "op": "replace",
3 "filename": "LLaDA/custom_demask_eval.py",
4 "content": "class DemaskDecoder:\n \"\"\"Consensus-Frontier Adaptive Demasking (CFAD).\n\n Key ideas:\n 1) token-consensus stability across recent denoising steps;\n 2) a minimum-transfer schedule so decoding never stalls;\n 3) regime-aware frontier control: semi-autoregressive blocks decode more\n locally, while fully-parallel decoding remains aggressive.\n \"\"\"\n\n def __init__(self, mask_id: int, temperature: float = 0.0,\n conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n history_length: int = 2):\n self.mask_id = mask_id\n self.temperature = temperature\n self.conf_threshold = conf_threshold\n self.kl_threshold = kl_threshold\n self.history_length = history_length\n\n # Hardcoded algorithmic constants.\n self.relaxed_conf_factor = 0.75\n self.min_ready_margin = 0.05\n self.ready_bonus = 1.0\n self.frontier_strength = 0.25\n self.frontier_min_width = 8\n self.frontier_growth = 0.35\n\n @torch.no_grad()\n def decode(self, model, input_ids, gen_length: int, steps: int,\n block_length: int):\n mid = self.mask_id\n prompt_len = input_ids.shape[1]\n x = torch.full((1, prompt_len + gen_length), mid,\n dtype=torch.long, device=model.device)\n x[:, :prompt_len] = input_ids.clone()\n assert gen_length % block_length == 0\n num_blocks = gen_length // block_length\n assert steps % num_blocks == 0\n steps_per_block = steps // num_blocks\n used = 0\n\n semi_ar_strength = max(0.0, 1.0 - (block_length / max(gen_length, 1)))\n denom = float(max(block_length - 1, 1))\n pos_ids = torch.arange(block_length, device=x.device)\n pos_bias = (1.0 - pos_ids.to(torch.float32) / denom).view(1, -1)\n pos_bias = pos_bias * (semi_ar_strength * self.frontier_strength)\n relaxed_conf = max(0.55, self.conf_threshold * self.relaxed_conf_factor)\n\n for b in range(num_blocks):\n bs = prompt_len + b * block_length\n be = bs + block_length\n num_xfer = get_num_transfer_tokens((x[:, bs:be] == mid), steps_per_block)\n\n tok_hist = torch.full((1, block_length, self.history_length), -1,\n dtype=torch.long, device=x.device)\n conf_hist = torch.zeros((1, block_length, self.history_length),\n dtype=torch.float32, device=x.device)\n\n for step in range(steps_per_block):\n mask_block = (x[:, bs:be] == mid)\n if not mask_block.any():\n break\n\n logits = model(x).logits[:, bs:be, :].to(torch.float32)\n top2_logits, top2_idx = torch.topk(logits, k=2, dim=-1)\n x0 = top2_idx[..., 0]\n log_z = torch.logsumexp(logits, dim=-1)\n conf = torch.exp(top2_logits[..., 0] - log_z)\n p2 = torch.exp(top2_logits[..., 1] - log_z)\n margin = conf - p2\n\n valid_hist = conf_hist > 0\n hist_count = valid_hist.float().sum(-1)\n same_tok = (tok_hist == x0.unsqueeze(-1)) & valid_hist\n same_count = same_tok.float().sum(-1)\n agreement = same_count / hist_count.clamp_min(1.0)\n same_conf = (conf_hist * same_tok.float()).sum(-1) / same_count.clamp_min(1.0)\n enough_hist = hist_count >= max(1, self.history_length - 1)\n consensus = enough_hist & (agreement > 0.999)\n\n ready = mask_block & (\n ((conf > self.conf_threshold) & (margin > self.min_ready_margin)) |\n (consensus & (conf > relaxed_conf) & (same_conf > relaxed_conf * 0.9))\n )\n\n masked_frac = mask_block.float().sum(dim=1, keepdim=True) / float(block_length)\n score = conf + 0.75 * margin + 0.35 * same_conf + 0.25 * agreement\n score = score + pos_bias * masked_frac + ready.float() * self.ready_bonus\n\n xfer = torch.zeros_like(mask_block)\n for j in range(mask_block.shape[0]):\n remain = int(mask_block[j].sum().item())\n base_k = max(1, int(num_xfer[j, step].item()))\n maturity = 1.0 - remain / float(block_length)\n\n if step == steps_per_block - 1 or remain <= base_k:\n target_k = remain\n else:\n ready_count = int(ready[j].sum().item())\n if semi_ar_strength > 0:\n extra_scale = max(0.0, maturity - 0.25) / 0.75\n extra = int(max(0, ready_count - base_k) * extra_scale)\n if ready_count >= 4 * base_k and maturity > 0.6:\n extra = max(extra, min(ready_count - base_k, 2 * base_k))\n else:\n extra = int(max(0, ready_count - base_k) * (0.35 + 0.65 * maturity))\n if ready_count >= 2 * base_k:\n extra = max(extra, min(ready_count - base_k, base_k))\n target_k = min(remain, base_k + extra)\n\n s = score[j].clone()\n s[~mask_block[j]] = -float(\"inf\")\n\n if semi_ar_strength > 0 and target_k < remain:\n filled = block_length - remain\n width = max(self.frontier_min_width,\n int(block_length * (0.15 + self.frontier_growth * maturity)))\n frontier = min(block_length, filled + width)\n eligible = pos_ids < frontier\n if int((mask_block[j] & eligible).sum().item()) >= target_k:\n s[~eligible] = -float(\"inf\")\n\n _, topk = torch.topk(s, target_k)\n xfer[j, topk] = True\n\n chosen = x0\n if self.temperature > 0 and xfer.any():\n chosen = x0.clone()\n temp = max(float(self.temperature), 1e-6)\n samp_probs = F.softmax(logits[xfer] / temp, dim=-1)\n chosen[xfer] = torch.multinomial(samp_probs, 1).squeeze(-1)\n\n x[:, bs:be] = torch.where(xfer, chosen, x[:, bs:be])\n used += 1\n\n tok_hist = torch.roll(tok_hist, -1, dims=-1)\n tok_hist[..., -1] = x0\n conf_hist = torch.roll(conf_hist, -1, dims=-1)\n conf_hist[..., -1] = conf\n\n return x, used",
5 "after_line": 0,
6 "start_line": 59,
7 "end_line": 183
8}Tool ResultStep 3
OK: Replaced lines 59..183 with 141 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–199.
[Current file: LLaDA/custom_demask_eval.py | editable: 59–199 | total: 539 lines]
59 class DemaskDecoder:
60 """Consensus-Frontier Adaptive Demasking (CFAD).
61
... (135 more lines) ...
197 conf_hist[..., -1] = conf
198
199 return x, usedAssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final) ### llada-math (scripts/llada_math.sh) Generation progress (llada-math): TRAIN_METRICS: math 460/500 acc=0.335 avg_steps=149.0 TRAIN_METRICS: math 470/500 acc=0.336 avg_steps=148.9 TRAIN_METRICS: math 480/500 acc=0.338 avg_steps=149.1 TRAIN_METRICS: math 490/500 acc=0.339 avg_steps=149.2 TRAIN_METRICS: math 500/500 acc=0.336 avg_steps=149.5 accuracy_llada-math: 0.3360 avg_steps_llada-math: 149.5100 n_samples_llada-math: 500.0000 ### llada-humaneval (scripts/llada_humaneval.sh) Generation progress (llada-humaneval): TRAIN_METRICS: humaneval 120/164 pass@1=0.433 avg_steps=139.2 TRAIN_METRICS: humaneval 130/164 pass@1=0.423 avg_steps=138.8 TRAIN_METRICS: humaneval 140/164 pass@1=0.400 avg_steps=138.3 TRAIN_METRICS: humaneval 150/164 pass@1=0.387 avg_steps=137.9 TRAIN_METRICS: humaneval 160/164 pass@1=0.381 avg_steps=137.7 accuracy_llada-humaneval: 0.3780 avg_steps_llada-humaneval: 137.6300 n_sa
AssistantStep 5submit
Tool Input
1{
2 "n": 2
3}Tool ResultStep 5
[submit] Submitting result from test #2 as final. ### llada-math (scripts/llada_math.sh) Generation progress (llada-math): TRAIN_METRICS: math 460/500 acc=0.335 avg_steps=149.0 TRAIN_METRICS: math 470/500 acc=0.336 avg_steps=148.9 TRAIN_METRICS: math 480/500 acc=0.338 avg_steps=149.1 TRAIN_METRICS: math 490/500 acc=0.339 avg_steps=149.2 TRAIN_METRICS: math 500/500 acc=0.336 avg_steps=149.5 accuracy_llada-math: 0.3360 avg_steps_llada-math: 149.5100 n_samples_llada-math: 500.0000 ### llada-humaneval (scripts/llada_humaneval.sh) Generation progress (llada-humaneval): TRAIN_METRICS: humaneval 120/164 pass@1=0.433 avg_steps=139.2 TRAIN_METRICS: humaneval 130/164 pass@1=0.423 avg_steps=138.8 TRAIN_METRICS: humaneval 140/164 pass@1=0.400 avg_steps=138.3 TRAIN_METRICS: humaneval 150/164 pass@1=0.387 avg_steps=137.9 TRAIN_METRICS: humaneval 160/164 pass@1=0.381 avg_steps=137.7 accuracy_llada-humaneval: 0.3780 avg_steps_llada-humaneval: 137.6300 n_samples_llada-humaneval: 164.0000 ### dream-t