Skip to content

Commit 130cded

Browse files
lukebaumanncopybara-github
authored andcommitted
Rename read_timeout to timeout in Orbax handler.
This change renames the `read_timeout` parameter to a more general `timeout` in `CloudPathwaysArrayHandler` and `register_pathways_handlers`. A default timeout of 1 hour is also added within the handler's constructor. PiperOrigin-RevId: 813431195
1 parent 2fa0623 commit 130cded

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

pathwaysutils/_initialize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def initialize() -> None:
9292
profiling.monkey_patch_jax()
9393
# TODO: b/365549911 - Remove when OCDBT-compatible
9494
if _is_persistence_enabled():
95-
orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1))
95+
orbax_handler.register_pathways_handlers(
96+
timeout=datetime.timedelta(hours=1),
97+
)
9698

9799
# Turn off JAX compilation cache because Pathways handles its own
98100
# compilation cache.

pathwaysutils/persistence/orbax_handler.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,19 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
4949

5050
def __init__(
5151
self,
52-
read_timeout: datetime.timedelta | None = None,
52+
timeout: datetime.timedelta | None = None,
5353
use_ocdbt: bool = False,
5454
):
55-
"""Constructor.
55+
"""Orbax array handler for Pathways on Cloud with Persistence API.
5656
5757
Args:
58-
read_timeout: Duration indicating the timeout for reading arrays
58+
timeout: Duration indicating the timeout for reading and writing arrays.
59+
Default is 1 hour.
5960
use_ocdbt: allows using Tensorstore OCDBT driver.
6061
"""
61-
self._read_timeout = read_timeout
62+
if timeout is None:
63+
timeout = datetime.timedelta(hours=1)
64+
self.timeout = timeout
6265

6366
if use_ocdbt:
6467
raise ValueError("OCDBT not supported for Pathways.")
@@ -92,7 +95,7 @@ async def serialize(
9295

9396
self._wait_for_directory_creation_signals()
9497
locations, names = extract_parent_dir_and_name(infos)
95-
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
98+
f = functools.partial(helper.write_one_array, timeout=self.timeout)
9699
futures_results = list(map(f, locations, names, values))
97100

98101
return [
@@ -181,7 +184,7 @@ async def deserialize(
181184
grouped_global_shapes,
182185
grouped_shardings,
183186
global_mesh.devices,
184-
timeout=self._read_timeout,
187+
timeout=self.timeout,
185188
)
186189
# each persistence call is awaited serially.
187190
read_future.result()
@@ -191,7 +194,7 @@ async def deserialize(
191194

192195

193196
def register_pathways_handlers(
194-
read_timeout: datetime.timedelta | None = None,
197+
timeout: datetime.timedelta | None = None,
195198
):
196199
"""Function that must be called before saving or restoring with Pathways."""
197200
logger.debug(
@@ -200,7 +203,7 @@ def register_pathways_handlers(
200203
type_handlers.register_type_handler(
201204
jax.Array,
202205
CloudPathwaysArrayHandler(
203-
read_timeout=read_timeout,
206+
timeout=timeout,
204207
),
205208
override=True,
206209
)

0 commit comments

Comments
 (0)