@@ -49,16 +49,19 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
49
49
50
50
def __init__ (
51
51
self ,
52
- read_timeout : datetime .timedelta | None = None ,
52
+ timeout : datetime .timedelta | None = None ,
53
53
use_ocdbt : bool = False ,
54
54
):
55
- """Constructor .
55
+ """Orbax array handler for Pathways on Cloud with Persistence API .
56
56
57
57
Args:
58
- read_timeout: Duration indicating the timeout for reading arrays
58
+ timeout: Duration indicating the timeout for reading and writing arrays.
59
+ Default is 1 hour.
59
60
use_ocdbt: allows using Tensorstore OCDBT driver.
60
61
"""
61
- self ._read_timeout = read_timeout
62
+ if timeout is None :
63
+ timeout = datetime .timedelta (hours = 1 )
64
+ self .timeout = timeout
62
65
63
66
if use_ocdbt :
64
67
raise ValueError ("OCDBT not supported for Pathways." )
@@ -92,7 +95,7 @@ async def serialize(
92
95
93
96
self ._wait_for_directory_creation_signals ()
94
97
locations , names = extract_parent_dir_and_name (infos )
95
- f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
98
+ f = functools .partial (helper .write_one_array , timeout = self .timeout )
96
99
futures_results = list (map (f , locations , names , values ))
97
100
98
101
return [
@@ -181,7 +184,7 @@ async def deserialize(
181
184
grouped_global_shapes ,
182
185
grouped_shardings ,
183
186
global_mesh .devices ,
184
- timeout = self ._read_timeout ,
187
+ timeout = self .timeout ,
185
188
)
186
189
# each persistence call is awaited serially.
187
190
read_future .result ()
@@ -191,7 +194,7 @@ async def deserialize(
191
194
192
195
193
196
def register_pathways_handlers (
194
- read_timeout : datetime .timedelta | None = None ,
197
+ timeout : datetime .timedelta | None = None ,
195
198
):
196
199
"""Function that must be called before saving or restoring with Pathways."""
197
200
logger .debug (
@@ -200,7 +203,7 @@ def register_pathways_handlers(
200
203
type_handlers .register_type_handler (
201
204
jax .Array ,
202
205
CloudPathwaysArrayHandler (
203
- read_timeout = read_timeout ,
206
+ timeout = timeout ,
204
207
),
205
208
override = True ,
206
209
)
0 commit comments