Agent Conversation: qwen/qwen3.6-plus

Model: qwen/qwen3.6-plus
Total steps: 17
Messages: 36
UserStep 0

Task: llm-dllm-demask-strategy

Masked Diffusion LM: Demasking Strategy

Research Question

Design a better demasking (decoding) strategy for masked diffusion language models. The strategy must generalize across different decoding regimes:

  • Block-based semi-autoregressive decoding for downstream task accuracy (LLaDA on MATH/HumanEval, following the KLASS protocol)
  • Fully-parallel decoding for open-ended text generation (Dream on prefix-conditioned C4 continuation, measured by perplexity / diversity)

Background

Masked diffusion LMs (LLaDA, Dream) generate by starting from a fully masked generation region and iteratively unmasking over steps denoising iterations. A demasking strategy decides at each step:

  1. Schedule: how many tokens to unmask
  2. Position selection: which masked positions to unmask
  3. Token assignment: what token id to place

Decoding can be semi-autoregressive (when block_length < gen_length, process one block at a time) or fully parallel (block_length == gen_length, all positions decoded together).

What You Can Modify

Edit the DemaskDecoder class in LLaDA/custom_demask_eval.py (lines 59-151).

Interface

class DemaskDecoder:
    def __init__(self, mask_id, temperature=0.0,
                 conf_threshold=0.9, kl_threshold=0.01, history_length=2):
        ...

    @torch.no_grad()
    def decode(self, model, input_ids, gen_length, steps, block_length):
        # Returns (x_output [1, prompt_len + gen_length], used_steps)

get_num_transfer_tokens(mask, steps) is available outside the editable region — returns the uniform schedule (mask.sum() // steps per step).

Constraints

  • gen_length % block_length == 0. When equal, decoding is fully parallel.
  • Process blocks sequentially (no early-decoding into later blocks).
  • Always return [1, prompt_len + gen_length].
  • used_steps counts model forward passes (lower = more efficient).

Evaluation

Benchmarks

LabelTaskModelgen_lenstepsblock_lenMetrics
llada-mathMATH-500LLaDA-8B-Instruct25625664accuracy + avg_steps
llada-humanevalHumanEval (164)LLaDA-8B-Instruct25625664accuracy + avg_steps
dream-textC4 prefix-continuation (256 samples, 32-tok prefix → 224-tok continuation)Dream-v0-Instruct-7B224256224gen_ppl + MAUVE + entropy + rep2 + avg_steps

Metrics

MetricDirectionWhereDescription
accuracymath/humanevalexact-match (MATH) or pass@1 (HumanEval)
gen_ppltextConditional perplexity via GPT-2-Large
mauvetextDistributional similarity to C4 reference text
entropytextBigram entropy (lexical diversity)
rep2textRepeated bigram ratio
avg_stepsallActual model forward passes used

AssistantStep 1edit
Tool Input
1{
2  "content": "class DemaskDecoder:\n    \"\"\"Entropy-Stability Adaptive decoder for masked diffusion LMs.\n\n    Novel contributions:\n    1. Entropy-Stability Score: Uses predictive entropy as primary signal\n       for position selection. Low entropy = model is unambiguous about\n       what token belongs there, so it's safe to commit.\n    2. EMA-based stability: Tracks exponential moving average of predictions\n       and measures KL(p_curr || ema) as stability signal. Avoids the\n       zero-initialization bug in KLASS's p_prev tracking.\n    3. Progressive threshold adaptation: The effective confidence threshold\n       decays over steps, starting conservative (ensuring correctness) and\n       relaxing as context builds. Also adapts based on how many tokens\n       remain vs steps left.\n    4. Fallback scoring: When not enough positions meet thresholds, uses a\n       hybrid score combining confidence, margin, and stability rather than\n       pure confidence.\n\n    This strategy generalizes across regimes:\n    - Semi-AR (block_length < gen_length): careful unmasking for accuracy\n    - Fully-parallel (block_length == gen_length): entropy-guided for\n      both quality and diversity in open-ended text generation.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        # Adaptive base thresholds\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n        # Hardcoded novel hyperparameters\n        self._entropy_threshold = 0.5        # max entropy to consider \"unambiguous\"\n        self._ema_decay = 0.7                # EMA decay factor for stability tracking\n        self._margin_weight = 0.3            # weight for margin in fallback score\n        self._entropy_weight = 0.4           # weight for entropy in fallback score\n        self._max_entropy = None             # lazily computed from vocab size\n\n    def _adapt_conf_threshold(self, step: int, total_steps: int,\n                              frac_masked: float) -> float:\n        \"\"\"Adapt confidence threshold: conservative early, relaxed later.\n\n        When many tokens remain (frac_masked high), be conservative.\n        As we progress, relax to ensure all tokens get unmasked.\n        \"\"\"\n        # Step-based decay: start at base, decay toward 0.5 by end\n        step_factor = 1.0 - 0.5 * (step / max(total_steps - 1, 1))\n        # Mask-aware: if many tokens remain relative to steps, stay conservative\n        step_budget = max(total_steps - step, 1)\n        urgency = min(frac_masked * total_steps / step_budget, 1.0)\n        urgency_factor = 1.0 - 0.3 * urgency  # reduce threshold when need to hurry\n        return self.conf_threshold * step_factor * urgency_factor\n\n    def _adapt_ema_decay(self, step: int, total_steps: int) -> float:\n        \"\"\"Adapt EMA decay: faster tracking early, more stable later.\"\"\"\n        if total_steps <= 1:\n            return self._ema_decay\n        # Early steps: lower decay (faster adaptation)\n        # Later steps: higher decay (smooth, stable tracking)\n        progress = step / max(total_steps - 1, 1)\n        return min(0.3 + 0.7 * progress, 0.9)\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        # Set max entropy for normalization\n        if self._max_entropy is None:\n            self._max_entropy = torch.log(torch.tensor(V, dtype=torch.float64))\n\n        # EMA prediction distribution for stability tracking\n        # Initialize with uniform to avoid zero-initialization KL bug\n        eps_unif = 1.0 / V\n        p_ema = torch.full((1, x.shape[1], V), eps_unif,\n                           dtype=torch.float64, device=x.device)\n\n        # Per-position confidence history for adaptive thresholds\n        conf_history = torch.zeros((1, x.shape[1], 2),\n                                   dtype=torch.float64, device=x.device)\n\n        used = 0\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                x0 = torch.argmax(p_curr, dim=-1)\n                conf = torch.gather(p_curr, -1, x0.unsqueeze(-1)).squeeze(-1)\n\n                # Compute predictive entropy for each position\n                ent = -(p_curr * torch.log(p_curr + 1e-12)).sum(-1)\n                # Normalize entropy to [0, 1]\n                norm_ent = ent / self._max_entropy\n\n                # Compute margin (top1 - top2 probability)\n                sorted_probs, _ = torch.sort(p_curr, dim=-1, descending=True)\n                margin = sorted_probs[..., 0] - sorted_probs[..., 1]\n\n                # Update EMA with adaptive decay\n                decay = self._adapt_ema_decay(step, steps_per_block)\n                p_ema = decay * p_ema + (1 - decay) * p_curr\n\n                # EMA-stability: KL(p_curr || p_ema)\n                # Small KL means prediction is stable\n                kl_stab = (p_curr * (torch.log(p_curr + 1e-12)\n                                     - torch.log(p_ema + 1e-12))).sum(-1)\n\n                # Compute adaptive thresholds\n                num_masked_in_block = mask_idx[:, bs:be].sum().item()\n                frac_masked = num_masked_in_block / max(block_length, 1)\n                adaptive_conf = self._adapt_conf_threshold(\n                    step, steps_per_block, frac_masked)\n\n                # Primary criteria: confident + low entropy (unambiguous) + stable\n                if step >= 1:  # Need at least one EMA update\n                    stable = kl_stab < self.kl_threshold\n                    low_ent = norm_ent < self._entropy_threshold\n                    ready = (conf > adaptive_conf) & low_ent & stable & mask_idx\n                else:\n                    ready = (conf > adaptive_conf) & mask_idx\n\n                # Count how many tokens we should transfer this step\n                num_xfer_tensor = get_num_transfer_tokens(\n                    (x[:, bs:be] == mid), steps_per_block)\n                k = int(num_xfer_tensor[0, step].item())\n\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n\n                # If we have enough ready positions, commit them\n                n_ready = ready.sum().item()\n                if n_ready >= k:\n                    # Select top-k by confidence among ready positions\n                    c_ready = conf.clone()\n                    c_ready[~ready] = -float(\"inf\")\n                    _, topk = torch.topk(c_ready, k)\n                    xfer[0, topk] = True\n                elif n_ready > 0:\n                    # Commit all ready, fill rest with hybrid score\n                    xfer = xfer | ready\n                    remaining = k - n_ready\n\n                    # Hybrid score for remaining positions\n                    # Low entropy is good (unambiguous), high confidence is good,\n                    # high margin is good (clear winner), low KL is good (stable)\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[xfer[0]] = -float(\"inf\")\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], remaining)\n                    xfer[0, topk] = True\n                else:\n                    # No ready positions: use hybrid score to fill quota\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], k)\n                    xfer[0, topk] = True\n\n                # Update conf history\n                conf_history = torch.roll(conf_history, -1, dims=-1)\n                conf_history[..., -1] = conf\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n\n        return x, used",
3  "filename": "LLaMA-Factory/src/LLaDA/custom_demask_eval.py",
4  "op": "replace",
5  "start_line": 59,
6  "end_line": 151
7}
Tool ResultStep 1
ERROR: Package 'LLaMA-Factory' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "content": "class DemaskDecoder:\n    \"\"\"Entropy-Stability Adaptive decoder for masked diffusion LMs.\n\n    Novel contributions:\n    1. Entropy-Stability Score: Uses predictive entropy as primary signal\n       for position selection. Low entropy = model is unambiguous about\n       what token belongs there, so it's safe to commit.\n    2. EMA-based stability: Tracks exponential moving average of predictions\n       and measures KL(p_curr || ema) as stability signal. Avoids the\n       zero-initialization bug in KLASS's p_prev tracking.\n    3. Progressive threshold adaptation: The effective confidence threshold\n       decays over steps, starting conservative (ensuring correctness) and\n       relaxing as context builds. Also adapts based on how many tokens\n       remain vs steps left.\n    4. Fallback scoring: When not enough positions meet thresholds, uses a\n       hybrid score combining confidence, margin, and stability rather than\n       pure confidence.\n\n    This strategy generalizes across regimes:\n    - Semi-AR (block_length < gen_length): careful unmasking for accuracy\n    - Fully-parallel (block_length == gen_length): entropy-guided for\n      both quality and diversity in open-ended text generation.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        # Adaptive base thresholds\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n        # Hardcoded novel hyperparameters\n        self._entropy_threshold = 0.5        # max entropy to consider \"unambiguous\"\n        self._ema_decay = 0.7                # EMA decay factor for stability tracking\n        self._margin_weight = 0.3            # weight for margin in fallback score\n        self._entropy_weight = 0.4           # weight for entropy in fallback score\n        self._max_entropy = None             # lazily computed from vocab size\n\n    def _adapt_conf_threshold(self, step: int, total_steps: int,\n                              frac_masked: float) -> float:\n        \"\"\"Adapt confidence threshold: conservative early, relaxed later.\n\n        When many tokens remain (frac_masked high), be conservative.\n        As we progress, relax to ensure all tokens get unmasked.\n        \"\"\"\n        # Step-based decay: start at base, decay toward 0.5 by end\n        step_factor = 1.0 - 0.5 * (step / max(total_steps - 1, 1))\n        # Mask-aware: if many tokens remain relative to steps, stay conservative\n        step_budget = max(total_steps - step, 1)\n        urgency = min(frac_masked * total_steps / step_budget, 1.0)\n        urgency_factor = 1.0 - 0.3 * urgency  # reduce threshold when need to hurry\n        return self.conf_threshold * step_factor * urgency_factor\n\n    def _adapt_ema_decay(self, step: int, total_steps: int) -> float:\n        \"\"\"Adapt EMA decay: faster tracking early, more stable later.\"\"\"\n        if total_steps <= 1:\n            return self._ema_decay\n        # Early steps: lower decay (faster adaptation)\n        # Later steps: higher decay (smooth, stable tracking)\n        progress = step / max(total_steps - 1, 1)\n        return min(0.3 + 0.7 * progress, 0.9)\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        # Set max entropy for normalization\n        if self._max_entropy is None:\n            self._max_entropy = torch.log(torch.tensor(V, dtype=torch.float64))\n\n        # EMA prediction distribution for stability tracking\n        # Initialize with uniform to avoid zero-initialization KL bug\n        eps_unif = 1.0 / V\n        p_ema = torch.full((1, x.shape[1], V), eps_unif,\n                           dtype=torch.float64, device=x.device)\n\n        # Per-position confidence history for adaptive thresholds\n        conf_history = torch.zeros((1, x.shape[1], 2),\n                                   dtype=torch.float64, device=x.device)\n\n        used = 0\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                x0 = torch.argmax(p_curr, dim=-1)\n                conf = torch.gather(p_curr, -1, x0.unsqueeze(-1)).squeeze(-1)\n\n                # Compute predictive entropy for each position\n                ent = -(p_curr * torch.log(p_curr + 1e-12)).sum(-1)\n                # Normalize entropy to [0, 1]\n                norm_ent = ent / self._max_entropy\n\n                # Compute margin (top1 - top2 probability)\n                sorted_probs, _ = torch.sort(p_curr, dim=-1, descending=True)\n                margin = sorted_probs[..., 0] - sorted_probs[..., 1]\n\n                # Update EMA with adaptive decay\n                decay = self._adapt_ema_decay(step, steps_per_block)\n                p_ema = decay * p_ema + (1 - decay) * p_curr\n\n                # EMA-stability: KL(p_curr || p_ema)\n                # Small KL means prediction is stable\n                kl_stab = (p_curr * (torch.log(p_curr + 1e-12)\n                                     - torch.log(p_ema + 1e-12))).sum(-1)\n\n                # Compute adaptive thresholds\n                num_masked_in_block = mask_idx[:, bs:be].sum().item()\n                frac_masked = num_masked_in_block / max(block_length, 1)\n                adaptive_conf = self._adapt_conf_threshold(\n                    step, steps_per_block, frac_masked)\n\n                # Primary criteria: confident + low entropy (unambiguous) + stable\n                if step >= 1:  # Need at least one EMA update\n                    stable = kl_stab < self.kl_threshold\n                    low_ent = norm_ent < self._entropy_threshold\n                    ready = (conf > adaptive_conf) & low_ent & stable & mask_idx\n                else:\n                    ready = (conf > adaptive_conf) & mask_idx\n\n                # Count how many tokens we should transfer this step\n                num_xfer_tensor = get_num_transfer_tokens(\n                    (x[:, bs:be] == mid), steps_per_block)\n                k = int(num_xfer_tensor[0, step].item())\n\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n\n                # If we have enough ready positions, commit them\n                n_ready = ready.sum().item()\n                if n_ready >= k:\n                    # Select top-k by confidence among ready positions\n                    c_ready = conf.clone()\n                    c_ready[~ready] = -float(\"inf\")\n                    _, topk = torch.topk(c_ready, k)\n                    xfer[0, topk] = True\n                elif n_ready > 0:\n                    # Commit all ready, fill rest with hybrid score\n                    xfer = xfer | ready\n                    remaining = k - n_ready\n\n                    # Hybrid score for remaining positions\n                    # Low entropy is good (unambiguous), high confidence is good,\n                    # high margin is good (clear winner), low KL is good (stable)\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[xfer[0]] = -float(\"inf\")\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], remaining)\n                    xfer[0, topk] = True\n                else:\n                    # No ready positions: use hybrid score to fill quota\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], k)\n                    xfer[0, topk] = True\n\n                # Update conf history\n                conf_history = torch.roll(conf_history, -1, dims=-1)\n                conf_history[..., -1] = conf\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n\n        return x, used",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "op": "replace",
5  "start_line": 59,
6  "end_line": 151
7}
Tool ResultStep 2
OK: Replaced lines 59..151 with 196 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–254.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–254 | total: 594 lines]
    59  class DemaskDecoder:
    60      """Entropy-Stability Adaptive decoder for masked diffusion LMs.
    61  
       ... (190 more lines) ...
   252                  used += 1
   253  
   254          return x, used
AssistantStep 3edit
Tool Input
1{
2  "content": "class DemaskDecoder:\n    \"\"\"Entropy-Stability Adaptive decoder for masked diffusion LMs.\n\n    Novel contributions:\n    1. Entropy-Stability Score: Uses predictive entropy as primary signal\n       for position selection. Low entropy = model is unambiguous about\n       what token belongs there, so it's safe to commit.\n    2. EMA-based stability: Tracks exponential moving average of predictions\n       and measures KL(p_curr || ema) as stability signal. Avoids the\n       zero-initialization bug in KLASS's p_prev tracking.\n    3. Progressive threshold adaptation: The effective confidence threshold\n       decays over steps, starting conservative (ensuring correctness) and\n       relaxing as context builds. Also adapts based on how many tokens\n       remain vs steps left.\n    4. Fallback scoring: When not enough positions meet thresholds, uses a\n       hybrid score combining confidence, margin, and stability rather than\n       pure confidence.\n\n    This strategy generalizes across regimes:\n    - Semi-AR (block_length < gen_length): careful unmasking for accuracy\n    - Fully-parallel (block_length == gen_length): entropy-guided for\n      both quality and diversity in open-ended text generation.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        # Adaptive base thresholds\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n        # Hardcoded novel hyperparameters\n        self._entropy_threshold = 0.5        # max entropy to consider \"unambiguous\"\n        self._ema_decay = 0.7                # EMA decay factor for stability tracking\n        self._margin_weight = 0.3            # weight for margin in fallback score\n        self._entropy_weight = 0.4           # weight for entropy in fallback score\n        self._max_entropy = None             # lazily computed from vocab size\n\n    def _adapt_conf_threshold(self, step: int, total_steps: int,\n                              frac_masked: float) -> float:\n        \"\"\"Adapt confidence threshold: conservative early, relaxed later.\n\n        When many tokens remain (frac_masked high), be conservative.\n        As we progress, relax to ensure all tokens get unmasked.\n        \"\"\"\n        # Step-based decay: start at base, decay toward 0.5 by end\n        step_factor = 1.0 - 0.5 * (step / max(total_steps - 1, 1))\n        # Mask-aware: if many tokens remain relative to steps, stay conservative\n        step_budget = max(total_steps - step, 1)\n        urgency = min(frac_masked * total_steps / step_budget, 1.0)\n        urgency_factor = 1.0 - 0.3 * urgency  # reduce threshold when need to hurry\n        return self.conf_threshold * step_factor * urgency_factor\n\n    def _adapt_ema_decay(self, step: int, total_steps: int) -> float:\n        \"\"\"Adapt EMA decay: faster tracking early, more stable later.\"\"\"\n        if total_steps <= 1:\n            return self._ema_decay\n        # Early steps: lower decay (faster adaptation)\n        # Later steps: higher decay (smooth, stable tracking)\n        progress = step / max(total_steps - 1, 1)\n        return min(0.3 + 0.7 * progress, 0.9)\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        # Set max entropy for normalization\n        if self._max_entropy is None:\n            self._max_entropy = torch.log(torch.tensor(V, dtype=torch.float64))\n\n        # EMA prediction distribution for stability tracking\n        # Initialize with uniform to avoid zero-initialization KL bug\n        eps_unif = 1.0 / V\n        p_ema = torch.full((1, x.shape[1], V), eps_unif,\n                           dtype=torch.float64, device=x.device)\n\n        # Per-position confidence history for adaptive thresholds\n        conf_history = torch.zeros((1, x.shape[1], 2),\n                                   dtype=torch.float64, device=x.device)\n\n        used = 0\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                x0 = torch.argmax(p_curr, dim=-1)\n                conf = torch.gather(p_curr, -1, x0.unsqueeze(-1)).squeeze(-1)\n\n                # Compute predictive entropy for each position\n                ent = -(p_curr * torch.log(p_curr + 1e-12)).sum(-1)\n                # Normalize entropy to [0, 1]\n                norm_ent = ent / self._max_entropy\n\n                # Compute margin (top1 - top2 probability)\n                sorted_probs, _ = torch.sort(p_curr, dim=-1, descending=True)\n                margin = sorted_probs[..., 0] - sorted_probs[..., 1]\n\n                # Update EMA with adaptive decay\n                decay = self._adapt_ema_decay(step, steps_per_block)\n                p_ema = decay * p_ema + (1 - decay) * p_curr\n\n                # EMA-stability: KL(p_curr || p_ema)\n                # Small KL means prediction is stable\n                kl_stab = (p_curr * (torch.log(p_curr + 1e-12)\n                                     - torch.log(p_ema + 1e-12))).sum(-1)\n\n                # Compute adaptive thresholds\n                num_masked_in_block = mask_idx[:, bs:be].sum().item()\n                frac_masked = num_masked_in_block / max(block_length, 1)\n                adaptive_conf = self._adapt_conf_threshold(\n                    step, steps_per_block, frac_masked)\n\n                # Primary criteria: confident + low entropy (unambiguous) + stable\n                if step >= 1:  # Need at least one EMA update\n                    stable = kl_stab < self.kl_threshold\n                    low_ent = norm_ent < self._entropy_threshold\n                    ready = (conf > adaptive_conf) & low_ent & stable & mask_idx\n                else:\n                    ready = (conf > adaptive_conf) & mask_idx\n\n                # Count how many tokens we should transfer this step\n                num_xfer_tensor = get_num_transfer_tokens(\n                    (x[:, bs:be] == mid), steps_per_block)\n                k = int(num_xfer_tensor[0, step].item())\n\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n\n                # If we have enough ready positions, commit them\n                n_ready = ready.sum().item()\n                if n_ready >= k:\n                    # Select top-k by confidence among ready positions\n                    c_ready = conf.clone()\n                    c_ready[~ready] = -float(\"inf\")\n                    _, topk = torch.topk(c_ready, k)\n                    xfer[0, topk] = True\n                elif n_ready > 0:\n                    # Commit all ready, fill rest with hybrid score\n                    xfer = xfer | ready\n                    remaining = k - n_ready\n\n                    # Hybrid score for remaining positions\n                    # Low entropy is good (unambiguous), high confidence is good,\n                    # high margin is good (clear winner), low KL is good (stable)\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[xfer[0]] = -float(\"inf\")\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], remaining)\n                    xfer[0, topk] = True\n                else:\n                    # No ready positions: use hybrid score to fill quota\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], k)\n                    xfer[0, topk] = True\n\n                # Update conf history\n                conf_history = torch.roll(conf_history, -1, dims=-1)\n                conf_history[..., -1] = conf\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n\n        return x, used",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "op": "replace",
5  "start_line": 59,
6  "end_line": 151
7}
Tool ResultStep 3
OK: Replaced lines 59..151 with 196 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–357.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–357 | total: 697 lines]
    59  class DemaskDecoder:
    60      """Entropy-Stability Adaptive decoder for masked diffusion LMs.
    61  
       ... (293 more lines) ...
   355                  used += 1
   356  
   357          return x, used
AssistantStep 4undo
Tool Input
1{
2  "n": 1
3}
Tool ResultStep 4
Undo complete:
Restored: LLaDA/custom_demask_eval.py
AssistantStep 5edit
Tool Input
1{
2  "content": "class DemaskDecoder:\n    \"\"\"Entropy-Stability Adaptive decoder for masked diffusion LMs.\n\n    Novel contributions:\n    1. Entropy-Stability Score: Uses predictive entropy as primary signal\n       for position selection. Low entropy = model is unambiguous about\n       what token belongs there, so it's safe to commit.\n    2. EMA-based stability: Tracks exponential moving average of predictions\n       and measures KL(p_curr || ema) as stability signal. Avoids the\n       zero-initialization bug in KLASS's p_prev tracking.\n    3. Progressive threshold adaptation: The effective confidence threshold\n       decays over steps, starting conservative (ensuring correctness) and\n       relaxing as context builds. Also adapts based on how many tokens\n       remain vs steps left.\n    4. Fallback scoring: When not enough positions meet thresholds, uses a\n       hybrid score combining confidence, margin, and stability rather than\n       pure confidence.\n\n    This strategy generalizes across regimes:\n    - Semi-AR (block_length < gen_length): careful unmasking for accuracy\n    - Fully-parallel (block_length == gen_length): entropy-guided for\n      both quality and diversity in open-ended text generation.\n    \"\"\"\n\n    def __init__(self, mask_id: int, temperature: float = 0.0,\n                 conf_threshold: float = 0.9, kl_threshold: float = 0.01,\n                 history_length: int = 2):\n        self.mask_id = mask_id\n        self.temperature = temperature\n        self.conf_threshold = conf_threshold\n        self.kl_threshold = kl_threshold\n        self.history_length = history_length\n        # Novel hyperparameters\n        self._entropy_threshold = 0.5\n        self._ema_decay = 0.7\n        self._margin_weight = 0.3\n        self._entropy_weight = 0.4\n        self._max_entropy = None\n\n    def _adapt_conf_threshold(self, step, total_steps, frac_masked):\n        \"\"\"Adapt confidence threshold: conservative early, relaxed later.\"\"\"\n        step_factor = 1.0 - 0.5 * (step / max(total_steps - 1, 1))\n        step_budget = max(total_steps - step, 1)\n        urgency = min(frac_masked * total_steps / step_budget, 1.0)\n        urgency_factor = 1.0 - 0.3 * urgency\n        return self.conf_threshold * step_factor * urgency_factor\n\n    def _adapt_ema_decay(self, step, total_steps):\n        \"\"\"Adapt EMA decay: faster tracking early, more stable later.\"\"\"\n        if total_steps <= 1:\n            return self._ema_decay\n        progress = step / max(total_steps - 1, 1)\n        return min(0.3 + 0.7 * progress, 0.9)\n\n    @torch.no_grad()\n    def decode(self, model, input_ids, gen_length: int, steps: int,\n               block_length: int):\n        mid = self.mask_id\n        x = torch.full((1, input_ids.shape[1] + gen_length), mid,\n                       dtype=torch.long, device=model.device)\n        x[:, :input_ids.shape[1]] = input_ids.clone()\n        assert gen_length % block_length == 0\n        num_blocks = gen_length // block_length\n        assert steps % num_blocks == 0\n        steps_per_block = steps // num_blocks\n\n        V = model.lm_head.out_features if hasattr(model, \"lm_head\") \\\n                                       else model.config.vocab_size\n        if self._max_entropy is None:\n            self._max_entropy = torch.log(torch.tensor(V, dtype=torch.float64))\n\n        # EMA distribution for stability tracking (uniform init avoids KL bug)\n        eps_unif = 1.0 / V\n        p_ema = torch.full((1, x.shape[1], V), eps_unif,\n                           dtype=torch.float64, device=x.device)\n\n        used = 0\n\n        for b in range(num_blocks):\n            bs = input_ids.shape[1] + b * block_length\n            be = bs + block_length\n            num_xfer = get_num_transfer_tokens(\n                (x[:, bs:be] == mid), steps_per_block)\n\n            for step in range(steps_per_block):\n                mask_idx = (x == mid)\n                block_m = torch.zeros_like(mask_idx)\n                block_m[:, bs:be] = True\n                mask_idx = mask_idx & block_m\n                if not mask_idx.any():\n                    break\n\n                logits = model(x).logits\n                p_curr = F.softmax(logits.to(torch.float64), dim=-1)\n                x0 = torch.argmax(p_curr, dim=-1)\n                conf = torch.gather(p_curr, -1, x0.unsqueeze(-1)).squeeze(-1)\n\n                # Predictive entropy, normalized to [0,1]\n                ent = -(p_curr * torch.log(p_curr + 1e-12)).sum(-1)\n                norm_ent = ent / self._max_entropy\n\n                # Margin (top1 - top2)\n                sorted_probs, _ = torch.sort(p_curr, dim=-1, descending=True)\n                margin = sorted_probs[..., 0] - sorted_probs[..., 1]\n\n                # Update EMA with adaptive decay\n                decay = self._adapt_ema_decay(step, steps_per_block)\n                p_ema = decay * p_ema + (1 - decay) * p_curr\n\n                # EMA-stability: KL(p_curr || p_ema)\n                kl_stab = (p_curr * (torch.log(p_curr + 1e-12)\n                                     - torch.log(p_ema + 1e-12))).sum(-1)\n\n                # Adaptive thresholds\n                num_masked_in_block = mask_idx[:, bs:be].sum().item()\n                frac_masked = num_masked_in_block / max(block_length, 1)\n                adaptive_conf = self._adapt_conf_threshold(\n                    step, steps_per_block, frac_masked)\n\n                if step >= 1:\n                    stable = kl_stab < self.kl_threshold\n                    low_ent = norm_ent < self._entropy_threshold\n                    ready = (conf > adaptive_conf) & low_ent & stable & mask_idx\n                else:\n                    ready = (conf > adaptive_conf) & mask_idx\n\n                k = int(num_xfer[0, step].item())\n                n_ready = ready.sum().item()\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n\n                if n_ready >= k:\n                    c_ready = conf.clone()\n                    c_ready[~ready] = -float(\"inf\")\n                    _, topk = torch.topk(c_ready, k)\n                    xfer[0, topk] = True\n                elif n_ready > 0:\n                    xfer = xfer | ready\n                    remaining = k - n_ready\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[xfer[0]] = -float(\"inf\")\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], remaining)\n                    xfer[0, topk] = True\n                else:\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    )\n                    hybrid_score[~mask_idx[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score[0], k)\n                    xfer[0, topk] = True\n\n                x = torch.where(xfer, x0, x)\n                used += 1\n        return x, used",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "op": "replace",
5  "start_line": 59,
6  "end_line": 151
7}
Tool ResultStep 5
OK: Replaced lines 59..151 with 162 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–323.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–323 | total: 663 lines]
    59  class DemaskDecoder:
    60      """Entropy-Stability Adaptive decoder for masked diffusion LMs.
    61  
       ... (259 more lines) ...
   321                  used += 1
   322  
   323          return x, used
AssistantStep 6test
Tool ResultStep 6
[Test #1] (2 tests remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
[COMMAND FAILED exit=1]
[INFO] Loading llada...

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]
Loading checkpoint shards:  17%|█▋        | 1/6 [00:00<00:00,  8.04it/s]
Loading checkpoint shards:  33%|███▎      | 2/6 [00:00<00:00,  8.45it/s]
Loading checkpoint shards:  50%|█████     | 3/6 [00:00<00:00,  8.25it/s]
Loading checkpoint shards:  67%|██████▋   | 4/6 [00:00<00:00,  8.32it/s]
Loading checkpoint shards:  83%|████████▎ | 5/6 [00:00<00:00,  8.44it/s]
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  8.55it/s]
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  8.43it/s]
[INFO] task=math steps=256 gen_length=256 block_length=64
Traceback (most recent call last):
  File "/workspace/LLaDA/custom_demask_eval.py", line 663, in <module>
    main()
  File "/workspace/LLaDA/custom_demask_eval.py", line 635, in mai
AssistantStep 7edit
Tool Input
1{
2  "content": "                k = int(num_xfer[0, step].item())\n                n_ready = ready.sum().item()\n                xfer = torch.zeros_like(x0, dtype=torch.bool)\n\n                if n_ready >= k:\n                    c_ready = conf.clone()\n                    c_ready[~ready] = -float(\"inf\")\n                    _, topk = torch.topk(c_ready.view(-1), k)\n                    xfer.view(-1)[topk] = True\n                elif n_ready > 0:\n                    xfer = xfer | ready\n                    remaining = k - n_ready\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    ).view(-1)\n                    hybrid_score[xfer.view(-1)] = -float(\"inf\")\n                    hybrid_score[torch.where(~mask_idx.view(-1))[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score, remaining)\n                    xfer.view(-1)[topk] = True\n                else:\n                    hybrid_score = (\n                        self._entropy_weight * (1.0 - norm_ent) +\n                        0.3 * conf +\n                        self._margin_weight * margin +\n                        0.1 * (1.0 - kl_stab.clamp(0, 1))\n                    ).view(-1)\n                    hybrid_score[torch.where(~mask_idx.view(-1))[0]] = -float(\"inf\")\n                    _, topk = torch.topk(hybrid_score, k)\n                    xfer.view(-1)[topk] = True",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "op": "replace",
5  "start_line": 195,
6  "end_line": 219
7}
Tool ResultStep 7
OK: Replaced lines 195..219 with 32 line(s) in LLaDA/custom_demask_eval.py. Editable range: 59–330.

[Current file: LLaDA/custom_demask_eval.py | editable: 59–330 | total: 670 lines]
    59  class DemaskDecoder:
    60      """Entropy-Stability Adaptive decoder for masked diffusion LMs.
    61  
       ... (266 more lines) ...
   328                  used += 1
   329  
   330          return x, used
AssistantStep 8test
Tool ResultStep 8
[Test #2] (1 test remaining; call submit(n=N) to choose which test result to submit as final)

### llada-math (scripts/llada_math.sh)
[COMMAND FAILED exit=1]
  File "/workspace/LLaDA/custom_demask_eval.py", line 195
    k = int(num_xfer[0, step].item())
    ^
IndentationError: expected an indented block after 'elif' statement on line 194


### llada-humaneval (scripts/llada_humaneval.sh)
[COMMAND FAILED exit=1]
  File "/workspace/LLaDA/custom_demask_eval.py", line 195
    k = int(num_xfer[0, step].item())
    ^
IndentationError: expected an indented block after 'elif' statement on line 194
AssistantStep 9read_file
Tool Input
1{
2  "end_line": "230",
3  "filename": "LLaDA/custom_demask_eval.py",
4  "start_line": "180"
5}
Tool ResultStep 9
ERROR: Unknown tool 'read_file'