From 1c0231641a63b8561aefefb96e964e49f98176ac Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Fri, 22 May 2026 01:00:44 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 919512074 --- .../_src/checkpointers/async_checkpointer.py | 17 ++++++++--------- .../_src/checkpointers/checkpointer.py | 3 +++ .../experimental/v1/_src/saving/execution.py | 6 ++++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py index 07d546d5d..fe9c60a0b 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py @@ -386,6 +386,7 @@ def _make_on_commit_callback( tmpdir: atomicity_types.TemporaryPath, custom_metadata: dict[str, Any] | None, checkpoint_start_time: float, + ckpt_args: checkpoint_args.CheckpointArgs, ) -> Callable[[], None]: # Directory is the final directory. @@ -453,9 +454,9 @@ def _callback() -> None: async def _save( self, tmpdir: atomicity_types.TemporaryPath, - *args, + *, + ckpt_args: checkpoint_args.CheckpointArgs, force: bool = False, - **kwargs, ): directory = tmpdir.get_final() if await async_path.exists(directory): @@ -483,10 +484,6 @@ async def _save( else: await self.create_temporary_path(tmpdir) # Run copy ops. - # Try to save using new CheckpointArgs API if supported by the handler. - ckpt_args = checkpointer.construct_checkpoint_args( - self._handler, True, *args, **kwargs - ) if isinstance( self._handler, async_checkpoint_handler.DeferredPathAsyncCheckpointHandler, @@ -554,15 +551,17 @@ def save( tmpdir = self.get_temporary_path(directory) self.wait_until_finished() self.synchronize_next_awaitable_signal_operation_id() + ckpt_args = checkpointer.construct_checkpoint_args( + self._handler, True, *args, **kwargs + ) on_commit_callback = self._make_on_commit_callback( - tmpdir, custom_metadata, checkpoint_start_time + tmpdir, custom_metadata, checkpoint_start_time, ckpt_args ) commit_ops = asyncio_utils.run_sync( self._save( tmpdir, - *args, + ckpt_args=ckpt_args, force=force, - **kwargs, ) ) operation_recorder.record_blocking_completion( diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py index e92c3416d..4eb64a00c 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py @@ -14,6 +14,7 @@ """Synchronous Checkpointer implementation.""" +import os import time from typing import Any, Iterable, Optional, Type @@ -72,6 +73,8 @@ def construct_checkpoint_args( return restore_arg_cls(*args, **kwargs) + + class Checkpointer( abstract_checkpointer.AbstractCheckpointer, epy.ContextManager ): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index a8d5946ae..1189c5e1d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -17,6 +17,7 @@ from __future__ import annotations import hashlib +import os import time from typing import Any, Awaitable, Callable, Iterable import uuid @@ -85,6 +86,8 @@ def add_internal_checkpointables( return checkpointables + + class _SaveResponse(AsyncResponse[None]): """An :py:class:`.AsyncResponse` representing the result of :py:func:`.save_async`.""" @@ -98,6 +101,7 @@ def __init__( start_time: float, custom_metadata: tree_types.JsonType | None, context: context_lib.Context, + checkpointables: dict[str, Any], async_origin: bool, partial_save: bool = False, ): @@ -108,6 +112,7 @@ def __init__( self._start_time = start_time self._custom_metadata = custom_metadata self._context = context + self._checkpointables = checkpointables self._async_origin = async_origin self._partial_save = partial_save self._thread_runner = thread_utils.BackgroundThreadRunner[None]( @@ -150,6 +155,7 @@ def create( start_time=start_time, custom_metadata=custom_metadata, context=context, + checkpointables=checkpointables, async_origin=async_origin, partial_save=partial_save, )