llm-dllm-demask-strategy

Language ModelsLLaDArigorous codebase

Description

Masked Diffusion LM: Demasking Strategy

Research Question

Design a better demasking (decoding) strategy for masked diffusion language models. The strategy must generalize across different decoding regimes:

  • Block-based semi-autoregressive decoding for downstream task accuracy (LLaDA on MATH/HumanEval, following the KLASS protocol)
  • Fully-parallel decoding for open-ended text generation (Dream on prefix-conditioned C4 continuation, measured by perplexity / diversity)

Background

Masked diffusion LMs (LLaDA, Dream) generate by starting from a fully masked generation region and iteratively unmasking over steps denoising iterations. A demasking strategy decides at each step:

  1. Schedule: how many tokens to unmask
  2. Position selection: which masked positions to unmask
  3. Token assignment: what token id to place

Decoding can be semi-autoregressive (when block_length < gen_length, process one block at a time) or fully parallel (block_length == gen_length, all positions decoded together).

What You Can Modify

Edit the DemaskDecoder class in LLaDA/custom_demask_eval.py (lines 59-151).

Interface

class DemaskDecoder:
    def __init__(self, mask_id, temperature=0.0,
                 conf_threshold=0.9, kl_threshold=0.01, history_length=2):
        ...

    @torch.no_grad()
    def decode(self, model, input_ids, gen_length, steps, block_length):
        # Returns (x_output [1, prompt_len + gen_length], used_steps)

get_num_transfer_tokens(mask, steps) is available outside the editable region — returns the uniform schedule (mask.sum() // steps per step).

Constraints

  • gen_length % block_length == 0. When equal, decoding is fully parallel.
  • Process blocks sequentially (no early-decoding into later blocks).
  • Always return [1, prompt_len + gen_length].
  • used_steps counts model forward passes (lower = more efficient).

Evaluation

Benchmarks

LabelTaskModelgen_lenstepsblock_lenMetrics
llada-mathMATH-500LLaDA-8B-Instruct25625664accuracy + avg_steps
llada-humanevalHumanEval (164)LLaDA-8B-Instruct25625664accuracy + avg_steps
dream-textC4 prefix-continuation (256 samples, 32-tok prefix → 224-tok continuation)Dream-v0-Instruct-7B224256224gen_ppl + MAUVE + entropy + rep2 + avg_steps

Metrics

MetricDirectionWhereDescription
accuracymath/humanevalexact-match (MATH) or pass@1 (HumanEval)
gen_ppltextConditional perplexity via GPT-2-Large
mauvetextDistributional similarity to C4 reference text
entropytextBigram entropy (lexical diversity)
rep2textRepeated bigram ratio
avg_stepsallActual model forward passes used

Protocol references

  • MATH/HumanEval: KLASS (Kim et al., NeurIPS 2025; arXiv 2511.05664). We use KLASS's exact data/math_test.json, prompts, and utils.py for answer extraction (extract_math_answer, compare_answers).
  • Text generation: prefix-conditioned C4 continuation, similar to MDLM / ReMDM evaluation but with conditioning on a 32-token prefix.

Baselines (from KLASS algorithms)

  • confidence_greedy — LLaDA's low_confidence remasking: top-k by max prob.
  • topk_margin — Dream's topk_margin: top-k by (top1 prob − top2 prob).
  • klass — SOTA: KL-adaptive stability + confidence thresholds.

Reference Performance

LLaDA paper (EVAL.md, gen_length=256/steps=256/block_length=256): MATH = 30.3%, HumanEval = 32.9% on LLaDA-8B-Base.

KLASS paper on LLaDA-8B-Instruct, MATH (with block_length=64): ~33.8% (KLASS), reducing steps by 40-70%.

Code

custom_demask_eval.py
EditableRead-only
1"""Downstream task evaluation (MATH, HumanEval) for masked diffusion LMs.
2
3Following the KLASS evaluation protocol (Kim et al., NeurIPS 2025):
4 https://github.com/shkim0116/KLASS
5"""
6
7from __future__ import annotations
8
9import argparse
10import gzip
11import json
12import os
13import re
14import sys
15import time

Results

ModelTypeaccuracy llada-math avg steps llada-math n samples llada-math accuracy llada-humaneval avg steps llada-humaneval n samples llada-humaneval gen ppl dream-text mauve dream-text entropy dream-text rep2 dream-text avg steps dream-text n samples dream-text gen ppl llada-16step mauve llada-16step entropy llada-16step rep2 llada-16step avg steps llada-16step gen ppl llada-64step mauve llada-64step entropy llada-64step rep2 llada-64step avg steps llada-64step gen ppl dream-16step mauve dream-16step entropy dream-16step rep2 dream-16step avg steps dream-16step gen ppl dream-8step mauve dream-8step entropy dream-8step rep2 dream-8step avg steps dream-8step gen ppl dream-64step mauve dream-64step entropy dream-64step rep2 dream-64step avg steps dream-64step gen ppl dream-128step mauve dream-128step entropy dream-128step rep2 dream-128step avg steps dream-128step gen ppl llada-256step mauve llada-256step entropy llada-256step rep2 llada-256step avg steps llada-256step accuracy dream-humaneval avg steps dream-humaneval n samples dream-humaneval accuracy dream-math avg steps dream-math n samples dream-math
confidence_greedybaseline0.316256.000500.0000.366256.000164.000170.6090.0326.4130.013224.000256.000-----------------------------------------
confidence_greedybaseline------------9999.0000.0314.7690.61216.0009999.0000.0489.4240.64864.000669.2180.0307.8360.03916.000--------------------------
confidence_greedybaseline----------------------669.2180.0307.8360.03916.000383.1150.0237.7970.0958.000108.9390.1415.4210.00264.000----------------
confidence_greedybaseline----------------------669.2180.0307.8360.03916.000----------136.1840.0565.6300.015128.0009999.0000.09712.2200.658224.000------
confidence_greedybaseline-----------------------------------------------0.000256.000164.000---
confidence_greedybaseline-----------------------------------------------------
confidence_greedybaseline------170.6090.0326.4130.013224.000256.000-----------------------------------------
klassbaseline0.334127.860500.0000.37293.810164.00064.2190.0686.3240.01688.540256.000-----------------------------------------
klassbaseline0.334127.860500.000-----------------------------------------------0.00444.860500.000
klass_klbaseline------------9999.0000.0373.7700.55115.2109999.0000.0244.3680.62151.270299.4230.0296.3710.05315.880--------------------------
klass_klbaseline----------------------299.4230.0296.3710.05315.880138.9320.0216.4830.0988.00074.6800.0474.4160.01551.260----------------
klass_klbaseline----------------------299.4230.0296.3710.05315.880----------127.8600.0605.2980.01480.8109999.0000.12111.2670.565113.000------
klass_klbaseline-----------------------------------------------0.000129.000164.000---
prophetbaseline----------------------671.7740.0147.9100.04411.300403.2310.0237.8110.0895.610170.3920.0236.5560.01448.430----------------
prophetbaseline----------------------671.7740.0147.9100.04411.300----------182.6220.0186.3790.02596.8009999.0000.10312.2250.657181.640------
prophetbaseline-----------------------------------------------0.000208.460164.000---
randombaseline------------9999.0000.0464.2890.64916.0009999.0000.0806.4950.57664.0009999.0000.0126.6100.25516.000--------------------------
randombaseline----------------------9999.0000.0126.6100.25516.0009999.0000.0116.5800.2528.0009999.0000.0106.1010.28864.000----------------
topk_marginbaseline0.322256.000500.0000.390256.000164.000237.0500.1125.9260.025224.000256.000-----------------------------------------
topk_marginbaseline-----------------------------------------------------
topk_marginbaseline------237.0500.1125.9260.025224.000256.000-----------------------------------------
topk_marginbaseline---0.390256.000164.000237.0500.1125.9260.025224.000256.000-----------------------------------------
anthropic/claude-opus-4.6vanilla0.28457.430500.0000.37856.260164.00039.5040.0856.0260.03440.820256.000-----------------------------------------
deepseek-reasonervanilla0.328256.000500.0000.415256.000164.000221.4980.0555.2760.015224.000256.000-----------------------------------------
google/gemini-3.1-pro-previewvanilla0.318118.310500.0000.40291.450164.00012.4280.0944.3280.08049.090256.000-----------------------------------------
openai/gpt-5.4vanilla0.30883.040500.0000.37871.680164.00034.4200.0936.2320.00435.320256.000-----------------------------------------
qwen/qwen3.6-plusvanilla-----------------------------------------------------
anthropic/claude-opus-4.6agent0.304121.730500.0000.40289.390164.00028.4420.2306.1340.01335.230256.000-----------------------------------------
deepseek-reasoneragent0.290114.620500.0000.37874.800164.00026.7490.2106.0310.00726.720256.000-----------------------------------------
google/gemini-3.1-pro-previewagent0.318118.310500.0000.40291.450164.00012.4280.0944.3280.08049.090256.000-----------------------------------------
openai/gpt-5.4agent0.336149.510500.0000.378137.630164.00027.1820.1026.2160.00431.980256.000-----------------------------------------
qwen/qwen3.6-plusagent-----------------------------------------------------

Agent Conversations