diff --git a/docs/reference/toctree.txt b/docs/reference/toctree.txt index 76e021447a..496b9ff538 100644 --- a/docs/reference/toctree.txt +++ b/docs/reference/toctree.txt @@ -38,6 +38,8 @@ generated/ak.from_feather generated/ak.to_feather generated/ak.from_avro_file + generated/ak.to_safetensors + generated/ak.from_safetensors .. toctree:: :caption: Conversions for machine learning diff --git a/requirements-test-full.txt b/requirements-test-full.txt index d3d934ef95..894277140e 100644 --- a/requirements-test-full.txt +++ b/requirements-test-full.txt @@ -7,4 +7,5 @@ pyarrow>=12.0.0;sys_platform != "win32" and python_version < "3.14" pytest>=6 pytest-cov pytest-xdist +safetensors>=0.6.2 uproot>=5 diff --git a/src/awkward/operations/__init__.py b/src/awkward/operations/__init__.py index 54c8f9c28d..c2e8723686 100644 --- a/src/awkward/operations/__init__.py +++ b/src/awkward/operations/__init__.py @@ -47,6 +47,7 @@ from awkward.operations.ak_from_raggedtensor import * from awkward.operations.ak_from_rdataframe import * from awkward.operations.ak_from_regular import * +from awkward.operations.ak_from_safetensors import * from awkward.operations.ak_from_tensorflow import * from awkward.operations.ak_from_torch import * from awkward.operations.ak_full_like import * @@ -105,6 +106,7 @@ from awkward.operations.ak_to_raggedtensor import * from awkward.operations.ak_to_rdataframe import * from awkward.operations.ak_to_regular import * +from awkward.operations.ak_to_safetensors import * from awkward.operations.ak_to_tensorflow import * from awkward.operations.ak_to_torch import * from awkward.operations.ak_transform import * diff --git a/src/awkward/operations/ak_from_safetensors.py b/src/awkward/operations/ak_from_safetensors.py new file mode 100644 index 0000000000..3eff1a5daf --- /dev/null +++ b/src/awkward/operations/ak_from_safetensors.py @@ -0,0 +1,151 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import fsspec + +import awkward as ak +from awkward._dispatch import high_level_function + +__all__ = ("from_safetensors",) + + +@high_level_function() +def from_safetensors( + source, + *, + storage_options=None, + virtual=False, + # ak.from_buffers kwargs + buffer_key="{form_key}-{attribute}", + backend="cpu", + byteorder="<", + allow_noncanonical_form=False, + highlevel=True, + behavior=None, + attrs=None, +): + """ + Args: + source (path-like): Name of the input file, file path, or + remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) + for remote reading. + storage_options (None or dict): Any additional options to pass to + [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) + to open a remote file for reading. + virtual (bool, optional): If True, create a virtual (lazy) Awkward Array + that references buffers without materializing them. Defaults to False. + buffer_key (str, optional): Template for buffer names, with placeholders + `{form_key}` and `{attribute}`. Defaults to "{form_key}-{attribute}". + backend (str, optional): Backend identifier (e.g., "cpu"). Defaults to "cpu". + byteorder (str, optional): Byte order, "<" (little-endian, default) or ">". + allow_noncanonical_form (bool, optional): If True, normalize + safetensors forms that do not directly match Awkward. Defaults to False. + highlevel (bool, optional): If True, return a high-level ak.Array. If False, + return the low-level layout. Defaults to True. + behavior (Mapping | None, optional): Optional Awkward behavior mapping. + attrs (Mapping | None, optional): Optional metadata to attach to the array. + + Returns: + ak.Array or ak.layout.Content: An Awkward Array (or layout) reconstructed + from the safetensors buffers. + + Load a safetensors file as an Awkward Array. + + Ref: https://huggingface.co/docs/safetensors/. + + This function reads data serialized in the safetensors format and reconstructs + an Awkward Array (or low-level layout) from it. Buffers in the safetensors file + are mapped to Awkward buffers according to the `buffer_key` template, and + optional behavior or attributes can be attached to the returned array. + + The safetensors file **must contain** `form` and `length` entries in its + metadata, which define the structure and length of the reconstructed array. + + Example: + + >>> import awkward as ak + >>> arr = ak.from_safetensors("out.safetensors") + >>> arr # doctest: +SKIP + + + Create a virtual (lazy) array that references buffers without materializing them: + + >>> virtual_arr = ak.from_safetensors("out.safetensors", virtual=True) + >>> virtual_arr # doctest: +SKIP + + + + See also #ak.to_safetensors. + """ + # Implementation + return _impl( + source, + storage_options, + virtual, + buffer_key, + backend, + byteorder, + allow_noncanonical_form, + highlevel, + behavior, + attrs, + ) + + +def _impl( + source, + storage_options, + virtual, + buffer_key, + backend, + byteorder, + allow_noncanonical_form, + highlevel, + behavior, + attrs, +): + try: + from safetensors import _safe_open_handle + except ImportError as err: + raise ImportError( + """to use ak.from_tensorflow, you must install the 'safetensors' package with: + + pip install safetensors +or + conda install -c huggingface safetensors""" + ) from err + + fs, source = fsspec.core.url_to_fs(source, **(storage_options or {})) + + buffers = {} + + def maybe_virtualize(x): + return (lambda: x) if virtual else x + + with fs.open(source, "rb") as f: + with _safe_open_handle(f, framework="np") as g: + metadata = g.metadata() + for k in g.offset_keys(): + buffers[k] = maybe_virtualize(g.get_tensor(k)) + + if "form" not in metadata or "length" not in metadata: + raise RuntimeError( + "Missing required metadata in safetensors file: 'form' and 'length' are required." + ) + form = ak.forms.from_json(metadata["form"]) + length = int(metadata["length"]) + + # reconstruct array + return ak.ak_from_buffers._impl( + form, + length, + buffers, + buffer_key=buffer_key, + backend=backend, + byteorder=byteorder, + simplify=allow_noncanonical_form, + highlevel=highlevel, + behavior=behavior, + attrs=attrs, + ) diff --git a/src/awkward/operations/ak_to_safetensors.py b/src/awkward/operations/ak_to_safetensors.py new file mode 100644 index 0000000000..3282b4aa20 --- /dev/null +++ b/src/awkward/operations/ak_to_safetensors.py @@ -0,0 +1,147 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import fsspec + +import awkward as ak +from awkward._dispatch import high_level_function +from awkward._layout import HighLevelContext + +__all__ = ("to_safetensors",) + + +@high_level_function() +def to_safetensors( + array, + destination, + *, + storage_options=None, + # ak.to_buffers kwargs + container=None, + buffer_key="{form_key}-{attribute}", + form_key="node{id}", + id_start=0, + backend=None, + byteorder=ak._util.native_byteorder, +): + """ + Args: + array: An Awkward Array or array-like object to serialize. + destination (path-like): Name of the output file, file path, or + remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) + for remote writing. + storage_options (None or dict): Any additional options to pass to + [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) + to open a remote file for writing. + container (dict, optional): Optional mapping to receive the generated buffer + bytes. If None (default), a temporary container is used and discarded + after writing. + buffer_key (str, optional): Format string for naming buffers. May include + `{form_key}` and `{attribute}` placeholders. Defaults to + `"{form_key}-{attribute}"`. + form_key (str, optional): Format string for node forms when generating buffer + keys. Typically includes `"{id}"`. Defaults to `"node{id}"`. + id_start (int, optional): Starting index for node numbering. Defaults to `0`. + backend (str | object, optional): Backend used to convert array data into + buffers. If None, the default backend is used. + byteorder (str, optional): Byte order for numeric buffers. Defaults to the + system's native byte order. + + Returns: + None + This function writes the safetensors file to `destination`. If + `container` is provided, it will be populated with the raw buffer bytes. + + Serialize an Awkward Array to the safetensors format and write it to `destination`. + + Ref: https://huggingface.co/docs/safetensors/. + + This function converts the provided Awkward Array (or array-like object) into raw + buffers via `ak.to_buffers` and stores them in the safetensors format. Buffer names + are generated from `buffer_key` and `form_key` templates, allowing downstream + compatibility or layout reuse. + The resulting safetensors file includes metadata containing the Awkward `form` and + array `length`, which are required for `ak.from_safetensors` to reconstruct the array. + + Example: + + >>> import awkward as ak + >>> arr = ak.Array([[1, 2, 3], [], [4]]) + >>> ak.to_safetensors(arr, "out.safetensors") + + + See also #ak.from_safetensors. + """ + # Implementation + return _impl( + array, + destination, + storage_options, + container, + buffer_key, + form_key, + id_start, + backend, + byteorder, + ) + + +def _impl( + array, + destination, + storage_options, + container, + buffer_key, + form_key, + id_start, + backend, + byteorder, +): + try: + from safetensors.numpy import save + except ImportError as err: + raise ImportError( + """to use ak.to_safetensors, you must install the 'safetensors' package with: + + pip install safetensors +or + conda install -c huggingface safetensors""" + ) from err + + fs, destination = fsspec.core.url_to_fs(destination, **(storage_options or {})) + + with HighLevelContext(behavior=None, attrs=None) as ctx: + layout = ctx.unwrap(array, allow_record=True, primitive_policy="error") + + layout = ak.ak_to_packed._impl( + layout, + highlevel=False, # doesn't matter, but we can avoid extra wrapping/unwrapping + behavior=ctx.behavior, + attrs=ctx.attrs, + ) + + form, length, buffers = ak.ak_to_buffers._impl( + layout, + container=container, + buffer_key=buffer_key, + form_key=form_key, + id_start=id_start, + backend=backend, + byteorder=byteorder, + ) + + metadata = { + "form": form.to_json(), + "length": str(length), + } + + byts = save(buffers, metadata) + # save + try: + with fs.open(destination, "wb") as f: + f.write(byts) + except Exception as err: + raise RuntimeError( + f"Failed to write safetensors file to '{destination}': {err}" + ) from err diff --git a/tests/test_3685_to_from_safetensors.py b/tests/test_3685_to_from_safetensors.py new file mode 100644 index 0000000000..532311475b --- /dev/null +++ b/tests/test_3685_to_from_safetensors.py @@ -0,0 +1,49 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE +# +from __future__ import annotations + +import os + +import pytest + +safetensors = pytest.importorskip("safetensors") + + +def test_roundtrip(): + import awkward as ak + + array = ak.Array([[1, 2, 3], [], [4, 5], [6], [7, 8, 9, 10]]) + + path = "./test.safetensors" + ak.to_safetensors(array, path) + + loaded = ak.from_safetensors(path) + virtual_loaded = ak.from_safetensors(path, virtual=True) + + os.remove(path) + + assert array.layout.is_equal_to(loaded.layout, all_parameters=True) + assert array.layout.is_equal_to( + virtual_loaded.layout.materialize(), all_parameters=True + ) + + +def test_virtual_array_to_safetensors(): + import awkward as ak + + array = ak.Array([[1, 2, 3], [], [4, 5], [6], [7, 8, 9, 10]]) + + path = "./test_virtual{}.safetensors".format + + ak.to_safetensors(array, path(0)) + virtual_loaded = ak.from_safetensors(path(0), virtual=True) + + ak.to_safetensors(virtual_loaded, path(1)) + loaded = ak.from_safetensors(path(1), virtual=False) + + os.remove(path(0)) + os.remove(path(1)) + + assert virtual_loaded.layout.materialize().is_equal_to( + loaded.layout, all_parameters=True + )