llm-hybrid-posttraining

Language Modelsverlrigorous codebase

Description

LLM Hybrid Post-Training: SFT + RL Loss Combination

Objective

Design and implement a custom policy loss function that combines supervised fine-tuning (SFT) and reinforcement learning (RL/GRPO) objectives for LLM post-training. Your code goes in the compute_custom_policy_loss() function in custom_hybrid_loss.py. The read-only reference files core_algos.py and losses.py contain the vanilla policy loss and SFT loss implementations you can study.

Background

LLM post-training typically uses either SFT (learning from demonstrations) or RL (learning from reward signals), but combining them is an active research question. Key design choices include:

  • Static weighting: loss = α * sft_loss + (1-α) * rl_loss with fixed coefficients.
  • Curriculum scheduling: Start with SFT and gradually shift to RL (or vice versa) based on training progress (global_step / total_steps).
  • Adaptive mixing: Weight the SFT term based on model competence — e.g., use more SFT when the model struggles (low reward) and more RL when it succeeds.
  • Importance-weighted SFT (LUFFY): Apply importance sampling weights exp(log_prob - sft_log_prob) to the SFT loss to avoid rigid imitation of off-policy demonstrations.
  • Per-question switching (HPT): For each question, choose between pure SFT and pure RL based on the model's pass rate.

Evaluation

Your policy loss is used to train Qwen3-1.7B (full parameter, non-thinking mode) using the verl framework. RL training uses the DeepMath-103K dataset for rollout generation with math-rule reward. SFT data comes from MetaMathQA — a large-scale augmented math instruction dataset providing question-answer demonstrations.

The model is evaluated on three math reasoning benchmarks:

  1. GSM8K — Grade school math (1,319 test problems). Metric: val-core/openai/gsm8k/acc/mean@1.
  2. MATH-500 — Curated 500-problem subset of MATH competition problems. Metric: val-core/HuggingFaceH4/MATH-500/acc/mean@1.
  3. AIME 2024 — 30 problems from the American Invitational Mathematics Examination. Metric: val-core/aime2024/acc/mean@1.

Training runs for 15 epochs with 5 rollout samples per prompt on 4 GPUs. Higher accuracy indicates better math reasoning.

Interface Contract

The training loop calls your function via the verl policy loss registry:

@register_policy_loss("custom")
def compute_custom_policy_loss(
    old_log_prob: torch.Tensor,      # (bs, response_length) — rollout policy log-probs
    log_prob: torch.Tensor,           # (bs, response_length) — current policy log-probs
    advantages: torch.Tensor,         # (bs, response_length) — GRPO advantage estimates
    response_mask: torch.Tensor,      # (bs, response_length) — valid token mask
    loss_agg_mode: str = "token-mean",
    config: Optional[ActorConfig] = None,
    rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:

Parameters:

  • old_log_prob: Per-token log-probs under the rollout policy. Shape (bs, response_length).
  • log_prob: Per-token log-probs under the current (training) policy. Shape (bs, response_length).
  • advantages: GRPO-style advantage estimates per token. Shape (bs, response_length).
  • response_mask: Binary mask (1 = valid response token). Shape (bs, response_length).
  • loss_agg_mode: How to aggregate the per-token loss into a scalar (e.g., "token-mean").
  • config: ActorConfig with fields including:
    • config.clip_ratio (float): PPO clipping parameter ε (default 0.2).
    • config.global_batch_info (dict): Contains aggregation info AND hybrid-training data:
      • "sft_loss" (float): Pre-computed SFT cross-entropy loss for the current batch.
      • "sft_log_prob" (Tensor, (bs, response_length)): SFT demonstration log-probs under current policy.
      • "sft_labels_mask" (Tensor, (bs, response_length)): Mask for valid SFT tokens.
      • "global_step" (int): Current training step (1-indexed).
      • "total_steps" (int): Total training steps.
      • "dp_size", "batch_num_tokens", "global_batch_size", "loss_scale_factor": Standard aggregation fields.

Return values:

  • pg_loss: Scalar loss tensor (combined SFT+RL).
  • pg_metrics: Dict of metric name → float value for logging.

Available utilities:

  • agg_loss(loss_mat, loss_mask, loss_agg_mode, **config.global_batch_info) — aggregate per-token loss matrix to scalar.
  • verl_F.masked_mean(values, mask) — mean over masked elements.
  • verl_F.masked_sum(values, mask) — sum over masked elements.
  • torch, numpy (already imported).

Important notes:

  • The RL component should use PPO-style clipped surrogate objective (ratio = exp(log_prob - old_log_prob), then clip).
  • The SFT loss is pre-computed and available via config.global_batch_info["sft_loss"]. You can also compute custom SFT losses using sft_log_prob and sft_labels_mask.
  • Use global_step and total_steps for curriculum scheduling.
  • Return meaningful metrics (e.g., sft_weight, rl_weight, sft_loss, rl_loss) to aid debugging.

Code

custom_hybrid_loss.py
EditableRead-only
1# Copyright 2024 Bytedance Ltd. and/or its affiliates
2# Licensed under the Apache License, Version 2.0
3"""Custom hybrid SFT+RL policy loss for verl PPO training."""
4
5from typing import Any, Optional
6
7import numpy as np
8import torch
9
10import verl.utils.torch_functional as verl_F
11from verl.trainer.ppo.core_algos import agg_loss, register_policy_loss
12from verl.workers.config import ActorConfig
13
14# =====================================================================
15# EDITABLE: Implement your custom hybrid SFT+RL policy loss below.

Additional context files (read-only):

  • verl/verl/trainer/ppo/core_algos.py
  • verl/verl/workers/utils/losses.py

Results

No results available yet.