-
Notifications
You must be signed in to change notification settings - Fork 92
Add PEFT benchmarking script in thunder/benchmarks
#1978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add PEFT benchmarking script in thunder/benchmarks
#1978
Conversation
for more information, see https://pre-commit.ci
# Configure model for static shapes before FSDP2 | ||
if hasattr(model, "config"): | ||
model.config.use_cache = True | ||
model.config.max_position_embeddings = args.seq_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can move setting model to static to a function as I see this in multiple different places.
return model | ||
|
||
|
||
def setup_fsdp2(model: torch.nn.Module, devices: int, verbose: bool = False) -> torch.nn.Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @crcrpar for review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack
dynamo_config.cache_size_limit = 64 | ||
# Disable gradient checkpointing for Thunder | ||
if hasattr(model, "gradient_checkpointing_enable"): | ||
model.gradient_checkpointing_disable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if this is not called?
"""Parse command line arguments.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", default="meta-llama/Llama-3.2-1B") | ||
parser.add_argument("--strategy", type=str, default="auto", choices=["auto", "ddp", "fsdp2"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ddp
is not supported with this script.
) | ||
logger.info(f"Base model loaded on meta device") | ||
|
||
# Configure model for static shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if this is not done?
logger.info(f"Configured model for static shapes with sequence length: {args.seq_length}") | ||
|
||
# Materialize the model on CUDA | ||
model = model.to_empty(device=f"cuda:{LOCAL_RANK}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of FSDP, I think materialization should happen after setup_fsdp2
step. Otherwise, we will get OOM for a model which would have worked with FSDP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that is the correct. However, when materializing it after setup_fsdp2
, I see about 30% slowdown to throughput compared to current ordering. Is it "normal"?
if "lora" in name.lower(): | ||
if not param.requires_grad: | ||
if args.verbose: | ||
logger.warning(f"LoRA parameter {name} does not require grad!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these be asserts instead? I think if the requires_grad was setup wrong, we shouldn't proceed with getting the numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about using one of existing requirements txt files, maybe https://github.com/Lightning-AI/lightning-thunder/blob/main/requirements/devel.txt?
@@ -0,0 +1,724 @@ | |||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite sure about the license (also, shouldn't the year be 2025?).
from torch.distributed import DeviceMesh, init_process_group | ||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.nn.attention import SDPBackend, sdpa_kernel | ||
from tqdm import tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we have tqdm included in any of requirements at the moment. Would transformers or some others install it as their dependency?
import random | ||
import time | ||
from contextlib import contextmanager | ||
from distutils.version import LooseVersion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit-picking
from distutils.version import LooseVersion | |
from looseversion import LooseVersion |
as we do in
from looseversion import LooseVersion |
return args | ||
|
||
|
||
def get_tokenizer(model_name: str, trust_remote_code: bool, fallback_model: str = "gpt2") -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refine the type annotation of return value and then remove from typing import Any
?
import time | ||
from contextlib import contextmanager | ||
from distutils.version import LooseVersion | ||
from typing import Any, List, Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from typing import Any, List, Optional |
Any has one use but I guess we can do away with it
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
import transformers | ||
from datasets import Dataset | ||
from loguru import logger | ||
from peft import LoraConfig, get_peft_model | ||
from torch.distributed import DeviceMesh, init_process_group | ||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.nn.attention import SDPBackend, sdpa_kernel | ||
from tqdm import tqdm | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clean up these imports? At glance, I'm not quite convinced with the imports of numpy and transformers.
For transformers import, we do import three from it so I think it'd be a bit cleaner to avoid import transformers
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): | ||
if verbose: | ||
logger.info(f"Wrapping layer {name} with FSDP2") | ||
fully_shard(module, mesh=mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for the consistency with the below
fully_shard(module, mesh=mesh) | |
fully_shard(module, mesh=mesh, reshard_after_forward=True) |
logger.info(f"Set static cache size to sequence length: {args.seq_length}") | ||
|
||
executors = thunder.get_default_executors() | ||
xforms: list = [NvtxProfileTransform()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we always use this one? I guess there's some overhead with this
if backend == "torchjit": | ||
logger.info("Compiling model with torch.compile") | ||
dist_print("Resetting cache size for torch.compile") | ||
import torch._dynamo.config as dynamo_config | ||
|
||
# Fixes recompilation issues with inductor | ||
dynamo_config.cache_size_limit = 64 | ||
model = torch.compile(model) | ||
elif backend == "thunder": | ||
import thunder | ||
import thunder.dynamo | ||
import torch._dynamo.config as dynamo_config | ||
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform | ||
from thunder.executors.transformer_engineex import transformer_engine_ex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I possibly miss something innegligible, but could you have all the imports at the beginning of the file?
@kshitij12345 @crcrpar Thanks a lot for the review and for many valid points above, which skipped my attention. I wanted to let you know that I will be OOTO starting tomorrow and I am not sure if anyone from my team will take care of this PR during my absence. I will implement the fixes for the above points with the highest priority once I come back |
Before submitting
What does this PR do?
It introduces the benchmarking script for PEFT finetuning scenario, which supports:
To execute:
for multi-gpu:
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃