From 82b3f58229f8a98e61c3c7d1f1c1337189eac115 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:55:46 -0700 Subject: [PATCH 01/22] Fix endian Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 86badf60..4c97a8cb 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -185,6 +185,17 @@ def meta(self) -> _metadata.MetadataStore: self._metadata = _metadata.MetadataStore() return self._metadata + def write(self, file) -> None: + """Write the tensor to a binary file. + + This method writes the raw bytes of the tensor to a file-like object. + The file-like object must have a ``write`` method that accepts bytes. + + Args: + file: A file-like object with a ``write`` method that accepts bytes. + """ + file.write(self.tobytes()) + def display(self, *, page: bool = False) -> None: rich = _display.require_rich() @@ -520,9 +531,17 @@ def tobytes(self) -> bytes: else: assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" if not _IS_LITTLE_ENDIAN: - array = array.view(array.dtype.newbyteorder("<")) + array = array.astype(array.dtype.newbyteorder("<")) return array.tobytes() + def write(self, file) -> None: + """Write the tensor to a binary file. + + Args: + file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. + """ + file.write(self.tobytes()) + class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors """An immutable concrete tensor with its data store on disk. @@ -1110,7 +1129,7 @@ def tobytes(self) -> bytes: """ array = self.numpy_packed() if not _IS_LITTLE_ENDIAN: - array = array.view(array.dtype.newbyteorder("<")) + array = array.astype(array.dtype.newbyteorder("<")) return array.tobytes() From 42d8edc6fb94f4bd6cb841a26bbf466a9be14447 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:56:05 -0700 Subject: [PATCH 02/22] nvm Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 4c97a8cb..6c729f3e 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -531,7 +531,7 @@ def tobytes(self) -> bytes: else: assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" if not _IS_LITTLE_ENDIAN: - array = array.astype(array.dtype.newbyteorder("<")) + array = array.view(array.dtype.newbyteorder("<")) return array.tobytes() def write(self, file) -> None: @@ -1129,7 +1129,7 @@ def tobytes(self) -> bytes: """ array = self.numpy_packed() if not _IS_LITTLE_ENDIAN: - array = array.astype(array.dtype.newbyteorder("<")) + array = array.view(array.dtype.newbyteorder("<")) return array.tobytes() From 63310c10dfa2f9aaa4357557a9111c9efb16e084 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 16:14:31 -0700 Subject: [PATCH 03/22] More implementations Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 55 ++++++++++++++++++++++++++++++++-- src/onnx_ir/tensor_adapters.py | 20 ++++++++----- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 6c729f3e..197fa2cb 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -540,7 +540,23 @@ def write(self, file) -> None: Args: file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. """ - file.write(self.tobytes()) + if hasattr(file, "fileno"): + # This is a duplication of tobytes() for handling edge cases + array = self.numpy() + if self.dtype in { + _enums.DataType.INT4, + _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, + }: + # Pack the array into int4 + array = _type_casting.pack_4bitx2(array) + else: + assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" + if not _IS_LITTLE_ENDIAN: + array = array.view(array.dtype.newbyteorder("<")) + array.tofile(file) + else: + file.write(self.tobytes()) class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors @@ -609,7 +625,7 @@ def __init__( length: The length of the data in bytes. dtype: The data type of the tensor. shape: The shape of the tensor. - name: The name of the tensor.. + name: The name of the tensor. doc_string: The documentation string. metadata_props: The metadata properties. base_dir: The base directory for the external data. It is used to resolve relative paths. @@ -765,6 +781,18 @@ def tobytes(self) -> bytes: length = self._length or self.nbytes return self.raw[offset : offset + length] + def write(self, file) -> None: + self._check_validity() + with open(self.path, "rb") as src: + if self._offset is not None: + src.seek(self._offset) + bytes_to_copy = self._length or self.nbytes + chunk_size = 1024 * 1024 # 1MB + while bytes_to_copy > 0: + chunk = src.read(min(chunk_size, bytes_to_copy)) + file.write(chunk) + bytes_to_copy -= len(chunk) + def valid(self) -> bool: """Check if the tensor is valid. @@ -998,6 +1026,14 @@ def tobytes(self) -> bytes: """Return the bytes of the tensor.""" return self._evaluate().tobytes() + def write(self, file) -> None: + """Write the tensor to a binary file.""" + tensor = self._evaluate() + if hasattr(tensor, "write"): + tensor.write(file) + else: + super().write(file) + class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors """A tensor that stores 4bit datatypes in packed format. @@ -1132,6 +1168,21 @@ def tobytes(self) -> bytes: array = array.view(array.dtype.newbyteorder("<")) return array.tobytes() + def write(self, file) -> None: + """Write the tensor to a binary file. + + Args: + file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. + """ + if hasattr(file, "fileno"): + # This is a duplication of tobytes() for handling edge cases + array = self.numpy_packed() + if not _IS_LITTLE_ENDIAN: + array = array.view(array.dtype.newbyteorder("<")) + array.tofile(file) + else: + file.write(self.tobytes()) + class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): """Immutable symbolic dimension that can be shared across multiple shapes. diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index cb4e0ccf..42b48c82 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -168,10 +168,7 @@ def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: return self.numpy() return self.numpy().__array__(dtype) - def tobytes(self) -> bytes: - # Implement tobytes to support native PyTorch types so we can use types like bloat16 - # Reading from memory directly is also more efficient because - # it avoids copying to a NumPy array + def _get_data_chunk(self): import torch._subclasses.fake_tensor with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access @@ -185,8 +182,15 @@ def tobytes(self) -> bytes: "or save the model without initializers by setting include_initializers=False." ) - return bytes( - (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( - tensor.data_ptr() - ) + return (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( + tensor.data_ptr() ) + + def tobytes(self) -> bytes: + # Implement tobytes to support native PyTorch types so we can use types like bloat16 + # Reading from memory directly is also more efficient because + # it avoids copying to a NumPy array + return bytes(self._get_data_chunk()) + + def write(self, file) -> None: + return file.write(self._get_data_chunk()) From 290ab6c7081dd92ce092821c18ef0fead3dd83da Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 16:24:18 -0700 Subject: [PATCH 04/22] tofile Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 14 +++++++------- src/onnx_ir/external_data.py | 12 ++++++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 197fa2cb..bb553435 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -185,7 +185,7 @@ def meta(self) -> _metadata.MetadataStore: self._metadata = _metadata.MetadataStore() return self._metadata - def write(self, file) -> None: + def tofile(self, file) -> None: """Write the tensor to a binary file. This method writes the raw bytes of the tensor to a file-like object. @@ -534,7 +534,7 @@ def tobytes(self) -> bytes: array = array.view(array.dtype.newbyteorder("<")) return array.tobytes() - def write(self, file) -> None: + def tofile(self, file) -> None: """Write the tensor to a binary file. Args: @@ -781,7 +781,7 @@ def tobytes(self) -> bytes: length = self._length or self.nbytes return self.raw[offset : offset + length] - def write(self, file) -> None: + def tofile(self, file) -> None: self._check_validity() with open(self.path, "rb") as src: if self._offset is not None: @@ -1026,13 +1026,13 @@ def tobytes(self) -> bytes: """Return the bytes of the tensor.""" return self._evaluate().tobytes() - def write(self, file) -> None: + def tofile(self, file) -> None: """Write the tensor to a binary file.""" tensor = self._evaluate() if hasattr(tensor, "write"): - tensor.write(file) + tensor.tofile(file) else: - super().write(file) + super().tofile(file) class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors @@ -1168,7 +1168,7 @@ def tobytes(self) -> bytes: array = array.view(array.dtype.newbyteorder("<")) return array.tobytes() - def write(self, file) -> None: + def tofile(self, file) -> None: """Write the tensor to a binary file. Args: diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index ab4d504c..b7f2e8c2 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -205,14 +205,18 @@ def _write_external_data( ) current_offset = tensor_info.offset assert tensor is not None - raw_data = tensor.tobytes() - if isinstance(tensor, _core.ExternalTensor): - tensor.release() # Pad file to required offset if needed file_size = data_file.tell() if current_offset > file_size: data_file.write(b"\0" * (current_offset - file_size)) - data_file.write(raw_data) + + if hasattr(tensor, "write"): + tensor.tofile(data_file) + else: + raw_data = tensor.tobytes() + if isinstance(tensor, _core.ExternalTensor): + tensor.release() + data_file.write(raw_data) def _create_external_tensor( From 1b53a6a0c676820c09785af52bbe49f76b144ece Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 16:28:22 -0700 Subject: [PATCH 05/22] hasattr Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 2 +- src/onnx_ir/external_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index bb553435..c915ccc1 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1029,7 +1029,7 @@ def tobytes(self) -> bytes: def tofile(self, file) -> None: """Write the tensor to a binary file.""" tensor = self._evaluate() - if hasattr(tensor, "write"): + if hasattr(tensor, "tofile"): tensor.tofile(file) else: super().tofile(file) diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index b7f2e8c2..1fccb62b 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -210,7 +210,7 @@ def _write_external_data( if current_offset > file_size: data_file.write(b"\0" * (current_offset - file_size)) - if hasattr(tensor, "write"): + if hasattr(tensor, "tofile"): tensor.tofile(data_file) else: raw_data = tensor.tobytes() From c05e1894f2482853ba5e3c61a5197c5b086888df Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 17:19:18 -0700 Subject: [PATCH 06/22] tofile! Signed-off-by: Justin Chu --- src/onnx_ir/tensor_adapters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 42b48c82..4ce5f2be 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -192,5 +192,5 @@ def tobytes(self) -> bytes: # it avoids copying to a NumPy array return bytes(self._get_data_chunk()) - def write(self, file) -> None: - return file.write(self._get_data_chunk()) + def tofile(self, file) -> None: + return file.tofile(self._get_data_chunk()) From 6377435a2e101905740ff7e337d67e94923eea0d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 17:20:18 -0700 Subject: [PATCH 07/22] write Signed-off-by: Justin Chu --- src/onnx_ir/tensor_adapters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 4ce5f2be..8a475bee 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -193,4 +193,4 @@ def tobytes(self) -> bytes: return bytes(self._get_data_chunk()) def tofile(self, file) -> None: - return file.tofile(self._get_data_chunk()) + return file.write(self._get_data_chunk()) From 3dc57041f5a54d7cf211a573ea7a8d2d5c3b4e93 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 17:22:57 -0700 Subject: [PATCH 08/22] always write numpy Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index c915ccc1..be593881 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -540,8 +540,8 @@ def tofile(self, file) -> None: Args: file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. """ - if hasattr(file, "fileno"): - # This is a duplication of tobytes() for handling edge cases + if hasattr(file, "fileno") and isinstance(self._raw, np.ndarray): + # This is a duplication of tobytes() for handling special cases array = self.numpy() if self.dtype in { _enums.DataType.INT4, From 7fd35d7748f930c6f95918a166312d0f1ea97d69 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 17:38:14 -0700 Subject: [PATCH 09/22] Maintain reference Signed-off-by: Justin Chu --- src/onnx_ir/tensor_adapters.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 8a475bee..9b074ce2 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -182,7 +182,7 @@ def _get_data_chunk(self): "or save the model without initializers by setting include_initializers=False." ) - return (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( + return tensor, (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( tensor.data_ptr() ) @@ -190,7 +190,9 @@ def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array - return bytes(self._get_data_chunk()) + _, address = self._get_data_chunk() + return bytes(address) def tofile(self, file) -> None: - return file.write(self._get_data_chunk()) + _, address = self._get_data_chunk() + return file.write(address) From 909344d80bfba1b1d6430a907f947f9782ef2315 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 13:10:48 -0700 Subject: [PATCH 10/22] Fix fileno Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index fd4ad954..79c1cf88 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -348,6 +348,17 @@ def _maybe_view_np_array_with_ml_dtypes( return array +def _supports_fileno(file: Any) -> bool: + """Check if the file-like object supports fileno().""" + if not hasattr(file, "fileno"): + return False + try: + file.fileno() + except Exception: # pylint: disable=broad-except + return False + return True + + class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors """An immutable concrete tensor. @@ -540,7 +551,7 @@ def tofile(self, file) -> None: Args: file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. """ - if hasattr(file, "fileno") and isinstance(self._raw, np.ndarray): + if _supports_fileno(file) and isinstance(self._raw, np.ndarray): # This is a duplication of tobytes() for handling special cases array = self.numpy() if self.dtype in { @@ -553,7 +564,7 @@ def tofile(self, file) -> None: else: assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" if not _IS_LITTLE_ENDIAN: - array = array.view(array.dtype.newbyteorder("<")) + array = array.astype(array.dtype.newbyteorder("<")) array.tofile(file) else: file.write(self.tobytes()) @@ -1174,11 +1185,11 @@ def tofile(self, file) -> None: Args: file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. """ - if hasattr(file, "fileno"): + if _supports_fileno(file): # This is a duplication of tobytes() for handling edge cases array = self.numpy_packed() if not _IS_LITTLE_ENDIAN: - array = array.view(array.dtype.newbyteorder("<")) + array = array.astype(array.dtype.newbyteorder("<")) array.tofile(file) else: file.write(self.tobytes()) From e7dc3013229df9508c67525dfb9ddb61689b08ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 13:11:57 -0700 Subject: [PATCH 11/22] Test Signed-off-by: Justin Chu --- src/onnx_ir/tensor_adapters_test.py | 52 +++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/onnx_ir/tensor_adapters_test.py b/src/onnx_ir/tensor_adapters_test.py index 6e759081..6d4ad851 100644 --- a/src/onnx_ir/tensor_adapters_test.py +++ b/src/onnx_ir/tensor_adapters_test.py @@ -5,6 +5,7 @@ from __future__ import annotations import importlib.util +import tempfile import unittest import ml_dtypes @@ -83,6 +84,57 @@ def test_tobytes(self, dtype: torch.dtype): tensor = tensor_adapters.TorchTensor(torch.tensor([1], dtype=dtype)) self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes()) + def test_tofile_method_exists_and_works(self): + """Test that tofile() method exists and works correctly.""" + import io + + torch_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + tensor = tensor_adapters.TorchTensor(torch_tensor) + + # Test with BytesIO buffer + buffer = io.BytesIO() + tensor.tofile(buffer) + result_bytes = buffer.getvalue() + + expected_bytes = tensor.tobytes() + self.assertEqual(result_bytes, expected_bytes) + + @parameterized.parameterized.expand( + [ + (torch.bfloat16,), + (torch.bool,), + (torch.complex128,), + (torch.complex64,), + (torch.float16,), + (torch.float32,), + (torch.float64,), + (torch.float8_e4m3fn,), + (torch.float8_e4m3fnuz,), + (torch.float8_e5m2,), + (torch.float8_e5m2fnuz,), + (torch.int16,), + (torch.int32,), + (torch.int64,), + (torch.int8,), + (torch.uint16,), + (torch.uint32,), + (torch.uint64,), + (torch.uint8,), + ], + ) + def test_tofile(self, dtype: torch.dtype): + """Test tofile() method for various data types.""" + torch_tensor = torch.tensor([1], dtype=dtype) + tensor = tensor_adapters.TorchTensor(torch_tensor) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + expected_bytes = tensor.tobytes() + self.assertEqual(result_bytes, expected_bytes) + class TorchDtypeConversionTest(unittest.TestCase): @parameterized.parameterized.expand( From 8f832b3789c14307907ee92696bcc85e9b3d1506 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 13:13:15 -0700 Subject: [PATCH 12/22] test Signed-off-by: Justin Chu --- src/onnx_ir/_core_test.py | 167 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index a32f2bb7..0ff6daf0 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import io import pathlib import tempfile import unittest @@ -194,6 +195,112 @@ def test_metadata(self): tensor.metadata_props["test"] = "any string" self.assertEqual(tensor.metadata_props["test"], "any string") + def test_tobytes_big_endian_handling(self): + """Test that tobytes() correctly handles byte order conversion on big endian systems.""" + import unittest.mock + + array = np.array([1.0, 2.0, 3.0], dtype=np.float32) + tensor = _core.Tensor(array) + + # Mock _IS_LITTLE_ENDIAN to simulate big endian system + with unittest.mock.patch("onnx_ir._core._IS_LITTLE_ENDIAN", False): + result_bytes = tensor.tobytes() + + # Verify that the result is in little endian format regardless of system endianness + expected_bytes = array.astype(array.dtype.newbyteorder("<")).tobytes() + self.assertEqual(result_bytes, expected_bytes) + + def test_tobytes_packed_types_big_endian_handling(self): + """Test that tobytes() handles byte order conversion for packed 4-bit types.""" + import unittest.mock + + array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) + tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) + + # Mock _IS_LITTLE_ENDIAN to simulate big endian system + with unittest.mock.patch("onnx_ir._core._IS_LITTLE_ENDIAN", False): + result_bytes = tensor.tobytes() + + # For packed types, the result should be the same as the packed data in little endian + packed_array = _type_casting.pack_4bitx2(array.view(ir.DataType.UINT4.numpy())) + expected_bytes = packed_array.astype(packed_array.dtype.newbyteorder("<")).tobytes() + self.assertEqual(result_bytes, expected_bytes) + + def test_tofile_with_fileno_numpy_array(self): + """Test tofile() with file-like object that has fileno() method and numpy array.""" + import tempfile + + array = np.array([1.0, 2.0, 3.0], dtype=np.float32) + tensor = _core.Tensor(array) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + self.assertEqual(result_bytes, array.tobytes()) + + def test_tofile_with_fileno_non_numpy_array(self): + """Test tofile() with file-like object that has fileno() method but non-numpy array.""" + import tempfile + + array = np.array([1.0, 2.0, 3.0], dtype=np.float32) + torch_tensor = torch.tensor(array) + tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + # Should use tobytes() path since _raw is not a numpy array + self.assertEqual(result_bytes, tensor.tobytes()) + + def test_tofile_without_fileno(self): + """Test tofile() with file-like object without fileno() method.""" + array = np.array([1.0, 2.0, 3.0], dtype=np.float32) + tensor = _core.Tensor(array) + + buffer = io.BytesIO() + tensor.tofile(buffer) + result_bytes = buffer.getvalue() + + self.assertEqual(result_bytes, tensor.tobytes()) + + def test_tofile_packed_types_with_fileno(self): + """Test tofile() with packed types and file with fileno().""" + import tempfile + + array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) + tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + # Should be the same as tobytes() for packed types + self.assertEqual(result_bytes, tensor.tobytes()) + + def test_tofile_big_endian_handling_with_fileno(self): + """Test tofile() big endian handling when file has fileno() method.""" + import tempfile + import unittest.mock + + array = np.array([1.0, 2.0, 3.0], dtype=np.float32) + tensor = _core.Tensor(array) + + with tempfile.NamedTemporaryFile() as temp_file: + # Mock _IS_LITTLE_ENDIAN to simulate big endian system + with unittest.mock.patch("onnx_ir._core._IS_LITTLE_ENDIAN", False): + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + # Should still produce little endian output + expected_bytes = array.astype(array.dtype.newbyteorder("<")).tobytes() + self.assertEqual(result_bytes, expected_bytes) + def _to_external_tensor(tensor_proto, dir: str, filename: str): onnx.external_data_helper.set_external_data(tensor_proto, location=filename) @@ -2628,6 +2735,66 @@ def test_integration_with_regular_tensor_operations(self): result = tensor.numpy() self.assertEqual(result.sum(), 10) # 1+2+3+4 = 10 + @parameterized.parameterized.expand( + [ + ("INT4", ir.DataType.INT4), + ("UINT4", ir.DataType.UINT4), + ("FLOAT4E2M1", ir.DataType.FLOAT4E2M1), + ] + ) + def test_tobytes_big_endian_handling(self, _: str, dtype: ir.DataType): + """Test that PackedTensor.tobytes() correctly handles byte order conversion.""" + import unittest.mock + + # Create packed data + packed_data = np.array([0x21, 0x43], dtype=np.uint8) + shape = _core.Shape([4]) + tensor = _core.PackedTensor(packed_data, dtype=dtype, shape=shape) + + # Mock _IS_LITTLE_ENDIAN to simulate big endian system + with unittest.mock.patch("onnx_ir._core._IS_LITTLE_ENDIAN", False): + result_bytes = tensor.tobytes() + + # Verify that the result is in little endian format regardless of system endianness + expected_bytes = packed_data.astype(packed_data.dtype.newbyteorder("<")).tobytes() + self.assertEqual(result_bytes, expected_bytes) + + def test_tofile_packed_tensor(self): + """Test tofile() method works correctly for PackedTensor.""" + import tempfile + + packed_data = np.array([0x21, 0x43], dtype=np.uint8) + shape = _core.Shape([4]) + tensor = _core.PackedTensor(packed_data, dtype=ir.DataType.UINT4, shape=shape) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + # Should be the same as tobytes() + self.assertEqual(result_bytes, tensor.tobytes()) + + def test_tofile_packed_tensor_big_endian_handling(self): + """Test tofile() big endian handling for PackedTensor.""" + import tempfile + import unittest.mock + + packed_data = np.array([0x21, 0x43], dtype=np.uint8) + shape = _core.Shape([4]) + tensor = _core.PackedTensor(packed_data, dtype=ir.DataType.UINT4, shape=shape) + + with tempfile.NamedTemporaryFile() as temp_file: + # Mock _IS_LITTLE_ENDIAN to simulate big endian system + with unittest.mock.patch("onnx_ir._core._IS_LITTLE_ENDIAN", False): + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + # Should still produce little endian output + expected_bytes = packed_data.astype(packed_data.dtype.newbyteorder("<")).tobytes() + self.assertEqual(result_bytes, expected_bytes) + class StringTensorTest(unittest.TestCase): def test_nbytes(self): From 9afc144a614707d74ea1e898e7c635349ace8d8a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 20:51:22 -0700 Subject: [PATCH 13/22] Create tests Signed-off-by: Justin Chu --- src/onnx_ir/_core_test.py | 128 ++++++++++++++++++++++++++++++++------ 1 file changed, 108 insertions(+), 20 deletions(-) diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index 0ff6daf0..3d9a2913 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -7,6 +7,7 @@ import pathlib import tempfile import unittest +import unittest.mock from typing import Any import ml_dtypes @@ -197,8 +198,6 @@ def test_metadata(self): def test_tobytes_big_endian_handling(self): """Test that tobytes() correctly handles byte order conversion on big endian systems.""" - import unittest.mock - array = np.array([1.0, 2.0, 3.0], dtype=np.float32) tensor = _core.Tensor(array) @@ -212,8 +211,6 @@ def test_tobytes_big_endian_handling(self): def test_tobytes_packed_types_big_endian_handling(self): """Test that tobytes() handles byte order conversion for packed 4-bit types.""" - import unittest.mock - array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) @@ -228,8 +225,6 @@ def test_tobytes_packed_types_big_endian_handling(self): def test_tofile_with_fileno_numpy_array(self): """Test tofile() with file-like object that has fileno() method and numpy array.""" - import tempfile - array = np.array([1.0, 2.0, 3.0], dtype=np.float32) tensor = _core.Tensor(array) @@ -242,8 +237,6 @@ def test_tofile_with_fileno_numpy_array(self): def test_tofile_with_fileno_non_numpy_array(self): """Test tofile() with file-like object that has fileno() method but non-numpy array.""" - import tempfile - array = np.array([1.0, 2.0, 3.0], dtype=np.float32) torch_tensor = torch.tensor(array) tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) @@ -269,8 +262,6 @@ def test_tofile_without_fileno(self): def test_tofile_packed_types_with_fileno(self): """Test tofile() with packed types and file with fileno().""" - import tempfile - array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) @@ -284,9 +275,6 @@ def test_tofile_packed_types_with_fileno(self): def test_tofile_big_endian_handling_with_fileno(self): """Test tofile() big endian handling when file has fileno() method.""" - import tempfile - import unittest.mock - array = np.array([1.0, 2.0, 3.0], dtype=np.float32) tensor = _core.Tensor(array) @@ -301,6 +289,113 @@ def test_tofile_big_endian_handling_with_fileno(self): expected_bytes = array.astype(array.dtype.newbyteorder("<")).tobytes() self.assertEqual(result_bytes, expected_bytes) + def test_tofile_empty_tensor(self): + """Test tofile() with an empty tensor.""" + # Test with numpy empty array + empty_array = np.array([], dtype=np.float32) + tensor = _core.Tensor(empty_array) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + # Empty tensor should write empty bytes + self.assertEqual(result_bytes, b"") + self.assertEqual(result_bytes, tensor.tobytes()) + + def test_tofile_empty_tensor_torch(self): + """Test tofile() with an empty torch tensor.""" + # Test with torch empty tensor + empty_torch_tensor = torch.tensor([], dtype=torch.float32) + tensor = _core.Tensor(empty_torch_tensor, dtype=ir.DataType.FLOAT) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + result_bytes = temp_file.read() + + # Empty tensor should write empty bytes + self.assertEqual(result_bytes, b"") + self.assertEqual(result_bytes, tensor.tobytes()) + + def test_tofile_consecutive_writes_same_file(self): + """Test tofile() with three tensors writing consecutively to the same file.""" + # Create three different tensors + array1 = np.array([1.0, 2.0], dtype=np.float32) + array2 = np.array([3.0, 4.0, 5.0], dtype=np.float32) + array3 = np.array([6.0], dtype=np.float32) + + tensor1 = _core.Tensor(array1) + tensor2 = _core.Tensor(array2) + tensor3 = _core.Tensor(array3) + + with tempfile.NamedTemporaryFile() as temp_file: + # Write three tensors consecutively + tensor1.tofile(temp_file) + tensor2.tofile(temp_file) + tensor3.tofile(temp_file) + + # Read the entire file + temp_file.seek(0) + result_bytes = temp_file.read() + + # The file should contain all three tensors' data concatenated + expected_bytes = array1.tobytes() + array2.tobytes() + array3.tobytes() + self.assertEqual(result_bytes, expected_bytes) + + # Verify each part + bytes1 = array1.tobytes() + bytes2 = array2.tobytes() + bytes3 = array3.tobytes() + + self.assertEqual(result_bytes[: len(bytes1)], bytes1) + self.assertEqual(result_bytes[len(bytes1) : len(bytes1) + len(bytes2)], bytes2) + self.assertEqual(result_bytes[len(bytes1) + len(bytes2) :], bytes3) + + def test_tofile_consecutive_writes_mixed_types(self): + """Test tofile() with mixed tensor types (numpy and torch) writing consecutively.""" + # Create tensors with different underlying types + numpy_array = np.array([1.0, 2.0], dtype=np.float32) + torch_array = np.array([3.0, 4.0], dtype=np.float32) + torch_tensor_raw = torch.tensor(torch_array) + + numpy_tensor = _core.Tensor(numpy_array) + torch_tensor = _core.Tensor(torch_tensor_raw, dtype=ir.DataType.FLOAT) + + with tempfile.NamedTemporaryFile() as temp_file: + # Write numpy tensor first, then torch tensor + numpy_tensor.tofile(temp_file) + torch_tensor.tofile(temp_file) + + temp_file.seek(0) + result_bytes = temp_file.read() + + # Should be equivalent to concatenating their tobytes() + expected_bytes = numpy_tensor.tobytes() + torch_tensor.tobytes() + self.assertEqual(result_bytes, expected_bytes) + + def test_tofile_consecutive_writes_packed_types(self): + """Test tofile() with packed tensor types writing consecutively.""" + # Create packed tensors + array1 = np.array([0, 1, 2, 7], dtype=np.uint8) + array2 = np.array([8, 9, 10, 15], dtype=np.uint8) + + tensor1 = _core.Tensor(array1, dtype=ir.DataType.UINT4) + tensor2 = _core.Tensor(array2, dtype=ir.DataType.UINT4) + + with tempfile.NamedTemporaryFile() as temp_file: + # Write packed tensors consecutively + tensor1.tofile(temp_file) + tensor2.tofile(temp_file) + + temp_file.seek(0) + result_bytes = temp_file.read() + + # Should be equivalent to concatenating their tobytes() + expected_bytes = tensor1.tobytes() + tensor2.tobytes() + self.assertEqual(result_bytes, expected_bytes) + def _to_external_tensor(tensor_proto, dir: str, filename: str): onnx.external_data_helper.set_external_data(tensor_proto, location=filename) @@ -2744,8 +2839,6 @@ def test_integration_with_regular_tensor_operations(self): ) def test_tobytes_big_endian_handling(self, _: str, dtype: ir.DataType): """Test that PackedTensor.tobytes() correctly handles byte order conversion.""" - import unittest.mock - # Create packed data packed_data = np.array([0x21, 0x43], dtype=np.uint8) shape = _core.Shape([4]) @@ -2761,8 +2854,6 @@ def test_tobytes_big_endian_handling(self, _: str, dtype: ir.DataType): def test_tofile_packed_tensor(self): """Test tofile() method works correctly for PackedTensor.""" - import tempfile - packed_data = np.array([0x21, 0x43], dtype=np.uint8) shape = _core.Shape([4]) tensor = _core.PackedTensor(packed_data, dtype=ir.DataType.UINT4, shape=shape) @@ -2777,9 +2868,6 @@ def test_tofile_packed_tensor(self): def test_tofile_packed_tensor_big_endian_handling(self): """Test tofile() big endian handling for PackedTensor.""" - import tempfile - import unittest.mock - packed_data = np.array([0x21, 0x43], dtype=np.uint8) shape = _core.Shape([4]) tensor = _core.PackedTensor(packed_data, dtype=ir.DataType.UINT4, shape=shape) From 1f87be162a4c8d782c456cb88d388fc73b21ae3f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 20:53:54 -0700 Subject: [PATCH 14/22] naming Signed-off-by: Justin Chu --- src/onnx_ir/tensor_adapters.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 9b074ce2..8be17108 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -168,7 +168,8 @@ def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: return self.numpy() return self.numpy().__array__(dtype) - def _get_data_chunk(self): + def _get_cbytes(self): + """Get a ctypes byte array pointing to the tensor data.""" import torch._subclasses.fake_tensor with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access @@ -182,6 +183,7 @@ def _get_data_chunk(self): "or save the model without initializers by setting include_initializers=False." ) + # Return the tensor to ensure it is not garbage collected while the ctypes array is in use return tensor, (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( tensor.data_ptr() ) @@ -190,9 +192,9 @@ def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array - _, address = self._get_data_chunk() - return bytes(address) + _, data = self._get_cbytes() + return bytes(data) def tofile(self, file) -> None: - _, address = self._get_data_chunk() - return file.write(address) + _, data = self._get_cbytes() + return file.write(data) From 2e06f508376f6c596f12539311cd7f75a3ad5164 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 20:59:03 -0700 Subject: [PATCH 15/22] versionadded Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 79c1cf88..bddbdbd0 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -191,6 +191,8 @@ def tofile(self, file) -> None: This method writes the raw bytes of the tensor to a file-like object. The file-like object must have a ``write`` method that accepts bytes. + .. versionadded:: 0.1.11 + Args: file: A file-like object with a ``write`` method that accepts bytes. """ @@ -548,6 +550,8 @@ def tobytes(self) -> bytes: def tofile(self, file) -> None: """Write the tensor to a binary file. + .. versionadded:: 0.1.11 + Args: file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. """ @@ -1038,7 +1042,6 @@ def tobytes(self) -> bytes: return self._evaluate().tobytes() def tofile(self, file) -> None: - """Write the tensor to a binary file.""" tensor = self._evaluate() if hasattr(tensor, "tofile"): tensor.tofile(file) @@ -1182,6 +1185,8 @@ def tobytes(self) -> bytes: def tofile(self, file) -> None: """Write the tensor to a binary file. + .. versionadded:: 0.1.11 + Args: file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. """ From dafeaf722b9d44ac72f0685c954a02eae354e753 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 21:08:52 -0700 Subject: [PATCH 16/22] Add tests Signed-off-by: Justin Chu --- src/onnx_ir/_core_test.py | 164 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index 3d9a2913..5c45f561 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -710,6 +710,170 @@ def test_external_tensor_empty_tensor(self): # about permission errors del tensor + def test_tofile_basic(self): + """Test ExternalTensor.tofile() with basic functionality.""" + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + base_dir=self.base_path, + name="input", + shape=_core.Shape(external_tensor.dims), + ) + + # Test writing to BytesIO + output = io.BytesIO() + tensor.tofile(output) + output.seek(0) + written_data = output.read() + + # Verify the written data matches expected + expected_data = self.data.tobytes() + self.assertEqual(written_data, expected_data) + + def test_tofile_with_offset(self): + """Test ExternalTensor.tofile() with offset handling.""" + # Use the second tensor which has an offset + external_tensor2 = self.model.graph.initializer[1] + external_info2 = onnx.external_data_helper.ExternalDataInfo(external_tensor2) + tensor2 = _core.ExternalTensor( + external_info2.location, + offset=external_info2.offset, + length=external_info2.length, + dtype=ir.DataType.FLOAT16, + base_dir=self.base_path, + name="input2", + shape=_core.Shape(external_tensor2.dims), + ) + + # Test writing to BytesIO + output = io.BytesIO() + tensor2.tofile(output) + output.seek(0) + written_data = output.read() + + # Verify the written data matches expected + expected_data = self.data_float16.tobytes() + self.assertEqual(written_data, expected_data) + + def test_tofile_with_file_object(self): + """Test ExternalTensor.tofile() writing to a file.""" + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + base_dir=self.base_path, + name="input", + shape=_core.Shape(external_tensor.dims), + ) + + with tempfile.NamedTemporaryFile() as temp_file: + tensor.tofile(temp_file) + temp_file.seek(0) + written_data = temp_file.read() + + # Verify the written data matches expected + expected_data = self.data.tobytes() + self.assertEqual(written_data, expected_data) + + def test_tofile_empty_tensor(self): + """Test ExternalTensor.tofile() with empty tensor.""" + expected_array = np.array([], dtype=np.float32) + tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array)) + with tempfile.TemporaryDirectory() as temp_dir: + _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") + tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) + + self.assertIsInstance(tensor, _core.ExternalTensor) + + # Test writing empty tensor to BytesIO + output = io.BytesIO() + tensor.tofile(output) + output.seek(0) + written_data = output.read() + + # Should write empty bytes + self.assertEqual(written_data, b"") + del tensor + + def test_tofile_large_chunks(self): + """Test ExternalTensor.tofile() handles large data with chunking.""" + # Create a larger array to test the chunking mechanism + large_data = np.random.rand(1100, 1100).astype(np.float32) + tensor_proto = ir.serde.serialize_tensor(ir.Tensor(large_data)) + with tempfile.TemporaryDirectory() as temp_dir: + _to_external_tensor(tensor_proto, temp_dir, "large_tensor.bin") + tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) + + self.assertIsInstance(tensor, _core.ExternalTensor) + + # Test writing to BytesIO + output = io.BytesIO() + tensor.tofile(output) + output.seek(0) + written_data = output.read() + + # Verify the written data matches expected + expected_data = large_data.tobytes() + self.assertEqual(written_data, expected_data) + del tensor + + def test_tofile_invalidated_tensor_raises_error(self): + """Test that tofile() raises error on invalidated tensor.""" + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + base_dir=self.base_path, + name="input", + shape=_core.Shape(external_tensor.dims), + ) + + # Invalidate the tensor + tensor.invalidate() + + # Should raise ValueError when trying to write + output = io.BytesIO() + with self.assertRaisesRegex(ValueError, "invalidated"): + tensor.tofile(output) + + def test_tofile_consecutive_writes(self): + """Test ExternalTensor.tofile() with consecutive writes to same file.""" + external_tensor = self.model.graph.initializer[0] + external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) + tensor = _core.ExternalTensor( + external_info.location, + offset=external_info.offset, + length=external_info.length, + dtype=ir.DataType.FLOAT, + base_dir=self.base_path, + name="input", + shape=_core.Shape(external_tensor.dims), + ) + + # Write tensor three times consecutively to BytesIO + output = io.BytesIO() + tensor.tofile(output) + tensor.tofile(output) + tensor.tofile(output) + + output.seek(0) + written_data = output.read() + + # Should have written the data three times + expected_data = self.data.tobytes() + expected_triple = expected_data + expected_data + expected_data + self.assertEqual(written_data, expected_triple) + class SymbolicDimTest(unittest.TestCase): def test_init_raises_when_value_is_int(self): From ff2df132c7d32aeba61226dabbc8ea0968da6af2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Oct 2025 11:16:56 -0700 Subject: [PATCH 17/22] docstring Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index bddbdbd0..9d8f1084 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -585,7 +585,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable= the tensor is recommended if IO overhead and memory usage is a concern. To obtain an array, call :meth:`numpy`. To obtain the bytes, - call :meth:`tobytes`. + call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`. The :attr:`location` must be a relative path conforming to the ONNX specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed From 2220eb3e71e9c4b8bc43a512fd5554ae3d9a415c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Oct 2025 11:18:12 -0700 Subject: [PATCH 18/22] docs Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 2 ++ src/onnx_ir/external_data.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 9d8f1084..742d90e9 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1044,6 +1044,8 @@ def tobytes(self) -> bytes: def tofile(self, file) -> None: tensor = self._evaluate() if hasattr(tensor, "tofile"): + # Some existing implementation of TensorProtocol may not have tofile() + # as it was introduced in v0.1.11 tensor.tofile(file) else: super().tofile(file) diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index 1fccb62b..23d3c362 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -211,6 +211,8 @@ def _write_external_data( data_file.write(b"\0" * (current_offset - file_size)) if hasattr(tensor, "tofile"): + # Some existing implementation of TensorProtocol may not have tofile() + # as it was introduced in v0.1.11 tensor.tofile(data_file) else: raw_data = tensor.tobytes() From 24d6e658d278d082a6536e94e4728845893a0365 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Oct 2025 11:19:49 -0700 Subject: [PATCH 19/22] docs Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 4 ++-- src/onnx_ir/external_data.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 742d90e9..f54fd192 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1044,8 +1044,8 @@ def tobytes(self) -> bytes: def tofile(self, file) -> None: tensor = self._evaluate() if hasattr(tensor, "tofile"): - # Some existing implementation of TensorProtocol may not have tofile() - # as it was introduced in v0.1.11 + # Some existing implementation (e.g. PyTorch <2.10) of TensorProtocol + # may not have tofile() as it was introduced in v0.1.11 tensor.tofile(file) else: super().tofile(file) diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index 23d3c362..043d983f 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -211,8 +211,8 @@ def _write_external_data( data_file.write(b"\0" * (current_offset - file_size)) if hasattr(tensor, "tofile"): - # Some existing implementation of TensorProtocol may not have tofile() - # as it was introduced in v0.1.11 + # Some existing implementation (e.g. PyTorch <2.10) of TensorProtocol + # may not have tofile() as it was introduced in v0.1.11 tensor.tofile(data_file) else: raw_data = tensor.tobytes() From ef9b697f3265a11d6b803187df12e1c58aa29340 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Oct 2025 11:22:54 -0700 Subject: [PATCH 20/22] use function Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 47 ++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index f54fd192..e1d8fb58 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -361,6 +361,27 @@ def _supports_fileno(file: Any) -> bool: return True +def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray: + """Create a numpy array for the byte representation of the tensor. + + This function is used for serializing the tensor to bytes. It handles the + special cases for 4-bit data types and endianness. + """ + array = tensor.numpy() + if tensor.dtype in { + _enums.DataType.INT4, + _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, + }: + # Pack the array into int4 + array = _type_casting.pack_4bitx2(array) + else: + assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" + if not _IS_LITTLE_ENDIAN: + array = array.astype(array.dtype.newbyteorder("<")) + return array + + class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors """An immutable concrete tensor. @@ -533,18 +554,7 @@ def tobytes(self) -> bytes: value is not a numpy array. """ # TODO(justinchuby): Support DLPack - array = self.numpy() - if self.dtype in { - _enums.DataType.INT4, - _enums.DataType.UINT4, - _enums.DataType.FLOAT4E2M1, - }: - # Pack the array into int4 - array = _type_casting.pack_4bitx2(array) - else: - assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" - if not _IS_LITTLE_ENDIAN: - array = array.astype(array.dtype.newbyteorder("<")) + array = _create_np_array_for_byte_representation(self) return array.tobytes() def tofile(self, file) -> None: @@ -557,18 +567,7 @@ def tofile(self, file) -> None: """ if _supports_fileno(file) and isinstance(self._raw, np.ndarray): # This is a duplication of tobytes() for handling special cases - array = self.numpy() - if self.dtype in { - _enums.DataType.INT4, - _enums.DataType.UINT4, - _enums.DataType.FLOAT4E2M1, - }: - # Pack the array into int4 - array = _type_casting.pack_4bitx2(array) - else: - assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" - if not _IS_LITTLE_ENDIAN: - array = array.astype(array.dtype.newbyteorder("<")) + array = _create_np_array_for_byte_representation(self) array.tofile(file) else: file.write(self.tobytes()) From 6036735d05755674d520b57fc291dd0e64df5725 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Oct 2025 11:37:57 -0700 Subject: [PATCH 21/22] update docs Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 2 +- src/onnx_ir/external_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index e1d8fb58..2c383f4c 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1043,7 +1043,7 @@ def tobytes(self) -> bytes: def tofile(self, file) -> None: tensor = self._evaluate() if hasattr(tensor, "tofile"): - # Some existing implementation (e.g. PyTorch <2.10) of TensorProtocol + # Some existing implementation of TensorProtocol # may not have tofile() as it was introduced in v0.1.11 tensor.tofile(file) else: diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index 043d983f..c33bcf5e 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -211,7 +211,7 @@ def _write_external_data( data_file.write(b"\0" * (current_offset - file_size)) if hasattr(tensor, "tofile"): - # Some existing implementation (e.g. PyTorch <2.10) of TensorProtocol + # Some existing implementation of TensorProtocol # may not have tofile() as it was introduced in v0.1.11 tensor.tofile(data_file) else: From e3df4c99428368cd8fe48bedb89b3209662d9de6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Oct 2025 12:49:14 -0700 Subject: [PATCH 22/22] Apply suggestion from @justinchuby Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 2c383f4c..cc21f204 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -565,7 +565,7 @@ def tofile(self, file) -> None: Args: file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. """ - if _supports_fileno(file) and isinstance(self._raw, np.ndarray): + if isinstance(self._raw, np.ndarray) and _supports_fileno(file): # This is a duplication of tobytes() for handling special cases array = _create_np_array_for_byte_representation(self) array.tofile(file)