dlm-dkv-policy
Description
Diffusion LM Cache: Refresh Policy
Research Question
Design a better denoising-step refresh policy for diffusion language models. Given a fixed host model (LLaDA-8B-Instruct), fixed step budget, and fixed evaluation harness, can you decide which token states to refresh and how often to refresh them so that output quality stays high while compute and memory stay low?
This task isolates one scientific question: state reuse versus refresh scheduling inside the denoising trajectory. You may not change the scheduler, model architecture, serving stack, or decoding mode.
Evaluation Setup
The harness runs real LLaDA-8B-Instruct inference end-to-end using the
dLLM-cache library. For each workload and regime the harness:
- Loads prompts from checked-in trace files (real public benchmark prompts from MMLU-Pro, GSM8K, MBPP, and NuminaMath-CoT).
- Runs a reference pass with
gen_interval=1(no caching) to establish ground-truth token sequences. - Runs the policy pass with dLLM-cache hooks active, controlled by the
editable
DLMRefreshPolicyclass. - Reports token-level exact-match quality against the reference, plus efficiency metrics.
What You Can Modify
You may edit only the DLMRefreshPolicy class in dLLM-cache/custom_dlm_eval.py
(lines 51–100 in the harness file).
The editable methods are:
refresh_mask(step_id, token_stats, budget_state)→list[bool]prompt_refresh_interval(step_id, request_meta)→intgen_refresh_interval(step_id, request_meta)→inttransfer_ratio(step_id, request_meta, token_stats)→floatfallback_action(step_id, quality_proxy)→str
Allowed fallback_action outputs: hold, refresh_all.
Token stats available in token_stats list (each element is a dict):
| Field | Type | Meaning |
|---|---|---|
importance | float 0–1 | per-token confidence score from model logits |
staleness | float 0–1 | relative step progress (0=early, 1=late) |
difficulty | float 0–1 | normalized token-distribution entropy |
similarity | float 0–1 | cosine similarity to previous-step distribution |
Budget state fields in budget_state:
| Field | Type | Meaning |
|---|---|---|
budget_scale | float | 1.0=full, 0.70=medium, 0.48=tight |
scarcity_pressure | float | 1 − budget_scale |
What You Cannot Modify
- the
dLLM-cachehost library - the visible workload presets and regimes
- metric definitions and the parser
- the reference generation pass
Workloads and Regimes
Visible workload families:
| Workload | Benchmark Source | Character |
|---|---|---|
general_instruction | MMLU-Pro | multi-choice instruction following |
math_reasoning | GSM8K | step-by-step arithmetic reasoning |
code_generation | MBPP | Python function synthesis |
reasoning_refresh_scarcity | NuminaMath-CoT | long-horizon math with shifting token importance |
Visible step regimes:
| Regime | budget_scale | Character |
|---|---|---|
full_steps | 1.00 | unconstrained |
medium_steps | 0.70 | moderate cache pressure |
tight_steps | 0.48 | high cache pressure |
Visible test scripts run:
instruction-medium:general_instruction×medium_stepsmath-tight:math_reasoning×tight_stepscode-tight:code_generation×tight_stepsreasoning-scarcity-tight:reasoning_refresh_scarcity×tight_steps
Metrics
The harness prints a TEST_METRICS: line with:
| Metric | Direction | Meaning |
|---|---|---|
quality_main | ↑ higher | token-level exact match vs uncached reference (%) |
reuse_ratio | ↑ higher | fraction of gen steps where KV is reused |
refresh_ratio | ↓ lower | 1 − reuse_ratio |
quality_efficiency_score | ↑ higher | quality_main × reuse_ratio (primary rank) |
tokens_per_s | ↑ higher | generation throughput |
peak_memory_mb | ↓ lower | peak GPU memory usage |
n_prompts | — | number of prompts evaluated |
eval_mode | — | always real_rollout |
Composite Ranking Metric
quality_efficiency_score = quality_main × reuse_ratio
This is the primary ranking signal. It rewards policies that simultaneously achieve high denoising quality and high KV-state reuse:
- A policy that refreshes everything every step:
reuse_ratio ≈ 0→ low score. - A policy that almost never refreshes:
quality_maindegrades → low score. - The optimal policy: near-reference quality at high reuse → high score.
Baselines
| Baseline | Family | Source | Status |
|---|---|---|---|
fixed_interval | fixed cadence | task-native control | anchor (weakest) |
d2cache | confidence + difficulty | D²Cache-inspired | representative |
dllm_cache_similarity | similarity-guided | dLLM-Cache-inspired | SOTA |
freecache | stable-state reuse | FreeCache-inspired | representative |
dkv_cache | importance+staleness threshold | dKV-Cache threshold | representative |
dkv_cache_greedy | importance+staleness top-fraction | dKV-Cache greedy | representative |
SOTA anchor: dllm_cache_similarity achieves the highest average
quality_efficiency_score across the four visible workloads by using
similarity-guided refresh decisions that balance quality preservation with
high KV-state reuse.
A new policy is considered an improvement when it beats dllm_cache_similarity
on quality_efficiency_score while not regressing quality_main below the
next-best baseline on any workload.
fixed_interval is the simplest baseline (fixed cadence, no token-level
adaptation). It serves as a lower-bound anchor: its refresh cadence is too
rigid to achieve competitive quality under tight budgets, so it scores
lowest on quality_efficiency_score overall.
The adaptive baselines (d2cache, freecache, dkv_cache, dkv_cache_greedy)
use token-level statistics to decide when to refresh, achieving higher
quality and reuse. Designing a policy that outperforms the best adaptive
methods across all workloads is the key challenge.
Notes
- The editable region is lines 51–100 in
dLLM-cache/custom_dlm_eval.py. mid_edit.pycreates the harness file; baseline*.edit.pyfiles apply the policy class replacement.reasoning_refresh_scarcityis the key diversity workload: it features a longer reasoning trajectory where important token groups shift across steps, making it a harder policy design problem.dkv_cache_greedyis the maximum-reuse / minimum-quality corner of the visible baseline Pareto front; it is not a target to beat on quality.- Token stats (
importance,staleness,difficulty,similarity) are computed from real model logits during each denoising step. - The
transfer_ratiocontrols what fraction of generation tokens undergo partial key/value transfer even when not fully refreshed.
Code
1"""Real LLaDA rollout harness for dlm-dkv-policy."""23from __future__ import annotations45import argparse6import json7import os8import sys9import time10from pathlib import Path1112import torch1314# dLLM-cache path resolution15_HERE = Path(__file__).resolve().parent
Results
No results yet.