Skip to content

[WIP] Representing DTensor in thunder traces #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Mar 26, 2025

Fixes #1898

TODO

  1. For backward, call the .contiguous equivalent on DTensor to convert the grad_output to the standard placement (which is baked in the backward trace) - see AOTDispatch: allow subclasses to correct when we guess metadata of tangents incorrectly pytorch/pytorch#118670.

Design Doc - https://docs.google.com/document/d/1Gqb_jXrL-sSqs-D8KrZdcQinxuUSlccZBnnvbYJfYl0/edit?usp=sharing

Changes -
This PR adds support for DTensor inputs to the jitted function. Most of the additions required to support DTensor are present in thunder/torch/experimental like the DTensorProxy, related prims and tracing utilities for the ATen decomposition.

NOTE: This PR just adds the basic infrastructure to be able to run a simple DTensor program (with torch.mul or torch.add). Coverage will be followed in subsequent PRs.

Following are the main updates:

  1. Prologue: Adds a new primitive check_dtensor_spec_repr which will match the repr of DTensorSpec of the DTensor in question (see the example below). PR also makes sure that besices the DTensorSpec there is tensor metadata check for the DTensor object as well as for the local tensor that it points to. NOTE - Other option for checking DTensorSpec would be to keep the inputs DTensorSpec in the TracingContext and prologue could verify for equality.

  2. DTensorProxy: Adds a new Proxy object to represent the DTensor. This class inherits from TensorProxy as DTensor is a tensor subclass and implements all the same methods that a tensor implements.

  3. Prims and Operations: For computation trace, adds two prims get_dtensor_inner_tensor and construct_dtensor to extract local tensor from DTensor and construct a DTensor respectively. Also, it only adds symbols for two aten operations.

  4. Representation in trace -

Example Program

def fn(x, w):
    return x * w

thunder.jit(fn)(x_dtensor, w_dtensor)

Prologue Trace (relevant snippet)

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  prims.check_len(args, 2)
  # kwargs: "Any"
  prims.check_len(kwargs, 0)
  l_x_: "DTensor cuda:0 f32[16, 16]" = args[0]
  l_w_: "DTensor cuda:0 f32[16, 16]" = args[1]
  dtensor_spec0: "<class 'NoneType'>" = l_x_._spec
  thunder.torch.experimental.dtensor_prims_and_impl.check_dtensor_spec_repr(dtensor_spec0, "DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=torch.Size([16, 16]), stride=(16, 1), dtype=torch.float32))")
  t1: "cuda:0 f32[8, 16]" = l_x_._local_tensor
  prims.check_tensor_shape_and_metadata(t1, (8, 16), 'cuda:0', torch.float32, True)
  prims.check_tensor_shape_and_metadata(l_x_, (16, 16), 'cuda:0', torch.float32, True)
  dtensor_spec2: "<class 'NoneType'>" = l_w_._spec
  thunder.torch.experimental.dtensor_prims_and_impl.check_dtensor_spec_repr(dtensor_spec2, "DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=torch.Size([16, 16]), stride=(16, 1), dtype=torch.float32))")
  t3: "cuda:0 f32[8, 16]" = l_w_._local_tensor
  prims.check_tensor_shape_and_metadata(t3, (8, 16), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(l_w_, (16, 16), 'cuda:0', torch.float32, False)

Computation Trace : There is a torch level symbol dtensor_mul which is decomposed into aten decomposition. This allows an executor to claim dtensor_mul or for fusion executor to fuse the decomposition if it can.

@torch.no_grad()
@no_autocast
def computation(l_x_, l_w_):
  # l_x_: "DTensor cuda:0 f32[16, 16]"
  # l_w_: "DTensor cuda:0 f32[16, 16]"

  # <eval_with_key>.10:5: 	    mul = torch.mul(l_x_, l_w_);  l_x_ = l_w_ = None
  mul = thunder.torch.experimental.dtensor_torch_and_aten_ops.dtensor_mul(l_x_, l_w_)  # mul: "DTensor cuda:0 f32[16, 16]"
    # t4 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(l_x_)  # t4: "cuda:0 f32[8, 16]"
    # t5 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(l_w_)  # t5: "cuda:0 f32[8, 16]"
    # t0 = thunder.torch.experimental.dtensor_torch_and_aten_ops.aten_mul(t4, t5)  # t0: "cuda:0 f32[8, 16]"
      # t0 = prims.mul(t4, t5)  # t0: "cuda:0 f32[8, 16]"
    # mul = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(t0, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32)))  # mul: "DTensor cuda:0 f32[16, 16]"
  return (mul,)

Backward Trace

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  # C0: "Collection"
  # C1: "Collection"
  t2, = cotangents
  # t2: "DTensor cuda:0 f32[16, 16]"
  t20, t21, = C0
  # t20: "cuda:0 f32[8, 16]"
  # t21: "cuda:0 f32[8, 16]"
  # C1 (empty sequence)
  t11 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(t2)  # t11: "cuda:0 f32[8, 16]"
  t13 = ltorch.mul(t21, t11)  # t13: "cuda:0 f32[8, 16]"
    # t13 = prims.mul(t21, t11)  # t13: "cuda:0 f32[8, 16]"
  t14 = ltorch.mul(t20, t11)  # t14: "cuda:0 f32[8, 16]"
    # t14 = prims.mul(t20, t11)  # t14: "cuda:0 f32[8, 16]"
  t16 = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(t14, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32)))  # t16: "DTensor cuda:0 f32[16, 16]"
  t18 = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(t13, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32)))  # t18: "DTensor cuda:0 f32[16, 16]"
  return (t18, None)

Thank you Masaki, Ivan and Mike for the helpful discussions and guidance!

Comment on lines +10 to +11
# Inherit from TensorProxy as DTensor also supports
# Tensor methods like __add__, __div__, sin, etc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I don't remember the behavior, would DTensorProxy.__add__ and others return an instance of DTensorProxy or TensorProxy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTensorProxy.__add__ will return an instance of DTensorProxy as after the method resolution, it will finally dispatch to the DTensor symbol.

For a method which hasn't been implemented yet for DTensor, it will error out with

Expected all inputs to be TensorProxy but found {list(map(lambda t: type(t), filter_tensor_proxies))}

@IvanYashchuk IvanYashchuk added the DTensor Issues about DTensor support in Thunder label Apr 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DTensor Issues about DTensor support in Thunder
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Accept DTensor input without errors
3 participants