RLHF(三):基于TRL的GrpoTrainer详解
RLHF(三):基于TRL的GrpoTrainer详解
Peter写在前面:目前主流的LLM post-training框架主要有trl, OpenRLHF, verl。后两者集成度较高,适合对LLM零代码训练,而trl灵活性较强,这里主要对GRPO Trainer的训练流程进行梳理
GRPOTrainer类
它继承了transformers.Trainer,并重写或拓展了若干方法,包括:
init
作用:初始化模型、参考模型(ref_model)、奖励模型(reward_funcs)等,并作一些超参数设置(如 num_generations, beta 等)。
- model: 加载策略模型, 可以是字符串(模型ID或路径)或预训练模型对象。仅支持因果语言模型
- reward_funcs: 加载奖励函数,可以是预训练模型(仅支持
SequenceClassification
模型);用户自定义Python函数;或者是一个列表,意味着多种奖励函数一起用 - args: GRPOConfig对象,包含训练的所有参数
- train_dataset: 训练数据集,必须包含名为’prompt’的列,可以是Dataset或IterableDataset
- eval_dataset: 评估数据集
- processing_class: 数据处理器类,用于对训练和评估数据进行预处理 typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None
- reward_processing_class: 奖励函数对应的分词器,支持单个分词器或多个分词器的列表
- callback: 自定义训练回调列表,可扩展或覆盖默认的训练过程 typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None
- optimizer
- peft_config
补充说明:
- processing_class 填充侧必须设置为 “left”。如果为 None,则使用 from_pretrained 从模型名称加载处理类。
- reward_processing_class: 可选,默认为None。在自定义时,必须与 reward_funcs 中奖励函数的顺序和长度匹配。
_prepare_inputs
作用:在训练循环中,每一个 batch 先对 prompt 进行采样生成多条回答,调用奖励模型打分,计算组内相对优势。
简要流程:
- 对batch中每个prompt调用模型一次性生成
num_generations
条回答。若生成中提前出现EOS,对EOS之后的token使用completion_mask进行掩码 - 使用ref_model对完整序列(prompt+completion) 计算token级对数概率,用于后面进行KL
- 调用reward model对每条回答打分,形成[B*G]的reward
- 计算相对优势:[B*G]-reshape->[B,G], 对同一个 prompt 的 G 条回答做“均值、标准差”,再 broadcast 回去,以得到每条回答的相对 advantage
全部代码:具体细节:https://blog.csdn.net/shizheng_Li/article/details/145794949def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0)
prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Log the metrics
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}
compute_loss
作用:根据 GRPO 公式,结合 KL 惩罚项和相对优势,计算最终损失并进行反向传播。
简要流程:
- 在当前策略下计算完整序列的token级对数概率
- 根据actor model和ref model的token_log_prob,计算KL散度
- 利用KL散度和
_prepare_inputs
中得到的相对优势计算GRPO Loss,然后反向传播进行梯度更新
全部代码:具体细节:https://blog.csdn.net/shizheng_Li/article/details/145793070def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
return loss
prediction_step
作用:训练 / 验证阶段如何调用 _prepare_inputs 并获取 loss。
在评估或预测时,也需要执行 _prepare_inputs 来生成多条回答并算 loss,只不过不会再反向传播。这个函数就是在 eval 阶段或预测时,对应地拿到 loss,用于打印日志或 early stopping 等操作。
log 与 create_model_card
作用:日志与模型卡,可上传到 Hugging Face Hub 做模型管理。
create_model_card输入参数:
- model_name:模型的名称
- dataset_name:用于训练的数据集的名称
- tags:要与模型卡关联的标签
自定义奖励函数
GRPOTrainer 支持使用自定义奖励函数来代替密集奖励模型。为了确保兼容性,您的奖励函数必须满足以下要求:
- 输入参数
函数必须接受以下内容作为关键字参数:- prompts(包含提示),
- completions(包含生成的补全),
- 数据集可能包含的所有列名(但prompt除外)。例如,如果数据集包含名为ground_truth的列,则将使用ground_truth作为关键字参数调用该函数。满足此要求的最简单方法是在函数签名中使用**kwargs。
根据数据集格式,输入会有所不同:
- 对于标准格式,prompts和completions将是字符串列表。
- 对于对话格式,prompts和completions将是消息字典列表。
- 返回值
函数必须返回一个浮点数列表。每个浮点数代表对应于单个补全的奖励。
示例一: 奖励较长的补全
def reward_func(completions, **kwargs): |
示例二:奖励具有特定格式的补全
import re |
写在最后: 详见https://huggingface.co/docs/trl/main/en/grpo_trainer