diff --git a/tests/test_repad.py b/tests/test_repad.py new file mode 100644 index 0000000000..652cd293a7 --- /dev/null +++ b/tests/test_repad.py @@ -0,0 +1,131 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy + +import torch + +from trl.trainer.grpo_replay_buffer import repad + + +PAD_TOKEN_ID = 123 + + +def test_repad_basic_padding(): + sample = [ + { + "prompt_ids": torch.LongTensor([1, 2, 3]), + "prompt_mask": torch.LongTensor([1, 1, 0]), + "completion_ids": torch.LongTensor([5, 6, 7, 8]), + "completion_mask": torch.LongTensor([1, 1, 1, 0]), + "old_per_token_logps": torch.tensor([0.1, 0.2, 0.3, 0.4]), + "ref_per_token_logps": torch.tensor([0.0, -0.1, -0.2, -0.3]), + }, + { + "prompt_ids": torch.LongTensor([4, 5]), + "prompt_mask": torch.LongTensor([1, 1]), + "completion_ids": torch.LongTensor([9, 10]), + "completion_mask": torch.LongTensor([1, 1]), + "old_per_token_logps": torch.tensor([-0.5, -0.6]), + "ref_per_token_logps": torch.tensor([0.5, 0.6]), + }, + ] + + padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID) + + assert len(padded[0]["prompt_ids"]) == 2 + assert len(padded[0]["completion_ids"]) == 3 + + for ex in padded: + # All sequences in same batch should have same length + assert len(ex["prompt_ids"]) == len(padded[0]["prompt_ids"]) + assert len(ex["prompt_mask"]) == len(padded[0]["prompt_mask"]) + assert len(ex["completion_ids"]) == len(padded[0]["completion_ids"]) + assert len(ex["completion_mask"]) == len(padded[0]["completion_mask"]) + + # Mask and ids should match in shape + assert ex["prompt_ids"].shape == ex["prompt_mask"].shape + assert ex["completion_ids"].shape == ex["completion_mask"].shape + + +def test_repad_logps_padding(): + sample = [ + { + "prompt_ids": torch.LongTensor([1]), + "prompt_mask": torch.LongTensor([1]), + "completion_ids": torch.LongTensor([2, 3, 4]), + "completion_mask": torch.LongTensor([1, 1, 0]), + "old_per_token_logps": torch.tensor([-0.1, -0.2, -0.3]), + "ref_per_token_logps": torch.tensor([-0.5, -0.6, -0.7]), + }, + { + "prompt_ids": torch.LongTensor([5, 6]), + "prompt_mask": torch.LongTensor([1, 1]), + "completion_ids": torch.LongTensor([7, 8]), + "completion_mask": torch.LongTensor([1, 1]), + "old_per_token_logps": torch.tensor([0.4, 0.5]), + "ref_per_token_logps": torch.tensor([0.6, 0.7]), + }, + ] + + padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID) + + for logps in ["old_per_token_logps", "ref_per_token_logps"]: + for ex in padded: + assert len(ex[logps]) == len(padded[0][logps]) + assert isinstance(ex[logps], torch.Tensor) + + +def test_repad_empty_masks(): + sample = [ + { + "prompt_ids": torch.tensor([0]), + "prompt_mask": torch.tensor([0]), + "completion_ids": torch.tensor([0]), + "completion_mask": torch.tensor([0]), + "old_per_token_logps": torch.tensor([0.0]), + "ref_per_token_logps": torch.tensor([0.0]), + }, + { + "prompt_ids": torch.tensor([1]), + "prompt_mask": torch.tensor([0]), + "completion_ids": torch.tensor([1]), + "completion_mask": torch.tensor([0]), + "old_per_token_logps": torch.tensor([0.0]), + "ref_per_token_logps": torch.tensor([0.0]), + }, + { + "prompt_ids": torch.tensor([1, 1]), + "prompt_mask": torch.tensor([0, 1]), + "completion_ids": torch.tensor([1, 2]), + "completion_mask": torch.tensor([1, 0]), + "old_per_token_logps": torch.tensor([0.0, 1.0]), + "ref_per_token_logps": torch.tensor([0.0, 1.0]), + }, + { + "prompt_ids": torch.tensor([1, 1]), + "prompt_mask": torch.tensor([1, 1]), + "completion_ids": torch.tensor([1, 2]), + "completion_mask": torch.tensor([1, 0]), + "old_per_token_logps": torch.tensor([0.0, 1.0]), + "ref_per_token_logps": torch.tensor([0.0, 1.0]), + }, + ] + padded = repad(deepcopy(sample), padding_value=999) + + assert len(padded[0]["prompt_ids"]) == 2 + assert len(padded[0]["completion_ids"]) == 1 + + assert padded[0]["prompt_ids"].eq(999).all() + assert padded[0]["completion_ids"].eq(999).all() diff --git a/train_grpo.py b/train_grpo.py new file mode 100644 index 0000000000..e28a5fe859 --- /dev/null +++ b/train_grpo.py @@ -0,0 +1,35 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset + +from trl import GRPOConfig, GRPOTrainer + + +dataset = load_dataset("trl-lib/tldr", split="train") + + +# Define the reward function, which rewards completions that are close to 20 characters +def reward_len(completions, **kwargs): + return [-abs(20 - len(completion)) for completion in completions] + + +training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=1, replay_buffer_class="SSRReplayBuffer") +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_len, + args=training_args, + train_dataset=dataset, +) +trainer.train() diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index da2bad86b8..621d3f5c2f 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -153,6 +153,8 @@ class GRPOConfig(TrainingArguments): use_liger_loss (`bool`, *optional*, defaults to `False`): Whether to use the Liger GRPO loss. + replay_buffer_class: (`str`, defaults to `ReplayBuffer`): + > Parameters that control the logging log_completions (`bool`, *optional*, defaults to `False`): @@ -393,6 +395,26 @@ class GRPOConfig(TrainingArguments): metadata={"help": "Whether to use the Liger GRPO loss."}, ) + replay_buffer_class: str = field( + default="ReplayBuffer", + metadata={ + "help": "Replay buffer class to use, Options [ReplayBuffer, SSRReplayBuffer] The default is `ReplayBuffer`, that randomly samples without replacement." + }, + ) + ssr_capacity_scalar: int = field( + default=4, + metadata={ + "help": "Scalar to multiply the replay buffer capacity. The default is 1, which means the capacity is " + "equal to the number of training samples in the effective batch." + }, + ) + ssr_alpha: float = field( + default=1.0, + metadata={ + "help": "Alpha parameter for controlling the probablity distribution of the replay buffer. The default is 1.0, " + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, diff --git a/trl/trainer/grpo_replay_buffer.py b/trl/trainer/grpo_replay_buffer.py new file mode 100644 index 0000000000..86956f2f18 --- /dev/null +++ b/trl/trainer/grpo_replay_buffer.py @@ -0,0 +1,154 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np + +from .utils import pad + + +def repad(list_of_tensor_dicts, padding_value): + p_ids, p_attn_masks = remove_and_pad( + [tensor_dict["prompt_ids"] for tensor_dict in list_of_tensor_dicts], + [tensor_dict["prompt_mask"] for tensor_dict in list_of_tensor_dicts], + pad_token_id=padding_value, + padding_side="left", + ) + c_ids, c_attn_masks = remove_and_pad( + [tensor_dict["completion_ids"] for tensor_dict in list_of_tensor_dicts], + [tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts], + pad_token_id=padding_value, + ) + old_logps, _ = remove_and_pad( + [tensor_dict["old_per_token_logps"] for tensor_dict in list_of_tensor_dicts], + [tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts], + pad_token_id=-10000.0, # ignored so can be anything + ) + ref_logps, _ = remove_and_pad( + [tensor_dict["ref_per_token_logps"] for tensor_dict in list_of_tensor_dicts], + [tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts], + pad_token_id=-10000.0, # ignored so can be anything + ) + + for i, (p_id, p_mask, c_id, c_mask, o_logp, r_logp) in enumerate( + zip(p_ids, p_attn_masks, c_ids, c_attn_masks, old_logps, ref_logps) + ): + list_of_tensor_dicts[i]["prompt_ids"] = p_id + list_of_tensor_dicts[i]["prompt_mask"] = p_mask + list_of_tensor_dicts[i]["completion_ids"] = c_id + list_of_tensor_dicts[i]["completion_mask"] = c_mask + list_of_tensor_dicts[i]["old_per_token_logps"] = o_logp + list_of_tensor_dicts[i]["ref_per_token_logps"] = r_logp + + return list_of_tensor_dicts + + +def remove_and_pad(list_of_ids, list_of_masks, pad_token_id=0, padding_side="right"): + """ + Remove padding from list_of_ids and list_of_masks, and then pad them to the same length. + """ + num_samples = len(list_of_ids) + if list_of_ids[0] is None: + # we are not using old_per_token_logps / ref_per_token_logps + return [None] * num_samples, [None] * num_samples + # Remove padding + list_of_ids = [ids[mask == 1] for ids, mask in zip(list_of_ids, list_of_masks)] + list_of_masks = [mask[mask == 1] for mask in list_of_masks] + + ids = pad(list_of_ids, padding_value=pad_token_id, padding_side=padding_side) + masks = pad(list_of_masks, padding_value=0, padding_side=padding_side) + + return ids, masks + + +def remove_padding(input_ids, attn_mask): + """ + Remove padding from input_ids and attn_mask. + """ + if attn_mask is not None: + input_ids = input_ids[attn_mask == 1] + attn_mask = attn_mask[attn_mask == 1] + return input_ids, attn_mask + + +class ReplayBuffer: + def __init__(self, capacity): + self.capacity = capacity + self.buffer = [] + self.sample_indices = [] + + def add(self, experience): + if len(self.buffer) < self.capacity: + self.buffer.append(experience) + else: + self.buffer.pop(0) + self.buffer.append(experience) + + # Clear index queue when buffer changes + self.sample_indices.clear() + + def _init_sampling_queue(self): + self.sample_indices = list(range(len(self.buffer))) + random.shuffle(self.sample_indices) + + def sample(self, batch_size): + if not self.sample_indices: + self._init_sampling_queue() + + batch = [] + while len(batch) < batch_size and self.sample_indices: + idx = self.sample_indices.pop(0) + batch.append(self.buffer[idx]) + + if len(batch) != batch_size: + raise ValueError("Not enough samples in the buffer to fill the batch.") + + return batch + + def __len__(self): + return len(self.buffer) + + +class SSRReplayBuffer(ReplayBuffer): + # implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837 + def __init__(self, capacity, alpha=1.0): + super().__init__(capacity) + self.alpha = alpha + self.advantages = [] + + def add(self, experience): + EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages + advantage = experience["advantages"].item() + if len(self.buffer) < self.capacity: + self.buffer.append(experience) + self.advantages.append(abs(advantage) + EPS) # Store absolute advantage + else: + # Replace the oldest entry if the buffer is full + self.buffer.pop(0) + self.advantages.pop(0) + self.buffer.append(experience) + self.advantages.append(abs(advantage)) + + def sample(self, batch_size): + if not self.buffer: + raise ValueError("Buffer is empty. Cannot sample from an empty buffer.") + + # Convert advantages to priorities + scaled_priorities = np.power(self.advantages, self.alpha) + total_priority = np.sum(scaled_priorities) + probabilities = scaled_priorities / total_priority + + indices = np.random.choice(len(self.buffer), batch_size, p=probabilities) + return [self.buffer[i] for i in indices] diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b6c9770213..1d8cb44964 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -51,6 +51,7 @@ from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig +from .grpo_replay_buffer import ReplayBuffer, SSRReplayBuffer, repad from .utils import ( disable_dropout_in_model, generate_model_card, @@ -231,6 +232,31 @@ def split_tensor_dict( ] +def combine_tensor_dict(split_dicts: list[dict[str, Optional[torch.Tensor]]]) -> dict[str, Optional[torch.Tensor]]: + """ + Combines a list of dictionaries containing tensors into a single dictionary by + concatenating the tensors along the first dimension. + + Example: + >>> d1 = {"x": torch.tensor([[0, 1], [2, 3]]), "y": torch.tensor([[0], [1]])} + >>> d2 = {"x": torch.tensor([[4, 5], [6, 7]]), "y": torch.tensor([[2], [3]])} + >>> d3 = {"x": torch.tensor([[8, 9], [10, 11]]), "y": torch.tensor([[4], [5]])} + >>> combine_tensor_dict([d1, d2, d3]) + { + "x": tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]]), + "y": tensor([[0], [1], [2], [3], [4], [5]]) + } + """ + combined_dict = {} + keys = split_dicts[0].keys() + + for key in keys: + tensors = [d[key] for d in split_dicts if d[key] is not None] + combined_dict[key] = torch.stack(tensors, dim=0) if tensors else None + + return combined_dict + + def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. @@ -684,6 +710,22 @@ def data_collator(features): # No data collation is needed in GRPO else: self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + # for the standard setting, use this replay buffer + + effective_batch_size = self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps + + if self.args.replay_buffer_class == "ReplayBuffer": + self.replay_buffer = ReplayBuffer(capacity=effective_batch_size) + elif self.args.replay_buffer_class == "SSRReplayBuffer": + self.replay_buffer = SSRReplayBuffer( + capacity=effective_batch_size * self.args.ssr_capacity_scalar, + alpha=self.args.ssr_alpha, + ) + else: + raise ValueError( + f"Invalid `replay_buffer_class` passed to `GRPOConfig`. Expected either 'ReplayBuffer' or 'SSRReplayBuffer', but got {self.args.replay_buffer_class}." + ) + def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. # By default, this method sets `self._signature_columns` to the model's expected inputs. @@ -893,13 +935,19 @@ def _prepare_inputs( mode = "eval" if self.control.should_evaluate else "train" if mode == "train": generate_every = self.args.gradient_accumulation_steps * self.num_iterations - if self._step % generate_every == 0 or self._buffered_inputs is None: + if self._step % generate_every == 0 or len(self.replay_buffer) == 0: # self._buffered_inputs=None can occur when resuming from a checkpoint accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch) - self._buffered_inputs = split_tensor_dict( - accumulated_local_batch, self.args.gradient_accumulation_steps - ) - inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] + effective_batch_size = self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps + split_tensors = split_tensor_dict(accumulated_local_batch, effective_batch_size) + + for tensor in split_tensors: + self.replay_buffer.add(tensor) + + split_inputs = self.replay_buffer.sample(self.args.per_device_train_batch_size) + repadded_split_inputs = repad(split_inputs, padding_value=self.processing_class.pad_token_id) + inputs = combine_tensor_dict(repadded_split_inputs) + self._step += 1 else: # In evaluation, there is neither gradient accumulation, nor multiple iterations