Agent Conversation: google/gemini-3.1-pro-preview

Model: google/gemini-3.1-pro-preview
Total steps: 9
Messages: 20
UserStep 0

Task: llm-dllm-demask-strategy

Masked Diffusion LM: Demasking Strategy

Research Question

Design a better demasking (decoding) strategy for masked diffusion language models. The strategy must generalize across different decoding regimes:

  • Block-based semi-autoregressive decoding for downstream task accuracy (LLaDA on MATH/HumanEval, following the KLASS protocol)
  • Fully-parallel decoding for open-ended text generation (Dream on prefix-conditioned C4 continuation, measured by perplexity / diversity)

Background

Masked diffusion LMs (LLaDA, Dream) generate by starting from a fully masked generation region and iteratively unmasking over steps denoising iterations. A demasking strategy decides at each step:

  1. Schedule: how many tokens to unmask
  2. Position selection: which masked positions to unmask
  3. Token assignment: what token id to place

Decoding can be semi-autoregressive (when block_length < gen_length, process one block at a time) or fully parallel (block_length == gen_length, all positions decoded together).

What You Can Modify

Edit the DemaskDecoder class in LLaDA/custom_demask_eval.py (lines 59-151).

Interface

class DemaskDecoder:
    def __init__(self, mask_id, temperature=0.0,
                 conf_threshold=0.9, kl_threshold=0.01, history_length=2):
        ...

    @torch.no_grad()
    def decode(self, model, input_ids, gen_length, steps, block_length):
        # Returns (x_output [1, prompt_len + gen_length], used_steps)

get_num_transfer_tokens(mask, steps) is available outside the editable region — returns the uniform schedule (mask.sum() // steps per step).

Constraints

  • gen_length % block_length == 0. When equal, decoding is fully parallel.
  • Process blocks sequentially (no early-decoding into later blocks).
  • Always return [1, prompt_len + gen_length].
  • used_steps counts model forward passes (lower = more efficient).

Evaluation

Benchmarks

LabelTaskModelgen_lenstepsblock_lenMetrics
llada-mathMATH-500LLaDA-8B-Instruct25625664accuracy + avg_steps
llada-humanevalHumanEval (164)LLaDA-8B-Instruct25625664accuracy + avg_steps
dream-textC4 prefix-continuation (256 samples, 32-tok prefix → 224-tok continuation)Dream-v0-Instruct-7B224256224gen_ppl + MAUVE + entropy + rep2 + avg_steps

Metrics

MetricDirectionWhereDescription
accuracymath/humanevalexact-match (MATH) or pass@1 (HumanEval)
gen_ppltextConditional perplexity via GPT-2-Large
mauvetextDistributional similarity to C4 reference text
entropytextBigram entropy (lexical diversity)
rep2textRepeated bigram ratio
avg_stepsallActual model forward passes used

AssistantStep 1edit
Tool Input
1{
2  "end_line": 151,
3  "op": "replace",
4  "start_line": 59,
5  "content": "class DemaskDecoder:\n    \"\"\"ASGMS: Adaptive Schedule with Gumbel-Margin and Stability\"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        kl_hist = torch.zeros((1, x.shape[1], self.history_length),\n                              dtype=torch.float64, device=x.device)\n        p_prev = torch.zeros((1, x.shape[1], V), dtype=torch.float64,\n                             device=x.device)\n        used = 0\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                \n                eps = 1e-12\n                temp = self.temperature\n                is_parallel = (block_length == gen_length)\n                \n                # Apply slight temperature sampling for fully-parallel text decoding\n                # to improve MAUVE and lexical diversity, otherwise use provided temp\n                if is_parallel and temp == 0.0:\n                    temp_local = 0.1\n                    gumbel = -torch.log(-torch.log(torch.rand_like(p_curr) + eps) + eps)\n                    x0 = torch.argmax(torch.log(p_curr + eps) + gumbel * temp_local, dim=-1)\n                elif temp > 0:\n                    gumbel = -torch.log(-torch.log(torch.rand_like(p_curr) + eps) + eps)\n                    x0 = torch.argmax(torch.log(p_curr + eps) + gumbel * temp, dim=-1)\n                else:\n                    x0 = torch.argmax(p_curr, dim=-1)\n                    \n                sorted_p, _ = torch.sort(p_curr, dim=-1, descending=True)\n                conf = sorted_p[..., 0]\n                margin = sorted_p[..., 0] - sorted_p[..., 1]\n\n                kl = (p_curr * (torch.log(p_curr + eps) - torch.log(p_prev + eps))).sum(-1)\n                kl_hist = torch.roll(kl_hist, -1, dims=-1)\n                kl_hist[..., -1] = kl\n                p_prev = p_curr.clone()\n\n                if step >= self.history_length - 1:\n                    stable = torch.all(kl_hist < self.kl_threshold, dim=-1)\n                else:\n                    stable = torch.zeros_like(conf, dtype=torch.bool)\n                \n                ready = stable & (conf > self.conf_threshold) & mask_idx\n\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n                rem_steps = steps_per_block - step\n                \n                for j in range(ready.shape[0]):\n                    num_mask_j = mask_idx[j].sum().item()\n                    target_xfer = (num_mask_j + rem_steps - 1) // rem_steps\n                    \n                    rdy = torch.where(ready[j])[0]\n                    num_ready = len(rdy)\n                    \n                    if num_ready > 0:\n                        xfer[j, rdy] = True\n                        \n                    if num_ready < target_xfer:\n                        c = margin[j].clone()\n                        if is_parallel:\n                            gumbel_c = -torch.log(-torch.log(torch.rand_like(c) + eps) + eps)\n                            c = c + 0.15 * gumbel_c\n                            \n                        c[~mask_idx[j]] = -float(\"inf\")\n                        if num_ready > 0:\n                            c[rdy] = -float(\"inf\")\n                        \n                        rem_to_unmask = target_xfer - num_ready\n                        rem_to_unmask = min(rem_to_unmask, int((c > -float(\"inf\")).sum().item()))\n                        if rem_to_unmask > 0:\n                            _, topk = torch.topk(c, rem_to_unmask)\n                            xfer[j, topk] = True\n\n                if step == steps_per_block - 1:\n                    xfer = xfer | mask_idx\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n\n        return x, used",
6  "filename": "LLaDA/custom_demask_eval.py"
7}
Tool ResultStep 1
OK: Replaced lines 59..151 with 115 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–173.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–173 | total: 513 lines]
    59  class DemaskDecoder:
    60      """ASGMS: Adaptive Schedule with Gumbel-Margin and Stability"""
    61  
       ... (109 more lines) ...
   171                  used += 1
   172  
   173          return x, used
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.317 avg_steps=117.6
TRAIN_METRICS: math 470/500 acc=0.321 avg_steps=117.1
TRAIN_METRICS: math 480/500 acc=0.319 avg_steps=117.4
TRAIN_METRICS: math 490/500 acc=0.318 avg_steps=117.8
TRAIN_METRICS: math 500/500 acc=0.318 avg_steps=118.3
accuracy_llada-math: 0.3180
avg_steps_llada-math: 118.3100
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.458 avg_steps=89.0
TRAIN_METRICS: humaneval 130/164 pass@1=0.446 avg_steps=90.0
TRAIN_METRICS: humaneval 140/164 pass@1=0.421 avg_steps=91.8
TRAIN_METRICS: humaneval 150/164 pass@1=0.413 avg_steps=91.3
TRAIN_METRICS: humaneval 160/164 pass@1=0.406 avg_steps=91.3
accuracy_llada-humaneval: 0.4024
avg_steps_llada-humaneval: 91.4500
n_samples
AssistantStep 3edit
Tool Input
1{
2  "content": "class DemaskDecoder:\n    \"\"\"Adaptive Pathway Stability Decoding with Modality-Aware Fallback\"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        kl_hist = torch.zeros((1, x.shape[1], self.history_length),\n                              dtype=torch.float64, device=x.device)\n        p_prev = torch.zeros((1, x.shape[1], V), dtype=torch.float64,\n                             device=x.device)\n        used = 0\n        \n        is_parallel = (block_length == gen_length)\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                \n                eps = 1e-12\n                temp = self.temperature\n                \n                if temp > 0:\n                    gumbel_t = -torch.log(-torch.log(torch.rand_like(p_curr) + eps) + eps)\n                    x0 = torch.argmax(torch.log(p_curr + eps) + gumbel_t * temp, dim=-1)\n                else:\n                    x0 = torch.argmax(p_curr, dim=-1)\n                    \n                sorted_p, _ = torch.sort(p_curr, dim=-1, descending=True)\n                conf = sorted_p[..., 0]\n                margin = sorted_p[..., 0] - sorted_p[..., 1]\n\n                kl = (p_curr * (torch.log(p_curr + eps) - torch.log(p_prev + eps))).sum(-1)\n                kl_hist = torch.roll(kl_hist, -1, dims=-1)\n                kl_hist[..., -1] = kl\n                p_prev = p_curr.clone()\n\n                if step >= self.history_length - 1:\n                    stable = torch.all(kl_hist < self.kl_threshold, dim=-1)\n                else:\n                    stable = torch.zeros_like(conf, dtype=torch.bool)\n                \n                # Confidence threshold decay over the block to avoid backlogs\n                progress = step / max(1, steps_per_block - 1)\n                curr_conf_threshold = self.conf_threshold * (1.0 - 0.15 * progress)\n                \n                ready = stable & (conf > curr_conf_threshold) & mask_idx\n\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n                rem_steps = steps_per_block - step\n                \n                for j in range(ready.shape[0]):\n                    num_mask_j = mask_idx[j].sum().item()\n                    target_xfer = (num_mask_j + rem_steps - 1) // rem_steps\n                    \n                    rdy = torch.where(ready[j])[0]\n                    num_ready = len(rdy)\n                    \n                    if num_ready > 0:\n                        max_unmask = max(target_xfer * 3, 4)\n                        if num_ready > max_unmask and is_parallel:\n                            c_rdy = conf[j].clone()\n                            c_rdy[~ready[j]] = -float(\"inf\")\n                            _, topk_rdy = torch.topk(c_rdy, max_unmask)\n                            xfer[j, topk_rdy] = True\n                        else:\n                            xfer[j, rdy] = True\n                    else:\n                        if is_parallel:\n                            c = margin[j].clone()\n                            gumbel_m = -torch.log(-torch.log(torch.rand_like(c) + eps) + eps)\n                            c = c + 0.1 * gumbel_m\n                        else:\n                            c = conf[j].clone()\n                            \n                        c[~mask_idx[j]] = -float(\"inf\")\n                        rem_to_unmask = min(target_xfer, int((c > -float(\"inf\")).sum().item()))\n                        if rem_to_unmask > 0:\n                            _, topk = torch.topk(c, rem_to_unmask)\n                            xfer[j, topk] = True\n\n                if step == steps_per_block - 1:\n                    xfer = xfer | mask_idx\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n\n        return x, used",
3  "op": "replace",
4  "end_line": 173,
5  "start_line": 59,
6  "filename": "LLaDA/custom_demask_eval.py"
7}
Tool ResultStep 3
OK: Replaced lines 59..173 with 118 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–176.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–176 | total: 516 lines]
    59  class DemaskDecoder:
    60      """Adaptive Pathway Stability Decoding with Modality-Aware Fallback"""
    61  
       ... (112 more lines) ...
   174                  used += 1
   175  
   176          return x, used
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.333 avg_steps=125.7
TRAIN_METRICS: math 470/500 acc=0.336 avg_steps=125.4
TRAIN_METRICS: math 480/500 acc=0.331 avg_steps=125.8
TRAIN_METRICS: math 490/500 acc=0.331 avg_steps=125.9
TRAIN_METRICS: math 500/500 acc=0.330 avg_steps=126.4
accuracy_llada-math: 0.3300
avg_steps_llada-math: 126.4100
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.417 avg_steps=90.6
TRAIN_METRICS: humaneval 130/164 pass@1=0.400 avg_steps=92.6
TRAIN_METRICS: humaneval 140/164 pass@1=0.371 avg_steps=92.9
TRAIN_METRICS: humaneval 150/164 pass@1=0.360 avg_steps=92.4
TRAIN_METRICS: humaneval 160/164 pass@1=0.356 avg_steps=92.1
accuracy_llada-humaneval: 0.3537
avg_steps_llada-humaneval: 92.1000
n_samples_
AssistantStep 5edit
Tool Input
1{
2  "content": "class DemaskDecoder:\n    \"\"\"Progress-Ensured Adaptive Pathway Stability Decoding\n\n    Algorithm features:\n      1. Unmasks all stable & confident tokens dynamically.\n      2. Ensures a minimum `target_xfer` unmask rate per step to prevent backlogs.\n      3. Uses a robust `conf + margin` metric for forced unmasking.\n      4. Injects path stochasticity via Gumbel noise in fully-parallel settings\n         to increase lexical diversity (MAUVE, entropy).\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        kl_hist = torch.zeros((1, x.shape[1], self.history_length),\n                              dtype=torch.float64, device=x.device)\n        p_prev = torch.zeros((1, x.shape[1], V), dtype=torch.float64,\n                             device=x.device)\n        used = 0\n        is_parallel = (block_length == gen_length)\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                \n                eps = 1e-12\n                temp = self.temperature\n                if temp > 0:\n                    gumbel_t = -torch.log(-torch.log(torch.rand_like(p_curr) + eps) + eps)\n                    x0 = torch.argmax(torch.log(p_curr + eps) + gumbel_t * temp, dim=-1)\n                else:\n                    x0 = torch.argmax(p_curr, dim=-1)\n                    \n                sorted_p, _ = torch.sort(p_curr, dim=-1, descending=True)\n                conf = sorted_p[..., 0]\n                margin = sorted_p[..., 0] - sorted_p[..., 1]\n\n                kl = (p_curr * (torch.log(p_curr + eps) - torch.log(p_prev + eps))).sum(-1)\n                kl_hist = torch.roll(kl_hist, -1, dims=-1)\n                kl_hist[..., -1] = kl\n                p_prev = p_curr.clone()\n\n                if step >= self.history_length - 1:\n                    stable = torch.all(kl_hist < self.kl_threshold, dim=-1)\n                else:\n                    stable = torch.zeros_like(conf, dtype=torch.bool)\n                \n                ready = stable & (conf > self.conf_threshold) & mask_idx\n\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n                rem_steps = steps_per_block - step\n                \n                for j in range(ready.shape[0]):\n                    num_mask_j = mask_idx[j].sum().item()\n                    target_xfer = (num_mask_j + rem_steps - 1) // rem_steps\n                    \n                    rdy = torch.where(ready[j])[0]\n                    num_ready = len(rdy)\n                    \n                    if num_ready > 0:\n                        xfer[j, rdy] = True\n                        \n                    if num_ready < target_xfer:\n                        # Robust Modality-Agnostic Score: combines confidence scale with marginal clarity\n                        c = conf[j].clone() + margin[j].clone()\n                        \n                        if is_parallel:\n                            gumbel_c = -torch.log(-torch.log(torch.rand_like(c) + eps) + eps)\n                            c = c + 0.1 * gumbel_c\n                            \n                        c[~mask_idx[j]] = -float(\"inf\")\n                        if num_ready > 0:\n                            c[rdy] = -float(\"inf\")\n                            \n                        rem_to_unmask = target_xfer - num_ready\n                        rem_to_unmask = min(rem_to_unmask, int((c > -float(\"inf\")).sum().item()))\n                        if rem_to_unmask > 0:\n                            _, topk = torch.topk(c, rem_to_unmask)\n                            xfer[j, topk] = True\n\n                if step == steps_per_block - 1:\n                    xfer = xfer | mask_idx\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n\n        return x, used",
3  "start_line": 59,
4  "filename": "LLaDA/custom_demask_eval.py",
5  "end_line": 176,
6  "op": "replace"
7}
Tool ResultStep 5
OK: Replaced lines 59..176 with 118 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–176.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–176 | total: 516 lines]
    59  class DemaskDecoder:
    60      """Progress-Ensured Adaptive Pathway Stability Decoding
    61  
       ... (112 more lines) ...
   174                  used += 1
   175  
   176          return x, used
AssistantStep 6test
Tool ResultStep 6
[Test #3] (0 tests remaining; call submit(n=N) to choose which test result to submit as final)

[NOTE] This was your last test. You MUST now call submit(n=X) to choose which test result to submit as your final answer.

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.298 avg_steps=121.4
TRAIN_METRICS: math 470/500 acc=0.302 avg_steps=121.0
TRAIN_METRICS: math 480/500 acc=0.300 avg_steps=121.3
TRAIN_METRICS: math 490/500 acc=0.300 avg_steps=121.7
TRAIN_METRICS: math 500/500 acc=0.296 avg_steps=122.4
accuracy_llada-math: 0.2960
avg_steps_llada-math: 122.3600
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.450 avg_steps=89.6
TRAIN_METRICS: humaneval 130/164 pass@1=0.438 avg_steps=91.6
TRAIN_METRICS: humaneval 140/164 pass@1=0.414 avg_steps=91.9
TRAIN_METRICS: humaneval 150/164 pass@1=0.400 avg_steps=91.1
TRAIN_METRICS: 
AssistantStep 7edit
Tool Input
1{
2  "end_line": 176,
3  "start_line": 59,
4  "content": "class DemaskDecoder:\n    \"\"\"Modality-Aware Path Stochasticity and Stability Decoding\n\n    Algorithm features:\n      - Uses strict KLASS Early-Escaping SOTA logic for structured Semi-AR modalities (Math/Code)\n        to minimize used steps and maximize rigorous logical stability.\n      - Uses Stochastic Uniform Topk-Margin logic + Temperature Injection for fully parallel modalities\n        (open-ended Text generation) to promote maximal lexical diversity and smoothness.\n      - Safely guarantees terminal sequence completion, preventing corrupt mask fragments.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        kl_hist = torch.zeros((1, x.shape[1], self.history_length),\n                              dtype=torch.float64, device=x.device)\n        p_prev = torch.zeros((1, x.shape[1], V), dtype=torch.float64,\n                             device=x.device)\n        used = 0\n        is_parallel = (block_length == gen_length)\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n            \n            num_xfer = get_num_transfer_tokens((x[:, bs:be] == mid), steps_per_block)\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                \n                eps = 1e-12\n                temp = self.temperature\n                \n                if temp > 0:\n                    gumbel_t = -torch.log(-torch.log(torch.rand_like(p_curr) + eps) + eps)\n                    x0 = torch.argmax(torch.log(p_curr + eps) + gumbel_t * temp, dim=-1)\n                elif is_parallel:\n                    temp_local = 0.1\n                    gumbel_t = -torch.log(-torch.log(torch.rand_like(p_curr) + eps) + eps)\n                    x0 = torch.argmax(torch.log(p_curr + eps) + gumbel_t * temp_local, dim=-1)\n                else:\n                    x0 = torch.argmax(p_curr, dim=-1)\n                    \n                sorted_p, _ = torch.sort(p_curr, dim=-1, descending=True)\n                conf = sorted_p[..., 0]\n                margin = sorted_p[..., 0] - sorted_p[..., 1]\n\n                kl = (p_curr * (torch.log(p_curr + eps) - torch.log(p_prev + eps))).sum(-1)\n                kl_hist = torch.roll(kl_hist, -1, dims=-1)\n                kl_hist[..., -1] = kl\n                p_prev = p_curr.clone()\n\n                if step >= self.history_length - 1:\n                    stable = torch.all(kl_hist < self.kl_threshold, dim=-1)\n                else:\n                    stable = torch.zeros_like(conf, dtype=torch.bool)\n                \n                ready = stable & (conf > self.conf_threshold) & mask_idx\n\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n                \n                for j in range(ready.shape[0]):\n                    rdy = torch.where(ready[j])[0]\n                    num_ready = len(rdy)\n                    \n                    if num_ready > 0:\n                        xfer[j, rdy] = True\n                        if is_parallel:\n                            c = margin[j].clone()\n                            gumbel_c = -torch.log(-torch.log(torch.rand_like(c) + eps) + eps)\n                            c = c + 0.15 * gumbel_c\n                            c[~mask_idx[j]] = -float(\"inf\")\n                            c[rdy] = -float(\"inf\")\n                            num_xfer_step = int(num_xfer[j, step].item())\n                            rem_to_unmask = num_xfer_step - num_ready\n                            if rem_to_unmask > 0:\n                                rem_to_unmask = min(rem_to_unmask, int((c > -float(\"inf\")).sum().item()))\n                                if rem_to_unmask > 0:\n                                    _, topk = torch.topk(c, rem_to_unmask)\n                                    xfer[j, topk] = True\n                    else:\n                        if is_parallel:\n                            c = margin[j].clone()\n                            gumbel_c = -torch.log(-torch.log(torch.rand_like(c) + eps) + eps)\n                            c = c + 0.15 * gumbel_c\n                        else:\n                            c = conf[j].clone()\n                            \n                        c[~mask_idx[j]] = -float(\"inf\")\n                        rem_to_unmask = int(num_xfer[j, step].item())\n                        rem_to_unmask = min(rem_to_unmask, int((c > -float(\"inf\")).sum().item()))\n                        if rem_to_unmask > 0:\n                            _, topk = torch.topk(c, rem_to_unmask)\n                            xfer[j, topk] = True\n\n                if step == steps_per_block - 1:\n                    xfer = xfer | mask_idx\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n\n        return x, used",
5  "op": "replace",
6  "filename": "LLaDA/custom_demask_eval.py"
7}
Tool ResultStep 7
OK: Replaced lines 59..176 with 130 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–188.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–188 | total: 528 lines]
    59  class DemaskDecoder:
    60      """Modality-Aware Path Stochasticity and Stability Decoding
    61  
       ... (124 more lines) ...
   186                  used += 1
   187  
   188          return x, used
AssistantStep 8test
Tool ResultStep 8
ERROR: Test budget exhausted (3/3). You MUST call submit(n=N) to choose which test result to submit as final.
AssistantStep 9submit
Tool Input
1{
2  "n": 1
3}
Tool ResultStep 9
[submit] Submitting result from test #1 as final.

### llada-math (scripts/llada_math.sh)
Generation progress (llada-math):
TRAIN_METRICS: math 460/500 acc=0.317 avg_steps=117.6
TRAIN_METRICS: math 470/500 acc=0.321 avg_steps=117.1
TRAIN_METRICS: math 480/500 acc=0.319 avg_steps=117.4
TRAIN_METRICS: math 490/500 acc=0.318 avg_steps=117.8
TRAIN_METRICS: math 500/500 acc=0.318 avg_steps=118.3
accuracy_llada-math: 0.3180
avg_steps_llada-math: 118.3100
n_samples_llada-math: 500.0000

### llada-humaneval (scripts/llada_humaneval.sh)
Generation progress (llada-humaneval):
TRAIN_METRICS: humaneval 120/164 pass@1=0.458 avg_steps=89.0
TRAIN_METRICS: humaneval 130/164 pass@1=0.446 avg_steps=90.0
TRAIN_METRICS: humaneval 140/164 pass@1=0.421 avg_steps=91.8
TRAIN_METRICS: humaneval 150/164 pass@1=0.413 avg_steps=91.3
TRAIN_METRICS: humaneval 160/164 pass@1=0.406 avg_steps=91.3
accuracy_llada-humaneval: 0.4024
avg_steps_llada-humaneval: 91.4500
n_samples_llada-humaneval: 164.0000

### dream-text (s