Skip to content

⏩ Train on completion only #3329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 23, 2025
Merged

Conversation

qgallouedec
Copy link
Member

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@LeonEricsson
Copy link
Contributor

Suggestion to update the DataCollator example to:

    Examples:
    ```python
    >>> from trl import DataCollatorForLanguageModeling
    >>> collator = DataCollatorForLanguageModeling(pad_token_id=0)
    >>> examples = [
    ...     {"input_ids": [1, 2, 3, 4], "completion_mask": [0, 0, 1, 1]},
    ...     {"input_ids": [5, 6, 7], "completion_mask": [0, 1, 1]}
    ... ]
    >>> collator(examples)
    {'input_ids': tensor([[1, 2, 3, 4],
                          [5, 6, 7, 0]]),
     'attention_mask': tensor([[1, 1, 1, 1],
                               [1, 1, 1, 0]]),
     'labels': tensor([[-100, -100,    3,    4],
                       [ -100,    6,    7, -100]])}
    ```

given that completion_only_loss is true by default

Comment on lines -1067 to -1083
def test_train_model_wrong_torch_dtype(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, model_init_kwargs={"torch_dtype": -1}, report_to="none")
with self.assertRaises(ValueError) as context:
SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
self.assertIn(
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
"a `torch.dtype` (e.g., 'float32'), but got -1.",
str(context.exception),
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to the core change of this PR.
With the new serialisation logic of TrainingArguments, passing a wrong dtype fails when you instantiate the TrainingArguments. There is no need for such test anymore

Comment on lines -503 to -512
# If the dataset is prompt-completion, convert it to language modeling type
first_example = next(iter(dataset))
if "prompt" in first_example.keys() and "completion" in first_example.keys():
key = "messages" if is_conversational(first_example) else "text"

def concat_prompt_completion(example):
return {key: example["prompt"] + example["completion"]}

dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This concatenation needs to be removed, as we loses the information about where the completion starts. This completion is now managed in tokenize.

@qgallouedec
Copy link
Member Author

qgallouedec commented Apr 22, 2025

Suggestion to update the DataCollator example to:

Thanks! Done in 7f7f2a4

@qgallouedec qgallouedec merged commit 9497527 into fix-add_special_tokens Apr 23, 2025
10 checks passed
@qgallouedec qgallouedec deleted the train-completion-only branch April 23, 2025 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants