-
Notifications
You must be signed in to change notification settings - Fork 14
Implement tofile on tensors to reduce data write time by 40% #210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
82b3f58
42d8edc
63310c1
290ab6c
1b53a6a
c05e189
6377435
3dc5704
7fd35d7
40cb60d
909344d
e7dc301
8f832b3
9afc144
1f87be1
2e06f50
dafeaf7
ff2df13
2220eb3
24d6e65
ef9b697
6036735
e3df4c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,6 +185,17 @@ def meta(self) -> _metadata.MetadataStore: | |
self._metadata = _metadata.MetadataStore() | ||
return self._metadata | ||
|
||
def tofile(self, file) -> None: | ||
"""Write the tensor to a binary file. | ||
|
||
This method writes the raw bytes of the tensor to a file-like object. | ||
The file-like object must have a ``write`` method that accepts bytes. | ||
|
||
Args: | ||
file: A file-like object with a ``write`` method that accepts bytes. | ||
""" | ||
file.write(self.tobytes()) | ||
|
||
def display(self, *, page: bool = False) -> None: | ||
rich = _display.require_rich() | ||
|
||
|
@@ -523,6 +534,30 @@ def tobytes(self) -> bytes: | |
array = array.astype(array.dtype.newbyteorder("<")) | ||
return array.tobytes() | ||
|
||
def tofile(self, file) -> None: | ||
"""Write the tensor to a binary file. | ||
|
||
Args: | ||
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. | ||
""" | ||
if hasattr(file, "fileno") and isinstance(self._raw, np.ndarray): | ||
# This is a duplication of tobytes() for handling special cases | ||
array = self.numpy() | ||
if self.dtype in { | ||
_enums.DataType.INT4, | ||
_enums.DataType.UINT4, | ||
_enums.DataType.FLOAT4E2M1, | ||
}: | ||
# Pack the array into int4 | ||
array = _type_casting.pack_4bitx2(array) | ||
else: | ||
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" | ||
if not _IS_LITTLE_ENDIAN: | ||
array = array.view(array.dtype.newbyteorder("<")) | ||
array.tofile(file) | ||
else: | ||
file.write(self.tobytes()) | ||
|
||
|
||
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors | ||
"""An immutable concrete tensor with its data store on disk. | ||
|
@@ -590,7 +625,7 @@ def __init__( | |
length: The length of the data in bytes. | ||
dtype: The data type of the tensor. | ||
shape: The shape of the tensor. | ||
name: The name of the tensor.. | ||
name: The name of the tensor. | ||
doc_string: The documentation string. | ||
metadata_props: The metadata properties. | ||
base_dir: The base directory for the external data. It is used to resolve relative paths. | ||
|
@@ -746,6 +781,18 @@ def tobytes(self) -> bytes: | |
length = self._length or self.nbytes | ||
return self.raw[offset : offset + length] | ||
|
||
def tofile(self, file) -> None: | ||
self._check_validity() | ||
with open(self.path, "rb") as src: | ||
if self._offset is not None: | ||
src.seek(self._offset) | ||
bytes_to_copy = self._length or self.nbytes | ||
chunk_size = 1024 * 1024 # 1MB | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
while bytes_to_copy > 0: | ||
chunk = src.read(min(chunk_size, bytes_to_copy)) | ||
file.write(chunk) | ||
bytes_to_copy -= len(chunk) | ||
|
||
def valid(self) -> bool: | ||
"""Check if the tensor is valid. | ||
|
||
|
@@ -979,6 +1026,14 @@ def tobytes(self) -> bytes: | |
"""Return the bytes of the tensor.""" | ||
return self._evaluate().tobytes() | ||
|
||
def tofile(self, file) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering whether There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you say more? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just thought it's not even real until it's evaluated. Intuitively, not very suitable with tofile(), which we want to write it to disk. But I guess in general expectation, we want all tensors have this method. It's understandable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is actually useful: even when the tensor is lazily evaluated, we still want to avoid tobytes() making a copy of the tensor data before writing to file. The screenshots on the PR description are showing lazy tensors. |
||
"""Write the tensor to a binary file.""" | ||
tensor = self._evaluate() | ||
if hasattr(tensor, "tofile"): | ||
tensor.tofile(file) | ||
else: | ||
super().tofile(file) | ||
|
||
|
||
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors | ||
"""A tensor that stores 4bit datatypes in packed format. | ||
|
@@ -1113,6 +1168,21 @@ def tobytes(self) -> bytes: | |
array = array.astype(array.dtype.newbyteorder("<")) | ||
return array.tobytes() | ||
|
||
def tofile(self, file) -> None: | ||
"""Write the tensor to a binary file. | ||
|
||
Args: | ||
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method. | ||
""" | ||
if hasattr(file, "fileno"): | ||
# This is a duplication of tobytes() for handling edge cases | ||
array = self.numpy_packed() | ||
if not _IS_LITTLE_ENDIAN: | ||
array = array.view(array.dtype.newbyteorder("<")) | ||
array.tofile(file) | ||
else: | ||
file.write(self.tobytes()) | ||
|
||
|
||
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): | ||
"""Immutable symbolic dimension that can be shared across multiple shapes. | ||
|
Uh oh!
There was an error while loading. Please reload this page.