RLHF(三):基于TRL的GrpoTrainer详解

写在前面:目前主流的LLM post-training框架主要有trl, OpenRLHF, verl。后两者集成度较高,适合对LLM零代码训练,而trl灵活性较强,这里主要对GRPO Trainer的训练流程进行梳理

GRPOTrainer类

它继承了transformers.Trainer,并重写或拓展了若干方法,包括:
init
作用:初始化模型、参考模型(ref_model)、奖励模型(reward_funcs)等,并作一些超参数设置(如 num_generations, beta 等)。

  • model: 加载策略模型, 可以是字符串(模型ID或路径)或预训练模型对象。仅支持因果语言模型
  • reward_funcs: 加载奖励函数,可以是预训练模型(仅支持SequenceClassification模型);用户自定义Python函数;或者是一个列表,意味着多种奖励函数一起用
  • args: GRPOConfig对象,包含训练的所有参数
  • train_dataset: 训练数据集,必须包含名为’prompt’的列,可以是Dataset或IterableDataset
  • eval_dataset: 评估数据集
  • processing_class: 数据处理器类,用于对训练和评估数据进行预处理 typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None
  • reward_processing_class: 奖励函数对应的分词器,支持单个分词器或多个分词器的列表
  • callback: 自定义训练回调列表,可扩展或覆盖默认的训练过程 typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None
  • optimizer
  • peft_config

补充说明:

  1. processing_class 填充侧必须设置为 “left”。如果为 None,则使用 from_pretrained 从模型名称加载处理类。
  2. reward_processing_class: 可选,默认为None。在自定义时,必须与 reward_funcs 中奖励函数的顺序和长度匹配。

_prepare_inputs
作用:在训练循环中,每一个 batch 先对 prompt 进行采样生成多条回答,调用奖励模型打分,计算组内相对优势。
简要流程:

  1. 对batch中每个prompt调用模型一次性生成 num_generations条回答。若生成中提前出现EOS,对EOS之后的token使用completion_mask进行掩码
  2. 使用ref_model对完整序列(prompt+completion) 计算token级对数概率,用于后面进行KL
  3. 调用reward model对每条回答打分,形成[B*G]的reward
  4. 计算相对优势:[B*G]-reshape->[B,G], 对同一个 prompt 的 G 条回答做“均值、标准差”,再 broadcast 回去,以得到每条回答的相对 advantage
    全部代码:
    def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
    device = self.accelerator.device
    prompts = [x["prompt"] for x in inputs]
    prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
    prompt_inputs = self.processing_class(
    prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
    )
    prompt_inputs = super()._prepare_inputs(prompt_inputs)
    prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

    if self.max_prompt_length is not None:
    prompt_ids = prompt_ids[:, -self.max_prompt_length :]
    prompt_mask = prompt_mask[:, -self.max_prompt_length :]

    # Generate completions using either vLLM or regular generation
    if self.args.use_vllm:
    # First, have main process load weights if needed
    if self.state.global_step != self._last_loaded_step:
    with unwrap_model_for_generation(
    self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
    ) as unwrapped_model:
    if is_compiled_module(unwrapped_model):
    state_dict = unwrapped_model._orig_mod.state_dict()
    else:
    state_dict = unwrapped_model.state_dict()
    if self.accelerator.is_main_process:
    llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
    llm_model.load_weights(state_dict.items())
    self._last_loaded_step = self.state.global_step

    # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
    all_prompts_text = gather_object(prompts_text)
    if self.accelerator.is_main_process:
    outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
    completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
    else:
    completion_ids = [None] * len(all_prompts_text) * self.num_generations

    # Broadcast the completions from the main process to all processes, ensuring each process receives its
    # corresponding slice.
    completion_ids = broadcast_object_list(completion_ids, from_process=0)
    process_slice = slice(
    self.accelerator.process_index * len(prompts) * self.num_generations,
    (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
    )
    completion_ids = completion_ids[process_slice]

    # Pad the completions, and concatenate them with the prompts
    completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
    completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
    prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0)
    prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0)
    prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    else:
    # Regular generation path
    with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
    prompt_completion_ids = unwrapped_model.generate(
    prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
    )

    # Compute prompt length and extract completion ids
    prompt_length = prompt_ids.size(1)
    prompt_ids = prompt_completion_ids[:, :prompt_length]
    completion_ids = prompt_completion_ids[:, prompt_length:]
    prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)

    # Mask everything after the first EOS token
    is_eos = completion_ids == self.processing_class.eos_token_id
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
    sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
    completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

    # Concatenate prompt_mask with completion_mask for logit computation
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)

    logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

    with torch.inference_mode():
    if self.ref_model is not None:
    ref_per_token_logps = self._get_per_token_logps(
    self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
    )
    else:
    with self.accelerator.unwrap_model(self.model).disable_adapter():
    ref_per_token_logps = self._get_per_token_logps(
    self.model, prompt_completion_ids, attention_mask, logits_to_keep
    )

    # Decode the generated completions
    completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
    if is_conversational(inputs[0]):
    completions = [[{"role": "assistant", "content": completion}] for completion in completions]

    # Compute the rewards
    prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts

    rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
    for i, (reward_func, reward_processing_class) in enumerate(
    zip(self.reward_funcs, self.reward_processing_classes)
    ):
    if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
    if is_conversational(inputs[0]):
    messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
    texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
    else:
    texts = [p + c for p, c in zip(prompts, completions)]
    reward_inputs = reward_processing_class(
    texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
    )
    reward_inputs = super()._prepare_inputs(reward_inputs)
    with torch.inference_mode():
    rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
    else:
    # Repeat all input columns (but "prompt" and "completion") to match the number of generations
    reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
    for key in reward_kwargs:
    for example in inputs:
    # Repeat each value in the column for `num_generations` times
    reward_kwargs[key].extend([example[key]] * self.num_generations)
    output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
    rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

    # Sum the rewards from all reward functions
    rewards = rewards_per_func.sum(dim=1)

    # Compute grouped-wise rewards
    mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
    std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

    # Normalize the rewards to compute the advantages
    mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
    std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
    advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

    # Log the metrics
    reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
    for i, reward_func in enumerate(self.reward_funcs):
    if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
    reward_func_name = reward_func.config._name_or_path.split("/")[-1]
    else:
    reward_func_name = reward_func.__name__
    self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

    self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
    self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

    return {
    "prompt_ids": prompt_ids,
    "prompt_mask": prompt_mask,
    "completion_ids": completion_ids,
    "completion_mask": completion_mask,
    "ref_per_token_logps": ref_per_token_logps,
    "advantages": advantages,
    }
    具体细节:https://blog.csdn.net/shizheng_Li/article/details/145794949

compute_loss
作用:根据 GRPO 公式,结合 KL 惩罚项和相对优势,计算最终损失并进行反向传播。
简要流程:

  1. 在当前策略下计算完整序列的token级对数概率
  2. 根据actor model和ref model的token_log_prob,计算KL散度
  3. 利用KL散度和_prepare_inputs中得到的相对优势计算GRPO Loss,然后反向传播进行梯度更新
    全部代码:
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    if return_outputs:
    raise ValueError("The GRPOTrainer does not support returning outputs")
    # Compute the per-token log probabilities for the model

    prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
    completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
    input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
    logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

    per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

    # Compute the KL divergence between the model and the reference model
    ref_per_token_logps = inputs["ref_per_token_logps"]
    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

    # x - x.detach() allows for preserving gradients from x
    advantages = inputs["advantages"]
    per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
    per_token_loss = -(per_token_loss - self.beta * per_token_kl)
    loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

    # Log the metrics
    completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
    self._metrics["completion_length"].append(completion_length)

    mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

    return loss
    具体细节:https://blog.csdn.net/shizheng_Li/article/details/145793070

prediction_step
作用:训练 / 验证阶段如何调用 _prepare_inputs 并获取 loss。
在评估或预测时,也需要执行 _prepare_inputs 来生成多条回答并算 loss,只不过不会再反向传播。这个函数就是在 eval 阶段或预测时,对应地拿到 loss,用于打印日志或 early stopping 等操作。

log 与 create_model_card
作用:日志与模型卡,可上传到 Hugging Face Hub 做模型管理。
create_model_card输入参数:

  • model_name:模型的名称
  • dataset_name:用于训练的数据集的名称
  • tags:要与模型卡关联的标签

自定义奖励函数

GRPOTrainer 支持使用自定义奖励函数来代替密集奖励模型。为了确保兼容性,您的奖励函数必须满足以下要求:

  1. 输入参数
    函数必须接受以下内容作为关键字参数:
    • prompts(包含提示),
    • completions(包含生成的补全),
    • 数据集可能包含的所有列名(但prompt除外)。例如,如果数据集包含名为ground_truth的列,则将使用ground_truth作为关键字参数调用该函数。满足此要求的最简单方法是在函数签名中使用**kwargs。

根据数据集格式,输入会有所不同:

- 对于标准格式,prompts和completions将是字符串列表。
- 对于对话格式,prompts和completions将是消息字典列表。
  1. 返回值
    函数必须返回一个浮点数列表。每个浮点数代表对应于单个补全的奖励。

示例一: 奖励较长的补全

def reward_func(completions, **kwargs):
"""奖励函数,对较长的补全给予更高的分数。"""
return [float(len(completion)) for completion in completions]

示例二:奖励具有特定格式的补全
import re

def format_reward_func(completions, **kwargs):
"""奖励函数,检查补全是否具有特定格式。"""
pattern = r"^&lt;think&gt;.*?&lt;/think&gt;&lt;answer&gt;.*?&lt;/answer&gt;$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]

写在最后: 详见https://huggingface.co/docs/trl/main/en/grpo_trainer