14
14
"""TypeHandlers supporting Pathways backend."""
15
15
16
16
import collections
17
- from collections .abc import Sequence
17
+ from collections .abc import Coroutine , Sequence
18
18
import concurrent .futures
19
19
import datetime
20
20
import functools
21
21
import logging
22
- import typing
22
+ from typing import Any , cast
23
23
24
24
import jax
25
25
from orbax .checkpoint import future
26
26
from orbax .checkpoint import type_handlers
27
+ from orbax .checkpoint ._src .metadata import array_metadata as array_metadata_lib
28
+ from orbax .checkpoint ._src .metadata import array_metadata_store as array_metadata_store_lib
27
29
from pathwaysutils .persistence import helper
28
30
29
31
33
35
SaveArgs = type_handlers .SaveArgs
34
36
RestoreArgs = type_handlers .RestoreArgs
35
37
ArrayRestoreArgs = type_handlers .ArrayRestoreArgs
38
+ ArrayMetadata = array_metadata_lib .ArrayMetadata
36
39
37
40
38
41
def extract_parent_dir_and_name (
@@ -51,23 +54,30 @@ def __init__(
51
54
self ,
52
55
read_timeout : datetime .timedelta | None = None ,
53
56
use_ocdbt : bool = False ,
57
+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
54
58
):
55
59
"""Constructor.
56
60
57
61
Args:
58
62
read_timeout: Duration indicating the timeout for reading arrays
59
63
use_ocdbt: allows using Tensorstore OCDBT driver.
64
+ array_metadata_store: An optional store for writing and reading array
65
+ metadata. Only required for saving new-style jax random keys.
60
66
"""
61
67
self ._read_timeout = read_timeout
62
68
63
69
if use_ocdbt :
64
70
raise ValueError ("OCDBT not supported for Pathways." )
65
- super ().__init__ ()
71
+ super ().__init__ (array_metadata_store = array_metadata_store )
66
72
67
73
async def _background_serialize (
68
74
self ,
69
75
futures_results : Sequence [concurrent .futures .Future [None ]],
76
+ metadata_coroutine : Coroutine [Any , Any , None ] | None = None ,
70
77
) -> None :
78
+ if metadata_coroutine :
79
+ await metadata_coroutine
80
+
71
81
for future_result in futures_results :
72
82
future_result .result ()
73
83
@@ -90,14 +100,53 @@ async def serialize(
90
100
if any ([arg .dtype is not None for arg in args ]):
91
101
raise ValueError ("Casting during save not supported for Pathways." )
92
102
103
+ array_metadatas = []
104
+ arrays = []
105
+ for v , info , arg in zip (values , infos , args ):
106
+ if jax .dtypes .issubdtype (v .dtype , jax .dtypes .prng_key ):
107
+ # a JAX random key
108
+ arrays .append (jax .random .key_data (v ))
109
+ array_metadatas .append (
110
+ ArrayMetadata (
111
+ param_name = info .name ,
112
+ shape = v .shape ,
113
+ dtype = (arg .dtype if arg is not None else v .dtype ),
114
+ write_shape = getattr (v , "local_shape" , v .shape ),
115
+ chunk_shape = getattr (v , "local_shape" , v .shape ),
116
+ use_ocdbt = False ,
117
+ use_zarr3 = False ,
118
+ ext_metadata = {
119
+ array_metadata_lib .RANDOM_KEY_IMPL : str (
120
+ jax .random .key_impl (v )
121
+ )
122
+ },
123
+ )
124
+ )
125
+ else :
126
+ arrays .append (v )
127
+
128
+ metadata_coroutine = None
129
+ if array_metadatas :
130
+ if self ._array_metadata_store is None :
131
+ raise ValueError (
132
+ "Array metadata store is not set with a checkpoint that requires"
133
+ f" it. Array metadata: { array_metadatas } "
134
+ )
135
+
136
+ metadata_coroutine = self ._array_metadata_store .write (
137
+ checkpoint_dir = infos [0 ].parent_dir ,
138
+ array_metadatas = array_metadatas ,
139
+ process_index = 0 ,
140
+ )
141
+
93
142
self ._wait_for_directory_creation_signals ()
94
143
locations , names = extract_parent_dir_and_name (infos )
95
144
f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
96
- futures_results = list (map (f , locations , names , values ))
145
+ futures_results = list (map (f , locations , names , arrays ))
97
146
98
147
return [
99
148
future .CommitFutureAwaitingContractedSignals (
100
- self ._background_serialize (futures_results ),
149
+ self ._background_serialize (futures_results , metadata_coroutine ),
101
150
name = "cloud_pathways_array_handler" ,
102
151
)
103
152
]
@@ -106,7 +155,7 @@ async def deserialize(
106
155
self ,
107
156
infos : Sequence [ParamInfo ],
108
157
args : Sequence [RestoreArgs ] | None = None ,
109
- ) -> Sequence [jax .Array ]:
158
+ ) -> list [jax .Array ]:
110
159
"""Uses Pathways Persistence API to deserialize a jax array."""
111
160
if args is None :
112
161
raise ValueError ("Must provide ArrayRestoreArgs to restore as jax.Array." )
@@ -125,7 +174,7 @@ async def deserialize(
125
174
"To restore jax.Array, provide ArrayRestoreArgs; found"
126
175
f" { type (arg ).__name__ } "
127
176
)
128
- arg = typing . cast (ArrayRestoreArgs , arg )
177
+ arg = cast (ArrayRestoreArgs , arg )
129
178
if arg .sharding is None and (arg .mesh is None or arg .mesh_axes is None ):
130
179
raise ValueError (
131
180
"Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -140,7 +189,7 @@ async def deserialize(
140
189
else :
141
190
if not isinstance (arg .sharding , jax .sharding .NamedSharding ):
142
191
raise ValueError ("Pathways only supports jax.sharding.NamedSharding." )
143
- sharding = typing . cast (jax .sharding .NamedSharding , arg .sharding )
192
+ sharding = cast (jax .sharding .NamedSharding , arg .sharding )
144
193
global_meshes .append (sharding .mesh )
145
194
mesh_axes .append (sharding .spec )
146
195
shardings .append (sharding )
@@ -160,13 +209,30 @@ async def deserialize(
160
209
]
161
210
dtypes = [m .dtype if d is None else d for m , d in zip (metadatas , dtypes )]
162
211
212
+ array_metadatas_cache = {}
213
+ if self ._array_metadata_store is not None :
214
+ array_metadatas = await self ._array_metadata_store .read (
215
+ checkpoint_dir = infos [0 ].parent_dir ,
216
+ process_index = 0 ,
217
+ )
218
+ if not isinstance (array_metadatas , list ):
219
+ raise ValueError (
220
+ "Array metadata store returned unexpected result:"
221
+ f" { array_metadatas } "
222
+ )
223
+
224
+ array_metadatas_cache = {
225
+ array_metadata .param_name : array_metadata
226
+ for array_metadata in array_metadatas
227
+ }
228
+
163
229
# Group inputs by global_mesh so that we can perform batched Array
164
230
# construction for each global_mesh.
165
231
inputs_by_global_mesh = collections .defaultdict (list )
166
232
for i , global_mesh in enumerate (global_meshes ):
167
233
inputs_by_global_mesh [global_mesh ].append (i )
168
234
169
- results = [ None ] * len (infos )
235
+ results = cast ( list [ jax . Array ], [ None ] * len (infos ) )
170
236
171
237
for global_mesh , idxs in inputs_by_global_mesh .items ():
172
238
grouped_infos = [infos [idx ] for idx in idxs ]
@@ -185,13 +251,26 @@ async def deserialize(
185
251
)
186
252
# each persistence call is awaited serially.
187
253
read_future .result ()
188
- for idx , arr in zip (idxs , grouped_arrays ):
254
+ for idx , info , arr in zip (idxs , grouped_infos , grouped_arrays ):
255
+ if meta := array_metadatas_cache .get (info .name ):
256
+ assert isinstance (
257
+ meta , array_metadata_lib .SerializedArrayMetadata
258
+ ), f"Expecting SerializedArrayMetadata but got { type (meta )} ."
259
+ assert isinstance (meta .ext_metadata , dict ), (
260
+ "Expecting ext_metadata to be a dict but got"
261
+ f" { type (meta .ext_metadata )} ."
262
+ )
263
+
264
+ if impl := meta .ext_metadata .get (array_metadata_lib .RANDOM_KEY_IMPL ):
265
+ arr = jax .random .wrap_key_data (arr , impl = impl )
189
266
results [idx ] = arr
190
- return results # pytype: disable=bad-return-type
267
+
268
+ return results
191
269
192
270
193
271
def register_pathways_handlers (
194
272
read_timeout : datetime .timedelta | None = None ,
273
+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
195
274
):
196
275
"""Function that must be called before saving or restoring with Pathways."""
197
276
logger .debug (
@@ -201,6 +280,7 @@ def register_pathways_handlers(
201
280
jax .Array ,
202
281
CloudPathwaysArrayHandler (
203
282
read_timeout = read_timeout ,
283
+ array_metadata_store = array_metadata_store ,
204
284
),
205
285
override = True ,
206
286
)
0 commit comments