Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pathwaysutils/_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def initialize() -> None:
profiling.monkey_patch_jax()
# TODO: b/365549911 - Remove when OCDBT-compatible
if _is_persistence_enabled():
orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1))
orbax_handler.register_pathways_handlers(
timeout=datetime.timedelta(hours=1),
)

# Turn off JAX compilation cache because Pathways handles its own
# compilation cache.
Expand Down
19 changes: 11 additions & 8 deletions pathwaysutils/persistence/orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,19 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):

def __init__(
self,
read_timeout: datetime.timedelta | None = None,
timeout: datetime.timedelta | None = None,
use_ocdbt: bool = False,
):
"""Constructor.
"""Orbax array handler for Pathways on Cloud with Persistence API.

Args:
read_timeout: Duration indicating the timeout for reading arrays
timeout: Duration indicating the timeout for reading and writing arrays.
Default is 1 hour.
use_ocdbt: allows using Tensorstore OCDBT driver.
"""
self._read_timeout = read_timeout
if timeout is None:
timeout = datetime.timedelta(hours=1)
self.timeout = timeout

if use_ocdbt:
raise ValueError("OCDBT not supported for Pathways.")
Expand Down Expand Up @@ -92,7 +95,7 @@ async def serialize(

self._wait_for_directory_creation_signals()
locations, names = extract_parent_dir_and_name(infos)
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
f = functools.partial(helper.write_one_array, timeout=self.timeout)
futures_results = list(map(f, locations, names, values))

return [
Expand Down Expand Up @@ -181,7 +184,7 @@ async def deserialize(
grouped_global_shapes,
grouped_shardings,
global_mesh.devices,
timeout=self._read_timeout,
timeout=self.timeout,
)
# each persistence call is awaited serially.
read_future.result()
Expand All @@ -191,7 +194,7 @@ async def deserialize(


def register_pathways_handlers(
read_timeout: datetime.timedelta | None = None,
timeout: datetime.timedelta | None = None,
):
"""Function that must be called before saving or restoring with Pathways."""
logger.debug(
Expand All @@ -200,7 +203,7 @@ def register_pathways_handlers(
type_handlers.register_type_handler(
jax.Array,
CloudPathwaysArrayHandler(
read_timeout=read_timeout,
timeout=timeout,
),
override=True,
)