Skip to content

[IR] Refactor TensorBase to simplify implementation #2081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 55 additions & 72 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
"""Convenience Shared methods for classes implementing TensorProtocol."""

__slots__ = ()
__slots__ = (
"_doc_string",
"_metadata",
"_metadata_props",
"_name",
)

def __init__(
self,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = metadata_props
self._name: str | None = name
self._doc_string: str | None = doc_string

def _printable_type_shape(self) -> str:
"""Return a string representation of the shape and data type."""
Expand All @@ -110,6 +126,24 @@ def _repr_base(self) -> str:
"""
return f"{self.__class__.__name__}<{self._printable_type_shape()}>"

@property
def name(self) -> str | None:
"""The name of the tensor."""
return self._name

@name.setter
def name(self, value: str | None) -> None:
self._name = value

@property
def doc_string(self) -> str | None:
"""The documentation string."""
return self._doc_string

@doc_string.setter
def doc_string(self, value: str | None) -> None:
self._doc_string = value

@property
def size(self) -> int:
"""The number of elements in the tensor."""
Expand All @@ -121,6 +155,23 @@ def nbytes(self) -> int:
# Use math.ceil because when dtype is INT4, the itemsize is 0.5
return math.ceil(self.dtype.itemsize * self.size)

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.

Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata

def display(self, *, page: bool = False) -> None:
rich = _display.require_rich()

Expand Down Expand Up @@ -309,12 +360,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):

__slots__ = (
"_dtype",
"_metadata",
"_metadata_props",
"_raw",
"_shape",
"doc_string",
"name",
)

def __init__(
Expand Down Expand Up @@ -347,6 +394,7 @@ def __init__(
ValueError: If the shape is not specified and the value does not have a shape attribute.
ValueError: If the dtype is not specified and the value is not a numpy array.
"""
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
# NOTE: We should not do any copying here for performance reasons
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
raise TypeError(f"Expected an array compatible object, got {type(value)}")
Expand Down Expand Up @@ -381,10 +429,6 @@ def __init__(
value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]

self._raw = value
self.name = name
self.doc_string = doc_string
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props = metadata_props

def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
Expand Down Expand Up @@ -455,23 +499,6 @@ def tobytes(self) -> bytes:
array = array.view(array.dtype.newbyteorder("<"))
return array.tobytes()

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.

Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata


class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
"""An immutable concrete tensor with its data store on disk.
Expand Down Expand Up @@ -512,13 +539,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
"_dtype",
"_length",
"_location",
"_metadata",
"_metadata_props",
"_offset",
"_shape",
"_valid",
"doc_string",
"name",
"raw",
)

Expand Down Expand Up @@ -548,6 +571,7 @@ def __init__(
metadata_props: The metadata properties.
base_dir: The base directory for the external data. It is used to resolve relative paths.
"""
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
# NOTE: Do not verify the location by default. This is because the location field
# in the tensor proto can be anything and we would like deserialization from
# proto to IR to not fail.
Expand Down Expand Up @@ -725,34 +749,13 @@ def release(self) -> None:
self.raw.close()
self.raw = None

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.

Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata


class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""

__slots__ = (
"_metadata",
"_metadata_props",
"_raw",
"_shape",
"doc_string",
"name",
)

def __init__(
Expand All @@ -773,6 +776,7 @@ def __init__(
doc_string: The documentation string.
metadata_props: The metadata properties.
"""
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
if shape is None:
if not hasattr(value, "shape"):
raise ValueError(
Expand All @@ -784,10 +788,6 @@ def __init__(
self._shape = shape
self._shape._frozen = True
self._raw = value
self.name = name
self.doc_string = doc_string
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props = metadata_props

def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray):
Expand Down Expand Up @@ -835,23 +835,6 @@ def string_data(self) -> Sequence[bytes]:
return self._raw.flatten().tolist()
return self._raw

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.

Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata


class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
__slots__ = ("_value",)
Expand Down
26 changes: 4 additions & 22 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
import onnx
import onnx.external_data_helper

from onnxscript.ir import _core, _enums, _metadata, _protocols, _type_casting
from onnxscript.ir import _core, _enums, _protocols, _type_casting

if typing.TYPE_CHECKING:
import google.protobuf.internal.containers as proto_containers
Expand Down Expand Up @@ -242,12 +242,11 @@ def to_proto(ir_object: object) -> object:
class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
"""A tensor initialized from a tensor proto."""

__slots__ = ("_proto",)

def __init__(self, proto: onnx.TensorProto) -> None:
super().__init__(metadata_props=deserialize_metadata_props(proto.metadata_props))
self._proto = proto
self._metadata_props: dict[str, str] | None = deserialize_metadata_props(
proto.metadata_props
)
self._metadata: _metadata.MetadataStore | None = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -438,23 +437,6 @@ def tobytes(self) -> bytes:
# For example, int32_data can be empty and still be a valid tensor.
return b""

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.

Write to the :attr:`metadata_props` if you would like the metadata to be serialized
to the ONNX proto.
"""
if self._metadata is None:
self._metadata = _metadata.MetadataStore()
return self._metadata

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
self._metadata_props = {}
return self._metadata_props


def _get_field(proto: Any, field: str) -> Any:
if proto.HasField(field):
Expand Down
Loading