diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py index 07d546d5d..d3b5364f3 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py @@ -333,6 +333,7 @@ def __init__( temporary_path_class: Optional[ Type[atomicity_types.TemporaryPath] ] = None, + signals_prefix: Optional[str] = None, ): jax.monitoring.record_event('/jax/orbax/async_checkpointer/init') if not checkpoint_args.has_registered_args(handler): @@ -360,6 +361,7 @@ def __init__( ) self._barrier_sync_key_prefix = barrier_sync_key_prefix self._file_options = file_options + self._signals_prefix = signals_prefix self._metadata_store = ( checkpoint_metadata_store or checkpoint.metadata_store(enable_write=True) @@ -445,7 +447,7 @@ def _callback() -> None: # no longer needed. if self._create_directories_asynchronously: future.AwaitableSignalsContract.remove_all_awaitable_signals( - current_operation_id + current_operation_id, prefix=self._signals_prefix ) return _callback @@ -478,6 +480,7 @@ async def _save( [tmpdir], completion_signals=_DIRECTORY_CREATION_SIGNALS, multiprocessing_options=self._multiprocessing_options, + signals_prefix=self._signals_prefix, ) ) else: diff --git a/checkpoint/orbax/checkpoint/_src/futures/future.py b/checkpoint/orbax/checkpoint/_src/futures/future.py index 341f1f1c3..eab274a8c 100644 --- a/checkpoint/orbax/checkpoint/_src/futures/future.py +++ b/checkpoint/orbax/checkpoint/_src/futures/future.py @@ -44,30 +44,35 @@ def get_unique_awaitable_singal_key( cls, signal: synchronization.HandlerAwaitableSignal, operation_id: str, + prefix: str | None = None, ) -> str: """Returns a unique barrier key for the signal. Args: signal: The signal to generate a barrier key for. operation_id: The operation id to use as a suffix for the barrier key. + prefix: The prefix to use for the barrier key. If None, the default is + used. Returns: A unique barrier key for the signal with operation id as key directory. """ - return ( - f'{cls.awaitable_signals_contract_prefix}_{operation_id}/{signal.value}' - ) + prefix = prefix or cls.awaitable_signals_contract_prefix + return f'{prefix}_{operation_id}/{signal.value}' @classmethod def get_awaitable_signals_from_contract( cls, operation_id: str | None = None, + prefix: str | None = None, ) -> Sequence[synchronization.HandlerAwaitableSignal]: """Gets the awaitable signals that may be sent for the current operation id. Args: operation_id: The operation id to use for the barrier keys. If None, the current operation id is used. + prefix: The prefix to use for the barrier key. If None, the default is + used. Returns: A list of awaitable signals that may be sent for the current operation id. @@ -80,6 +85,7 @@ def get_awaitable_signals_from_contract( barrier_key = cls.get_unique_awaitable_singal_key( synchronization.HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT, operation_id, + prefix=prefix, ) values_str = client.key_value_try_get(barrier_key) if values_str is None: @@ -103,6 +109,7 @@ def get_awaitable_signals_from_contract( def add_to_awaitable_signals_contract( cls, signals: Sequence[synchronization.HandlerAwaitableSignal], + prefix: str | None = None, ): """Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on. @@ -111,11 +118,15 @@ def add_to_awaitable_signals_contract( Args: signals: The signals to add to the list of awaitable signals. + prefix: The prefix to use for the barrier key. If None, the default is + used. """ if not signals: return - current_signals = list(cls.get_awaitable_signals_from_contract()) + current_signals = list( + cls.get_awaitable_signals_from_contract(prefix=prefix) + ) current_signals.extend(signals) keys = ','.join( [current_signal.value for current_signal in current_signals] @@ -127,6 +138,7 @@ def add_to_awaitable_signals_contract( barrier_key = cls.get_unique_awaitable_singal_key( synchronization.HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT, operation_id, + prefix=prefix, ) client.key_value_set(barrier_key, keys, allow_overwrite=True) logging.vlog( @@ -139,15 +151,18 @@ def add_to_awaitable_signals_contract( ) @classmethod - def remove_all_awaitable_signals(cls, operation_id: str | None = None): + def remove_all_awaitable_signals( + cls, operation_id: str | None = None, prefix: str | None = None + ): """Removes all awaitable signals for the current / given operation id.""" operation_id = ( operation_id or synchronization.OperationIdGenerator.get_current_operation_id() ) + prefix = prefix or cls.awaitable_signals_contract_prefix client = signaling_client.get_signaling_client() client.key_value_delete( - f'{cls.awaitable_signals_contract_prefix}_{operation_id}/' + f'{prefix}_{operation_id}/' ) logging.vlog( 1, diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity.py b/checkpoint/orbax/checkpoint/_src/path/atomicity.py index 01c3ae9d1..1ba4c14c0 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity.py @@ -731,6 +731,7 @@ def create_all_async( *, multiprocessing_options: options_lib.MultiprocessingOptions | None = None, subdirectories: Sequence[str] | None = None, + signals_prefix: str | None = None, ) -> future.Future: """Creates all temporary paths in parallel asynchronously. @@ -743,6 +744,8 @@ def create_all_async( subdirectories: Sequence of subdirectories to create under `paths`. If not provided, no subdirectories will be created. The same set of subdirectories will be created under each path in `paths`. + signals_prefix: The prefix to use for the barrier key. If None, the default + is used. Returns: A future that which sends the completion signals when all paths are created. @@ -775,7 +778,7 @@ def create_all_async( timeout_secs=multihost.coordination_timeout(), ) future.AwaitableSignalsContract.add_to_awaitable_signals_contract( - completion_signals + completion_signals, prefix=signals_prefix ) # Sync to enusre that all hosts have the same awaitable signals contract. diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 1875d1add..b2544a9b3 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -719,11 +719,6 @@ def __init__( self._options = options or CheckpointManagerOptions() self._multiprocessing_options = self._options.multiprocessing_options - if self._options.enable_per_process_directory_creation: - future.AwaitableSignalsContract.awaitable_signals_contract_prefix += ( - f'_{multihost.process_index()}' - ) - self._save_decision_policy = ( self._options.save_decision_policy or _get_default_save_decision_policy(self._options) @@ -952,6 +947,11 @@ def _configure_checkpointer_common( use_async: bool, ) -> Checkpointer: if use_async: + signals_prefix = ( + f'awaitable_signals_contract_{multihost.process_index()}' + if options.enable_per_process_directory_creation + else None + ) return async_checkpointer.AsyncCheckpointer( handler, multiprocessing_options=options.multiprocessing_options, @@ -959,6 +959,7 @@ def _configure_checkpointer_common( file_options=options.file_options, checkpoint_metadata_store=self._non_blocking_metadata_store, temporary_path_class=options.temporary_path_class, + signals_prefix=signals_prefix, ) else: return Checkpointer( diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/checkpoint_manager_test.py index 92ecea130..6e25d71a4 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager_test.py @@ -153,6 +153,33 @@ def tearDown(self): test_utils.sync_global_processes('CheckpointManagerTest:tests_complete') super().tearDown() + def test_per_process_directory_creation_prefix(self): + options = checkpoint_manager.CheckpointManagerOptions( + enable_per_process_directory_creation=True, + multiprocessing_options=checkpoint_manager.MultiprocessingOptions( + primary_host=None + ), + ) + signals_contract = checkpoint_manager.future.AwaitableSignalsContract + original_prefix = signals_contract.awaitable_signals_contract_prefix + try: + _ = checkpoint_manager.CheckpointManager(self.directory, options=options) + self.assertEqual( + signals_contract.awaitable_signals_contract_prefix, + original_prefix + ) + + options2 = checkpoint_manager.CheckpointManagerOptions( + enable_per_process_directory_creation=False + ) + _ = checkpoint_manager.CheckpointManager(self.directory, options=options2) + self.assertEqual( + checkpoint_manager.future.AwaitableSignalsContract.awaitable_signals_contract_prefix, + original_prefix + ) + finally: + pass + def save_params(self, step, manager, params, metrics=None, force=False): return manager.save( step,