Skip to content

Add PEFT benchmarking script in thunder/benchmarks #1978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

wprazuch
Copy link
Contributor

@wprazuch wprazuch commented Apr 22, 2025

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements) -> Discussed with @IvanYashchuk
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests? -> To verify what kind of tests to add

What does this PR do?

It introduces the benchmarking script for PEFT finetuning scenario, which supports:

  • compiler setup (inductor/thunder/eager)
  • single/multi-gpu setup (fsdp2)

To execute:

python thunder/benchmarks/benchmark_peft.py     --model deepseek-ai/DeepSeek-R1-Distill-Qwen-7B     \
    --devices 1  --trust-remote-code      --attn-implementation sdpa  \
 --max-steps 10        --mbs 1  \
   --seq-length 4096        --jit-backend thunder

for multi-gpu:

torchrun --nproc_per_node=8 --master_port=12345 thunder/benchmarks/benchmark_peft.py  \
   --model meta-llama/CodeLlama-34b-Instruct-hf     --strategy fsdp2   \
     --devices 8     --mbs 1     --seq-length 1024     --max-steps 10  \
        --jit-backend eager     --attn-implementation sdpa  \
           --trust-remote-code

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

# Configure model for static shapes before FSDP2
if hasattr(model, "config"):
model.config.use_cache = True
model.config.max_position_embeddings = args.seq_length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move setting model to static to a function as I see this in multiple different places.

return model


def setup_fsdp2(model: torch.nn.Module, devices: int, verbose: bool = False) -> torch.nn.Module:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @crcrpar for review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

dynamo_config.cache_size_limit = 64
# Disable gradient checkpointing for Thunder
if hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_disable()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if this is not called?

"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="meta-llama/Llama-3.2-1B")
parser.add_argument("--strategy", type=str, default="auto", choices=["auto", "ddp", "fsdp2"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ddp is not supported with this script.

)
logger.info(f"Base model loaded on meta device")

# Configure model for static shapes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if this is not done?

logger.info(f"Configured model for static shapes with sequence length: {args.seq_length}")

# Materialize the model on CUDA
model = model.to_empty(device=f"cuda:{LOCAL_RANK}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of FSDP, I think materialization should happen after setup_fsdp2 step. Otherwise, we will get OOM for a model which would have worked with FSDP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that is the correct. However, when materializing it after setup_fsdp2, I see about 30% slowdown to throughput compared to current ordering. Is it "normal"?

if "lora" in name.lower():
if not param.requires_grad:
if args.verbose:
logger.warning(f"LoRA parameter {name} does not require grad!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be asserts instead? I think if the requires_grad was setup wrong, we shouldn't proceed with getting the numbers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about using one of existing requirements txt files, maybe https://github.com/Lightning-AI/lightning-thunder/blob/main/requirements/devel.txt?

@@ -0,0 +1,724 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Copy link
Collaborator

@crcrpar crcrpar Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure about the license (also, shouldn't the year be 2025?).

from torch.distributed import DeviceMesh, init_process_group
from torch.distributed._composable.fsdp import fully_shard
from torch.nn.attention import SDPBackend, sdpa_kernel
from tqdm import tqdm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we have tqdm included in any of requirements at the moment. Would transformers or some others install it as their dependency?

import random
import time
from contextlib import contextmanager
from distutils.version import LooseVersion
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit-picking

Suggested change
from distutils.version import LooseVersion
from looseversion import LooseVersion

as we do in

from looseversion import LooseVersion

return args


def get_tokenizer(model_name: str, trust_remote_code: bool, fallback_model: str = "gpt2") -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refine the type annotation of return value and then remove from typing import Any?

import time
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Any, List, Optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Any, List, Optional

Any has one use but I guess we can do away with it

Comment on lines +23 to +34
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from datasets import Dataset
from loguru import logger
from peft import LoraConfig, get_peft_model
from torch.distributed import DeviceMesh, init_process_group
from torch.distributed._composable.fsdp import fully_shard
from torch.nn.attention import SDPBackend, sdpa_kernel
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clean up these imports? At glance, I'm not quite convinced with the imports of numpy and transformers.
For transformers import, we do import three from it so I think it'd be a bit cleaner to avoid import transformers

if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
if verbose:
logger.info(f"Wrapping layer {name} with FSDP2")
fully_shard(module, mesh=mesh)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for the consistency with the below

Suggested change
fully_shard(module, mesh=mesh)
fully_shard(module, mesh=mesh, reshard_after_forward=True)

logger.info(f"Set static cache size to sequence length: {args.seq_length}")

executors = thunder.get_default_executors()
xforms: list = [NvtxProfileTransform()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we always use this one? I guess there's some overhead with this

Comment on lines +244 to +257
if backend == "torchjit":
logger.info("Compiling model with torch.compile")
dist_print("Resetting cache size for torch.compile")
import torch._dynamo.config as dynamo_config

# Fixes recompilation issues with inductor
dynamo_config.cache_size_limit = 64
model = torch.compile(model)
elif backend == "thunder":
import thunder
import thunder.dynamo
import torch._dynamo.config as dynamo_config
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform
from thunder.executors.transformer_engineex import transformer_engine_ex
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I possibly miss something innegligible, but could you have all the imports at the beginning of the file?

@wprazuch
Copy link
Contributor Author

@kshitij12345 @crcrpar Thanks a lot for the review and for many valid points above, which skipped my attention. I wanted to let you know that I will be OOTO starting tomorrow and I am not sure if anyone from my team will take care of this PR during my absence. I will implement the fixes for the above points with the highest priority once I come back

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants