Skip to content

Commit 4f34af8

Browse files
pfackeldeypre-commit-ci[bot]ianna
authored
feat: add to_/from_safetensors (#3685)
* feat: add to_/from_safetensors * style: pre-commit fixes * satisfy pre-commit * add test * satisfy pylint too * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_from_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * Update src/awkward/operations/ak_to_safetensors.py Co-authored-by: Ianna Osborne <[email protected]> * address remaining comments * make sure arrays are packed before serializing to safetensors * use fsspec to allow remote writing and reading --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ianna Osborne <[email protected]>
1 parent 3b9d8f6 commit 4f34af8

File tree

6 files changed

+352
-0
lines changed

6 files changed

+352
-0
lines changed

docs/reference/toctree.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
generated/ak.from_feather
3939
generated/ak.to_feather
4040
generated/ak.from_avro_file
41+
generated/ak.to_safetensors
42+
generated/ak.from_safetensors
4143

4244
.. toctree::
4345
:caption: Conversions for machine learning

requirements-test-full.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ pyarrow>=12.0.0;sys_platform != "win32" and python_version < "3.14"
77
pytest>=6
88
pytest-cov
99
pytest-xdist
10+
safetensors>=0.6.2
1011
uproot>=5

src/awkward/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from awkward.operations.ak_from_raggedtensor import *
4848
from awkward.operations.ak_from_rdataframe import *
4949
from awkward.operations.ak_from_regular import *
50+
from awkward.operations.ak_from_safetensors import *
5051
from awkward.operations.ak_from_tensorflow import *
5152
from awkward.operations.ak_from_torch import *
5253
from awkward.operations.ak_full_like import *
@@ -105,6 +106,7 @@
105106
from awkward.operations.ak_to_raggedtensor import *
106107
from awkward.operations.ak_to_rdataframe import *
107108
from awkward.operations.ak_to_regular import *
109+
from awkward.operations.ak_to_safetensors import *
108110
from awkward.operations.ak_to_tensorflow import *
109111
from awkward.operations.ak_to_torch import *
110112
from awkward.operations.ak_transform import *
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
import fsspec
6+
7+
import awkward as ak
8+
from awkward._dispatch import high_level_function
9+
10+
__all__ = ("from_safetensors",)
11+
12+
13+
@high_level_function()
14+
def from_safetensors(
15+
source,
16+
*,
17+
storage_options=None,
18+
virtual=False,
19+
# ak.from_buffers kwargs
20+
buffer_key="{form_key}-{attribute}",
21+
backend="cpu",
22+
byteorder="<",
23+
allow_noncanonical_form=False,
24+
highlevel=True,
25+
behavior=None,
26+
attrs=None,
27+
):
28+
"""
29+
Args:
30+
source (path-like): Name of the input file, file path, or
31+
remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs)
32+
for remote reading.
33+
storage_options (None or dict): Any additional options to pass to
34+
[fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs)
35+
to open a remote file for reading.
36+
virtual (bool, optional): If True, create a virtual (lazy) Awkward Array
37+
that references buffers without materializing them. Defaults to False.
38+
buffer_key (str, optional): Template for buffer names, with placeholders
39+
`{form_key}` and `{attribute}`. Defaults to "{form_key}-{attribute}".
40+
backend (str, optional): Backend identifier (e.g., "cpu"). Defaults to "cpu".
41+
byteorder (str, optional): Byte order, "<" (little-endian, default) or ">".
42+
allow_noncanonical_form (bool, optional): If True, normalize
43+
safetensors forms that do not directly match Awkward. Defaults to False.
44+
highlevel (bool, optional): If True, return a high-level ak.Array. If False,
45+
return the low-level layout. Defaults to True.
46+
behavior (Mapping | None, optional): Optional Awkward behavior mapping.
47+
attrs (Mapping | None, optional): Optional metadata to attach to the array.
48+
49+
Returns:
50+
ak.Array or ak.layout.Content: An Awkward Array (or layout) reconstructed
51+
from the safetensors buffers.
52+
53+
Load a safetensors file as an Awkward Array.
54+
55+
Ref: https://huggingface.co/docs/safetensors/.
56+
57+
This function reads data serialized in the safetensors format and reconstructs
58+
an Awkward Array (or low-level layout) from it. Buffers in the safetensors file
59+
are mapped to Awkward buffers according to the `buffer_key` template, and
60+
optional behavior or attributes can be attached to the returned array.
61+
62+
The safetensors file **must contain** `form` and `length` entries in its
63+
metadata, which define the structure and length of the reconstructed array.
64+
65+
Example:
66+
67+
>>> import awkward as ak
68+
>>> arr = ak.from_safetensors("out.safetensors")
69+
>>> arr # doctest: +SKIP
70+
<Array [[1, 2, 3], [], [4]] type='3 * var * int64'>
71+
72+
Create a virtual (lazy) array that references buffers without materializing them:
73+
74+
>>> virtual_arr = ak.from_safetensors("out.safetensors", virtual=True)
75+
>>> virtual_arr # doctest: +SKIP
76+
<Array [??, ??, ??] type='3 * var * int64'>
77+
78+
79+
See also #ak.to_safetensors.
80+
"""
81+
# Implementation
82+
return _impl(
83+
source,
84+
storage_options,
85+
virtual,
86+
buffer_key,
87+
backend,
88+
byteorder,
89+
allow_noncanonical_form,
90+
highlevel,
91+
behavior,
92+
attrs,
93+
)
94+
95+
96+
def _impl(
97+
source,
98+
storage_options,
99+
virtual,
100+
buffer_key,
101+
backend,
102+
byteorder,
103+
allow_noncanonical_form,
104+
highlevel,
105+
behavior,
106+
attrs,
107+
):
108+
try:
109+
from safetensors import _safe_open_handle
110+
except ImportError as err:
111+
raise ImportError(
112+
"""to use ak.from_tensorflow, you must install the 'safetensors' package with:
113+
114+
pip install safetensors
115+
or
116+
conda install -c huggingface safetensors"""
117+
) from err
118+
119+
fs, source = fsspec.core.url_to_fs(source, **(storage_options or {}))
120+
121+
buffers = {}
122+
123+
def maybe_virtualize(x):
124+
return (lambda: x) if virtual else x
125+
126+
with fs.open(source, "rb") as f:
127+
with _safe_open_handle(f, framework="np") as g:
128+
metadata = g.metadata()
129+
for k in g.offset_keys():
130+
buffers[k] = maybe_virtualize(g.get_tensor(k))
131+
132+
if "form" not in metadata or "length" not in metadata:
133+
raise RuntimeError(
134+
"Missing required metadata in safetensors file: 'form' and 'length' are required."
135+
)
136+
form = ak.forms.from_json(metadata["form"])
137+
length = int(metadata["length"])
138+
139+
# reconstruct array
140+
return ak.ak_from_buffers._impl(
141+
form,
142+
length,
143+
buffers,
144+
buffer_key=buffer_key,
145+
backend=backend,
146+
byteorder=byteorder,
147+
simplify=allow_noncanonical_form,
148+
highlevel=highlevel,
149+
behavior=behavior,
150+
attrs=attrs,
151+
)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
import fsspec
6+
7+
import awkward as ak
8+
from awkward._dispatch import high_level_function
9+
from awkward._layout import HighLevelContext
10+
11+
__all__ = ("to_safetensors",)
12+
13+
14+
@high_level_function()
15+
def to_safetensors(
16+
array,
17+
destination,
18+
*,
19+
storage_options=None,
20+
# ak.to_buffers kwargs
21+
container=None,
22+
buffer_key="{form_key}-{attribute}",
23+
form_key="node{id}",
24+
id_start=0,
25+
backend=None,
26+
byteorder=ak._util.native_byteorder,
27+
):
28+
"""
29+
Args:
30+
array: An Awkward Array or array-like object to serialize.
31+
destination (path-like): Name of the output file, file path, or
32+
remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs)
33+
for remote writing.
34+
storage_options (None or dict): Any additional options to pass to
35+
[fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs)
36+
to open a remote file for writing.
37+
container (dict, optional): Optional mapping to receive the generated buffer
38+
bytes. If None (default), a temporary container is used and discarded
39+
after writing.
40+
buffer_key (str, optional): Format string for naming buffers. May include
41+
`{form_key}` and `{attribute}` placeholders. Defaults to
42+
`"{form_key}-{attribute}"`.
43+
form_key (str, optional): Format string for node forms when generating buffer
44+
keys. Typically includes `"{id}"`. Defaults to `"node{id}"`.
45+
id_start (int, optional): Starting index for node numbering. Defaults to `0`.
46+
backend (str | object, optional): Backend used to convert array data into
47+
buffers. If None, the default backend is used.
48+
byteorder (str, optional): Byte order for numeric buffers. Defaults to the
49+
system's native byte order.
50+
51+
Returns:
52+
None
53+
This function writes the safetensors file to `destination`. If
54+
`container` is provided, it will be populated with the raw buffer bytes.
55+
56+
Serialize an Awkward Array to the safetensors format and write it to `destination`.
57+
58+
Ref: https://huggingface.co/docs/safetensors/.
59+
60+
This function converts the provided Awkward Array (or array-like object) into raw
61+
buffers via `ak.to_buffers` and stores them in the safetensors format. Buffer names
62+
are generated from `buffer_key` and `form_key` templates, allowing downstream
63+
compatibility or layout reuse.
64+
The resulting safetensors file includes metadata containing the Awkward `form` and
65+
array `length`, which are required for `ak.from_safetensors` to reconstruct the array.
66+
67+
Example:
68+
69+
>>> import awkward as ak
70+
>>> arr = ak.Array([[1, 2, 3], [], [4]])
71+
>>> ak.to_safetensors(arr, "out.safetensors")
72+
73+
74+
See also #ak.from_safetensors.
75+
"""
76+
# Implementation
77+
return _impl(
78+
array,
79+
destination,
80+
storage_options,
81+
container,
82+
buffer_key,
83+
form_key,
84+
id_start,
85+
backend,
86+
byteorder,
87+
)
88+
89+
90+
def _impl(
91+
array,
92+
destination,
93+
storage_options,
94+
container,
95+
buffer_key,
96+
form_key,
97+
id_start,
98+
backend,
99+
byteorder,
100+
):
101+
try:
102+
from safetensors.numpy import save
103+
except ImportError as err:
104+
raise ImportError(
105+
"""to use ak.to_safetensors, you must install the 'safetensors' package with:
106+
107+
pip install safetensors
108+
or
109+
conda install -c huggingface safetensors"""
110+
) from err
111+
112+
fs, destination = fsspec.core.url_to_fs(destination, **(storage_options or {}))
113+
114+
with HighLevelContext(behavior=None, attrs=None) as ctx:
115+
layout = ctx.unwrap(array, allow_record=True, primitive_policy="error")
116+
117+
layout = ak.ak_to_packed._impl(
118+
layout,
119+
highlevel=False, # doesn't matter, but we can avoid extra wrapping/unwrapping
120+
behavior=ctx.behavior,
121+
attrs=ctx.attrs,
122+
)
123+
124+
form, length, buffers = ak.ak_to_buffers._impl(
125+
layout,
126+
container=container,
127+
buffer_key=buffer_key,
128+
form_key=form_key,
129+
id_start=id_start,
130+
backend=backend,
131+
byteorder=byteorder,
132+
)
133+
134+
metadata = {
135+
"form": form.to_json(),
136+
"length": str(length),
137+
}
138+
139+
byts = save(buffers, metadata)
140+
# save
141+
try:
142+
with fs.open(destination, "wb") as f:
143+
f.write(byts)
144+
except Exception as err:
145+
raise RuntimeError(
146+
f"Failed to write safetensors file to '{destination}': {err}"
147+
) from err
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
#
3+
from __future__ import annotations
4+
5+
import os
6+
7+
import pytest
8+
9+
safetensors = pytest.importorskip("safetensors")
10+
11+
12+
def test_roundtrip():
13+
import awkward as ak
14+
15+
array = ak.Array([[1, 2, 3], [], [4, 5], [6], [7, 8, 9, 10]])
16+
17+
path = "./test.safetensors"
18+
ak.to_safetensors(array, path)
19+
20+
loaded = ak.from_safetensors(path)
21+
virtual_loaded = ak.from_safetensors(path, virtual=True)
22+
23+
os.remove(path)
24+
25+
assert array.layout.is_equal_to(loaded.layout, all_parameters=True)
26+
assert array.layout.is_equal_to(
27+
virtual_loaded.layout.materialize(), all_parameters=True
28+
)
29+
30+
31+
def test_virtual_array_to_safetensors():
32+
import awkward as ak
33+
34+
array = ak.Array([[1, 2, 3], [], [4, 5], [6], [7, 8, 9, 10]])
35+
36+
path = "./test_virtual{}.safetensors".format
37+
38+
ak.to_safetensors(array, path(0))
39+
virtual_loaded = ak.from_safetensors(path(0), virtual=True)
40+
41+
ak.to_safetensors(virtual_loaded, path(1))
42+
loaded = ak.from_safetensors(path(1), virtual=False)
43+
44+
os.remove(path(0))
45+
os.remove(path(1))
46+
47+
assert virtual_loaded.layout.materialize().is_equal_to(
48+
loaded.layout, all_parameters=True
49+
)

0 commit comments

Comments
 (0)