diff --git a/onnxscript/ir/tensor_adapters.py b/onnxscript/ir/tensor_adapters.py index e24bce026e..a5cbb87e56 100644 --- a/onnxscript/ir/tensor_adapters.py +++ b/onnxscript/ir/tensor_adapters.py @@ -30,6 +30,7 @@ __all__ = [ "TorchTensor", + "MlxTensor", ] import ctypes @@ -39,14 +40,20 @@ from onnxscript import ir from onnxscript.ir import _core +import ml_dtypes +import numpy as np if TYPE_CHECKING: import torch + import mlx.core as mx class TorchTensor(_core.Tensor): def __init__( - self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None + self, + tensor: torch.Tensor, + name: str | None = None, + doc_string: str | None = None, ): # Pass the tensor as the raw data to ir.Tensor's constructor import torch @@ -73,7 +80,10 @@ def __init__( torch.uint64: ir.DataType.UINT64, } super().__init__( - tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string + tensor, + dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], + name=name, + doc_string=doc_string, ) def numpy(self) -> npt.NDArray: @@ -81,15 +91,22 @@ def numpy(self) -> npt.NDArray: self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).numpy(force=True) + return ( + self.raw.view(torch.uint16) + .numpy(force=True) + .view(dtype=self.dtype.numpy()) + ) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2, ir.DataType.FLOAT8E5M2FNUZ, }: - # TODO: Use ml_dtypes - return self.raw.view(torch.uint8).numpy(force=True) + return ( + self.raw.view(torch.uint8) + .numpy(force=True) + .view(dtype=self.dtype.numpy()) + ) return self.raw.numpy(force=True) def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: @@ -120,3 +137,47 @@ def tobytes(self) -> bytes: tensor.data_ptr() ) ) + + +class MlxTensor(_core.Tensor): + def __init__( + self, tensor: mx.array, name: str | None = None, doc_string: str | None = None + ): + import mlx.core as mx + + _MLX_DTYPE_TO_ONNX: dict[mx.Dtype, ir.DataType] = { + mx.bfloat16: ir.DataType.BFLOAT16, + mx.complex64: ir.DataType.COMPLEX64, + mx.float16: ir.DataType.FLOAT16, + mx.float32: ir.DataType.FLOAT, + mx.int16: ir.DataType.INT16, + mx.int32: ir.DataType.INT32, + mx.int64: ir.DataType.INT64, + mx.int8: ir.DataType.INT8, + mx.uint8: ir.DataType.UINT8, + mx.uint16: ir.DataType.UINT16, + mx.uint32: ir.DataType.UINT32, + mx.uint64: ir.DataType.UINT64, + } + super().__init__( + tensor, + dtype=_MLX_DTYPE_TO_ONNX[tensor.dtype], + name=name, + doc_string=doc_string, + ) + + def numpy(self) -> npt.NDArray: + import mlx.core as mx + + self.raw: mx.array + if self.dtype == ir.DataType.BFLOAT16: + return np.array(self.raw.view(mx.uint16), copy=False).view( + dtype=self.dtype.numpy() + ) + return np.array(self.raw, copy=False) + + def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: + del copy # Unused, but needed for the signature + if dtype is None: + return self.numpy() + return self.numpy().__array__(dtype)