llm-hybrid-posttraining
Description
LLM Hybrid Post-Training: SFT + RL Loss Combination
Objective
Design and implement a custom policy loss function that combines supervised fine-tuning (SFT) and reinforcement learning (RL/GRPO) objectives for LLM post-training. Your code goes in the compute_custom_policy_loss() function in custom_hybrid_loss.py. The read-only reference files core_algos.py and losses.py contain the vanilla policy loss and SFT loss implementations you can study.
Background
LLM post-training typically uses either SFT (learning from demonstrations) or RL (learning from reward signals), but combining them is an active research question. Key design choices include:
- Static weighting:
loss = α * sft_loss + (1-α) * rl_losswith fixed coefficients. - Curriculum scheduling: Start with SFT and gradually shift to RL (or vice versa) based on training progress (
global_step / total_steps). - Adaptive mixing: Weight the SFT term based on model competence — e.g., use more SFT when the model struggles (low reward) and more RL when it succeeds.
- Importance-weighted SFT (LUFFY): Apply importance sampling weights
exp(log_prob - sft_log_prob)to the SFT loss to avoid rigid imitation of off-policy demonstrations. - Per-question switching (HPT): For each question, choose between pure SFT and pure RL based on the model's pass rate.
Evaluation
Your policy loss is used to train Qwen3-1.7B (full parameter, non-thinking mode) using the verl framework. RL training uses the DeepMath-103K dataset for rollout generation with math-rule reward. SFT data comes from MetaMathQA — a large-scale augmented math instruction dataset providing question-answer demonstrations.
The model is evaluated on three math reasoning benchmarks:
- GSM8K — Grade school math (1,319 test problems). Metric:
val-core/openai/gsm8k/acc/mean@1. - MATH-500 — Curated 500-problem subset of MATH competition problems. Metric:
val-core/HuggingFaceH4/MATH-500/acc/mean@1. - AIME 2024 — 30 problems from the American Invitational Mathematics Examination. Metric:
val-core/aime2024/acc/mean@1.
Training runs for 15 epochs with 5 rollout samples per prompt on 4 GPUs. Higher accuracy indicates better math reasoning.
Interface Contract
The training loop calls your function via the verl policy loss registry:
@register_policy_loss("custom")
def compute_custom_policy_loss(
old_log_prob: torch.Tensor, # (bs, response_length) — rollout policy log-probs
log_prob: torch.Tensor, # (bs, response_length) — current policy log-probs
advantages: torch.Tensor, # (bs, response_length) — GRPO advantage estimates
response_mask: torch.Tensor, # (bs, response_length) — valid token mask
loss_agg_mode: str = "token-mean",
config: Optional[ActorConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
Parameters:
old_log_prob: Per-token log-probs under the rollout policy. Shape(bs, response_length).log_prob: Per-token log-probs under the current (training) policy. Shape(bs, response_length).advantages: GRPO-style advantage estimates per token. Shape(bs, response_length).response_mask: Binary mask (1 = valid response token). Shape(bs, response_length).loss_agg_mode: How to aggregate the per-token loss into a scalar (e.g.,"token-mean").config:ActorConfigwith fields including:config.clip_ratio(float): PPO clipping parameter ε (default 0.2).config.global_batch_info(dict): Contains aggregation info AND hybrid-training data:"sft_loss"(float): Pre-computed SFT cross-entropy loss for the current batch."sft_log_prob"(Tensor, (bs, response_length)): SFT demonstration log-probs under current policy."sft_labels_mask"(Tensor, (bs, response_length)): Mask for valid SFT tokens."global_step"(int): Current training step (1-indexed)."total_steps"(int): Total training steps."dp_size","batch_num_tokens","global_batch_size","loss_scale_factor": Standard aggregation fields.
Return values:
pg_loss: Scalar loss tensor (combined SFT+RL).pg_metrics: Dict of metric name → float value for logging.
Available utilities:
agg_loss(loss_mat, loss_mask, loss_agg_mode, **config.global_batch_info)— aggregate per-token loss matrix to scalar.verl_F.masked_mean(values, mask)— mean over masked elements.verl_F.masked_sum(values, mask)— sum over masked elements.torch,numpy(already imported).
Important notes:
- The RL component should use PPO-style clipped surrogate objective (ratio = exp(log_prob - old_log_prob), then clip).
- The SFT loss is pre-computed and available via
config.global_batch_info["sft_loss"]. You can also compute custom SFT losses usingsft_log_probandsft_labels_mask. - Use
global_stepandtotal_stepsfor curriculum scheduling. - Return meaningful metrics (e.g.,
sft_weight,rl_weight,sft_loss,rl_loss) to aid debugging.
Code
1# Copyright 2024 Bytedance Ltd. and/or its affiliates2# Licensed under the Apache License, Version 2.03"""Custom hybrid SFT+RL policy loss for verl PPO training."""45from typing import Any, Optional67import numpy as np8import torch910import verl.utils.torch_functional as verl_F11from verl.trainer.ppo.core_algos import agg_loss, register_policy_loss12from verl.workers.config import ActorConfig1314# =====================================================================15# EDITABLE: Implement your custom hybrid SFT+RL policy loss below.
Additional context files (read-only):
verl/verl/trainer/ppo/core_algos.pyverl/verl/workers/utils/losses.py
Results
No results available yet.