-
Notifications
You must be signed in to change notification settings - Fork 92
[WIP] Representing DTensor in thunder traces #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] Representing DTensor in thunder traces #1907
Conversation
# Inherit from TensorProxy as DTensor also supports | ||
# Tensor methods like __add__, __div__, sin, etc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I don't remember the behavior, would DTensorProxy.__add__
and others return an instance of DTensorProxy
or TensorProxy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DTensorProxy.__add__
will return an instance of DTensorProxy as after the method resolution, it will finally dispatch to the DTensor symbol.
For a method which hasn't been implemented yet for DTensor, it will error out with
Expected all inputs to be TensorProxy but found {list(map(lambda t: type(t), filter_tensor_proxies))}
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
… dtensor-init-support
Fixes #1898
TODO
.contiguous
equivalent on DTensor to convert the grad_output to the standard placement (which is baked in the backward trace) - see AOTDispatch: allow subclasses to correct when we guess metadata of tangents incorrectly pytorch/pytorch#118670.Design Doc - https://docs.google.com/document/d/1Gqb_jXrL-sSqs-D8KrZdcQinxuUSlccZBnnvbYJfYl0/edit?usp=sharing
Changes -
This PR adds support for DTensor inputs to the jitted function. Most of the additions required to support DTensor are present in
thunder/torch/experimental
like theDTensorProxy
, related prims and tracing utilities for the ATen decomposition.NOTE: This PR just adds the basic infrastructure to be able to run a simple DTensor program (with
torch.mul
ortorch.add
). Coverage will be followed in subsequent PRs.Following are the main updates:
Prologue: Adds a new primitive
check_dtensor_spec_repr
which will match the repr ofDTensorSpec
of the DTensor in question (see the example below). PR also makes sure that besices theDTensorSpec
there is tensor metadata check for theDTensor
object as well as for the local tensor that it points to. NOTE - Other option for checkingDTensorSpec
would be to keep the inputsDTensorSpec
in the TracingContext and prologue could verify for equality.DTensorProxy: Adds a new Proxy object to represent the
DTensor
. This class inherits fromTensorProxy
asDTensor
is a tensor subclass and implements all the same methods that a tensor implements.Prims and Operations: For computation trace, adds two prims
get_dtensor_inner_tensor
andconstruct_dtensor
to extract local tensor from DTensor and construct a DTensor respectively. Also, it only adds symbols for twoaten
operations.Representation in trace -
Example Program
Prologue Trace (relevant snippet)
Computation Trace : There is a
torch
level symboldtensor_mul
which is decomposed intoaten
decomposition. This allows an executor to claimdtensor_mul
or for fusion executor to fuse the decomposition if it can.Backward Trace
Thank you Masaki, Ivan and Mike for the helpful discussions and guidance!