Skip to content

🍡 Fix using reward model and DeepSpeed ZeRO 3 #3326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,16 @@ def __init__(
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
self.reward_func_names = []
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
else:
self.reward_func_names.append(reward_funcs[i].__name__)
Comment on lines +447 to +456
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to get the reward name before it's wrapped with deepspeed.

self.reward_funcs = reward_funcs

# Reward weights
Expand Down Expand Up @@ -674,7 +679,10 @@ def data_collator(features): # No data collation is needed in GRPO

for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
if self.is_deepspeed_enabled:
self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
else:
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
Comment on lines +682 to +685
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes this issue


def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
Expand Down Expand Up @@ -1022,13 +1030,9 @@ def _generate_and_score_completions(
completions = completions_text

rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
else:
reward_func_name = reward_func.__name__
with profiling_context(self, reward_func_name):
if isinstance(
reward_func, nn.Module
Expand Down Expand Up @@ -1113,17 +1117,8 @@ def _generate_and_score_completions(
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_mask.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_mask.float().max().item())

# Get the names of the reward functions
reward_func_names = []
for reward_func in self.reward_funcs:
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
reward_func_names.append(reward_func_name)

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(reward_func_names):
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
std_rewards = nanstd(rewards_per_func[:, i]).item()
Expand All @@ -1134,7 +1129,7 @@ def _generate_and_score_completions(
# Log prompt and completion texts
self._textual_logs["prompt"].extend(gather_object(prompts_text))
self._textual_logs["completion"].extend(gather_object(completions_text))
for i, name in enumerate(reward_func_names):
for i, name in enumerate(self.reward_func_names):
self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())

return {
Expand Down
Loading