Skip to content

Allow for saving the PPOTrainer value model (critic model) #3308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class PPOConfig(OnPolicyConfig):
Discount factor.
lam (`float`, *optional*, defaults to `0.95`):
Lambda value for GAE.
save_value_model (`bool`, *optional*, defaults to `False`):
Whether the value model (also known as the critic model) should be saved when the policy model is saved. If `False`, the folder will contain the files for the policy only. If `True`, the folder will contain sub-folders for the policy and value model. You can import them by specifying the subfolder using a keyword argument: `from_pretrained(repo_id, subfolder=subfolder)`
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
Expand Down Expand Up @@ -121,6 +123,12 @@ class PPOConfig(OnPolicyConfig):
default=0.95,
metadata={"help": "Lambda value for GAE."},
)
save_value_model: bool = field(
default=False,
metadata={
"help": "Whether the value model (also known as the critic model) should be saved when the policy model is saved. If `False`, the folder will contain the files for the policy only. If `True`, the folder will contain sub-folders for the policy and value model. You can import them by specifying the subfolder using a keyword argument: `from_pretrained(repo_id, subfolder=subfolder)`"
},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
Expand Down
25 changes: 24 additions & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,43 @@ def null_ref_context(self):
self.model.policy.set_adapter(self.model_adapter_name or "default")

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
# Handle the None case here so that we can have subfolders for policy and value
if output_dir is None:
output_dir = self.args.output_dir
if output_dir is None:
raise ValueError("No output directory specified for saving the model")
# I am unsure whether this early return is legal. Line 4814 in Trainer.py says that save_model has to be executed on all processes for TPU training. Previously, save_model would be called in parallel while one process had already set self.model to self.model.policy, resulting in errors. Including this line gets rid of those errors and the model still gets uploaded.
if not hasattr(self.model, "policy"):
return
backup_model = self.model
self.model = self.model.policy # save only the policy

if self.is_deepspeed_enabled:
backup_deepspeed = self.deepspeed
self.deepspeed = self.model

super().save_model(output_dir, _internal_call)
policy_output_dir = output_dir if not self.args.save_value_model else os.path.join(output_dir, "policy_model")
super().save_model(policy_output_dir, _internal_call)

self.model = backup_model

if self.is_deepspeed_enabled:
self.deepspeed = backup_deepspeed

if self.args.save_value_model:
backup_model = self.model
self.model = self.model.value_model

if self.is_deepspeed_enabled:
backup_deepspeed = self.deepspeed
self.deepspeed = self.model
value_output_dir = os.path.join(output_dir, "value_model")
super().save_model(value_output_dir, _internal_call)
self.model = backup_model

if self.is_deepspeed_enabled:
self.deepspeed = backup_deepspeed

def train(self):
args = self.args
accelerator = self.accelerator
Expand Down