diff --git a/pathwaysutils/_initialize.py b/pathwaysutils/_initialize.py index 0e42e59..a8df2a0 100644 --- a/pathwaysutils/_initialize.py +++ b/pathwaysutils/_initialize.py @@ -92,7 +92,9 @@ def initialize() -> None: profiling.monkey_patch_jax() # TODO: b/365549911 - Remove when OCDBT-compatible if _is_persistence_enabled(): - orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1)) + orbax_handler.register_pathways_handlers( + timeout=datetime.timedelta(hours=1), + ) # Turn off JAX compilation cache because Pathways handles its own # compilation cache. diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index c0e72e8..9a35ca8 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -49,16 +49,19 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler): def __init__( self, - read_timeout: datetime.timedelta | None = None, + timeout: datetime.timedelta | None = None, use_ocdbt: bool = False, ): - """Constructor. + """Orbax array handler for Pathways on Cloud with Persistence API. Args: - read_timeout: Duration indicating the timeout for reading arrays + timeout: Duration indicating the timeout for reading and writing arrays. + Default is 1 hour. use_ocdbt: allows using Tensorstore OCDBT driver. """ - self._read_timeout = read_timeout + if timeout is None: + timeout = datetime.timedelta(hours=1) + self.timeout = timeout if use_ocdbt: raise ValueError("OCDBT not supported for Pathways.") @@ -92,7 +95,7 @@ async def serialize( self._wait_for_directory_creation_signals() locations, names = extract_parent_dir_and_name(infos) - f = functools.partial(helper.write_one_array, timeout=self._read_timeout) + f = functools.partial(helper.write_one_array, timeout=self.timeout) futures_results = list(map(f, locations, names, values)) return [ @@ -181,7 +184,7 @@ async def deserialize( grouped_global_shapes, grouped_shardings, global_mesh.devices, - timeout=self._read_timeout, + timeout=self.timeout, ) # each persistence call is awaited serially. read_future.result() @@ -191,7 +194,7 @@ async def deserialize( def register_pathways_handlers( - read_timeout: datetime.timedelta | None = None, + timeout: datetime.timedelta | None = None, ): """Function that must be called before saving or restoring with Pathways.""" logger.debug( @@ -200,7 +203,7 @@ def register_pathways_handlers( type_handlers.register_type_handler( jax.Array, CloudPathwaysArrayHandler( - read_timeout=read_timeout, + timeout=timeout, ), override=True, )