Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ def meta(self) -> _metadata.MetadataStore:
self._metadata = _metadata.MetadataStore()
return self._metadata

def tofile(self, file) -> None:
"""Write the tensor to a binary file.

This method writes the raw bytes of the tensor to a file-like object.
The file-like object must have a ``write`` method that accepts bytes.

Args:
file: A file-like object with a ``write`` method that accepts bytes.
"""
file.write(self.tobytes())

def display(self, *, page: bool = False) -> None:
rich = _display.require_rich()

Expand Down Expand Up @@ -523,6 +534,30 @@ def tobytes(self) -> bytes:
array = array.astype(array.dtype.newbyteorder("<"))
return array.tobytes()

def tofile(self, file) -> None:
"""Write the tensor to a binary file.

Args:
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
"""
if hasattr(file, "fileno") and isinstance(self._raw, np.ndarray):
# This is a duplication of tobytes() for handling special cases
array = self.numpy()
if self.dtype in {
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
}:
# Pack the array into int4
array = _type_casting.pack_4bitx2(array)
else:
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
if not _IS_LITTLE_ENDIAN:
array = array.view(array.dtype.newbyteorder("<"))
array.tofile(file)
else:
file.write(self.tobytes())


class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
"""An immutable concrete tensor with its data store on disk.
Expand Down Expand Up @@ -590,7 +625,7 @@ def __init__(
length: The length of the data in bytes.
dtype: The data type of the tensor.
shape: The shape of the tensor.
name: The name of the tensor..
name: The name of the tensor.
doc_string: The documentation string.
metadata_props: The metadata properties.
base_dir: The base directory for the external data. It is used to resolve relative paths.
Expand Down Expand Up @@ -746,6 +781,18 @@ def tobytes(self) -> bytes:
length = self._length or self.nbytes
return self.raw[offset : offset + length]

def tofile(self, file) -> None:
self._check_validity()
with open(self.path, "rb") as src:
if self._offset is not None:
src.seek(self._offset)
bytes_to_copy = self._length or self.nbytes
chunk_size = 1024 * 1024 # 1MB
while bytes_to_copy > 0:
chunk = src.read(min(chunk_size, bytes_to_copy))
file.write(chunk)
bytes_to_copy -= len(chunk)

def valid(self) -> bool:
"""Check if the tensor is valid.

Expand Down Expand Up @@ -979,6 +1026,14 @@ def tobytes(self) -> bytes:
"""Return the bytes of the tensor."""
return self._evaluate().tobytes()

def tofile(self, file) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether tofile() makes sense to LazyTensor. hmm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought it's not even real until it's evaluated. Intuitively, not very suitable with tofile(), which we want to write it to disk. But I guess in general expectation, we want all tensors have this method. It's understandable.

Copy link
Member Author

@justinchuby justinchuby Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually useful: even when the tensor is lazily evaluated, we still want to avoid tobytes() making a copy of the tensor data before writing to file. The screenshots on the PR description are showing lazy tensors.

"""Write the tensor to a binary file."""
tensor = self._evaluate()
if hasattr(tensor, "tofile"):
tensor.tofile(file)
else:
super().tofile(file)


class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
"""A tensor that stores 4bit datatypes in packed format.
Expand Down Expand Up @@ -1113,6 +1168,21 @@ def tobytes(self) -> bytes:
array = array.astype(array.dtype.newbyteorder("<"))
return array.tobytes()

def tofile(self, file) -> None:
"""Write the tensor to a binary file.

Args:
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
"""
if hasattr(file, "fileno"):
# This is a duplication of tobytes() for handling edge cases
array = self.numpy_packed()
if not _IS_LITTLE_ENDIAN:
array = array.view(array.dtype.newbyteorder("<"))
array.tofile(file)
else:
file.write(self.tobytes())


class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
"""Immutable symbolic dimension that can be shared across multiple shapes.
Expand Down
12 changes: 8 additions & 4 deletions src/onnx_ir/external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,18 @@ def _write_external_data(
)
current_offset = tensor_info.offset
assert tensor is not None
raw_data = tensor.tobytes()
if isinstance(tensor, _core.ExternalTensor):
tensor.release()
# Pad file to required offset if needed
file_size = data_file.tell()
if current_offset > file_size:
data_file.write(b"\0" * (current_offset - file_size))
data_file.write(raw_data)

if hasattr(tensor, "tofile"):
tensor.tofile(data_file)
else:
raw_data = tensor.tobytes()
if isinstance(tensor, _core.ExternalTensor):
tensor.release()
data_file.write(raw_data)


def _create_external_tensor(
Expand Down
22 changes: 14 additions & 8 deletions src/onnx_ir/tensor_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,7 @@ def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
return self.numpy()
return self.numpy().__array__(dtype)

def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
# Reading from memory directly is also more efficient because
# it avoids copying to a NumPy array
def _get_data_chunk(self):
import torch._subclasses.fake_tensor

with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
Expand All @@ -185,8 +182,17 @@ def tobytes(self) -> bytes:
"or save the model without initializers by setting include_initializers=False."
)

return bytes(
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)
return tensor, (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)

def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
# Reading from memory directly is also more efficient because
# it avoids copying to a NumPy array
_, address = self._get_data_chunk()
return bytes(address)

def tofile(self, file) -> None:
_, address = self._get_data_chunk()
return file.write(address)