Skip to content

Trace Transform for Tensor Wrapper Subclasses #1883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 38 commits into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Mar 14, 2025

What does this PR do?

There are about a handful of my pull requests of this topic.
This is based on a recent main commit and could be helpful for some other tensor subclasses.

I try to keep the transform agnostic to tensor subclass implementations while the tests are focused on torchao.float8.Float8Tensor of v0.7.0.

There are many caveats:

crcrpar and others added 30 commits March 14, 2025 21:40
to support programs that only call ctor of tensor wrapper subclasses

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
to support `__torch_dispatch__`.
Since it extends the behavior that is implemented in C++ level,
we'd need to apply the transform to split forward and backward traces
separately.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
to support `__torch_dispatch__`.
Since it extends the behavior that is implemented in C++ level,
we'd need to apply the transform to split forward and backward traces
separately.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
- Add `scaled_mm`
- Change how the lookaside of `torch.autograd.Function.apply` applies dce
taking the failure of apex fused rms norm into consideration.

```python
@torch.no_grad()
@no_autocast
def FusedRMSNormAffineMixedDtypesFunction(t_0, t_1, tup11, f12, b13):
  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128:                 weight_ = weight.contiguous()
  # t_0: "cuda:0 f32[4, 5, 3, 2]"
  # t_1: "cuda:0 f32[3, 2]"

  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:127:                 input_ = input.contiguous()
  t5 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0)  # t5: "cuda:0 f32[4, 5, 3, 2]"
    # t5 = prims.stride_order(t_0, (3, 2, 1, 0))  # t5: "cuda:0 f32[4, 5, 3, 2]"

  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128:                 weight_ = weight.contiguous()
  t6 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0)  # t6: "cuda:0 f32[3, 2]"
    # t6 = prims.stride_order(t_1, (1, 0))  # t6: "cuda:0 f32[3, 2]"
  (t10, t9) = apex_fused_rms_norm_forward_affine_mixed_dtypes(t5, (3, 2), t6, 1e-05)
  return t10
```
For this trace, `thunder.core.transforms.dce` replaces `t9` with `_`
then the augmented forward trace would lose the access to it. So by
reusing the augmented forward trace in the basic forward trace, `dce`
would not do so.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
also use `pytorch_executor` in the `transform_for_execution` of
`prologue_trace` as it could have the prim of tensor subclass flattening
whose definition is only available in pytorch executor.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
as bsyms of adhoc executor are tricky to handle

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
and use it when converting proxies into faketensors

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
as any of them could have strides whose first element is 1.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
crcrpar added 8 commits March 14, 2025 22:06
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
to avoid swap map with a pair of key and value
that have the same name but different `id`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
as `bias=True` `torchao.float8` linear programs seem to mandate it

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
when `updated_bsym != bsym` as the new args and kwargs could involve
tensor proxies that are introduced by `__torch_dispatch__` evaluation.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
with tanh approximate, things look more complicated and I'm seeing some
errors

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 14, 2025
@@ -479,7 +480,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com

prologue_traces += transform_for_execution(
prologue_trc,
executors_list=(pythonex,),
executors_list=(pythonex, get_executor("torch")),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be strict and precise, I want to include torchex only when prologue takes SubclassTensorProxys because torchex is only here for tensor_subclass.__tensor_flatten__() used inside prologue.

Comment on lines +1923 to +1926
def get_default_prefix(self) -> str:
if (subclass_type := getattr(self, SubclassTensorProxy.SUBCLASS_TYPE_ATTR, None)) is None:
return super().get_default_prefix()
return subclass_type.__name__.lower()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it's quite reasonable to remove this method but I also think it'd be helpful if we can tell what tensor subclass type a proxy represents.

Comment on lines +1989 to +1999
bsym = prims.tensor_subclass_ctor.bind(
self._subclass_type,
self.name,
self.shape,
self.device,
self.dtype,
self.requires_grad,
self._tensors,
self._non_tensors,
output=self,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm speculating this prim call could cause interpreter to throw AssertionError that I mention in the PR description.

Comment on lines +134 to +136
# note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`,
# seems to cause `new.primal` == `old`, leading to a cycle in swapping.
if (key := variableify(new.primal)) != variableify(old):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to revisit this comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of lines are for core ATen IR ops.
The most important change is torch._scaled_mm, the second is memory_format

Comment on lines +181 to +185
def add_name(self, name: str, *, prefix: str | None = None) -> None:
from thunder.core.proxies import PREFIXES_ALLOW_NAME_DUPLICATES

baseutils.check(
name not in self.names,
name not in self.names or (prefix is not None and prefix in PREFIXES_ALLOW_NAME_DUPLICATES),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a program calls a custom torch.autograd.Function that makes a tensor subclass instance from torch.Tensors, the lookaside of torch.autograd.Function would be called. The lookaside creates a trace of that function, and the trace would have a BoundSymbol of a prim of tensor subclass ctor that calls SubclassTensorProxy.__init__ which calls Proxy.__init__ which calls TraceCtx.add_name.
The trace's bound symbols would be passed to OpExProcessor and their sym.meta would be evaluated thus that trace would try to add the same name for subclass tensor proxy at least twice. This special casing allows that duplication.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Apr 14, 2025

I noticed one interesting thing.
In the following snippet, ScaleTensorSubclass explicitly disallows torch.abs on its instances.
The function h checks the type of data but apparently it doesn't work well with Thunder's interpreter.
This check is in torchao.float8, e.g. _cast_weight_to_float8_t and tensor_already_casted_to_fp8

$$$ `data` has the type of <class '__main__.ScaleTensorSubclass'>
$$$ `data` has the type of <class 'thunder.core.proxies.SubclassTensorProxy'>
Traceback (most recent call last):
  File "/opt/pytorch/lightning-thunder/a.py", line 169, in <module>
    main()
  File "/opt/pytorch/lightning-thunder/a.py", line 161, in main
    actual = jitted(x, data)
             ^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 771, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 811, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 750, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 238, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 716, in get_computation_and_inputs
    cache_entry = apply_transforms_and_build_cache_entry(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 520, in apply_transforms_and_build_cache_entry
    computation_trc, _ = unroll_tensor_subclasses(computation_trc)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/transforms/tensor_wrapper_subclass.py", line 1110, in unroll_tensor_subclasses
    maybe_desugared_bsyms = desugar_tensor_subclass(bsym)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/transforms/tensor_wrapper_subclass.py", line 751, in __call__
    return self._process_bound_symbol_with_fx(updated_bsym, is_subclass_ctor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/transforms/tensor_wrapper_subclass.py", line 803, in _process_bound_symbol_with_fx
    fx, sequencified_cosmeticized_out, orig_output, _ = self.convert_trace_to_fx_graph_and_get_fake_result(trace)
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/transforms/tensor_wrapper_subclass.py", line 684, in convert_trace_to_fx_graph_and_get_fake_result
    fx: GraphModule = make_fx(f_with_wrap_and_unwrap)(*desugared_args)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2288, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2226, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2197, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 837, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 691, in flatten_fn
    tree_out = root_fn(*tree_args)
               ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1276, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/opt/pytorch/lightning-thunder/thunder/transforms/tensor_wrapper_subclass.py", line 665, in f_with_wrap_and_unwrap
    out = f(*args)
          ^^^^^^^^
  File "thunder.tmp_abs_5", line 6, in tmp_abs
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1324, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/a.py", line 126, in __torch_dispatch__
    raise NotImplementedError(f"Op of {aten_ir_op=} is not supporte: {args=}, {kwargs=}")
NotImplementedError: Op of aten_ir_op=<OpOverload(op='aten.abs', overload='default')> is not supporte: args=(ScaleTensorSubclass(dtype=torch.float32, device=cpu, x=FakeTensor(..., size=(4, 4)), scale=FakeTensor(..., size=())),), kwargs={}
from __future__ import annotations
from typing import TYPE_CHECKING

from lightning_utilities.core.imports import package_available
import pytest
import torch
import torch.nn as nn
from torch.utils import _pytree as pytree

import thunder
from thunder.dynamo.compiler import ThunderCompiler
from thunder.tests.framework import (
    DynamoThunderExecutor,
    TorchExecutor,
    instantiate,
    nvFuserExecutor,
)
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
    from typing import Any


TORCHAO_AVAILABLE = package_available("torchao")


@torch._dynamo.allow_in_graph
class EncapsulateXandScale(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, scale: torch.Tensor):
        return ScaleTensorSubclass(x, scale)

    @staticmethod
    def backward(ctx, grad):
        return grad, None


def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass:
    return EncapsulateXandScale.apply(x, scale)


@torch._dynamo.allow_in_graph
class ToScaleTensorSubclass(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        return ScaleTensorSubclass.from_tensor(x)

    @staticmethod
    def backward(ctx, grad):
        return grad


def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass:
    return ToScaleTensorSubclass.apply(x)


class ScaleTensorSubclass(torch.Tensor):
    _x: torch.Tensor
    _scale: torch.Tensor
    __slots__ = ["_x", "_scale"]

    def __new__(cls, x: torch.Tensor, scale: torch.Tensor):
        assert scale.numel() == 1, f"Invalid `scale`: {scale}"
        dtype = x.dtype
        device = x.device
        self = torch.Tensor._make_wrapper_subclass(
            cls,
            x.size(),
            dtype=dtype,
            device=device,
            # strides=x.stride(),
            # storage_offset=x.storage_offset(),
            # layout=x.layout,
            requires_grad=x.requires_grad,
        )
        self._x = x
        self._scale = scale

        return self

    # ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22
    __torch_function__ = torch._C._disabled_torch_function_impl

    def __repr__(self):
        return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})"

    def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]:
        return ["_x", "_scale"], {}

    @staticmethod
    def __tensor_unflatten__(
        inner_tensors: dict[str, torch.Tensor],
        metadata: dict[str, Any],
        outer_size,
        outer_stride,
    ) -> ScaleTensorSubclass:
        return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"])

    @staticmethod
    def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass:
        scale = x.abs().max()
        return ScaleTensorSubclass(x, scale)

    @classmethod
    def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None):

        def allowed_subclass(typ):
            return (
                issubclass(cls, typ)
                or issubclass(torch._subclasses.FakeTensor, typ)
                or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ)
            )

        def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any):
            if isinstance(t, ScaleTensorSubclass):
                if t.is_floating_point():
                    return t._x * t._scale
                else:
                    return t._x
            return t

        if not all(allowed_subclass(t) for t in types):
            raise NotImplementedError(f"Unsupported types are included: {types}")

        if aten_ir_op in (torch.ops.aten.abs.default, torch.ops.aten.abs):
            raise NotImplementedError(f"Op of {aten_ir_op=} is not supporte: {args=}, {kwargs=}")

        scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass))
        unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs))
        out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs)
        return out


def h(x: ScaleTensorSubclass, data: ScaleTensorSubclass | torch.Tensor):
    print(f"$$$ `data` has the type of {type(data)}")
    if not isinstance(data, ScaleTensorSubclass):
        scale = data.abs().amax()
        y = EncapsulateXandScale.apply(data, scale)
    else:
        y = data
    return x + y


def main():
    requires_grad = False
    shape = (4, 4)
    dtype = torch.float32
    device = torch.device("cpu")
    jitted = thunder.jit(h)

    x = ScaleTensorSubclass(
        make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
        make_tensor((), device=device, dtype=dtype),
    )
    data = ScaleTensorSubclass(
        make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
        make_tensor((), device=device, dtype=dtype),
    )
    actual = jitted(x, data)


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant