Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(
temporary_path_class: Optional[
Type[atomicity_types.TemporaryPath]
] = None,
signals_prefix: Optional[str] = None,
):
jax.monitoring.record_event('/jax/orbax/async_checkpointer/init')
if not checkpoint_args.has_registered_args(handler):
Expand Down Expand Up @@ -360,6 +361,7 @@ def __init__(
)
self._barrier_sync_key_prefix = barrier_sync_key_prefix
self._file_options = file_options
self._signals_prefix = signals_prefix
self._metadata_store = (
checkpoint_metadata_store
or checkpoint.metadata_store(enable_write=True)
Expand Down Expand Up @@ -445,7 +447,7 @@ def _callback() -> None:
# no longer needed.
if self._create_directories_asynchronously:
future.AwaitableSignalsContract.remove_all_awaitable_signals(
current_operation_id
current_operation_id, prefix=self._signals_prefix
)

return _callback
Expand Down Expand Up @@ -478,6 +480,7 @@ async def _save(
[tmpdir],
completion_signals=_DIRECTORY_CREATION_SIGNALS,
multiprocessing_options=self._multiprocessing_options,
signals_prefix=self._signals_prefix,
)
)
else:
Expand Down
27 changes: 21 additions & 6 deletions checkpoint/orbax/checkpoint/_src/futures/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,35 @@ def get_unique_awaitable_singal_key(
cls,
signal: synchronization.HandlerAwaitableSignal,
operation_id: str,
prefix: str | None = None,
) -> str:
"""Returns a unique barrier key for the signal.

Args:
signal: The signal to generate a barrier key for.
operation_id: The operation id to use as a suffix for the barrier key.
prefix: The prefix to use for the barrier key. If None, the default is
used.

Returns:
A unique barrier key for the signal with operation id as key directory.
"""
return (
f'{cls.awaitable_signals_contract_prefix}_{operation_id}/{signal.value}'
)
prefix = prefix or cls.awaitable_signals_contract_prefix
return f'{prefix}_{operation_id}/{signal.value}'

@classmethod
def get_awaitable_signals_from_contract(
cls,
operation_id: str | None = None,
prefix: str | None = None,
) -> Sequence[synchronization.HandlerAwaitableSignal]:
"""Gets the awaitable signals that may be sent for the current operation id.

Args:
operation_id: The operation id to use for the barrier keys. If None, the
current operation id is used.
prefix: The prefix to use for the barrier key. If None, the default is
used.

Returns:
A list of awaitable signals that may be sent for the current operation id.
Expand All @@ -80,6 +85,7 @@ def get_awaitable_signals_from_contract(
barrier_key = cls.get_unique_awaitable_singal_key(
synchronization.HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT,
operation_id,
prefix=prefix,
)
values_str = client.key_value_try_get(barrier_key)
if values_str is None:
Expand All @@ -103,6 +109,7 @@ def get_awaitable_signals_from_contract(
def add_to_awaitable_signals_contract(
cls,
signals: Sequence[synchronization.HandlerAwaitableSignal],
prefix: str | None = None,
):
"""Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on.

Expand All @@ -111,11 +118,15 @@ def add_to_awaitable_signals_contract(

Args:
signals: The signals to add to the list of awaitable signals.
prefix: The prefix to use for the barrier key. If None, the default is
used.
"""
if not signals:
return

current_signals = list(cls.get_awaitable_signals_from_contract())
current_signals = list(
cls.get_awaitable_signals_from_contract(prefix=prefix)
)
current_signals.extend(signals)
keys = ','.join(
[current_signal.value for current_signal in current_signals]
Expand All @@ -127,6 +138,7 @@ def add_to_awaitable_signals_contract(
barrier_key = cls.get_unique_awaitable_singal_key(
synchronization.HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT,
operation_id,
prefix=prefix,
)
client.key_value_set(barrier_key, keys, allow_overwrite=True)
logging.vlog(
Expand All @@ -139,15 +151,18 @@ def add_to_awaitable_signals_contract(
)

@classmethod
def remove_all_awaitable_signals(cls, operation_id: str | None = None):
def remove_all_awaitable_signals(
cls, operation_id: str | None = None, prefix: str | None = None
):
"""Removes all awaitable signals for the current / given operation id."""
operation_id = (
operation_id
or synchronization.OperationIdGenerator.get_current_operation_id()
)
prefix = prefix or cls.awaitable_signals_contract_prefix
client = signaling_client.get_signaling_client()
client.key_value_delete(
f'{cls.awaitable_signals_contract_prefix}_{operation_id}/'
f'{prefix}_{operation_id}/'
)
logging.vlog(
1,
Expand Down
5 changes: 4 additions & 1 deletion checkpoint/orbax/checkpoint/_src/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def create_all_async(
*,
multiprocessing_options: options_lib.MultiprocessingOptions | None = None,
subdirectories: Sequence[str] | None = None,
signals_prefix: str | None = None,
) -> future.Future:
"""Creates all temporary paths in parallel asynchronously.

Expand All @@ -743,6 +744,8 @@ def create_all_async(
subdirectories: Sequence of subdirectories to create under `paths`. If not
provided, no subdirectories will be created. The same set of
subdirectories will be created under each path in `paths`.
signals_prefix: The prefix to use for the barrier key. If None, the default
is used.

Returns:
A future that which sends the completion signals when all paths are created.
Expand Down Expand Up @@ -775,7 +778,7 @@ def create_all_async(
timeout_secs=multihost.coordination_timeout(),
)
future.AwaitableSignalsContract.add_to_awaitable_signals_contract(
completion_signals
completion_signals, prefix=signals_prefix
)

# Sync to enusre that all hosts have the same awaitable signals contract.
Expand Down
11 changes: 6 additions & 5 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,11 +719,6 @@ def __init__(
self._options = options or CheckpointManagerOptions()
self._multiprocessing_options = self._options.multiprocessing_options

if self._options.enable_per_process_directory_creation:
future.AwaitableSignalsContract.awaitable_signals_contract_prefix += (
f'_{multihost.process_index()}'
)

self._save_decision_policy = (
self._options.save_decision_policy
or _get_default_save_decision_policy(self._options)
Expand Down Expand Up @@ -952,13 +947,19 @@ def _configure_checkpointer_common(
use_async: bool,
) -> Checkpointer:
if use_async:
signals_prefix = (
f'awaitable_signals_contract_{multihost.process_index()}'
if options.enable_per_process_directory_creation
else None
)
return async_checkpointer.AsyncCheckpointer(
handler,
multiprocessing_options=options.multiprocessing_options,
async_options=options.async_options or AsyncOptions(),
file_options=options.file_options,
checkpoint_metadata_store=self._non_blocking_metadata_store,
temporary_path_class=options.temporary_path_class,
signals_prefix=signals_prefix,
)
else:
return Checkpointer(
Expand Down
27 changes: 27 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ def tearDown(self):
test_utils.sync_global_processes('CheckpointManagerTest:tests_complete')
super().tearDown()

def test_per_process_directory_creation_prefix(self):
options = checkpoint_manager.CheckpointManagerOptions(
enable_per_process_directory_creation=True,
multiprocessing_options=checkpoint_manager.MultiprocessingOptions(
primary_host=None
),
)
signals_contract = checkpoint_manager.future.AwaitableSignalsContract
original_prefix = signals_contract.awaitable_signals_contract_prefix
try:
_ = checkpoint_manager.CheckpointManager(self.directory, options=options)
self.assertEqual(
signals_contract.awaitable_signals_contract_prefix,
original_prefix
)

options2 = checkpoint_manager.CheckpointManagerOptions(
enable_per_process_directory_creation=False
)
_ = checkpoint_manager.CheckpointManager(self.directory, options=options2)
self.assertEqual(
checkpoint_manager.future.AwaitableSignalsContract.awaitable_signals_contract_prefix,
original_prefix
)
finally:
pass

def save_params(self, step, manager, params, metrics=None, force=False):
return manager.save(
step,
Expand Down
Loading