Agent Conversation: anthropic/claude-opus-4.6

Model: anthropic/claude-opus-4.6
Total steps: 7
Messages: 16
UserStep 0

Task: llm-dllm-demask-strategy

Masked Diffusion LM: Demasking Strategy

Research Question

Design a better demasking (decoding) strategy for masked diffusion language models. The strategy must generalize across different decoding regimes:

  • Block-based semi-autoregressive decoding for downstream task accuracy (LLaDA on MATH/HumanEval, following the KLASS protocol)
  • Fully-parallel decoding for open-ended text generation (Dream on prefix-conditioned C4 continuation, measured by perplexity / diversity)

Background

Masked diffusion LMs (LLaDA, Dream) generate by starting from a fully masked generation region and iteratively unmasking over steps denoising iterations. A demasking strategy decides at each step:

  1. Schedule: how many tokens to unmask
  2. Position selection: which masked positions to unmask
  3. Token assignment: what token id to place

Decoding can be semi-autoregressive (when block_length < gen_length, process one block at a time) or fully parallel (block_length == gen_length, all positions decoded together).

What You Can Modify

Edit the DemaskDecoder class in LLaDA/custom_demask_eval.py (lines 59-151).

Interface

class DemaskDecoder:
    def __init__(self, mask_id, temperature=0.0,
                 conf_threshold=0.9, kl_threshold=0.01, history_length=2):
        ...

    @torch.no_grad()
    def decode(self, model, input_ids, gen_length, steps, block_length):
        # Returns (x_output [1, prompt_len + gen_length], used_steps)

get_num_transfer_tokens(mask, steps) is available outside the editable region — returns the uniform schedule (mask.sum() // steps per step).

Constraints

  • gen_length % block_length == 0. When equal, decoding is fully parallel.
  • Process blocks sequentially (no early-decoding into later blocks).
  • Always return [1, prompt_len + gen_length].
  • used_steps counts model forward passes (lower = more efficient).

Evaluation

Benchmarks

LabelTaskModelgen_lenstepsblock_lenMetrics
llada-mathMATH-500LLaDA-8B-Instruct25625664accuracy + avg_steps
llada-humanevalHumanEval (164)LLaDA-8B-Instruct25625664accuracy + avg_steps
dream-textC4 prefix-continuation (256 samples, 32-tok prefix → 224-tok continuation)Dream-v0-Instruct-7B224256224gen_ppl + MAUVE + entropy + rep2 + avg_steps

Metrics

MetricDirectionWhereDescription
accuracymath/humanevalexact-match (MATH) or pass@1 (HumanEval)
gen_ppltextConditional perplexity via GPT-2-Large
mauvetextDistributional similarity to C4 reference text
entropytextBigram entropy (lexical diversity)
rep2textRepeated bigram ratio
avg_stepsallActual model forward passes used

AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "start_line": 59,
5  "end_line": 151,
6  "content": "class DemaskDecoder:\n    \"\"\"Adaptive Composite Demasking with Consistency Tracking (ACD-CT).\n\n    Novel algorithmic contributions over KLASS:\n    1. Margin-Confidence Fusion: positions ranked by conf * (p1 - p2),\n       combining prediction strength with decisiveness.\n    2. Argmax Consistency: lightweight stability via prediction constancy\n       across steps, replacing full-distribution KL divergence.\n    3. Top-k Annealing Sampler: for fully-parallel text decoding, samples\n       from top-k tokens with decaying temperature for diversity.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.history_length = history_length\n        self.sample_k = 10\n        self.sample_temp_start = 0.9\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n        is_parallel = (block_length == gen_length)\n        prev_pred = torch.full_like(x, -1)\n        cons = torch.zeros_like(x, dtype=torch.int)\n        used = 0\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n            num_xfer = get_num_transfer_tokens(\n                (x[:, bs:be] == mid), steps_per_block)\n            cons[:, bs:be] = 0\n            prev_pred[:, bs:be] = -1\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                blk = torch.zeros_like(mask_idx)\n                blk[:, bs:be] = True\n                mask_idx = mask_idx & blk\n                if not mask_idx.any():\n                    break\n                logits = model(x).logits\n                p = F.softmax(logits.to(torch.float64), dim=-1)\n                x0g = torch.argmax(p, dim=-1)\n                cf = torch.gather(p, -1, x0g.unsqueeze(-1)).squeeze(-1)\n                t2v, _ = torch.topk(p, 2, dim=-1)\n                mg = t2v[..., 0] - t2v[..., 1]\n                score = cf * mg\n                if is_parallel and self.sample_k > 1:\n                    prog = step / max(steps_per_block - 1, 1)\n                    ts = max(0.1, self.sample_temp_start * (1 - prog))\n                    sv, si = torch.topk(logits.float(), self.sample_k, -1)\n                    sp = F.softmax(sv / ts, dim=-1)\n                    sh = sp.shape[:-1]\n                    idx = torch.multinomial(sp.reshape(-1, self.sample_k), 1)\n                    x0 = torch.gather(\n                        si, -1, idx.reshape(*sh, 1)).squeeze(-1)\n                else:\n                    x0 = x0g\n                same = (x0g == prev_pred)\n                cons = torch.where(same, cons + 1, torch.ones_like(cons))\n                prev_pred = x0g.clone()\n                stable = cons >= self.history_length\n                ready = stable & (cf > self.conf_threshold) & mask_idx\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n                for j in range(ready.shape[0]):\n                    rdy = torch.where(ready[j])[0]\n                    if len(rdy) > 0:\n                        xfer[j, rdy] = True\n                    else:\n                        s = score[j].clone()\n                        s[~mask_idx[j]] = -float(\"inf\")\n                        nm = mask_idx[j].sum().item()\n                        k = min(int(num_xfer[j, step].item()), nm)\n                        if k > 0:\n                            _, tk = torch.topk(s, k)\n                            xfer[j, tk] = True\n                x = torch.where(xfer, x0, x)\n                used += 1\n        return x, used\n"
7}
Tool ResultStep 1
OK: Replaced lines 59..151 with 90 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–148.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–148 | total: 488 lines]
    59  class DemaskDecoder:
    60      """Adaptive Composite Demasking with Consistency Tracking (ACD-CT).
    61  
       ... (84 more lines) ...
   146                  x = torch.where(xfer, x0, x)
   147                  used += 1
   148          return x, used
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.285 avg_steps=56.8
TRAIN_METRICS: math 470/500 acc=0.287 avg_steps=56.6
TRAIN_METRICS: math 480/500 acc=0.285 avg_steps=56.8
TRAIN_METRICS: math 490/500 acc=0.286 avg_steps=57.0
TRAIN_METRICS: math 500/500 acc=0.284 avg_steps=57.4
accuracy_llada-math: 0.2840
avg_steps_llada-math: 57.4300
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.433 avg_steps=55.3
TRAIN_METRICS: humaneval 130/164 pass@1=0.423 avg_steps=56.5
TRAIN_METRICS: humaneval 140/164 pass@1=0.400 avg_steps=56.9
TRAIN_METRICS: humaneval 150/164 pass@1=0.393 avg_steps=56.3
TRAIN_METRICS: humaneval 160/164 pass@1=0.381 avg_steps=56.2
accuracy_llada-humaneval: 0.3780
avg_steps_llada-humaneval: 56.2600
n_samples_llada
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "start_line": 59,
5  "end_line": 148,
6  "content": "class DemaskDecoder:\n    \"\"\"Adaptive Composite Demasking with KL Stability (ACD-KL).\n\n    Novel algorithmic contributions over KLASS:\n    1. Margin-Confidence Fusion Score: fallback position ranking uses\n       score = conf * margin (product of top-1 prob and p1-p2 gap),\n       rewarding both prediction strength and decisiveness.  This avoids\n       committing to positions where the model is confident but\n       indecisive (small gap between top-1 and top-2 candidates).\n    2. Top-k Tempered Sampling: for fully-parallel text decoding,\n       samples from top-k (k=10) with temperature > 1 for controlled\n       diversity, improving MAUVE/entropy while bounding perplexity.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n        self.sample_k = 10\n        self.sample_temp = 1.5  # >1 for flatter distribution (diversity)\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n        is_parallel = (block_length == gen_length)\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        kl_hist = torch.zeros((1, x.shape[1], self.history_length),\n                              dtype=torch.float64, device=x.device)\n        p_prev = torch.zeros((1, x.shape[1], V), dtype=torch.float64,\n                             device=x.device)\n        used = 0\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n            num_xfer = get_num_transfer_tokens(\n                (x[:, bs:be] == mid), steps_per_block)\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                x0g = torch.argmax(p_curr, dim=-1)\n                conf = torch.gather(\n                    p_curr, -1, x0g.unsqueeze(-1)).squeeze(-1)\n                # Novel: margin-confidence fusion score\n                t2v, _ = torch.topk(p_curr, 2, dim=-1)\n                mg = t2v[..., 0] - t2v[..., 1]\n                score = conf * mg\n                # KL stability tracking (KLASS)\n                eps = 1e-12\n                kl = (p_curr * (torch.log(p_curr + eps)\n                               - torch.log(p_prev + eps))).sum(-1)\n                kl_hist = torch.roll(kl_hist, -1, dims=-1)\n                kl_hist[..., -1] = kl\n                p_prev = p_curr.clone()\n                # Novel: top-k tempered sampling for text diversity\n                if is_parallel:\n                    sv, si = torch.topk(\n                        logits.float(), self.sample_k, dim=-1)\n                    sp = F.softmax(sv / self.sample_temp, dim=-1)\n                    sh = sp.shape[:-1]\n                    idx = torch.multinomial(\n                        sp.reshape(-1, self.sample_k), 1)\n                    x0 = torch.gather(\n                        si, -1, idx.reshape(*sh, 1)).squeeze(-1)\n                else:\n                    x0 = x0g\n                # KLASS stability criterion\n                if step >= self.history_length - 1:\n                    stable = torch.all(\n                        kl_hist < self.kl_threshold, dim=-1)\n                else:\n                    stable = torch.zeros_like(conf, dtype=torch.bool)\n                ready = stable & (conf > self.conf_threshold) & mask_idx\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n                for j in range(ready.shape[0]):\n                    rdy = torch.where(ready[j])[0]\n                    if len(rdy) > 0:\n                        xfer[j, rdy] = True\n                    else:\n                        s = score[j].clone()\n                        s[~mask_idx[j]] = -float(\"inf\")\n                        nm = mask_idx[j].sum().item()\n                        k = min(int(num_xfer[j, step].item()), nm)\n                        if k > 0:\n                            _, tk = torch.topk(s, k)\n                            xfer[j, tk] = True\n                x = torch.where(xfer, x0, x)\n                used += 1\n        return x, used\n"
7}
Tool ResultStep 3
OK: Replaced lines 59..148 with 107 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–165.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–165 | total: 505 lines]
    59  class DemaskDecoder:
    60      """Adaptive Composite Demasking with KL Stability (ACD-KL).
    61  
       ... (101 more lines) ...
   163                  x = torch.where(xfer, x0, x)
   164                  used += 1
   165          return x, used
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.302 avg_steps=120.7
TRAIN_METRICS: math 470/500 acc=0.306 avg_steps=120.3
TRAIN_METRICS: math 480/500 acc=0.306 avg_steps=120.7
TRAIN_METRICS: math 490/500 acc=0.306 avg_steps=121.2
TRAIN_METRICS: math 500/500 acc=0.304 avg_steps=121.7
accuracy_llada-math: 0.3040
avg_steps_llada-math: 121.7300
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.458 avg_steps=87.2
TRAIN_METRICS: humaneval 130/164 pass@1=0.446 avg_steps=89.1
TRAIN_METRICS: humaneval 140/164 pass@1=0.421 avg_steps=89.5
TRAIN_METRICS: humaneval 150/164 pass@1=0.413 avg_steps=89.1
TRAIN_METRICS: humaneval 160/164 pass@1=0.406 avg_steps=89.2
accuracy_llada-humaneval: 0.4024
avg_steps_llada-humaneval: 89.3900
n_samples_
AssistantStep 5edit
Tool Input
1{
2  "op": "replace",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "start_line": 59,
5  "end_line": 165,
6  "content": "class DemaskDecoder:\n    \"\"\"Adaptive Composite Demasking with Mode-Adaptive Stability (ACD-MAS).\n\n    Novel algorithmic contributions:\n    1. Margin-Confidence Fusion Score: fallback position ranking uses\n       score = conf * (p1-p2), rewarding both strength and decisiveness.\n       Avoids committing to positions where top-1 and top-2 are close.\n    2. Mode-Adaptive Stability: semi-AR accuracy tasks use KL-divergence\n       stability (conservative, proven by KLASS). Parallel text tasks\n       use lightweight argmax consistency (fewer steps, similar quality).\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n        is_par = (block_length == gen_length)\n        if is_par:\n            prev_pred = torch.full_like(x, -1)\n            cons = torch.zeros_like(x, dtype=torch.int)\n        else:\n            V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                           else model.config.vocab_size\n            kl_hist = torch.zeros((1, x.shape[1], self.history_length),\n                                  dtype=torch.float64, device=x.device)\n            p_prev = torch.zeros((1, x.shape[1], V), dtype=torch.float64,\n                                 device=x.device)\n        used = 0\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n            num_xfer = get_num_transfer_tokens(\n                (x[:, bs:be] == mid), steps_per_block)\n            if is_par:\n                cons[:, bs:be] = 0\n                prev_pred[:, bs:be] = -1\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                blk = torch.zeros_like(mask_idx)\n                blk[:, bs:be] = True\n                mask_idx = mask_idx & blk\n                if not mask_idx.any():\n                    break\n                logits = model(x).logits\n                p = F.softmax(logits.to(torch.float64), dim=-1)\n                x0 = torch.argmax(p, dim=-1)\n                cf = torch.gather(p, -1, x0.unsqueeze(-1)).squeeze(-1)\n                t2v, _ = torch.topk(p, 2, dim=-1)\n                mg = t2v[..., 0] - t2v[..., 1]\n                score = cf * mg\n                if is_par:\n                    same = (x0 == prev_pred)\n                    cons = torch.where(\n                        same, cons + 1, torch.ones_like(cons))\n                    prev_pred = x0.clone()\n                    stable = cons >= self.history_length\n                else:\n                    eps = 1e-12\n                    kl = (p * (torch.log(p + eps)\n                              - torch.log(p_prev + eps))).sum(-1)\n                    kl_hist = torch.roll(kl_hist, -1, dims=-1)\n                    kl_hist[..., -1] = kl\n                    p_prev = p.clone()\n                    if step >= self.history_length - 1:\n                        stable = torch.all(\n                            kl_hist < self.kl_threshold, dim=-1)\n                    else:\n                        stable = torch.zeros_like(\n                            cf, dtype=torch.bool)\n                ready = stable & (cf > self.conf_threshold) & mask_idx\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n                for j in range(ready.shape[0]):\n                    rdy = torch.where(ready[j])[0]\n                    if len(rdy) > 0:\n                        xfer[j, rdy] = True\n                    else:\n                        s = score[j].clone()\n                        s[~mask_idx[j]] = -float(\"inf\")\n                        nm = mask_idx[j].sum().item()\n                        k = min(int(num_xfer[j, step].item()), nm)\n                        if k > 0:\n                            _, tk = torch.topk(s, k)\n                            xfer[j, tk] = True\n                x = torch.where(xfer, x0, x)\n                used += 1\n        return x, used\n"
7}
Tool ResultStep 5
OK: Replaced lines 59..165 with 102 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–160.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–160 | total: 500 lines]
    59  class DemaskDecoder:
    60      """Adaptive Composite Demasking with Mode-Adaptive Stability (ACD-MAS).
    61  
       ... (96 more lines) ...
   158                  x = torch.where(xfer, x0, x)
   159                  used += 1
   160          return x, used
AssistantStep 6test
Tool ResultStep 6
[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final)

[NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer.

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.302 avg_steps=120.7
TRAIN_METRICS: math 470/500 acc=0.306 avg_steps=120.3
TRAIN_METRICS: math 480/500 acc=0.306 avg_steps=120.7
TRAIN_METRICS: math 490/500 acc=0.306 avg_steps=121.2
TRAIN_METRICS: math 500/500 acc=0.304 avg_steps=121.7
accuracy_llada-math: 0.3040
avg_steps_llada-math: 121.7300
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.458 avg_steps=87.2
TRAIN_METRICS: humaneval 130/164 pass@1=0.446 avg_steps=89.1
TRAIN_METRICS: humaneval 140/164 pass@1=0.421 avg_steps=89.5
TRAIN_METRICS: humaneval 150/164 pass@1=0.413 avg_steps=89.1
TRAIN_METRICS: 
AssistantStep 7submit
Tool Input
1{
2  "n": 3
3}
Tool ResultStep 7
[submit] Submitting result from test #3 as final.

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.302 avg_steps=120.7
TRAIN_METRICS: math 470/500 acc=0.306 avg_steps=120.3
TRAIN_METRICS: math 480/500 acc=0.306 avg_steps=120.7
TRAIN_METRICS: math 490/500 acc=0.306 avg_steps=121.2
TRAIN_METRICS: math 500/500 acc=0.304 avg_steps=121.7
accuracy_llada-math: 0.3040
avg_steps_llada-math: 121.7300
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.458 avg_steps=87.2
TRAIN_METRICS: humaneval 130/164 pass@1=0.446 avg_steps=89.1
TRAIN_METRICS: humaneval 140/164 pass@1=0.421 avg_steps=89.5
TRAIN_METRICS: humaneval 150/164 pass@1=0.413 avg_steps=89.1
TRAIN_METRICS: humaneval 160/164 pass@1=0.406 avg_steps=89.2
accuracy_llada-humaneval: 0.4024
avg_steps_llada-humaneval: 89.3900
n_samples_llada-humaneval: 164.0000

### dream-text (s