diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 5b8fcbbd69..b6c9770213 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -444,11 +444,16 @@ def __init__( # Reward functions if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] + self.reward_func_names = [] for i, reward_func in enumerate(reward_funcs): if isinstance(reward_func, str): reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( reward_func, num_labels=1, **model_init_kwargs ) + if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models + self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) self.reward_funcs = reward_funcs # Reward weights @@ -674,7 +679,10 @@ def data_collator(features): # No data collation is needed in GRPO for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, PreTrainedModel): - self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. @@ -1022,13 +1030,9 @@ def _generate_and_score_completions( completions = completions_text rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - for i, (reward_func, reward_processing_class) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes) + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) ): - if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models - reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}" - else: - reward_func_name = reward_func.__name__ with profiling_context(self, reward_func_name): if isinstance( reward_func, nn.Module @@ -1113,17 +1117,8 @@ def _generate_and_score_completions( self._metrics[mode]["completions/min_terminated_length"].append(term_completion_mask.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_mask.float().max().item()) - # Get the names of the reward functions - reward_func_names = [] - for reward_func in self.reward_funcs: - if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models - reward_func_name = reward_func.config._name_or_path.split("/")[-1] - else: - reward_func_name = reward_func.__name__ - reward_func_names.append(reward_func_name) - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) - for i, reward_func_name in enumerate(reward_func_names): + for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) std_rewards = nanstd(rewards_per_func[:, i]).item() @@ -1134,7 +1129,7 @@ def _generate_and_score_completions( # Log prompt and completion texts self._textual_logs["prompt"].extend(gather_object(prompts_text)) self._textual_logs["completion"].extend(gather_object(completions_text)) - for i, name in enumerate(reward_func_names): + for i, name in enumerate(self.reward_func_names): self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) return {