-
Notifications
You must be signed in to change notification settings - Fork 93
Trace Transform for Tensor Wrapper Subclasses #1883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
to support programs that only call ctor of tensor wrapper subclasses Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
to support `__torch_dispatch__`. Since it extends the behavior that is implemented in C++ level, we'd need to apply the transform to split forward and backward traces separately. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
to support `__torch_dispatch__`. Since it extends the behavior that is implemented in C++ level, we'd need to apply the transform to split forward and backward traces separately. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
- Add `scaled_mm` - Change how the lookaside of `torch.autograd.Function.apply` applies dce taking the failure of apex fused rms norm into consideration. ```python @torch.no_grad() @no_autocast def FusedRMSNormAffineMixedDtypesFunction(t_0, t_1, tup11, f12, b13): # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128: weight_ = weight.contiguous() # t_0: "cuda:0 f32[4, 5, 3, 2]" # t_1: "cuda:0 f32[3, 2]" # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:127: input_ = input.contiguous() t5 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0) # t5: "cuda:0 f32[4, 5, 3, 2]" # t5 = prims.stride_order(t_0, (3, 2, 1, 0)) # t5: "cuda:0 f32[4, 5, 3, 2]" # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128: weight_ = weight.contiguous() t6 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0) # t6: "cuda:0 f32[3, 2]" # t6 = prims.stride_order(t_1, (1, 0)) # t6: "cuda:0 f32[3, 2]" (t10, t9) = apex_fused_rms_norm_forward_affine_mixed_dtypes(t5, (3, 2), t6, 1e-05) return t10 ``` For this trace, `thunder.core.transforms.dce` replaces `t9` with `_` then the augmented forward trace would lose the access to it. So by reusing the augmented forward trace in the basic forward trace, `dce` would not do so. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
also use `pytorch_executor` in the `transform_for_execution` of `prologue_trace` as it could have the prim of tensor subclass flattening whose definition is only available in pytorch executor. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
as bsyms of adhoc executor are tricky to handle Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
and use it when converting proxies into faketensors Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
as any of them could have strides whose first element is 1. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
to avoid swap map with a pair of key and value that have the same name but different `id` Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
as `bias=True` `torchao.float8` linear programs seem to mandate it Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
when `updated_bsym != bsym` as the new args and kwargs could involve tensor proxies that are introduced by `__torch_dispatch__` evaluation. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
with tanh approximate, things look more complicated and I'm seeing some errors Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@@ -479,7 +480,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com | |||
|
|||
prologue_traces += transform_for_execution( | |||
prologue_trc, | |||
executors_list=(pythonex,), | |||
executors_list=(pythonex, get_executor("torch")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be strict and precise, I want to include torchex only when prologue takes SubclassTensorProxy
s because torchex is only here for tensor_subclass.__tensor_flatten__()
used inside prologue.
def get_default_prefix(self) -> str: | ||
if (subclass_type := getattr(self, SubclassTensorProxy.SUBCLASS_TYPE_ATTR, None)) is None: | ||
return super().get_default_prefix() | ||
return subclass_type.__name__.lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say it's quite reasonable to remove this method but I also think it'd be helpful if we can tell what tensor subclass type a proxy represents.
bsym = prims.tensor_subclass_ctor.bind( | ||
self._subclass_type, | ||
self.name, | ||
self.shape, | ||
self.device, | ||
self.dtype, | ||
self.requires_grad, | ||
self._tensors, | ||
self._non_tensors, | ||
output=self, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm speculating this prim call could cause interpreter to throw AssertionError that I mention in the PR description.
# note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`, | ||
# seems to cause `new.primal` == `old`, leading to a cycle in swapping. | ||
if (key := variableify(new.primal)) != variableify(old): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to revisit this comment
thunder/torch/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of lines are for core ATen IR ops.
The most important change is torch._scaled_mm
, the second is memory_format
def add_name(self, name: str, *, prefix: str | None = None) -> None: | ||
from thunder.core.proxies import PREFIXES_ALLOW_NAME_DUPLICATES | ||
|
||
baseutils.check( | ||
name not in self.names, | ||
name not in self.names or (prefix is not None and prefix in PREFIXES_ALLOW_NAME_DUPLICATES), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a program calls a custom torch.autograd.Function
that makes a tensor subclass instance from torch.Tensor
s, the lookaside of torch.autograd.Function
would be called. The lookaside creates a trace of that function, and the trace would have a BoundSymbol of a prim of tensor subclass ctor that calls SubclassTensorProxy.__init__
which calls Proxy.__init__
which calls TraceCtx.add_name
.
The trace's bound symbols would be passed to OpExProcessor
and their sym.meta
would be evaluated thus that trace would try to add the same name for subclass tensor proxy at least twice. This special casing allows that duplication.
I noticed one interesting thing.
from __future__ import annotations
from typing import TYPE_CHECKING
from lightning_utilities.core.imports import package_available
import pytest
import torch
import torch.nn as nn
from torch.utils import _pytree as pytree
import thunder
from thunder.dynamo.compiler import ThunderCompiler
from thunder.tests.framework import (
DynamoThunderExecutor,
TorchExecutor,
instantiate,
nvFuserExecutor,
)
from thunder.tests.make_tensor import make_tensor
if TYPE_CHECKING:
from typing import Any
TORCHAO_AVAILABLE = package_available("torchao")
@torch._dynamo.allow_in_graph
class EncapsulateXandScale(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, scale: torch.Tensor):
return ScaleTensorSubclass(x, scale)
@staticmethod
def backward(ctx, grad):
return grad, None
def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass:
return EncapsulateXandScale.apply(x, scale)
@torch._dynamo.allow_in_graph
class ToScaleTensorSubclass(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return ScaleTensorSubclass.from_tensor(x)
@staticmethod
def backward(ctx, grad):
return grad
def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass:
return ToScaleTensorSubclass.apply(x)
class ScaleTensorSubclass(torch.Tensor):
_x: torch.Tensor
_scale: torch.Tensor
__slots__ = ["_x", "_scale"]
def __new__(cls, x: torch.Tensor, scale: torch.Tensor):
assert scale.numel() == 1, f"Invalid `scale`: {scale}"
dtype = x.dtype
device = x.device
self = torch.Tensor._make_wrapper_subclass(
cls,
x.size(),
dtype=dtype,
device=device,
# strides=x.stride(),
# storage_offset=x.storage_offset(),
# layout=x.layout,
requires_grad=x.requires_grad,
)
self._x = x
self._scale = scale
return self
# ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22
__torch_function__ = torch._C._disabled_torch_function_impl
def __repr__(self):
return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})"
def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]:
return ["_x", "_scale"], {}
@staticmethod
def __tensor_unflatten__(
inner_tensors: dict[str, torch.Tensor],
metadata: dict[str, Any],
outer_size,
outer_stride,
) -> ScaleTensorSubclass:
return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"])
@staticmethod
def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass:
scale = x.abs().max()
return ScaleTensorSubclass(x, scale)
@classmethod
def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None):
def allowed_subclass(typ):
return (
issubclass(cls, typ)
or issubclass(torch._subclasses.FakeTensor, typ)
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ)
)
def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any):
if isinstance(t, ScaleTensorSubclass):
if t.is_floating_point():
return t._x * t._scale
else:
return t._x
return t
if not all(allowed_subclass(t) for t in types):
raise NotImplementedError(f"Unsupported types are included: {types}")
if aten_ir_op in (torch.ops.aten.abs.default, torch.ops.aten.abs):
raise NotImplementedError(f"Op of {aten_ir_op=} is not supporte: {args=}, {kwargs=}")
scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass))
unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs))
out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs)
return out
def h(x: ScaleTensorSubclass, data: ScaleTensorSubclass | torch.Tensor):
print(f"$$$ `data` has the type of {type(data)}")
if not isinstance(data, ScaleTensorSubclass):
scale = data.abs().amax()
y = EncapsulateXandScale.apply(data, scale)
else:
y = data
return x + y
def main():
requires_grad = False
shape = (4, 4)
dtype = torch.float32
device = torch.device("cpu")
jitted = thunder.jit(h)
x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)
data = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)
actual = jitted(x, data)
if __name__ == "__main__":
main() |
What does this PR do?
There are about a handful of my pull requests of this topic.
This is based on a recent main commit and could be helpful for some other tensor subclasses.
I try to keep the transform agnostic to tensor subclass implementations while the tests are focused on
torchao.float8.Float8Tensor
of v0.7.0.There are many caveats:
torch._scaled_mm
, especially in backward.