Skip to content

Commit 383142f

Browse files
lukebaumanncopybara-github
authored andcommitted
Support jax.random.PRNGKey serialization in Pathways Orbax handler.
This change allows `CloudPathwaysArrayHandler` to correctly save and restore `jax.random.PRNGKey` objects by extracting and wrapping the key data, and storing metadata about the key implementation using an `ArrayMetadataStore`. This change introduces a dependency on Orbax's internal API. PiperOrigin-RevId: 813796155
1 parent 2fa0623 commit 383142f

File tree

2 files changed

+96
-12
lines changed

2 files changed

+96
-12
lines changed

pathwaysutils/_initialize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919

2020
import jax
21+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2122
from pathwaysutils import profiling
2223
from pathwaysutils import proxy_backend
2324
from pathwaysutils.persistence import orbax_handler
@@ -92,7 +93,10 @@ def initialize() -> None:
9293
profiling.monkey_patch_jax()
9394
# TODO: b/365549911 - Remove when OCDBT-compatible
9495
if _is_persistence_enabled():
95-
orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1))
96+
orbax_handler.register_pathways_handlers(
97+
read_timeout=datetime.timedelta(hours=1),
98+
array_metadata_store=array_metadata_store_lib.Store(),
99+
)
96100

97101
# Turn off JAX compilation cache because Pathways handles its own
98102
# compilation cache.

pathwaysutils/persistence/orbax_handler.py

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@
1414
"""TypeHandlers supporting Pathways backend."""
1515

1616
import collections
17-
from collections.abc import Sequence
17+
from collections.abc import Coroutine, Sequence
1818
import concurrent.futures
1919
import datetime
2020
import functools
2121
import logging
22-
import typing
22+
from typing import Any, cast
2323

2424
import jax
2525
from orbax.checkpoint import future
2626
from orbax.checkpoint import type_handlers
27+
from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib
28+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2729
from pathwaysutils.persistence import helper
2830

2931

@@ -33,6 +35,7 @@
3335
SaveArgs = type_handlers.SaveArgs
3436
RestoreArgs = type_handlers.RestoreArgs
3537
ArrayRestoreArgs = type_handlers.ArrayRestoreArgs
38+
ArrayMetadata = array_metadata_lib.ArrayMetadata
3639

3740

3841
def extract_parent_dir_and_name(
@@ -51,23 +54,30 @@ def __init__(
5154
self,
5255
read_timeout: datetime.timedelta | None = None,
5356
use_ocdbt: bool = False,
57+
array_metadata_store: array_metadata_store_lib.Store | None = None,
5458
):
5559
"""Constructor.
5660
5761
Args:
5862
read_timeout: Duration indicating the timeout for reading arrays
5963
use_ocdbt: allows using Tensorstore OCDBT driver.
64+
array_metadata_store: An optional store for writing and reading array
65+
metadata. Only required for saving new-style jax random keys.
6066
"""
6167
self._read_timeout = read_timeout
6268

6369
if use_ocdbt:
6470
raise ValueError("OCDBT not supported for Pathways.")
65-
super().__init__()
71+
super().__init__(array_metadata_store=array_metadata_store)
6672

6773
async def _background_serialize(
6874
self,
6975
futures_results: Sequence[concurrent.futures.Future[None]],
76+
metadata_coroutine: Coroutine[Any, Any, None] | None = None,
7077
) -> None:
78+
if metadata_coroutine:
79+
await metadata_coroutine
80+
7181
for future_result in futures_results:
7282
future_result.result()
7383

@@ -90,14 +100,53 @@ async def serialize(
90100
if any([arg.dtype is not None for arg in args]):
91101
raise ValueError("Casting during save not supported for Pathways.")
92102

103+
array_metadatas = []
104+
arrays = []
105+
for v, info, arg in zip(values, infos, args):
106+
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
107+
# a JAX random key
108+
arrays.append(jax.random.key_data(v))
109+
array_metadatas.append(
110+
ArrayMetadata(
111+
param_name=info.name,
112+
shape=v.shape,
113+
dtype=(arg.dtype if arg is not None else v.dtype),
114+
write_shape=getattr(v, "local_shape", v.shape),
115+
chunk_shape=getattr(v, "local_shape", v.shape),
116+
use_ocdbt=False,
117+
use_zarr3=False,
118+
ext_metadata={
119+
array_metadata_lib.RANDOM_KEY_IMPL: str(
120+
jax.random.key_impl(v)
121+
)
122+
},
123+
)
124+
)
125+
else:
126+
arrays.append(v)
127+
128+
metadata_coroutine = None
129+
if array_metadatas:
130+
if self._array_metadata_store is None:
131+
raise ValueError(
132+
"Array metadata store is not set with a checkpoint that requires"
133+
f" it. Array metadata: {array_metadatas}"
134+
)
135+
136+
metadata_coroutine = self._array_metadata_store.write(
137+
checkpoint_dir=infos[0].parent_dir,
138+
array_metadatas=array_metadatas,
139+
process_index=0,
140+
)
141+
93142
self._wait_for_directory_creation_signals()
94143
locations, names = extract_parent_dir_and_name(infos)
95144
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
96-
futures_results = list(map(f, locations, names, values))
145+
futures_results = list(map(f, locations, names, arrays))
97146

98147
return [
99148
future.CommitFutureAwaitingContractedSignals(
100-
self._background_serialize(futures_results),
149+
self._background_serialize(futures_results, metadata_coroutine),
101150
name="cloud_pathways_array_handler",
102151
)
103152
]
@@ -106,7 +155,7 @@ async def deserialize(
106155
self,
107156
infos: Sequence[ParamInfo],
108157
args: Sequence[RestoreArgs] | None = None,
109-
) -> Sequence[jax.Array]:
158+
) -> list[jax.Array]:
110159
"""Uses Pathways Persistence API to deserialize a jax array."""
111160
if args is None:
112161
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
@@ -125,7 +174,7 @@ async def deserialize(
125174
"To restore jax.Array, provide ArrayRestoreArgs; found"
126175
f" {type(arg).__name__}"
127176
)
128-
arg = typing.cast(ArrayRestoreArgs, arg)
177+
arg = cast(ArrayRestoreArgs, arg)
129178
if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None):
130179
raise ValueError(
131180
"Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -140,7 +189,7 @@ async def deserialize(
140189
else:
141190
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
142191
raise ValueError("Pathways only supports jax.sharding.NamedSharding.")
143-
sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding)
192+
sharding = cast(jax.sharding.NamedSharding, arg.sharding)
144193
global_meshes.append(sharding.mesh)
145194
mesh_axes.append(sharding.spec)
146195
shardings.append(sharding)
@@ -160,13 +209,30 @@ async def deserialize(
160209
]
161210
dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)]
162211

212+
array_metadatas_cache = {}
213+
if self._array_metadata_store is not None:
214+
array_metadatas = await self._array_metadata_store.read(
215+
checkpoint_dir=infos[0].parent_dir,
216+
process_index=0,
217+
)
218+
if not isinstance(array_metadatas, list):
219+
raise ValueError(
220+
"Array metadata store returned unexpected result:"
221+
f" {array_metadatas}"
222+
)
223+
224+
array_metadatas_cache = {
225+
array_metadata.param_name: array_metadata
226+
for array_metadata in array_metadatas
227+
}
228+
163229
# Group inputs by global_mesh so that we can perform batched Array
164230
# construction for each global_mesh.
165231
inputs_by_global_mesh = collections.defaultdict(list)
166232
for i, global_mesh in enumerate(global_meshes):
167233
inputs_by_global_mesh[global_mesh].append(i)
168234

169-
results = [None] * len(infos)
235+
results = cast(list[jax.Array], [None] * len(infos))
170236

171237
for global_mesh, idxs in inputs_by_global_mesh.items():
172238
grouped_infos = [infos[idx] for idx in idxs]
@@ -185,13 +251,26 @@ async def deserialize(
185251
)
186252
# each persistence call is awaited serially.
187253
read_future.result()
188-
for idx, arr in zip(idxs, grouped_arrays):
254+
for idx, info, arr in zip(idxs, grouped_infos, grouped_arrays):
255+
if meta := array_metadatas_cache.get(info.name):
256+
assert isinstance(
257+
meta, array_metadata_lib.SerializedArrayMetadata
258+
), f"Expecting SerializedArrayMetadata but got {type(meta)}."
259+
assert isinstance(meta.ext_metadata, dict), (
260+
"Expecting ext_metadata to be a dict but got"
261+
f" {type(meta.ext_metadata)}."
262+
)
263+
264+
if impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL):
265+
arr = jax.random.wrap_key_data(arr, impl=impl)
189266
results[idx] = arr
190-
return results # pytype: disable=bad-return-type
267+
268+
return results
191269

192270

193271
def register_pathways_handlers(
194272
read_timeout: datetime.timedelta | None = None,
273+
array_metadata_store: array_metadata_store_lib.Store | None = None,
195274
):
196275
"""Function that must be called before saving or restoring with Pathways."""
197276
logger.debug(
@@ -201,6 +280,7 @@ def register_pathways_handlers(
201280
jax.Array,
202281
CloudPathwaysArrayHandler(
203282
read_timeout=read_timeout,
283+
array_metadata_store=array_metadata_store,
204284
),
205285
override=True,
206286
)

0 commit comments

Comments
 (0)