diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 49298f392..0700a6e3c 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- #v1 Rename `save/load_pytree` to `save/load`. Eliminate most user-facing +"pytree" terminology in favor of "state" as a more specific term. +Add `deprecations.py` for handling deprecated public functions. + ## [0.11.40] - 2026-05-18 ### Removed diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark.py index 81adc54bc..c2fc8f159 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark.py @@ -126,12 +126,12 @@ def test_fn( resolved_path: str = self._client.resolve(self._xid, step) assert resolved_path.startswith(LUSTRE_PATH_PREFIX), resolved_path with metrics.measure("save_cache", metrics_to_measure): - ocp.save_pytree(resolved_path, context.pytree) + ocp.save(resolved_path, context.pytree) with metrics.measure("finalize_cache", metrics_to_measure): self._client.finalize(self._xid, step) with metrics.measure("save", metrics_to_measure): - ocp.save_pytree(context.path / str(step), context.pytree) + ocp.save(context.path / str(step), context.pytree) abstract_pytree = jax.tree.map( ocp.arrays.to_shape_dtype_struct, context.pytree @@ -148,12 +148,14 @@ def test_fn( with metrics.measure("wait_prefetch_cache", metrics_to_measure): self._client.await_transfer(self._xid, step - 1) with metrics.measure("restore_cache", metrics_to_measure): - restored_pytree = ocp.load_pytree(resolved_path, abstract_pytree) + restored_pytree = ocp.load( + resolved_path, abstract_state=abstract_pytree + ) restored_pytree = self._clear_pytree(restored_pytree) del restored_pytree with metrics.measure("restore", metrics_to_measure): - restored_pytree = ocp.load_pytree( - context.path / str(step), abstract_pytree + restored_pytree = ocp.load( + context.path / str(step), abstract_state=abstract_pytree ) restored_pytree = self._clear_pytree(restored_pytree) del restored_pytree diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark_test.py index 48dd3fc5c..32c951e83 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/lustre_benchmark_test.py @@ -53,7 +53,7 @@ def test_benchmark_test_fn(self): # Create temporary directories for the test # Ensure all processes use the same directory base_dir = self.create_tempdir(name='benchmark') - # cache_dir should not exist before save_pytree is called + # cache_dir should not exist before save is called cache_dir_path = epath.Path(base_dir.full_path) / 'cache' work_dir = base_dir.mkdir('work') @@ -156,7 +156,7 @@ def resolve_side_effect(xid, step): ) # Create a checkpoint at the "GCS" location for step 0 so restore works - ocp.save_pytree(gcs_dir_path / '0', pytree) + ocp.save(gcs_dir_path / '0', pytree) # Run the test function result = generator.test_fn(context) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py index 2f964a0c0..c1aa1dd58 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py @@ -164,12 +164,12 @@ def test_fn( jax.profiler.start_trace(context.path / "trace_save") if options.async_enabled: with metrics.measure("save_blocking", metrics_to_measure): - f = ocp.save_pytree_async(save_path, pytree) + f = ocp.save_async(save_path, pytree) with metrics.measure("save_background", metrics_to_measure): f.result() else: with metrics.measure("save_blocking", metrics_to_measure): - ocp.save_pytree(save_path, pytree) + ocp.save(save_path, pytree) with metrics.measure("save_background", metrics_to_measure): pass context.pytree = clear_pytree(context.pytree) @@ -179,7 +179,7 @@ def test_fn( if options.enable_trace: jax.profiler.start_trace(context.path / "trace_load") with metrics.measure("load", metrics_to_measure): - restored_pytree = ocp.load_pytree(save_path, abstract_pytree) + restored_pytree = ocp.load(save_path, abstract_state=abstract_pytree) clear_pytree(restored_pytree) if options.enable_trace: jax.profiler.stop_trace() diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util.py index 8cd36a758..47cbfd063 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util.py @@ -35,7 +35,7 @@ def get_multi_slice_abstract_state( ) -> Any: """Returns the abstract state for all replicas.""" with ocp.Context(context=context): - metadata = ocp.pytree_metadata(reference_checkpoint_path) + metadata = ocp.metadata(reference_checkpoint_path) # Abstract tree has shardings on a single replica. single_replica_abstract_state = ( checkpoint_generation.get_abstract_state_from_sharding_config( diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util_test.py index 3aa4ff5d0..4d26a46af 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/multi_slice_util_test.py @@ -56,7 +56,7 @@ def test_get_multi_slice_abstract_state(self): # Setup real checkpoint and sharding config pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}} ref_ckpt_path = self.directory / 'ref_ckpt' - ocp.save_pytree(ref_ckpt_path, pytree) + ocp.save(ref_ckpt_path, pytree) sharding_config = { 'a': { diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py index 9501b7444..9bfc53f64 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark.py @@ -116,8 +116,8 @@ def test_fn( ) with ocp.Context(context=options.context): - loaded_pytree = ocp.load_pytree( - reference_checkpoint_path, abstract_pytree + loaded_pytree = ocp.load( + reference_checkpoint_path, abstract_state=abstract_pytree ) for step in range(options.num_savings): @@ -132,12 +132,12 @@ def test_fn( "ReplicaParallelMultislice: Async Saving pytree to %s.", save_path, ) - f = ocp.save_pytree_async(save_path, loaded_pytree) + f = ocp.save_async(save_path, loaded_pytree) with metrics.measure("save_background", metrics_to_measure): f.result() else: with metrics.measure("save_blocking", metrics_to_measure): - ocp.save_pytree(save_path, loaded_pytree) + ocp.save(save_path, loaded_pytree) with metrics.measure("save_background", metrics_to_measure): pass diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark_test.py index 000e2b09b..350229051 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/replica_parallel_multislice_benchmark_test.py @@ -86,7 +86,7 @@ def test_benchmark_test_fn(self): # Setup real checkpoint and sharding config pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}} ref_ckpt_path = self.directory / 'ref_ckpt' - ocp.save_pytree(ref_ckpt_path, pytree) + ocp.save(ref_ckpt_path, pytree) sharding_config = { 'a': { diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py index 936df3e76..ec32dbdd2 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py @@ -96,7 +96,7 @@ def test_fn( ) with ocp.Context(context=options.context): - metadata = ocp.pytree_metadata(reference_checkpoint_path) + metadata = ocp.metadata(reference_checkpoint_path) abstract_pytree = ( checkpoint_generation.get_abstract_state_from_sharding_config( reference_sharding_path, @@ -108,8 +108,8 @@ def test_fn( if options.enable_trace: jax.profiler.start_trace(context.path / "trace_load") with metrics.measure("load", metrics_to_measure): - restored_pytree = ocp.load_pytree( - reference_checkpoint_path, abstract_pytree + restored_pytree = ocp.load( + reference_checkpoint_path, abstract_state=abstract_pytree ) benchmark.clear_pytree(restored_pytree) if options.enable_trace: diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark_test.py index 7ebc9127e..179bef37d 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark_test.py @@ -68,7 +68,7 @@ def test_benchmark_test_fn(self): # Setup real checkpoint and sharding config pytree = {'a': jnp.arange(8), 'b': {'c': jnp.ones((4, 4))}} ref_ckpt_path = self.directory / 'ref_ckpt' - ocp.save_pytree(ref_ckpt_path, pytree) + ocp.save(ref_ckpt_path, pytree) sharding_config = { 'a': { diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py index 66b16d542..b6668dcf8 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py @@ -122,8 +122,8 @@ def test_fn( if options.enable_trace: jax.profiler.start_trace(context.path / "trace_load") with metrics.measure("load", metrics_to_measure): - restored_pytree = ocp.load_pytree( - reference_checkpoint_path, abstract_pytree + restored_pytree = ocp.load( + reference_checkpoint_path, abstract_state=abstract_pytree ) benchmark.clear_pytree(restored_pytree) if options.enable_trace: diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py index 11c1c7c6d..f275eb1bd 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py @@ -86,7 +86,7 @@ def test_benchmark_test_fn(self): # Setup real checkpoint and sharding config pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}} ref_ckpt_path = self.directory / 'ref_ckpt' - ocp.save_pytree(ref_ckpt_path, pytree) + ocp.save(ref_ckpt_path, pytree) sharding_config = { 'a': { diff --git a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py index 2d4c1e9f0..c3ad8c7cb 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/__init__.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/__init__.py @@ -48,13 +48,13 @@ from orbax.checkpoint.experimental.v1._src.loading.loading import ( load_checkpointables, load_checkpointables_async, - load_pytree, - load_pytree_async, + load, + load_async, PLACEHOLDER, ) from orbax.checkpoint.experimental.v1._src.metadata.loading import ( checkpointables_metadata, - pytree_metadata, + metadata, ) from orbax.checkpoint.experimental.v1._src.metadata.types import ( CheckpointMetadata, @@ -63,6 +63,15 @@ from orbax.checkpoint.experimental.v1._src.saving.saving import ( save_checkpointables, save_checkpointables_async, + save, + save_async, +) + +### DEPRECATED APIS ### +from orbax.checkpoint.experimental.v1._src.deprecations.deprecations import ( save_pytree, save_pytree_async, + load_pytree, + load_pytree_async, + pytree_metadata, ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py index eb27f71df..f0decd624 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py @@ -73,7 +73,7 @@ class Context(epy.ContextManager): # Basic usage with ocp.Context(pytree_options=ocp.options.PyTreeOptions()): - ocp.save_pytree(directory, tree) + ocp.save(directory, tree) # Inheriting properties from an existing context with ocp.Context(pytree_options=ocp.options.PyTreeOptions()) as outer_ctx: @@ -81,7 +81,7 @@ class Context(epy.ContextManager): with ocp.Context(outer_ctx, array_options=ocp.options.ArrayOptions() ) as inner_ctx: - ocp.save_pytree(directory, tree) + ocp.save(directory, tree) Context is not shared across threads:: @@ -92,9 +92,9 @@ class Context(epy.ContextManager): with ocp.Context( pytree_options=ocp.options.PyTreeOptions() ): # Thread #1 creates Context. - # The following save_pytree call is executed in Thread #2, which sees + # The following save call is executed in Thread #2, which sees # a "default" Context, NOT the one created above. - executor.submit(ocp.save_pytree, directory, tree) + executor.submit(ocp.save, directory, tree) Attributes: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py index fecb1a529..764419d50 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options_test.py @@ -71,9 +71,9 @@ def is_prioritized_key_fn(path): pytree = {'a': jnp.ones((1,))} - # save_pytree will eventually call BasePyTreeCheckpointHandler + # save will eventually call BasePyTreeCheckpointHandler try: - saving.save_pytree('/tmp/test', pytree) + saving.save('/tmp/test', pytree) except Exception: # pylint: disable=broad-except # We might get some errors because we mocked too much, # but we check if mock_handler_class was called. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/deprecations/deprecations.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/deprecations/deprecations.py new file mode 100644 index 000000000..7f5b65db9 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/deprecations/deprecations.py @@ -0,0 +1,70 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines deprecated legacy aliases for Orbax V1 API.""" + +import functools +from typing import Any, Callable, TypeVar +import warnings + +from orbax.checkpoint.experimental.v1._src.loading import loading +from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading +from orbax.checkpoint.experimental.v1._src.saving import saving + +_FuncT = TypeVar('_FuncT', bound=Callable[..., Any]) + + +def deprecated(*, new: _FuncT) -> Callable[[_FuncT], _FuncT]: + """Decorator to mark a function as a deprecated alias of another function.""" + + def decorator(deprecated_func: _FuncT) -> _FuncT: + @functools.wraps(deprecated_func) + def wrapper(*args, **kwargs): + alias_name = getattr(deprecated_func, '__name__', 'unknown') + new_name = getattr(new, '__name__', 'unknown') + warnings.warn( + f'`{alias_name}` is deprecated, use `{new_name}` instead.', + DeprecationWarning, + stacklevel=2, + ) + return deprecated_func(*args, **kwargs) + + return wrapper + + return decorator + + +@deprecated(new=saving.save) +def save_pytree(*args, **kwargs): + return saving.save(*args, **kwargs) + + +@deprecated(new=saving.save_async) +def save_pytree_async(*args, **kwargs): + return saving.save_async(*args, **kwargs) + + +@deprecated(new=loading.load) +def load_pytree(*args, **kwargs): + return loading.load(*args, **kwargs) + + +@deprecated(new=loading.load_async) +def load_pytree_async(*args, **kwargs): + return loading.load_async(*args, **kwargs) + + +@deprecated(new=metadata_loading.metadata) +def pytree_metadata(*args, **kwargs): + return metadata_loading.metadata(*args, **kwargs) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/deprecations/deprecations_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/deprecations/deprecations_test.py new file mode 100644 index 000000000..3ee281e20 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/deprecations/deprecations_test.py @@ -0,0 +1,81 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from orbax.checkpoint.experimental.v1._src.deprecations import deprecations +from orbax.checkpoint.experimental.v1._src.loading import loading +from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading +from orbax.checkpoint.experimental.v1._src.saving import saving + + +class DeprecationsTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='save_pytree', + alias_func=deprecations.save_pytree, + alias_name='save_pytree', + target_module=saving, + target_name='save', + ), + dict( + testcase_name='save_pytree_async', + alias_func=deprecations.save_pytree_async, + alias_name='save_pytree_async', + target_module=saving, + target_name='save_async', + ), + dict( + testcase_name='load_pytree', + alias_func=deprecations.load_pytree, + alias_name='load_pytree', + target_module=loading, + target_name='load', + ), + dict( + testcase_name='load_pytree_async', + alias_func=deprecations.load_pytree_async, + alias_name='load_pytree_async', + target_module=loading, + target_name='load_async', + ), + dict( + testcase_name='pytree_metadata', + alias_func=deprecations.pytree_metadata, + alias_name='pytree_metadata', + target_module=metadata_loading, + target_name='metadata', + ), + ) + def test_deprecated_alias( + self, alias_func, alias_name, target_module, target_name + ): + with mock.patch.object(target_module, target_name) as mock_target: + mock_target.return_value = 'expected_result' + + with self.assertWarnsRegex( + DeprecationWarning, + f'`{alias_name}` is deprecated, use `{target_name}` instead.', + ): + result = alias_func('arg1', kwarg1='val1') + + self.assertEqual(result, 'expected_result') + mock_target.assert_called_once_with('arg1', kwarg1='val1') + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py index 78578b5f9..76c13b2d3 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py @@ -31,6 +31,8 @@ from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY + def _try_register_handler( handler_type: type[handler_types.CheckpointableHandler], @@ -80,6 +82,4 @@ def _try_register_handler( 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', ], ) -_try_register_handler( - pytree_handler.PyTreeHandler, checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY -) +_try_register_handler(pytree_handler.PyTreeHandler, STATE_CHECKPOINTABLE_KEY) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index b1ec35e73..703e38922 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -35,6 +35,7 @@ from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.serialization import compatibility @@ -55,7 +56,7 @@ base_pytree_checkpoint_handler.PartialSaveReplacementError ) -PYTREE_CHECKPOINTABLE_KEY = 'pytree' +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY def _get_remaining_timeout( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/resolution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/resolution.py index e55a51d03..367f4d7a9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/resolution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/resolution.py @@ -22,7 +22,9 @@ from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY InternalCheckpointMetadata = ( step_metadata_serialization.InternalCheckpointMetadata ) @@ -64,7 +66,8 @@ def _resolve_single_handler_for_load( The handler for the checkpointable. Raises: - registration.NoEntryError: If no handler is resolved and 'pytree' name is + registration.NoEntryError: If no handler is resolved and + STATE_CHECKPOINTABLE_KEY name is not registered. """ # 1. Resolve the checkpointable's handler using handler discovery. @@ -86,16 +89,17 @@ def _resolve_single_handler_for_load( # 2. If no handler is resolved yet, try to resolve using the default # pytree handler. pytree_handler = registration.get_registered_handler_by_name( - handler_registry, "pytree" + handler_registry, STATE_CHECKPOINTABLE_KEY ) if not pytree_handler: raise registration.NoEntryError( f"Could not resolve a handler for '{checkpointable_name}' and no" - f" 'pytree' handler found in {handler_registry})." + f" '{STATE_CHECKPOINTABLE_KEY}' handler found in {handler_registry})." " Please inspect the checkpoint contents via" " `loading.checkpointables_metadata`. You may need to provide an" " abstract_checkpointable or register a missing handler for this name" - " or for 'pytree' name which is used as a fallback." + f" or for '{STATE_CHECKPOINTABLE_KEY}' name which is used as a" + " fallback." ) return pytree_handler diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py index 49301179d..737b11fa5 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py @@ -96,7 +96,7 @@ class CheckpointableHandler(Protocol[_Checkpointable, _AbstractCheckpointable]): In most contexts, when dealing with just a PyTree, the API of choice is:: - ocp.save_pytree(directory, pytree) + ocp.save(directory, pytree) The concept of "checkpointable" is not so obvious in this case. When dealing with multiple objects, we can use:: @@ -120,7 +120,7 @@ class CheckpointableHandler(Protocol[_Checkpointable, _AbstractCheckpointable]): ), ) # Equivalently, - ocp.load_pytree(directory, abstract_model_params) + ocp.load(directory, abstract_model_params) With the methods defined in this Protocol (`save`, `load`), logic within the method itself is executed in the main thread, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py index af4791e00..2993ea979 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/checkpoint_layout.py @@ -24,7 +24,7 @@ ### Constants shared by all layouts. ### -PYTREE_CHECKPOINTABLE_KEY = "pytree" +STATE_CHECKPOINTABLE_KEY = "state" EMPTY_CHECKPOINTABLE_KEY = "" AUTO_CHECKPOINTABLE_KEY = "AUTO" @@ -107,18 +107,18 @@ async def validate_pytree( """ ... - async def load_pytree( + async def load( self, path: Path, checkpointable_name: str | None = None, - abstract_pytree: Any | None = None, + abstract_state: Any | None = None, ) -> Awaitable[Any]: """Loads a PyTree from the checkpoint. Args: path: The path to the checkpoint. checkpointable_name: The name of the checkpointable to load. - abstract_pytree: The abstract PyTree structure. + abstract_state: The abstract PyTree structure. Returns: An awaitable PyTree. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py index 75421d413..ceb50be27 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py @@ -35,6 +35,9 @@ from orbax.checkpoint.experimental.v1._src.tree import types as tree_types +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY + + class CheckpointVersion(enum.Enum): V0 = 0 V1 = 1 @@ -206,14 +209,12 @@ async def get_checkpointable_names(self, path: Path) -> list[str]: for n in existing_names if n not in checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS ] - if checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY in checkpointable_names: + if STATE_CHECKPOINTABLE_KEY in checkpointable_names: # Prioritize 'pytree' checkpointable name if present. - other_names = sorted([ - n - for n in checkpointable_names - if n != checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY - ]) - names = [checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY] + other_names + other_names = sorted( + [n for n in checkpointable_names if n != STATE_CHECKPOINTABLE_KEY] + ) + names = [STATE_CHECKPOINTABLE_KEY] + other_names else: names = sorted(checkpointable_names) return names @@ -251,7 +252,7 @@ async def metadata( ) async def _validate_pytree(self, path: Path, checkpointable_name: str | None): - """Validates checkpoint written by `save_pytree` or `save_checkpointables`. + """Validates checkpoint written by `save` or `save_checkpointables`. Validates that checkpointable_name is a Pytree checkpoint by verifying its path contains the required metadata files. @@ -371,11 +372,11 @@ async def validate_pytree( f"Failed to interpret path {path} as a V1 Orbax PyTree." ) from e - async def load_pytree( + async def load( self, path: Path, checkpointable_name: str | None = None, - abstract_pytree: ( + abstract_state: ( tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, ) -> Awaitable[Any]: @@ -384,7 +385,7 @@ async def load_pytree( Args: path: The path to the checkpoint. checkpointable_name: The name of the pytree checkpointable to load. - abstract_pytree: The abstract pytree to load. + abstract_state: The abstract pytree to load. Returns: An awaitable containing the loaded pytree. @@ -392,14 +393,14 @@ async def load_pytree( checkpoint_metadata = await read_checkpoint_metadata(path) handlers_for_load = await handler_resolution.get_handlers_for_load( self._handler_registry, - {checkpointable_name: abstract_pytree}, + {checkpointable_name: abstract_state}, checkpoint_metadata, ) handler_for_load = handlers_for_load[checkpointable_name] result = await handler_for_load.load( path / checkpointable_name, - abstract_pytree, + abstract_state, ) return result @@ -445,19 +446,19 @@ async def load_checkpointables( # Read checkpoint metadata and resolve handlers for loading. checkpoint_metadata = await read_checkpoint_metadata(path) # TODO(b/484400394): Find a better way to inform the user that they need - # to use load_pytree(..., checkpointable_name=None) when item_handlers is + # to use load(..., checkpointable_name=None) when item_handlers is # a str. An idea is to create a seperate validate_checkpointables method # and we can read in checkpoint metadata at validation time for both # validate_pytree and validate_checkpointables operations and warn the user # know if they are trying to load a composite checkpoint by calling - # load_pytree(checkpointable_name=None) or trying to load a composite + # load(checkpointable_name=None) or trying to load a composite # checkpoint as a pytree checkpoint respectively. if isinstance(checkpoint_metadata.item_handlers, str): logging.warning( "Checkpoint looks like a legacy V0 checkpoint. This is only" " supported for legacy V0 checkpoints. If you intended to load a" " pytree checkpoint from the given path, then please consider using" - " `loading.load_pytree(..., checkpointable_name=None)` instead." + " `loading.load(..., checkpointable_name=None)` instead." ) handlers_for_load = await handler_resolution.get_handlers_for_load( self._handler_registry, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py index b880ad7f6..71a4e9ba3 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py @@ -30,6 +30,8 @@ from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout + from orbax.checkpoint.experimental.v1._src.layout import orbax_layout from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization from orbax.checkpoint.experimental.v1._src.partial import path as partial_path_lib @@ -39,6 +41,7 @@ from orbax.checkpoint.experimental.v1._src.testing import path_utils +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY CHECKPOINT_METADATA = orbax_layout.CHECKPOINT_METADATA ORBAX_CHECKPOINT_INDICATOR_FILE = orbax_layout.ORBAX_CHECKPOINT_INDICATOR_FILE InternalCheckpointMetadata = ( @@ -175,8 +178,11 @@ def test_init(self): self.assertTrue(layout._handler_registry.has('pytree_foo')) self.assertEqual(layout._handler_registry.get('pytree_foo'), PyTreeHandler) - self.assertTrue(layout._handler_registry.has('pytree')) - self.assertEqual(layout._handler_registry.get('pytree'), PyTreeHandler) + self.assertTrue(layout._handler_registry.has(STATE_CHECKPOINTABLE_KEY)) + self.assertEqual( + layout._handler_registry.get(STATE_CHECKPOINTABLE_KEY), + PyTreeHandler, + ) @parameterized.product( save_checkpointables=({'foo': {'a': 1}, 'bar': {'x': 5}},), @@ -231,7 +237,7 @@ def test_save_load_checkpointables( ): if with_name: pairs_to_register = [ - (PyTreeHandler, 'pytree'), + (PyTreeHandler, 'state'), (FooHandler, 'foo'), ] else: @@ -245,7 +251,7 @@ def test_save_load_checkpointables( layout = OrbaxLayout() layout._handler_registry = registry - checkpointables = {'pytree': {'a': 1}, 'foo': Foo(x=1, y='foo')} + checkpointables = {'state': {'a': 1}, 'foo': Foo(x=1, y='foo')} self.save( layout, self.directory, @@ -341,8 +347,8 @@ def test_partial_save_and_finalize(self, finalize_with_partial_path: bool): 'foo_list': PartialSavePyTree([{}, {'b2': 4}]), } merged_checkpointables = tree_structure_utils.merge_trees( - {k: v.pytree for k, v in first_save_checkpointables.items()}, - {k: v.pytree for k, v in second_save_checkpointables.items()}, + {k: v.state for k, v in first_save_checkpointables.items()}, + {k: v.state for k, v in second_save_checkpointables.items()}, ) registry = self.create_registry(include_global_registry=False) registry.add(StatefulCheckpointableHandler) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py index 3091a5880..1a81b793d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py @@ -31,6 +31,7 @@ import safetensors.numpy +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY np_save_file = safetensors.numpy.save_file OrbaxLayout = orbax_layout.OrbaxLayout InvalidLayoutError = orbax_layout.InvalidLayoutError @@ -72,7 +73,7 @@ def setUp(self): } self.custom_metadata = {'framework': 'JAX', 'version': '1.0'} np_save_file(self.object_to_save, self.safetensors_path) - saving.save_pytree( + saving.save( self.orbax_path / '0', self.object_to_save, custom_metadata=self.custom_metadata, @@ -117,7 +118,9 @@ async def test_validate_no_indicator_or_metadata_files(self): indicator_path.rmtree() # Remove the indicator file metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA' metadata_path.rmtree() - pytree_metadata_path = self.orbax_path / '0' / 'pytree' / '_METADATA' + pytree_metadata_path = ( + self.orbax_path / '0' / STATE_CHECKPOINTABLE_KEY / '_METADATA' + ) pytree_metadata_path.rmtree() with self.assertRaises(InvalidLayoutError): await layout.validate(self.orbax_path / '0') @@ -137,12 +140,14 @@ async def test_load_orbax_checkpoint(self): ) restored_checkpointables = await restored_checkpointables_await test_utils.assert_tree_equal( - self, restored_checkpointables['pytree'], self.object_to_save + self, + restored_checkpointables[STATE_CHECKPOINTABLE_KEY], + self.object_to_save, ) async def test_save_with_custom_name(self): custom_path = epath.Path(self.test_dir.full_path) / 'custom_checkpoint' - saving.save_pytree( + saving.save( custom_path, self.object_to_save, checkpointable_name='my_custom_name', @@ -164,7 +169,7 @@ async def test_metadata(self): self.assertIsInstance(result_metadata, metadata_types.CheckpointMetadata) expected_structs = { - checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: { + STATE_CHECKPOINTABLE_KEY: { 'a': numpy_leaf_handler.NumpyMetadata( shape=(9,), dtype=np.dtype(np.int32), @@ -200,7 +205,7 @@ def setUp(self): / 'ckpt' ) self.pytree, _ = array_test_utils.create_numpy_pytree() - saving.save_pytree(self.directory, self.pytree) + saving.save(self.directory, self.pytree) async def test_nonexistent_path(self): with self.assertRaises(FileNotFoundError): @@ -217,21 +222,25 @@ async def test_missing_indicator_file(self): await OrbaxLayout()._validate(self.directory) async def test_deleted_pytree(self): - (self.directory / 'pytree').rmtree() + (self.directory / STATE_CHECKPOINTABLE_KEY).rmtree() await OrbaxLayout()._validate(self.directory) with self.assertRaises(FileNotFoundError): - await OrbaxLayout()._validate_pytree(self.directory, 'pytree') + await OrbaxLayout()._validate_pytree( + self.directory, STATE_CHECKPOINTABLE_KEY + ) async def test_missing_checkpointable_matching_name(self): with self.assertRaises(FileNotFoundError): await OrbaxLayout()._validate_pytree(self.directory, 'foo') async def test_no_pytree_metadata(self): - await _unlink_pytree_metadata(self.directory / 'pytree') + await _unlink_pytree_metadata(self.directory / STATE_CHECKPOINTABLE_KEY) with self.assertRaises(FileNotFoundError): - await OrbaxLayout()._validate_pytree(self.directory, 'pytree') + await OrbaxLayout()._validate_pytree( + self.directory, STATE_CHECKPOINTABLE_KEY + ) class IsOrbaxV1CheckpointTest(parameterized.TestCase): @@ -251,7 +260,7 @@ def setUp(self): } self.custom_metadata = {'framework': 'JAX', 'version': '1.0'} np_save_file(self.object_to_save, self.safetensors_path) - saving.save_pytree( + saving.save( self.orbax_path / '0', self.object_to_save, custom_metadata=self.custom_metadata, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py index 25d92ef8a..1365ddad5 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py @@ -304,24 +304,24 @@ async def validate_pytree( f"Failed to interpret path {path} as a V0 Orbax PyTree." ) from e - async def load_pytree( + async def load( self, path: Path, checkpointable_name: str | None = None, - abstract_pytree: ( + abstract_state: ( tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, ) -> Awaitable[Any]: """Loads a V0 PyTree checkpoint. Attempts to load `checkpointable_name` pytree by finding its corresponding - handler from the metadata. If `abstract_pytree` is provided, it attempts to + handler from the metadata. If `abstract_state` is provided, it attempts to load the checkpoint as a PyTree of the given abstract pytree. Args: path: The path to the checkpoint directory. checkpointable_name: The name of the pytree checkpointable to load. - abstract_pytree: The abstract pytree to load. + abstract_state: The abstract pytree to load. Returns: An awaitable containing the loaded PyTree. @@ -332,12 +332,12 @@ async def load_pytree( if checkpointable_name is None: # Read checkpoint metadata and resolve pytree handler for loading. # TODO(b/484400394): Find a better way to inform the user that they need - # to use load_pytree(..., checkpointable_name=None) when item_handlers is + # to use load(..., checkpointable_name=None) when item_handlers is # a str. An idea is to create a seperate validate_checkpointables method # and we can read in checkpoint metadata at validation time for both # validate_pytree and validate_checkpointables operations and warn the # user if they are trying to load a composite checkpoint by calling - # load_pytree(checkpointable_name=None) or trying to load a composite + # load(checkpointable_name=None) or trying to load a composite # checkpoint as a pytree checkpoint respectively. checkpoint_metadata = await orbax_layout.read_checkpoint_metadata( path @@ -345,27 +345,27 @@ async def load_pytree( if isinstance(checkpoint_metadata.item_handlers, dict): logging.warning( "Checkpoint looks like a V1 checkpoint. Calling" - " `loading.load_pytree(..., checkpointable_name=None)` is only" + " `loading.load(..., checkpointable_name=None)` is only" " supported for loading legacy V0 checkpoints. If you intended to" " load a specific checkpointable from the given path, then please" - " consider using `load_pytree` or `load_checkpointables` instead." + " consider using `load` or `load_checkpointables` instead." ) handler_for_load = ( await handler_resolution.get_handler_for_load_direct_pytree( path.name, self._handler_registry, - abstract_pytree, + abstract_state, checkpoint_metadata, ) ) result = await handler_for_load.load( path, - abstract_pytree, + abstract_state, ) return result else: - return await self._orbax_layout.load_pytree( - path, checkpointable_name, abstract_pytree + return await self._orbax_layout.load( + path, checkpointable_name, abstract_state ) async def load_checkpointables( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout_test.py index 9df0b8aed..a68ad247a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout_test.py @@ -205,33 +205,35 @@ async def test_valid_pytree(self): self.directory, 'state' ) - async def test_load_pytree(self): + async def test_load(self): layout = orbax_v0_layout.OrbaxV0Layout() loaded = await ( - await layout.load_pytree(self.directory, 'state', self.pytree) + await layout.load(self.directory, 'state', abstract_state=self.pytree) ) test_utils.assert_tree_equal(self, self.pytree, loaded) - async def test_load_pytree_no_checkpoint_metadata(self): + async def test_load_no_checkpoint_metadata(self): await async_path.unlink(self.directory / '_CHECKPOINT_METADATA') layout = orbax_v0_layout.OrbaxV0Layout() loaded = await ( - await layout.load_pytree(self.directory, 'state', self.pytree) + await layout.load(self.directory, 'state', abstract_state=self.pytree) ) test_utils.assert_tree_equal(self, self.pytree, loaded) - async def test_load_pytree_no_checkpoint_metadata_or_target_pytree(self): + async def test_load_no_checkpoint_metadata_or_target_pytree(self): await async_path.unlink(self.directory / '_CHECKPOINT_METADATA') layout = orbax_v0_layout.OrbaxV0Layout() - loaded = await (await layout.load_pytree(self.directory, 'state')) + loaded = await (await layout.load(self.directory, 'state')) test_utils.assert_tree_equal(self, self.pytree, loaded) - async def test_load_pytree_v0_checkpoint(self): + async def test_load_v0_checkpoint(self): layout = orbax_v0_layout.OrbaxV0Layout() loaded = await ( - await layout.load_pytree(self.v0_pytree_directory, None, self.pytree) + await layout.load( + self.v0_pytree_directory, None, abstract_state=self.pytree + ) ) test_utils.assert_tree_equal(self, self.pytree, loaded) @@ -245,7 +247,9 @@ async def test_v0_pytree_no_checkpoint_metadata(self): await orbax_v0_layout.OrbaxV0Layout()._validate(self.v0_pytree_directory) loaded = await ( - await layout.load_pytree(self.v0_pytree_directory, None, self.pytree) + await layout.load( + self.v0_pytree_directory, None, abstract_state=self.pytree + ) ) # Passes because we still have the pytree metadata. test_utils.assert_tree_equal(self, self.pytree, loaded) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py index 7356197ef..ceb0da1aa 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py @@ -22,12 +22,16 @@ from orbax.checkpoint._src.handlers import composite_checkpoint_handler from orbax.checkpoint._src.handlers import standard_checkpoint_handler from orbax.checkpoint.checkpoint_manager import CheckpointManager +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.layout import orbax_layout from orbax.checkpoint.experimental.v1._src.layout import orbax_v0_layout from orbax.checkpoint.experimental.v1._src.layout import registry from orbax.checkpoint.experimental.v1._src.saving import saving +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY + + async def get_pytree_spec(path, layout_enum, pytree_name=None): resolver = await registry.CheckpointLayoutResolver.resolve( path, layout_enum, pytree_name=pytree_name @@ -44,7 +48,7 @@ def setUp(self): super().setUp() self.root_directory = epath.Path(self.create_tempdir()) self.v1_directory = self.root_directory / 'v1' - saving.save_pytree( + saving.save( self.v1_directory, {'a': 1, 'b': 2}, ) @@ -68,13 +72,13 @@ async def test_root_directory(self): async def test_v1_valid_name(self): layout, resolved_name = await get_pytree_spec( - self.v1_directory, CheckpointLayoutEnum.ORBAX, 'pytree' + self.v1_directory, CheckpointLayoutEnum.ORBAX, STATE_CHECKPOINTABLE_KEY ) - self.assertEqual(resolved_name, 'pytree') + self.assertEqual(resolved_name, STATE_CHECKPOINTABLE_KEY) self.assertIsInstance(layout, orbax_layout.OrbaxLayout) self.assertTrue(await orbax_layout.has_indicator_file(self.v1_directory)) - @parameterized.parameters([None, 'state', 'params']) + @parameterized.parameters([None, 'pytree', 'params']) async def test_v1_invalid_name(self, checkpointable_name): with self.assertRaises(registry.InvalidLayoutError): await get_pytree_spec( @@ -99,7 +103,9 @@ async def test_v0_invalid_name(self, checkpointable_name): async def test_v1_direct_path(self): with self.assertRaises(registry.InvalidLayoutError): await get_pytree_spec( - self.v1_directory / 'pytree', CheckpointLayoutEnum.ORBAX, None + self.v1_directory / STATE_CHECKPOINTABLE_KEY, + CheckpointLayoutEnum.ORBAX, + None, ) async def test_v0_child_path_load_failure(self): @@ -123,7 +129,7 @@ async def test_v0_checkpoint_path(self): async def test_v1_checkpoint_path_missing_pytree_metadata(self): (self.v1_directory / orbax_layout.ORBAX_CHECKPOINT_INDICATOR_FILE).unlink() - (self.v1_directory / 'pytree' / '_METADATA').unlink() + (self.v1_directory / STATE_CHECKPOINTABLE_KEY / '_METADATA').unlink() with self.assertRaises(registry.InvalidLayoutError): await get_pytree_spec( self.v1_directory, CheckpointLayoutEnum.ORBAX, None @@ -179,7 +185,7 @@ def setUp(self): super().setUp() self.root_directory = epath.Path(self.create_tempdir()) self.v1_directory = self.root_directory / 'v1' - saving.save_pytree( + saving.save( self.v1_directory, {'a': 1, 'b': 2}, ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index cbaa5ce02..693c4dec9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -32,6 +32,7 @@ from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import types +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY CheckpointLayout = checkpoint_layout.CheckpointLayout InvalidLayoutError = checkpoint_layout.InvalidLayoutError Path = types.Path @@ -458,7 +459,7 @@ async def load_single_host(self) -> dict[str, np.ndarray]: return tensors async def load_multi_host( - self, abstract_pytree: dict[str, Any] + self, abstract_state: dict[str, Any] ) -> tuple[dict[str, Any], dict[str, float]]: """Loads tensors from a single safetensors file in multi-host mode. @@ -474,7 +475,7 @@ async def load_multi_host( requires no data processing. Args: - abstract_pytree: A flat dictionary mapping tensor names to + abstract_state: A flat dictionary mapping tensor names to jax.ShapeDtypeStruct objects specifying target shape and sharding. Returns: @@ -486,7 +487,7 @@ async def load_multi_host( ValueError: If the number of devices is not uniform across hosts, or if a sharded tensor's first dimension is not divisible by the number of devices per host. - KeyError: If a tensor in the abstract_pytree is not found in the + KeyError: If a tensor in the abstract_state is not found in the safetensors file. """ io_time = 0.0 @@ -523,7 +524,7 @@ async def load_multi_host( if not context_lib.get_context().safetensors_options.ignore_load_sharding: max_shard_shape_per_dtype = _calculate_max_shard_shapes( - abstract_pytree, + abstract_state, header, ) zero_buffers = _create_shared_zero_buffers( @@ -543,7 +544,7 @@ async def load_multi_host( ) # Process each tensor in the requested PyTree. - for name, abstract_leaf in abstract_pytree.items(): + for name, abstract_leaf in abstract_state.items(): if name not in header: continue @@ -578,7 +579,7 @@ async def _get_loaders(self) -> list[_SingleFileLoader]: paths = [self.path] return [_SingleFileLoader(path) for path in paths] - async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any: + async def _load_single_host(self, abstract_state: dict[str, Any]) -> Any: """Loads a safetensors checkpoint on a single host.""" # Return NumPy arrays. # Load from all files and merge. @@ -598,20 +599,20 @@ async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any: "[safetensors][single-host] Loaded tensors in %.0fs", time.time() - start, ) - if not abstract_pytree: + if not abstract_state: return restored_pytree start = time.time() - for k in abstract_pytree: + for k in abstract_state: if k not in restored_pytree: raise KeyError(f"Tensor '{k}' not found in Safetensors checkpoint.") restored_pytree = { k: jax.device_put( restored_pytree[k], - device=abstract_pytree[k].sharding, + device=abstract_state[k].sharding, ) - for k in abstract_pytree + for k in abstract_state } logging.info( "[safetensors][single-host] Host-to-device transfer in %.0fs", @@ -620,12 +621,12 @@ async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any: return restored_pytree async def _load_multi_host( - self, abstract_pytree: dict[str, Any] | None + self, abstract_state: dict[str, Any] | None ) -> Any: """Loads a safetensors checkpoint on multiple hosts.""" - if not abstract_pytree: + if not abstract_state: raise ValueError( - "abstract_pytree must be provided for multi-host loading." + "abstract_state must be provided for multi-host loading." ) loaders = await self._get_loaders() @@ -635,7 +636,7 @@ async def _load_multi_host( start = time.time() load_ops = [] for loader in loaders: - load_ops.append(loader.load_multi_host(abstract_pytree)) + load_ops.append(loader.load_multi_host(abstract_state)) restored_pytree = {} total_io_time = 0.0 @@ -649,7 +650,7 @@ async def _load_multi_host( restored_pytree[name] = arr # Validate that all requested tensors were found in at least one file. - for k in abstract_pytree: + for k in abstract_state: if k not in restored_pytree: raise KeyError(f"Tensor '{k}' not found in Safetensors checkpoint.") @@ -667,18 +668,18 @@ async def _load_multi_host( async def load_safetensors( self, - abstract_pytree: dict[str, Any] | None = None, + abstract_state: dict[str, Any] | None = None, ) -> Any: """Calls the correct safetensors loading function.""" - if abstract_pytree is not None and not tree_utils.is_flat_dict( - abstract_pytree + if abstract_state is not None and not tree_utils.is_flat_dict( + abstract_state ): raise ValueError("The PyTree is not a flat dictionary.") if multihost.process_count() > 1: - return await self._load_multi_host(abstract_pytree) + return await self._load_multi_host(abstract_state) else: - return await self._load_single_host(abstract_pytree) + return await self._load_single_host(abstract_state) async def load_metadata(self): """Loads the metadata from a safetensors checkpoint.""" @@ -711,7 +712,7 @@ async def load_metadata(self): logging.info("[safetensors] Loaded metadata in %.0fs", time.time() - start) return metadata_types.CheckpointMetadata[dict[str, Any]]( path=self.path, - metadata={checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: metadata}, + metadata={STATE_CHECKPOINTABLE_KEY: metadata}, commit_timestamp_nsecs=commit_timestamp_nsecs, custom_metadata=custom_metadata, ) @@ -764,29 +765,29 @@ async def validate(self, path: Path): ) async def get_checkpointable_names(self, path: Path) -> list[str]: - return [checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY] + return [STATE_CHECKPOINTABLE_KEY] async def validate_pytree( self, path: Path, checkpointable_name: str | None ) -> None: return - async def load_pytree( + async def load( self, path: Path, checkpointable_name: str | None = None, - abstract_pytree: Any | None = None, + abstract_state: Any | None = None, ) -> Awaitable[Any]: """Loads a NumPy checkpoint file. - If `abstract_pytree` is provided, it attempts to load numpy arrays as + If `abstract_state` is provided, it attempts to load numpy arrays as sharded `jax.Arrays` onto devices. Args: path: The path to load the checkpoint from. checkpointable_name: The name of the pytree checkpointable to load, unsused in this case. - abstract_pytree: An optional PyTree of abstract arrays specifying sharding + abstract_state: An optional PyTree of abstract arrays specifying sharding information. Returns: @@ -794,7 +795,7 @@ async def load_pytree( """ del checkpointable_name self._loader = _MultiFileLoader(path) - return self._loader.load_safetensors(abstract_pytree) + return self._loader.load_safetensors(abstract_state) async def save( self, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py index 8cb6456b6..65cd6ee18 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py @@ -157,9 +157,7 @@ async def test_sharding_scenarios( expected_tensor = jax.device_put(tensor_to_save, abstract_sharding) layout = SafetensorsLayout() - restore_fn = await layout.load_pytree( - st_path, abstract_pytree=abstract_state - ) + restore_fn = await layout.load(st_path, abstract_state=abstract_state) restored_tensor = await restore_fn restored_tensor = restored_tensor["params.tensor"] @@ -190,9 +188,7 @@ async def test_load_without_global_reshard_single_tensor(self): ignore_load_sharding=True ) ): - restore_fn = await layout.load_pytree( - st_path, abstract_pytree=abstract_state - ) + restore_fn = await layout.load(st_path, abstract_state=abstract_state) restored_pytree = await restore_fn restored_tensor = restored_pytree["params.tensor"] @@ -234,9 +230,7 @@ async def test_load_without_global_reshard_multi_tensor(self): ignore_load_sharding=True ) ): - restore_fn = await layout.load_pytree( - st_path, abstract_pytree=abstract_state - ) + restore_fn = await layout.load(st_path, abstract_state=abstract_state) restored_pytree = await restore_fn # Tensors are expected to be distributed among hosts. @@ -288,9 +282,9 @@ async def test_load_multi_host_memory_efficiency(self): tracemalloc.start() - restore_fn = await layout.load_pytree( + restore_fn = await layout.load( file_path, - abstract_pytree=abstract_pytree, + abstract_state=abstract_pytree, ) pytree = await restore_fn @@ -348,9 +342,9 @@ async def test_load_without_global_reshard_memory_efficiency(self): ignore_load_sharding=True ) ): - restore_fn = await layout.load_pytree( + restore_fn = await layout.load( file_path, - abstract_pytree=abstract_pytree, + abstract_state=abstract_pytree, ) pytree = await restore_fn @@ -441,8 +435,8 @@ async def test_load_sharded_fails_with_nested_abstract_pytree(self): with self.assertRaisesRegex( ValueError, "The PyTree is not a flat dictionary." ): - test_awaitable = await layout.load_pytree( - st_path, abstract_pytree=nested_abstract_pytree + test_awaitable = await layout.load( + st_path, abstract_state=nested_abstract_pytree ) await test_awaitable @@ -469,8 +463,8 @@ async def test_load_sharded_fails_with_wrong_key_abstract_pytree(self): with self.assertRaisesRegex( KeyError, "not found in Safetensors checkpoint" ): - test_awaitable = await layout.load_pytree( - st_path, abstract_pytree=wrong_key_abstract_pytree + test_awaitable = await layout.load( + st_path, abstract_state=wrong_key_abstract_pytree ) await test_awaitable diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py index 0bb8c2178..5ebd85a5d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_test.py @@ -27,6 +27,7 @@ from orbax.checkpoint.experimental.v1._src.saving import saving import safetensors.numpy +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY SafetensorsLayout = safetensors_layout.SafetensorsLayout np_save_file = safetensors.numpy.save_file InvalidLayoutError = checkpoint_layout.InvalidLayoutError @@ -55,7 +56,7 @@ def setUp(self): self.safetensors_path, metadata=self.custom_metadata, ) - saving.save_pytree(self.orbax_path, self.object_to_save) + saving.save(self.orbax_path, self.object_to_save) async def test_valid_safetensors_checkpoint(self): layout = SafetensorsLayout() @@ -103,7 +104,7 @@ async def test_load_safetensors_checkpoint(self, dtype: np.dtype): # Load the checkpoint layout = SafetensorsLayout() - restore_fn = await layout.load_pytree(test_path) + restore_fn = await layout.load(test_path) pytree = await restore_fn # Verify restored data @@ -129,7 +130,7 @@ async def test_load_fails_with_incomplete_dtypes(self): return_value=incomplete_dtypes, spec=True, ): - awaitable_fn = await layout.load_pytree(self.safetensors_path) + awaitable_fn = await layout.load(self.safetensors_path) _ = await awaitable_fn async def test_metadata(self): @@ -139,7 +140,7 @@ async def test_metadata(self): self.assertEqual( metadata.metadata, { - checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: { + STATE_CHECKPOINTABLE_KEY: { 'b': jax.ShapeDtypeStruct(shape=(3,), dtype=np.float32), 'a': jax.ShapeDtypeStruct(shape=(9,), dtype=np.int32), } @@ -194,7 +195,7 @@ async def test_validate_directory_fails_empty(self): async def test_load_directory(self): layout = SafetensorsLayout() - restore_fn = await layout.load_pytree(self.checkpoint_dir) + restore_fn = await layout.load(self.checkpoint_dir) pytree = await restore_fn np.testing.assert_array_equal(pytree['a'], self.data1['a']) np.testing.assert_array_equal(pytree['b'], self.data2['b']) @@ -204,7 +205,7 @@ async def test_load_directory(self): async def test_metadata_directory(self): layout = SafetensorsLayout() metadata = await layout.metadata(self.checkpoint_dir) - pytree_meta = metadata.metadata[checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY] + pytree_meta = metadata.metadata[STATE_CHECKPOINTABLE_KEY] self.assertIn('a', pytree_meta) self.assertIn('b', pytree_meta) self.assertIn('c', pytree_meta) @@ -226,9 +227,7 @@ async def test_load_directory_abstract_tree_all_keys(self): 'c': jax.ShapeDtypeStruct(shape=(2,), dtype=np.int32), 'd': jax.ShapeDtypeStruct(shape=(2,), dtype=np.float32), } - restore_fn = await layout.load_pytree( - self.checkpoint_dir, abstract_pytree=tree - ) + restore_fn = await layout.load(self.checkpoint_dir, abstract_state=tree) pytree = await restore_fn self.assertLen(pytree, 4) np.testing.assert_array_equal(pytree['a'], self.data1['a']) @@ -252,9 +251,7 @@ async def test_load_directory_abstract_tree_sharding(self): shape=(2,), dtype=np.int32, sharding=sharding ), } - restore_fn = await layout.load_pytree( - self.checkpoint_dir, abstract_pytree=tree - ) + restore_fn = await layout.load(self.checkpoint_dir, abstract_state=tree) pytree = await restore_fn self.assertLen(pytree, 2) np.testing.assert_array_equal( @@ -270,9 +267,7 @@ async def test_load_directory_abstract_tree_subset_one_file(self): 'a': jax.ShapeDtypeStruct(shape=(2,), dtype=np.int32), 'c': jax.ShapeDtypeStruct(shape=(2,), dtype=np.int32), } - restore_fn = await layout.load_pytree( - self.checkpoint_dir, abstract_pytree=tree - ) + restore_fn = await layout.load(self.checkpoint_dir, abstract_state=tree) pytree = await restore_fn self.assertLen(pytree, 2) self.assertIn('a', pytree) @@ -286,9 +281,7 @@ async def test_load_directory_abstract_tree_subset_many_files(self): 'a': jax.ShapeDtypeStruct(shape=(2,), dtype=np.int32), 'b': jax.ShapeDtypeStruct(shape=(2,), dtype=np.float32), } - restore_fn = await layout.load_pytree( - self.checkpoint_dir, abstract_pytree=tree - ) + restore_fn = await layout.load(self.checkpoint_dir, abstract_state=tree) pytree = await restore_fn self.assertLen(pytree, 2) self.assertIn('a', pytree) @@ -301,9 +294,7 @@ async def test_load_directory_abstract_tree_key_not_found(self): tree = { 'e': jax.ShapeDtypeStruct(shape=(2,), dtype=np.int32), } - restore_fn = await layout.load_pytree( - self.checkpoint_dir, abstract_pytree=tree - ) + restore_fn = await layout.load(self.checkpoint_dir, abstract_state=tree) with self.assertRaisesRegex(KeyError, "Tensor 'e' not found"): await restore_fn diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py index 36da779fe..76f516dbe 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py @@ -33,6 +33,7 @@ import safetensors.numpy +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY NamedSharding = sharding.NamedSharding Mesh = sharding.Mesh P = sharding.PartitionSpec @@ -63,7 +64,7 @@ def setUp(self): 'b': np.array([0, 1, 0.2], dtype=np.float32), } np_save_file(self.object_to_save, self.safetensors_path) - saving.save_pytree(self.orbax_pytree_path, self.object_to_save) + saving.save(self.orbax_pytree_path, self.object_to_save) # Create a mock Orbax checkpoint checkpointables self.checkpointables_to_save = { @@ -78,14 +79,14 @@ def test_load_safetensors_checkpoint(self): with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS ): - pytree = loading.load_pytree(self.safetensors_path) + pytree = loading.load(self.safetensors_path) self.assertIsInstance(pytree, dict) np.testing.assert_array_equal(pytree['a'], self.object_to_save['a']) # TODO(b/430651483) np.testing.assert_allclose(pytree['b'], self.object_to_save['b']) def test_load_orbax_pytree_checkpoint(self): - pytree = loading.load_pytree(self.orbax_pytree_path) + pytree = loading.load(self.orbax_pytree_path) test_utils.assert_tree_equal(self, self.object_to_save, pytree) def test_load_orbax_checkpointables_checkpoint(self): @@ -99,7 +100,7 @@ def test_load_bad_path_orbax_ckpt(self, layout_enum): # User provides a directory of Orbax checkpoints, not specific one. with context_lib.Context(checkpoint_layout=layout_enum): with self.assertRaises(InvalidLayoutError): - loading.load_pytree( + loading.load( epath.Path(self.test_dir.full_path), ) @@ -110,7 +111,7 @@ def test_load_bad_path_safetensors_ckpt(self, layout_enum): # User provides a empty directory of SafeTensors checkpoints, not a file. with context_lib.Context(checkpoint_layout=layout_enum): with self.assertRaises(InvalidLayoutError): - loading.load_pytree( + loading.load( epath.Path(self.test_dir_safetensors.full_path), ) @@ -121,7 +122,7 @@ def test_load_safetensors_ckpt_from_dir(self): with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS ): - pytree = loading.load_pytree(safetensors_dir) + pytree = loading.load(safetensors_dir) self.assertIsInstance(pytree, dict) np.testing.assert_array_equal(pytree['a'], self.object_to_save['a']) np.testing.assert_allclose(pytree['b'], self.object_to_save['b']) @@ -129,19 +130,17 @@ def test_load_safetensors_ckpt_from_dir(self): def test_nonexistent_path(self): # User provides a path that does not exist. with self.assertRaises(InvalidLayoutError): - loading.load_pytree( + loading.load( epath.Path(self.test_dir.full_path) / 'nonexistent_path', ) - def test_load_pytree_with_checkpoint_metadata(self): + def test_load_with_checkpoint_metadata(self): abstract_pytree = self.object_to_save metadata = metadata_types.CheckpointMetadata( path=self.orbax_pytree_path, metadata=abstract_pytree ) - loaded = loading.load_pytree( - self.orbax_pytree_path, abstract_pytree=metadata - ) + loaded = loading.load(self.orbax_pytree_path, abstract_state=metadata) test_utils.assert_tree_equal(self, self.object_to_save, loaded) def test_load_checkpointables_with_checkpoint_metadata(self): @@ -159,7 +158,7 @@ def test_load_checkpointables_with_checkpoint_metadata(self): (options_lib.CheckpointLayout.SAFETENSORS,), (options_lib.CheckpointLayout.ORBAX,), ) - def test_load_pytree_async(self, layout: options_lib.CheckpointLayout): + def test_load_async(self, layout: options_lib.CheckpointLayout): original_finalize_load = loading._LoadPyTreeResponse._finalize_load async def sleep_and_load(*args, **kwargs): @@ -183,11 +182,11 @@ async def sleep_and_load(*args, **kwargs): with context_lib.Context(checkpoint_layout=layout): if layout != options_lib.CheckpointLayout.SAFETENSORS: with self.assertRaises(NotImplementedError): - loading.load_pytree_async(directory) + loading.load_async(directory) return start = time.time() - response = loading.load_pytree_async(directory) + response = loading.load_async(directory) self.assertLess(time.time() - start, 1) loaded = response.result() @@ -200,7 +199,7 @@ def test_load_auto_resolution_mode_orbax(self): with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.ORBAX ): - loaded_orbax = loading.load_pytree( + loaded_orbax = loading.load( self.orbax_pytree_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, ) @@ -210,7 +209,7 @@ def test_load_auto_resolution_mode_safetensors(self): with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS ): - loaded_safe = loading.load_pytree( + loaded_safe = loading.load( self.safetensors_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, ) @@ -220,19 +219,21 @@ def test_load_auto_multiple_checkpointables_priority(self): # Save a checkpoint structure containing multiple checkpointable names. checkpointables = { 'analytics': {'a': np.array([1, 2, 3])}, - 'pytree': {'a': np.array([1, 2, 3])}, - 'state': {'b': np.array([4, 5, 6])}, + STATE_CHECKPOINTABLE_KEY: {'a': np.array([1, 2, 3])}, + 'pytree': {'b': np.array([4, 5, 6])}, } multiple_path = epath.Path(self.test_dir.full_path) / 'multi_checkpoint' saving.save_checkpointables(multiple_path, checkpointables) - # Triggering AUTO loading mode should prioritize resolving 'pytree'. + # Triggering AUTO loading mode should prioritize resolving state. with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.ORBAX ): - loaded = loading.load_pytree(multiple_path) + loaded = loading.load(multiple_path) - test_utils.assert_tree_equal(self, checkpointables['pytree'], loaded) + test_utils.assert_tree_equal( + self, checkpointables[STATE_CHECKPOINTABLE_KEY], loaded + ) def test_load_auto_non_pytree_fallback(self): # Save a checkpoint that intentionally omits the standard 'pytree' key. @@ -245,7 +246,7 @@ def test_load_auto_non_pytree_fallback(self): with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.ORBAX ): - loaded = loading.load_pytree( + loaded = loading.load( fallback_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, ) @@ -255,8 +256,8 @@ def test_load_auto_non_pytree_fallback(self): self, custom_checkpointables['custom_state'], loaded ) - def test_load_pytree_with_abstract_mesh(self): - # Tests that load_pytree works with abstract_pytree containing an + def test_load_with_abstract_mesh(self): + # Tests that load works with abstract_state containing an # AbstractMesh, similar to the pattern used whent loading NNX models. # See: # https://flax.readthedocs.io/en/stable/guides/checkpointing.html#restore-checkpoints @@ -269,7 +270,7 @@ def test_load_pytree_with_abstract_mesh(self): data = {'w': np.ones((8, 8), dtype=np.float32)} save_path = epath.Path(self.test_dir.full_path) / 'abstract_mesh_checkpoint' - saving.save_pytree(save_path, data) + saving.save(save_path, data) # Construct an AbstractMesh and NamedSharding and an abstract_pytree that # uses it, simulating the output of nnx.eval_shape(). @@ -286,7 +287,7 @@ def test_load_pytree_with_abstract_mesh(self): # Load with a concrete mesh context and validate. concrete_mesh = Mesh(np.array(jax.devices()[:1]).reshape(1, 1), ('x', 'y')) with jax.set_mesh(concrete_mesh): - loaded = loading.load_pytree(save_path, abstract_pytree=abstract_pytree) + loaded = loading.load(save_path, abstract_state=abstract_pytree) # Convert to numpy for comparison with the original numpy 'data' loaded_np = jax.tree.map(np.array, loaded) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py index f22ff71fd..bee25138b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py @@ -39,7 +39,7 @@ from orbax.checkpoint.experimental.v1._src.tree import types as tree_types -PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY AUTO_CHECKPOINTABLE_KEY = checkpoint_layout.AUTO_CHECKPOINTABLE_KEY AbstractPyTree = tree_types.PyTreeOf[tree_types.AbstractLeaf] CheckpointMetadata = metadata_types.CheckpointMetadata @@ -102,9 +102,9 @@ def _resolve_abstract_mesh(leaf): return abstract_checkpointables -def load_pytree( +def load( path: path_types.PathLike, - abstract_pytree: ( + abstract_state: ( AbstractPyTree | CheckpointMetadata[AbstractPyTree] | None ) = None, *, @@ -119,25 +119,25 @@ def load_pytree( This function must be called on all available controller processes. The operation blocks until complete. For improved performance, consider using - :py:func:`.load_pytree_async` instead. + :py:func:`.load_async` instead. - If `abstract_pytree` is not provided, the `PyTree` will be loaded exactly as + If `abstract_state` is not provided, the `PyTree` will be loaded exactly as saved. IMPORTANT: Loading is more brittle and error-prone when not providing - `abstract_pytree`. Always provide `abstract_pytree` if possible. Note that + `abstract_state`. Always provide `abstract_state` if possible. Note that you can always obtain the tree structure from a saved checkpoint using - :py:func:`.pytree_metadata`. + :py:func:`.metadata`. - Providing the `abstract_pytree` guarantees two things: + Providing the `abstract_state` guarantees two things: - 1. The restored tree will exactly match the structure of `abstract_pytree` (or + 1. The restored tree will exactly match the structure of `abstract_state` (or raise an error if it is impossible to guarantee this). For example, if - `abstract_pytree` is a custom object registered as a `PyTree`, the checkpoint + `abstract_state` is a custom object registered as a `PyTree`, the checkpoint will be restored as the same object, if possible. 2. The leaves of the restored tree will be restored with the properties - indicated by the abstract leaves. For example, if a leaf in `abstract_pytree` + indicated by the abstract leaves. For example, if a leaf in `abstract_state` is a `jax.ShapeDtypeStruct`, the restored leaf will be a `jax.Array` with the same shape and `dtype`. Each `AbstractLeaf` has a corresponding `Leaf` that is restored. See `orbax.checkpoint.v1.tree` for a table @@ -150,26 +150,26 @@ def load_pytree( path = '/tmp/my_checkpoint' # Save a checkpoint - pytree = {'a': jnp.arange(8), 'b': jnp.zeros(4)} - ocp.save_pytree(path, pytree) + state = {'a': jnp.arange(8), 'b': jnp.zeros(4)} + ocp.save(path, state) # Load the checkpoint # Highly recommended to provide the abstract pytree (structure/shapes) - abstract_pytree = jax.eval_shape(lambda: pytree) + abstract_state = jax.eval_shape(lambda: state) # Method A: Load using the abstract structure. # This automatically looks for the 'pytree' subdirectory inside 'path'. - restored = ocp.load_pytree(path, abstract_pytree) + restored = ocp.load(path, abstract_state) # Method B: Infer structure from file (Not recommended for production use) # cases or for complex trees. - restored_inferred = ocp.load_pytree(path) + restored_inferred = ocp.load(path) Args: path: The path to load the checkpoint from. This path must contain a subdirectory with name provided by `checkpointable_name`. See `checkpointable_name` for more details. - abstract_pytree: Provides a tree structure for the checkpoint to be restored + abstract_state: Provides a tree structure for the checkpoint to be restored into. May be omitted to load exactly as saved, but this is much more brittle than providing the tree. checkpointable_name: The name of the checkpointable to load. A subdirectory @@ -193,7 +193,7 @@ def load_pytree( async_origin=False, ).record_start() - abstract_pytree = _standardize_abstract_checkpointables(abstract_pytree) + abstract_state = _standardize_abstract_checkpointables(abstract_state) validation.validate_pytree_checkpointable_name(checkpointable_name) ctx = context_lib.get_context() @@ -207,10 +207,10 @@ def load_pytree( loaded_pytree = _load_impl( path, functools.partial( - resolver.layout.load_pytree, + resolver.layout.load, path=path, checkpointable_name=resolver.pytree_name, - abstract_pytree=abstract_pytree, + abstract_state=abstract_state, ), start_time=start_time, ) @@ -232,7 +232,7 @@ def load_checkpointables( what a checkpointable is. This function can be used to load any checkpoint saved by - :py:func:`.save_checkpointables` (or :py:func:`.save_pytree`). The path should + :py:func:`.save_checkpointables` (or :py:func:`.save`). The path should contain a number of subdirectories - each of these represents the name of a checkpointable. @@ -360,13 +360,12 @@ def _load_impl( load_fn: LoadFn, start_time: float, ) -> dict[str, Checkpointable] | tree_types.PyTreeOf[tree_types.Leaf]: - """Implementation of loading logic for both :py:func:`.load_checkpointables` and :py:func:`.load_pytree`. + """Implementation of loading logic for both :py:func:`.load_checkpointables` and :py:func:`.load`. Args: path: The path to the checkpoint. load_fn: A function that returns an awaitable for loading the checkpoint - based on either :py:func:`.load_checkpointables` or - :py:func:`.load_pytree`. + based on either :py:func:`.load_checkpointables` or :py:func:`.load`. start_time: The time when the loading process started. Returns: @@ -407,7 +406,7 @@ async def _load() -> Checkpointable: class _LoadPyTreeResponse(AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]): - """An :py:class:`.AsyncResponse` for :py:func:`.load_pytree_async`.""" + """An :py:class:`.AsyncResponse` for :py:func:`.load_async`.""" def __init__( self, @@ -487,13 +486,13 @@ def result( return self._thread_runner.result(timeout=timeout) -def load_pytree_async( +def load_async( path: path_types.PathLike, - abstract_pytree: ( + abstract_state: ( AbstractPyTree | CheckpointMetadata[AbstractPyTree] | None ) = None, *, - checkpointable_name: str | None = PYTREE_CHECKPOINTABLE_KEY, + checkpointable_name: str | None = STATE_CHECKPOINTABLE_KEY, ) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]: """Loads a PyTree asynchronously. Currently has limited support.""" start_time = time.time() @@ -511,17 +510,17 @@ def load_pytree_async( f'layout, not {ctx.checkpoint_layout}.' ) path = ctx.file_options.path_class(path) - abstract_pytree = _standardize_abstract_checkpointables(abstract_pytree) + abstract_state = _standardize_abstract_checkpointables(abstract_state) validation.validate_pytree_checkpointable_name(checkpointable_name) async def _blocking_load() -> Any: resolver = await layout_registry.CheckpointLayoutResolver.resolve( path, ctx.checkpoint_layout, pytree_name=checkpointable_name ) - return await resolver.layout.load_pytree( + return await resolver.layout.load( path, checkpointable_name=resolver.pytree_name, - abstract_pytree=abstract_pytree, + abstract_state=abstract_state, ) background_awaitable = asyncio_utils.run_sync(_blocking_load()) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py index acb1ffec8..bc23b7f60 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/validation.py @@ -39,7 +39,7 @@ def validate_pytree_checkpointable_name( if checkpointable_name == EMPTY_CHECKPOINTABLE_KEY: raise ValueError( 'Empty string is not supported as a checkpointable name in' - ' `load_pytree`. Checkpointable name must be a valid non-empty string' + ' `load`. Checkpointable name must be a valid non-empty string' ' name or None if loading a legacy V0 direct pytree checkpoint.' ) if checkpointable_name in RESERVED_CHECKPOINTABLE_KEYS: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py index b7641fd6e..461ac8a23 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py @@ -29,13 +29,13 @@ CheckpointMetadata = metadata_types.CheckpointMetadata InvalidLayoutError = errors.InvalidLayoutError PyTreeMetadata = metadata_types.PyTreeMetadata -PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY EMPTY_CHECKPOINTABLE_KEY = checkpoint_layout.EMPTY_CHECKPOINTABLE_KEY AbstractCheckpointable = handler_types.AbstractCheckpointable -def pytree_metadata( +def metadata( path: path_types.PathLike, checkpointable_name: str | None = checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, ) -> CheckpointMetadata[PyTreeMetadata]: @@ -54,7 +54,7 @@ def pytree_metadata( For example:: - metadata = ocp.pytree_metadata(path) # CheckpointMetadata[PyTreeMetadata] + metadata = ocp.metadata(path) # CheckpointMetadata[PyTreeMetadata] metadata.metadata # PyTreeMetadata metadata.init_timestamp_nsecs # Checkpoint creation timestamp. @@ -62,8 +62,8 @@ def pytree_metadata( The metadata can then be used to inform checkpoint loading. For example:: - metadata = ocp.pytree_metadata(path) - restored = ocp.load_pytree(path, metadata) + metadata = ocp.metadata(path) + restored = ocp.load(path, metadata) # Load with altered properties. def _get_abstract_array(arr): @@ -75,7 +75,7 @@ def _get_abstract_array(arr): metadata = dataclasses.replace(metadata, metadata=jax.tree.map(_get_abstract_array, metadata.metadata) ) - ocp.load_pytree(path, metadata) + ocp.load(path, metadata) Args: path: The path to the checkpoint. @@ -87,8 +87,8 @@ def _get_abstract_array(arr): dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard 'pytree' checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first - valid one, and ultimately falls back to interpreting the path as a flat - V0 root layout if no standard pytree exists. + valid one, and ultimately falls back to interpreting the path as a flat V0 + root layout if no standard pytree exists. Returns: A `CheckpointMetadata[PyTreeMetadata]` object. @@ -111,12 +111,12 @@ def _get_abstract_array(arr): # the composite handler into the layout themselves. step_metadata = _checkpointables_metadata_impl(layout, path) if resolved_name is None: - metadata = step_metadata.metadata + tree_metadata = step_metadata.metadata else: - metadata = step_metadata.metadata[resolved_name] + tree_metadata = step_metadata.metadata[resolved_name] return CheckpointMetadata[PyTreeMetadata]( path=path, - metadata=metadata, + metadata=tree_metadata, init_timestamp_nsecs=step_metadata.init_timestamp_nsecs, commit_timestamp_nsecs=step_metadata.commit_timestamp_nsecs, custom_metadata=step_metadata.custom_metadata, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py index 18e44b50f..ebb470097 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py @@ -36,7 +36,7 @@ AbstractFoo = handler_utils.AbstractFoo AbstractBar = handler_utils.AbstractBar InvalidLayoutError = ocp.errors.InvalidLayoutError -PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY class PyTreeMetadataTest(absltest.TestCase): @@ -45,7 +45,7 @@ def setUp(self): super().setUp() self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' self.pytree, self.abstract_pytree = array_test_utils.create_numpy_pytree() - ocp.save_pytree(self.directory, self.pytree) + ocp.save(self.directory, self.pytree) def _create_value_metadata(self, value): if isinstance(value, np.ndarray): @@ -67,42 +67,40 @@ def _create_value_metadata(self, value): def test_invalid_path(self): with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata(self.directory.parent) + ocp.metadata(self.directory.parent) with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata(self.directory.parent / 'foo') + ocp.metadata(self.directory.parent / 'foo') - def test_pytree_metadata_default_checkpointable_name(self): - expected_pytree_metadata = jax.tree.map( + def test_metadata_default_checkpointable_name(self): + expected_metadata = jax.tree.map( self._create_value_metadata, self.pytree ) - metadata = ocp.pytree_metadata(self.directory) + metadata = ocp.metadata(self.directory) self.assertIsInstance(metadata, metadata_types.CheckpointMetadata) - self.assertEqual(expected_pytree_metadata, metadata.metadata) + self.assertEqual(expected_metadata, metadata.metadata) - def test_pytree_metadata_custom_checkpointable_name(self): + def test_metadata_custom_checkpointable_name(self): self.directory.rmtree() custom_name = 'custom_pytree' ocp.save_checkpointables(self.directory, {custom_name: self.pytree}) - expected_pytree_metadata = jax.tree.map( + expected_metadata = jax.tree.map( self._create_value_metadata, self.pytree ) - metadata = ocp.pytree_metadata( - self.directory, checkpointable_name=custom_name - ) + metadata = ocp.metadata(self.directory, checkpointable_name=custom_name) self.assertIsInstance(metadata, metadata_types.CheckpointMetadata) - self.assertEqual(expected_pytree_metadata, metadata.metadata) + self.assertEqual(expected_metadata, metadata.metadata) - def test_pytree_metadata_checkpointable_name_none(self): - pytree_dir = self.directory / PYTREE_CHECKPOINTABLE_KEY + def test_metadata_checkpointable_name_none(self): + pytree_dir = self.directory / STATE_CHECKPOINTABLE_KEY with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata(pytree_dir, checkpointable_name=None) + ocp.metadata(pytree_dir, checkpointable_name=None) def test_load_with_metadata(self): - metadata = ocp.pytree_metadata(self.directory) + metadata = ocp.metadata(self.directory) def _set_numpy_cast_type(x): if isinstance(x, np.ndarray): @@ -124,15 +122,15 @@ def _set_numpy_cast_type(x): ) expected_pytree = jax.tree.map(_set_numpy_cast_type, self.pytree) - with self.subTest('pytree_metadata'): - loaded_pytree = ocp.load_pytree(self.directory, metadata.metadata) + with self.subTest('metadata'): + loaded_pytree = ocp.load(self.directory, metadata.metadata) test_utils.assert_tree_equal(self, expected_pytree, loaded_pytree) with self.subTest('full_metadata'): - loaded_pytree = ocp.load_pytree(self.directory, metadata) + loaded_pytree = ocp.load(self.directory, metadata) test_utils.assert_tree_equal(self, expected_pytree, loaded_pytree) - def test_pytree_metadata_safetensors(self): + def test_metadata_safetensors(self): st_path = epath.Path(self.create_tempdir().full_path) / 'model.safetensors' tensor_data = { 'x': np.array([[1.0, 2.0]], dtype=np.float32), @@ -149,7 +147,7 @@ def test_pytree_metadata_safetensors(self): with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS ): - ckpt_metadata = ocp.pytree_metadata(st_path) + ckpt_metadata = ocp.metadata(st_path) self.assertIsInstance(ckpt_metadata, metadata_types.CheckpointMetadata) self.assertEqual( @@ -168,10 +166,10 @@ def test_pytree_metadata_safetensors(self): with context_lib.Context( checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS ): - ocp.pytree_metadata(self.directory) + ocp.metadata(self.directory) - def test_pytree_metadata_with_incompatible_item(self): + def test_metadata_with_incompatible_item(self): self.directory.rmtree() # Save a valid PyTree to 'state' ocp.save_checkpointables(self.directory, {'state': self.pytree}) @@ -180,7 +178,7 @@ def test_pytree_metadata_with_incompatible_item(self): (self.directory / 'datasets').mkdir() (self.directory / 'datasets' / 'data.txt').write_text('some data') - metadata = ocp.pytree_metadata(self.directory, checkpointable_name='state') + metadata = ocp.metadata(self.directory, checkpointable_name='state') self.assertIsInstance(metadata, metadata_types.CheckpointMetadata) self.assertIsInstance(metadata.metadata, dict) self.assertSetEqual( @@ -253,9 +251,9 @@ def test_checkpointables_metadata_safetensors(self): ckpt_metadata = ocp.checkpointables_metadata(st_path) self.assertIsInstance(ckpt_metadata, metadata_types.CheckpointMetadata) - self.assertIn(PYTREE_CHECKPOINTABLE_KEY, ckpt_metadata.metadata) + self.assertIn(STATE_CHECKPOINTABLE_KEY, ckpt_metadata.metadata) - st_pytree_metadata = ckpt_metadata.metadata[PYTREE_CHECKPOINTABLE_KEY] + st_pytree_metadata = ckpt_metadata.metadata[STATE_CHECKPOINTABLE_KEY] self.assertEqual(st_pytree_metadata.keys(), expected_st_metadata.keys()) for key, expected_sds in expected_st_metadata.items(): actual_sds = st_pytree_metadata[key] diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py index 8ef7e6207..ceb97699c 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py @@ -56,12 +56,12 @@ class CheckpointMetadata(Generic[CheckpointableMetadataT]): `dict[str, AbstractCheckpointable]`. `CheckpointMetadata` can be accessed via one of two metadata methods. Please - see :py:func:`.pytree_metadata` and :py:func:`.checkpointables_metadata` for + see :py:func:`.metadata` and :py:func:`.checkpointables_metadata` for more information and usage instructions. If the checkpoint contains a `PyTree`, this metadata can be accessed via:: - metadata = ocp.pytree_metadata(path) + metadata = ocp.metadata(path) # Inspect various properties metadata.init_timestamp_nsecs diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index f36a01801..3e705f5b8 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -35,7 +35,7 @@ from orbax.checkpoint.experimental.v1._src.tree import types as tree_types -PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY StatefulCheckpointableHandler = ( stateful_checkpointable_handler.StatefulCheckpointableHandler @@ -46,7 +46,7 @@ class _PartialSavePyTree(handler_types.StatefulCheckpointable): """Wraps a PyTree to signal that it should be saved in partial mode.""" - pytree: tree_types.PyTree + state: tree_types.PyTree def __post_init__(self): self.handler = pytree_handler.PyTreeHandler(partial_save_mode=True) @@ -54,15 +54,15 @@ def __post_init__(self): async def save( self, directory: path_types.PathAwaitingCreation ) -> Awaitable[None]: - return await self.handler.save(directory, self.pytree) + return await self.handler.save(directory, self.state) async def load(self, directory: path_types.Path) -> Awaitable[None]: raise NotImplementedError('Partial load is not supported via this wrapper.') -def save_pytree( +def save( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.Leaf], + state: tree_types.PyTreeOf[tree_types.Leaf], *, custom_metadata: tree_types.JsonType | None = None, ): @@ -83,7 +83,7 @@ def save_pytree( ### Workflow A typical partial save workflow involves one or more calls to - :py:func:`.save_pytree` followed by a single call to :py:func:`~.finalize`:: + :py:func:`.save` followed by a single call to :py:func:`~.finalize`:: path = '/path/to/my/checkpoint' @@ -91,12 +91,12 @@ def save_pytree( # '/path/to/my/checkpoint.partial_save' # Note: the exact temporary directory name is an implementation detail that # depends on the file system and should not be relied on. - ocp.partial.save_pytree(path, {'layer1': ..., 'step': 1}) + ocp.partial.save(path, {'layer1': ..., 'step': 1}) # A subsequent call reads the previous version and applies new updates # to the temporary directory: # '/path/to/my/checkpoint.partial_save' - ocp.partial.save_pytree(path, {'layer2': ..., 'metrics': ...}) + ocp.partial.save(path, {'layer2': ..., 'metrics': ...}) # This call commits the latest version to the final destination at # '/path/to/my/checkpoint'. @@ -104,42 +104,42 @@ def save_pytree( ### Additions vs. Replacements - The provided `pytree` represents a set of updates. - - If a key in `pytree` (e.g., 'metrics') does not exist in the on-disk + The provided `state` represents a set of updates. + - If a key in `state` (e.g., 'metrics') does not exist in the on-disk checkpoint, it is treated as an **addition**. In other words, the sets of - keys of the on-disk PyTree and the provided `pytree` are disjoint. + keys of the on-disk PyTree and the provided `state` are disjoint. - If a key (e.g., 'step') already exists, its value is **replaced**. In other - words, the sets of keys of the on-disk PyTree and the provided `pytree` + words, the sets of keys of the on-disk PyTree and the provided `state` overlap. Replacements are currently NOT supported. Please reach out to the Orbax team if you need this functionality. - See :py:func:`~.v1.save_pytree` for general + See :py:func:`~.v1.save` for general PyTree saving documentation. Args: path: The path to save the checkpoint to. - pytree: A PyTree representing the additions to be applied to the on-disk + state: A PyTree representing the additions to be applied to the on-disk checkpoint. custom_metadata: User-provided custom metadata. This will be merged with any existing custom metadata. Values from this dictionary will overwrite existing values if keys conflict. """ - save_pytree_async( + save_async( path, - pytree, + state, custom_metadata=custom_metadata, ).result() -def save_pytree_async( +def save_async( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.Leaf], + state: tree_types.PyTreeOf[tree_types.Leaf], *, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[None]: """Partially saves a PyTree asynchronously. - Unlike :py:func:`.save_pytree`, this function returns an + Unlike :py:func:`.save`, this function returns an :py:class:`.AsyncResponse` immediately after scheduling the save operation. The actual writing to disk happens in a background thread. You can use `response.result()` to block @@ -161,17 +161,17 @@ def save_pytree_async( ### Workflow A typical partial save workflow involves one or more calls to - :py:func:`.save_pytree_async` followed by a single call to + :py:func:`.save_async` followed by a single call to :py:func:`.finalize`:: path = '/path/to/my/checkpoint' # The first call creates a temporary directory and returns immediately. - response1 = ocp.partial.save_pytree_async(path, {'layer1': ..., 'step': 1}) + response1 = ocp.partial.save_async(path, {'layer1': ..., 'step': 1}) # A subsequent call also returns immediately. Orbax ensures that this # operation waits for the first one to complete before starting. - response2 = ocp.partial.save_pytree_async( + response2 = ocp.partial.save_async( path, {'layer2': ..., 'metrics': ...} ) @@ -185,23 +185,22 @@ def save_pytree_async( ### Additions vs. Replacements - The provided `pytree` represents a set of updates. - - If a key in `pytree` (e.g., 'metrics') does not exist in the on-disk + The provided `state` represents a set of updates. + - If a key in `state` (e.g., 'metrics') does not exist in the on-disk checkpoint, it is treated as an **addition**. - If a key (e.g., 'step') already exists, its value is **replaced**. Replacements are currently NOT supported. Please reach out to the Orbax team if you need this functionality. - See :py:func:`~.v1.save_pytree_async` for general + See :py:func:`~.v1.save_async` for general PyTree saving documentation. Args: path: The path to save the checkpoint to. - pytree: The PyTree to save. This may be any JAX PyTree (including custom - objects registered as PyTrees) consisting of supported leaf types (see - :py:class:`~.v1.tree.Leaf`). Default supported leaf types include - `jax.Array`, `np.ndarray`, simple types like - `int`, `float`, `str`, and empty nodes. + state: The PyTree to save. This may be any JAX PyTree consisting of + supported leaf types (see :py:class:`~.v1.tree.Leaf`). + Default supported leaf types include `jax.Array`, `np.ndarray`, + simple types like `int`, `float`, `str`, and empty nodes. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. @@ -222,7 +221,7 @@ def save_pytree_async( return execution.save_checkpointables_impl( partial_path_lib.add_partial_save_suffix(path), - {PYTREE_CHECKPOINTABLE_KEY: _PartialSavePyTree(pytree)}, + {STATE_CHECKPOINTABLE_KEY: _PartialSavePyTree(state)}, overwrite=False, custom_metadata=custom_metadata, async_origin=True, @@ -235,7 +234,7 @@ def finalize(path: path_types.PathLike) -> None: This function commits all changes made during a partial save session, concluding the transaction. It should be called once after all desired - :py:func:`.save_pytree` operations are complete. + :py:func:`.save` operations are complete. The finalization process is atomic. It renames the temporary, versioned partial save directory to the final target `path`, making the updated @@ -250,7 +249,7 @@ def finalize(path: path_types.PathLike) -> None: path = '/path/to/my/checkpoint' # These calls write to a temporary, versioned directory, not the final path. - ocp.partial.save_pytree(path, {'step': 1}) + ocp.partial.save(path, {'step': 1}) ocp.partial.save_checkpointables(path, {'metrics': ...}) # This call performs the atomic rename, making the checkpoint available at @@ -259,13 +258,13 @@ def finalize(path: path_types.PathLike) -> None: Args: path: The final, target path of the checkpoint to be finalized. This should - be the same path that was passed to :py:func:`~.save_pytree` calls. + be the same path that was passed to :py:func:`~.save` calls. Raises: FileExistsError: If a finalized checkpoint already exists at `path`. To overwrite, it must be deleted first. FileNotFoundError: If no partial save session is found for the given `path`. - This can happen if :py:func:`.save_pytree` was not called first. + This can happen if :py:func:`.save` was not called first. """ context = context_lib.get_context() path = context.file_options.path_class(path) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index f5f20373b..b5fcdea9e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -85,7 +85,7 @@ def add_internal_checkpointables( class _SaveResponse(AsyncResponse[None]): - """An :py:class:`.AsyncResponse` representing the result of:py:func:`.save_pytree_async`.""" + """An :py:class:`.AsyncResponse` representing the result of :py:func:`.save_async`.""" def __init__( self, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py index 7775a345a..466ccd1d3 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py @@ -29,27 +29,27 @@ from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types -PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY Checkpointable = handler_types.Checkpointable -def save_pytree( +def save( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.Leaf], + state: tree_types.PyTreeOf[tree_types.Leaf], *, - checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, + checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, ): """Saves a `PyTree`. The operation blocks until complete. For improved performance, consider using - :py:func:`.save_pytree_async` instead. This function should be called on + :py:func:`.save_async` instead. This function should be called on all available controller processes. Example usage: Simple save of a dictionary containing JAX arrays:: - pytree = { + state = { 'params': { 'w': jnp.ones((8, 8)), 'b': jnp.zeros(8), @@ -57,11 +57,11 @@ def save_pytree( 'step': 100 } # Saves to /tmp/my_checkpoint/ - ocp.save_pytree('/tmp/my_checkpoint', pytree) + ocp.save('/tmp/my_checkpoint', state) Args: path: The path to save the checkpoint to. - pytree: The `PyTree` to save. This may be any JAX `PyTree` (including custom + state: The `PyTree` to save. This may be any JAX `PyTree` (including custom objects registered as `PyTrees`) consisting of supported leaf types. See `orbax.checkpoint.experimental.v1.tree` for a table of standard supported leaf types. @@ -75,7 +75,7 @@ def save_pytree( """ execution.save_checkpointables_impl( path, - {checkpointable_name: pytree}, + {checkpointable_name: state}, overwrite=overwrite, custom_metadata=custom_metadata, async_origin=False, @@ -147,17 +147,17 @@ def save_checkpointables( # TODO(b/396190818): Test modification of the context by the user after the # save operation is scheduled. -def save_pytree_async( +def save_async( path: path_types.PathLike, - pytree: tree_types.PyTreeOf[tree_types.Leaf], + state: tree_types.PyTreeOf[tree_types.Leaf], *, - checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, + checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[None]: """Saves a `PyTree` asynchronously. - Unlike :py:func:`.save_pytree`, this function returns immediately after the + Unlike :py:func:`.save`, this function returns immediately after the save operation is scheduled (except for certain operations, like device-to-host copying of on-device arrays, which must happen on the main thread). Further writing operations @@ -172,7 +172,7 @@ def save_pytree_async( Example usage: Simple save of a dictionary containing JAX arrays asynchronously:: - pytree = { + state = { 'params': { 'w': jnp.ones((8, 8)), 'b': jnp.zeros(8), @@ -180,8 +180,8 @@ def save_pytree_async( 'step': 100 } # Saves to /tmp/my_checkpoint/ - future = ocp.experimental.v1.save_pytree_async( - '/tmp/my_checkpoint', pytree + future = ocp.experimental.v1.save_async( + '/tmp/my_checkpoint', state ) # Perform other work here... @@ -191,7 +191,7 @@ def save_pytree_async( Args: path: The path to save the checkpoint to. - pytree: The `PyTree` to save. This may be any JAX `PyTree` (including custom + state: The `PyTree` to save. This may be any JAX `PyTree` (including custom objects registered as `PyTrees`) consisting of supported leaf types. See `orbax.checkpoint.v1.tree` for a table of standard supported leaf types. checkpointable_name: The name of the checkpointable to save a pytree under. @@ -208,7 +208,7 @@ def save_pytree_async( """ return execution.save_checkpointables_impl( path, - {checkpointable_name: pytree}, + {checkpointable_name: state}, overwrite=overwrite, custom_metadata=custom_metadata, async_origin=True, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py index 54f601de9..b46455688 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py @@ -30,6 +30,9 @@ from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils +STATE_CHECKPOINTABLE_KEY = checkpoint_layout_lib.STATE_CHECKPOINTABLE_KEY + + CheckpointLayoutEnum = options_lib.CheckpointLayout InvalidLayoutError = checkpoint_layout_lib.InvalidLayoutError @@ -58,7 +61,9 @@ def setUp(self) -> None: def setup_registry(self) -> registration.CheckpointableHandlerRegistry: """Ensures we only have what we explicitly add.""" registry = ocp.handlers.local_registry(include_global_registry=False) - registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='pytree') + registry.add( + ocp.handlers.PyTreeHandler, checkpointable_name=STATE_CHECKPOINTABLE_KEY + ) return registry def _determine_expected_outcome( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py index ffe275ea0..2ae9c5396 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py @@ -31,6 +31,9 @@ from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils +STATE_CHECKPOINTABLE_KEY = checkpoint_layout_lib.STATE_CHECKPOINTABLE_KEY + + CheckpointLayoutEnum = options_lib.CheckpointLayout InvalidLayoutError = checkpoint_layout_lib.InvalidLayoutError @@ -101,7 +104,11 @@ def setup_registry( registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='state') registry.add(ocp.handlers.JsonHandler, checkpointable_name='metadata') - registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='pytree') + if not registry.has(STATE_CHECKPOINTABLE_KEY): + registry.add( + ocp.handlers.PyTreeHandler, + checkpointable_name=STATE_CHECKPOINTABLE_KEY, + ) return registry diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py index c9a63def2..e7ae95296 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for V1 load_pytree API against generated V0 and V1 Checkpoints.""" +"""Tests for V1 load API against generated V0 and V1 Checkpoints.""" import os from typing import Tuple, Type @@ -33,6 +33,7 @@ from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils +STATE_CHECKPOINTABLE_KEY = checkpoint_layout_lib.STATE_CHECKPOINTABLE_KEY CheckpointLayoutEnum = options_lib.CheckpointLayout InvalidLayoutError = checkpoint_layout_lib.InvalidLayoutError @@ -41,7 +42,7 @@ class LoadPytreeCompatibilityTestBase(parameterized.TestCase): - """Tests for V1 load_pytree API against generated Checkpoints.""" + """Tests for V1 load API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -100,11 +101,14 @@ def setup_registry( else: registry.add(ocp.handlers.PyTreeHandler, checkpointable_name=path.name) - if pytree_registered: + if pytree_registered and not registry.has(STATE_CHECKPOINTABLE_KEY): # Register to scoped 'pytree' handler for fallback resolution. # Note this should standardly be present, though testing its presence to # ensure resolution works as expected without always relying on it. - registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='pytree') + registry.add( + ocp.handlers.PyTreeHandler, + checkpointable_name=STATE_CHECKPOINTABLE_KEY, + ) return registry @@ -174,8 +178,9 @@ def _determine_expected_outcome( return ( registration.NoEntryError, ( - r'Could not resolve a handler for .* and no \'pytree\' handler' - r' found in .*' + r'Could not resolve a handler for .* and no \'' + + STATE_CHECKPOINTABLE_KEY + + r'\' handler found in .*' ), ) @@ -192,7 +197,7 @@ def _determine_expected_outcome( handler_registered=[True, False], pytree_registered=[True, False], ) - def test_load_pytree_compatibility( + def test_load_compatibility( self, version: str, checkpointable_name: str | None, @@ -204,13 +209,13 @@ def test_load_pytree_compatibility( handler_registered: bool, pytree_registered: bool, ) -> None: - """Tests load_pytree against various checkpoint configurations. + """Tests load against various checkpoint configurations. Args: version: The checkpoint version to test against. checkpointable_name: The name of the checkpointable to load. abstract_pytree_provided: Whether an abstract pytree is provided to - ocp.load_pytree. + ocp.load. name_registered: Whether a handler is registered for the checkpointable_name. metadata_present: Whether the checkpoint has metadata. @@ -263,18 +268,18 @@ def test_load_pytree_compatibility( ) ): if error_type is None: - loaded = ocp.load_pytree( + loaded = ocp.load( path, checkpointable_name=checkpointable_name, - abstract_pytree=actual_abstract_pytree, + abstract_state=actual_abstract_pytree, ) test_utils.assert_tree_equal(self, loaded, self.expected_state) else: with self.assertRaisesRegex(error_type, expected_error_msg): - ocp.load_pytree( + ocp.load( path, checkpointable_name=checkpointable_name, - abstract_pytree=actual_abstract_pytree, + abstract_state=actual_abstract_pytree, ) @parameterized.product( @@ -290,10 +295,10 @@ def test_load_pytree_compatibility( 'missing_pytree_data_file__sharding', ], ) - def test_load_pytree_non_critical_corruptions( + def test_load_non_critical_corruptions( self, version: str, alteration: str ) -> None: - """Tests load_pytree against checkpoints with non-critical corruptions. + """Tests load against checkpoints with non-critical corruptions. Args: version: The version of the checkpoint to load. @@ -305,8 +310,8 @@ def test_load_pytree_non_critical_corruptions( 'non_critical_metadata_alterations', alteration, ) - loaded = ocp.load_pytree( - path, abstract_pytree=self.abstract_state, checkpointable_name='state' + loaded = ocp.load( + path, abstract_state=self.abstract_state, checkpointable_name='state' ) test_utils.assert_tree_equal(self, loaded, self.expected_state) @@ -317,10 +322,10 @@ def test_load_pytree_non_critical_corruptions( 'missing_pytree_data_dir_d', ], ) - def test_load_pytree_critical_corruptions( + def test_load_critical_corruptions( self, version: str, alteration: str ) -> None: - """Tests load_pytree against checkpoints with critical corruptions. + """Tests load against checkpoints with critical corruptions. Args: version: The version of the checkpoint to load. @@ -335,17 +340,17 @@ def test_load_pytree_critical_corruptions( error_type = ValueError error_msg = r'Error opening .* driver:' with self.assertRaisesRegex(error_type, error_msg): - ocp.load_pytree( + ocp.load( path, checkpointable_name='state', - abstract_pytree=self.abstract_state, + abstract_state=self.abstract_state, ) @parameterized.product( version=['v0', 'v1'], ) def test_load_incorrect_path(self, version: str) -> None: - """Tests load_pytree against checkpoints with incorrect paths. + """Tests load against checkpoints with incorrect paths. Args: version: The version of the checkpoint to test against. @@ -363,12 +368,12 @@ def test_load_incorrect_path(self, version: str) -> None: InvalidLayoutError, r'Could not recognize the checkpoint at .* as a valid Orbax checkpoint' ): - ocp.load_pytree(child_path, checkpointable_name='state') + ocp.load(child_path, checkpointable_name='state') with self.assertRaisesRegex( InvalidLayoutError, r'Could not recognize the checkpoint at .* as a valid Orbax checkpoint' ): - ocp.load_pytree(parent_path, checkpointable_name='state') + ocp.load(parent_path, checkpointable_name='state') if __name__ == '__main__': diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py index 8cfeb3e75..2238b3aaf 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py @@ -117,16 +117,7 @@ def setUp(self) -> None: def setup_registry(self) -> registration.CheckpointableHandlerRegistry: """Sets up a registry for the test.""" - registry = ocp.handlers.local_registry() - registry.add( - ocp.handlers.PyTreeHandler, - checkpointable_name='state', - secondary_typestrs=[ - 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler', - ], - ) - - return registry + return ocp.handlers.local_registry() def _create_temporary_checkpoint( self, @@ -270,7 +261,7 @@ def test_root_metadata( step.standard_name_format(step_prefix='checkpoint'), ], ) - def test_pytree_metadata( + def test_metadata( self, version: str, metrics_status: bool, @@ -301,7 +292,7 @@ def test_pytree_metadata( checkpointer = Checkpointer(path, step_name_format=name_format) self.enter_context(checkpointer) - metadata = checkpointer.pytree_metadata(0) + metadata = checkpointer.metadata(0) self.assertIsInstance(metadata, CheckpointMetadata) actual = compatibility_test_utils.strip_sharding_metadata( metadata.metadata @@ -425,14 +416,14 @@ def test_load_checkpointables( step.standard_name_format(step_prefix='checkpoint'), ], ) - def test_load_pytree( + def test_load( self, version: str, metrics_status: bool, root_metadata_status: bool, name_format: step.NameFormat | None = None, ) -> None: - """Verifies load_pytree API against generated Checkpoints. + """Verifies load API against generated Checkpoints. Args: version: The checkpoint version to load. @@ -456,7 +447,7 @@ def test_load_pytree( checkpointer = Checkpointer(path, step_name_format=name_format) self.enter_context(checkpointer) - loaded = checkpointer.load_pytree( - 0, abstract_pytree=self.abstract_state, checkpointable_name='state' + loaded = checkpointer.load( + 0, abstract_state=self.abstract_state, checkpointable_name='state' ) test_utils.assert_tree_equal(self, self.expected_state, loaded) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py index 986c4633e..51427199c 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for V1 pytree_metadata API against generated V0 and V1 Checkpoints.""" +"""Tests for V1 metadata API against generated V0 and V1 Checkpoints.""" import os from typing import Tuple, Type @@ -30,6 +30,9 @@ from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils +STATE_CHECKPOINTABLE_KEY = checkpoint_layout_lib.STATE_CHECKPOINTABLE_KEY + + CheckpointLayoutEnum = options_lib.CheckpointLayout InvalidLayoutError = checkpoint_layout_lib.InvalidLayoutError @@ -38,7 +41,7 @@ class PytreeMetadataCompatibilityTestBase(parameterized.TestCase): - """Tests for V1 pytree_metadata API against generated Checkpoints.""" + """Tests for V1 metadata API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -87,7 +90,11 @@ def setup_registry( registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='state') registry.add(ocp.handlers.JsonHandler, checkpointable_name='metadata') - registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='pytree') + if not registry.has(STATE_CHECKPOINTABLE_KEY): + registry.add( + ocp.handlers.PyTreeHandler, + checkpointable_name=STATE_CHECKPOINTABLE_KEY, + ) return registry def _determine_expected_outcome( @@ -146,7 +153,7 @@ def _determine_expected_outcome( is_pytree=[True, False], handler_registered=[True, False], ) - def test_pytree_metadata_compatibility( + def test_metadata_compatibility( self, version: str, checkpointable_name: str | None, @@ -156,7 +163,7 @@ def test_pytree_metadata_compatibility( is_pytree: bool, handler_registered: bool, ) -> None: - """Tests pytree_metadata compatibility across V0 and V1 checkpoints. + """Tests metadata compatibility across V0 and V1 checkpoints. Args: version: The checkpoint version to test against. @@ -192,7 +199,7 @@ def test_pytree_metadata_compatibility( ) ): if error_type is None: - loaded = ocp.pytree_metadata( + loaded = ocp.metadata( path, checkpointable_name=checkpointable_name, ) @@ -204,7 +211,7 @@ def test_pytree_metadata_compatibility( test_utils.assert_tree_equal(self, expected, actual) else: with self.assertRaisesRegex(error_type, error_msg): - ocp.pytree_metadata( + ocp.metadata( path, checkpointable_name=checkpointable_name, ) @@ -221,10 +228,10 @@ def test_pytree_metadata_compatibility( 'missing_pytree_data_dir_array_metadatas', ], ) - def test_pytree_metadata_non_critical_corruptions( + def test_metadata_non_critical_corruptions( self, version: str, alteration: str ) -> None: - """Tests pytree_metadata with non-critical corruptions. + """Tests metadata with non-critical corruptions. Args: version: The checkpoint version to test against. @@ -236,7 +243,7 @@ def test_pytree_metadata_non_critical_corruptions( 'non_critical_metadata_alterations', alteration, ) - loaded = ocp.pytree_metadata(path, checkpointable_name='state') + loaded = ocp.metadata(path, checkpointable_name='state') expected = self.expected_state_metadata actual = loaded.metadata if multihost.is_pathways_backend() or jax.process_count() > 1: @@ -247,10 +254,10 @@ def test_pytree_metadata_non_critical_corruptions( @parameterized.product( version=['v0', 'v1'], ) - def test_pytree_metadata_missing_sharding_corruption( + def test_metadata_missing_sharding_corruption( self, version: str ) -> None: - """Tests pytree_metadata with missing sharding corruption. + """Tests metadata with missing sharding corruption. Args: version: The checkpoint version to test against. @@ -263,7 +270,7 @@ def test_pytree_metadata_missing_sharding_corruption( ) # Missing sharding metadata results in a pytree identical to expected # values except sharding metadata is None. - loaded = ocp.pytree_metadata(path, checkpointable_name='state') + loaded = ocp.metadata(path, checkpointable_name='state') self.assertIsNone(loaded.metadata['a'].sharding_metadata) @parameterized.product( @@ -273,10 +280,10 @@ def test_pytree_metadata_missing_sharding_corruption( 'missing_pytree_data_dir_d', ], ) - def test_pytree_metadata_critical_corruptions( + def test_metadata_critical_corruptions( self, version: str, alteration: str ) -> None: - """Tests pytree_metadata with critical corruptions. + """Tests metadata with critical corruptions. Args: version: The checkpoint version to test against. @@ -289,7 +296,7 @@ def test_pytree_metadata_critical_corruptions( alteration, ) # Doesnt fail as we are just accessing the metadata. - loaded = ocp.pytree_metadata(path, checkpointable_name='state') + loaded = ocp.metadata(path, checkpointable_name='state') self.assertIsNone(loaded.metadata) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 8987b9385..cf2d9f80a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -49,6 +49,7 @@ from orbax.checkpoint.experimental.v1._src.testing import tree_utils as tree_test_utils from orbax.checkpoint.experimental.v1._src.tree import types as tree_types + PyTree = tree_types.PyTree Path = path_types.Path InvalidLayoutError = checkpoint_layout.InvalidLayoutError @@ -58,7 +59,7 @@ create_sharded_pytree = array_test_utils.create_sharded_pytree as_abstract_type = array_test_utils.as_abstract_type -PYTREE_CHECKPOINTABLE_KEY = 'pytree' +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY PLACEHOLDER = tree_types.PLACEHOLDER Foo = handler_utils.Foo @@ -98,18 +99,18 @@ def tearDown(self): def save_and_wait(self, *args, use_async: bool, **kwargs): if use_async: - response = ocp.save_pytree_async(*args, **kwargs) + response = ocp.save_async(*args, **kwargs) self.assertIsNotNone(response) self.assertIsNone(response.result()) else: - self.assertIsNone(ocp.save_pytree(*args, **kwargs)) + self.assertIsNone(ocp.save(*args, **kwargs)) def load_and_wait(self, *args, use_async: bool, **kwargs): del use_async - return ocp.load_pytree(*args, **kwargs) + return ocp.load(*args, **kwargs) @parameterized.parameters((True,), (False,)) - def test_save_load_pytree(self, use_async): + def test_save_load(self, use_async): self.save_and_wait(self.directory, self.pytree, use_async=use_async) loaded = self.load_and_wait( self.directory, self.abstract_pytree, use_async=use_async @@ -124,7 +125,7 @@ def test_load_default(self, use_async): loaded = self.load_and_wait(self.directory, use_async=use_async) test_utils.assert_tree_equal(self, self.pytree, loaded) - def test_save_pytree_async(self): + def test_save_async(self): start_serialize = threading.Event() original_serialize = serialization_v0.async_serialize_from_host @@ -141,7 +142,7 @@ def mock_serialize(*args, **kwargs): ) ) - response = ocp.save_pytree_async(self.directory, self.pytree) + response = ocp.save_async(self.directory, self.pytree) initial_d_files_mtimes = tree_test_utils.get_d_files_mtimes( self.directory ) @@ -158,7 +159,7 @@ def mock_serialize(*args, **kwargs): tree_test_utils.is_pytree_checkpoint_complete(self.directory) ) - restored = ocp.load_pytree( + restored = ocp.load( self.directory, self.abstract_pytree, ) @@ -207,7 +208,7 @@ async def mock_finalize(self_handler, directory): self.enter_context(context) start = time.time() - response = ocp.save_pytree_async(self.directory / 'timeout', self.pytree) + response = ocp.save_async(self.directory / 'timeout', self.pytree) is_primary = multihost.is_primary_host(0) msg = expected_msg_primary if is_primary else expected_msg_non_primary @@ -227,13 +228,13 @@ async def mock_finalize(self_handler, directory): ) def test_empty_tree(self, tree): with self.assertRaisesRegex(ValueError, 'Found empty item'): - ocp.save_pytree(self.directory, tree) + ocp.save(self.directory, tree) def test_none_tree(self): with self.assertRaisesRegex( ValueError, 'checkpointable must not be None for saving' ): - ocp.save_pytree(self.directory, None) + ocp.save(self.directory, None) # Note the ommission of jax.Array, since this is covered in # several other tests. @@ -245,8 +246,8 @@ def test_none_tree(self): (np.asarray(3.14),), ) def test_standard_leaf_types(self, value): - ocp.save_pytree(self.directory, dict(k=value)) - loaded = ocp.load_pytree(self.directory) + ocp.save(self.directory, dict(k=value)) + loaded = ocp.load(self.directory) if isinstance(value, np.ndarray): np.testing.assert_array_equal(loaded['k'], value) else: @@ -276,14 +277,14 @@ def test_jax_array_leaf_types(self): } for k, v in values.items(): with self.subTest(k): - ocp.save_pytree(self.directory / k, [v]) + ocp.save(self.directory / k, [v]) with self.subTest('with_abstract_pytree'): - loaded = ocp.load_pytree(self.directory / k, [as_abstract_type(v)]) + loaded = ocp.load(self.directory / k, [as_abstract_type(v)]) test_utils.assert_tree_equal(self, [v], loaded) with self.subTest('without_abstract_pytree'): if multihost.is_pathways_backend(): self.skipTest('Must provide abstract_pytree for Pathways.') - loaded = ocp.load_pytree(self.directory / k) + loaded = ocp.load(self.directory / k) test_utils.assert_tree_equal(self, [v], loaded) @parameterized.parameters( @@ -294,13 +295,9 @@ def test_jax_array_leaf_types(self): (np.asarray(3.14),), ) def test_standard_leaf_types_as_checkpointable(self, value): - with self.subTest('save_pytree'): - ocp.save_pytree( - self.directory / 'pytree', value, checkpointable_name='leaf' - ) - loaded = ocp.load_pytree( - self.directory / 'pytree', checkpointable_name='leaf' - ) + with self.subTest('save'): + ocp.save(self.directory / 'pytree', value, checkpointable_name='leaf') + loaded = ocp.load(self.directory / 'pytree', checkpointable_name='leaf') if isinstance(value, np.ndarray): np.testing.assert_array_equal(loaded, value) else: @@ -325,13 +322,9 @@ def test_jax_array_as_checkpointable(self): jax.sharding.PartitionSpec(), ), ) - with self.subTest('save_pytree'): - ocp.save_pytree( - self.directory / 'pytree', value, checkpointable_name='leaf' - ) - loaded = ocp.load_pytree( - self.directory / 'pytree', checkpointable_name='leaf' - ) + with self.subTest('save'): + ocp.save(self.directory / 'pytree', value, checkpointable_name='leaf') + loaded = ocp.load(self.directory / 'pytree', checkpointable_name='leaf') test_utils.assert_tree_equal(self, value, loaded) with self.subTest('save_checkpointables'): ocp.save_checkpointables( @@ -344,7 +337,7 @@ def test_jax_array_as_checkpointable(self): def test_save_unregistered_type_as_pytree(self): with self.assertRaises(serialization_registry.UnregisteredTypeError): - ocp.save_pytree(self.directory, handler_utils.Foo(1, 'hi')) + ocp.save(self.directory, handler_utils.Foo(1, 'hi')) @parameterized.parameters( ({},), @@ -364,51 +357,45 @@ def test_leaf_change_type(self): with self.subTest('numpy_to_jax'): subdir = 'numpy' - ocp.save_pytree(self.directory / subdir, [numpy_arr]) + ocp.save(self.directory / subdir, [numpy_arr]) test_utils.assert_tree_equal( self, [jax_arr], - ocp.load_pytree( - self.directory / subdir, [as_abstract_type(jax_arr)] - ), + ocp.load(self.directory / subdir, [as_abstract_type(jax_arr)]), ) if multihost.is_pathways_backend(): self.skipTest('Must provide abstract_pytree for Pathways.') test_utils.assert_tree_equal( - self, [numpy_arr], ocp.load_pytree(self.directory / subdir) + self, [numpy_arr], ocp.load(self.directory / subdir) ) with self.subTest('jax_to_numpy'): subdir = 'jax' - ocp.save_pytree(self.directory / subdir, [jax_arr]) + ocp.save(self.directory / subdir, [jax_arr]) test_utils.assert_tree_equal( self, [numpy_arr], - ocp.load_pytree( - self.directory / subdir, [as_abstract_type(numpy_arr)] - ), + ocp.load(self.directory / subdir, [as_abstract_type(numpy_arr)]), ) if multihost.is_pathways_backend(): self.skipTest('Must provide abstract_pytree for Pathways.') test_utils.assert_tree_equal( - self, [jax_arr], ocp.load_pytree(self.directory / subdir) + self, [jax_arr], ocp.load(self.directory / subdir) ) with self.subTest('jax_to_numpy_by_value'): subdir = 'jax_to_np_by_value' - ocp.save_pytree(self.directory / subdir, [jax_arr]) + ocp.save(self.directory / subdir, [jax_arr]) test_utils.assert_tree_equal( self, [numpy_arr], - ocp.load_pytree( - self.directory / subdir, [np.array([], dtype=np.int64)] - ), + ocp.load(self.directory / subdir, [np.array([], dtype=np.int64)]), ) def test_empty_array(self): value = np.ones(shape=(0,)) with self.assertRaisesRegex(ValueError, 'zero size'): - ocp.save_pytree(self.directory, dict(k=value)) + ocp.save(self.directory, dict(k=value)) @parameterized.parameters( (complex(1.1, 2.2),), @@ -419,7 +406,7 @@ def test_empty_array(self): def test_invalid_leaf_types(self, value): # TODO(cpgaffney): Consider improving the error message raised. with self.assertRaises(ValueError): - ocp.save_pytree(self.directory, dict(k=value)) + ocp.save(self.directory, dict(k=value)) # TODO(b/337105122): Add tests passing invalid abstract PyTrees. def test_flax_model(self): @@ -442,17 +429,17 @@ def make_state_with_nones(): ) state = make_state_with_optax() - ocp.save_pytree(self.directory, state) + ocp.save(self.directory, state) with self.subTest('with_abstract_state'): abstract_state = jax.tree.map(as_abstract_type, state) - loaded = ocp.load_pytree(self.directory, abstract_state) + loaded = ocp.load(self.directory, abstract_state) test_utils.assert_tree_equal(self, state, loaded) with self.subTest('without_abstract_state'): if multihost.is_pathways_backend(): self.skipTest('Must provide abstract_pytree for Pathways.') - loaded = ocp.load_pytree(self.directory) + loaded = ocp.load(self.directory) expected_tree = tree_utils.serialize_tree( make_state_with_nones(), keep_empty_nodes=True, @@ -490,14 +477,14 @@ def test_reshard(self, save_spec, load_spec): ) save_sharding = jax.sharding.NamedSharding(mesh, save_spec) tree = {'x': create_sharded_array(np.arange(len_devices), save_sharding)} - ocp.save_pytree(self.directory, tree) + ocp.save(self.directory, tree) load_sharding = jax.sharding.NamedSharding(mesh, load_spec) expected_tree = { 'x': create_sharded_array(np.arange(len_devices), load_sharding) } abstract_tree = {'x': as_abstract_type(expected_tree['x'])} - loaded = ocp.load_pytree(self.directory, abstract_pytree=abstract_tree) + loaded = ocp.load(self.directory, abstract_state=abstract_tree) test_utils.assert_tree_equal(self, expected_tree, loaded) @parameterized.parameters( @@ -552,56 +539,56 @@ def test_casting(self, original_dtype, save_dtype, load_dtype): ) ) ): - ocp.save_pytree(self.directory, tree) + ocp.save(self.directory, tree) with self.subTest('with_abstract_tree'): abstract_tree = jax.tree.map(as_abstract_type, load_casted_tree) - loaded = ocp.load_pytree(self.directory, abstract_tree) + loaded = ocp.load(self.directory, abstract_tree) test_utils.assert_tree_equal(self, load_casted_tree, loaded) with self.subTest('without_abstract_tree'): if multihost.is_pathways_backend(): self.skipTest('Must provide abstract_pytree for Pathways.') - loaded = ocp.load_pytree(self.directory) + loaded = ocp.load(self.directory) test_utils.assert_tree_equal(self, save_casted_tree, loaded) # TODO(b/295313820): Improve mismatched-tree error messages. def test_mismatched_abstract_tree(self): - ocp.save_pytree(self.directory, self.pytree) + ocp.save(self.directory, self.pytree) with self.subTest('subset_of_keys'): abstract_pytree = dict(self.abstract_pytree) del abstract_pytree['a'], abstract_pytree['b'] with self.assertRaisesRegex(ValueError, 'User-provided restore item'): - ocp.load_pytree(self.directory, abstract_pytree) + ocp.load(self.directory, abstract_pytree) with self.subTest('superset_of_keys'): abstract_pytree = dict(self.abstract_pytree) abstract_pytree['z'] = as_abstract_type(np.arange(16)) with self.assertRaisesRegex(ValueError, 'User-provided restore item'): - ocp.load_pytree(self.directory, abstract_pytree) + ocp.load(self.directory, abstract_pytree) with self.subTest('renamed_key'): abstract_pytree = dict(self.abstract_pytree) abstract_pytree['z'] = abstract_pytree.pop('a') with self.assertRaisesRegex(ValueError, 'User-provided restore item'): - ocp.load_pytree(self.directory, abstract_pytree) + ocp.load(self.directory, abstract_pytree) def test_overwrites(self): - ocp.save_pytree(self.directory, self.pytree) + ocp.save(self.directory, self.pytree) with self.assertLogs(level='INFO') as cm: - ocp.save_pytree(self.directory, self.numpy_pytree, overwrite=True) + ocp.save(self.directory, self.numpy_pytree, overwrite=True) found_log = any( 'Specified `overwrite`: removing existing path.' in log for log in cm.output ) self.assertEqual(found_log, multihost.is_primary_host(0)) test_utils.assert_tree_equal( - self, self.numpy_pytree, ocp.load_pytree(self.directory) + self, self.numpy_pytree, ocp.load(self.directory) ) def test_auto_overwrite_tmp_checkpoint(self): - ocp.save_pytree(self.directory, self.pytree) + ocp.save(self.directory, self.pytree) if multihost.is_primary_host(0): self.directory.rename( self.directory.parent @@ -610,9 +597,9 @@ def test_auto_overwrite_tmp_checkpoint(self): test_utils.sync_global_processes( 'test_auto_overwrite_tmp_checkpoint:rename' ) - ocp.save_pytree(self.directory, self.numpy_pytree) + ocp.save(self.directory, self.numpy_pytree) test_utils.assert_tree_equal( - self, self.numpy_pytree, ocp.load_pytree(self.directory) + self, self.numpy_pytree, ocp.load(self.directory) ) def test_multiple_pytrees(self): @@ -633,8 +620,8 @@ def test_multiple_pytrees(self): self.directory, abstract_checkpointables ) test_utils.assert_tree_equal(self, checkpointables, loaded) - with self.subTest('load_pytree'): - loaded = ocp.load_pytree(self.directory, self.abstract_pytree) + with self.subTest('load'): + loaded = ocp.load(self.directory, self.abstract_pytree) test_utils.assert_tree_equal(self, self.pytree, loaded) with self.subTest('load_numpy_pytree'): loaded = ocp.load_checkpointables( @@ -667,12 +654,14 @@ def test_missing_keys(self): } ocp.save_checkpointables(self.directory, checkpointables) - with self.subTest('load_pytree'): - loaded = ocp.load_pytree(self.directory) + with self.subTest('load'): + loaded = ocp.load(self.directory) test_utils.assert_tree_equal(self, self.numpy_pytree, loaded) with self.subTest('load_checkpointables'): - with self.assertRaisesRegex(KeyError, 'Requested checkpointables:'): + with self.assertRaisesRegex( + KeyError, 'Requested checkpointables:' + ): ocp.load_checkpointables( self.directory, {'foo': handler_utils.AbstractFoo()} ) @@ -782,7 +771,9 @@ def test_save_checkpointables_deleted(self): loaded = ocp.load_checkpointables(self.directory) self.assertSameElements(['two'], loaded.keys()) - with self.assertRaisesRegex(KeyError, 'Requested checkpointables:'): + with self.assertRaisesRegex( + KeyError, 'Requested checkpointables:' + ): ocp.load_checkpointables(self.directory, {'one': None}) @@ -790,22 +781,18 @@ def test_abstract_pytree_types(self): # TODO(b/408241116): Enable tests on Pathways. if multihost.is_pathways_backend(): self.skipTest('Sharding metadata not present in Pathways.') - ocp.save_pytree(self.directory, self.pytree) + ocp.save(self.directory, self.pytree) with self.subTest('checkpoint_metadata'): - loaded = ocp.load_pytree( - self.directory, ocp.pytree_metadata(self.directory) - ) + loaded = ocp.load(self.directory, ocp.metadata(self.directory)) test_utils.assert_tree_equal(self, self.pytree, loaded) - with self.subTest('pytree_metadata'): - loaded = ocp.load_pytree( - self.directory, ocp.pytree_metadata(self.directory).metadata - ) + with self.subTest('metadata'): + loaded = ocp.load(self.directory, ocp.metadata(self.directory).metadata) test_utils.assert_tree_equal(self, self.pytree, loaded) with self.subTest('abstract_pytree'): - loaded = ocp.load_pytree(self.directory, self.abstract_pytree) + loaded = ocp.load(self.directory, self.abstract_pytree) test_utils.assert_tree_equal(self, self.pytree, loaded) with self.subTest('none'): - loaded = ocp.load_pytree(self.directory) + loaded = ocp.load(self.directory) test_utils.assert_tree_equal(self, self.pytree, loaded) def test_abstract_checkpointables_types(self): @@ -994,7 +981,7 @@ def test_async_save_completes_without_result(self): self.assertTrue(self.directory.exists()) def test_partial_restore_placeholder(self): - ocp.save_pytree(self.directory, self.pytree) + ocp.save(self.directory, self.pytree) reference_pytree = jax.tree.map(lambda x: x, self.abstract_pytree) reference_pytree['b'] = PLACEHOLDER @@ -1012,11 +999,11 @@ def test_partial_restore_placeholder(self): 'y': self.pytree['y'], } - loaded = ocp.load_pytree(self.directory, reference_pytree) + loaded = ocp.load(self.directory, reference_pytree) test_utils.assert_tree_equal(self, expected, loaded) def test_partial_restore_omission(self): - ocp.save_pytree(self.directory, self.pytree) + ocp.save(self.directory, self.pytree) reference_pytree = jax.tree.map(lambda x: x, self.abstract_pytree) del reference_pytree['b'] @@ -1038,7 +1025,7 @@ def test_partial_restore_omission(self): ) ) ): - loaded = ocp.load_pytree(self.directory, reference_pytree) + loaded = ocp.load(self.directory, reference_pytree) test_utils.assert_tree_equal(self, expected, loaded) @@ -1053,8 +1040,8 @@ def test_save_with_global_mesh(self, use_same_mesh: bool): mesh = jax.sharding.Mesh(devices, axis_names) jax.sharding.set_mesh(mesh) - ocp.save_pytree(self.directory, self.pytree) - loaded = ocp.load_pytree(self.directory, self.abstract_pytree) + ocp.save(self.directory, self.pytree) + loaded = ocp.load(self.directory, self.abstract_pytree) test_utils.assert_tree_equal(self, self.pytree, loaded) @parameterized.parameters((3,), (8,)) @@ -1078,7 +1065,7 @@ def _assert_false(*args, **kwargs): return_value=timeout, ), ): - r = ocp.save_pytree_async(self.directory, self.pytree) + r = ocp.save_async(self.directory, self.pytree) start = time.time() if multihost.is_primary_host(primary_host=0): with self.assertRaises(AssertionError): @@ -1101,7 +1088,7 @@ def test_save_checkpointables_directory_consistency_failure(self): with self.assertRaisesRegex( ValueError, 'Directory path mismatch in multi-process save' ): - ocp.save_pytree(directory, self.pytree) + ocp.save(directory, self.pytree) def test_load_and_broadcast(self): replica_count = 2 @@ -1125,9 +1112,9 @@ def test_load_and_broadcast(self): ) ) ): - ocp.save_pytree(self.directory, [arr]) + ocp.save(self.directory, [arr]) with self.subTest('with_abstract_pytree'): - loaded = ocp.load_pytree( + loaded = ocp.load( self.directory, [array_test_utils.as_abstract_type(arr)] ) test_utils.assert_tree_equal(self, [arr], loaded) @@ -1137,7 +1124,7 @@ def test_load_and_broadcast(self): 'Must provide `sharding` to restore with' ' `SingleReplicaArrayHandler`', ): - ocp.load_pytree(self.directory) + ocp.load(self.directory) def test_subchunking(self): self.assertEqual(jax.device_count(), 8) @@ -1165,17 +1152,14 @@ def test_subchunking(self): ) ) ): - ocp.save_pytree(self.directory / 'global_setting', pytree) - metadata = ocp.pytree_metadata( - self.directory / 'global_setting' - ).metadata + ocp.save(self.directory / 'global_setting', pytree) + metadata = ocp.metadata(self.directory / 'global_setting').metadata for k in pytree: self.assertEqual(metadata[k].shape, (32,)) self.assertEqual(metadata[k].storage_metadata.write_shape, (4,)) self.assertEqual(metadata[k].storage_metadata.chunk_shape, (2,)) with self.subTest('per_key_setting'): - def scoped_storage_options_creator(key, value): del value if 'a' in tree_utils.str_keypath(key): @@ -1185,7 +1169,6 @@ def scoped_storage_options_creator(key, value): return ocp.options.ArrayOptions.Saving.StorageOptions( chunk_byte_size=8, # force divide in 2 subchunks ) - with ocp.Context( array_options=ocp.options.ArrayOptions( saving=ocp.options.ArrayOptions.Saving( @@ -1193,10 +1176,8 @@ def scoped_storage_options_creator(key, value): ) ), ): - ocp.save_pytree(self.directory / 'per_key_setting', pytree) - metadata = ocp.pytree_metadata( - self.directory / 'per_key_setting' - ).metadata + ocp.save(self.directory / 'per_key_setting', pytree) + metadata = ocp.metadata(self.directory / 'per_key_setting').metadata self.assertEqual(metadata['a'].shape, (32,)) self.assertEqual(metadata['a'].storage_metadata.write_shape, (4,)) self.assertEqual(metadata['a'].storage_metadata.chunk_shape, (1,)) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/tree_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/tree_utils.py index e35933f3e..a8c81ce54 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/tree_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/tree_utils.py @@ -19,10 +19,11 @@ from orbax.checkpoint.experimental.v1._src.synchronization import multihost +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY + + def is_pytree_checkpoint_complete(directory): - return ( - directory / checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY / 'manifest.ocdbt' - ).exists() + return (directory / STATE_CHECKPOINTABLE_KEY / 'manifest.ocdbt').exists() def get_d_files_mtimes(path: path_types.Path) -> list[int]: @@ -31,7 +32,7 @@ def get_d_files_mtimes(path: path_types.Path) -> list[int]: Assumes a structure like:: path/ - / + / ocdbt.process_0/ d/ @@ -57,9 +58,7 @@ def get_d_files_mtimes(path: path_types.Path) -> list[int]: len(matching_dirs) == 1 ), f'Expected exactly one matching directory, got {matching_dirs}.' tmpdir = matching_dirs[0] - matching_pytree_dirs = list( - tmpdir.glob(f'{checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY}*') - ) + matching_pytree_dirs = list(tmpdir.glob(f'{STATE_CHECKPOINTABLE_KEY}*')) if not matching_pytree_dirs: # Temp path not created yet. return [] diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index 3a21929fd..a700622b0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -45,7 +45,7 @@ RootMetadata = training_metadata_types.RootMetadata -PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY class _AsyncSaveResponse(async_types.AsyncResponse[bool]): @@ -130,15 +130,15 @@ def __init__( if ckptr.latest is None: model_state = init_from_scratch(rng) else: - model_state = ckptr.load_pytree() # Loads latest checkpoint. + model_state = ckptr.load() # Loads latest checkpoint. # Note: prefer to specify the abstract tree if available. - model_state = ckptr.load_pytree( - ckptr.latest, abstract_pytree=abstract_model_state) + model_state = ckptr.load( + ckptr.latest, abstract_state=abstract_model_state) start_step = ckptr.latest.step if ckptr.latest else 0 for step in range(start_step, num_steps): model_state = train_step(model_state) # Saves a checkpoint if needed (according to `save_decision_policy`). - ckptr.save_pytree(step, model_state) + ckptr.save(step, model_state) Prefer to use the context manager style as shown above, which ensures that the Checkpointer is closed properly and any outstanding async operations are @@ -230,7 +230,7 @@ def checkpoints(self) -> Sequence[CheckpointMetadata[None]]: The method returns a list of :py:class:`.CheckpointMetadata` objects, which contain selected properties describing the checkpoint. Contrast this with - the methods :py:func:`.pytree_metadata` and + the methods :py:func:`.metadata` and :py:func:`.checkpointables_metadata`, which may perform a more expensive disk read to retrieve additional information. This method only returns cheap cacheable properties like step and timestamp. The return value is @@ -279,12 +279,12 @@ def should_save(self, step: int) -> bool: step = _resolve_integer_step(step) return self._manager.should_save(step) - def save_pytree( + def save( self, step: int, - pytree: tree_types.PyTreeOf[tree_types.Leaf], + state: tree_types.PyTreeOf[tree_types.Leaf], *, - checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, + checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, force: bool = False, overwrite: bool = False, metrics: tree_types.JsonType | None = None, @@ -293,7 +293,7 @@ def save_pytree( """Saves a checkpoint, if dictated by :py:class:`.SaveDecisionPolicy`. This method behaves similarly to the standalone free function - :py:func:`~orbax.checkpoint.v1.save_pytree` (see + :py:func:`~orbax.checkpoint.v1.save` (see documentation), but performs additional tasks related to managing a sequence of checkpoint steps. @@ -325,7 +325,7 @@ def save_pytree( ckptr = training.Checkpointer(directory) # Save the tree at step 0. - saved = ckptr.save_pytree(step=0, pytree=tree) + saved = ckptr.save(step=0, state=tree) # Clean up background threads gracefully when the training loop ends ckptr.close() @@ -338,9 +338,9 @@ def save_pytree( ckptr = training.Checkpointer(directory) - ckptr.save_pytree( + ckptr.save( step=1, - pytree=tree, + state=tree, metrics={'loss': 0.12, 'accuracy': 0.95}, custom_metadata={'description': 'Model after epoch 1'}, ) @@ -350,7 +350,7 @@ def save_pytree( Args: step: The step number to save. - pytree: The PyTree to save. + state: The PyTree to save. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. force: If True, ignores all :py:class:`.SaveDecisionPolicy` checks, and @@ -366,9 +366,9 @@ def save_pytree( Returns: Whether a checkpoint was saved or not. """ - return self.save_pytree_async( + return self.save_async( step, - pytree, + state, checkpointable_name=checkpointable_name, force=force, overwrite=overwrite, @@ -392,7 +392,7 @@ def save_checkpointables( names to values. See `the guide on Checkpointables `_ for more details on checkpointables. Also see documentation for - :py:func:`~orbax.checkpoint.v1.save_pytree`. + :py:func:`~orbax.checkpoint.v1.save`. Example: 1. Basic Usage: @@ -462,12 +462,12 @@ def save_checkpointables( custom_metadata=custom_metadata, ).result() - def save_pytree_async( + def save_async( self, step: int, - pytree: tree_types.PyTreeOf[tree_types.Leaf], + state: tree_types.PyTreeOf[tree_types.Leaf], *, - checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, + checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, force: bool = False, overwrite: bool = False, metrics: tree_types.JsonType | None = None, @@ -476,7 +476,7 @@ def save_pytree_async( """Saves a checkpoint asynchronously. This function is the asynchronous equivalent of - :py:meth:`~.save_pytree`. It accepts the exact same + :py:meth:`~.save`. It accepts the exact same arguments; please refer to that method for detailed descriptions. This method executes mostly in the background, blocking the main thread for @@ -485,18 +485,18 @@ def save_pytree_async( Example: :: - async_response = ckptr.save_pytree_async(step=0, pytree=tree) + async_response = ckptr.save_async(step=0, state=tree) saved = async_response.result() Args: step: The step number to save. - pytree: The PyTree to save. + state: The PyTree to save. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. - force: See `save_pytree`. - overwrite: See `save_pytree`. - metrics: See `save_pytree`. - custom_metadata: See `save_pytree`. + force: See `save`. + overwrite: See `save`. + metrics: See `save`. + custom_metadata: See `save`. Returns: An `AsyncResponse`, which can be awaited via `result()`, which returns a @@ -504,7 +504,7 @@ def save_pytree_async( """ return self.save_checkpointables_async( step, - {checkpointable_name: pytree}, + {checkpointable_name: state}, force=force, overwrite=overwrite, metrics=metrics, @@ -582,23 +582,23 @@ def save_checkpointables_async( ) return _AsyncSaveResponse(self._manager, saved) - def load_pytree( + def load( self, step: int | CheckpointMetadata | None = None, - abstract_pytree: ( + abstract_state: ( tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, *, - checkpointable_name: str = PYTREE_CHECKPOINTABLE_KEY, + checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, ) -> tree_types.PyTreeOf[tree_types.Leaf]: """Loads a PyTree checkpoint at the given step. This method behaves similarly to the standalone free function - :py:func:`~orbax.checkpoint.v1.load_pytree`. + :py:func:`~orbax.checkpoint.v1.load`. - **Note:** Loading a PyTree without providing an `abstract_pytree` is + **Note:** Loading a PyTree without providing an `abstract_state` is provided purely for convenience. For serious or production use cases, it is - STRONGLY recommended to always provide an `abstract_pytree` to ensure the + STRONGLY recommended to always provide an `abstract_state` to ensure the restored PyTree strictly matches the expected shapes, dtypes, and sharding. Example: @@ -612,7 +612,7 @@ def load_pytree( ckptr = training.Checkpointer(directory) # Load the saved PyTree from latest step - restored_tree = ckptr.load_pytree(step=None) + restored_tree = ckptr.load(step=None) 2. Loading with an Abstract PyTree: Provide an abstract structure (such as target shapes and dtypes) @@ -631,16 +631,16 @@ def load_pytree( } # Restore exactly matching the target structure - restored_tree = ckptr.load_pytree( + restored_tree = ckptr.load( step=1, - abstract_pytree=target_structure + abstract_state=target_structure ) Args: step: The step number or :py:class:`.CheckpointMetadata` to load. If None, the checkpointer will attempt to resolve and load the latest existing checkpoint. - abstract_pytree: The abstract PyTree to load. + abstract_state: The abstract PyTree to load. checkpointable_name: The name of the checkpointable to load a pytree under. Defaults to 'pytree'. @@ -648,7 +648,7 @@ def load_pytree( The loaded PyTree. """ return self.load_checkpointables( - step, {checkpointable_name: abstract_pytree} + step, {checkpointable_name: abstract_state} )[checkpointable_name] def load_checkpointables( @@ -768,10 +768,10 @@ def load_checkpointables( abstract_checkpointables, ) - def load_pytree_async( + def load_async( self, step: int | CheckpointMetadata | None = None, - abstract_pytree: ( + abstract_state: ( tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, ) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]: @@ -786,7 +786,7 @@ def load_checkpointables_async( """Loads a set of checkpointables asynchronously at the given step.""" raise NotImplementedError() - def pytree_metadata( + def metadata( self, step: int | CheckpointMetadata | None = None ) -> training_metadata_types.CheckpointMetadata[ metadata_types.PyTreeMetadata @@ -810,7 +810,7 @@ def pytree_metadata( with context_lib.get_context(self._context): checkpoint = self._resolve_existing_checkpoint(step) del step - checkpoint_metadata = metadata_loading.pytree_metadata( + checkpoint_metadata = metadata_loading.metadata( self._manager.directory / self._step_name_format.build_name(checkpoint.step) ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index 68f3e8474..7a0275cbd 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -26,6 +26,7 @@ from orbax.checkpoint import test_utils from orbax.checkpoint._src.serialization import serialization import orbax.checkpoint.experimental.v1 as ocp +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.path import step as path_step_lib from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils @@ -33,6 +34,7 @@ from orbax.checkpoint.experimental.v1._src.testing import tree_utils as tree_test_utils from orbax.checkpoint.experimental.v1._src.tree import types as tree_types +STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY Checkpointer = ocp.training.Checkpointer save_decision_policies = ocp.training.save_decision_policies preservation_policies = ocp.training.preservation_policies @@ -82,18 +84,18 @@ def tearDown(self): test_utils.sync_global_processes('CheckpointerTest:tests_complete') super().tearDown() - def save_pytree( + def save( self, checkpointer: Checkpointer, step: int, - pytree: tree_types.PyTreeOf[tree_types.Leaf], + state: tree_types.PyTreeOf[tree_types.Leaf], metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, ) -> bool: """Saves pytree with v1 Checkpointer.""" - return checkpointer.save_pytree( + return checkpointer.save( step, - pytree, + state, metrics=metrics, custom_metadata=custom_metadata, ) @@ -124,46 +126,46 @@ def test_properties(self): def test_save_restore_pytree(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree) + self.save(checkpointer, 0, self.pytree) with self.subTest('with_abstract_pytree'): - loaded = checkpointer.load_pytree(0, self.abstract_pytree) + loaded = checkpointer.load(0, self.abstract_pytree) test_utils.assert_tree_equal(self, self.pytree, loaded) with self.subTest('without_abstract_pytree'): if multihost.is_pathways_backend(): self.skipTest('Must provide abstract_pytree for Pathways.') - loaded = checkpointer.load_pytree(0) + loaded = checkpointer.load(0) test_utils.assert_tree_equal(self, self.pytree, loaded) - def test_save_load_pytree_with_specific_name(self): + def test_save_load_with_specific_name(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - checkpointer.save_pytree(0, self.pytree, checkpointable_name='state') - loaded = checkpointer.load_pytree( - 0, abstract_pytree=self.abstract_pytree, checkpointable_name='state' + checkpointer.save(0, self.pytree, checkpointable_name='state') + loaded = checkpointer.load( + 0, abstract_state=self.abstract_pytree, checkpointable_name='state' ) test_utils.assert_tree_equal(self, self.pytree, loaded) - def test_save_checkpointables_load_pytree(self): + def test_save_checkpointables_load(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) checkpointer.save_checkpointables(0, {'state': self.pytree}) - loaded = checkpointer.load_pytree( - 0, abstract_pytree=self.abstract_pytree, checkpointable_name='state' + loaded = checkpointer.load( + 0, abstract_state=self.abstract_pytree, checkpointable_name='state' ) test_utils.assert_tree_equal(self, self.pytree, loaded) @parameterized.parameters((True,), (False,)) - def test_load_latest_pytree(self, latest_arg_is_none): + def test_load_latest(self, latest_arg_is_none): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree, metrics={'loss': 0.5}) + self.save(checkpointer, 0, self.pytree, metrics={'loss': 0.5}) new_pytree = { 'jax_array': self.pytree['jax_array'], 'numpy_array': array_test_utils.create_numpy_pytree(add=1)[0], } - self.save_pytree(checkpointer, 1, new_pytree, metrics={'loss': 0.4}) + self.save(checkpointer, 1, new_pytree, metrics={'loss': 0.4}) latest = checkpointer.latest self.assertIsInstance(latest, ocp.training.CheckpointMetadata) @@ -171,9 +173,9 @@ def test_load_latest_pytree(self, latest_arg_is_none): self.assertDictEqual(latest.metrics, {'loss': 0.4}) if latest_arg_is_none: - loaded = checkpointer.load_pytree(abstract_pytree=self.abstract_pytree) + loaded = checkpointer.load(abstract_state=self.abstract_pytree) else: - loaded = checkpointer.load_pytree(latest, self.abstract_pytree) + loaded = checkpointer.load(latest, self.abstract_pytree) test_utils.assert_tree_equal(self, new_pytree, loaded) @@ -181,30 +183,28 @@ def test_overwrites(self): plus_one_pytree, _ = array_test_utils.create_numpy_pytree(add=1) checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree) - checkpointer.save_pytree(0, plus_one_pytree, overwrite=True) - test_utils.assert_tree_equal( - self, plus_one_pytree, checkpointer.load_pytree(0) - ) + self.save(checkpointer, 0, self.pytree) + checkpointer.save(0, plus_one_pytree, overwrite=True) + test_utils.assert_tree_equal(self, plus_one_pytree, checkpointer.load(0)) def test_step_already_exists(self): plus_one_pytree, _ = array_test_utils.create_numpy_pytree(add=1) checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree) + self.save(checkpointer, 0, self.pytree) with self.assertRaises(ocp.training.errors.StepAlreadyExistsError): - checkpointer.save_pytree(0, plus_one_pytree) + checkpointer.save(0, plus_one_pytree) def test_load_non_existent_step(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree) - with self.subTest('load_pytree'): + self.save(checkpointer, 0, self.pytree) + with self.subTest('load'): with self.assertRaises(FileNotFoundError): - checkpointer.load_pytree(1) - with self.subTest('pytree_metadata'): + checkpointer.load(1) + with self.subTest('metadata'): with self.assertRaises(FileNotFoundError): - checkpointer.pytree_metadata(1) + checkpointer.metadata(1) @parameterized.parameters( (save_decision_policies.ContinuousCheckpointingPolicy(), range(10)), @@ -228,7 +228,7 @@ def test_steps( self.enter_context(checkpointer) for step in range(num_steps): self.assertEqual(checkpointer.should_save(step), step in expected_steps) - saved = self.save_pytree(checkpointer, step, self.pytree) + saved = self.save(checkpointer, step, self.pytree) self.assertEqual(saved, step in expected_steps) self.assertLen(checkpointer.checkpoints, len(expected_steps)) @@ -245,18 +245,18 @@ def test_force_save_ignores_save_decision_policy(self): ) self.enter_context(checkpointer) - self.assertTrue(checkpointer.save_pytree(0, self.pytree)) - self.assertFalse(checkpointer.save_pytree(1, self.pytree)) - self.assertTrue(checkpointer.save_pytree(2, self.pytree)) + self.assertTrue(checkpointer.save(0, self.pytree)) + self.assertFalse(checkpointer.save(1, self.pytree)) + self.assertTrue(checkpointer.save(2, self.pytree)) self.assertLen(checkpointer.checkpoints, 2) self.assertSequenceEqual( [c.step for c in checkpointer.checkpoints], [0, 2] ) - self.assertTrue(checkpointer.save_pytree(3, self.pytree, force=True)) - self.assertTrue(checkpointer.save_pytree(4, self.pytree, force=True)) - self.assertTrue(checkpointer.save_pytree(5, self.pytree, force=True)) + self.assertTrue(checkpointer.save(3, self.pytree, force=True)) + self.assertTrue(checkpointer.save(4, self.pytree, force=True)) + self.assertTrue(checkpointer.save(5, self.pytree, force=True)) self.assertLen(checkpointer.checkpoints, 5) self.assertSequenceEqual( @@ -269,8 +269,8 @@ def test_garbage_collection(self): def test_reload(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree) - self.save_pytree(checkpointer, 1, self.pytree) + self.save(checkpointer, 0, self.pytree) + self.save(checkpointer, 1, self.pytree) self.assertLen(checkpointer.checkpoints, 2) assert checkpointer.latest is not None self.assertEqual(checkpointer.latest.step, 1) @@ -291,15 +291,15 @@ def test_reload(self): def test_skips_when_ongoing_save(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - saved_0 = checkpointer.save_pytree_async(0, self.pytree) - saved_1 = checkpointer.save_pytree(1, self.pytree) + saved_0 = checkpointer.save_async(0, self.pytree) + saved_1 = checkpointer.save(1, self.pytree) self.assertTrue(saved_0.result()) self.assertFalse(saved_1) self.assertLen(checkpointer.checkpoints, 1) assert checkpointer.latest is not None self.assertEqual(checkpointer.latest.step, 0) - def test_save_pytree_async(self): + def test_save_async(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) @@ -318,7 +318,7 @@ def mock_serialize(*args, **kwargs): ) step = 0 - response = checkpointer.save_pytree_async(step, self.pytree) + response = checkpointer.save_async(step, self.pytree) initial_d_files_mtimes = tree_test_utils.get_d_files_mtimes( self.directory / str(step) ) @@ -341,7 +341,7 @@ def mock_serialize(*args, **kwargs): ) ) - def test_save_pytree_async_on_complete(self): + def test_save_async_on_complete(self): save_policy = save_decision_policies.FixedIntervalPolicy(100) checkpointer = Checkpointer( self.directory, save_decision_policy=save_policy @@ -357,7 +357,7 @@ def callback(saved): condition.notify_all() # Step 0 should save - response = checkpointer.save_pytree_async(0, self.pytree) + response = checkpointer.save_async(0, self.pytree) response.on_complete(callback) with condition: @@ -370,7 +370,7 @@ def callback(saved): results.clear() # Step 1 should not save - response = checkpointer.save_pytree_async(1, self.pytree) + response = checkpointer.save_async(1, self.pytree) response.on_complete(callback) with condition: @@ -384,7 +384,7 @@ def callback(saved): def test_close(self): checkpointer = Checkpointer(self.directory) step_path = self.directory / '0' - checkpointer.save_pytree_async(0, self.pytree) + checkpointer.save_async(0, self.pytree) self.assertFalse(step_path.exists()) # Not finalized yet. # But a tmp dir should have been created. self.assertNotEmpty(list(self.directory.iterdir())) @@ -394,7 +394,7 @@ def test_close(self): def test_context_manager_close(self): step_path = self.directory / '0' with Checkpointer(self.directory) as checkpointer: - checkpointer.save_pytree_async(0, self.pytree) + checkpointer.save_async(0, self.pytree) self.assertFalse(step_path.exists()) # Not finalized yet. # But a tmp dir should have been created. self.assertNotEmpty(list(self.directory.iterdir())) @@ -408,7 +408,7 @@ def test_step_name_format(self): ), ) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree) + self.save(checkpointer, 0, self.pytree) self.assertTrue((self.directory / 'foo_0').exists()) self.assertFalse((self.directory / '0').exists()) @@ -430,9 +430,9 @@ def test_root_metadata(self, reinitialize_checkpointer): @parameterized.product( reinitialize_checkpointer=(True, False), ) - def test_pytree_metadata(self, reinitialize_checkpointer): + def test_metadata(self, reinitialize_checkpointer): checkpointer = Checkpointer(self.directory) - self.save_pytree( + self.save( checkpointer, 0, self.pytree, @@ -443,7 +443,7 @@ def test_pytree_metadata(self, reinitialize_checkpointer): checkpointer.close() checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) - checkpoint_metadata = checkpointer.pytree_metadata(0) + checkpoint_metadata = checkpointer.metadata(0) self.assertIsInstance(checkpoint_metadata, CheckpointMetadata) self.assertDictEqual(checkpoint_metadata.custom_metadata, {'baz': 'qux'}) self.assertDictEqual(checkpoint_metadata.metrics, {'loss': 0.5}) @@ -462,7 +462,7 @@ def test_checkpointables_metadata(self, reinitialize_checkpointer): self.save_checkpointables( checkpointer, 0, - {'pytree': self.pytree, 'baz': Baz(123, 'hi')}, + {STATE_CHECKPOINTABLE_KEY: self.pytree, 'baz': Baz(123, 'hi')}, metrics={'loss': 0.5}, custom_metadata={'baz': 'qux'}, ) @@ -478,10 +478,10 @@ def test_checkpointables_metadata(self, reinitialize_checkpointer): self.assertIsNotNone(checkpoint_metadata.commit_timestamp_nsecs) self.assertIsInstance(checkpoint_metadata.metadata, dict) self.assertSameElements( - checkpoint_metadata.metadata.keys(), ['pytree', 'baz'] + checkpoint_metadata.metadata.keys(), [STATE_CHECKPOINTABLE_KEY, 'baz'] ) self.assertSameElements( - checkpoint_metadata.metadata['pytree'].keys(), + checkpoint_metadata.metadata[STATE_CHECKPOINTABLE_KEY].keys(), ['jax_array', 'numpy_array'], ) # Saved with v1 save_checkpointables, so v1 handler registry can resolve @@ -502,7 +502,7 @@ def test_custom_checkpointables(self): ocp.Context(checkpointables_options=checkpointables_options) ) checkpointables = { - 'pytree': self.pytree, + STATE_CHECKPOINTABLE_KEY: self.pytree, 'foo': Foo(123, 'hi'), 'bar': Bar(456, 'bye'), } @@ -514,9 +514,13 @@ def test_custom_checkpointables(self): if multihost.is_pathways_backend(): self.skipTest('Sharding metadata not present in Pathways.') loaded = checkpointer.load_checkpointables(0) - self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) + self.assertSameElements( + loaded.keys(), [STATE_CHECKPOINTABLE_KEY, 'foo', 'bar'] + ) test_utils.assert_tree_equal( - self, checkpointables['pytree'], loaded['pytree'] + self, + checkpointables[STATE_CHECKPOINTABLE_KEY], + loaded[STATE_CHECKPOINTABLE_KEY], ) self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) @@ -531,35 +535,47 @@ def test_custom_checkpointables(self): ) with ocp.Context(checkpointables_options=checkpointables_options): loaded = ocp.load_checkpointables(self.directory / '0') - self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) + self.assertSameElements( + loaded.keys(), [STATE_CHECKPOINTABLE_KEY, 'foo', 'bar'] + ) test_utils.assert_tree_equal( - self, checkpointables['pytree'], loaded['pytree'] + self, + checkpointables[STATE_CHECKPOINTABLE_KEY], + loaded[STATE_CHECKPOINTABLE_KEY], ) self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) with self.subTest('load_with_abstract_checkpointables'): abstract_checkpointables = { - 'pytree': self.abstract_pytree, + STATE_CHECKPOINTABLE_KEY: self.abstract_pytree, 'foo': AbstractFoo(), 'bar': AbstractBar(), } loaded = checkpointer.load_checkpointables(0, abstract_checkpointables) - self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) - test_utils.assert_tree_equal(self, self.pytree, loaded['pytree']) + self.assertSameElements( + loaded.keys(), [STATE_CHECKPOINTABLE_KEY, 'foo', 'bar'] + ) + test_utils.assert_tree_equal( + self, self.pytree, loaded[STATE_CHECKPOINTABLE_KEY] + ) self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) with self.subTest('load_with_abstract_checkpointables_none_values'): if multihost.is_pathways_backend(): self.skipTest('Sharding metadata not present in Pathways.') abstract_checkpointables = { - 'pytree': None, + STATE_CHECKPOINTABLE_KEY: None, 'foo': None, 'bar': None, } loaded = checkpointer.load_checkpointables(0, abstract_checkpointables) - self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) + self.assertSameElements( + loaded.keys(), [STATE_CHECKPOINTABLE_KEY, 'foo', 'bar'] + ) test_utils.assert_tree_equal( - self, checkpointables['pytree'], loaded['pytree'] + self, + checkpointables[STATE_CHECKPOINTABLE_KEY], + loaded[STATE_CHECKPOINTABLE_KEY], ) self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) @@ -583,7 +599,7 @@ def test_load_with_switched_abstract_checkpointables(self): ocp.Context(checkpointables_options=checkpointables_options) ) checkpointables = { - 'pytree': self.pytree, + STATE_CHECKPOINTABLE_KEY: self.pytree, 'foo': Foo(123, 'hi'), 'bar': Bar(456, 'bye'), } @@ -658,12 +674,12 @@ def should_save( ) self.enter_context(checkpointer) for step in range(0, 30): - self.save_pytree(checkpointer, step, self.pytree) + self.save(checkpointer, step, self.pytree) self.assertNotEmpty(checkpointer.checkpoints) self.assertLess(len(checkpointer.checkpoints), 30) checkpointer_metadata = [ - checkpointer.pytree_metadata(metadata.step) + checkpointer.metadata(metadata.step) for metadata in checkpointer.checkpoints ] for i in range(1, len(checkpointer.checkpoints)): @@ -705,7 +721,7 @@ def now(cls, tz=None): # ) # mock_dt.fromtimestamp.side_effect = datetime.datetime.fromtimestamp # mock_dt.timestamp.return_value = checkpoint_times[step] - checkpointer.save_pytree(step, self.pytree) + checkpointer.save(step, self.pytree) self.assertLen(checkpointer.checkpoints, len(expected_steps)) self.assertSequenceEqual( @@ -759,7 +775,7 @@ def test_preservation_metrics(self, policy, expected_steps): ] checkpointer = Checkpointer(self.directory, preservation_policy=policy) for step in range(num_steps): - checkpointer.save_pytree(step, self.pytree, metrics=all_metrics[step]) + checkpointer.save(step, self.pytree, metrics=all_metrics[step]) self.assertLen(checkpointer.checkpoints, len(expected_steps)) self.assertSequenceEqual( @@ -795,24 +811,24 @@ def test_context_constructor_override(self): ) checkpointer = Checkpointer(self.directory, context=ctx1) self.enter_context(checkpointer) - self.save_pytree(checkpointer, 0, self.pytree) + self.save(checkpointer, 0, self.pytree) with self.subTest('constructor_override_ocdbt'): # Default use_ocdbt is True, so set to False to prove constructor arg is # used. - pytree_dir = self.directory / '0' / 'pytree' + pytree_dir = self.directory / '0' / STATE_CHECKPOINTABLE_KEY self.assertFalse( (pytree_dir / 'manifest.ocdbt').exists(), f'Expected NO manifest.ocdbt under {pytree_dir}', ) with self.subTest('constructor_override_partial_load'): - loaded = checkpointer.load_pytree(0, self.abstract_pytree) + loaded = checkpointer.load(0, self.abstract_pytree) test_utils.assert_tree_equal(self, self.pytree, loaded) # Test partial load override. partial_abstract = {'jax_array': self.abstract_pytree['jax_array']} - loaded_partial = checkpointer.load_pytree(0, partial_abstract) + loaded_partial = checkpointer.load(0, partial_abstract) expected_pytree = {'jax_array': self.pytree['jax_array']} test_utils.assert_tree_equal(self, expected_pytree, loaded_partial) @@ -824,9 +840,9 @@ def test_context_constructor_override(self): ) ) with ctx2: - self.save_pytree(checkpointer, 1, self.pytree) + self.save(checkpointer, 1, self.pytree) - pytree_dir_1 = self.directory / '1' / 'pytree' + pytree_dir_1 = self.directory / '1' / STATE_CHECKPOINTABLE_KEY self.assertTrue( (pytree_dir_1 / 'manifest.ocdbt').exists(), f'Expected manifest.ocdbt under {pytree_dir_1}', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/metadata/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/metadata/types.py index 64117ce8d..bd9afe386 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/metadata/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/metadata/types.py @@ -49,7 +49,7 @@ class CheckpointMetadata( 2. **Lower level** (individual path API): Accessed via free functions. `CheckpointMetadata` objects are returned by both API levels using the same - core methods (:py:func:`~.v1.pytree_metadata` and + core methods (:py:func:`~.v1.metadata` and :py:func:`~.v1.checkpointables_metadata`), reflecting this inherent symmetry. See superclass documentation for more information, and for a list of base @@ -63,10 +63,10 @@ class CheckpointMetadata( # Higher level (sequence-of-steps API) with ocp.training.Checkpointer('/path/to/my/checkpoints') as ckptr: - ckpt_meta = ckptr.pytree_metadata(100) + ckpt_meta = ckptr.metadata(100) # Lower level (individual path API) - ckpt_meta = ocp.pytree_metadata('/path/to/my/checkpoints/100') + ckpt_meta = ocp.metadata('/path/to/my/checkpoints/100') # Inspect checkpoint-level properties print(f'Init time (ns): {ckpt_meta.init_timestamp_nsecs}') diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py index 4a0d2be72..9b47b9b63 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/pathways/snapshotter.py @@ -39,6 +39,7 @@ def __init__(self, *, replica_axis_index: int = 0): self._lock = threading.Lock() self._queue = queue.Queue(maxsize=1) self.replica_axis_index = replica_axis_index + self._worker_thread = threading.Thread(target=self._worker, daemon=True) self._worker_thread.start() @@ -52,7 +53,7 @@ def _worker(self): finally: self._queue.task_done() - def save_pytree(self, step: int, state: tree_types.PyTree) -> None: + def save(self, step: int, state: tree_types.PyTree) -> None: """Backs up JAX array states to pinned host memory, asynchronously. If previous snapshotting requests are still in progress, this request may @@ -79,7 +80,7 @@ def save_pytree(self, step: int, state: tree_types.PyTree) -> None: self._queue.put((pinned_state, step)) - def load_pytree( + def load( self, abstract_state: tree_types.PyTree, *, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/partial.py b/checkpoint/orbax/checkpoint/experimental/v1/partial.py index 858e16eb4..b00c736cc 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/partial.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/partial.py @@ -17,7 +17,7 @@ # pylint: disable=g-importing-member, unused-import, g-multiple-import from orbax.checkpoint.experimental.v1._src.partial.saving import ( - save_pytree, - save_pytree_async, + save, + save_async, finalize, ) diff --git a/docs/api_reference/checkpoint.v1.loading.rst b/docs/api_reference/checkpoint.v1.loading.rst index 31f912057..9c75aaa7e 100644 --- a/docs/api_reference/checkpoint.v1.loading.rst +++ b/docs/api_reference/checkpoint.v1.loading.rst @@ -7,7 +7,7 @@ Checkpoint Loading Loading functions ------------------------------------------------------------ -.. autofunction:: load_pytree +.. autofunction:: load .. autofunction:: load_checkpointables -.. autofunction:: load_pytree_async +.. autofunction:: load_async .. autofunction:: load_checkpointables_async diff --git a/docs/api_reference/checkpoint.v1.metadata.rst b/docs/api_reference/checkpoint.v1.metadata.rst index 9b56b3e54..5e8a4e33e 100644 --- a/docs/api_reference/checkpoint.v1.metadata.rst +++ b/docs/api_reference/checkpoint.v1.metadata.rst @@ -16,5 +16,5 @@ CheckpointMetadata Loading functions ------------------------------------------------------------ -.. autofunction:: pytree_metadata +.. autofunction:: metadata .. autofunction:: checkpointables_metadata diff --git a/docs/api_reference/checkpoint.v1.partial.rst b/docs/api_reference/checkpoint.v1.partial.rst index 7a07bced5..30e2eacd3 100644 --- a/docs/api_reference/checkpoint.v1.partial.rst +++ b/docs/api_reference/checkpoint.v1.partial.rst @@ -8,6 +8,6 @@ Saving ------------------------------------------------------------ -.. autofunction:: save_pytree -.. autofunction:: save_pytree_async +.. autofunction:: save +.. autofunction:: save_async .. autofunction:: finalize diff --git a/docs/api_reference/checkpoint.v1.rst b/docs/api_reference/checkpoint.v1.rst index 88eae462f..cfbf0ecba 100644 --- a/docs/api_reference/checkpoint.v1.rst +++ b/docs/api_reference/checkpoint.v1.rst @@ -33,21 +33,21 @@ Types Loading ~~~~~~~ -.. autofunction:: load_pytree -.. autofunction:: load_pytree_async +.. autofunction:: load +.. autofunction:: load_async .. autofunction:: load_checkpointables .. autofunction:: load_checkpointables_async Saving ~~~~~~ -.. autofunction:: save_pytree -.. autofunction:: save_pytree_async +.. autofunction:: save +.. autofunction:: save_async .. autofunction:: save_checkpointables .. autofunction:: save_checkpointables_async Metadata ~~~~~~~~ -.. autofunction:: pytree_metadata +.. autofunction:: metadata .. autofunction:: checkpointables_metadata .. autoclass:: PyTreeMetadata .. autoclass:: CheckpointMetadata diff --git a/docs/api_reference/checkpoint.v1.saving.rst b/docs/api_reference/checkpoint.v1.saving.rst index a5abca02c..c8fa32cfd 100644 --- a/docs/api_reference/checkpoint.v1.saving.rst +++ b/docs/api_reference/checkpoint.v1.saving.rst @@ -7,7 +7,7 @@ Checkpoint Saving Saving functions ------------------------------------------------------------ -.. autofunction:: save_pytree +.. autofunction:: save .. autofunction:: save_checkpointables -.. autofunction:: save_pytree_async +.. autofunction:: save_async .. autofunction:: save_checkpointables_async diff --git a/docs/guides/checkpoint/v1/async_checkpointing.ipynb b/docs/guides/checkpoint/v1/async_checkpointing.ipynb index c41fef6ab..da315d7c1 100644 --- a/docs/guides/checkpoint/v1/async_checkpointing.ipynb +++ b/docs/guides/checkpoint/v1/async_checkpointing.ipynb @@ -91,9 +91,9 @@ }, "cell_type": "code", "source": [ + "from etils import epath\n", "import numpy as np\n", "import orbax.checkpoint.experimental.v1 as ocp\n", - "from etils import epath\n", "\n", "root_dir = epath.Path('/tmp/async_checkpointing')\n", "\n", @@ -128,7 +128,7 @@ "path = root_dir / 'sync'\n", "path.rmtree(missing_ok=True)\n", "\n", - "ocp.save_pytree(path, train_state)" + "ocp.save(path, train_state)" ], "outputs": [], "execution_count": 6 @@ -150,7 +150,7 @@ }, "cell_type": "markdown", "source": [ - "For async save, simply use `save_pytree_async(...)` instead of `save_pytree(...)`. Calling it will kick off the checkpoint save in a background thread, and return a `response` object without waiting for completion. At this point, other work can be performed in the main thread, and `response.result()` can be called to block until completion." + "For async save, simply use `save_async(...)` instead of `save(...)`. Calling it will kick off the checkpoint save in a background thread, and return a `response` object without waiting for completion. At this point, other work can be performed in the main thread, and `response.result()` can be called to block until completion." ] }, { @@ -162,7 +162,7 @@ "path = root_dir / 'async'\n", "path.rmtree(missing_ok=True)\n", "\n", - "response = ocp.save_pytree_async(path, train_state)\n", + "response = ocp.save_async(path, train_state)\n", "### Do some other work...\n", "response.result()" ], @@ -189,8 +189,8 @@ "To save multiple checkpointables together, Orbax provides free functions in both blocking and async flavors: `save_checkpointables(...)` and `save_checkpointables_async(...)`.\n", "\n", "And the same goes with {py:class}`training.Checkpointer ` class:\n", - "* `training.Checkpointer.save_pytree(...)`\n", - "* `training.Checkpointer.save_pytree_async(...)`\n", + "* `training.Checkpointer.save(...)`\n", + "* `training.Checkpointer.save_async(...)`\n", "* `training.Checkpointer.save_checkpointables(...)`\n", "* `training.Checkpointer.save_checkpointables_async(...)`\n" ] diff --git a/docs/guides/checkpoint/v1/checkpoint_format.ipynb b/docs/guides/checkpoint/v1/checkpoint_format.ipynb index ab05be85b..96003a724 100644 --- a/docs/guides/checkpoint/v1/checkpoint_format.ipynb +++ b/docs/guides/checkpoint/v1/checkpoint_format.ipynb @@ -124,7 +124,7 @@ "Similarly, we can use a different API (see {doc}`Checkpointing PyTrees`):\n", "\n", "```\n", - "ocp.save_pytree(\n", + "ocp.save(\n", " '/path/to/my/checkpoint/',\n", " pytree_of_arrays,\n", ")\n", @@ -444,7 +444,7 @@ }, "cell_type": "code", "source": [ - "pprint.pp(ocp.pytree_metadata(directory / 'ckpt-0').metadata)" + "pprint.pp(ocp.metadata(directory / 'ckpt-0').metadata)" ], "outputs": [], "execution_count": null diff --git a/docs/guides/checkpoint/v1/checkpointables.ipynb b/docs/guides/checkpoint/v1/checkpointables.ipynb index 7de0ceb43..aa33ec485 100644 --- a/docs/guides/checkpoint/v1/checkpointables.ipynb +++ b/docs/guides/checkpoint/v1/checkpointables.ipynb @@ -35,7 +35,7 @@ "source": [ "An Orbax checkpoint is not just a single, opaque unit. The checkpoint actually consists of a bundle of named objects. In this sense, the checkpoint is really more like a dictionary.\n", "\n", - "When you save a checkpoint using {py:func}`ocp.save_pytree `, this produces a bundle with one key/value:\n", + "When you save a checkpoint using {py:func}`ocp.save `, this produces a bundle with one key/value:\n", "\n", "```\n", "{\n", @@ -252,7 +252,7 @@ "outputs": [], "source": [ "train_state = fake_train_step(0)\n", - "ocp.save_pytree(directory / 'just_pytree', train_state)" + "ocp.save(directory / 'just_pytree', train_state)" ] }, { @@ -306,7 +306,7 @@ "id": "7HTGhs3yzEgT" }, "source": [ - "### Interoperating with `save_pytree` / `load_pytree`\n", + "### Interoperating with `save` / `load`\n", "\n" ] }, @@ -316,7 +316,7 @@ "id": "nSWf2rPrAvLj" }, "source": [ - "The APIs `save_pytree` and `load_pytree` interoperate with `save_checkpointables` and `load_checkpointables`. Recall that `save_pytree` just produces a special bundle where the only key is \"pytree\" and the only checkpointable is the PyTree object. Calling `load_pytree` will only load the key named `pytree`, regardless of what other checkpointables are present." + "The APIs `save` and `load` interoperate with `save_checkpointables` and `load_checkpointables`. Recall that `save` just produces a special bundle where the only key is \"pytree\" and the only checkpointable is the PyTree object. Calling `load` will only load the key named `pytree`, regardless of what other checkpointables are present." ] }, { @@ -327,7 +327,7 @@ }, "outputs": [], "source": [ - "ocp.load_pytree(directory / 'pytree_and_dataset')" + "ocp.load(directory / 'pytree_and_dataset')" ] }, { @@ -369,7 +369,7 @@ "id": "y5HfP9YeAKGC" }, "source": [ - "These are still loadable as normal with {py:func}`load_checkpointables `, but {py:func}`load_pytree ` will fail because there is no key named \"pytree\"." + "These are still loadable as normal with {py:func}`load_checkpointables `, but {py:func}`load ` will fail because there is no key named \"pytree\"." ] }, { @@ -392,7 +392,7 @@ "outputs": [], "source": [ "try:\n", - " ocp.load_pytree(directory / 'train_state_and_dataset')\n", + " ocp.load(directory / 'train_state_and_dataset')\n", "except BaseException as e:\n", " print(e)" ] diff --git a/docs/guides/checkpoint/v1/checkpointing_and_exporting_jax_models.ipynb b/docs/guides/checkpoint/v1/checkpointing_and_exporting_jax_models.ipynb index 29a86455f..32e3013d6 100644 --- a/docs/guides/checkpoint/v1/checkpointing_and_exporting_jax_models.ipynb +++ b/docs/guides/checkpoint/v1/checkpointing_and_exporting_jax_models.ipynb @@ -422,8 +422,8 @@ " for _ in range(num_training_steps):\n", " step_to_save_at = current_loop_state['step']\n", "\n", - " # `save_pytree` takes the current step, the state to save, and optional metrics.\n", - " saved = ckptr.save_pytree(step_to_save_at, current_loop_state, metrics={'accuracy': 0.85})\n", + " # `save` takes the current step, the state to save, and optional metrics.\n", + " saved = ckptr.save(step_to_save_at, current_loop_state, metrics={'accuracy': 0.85})\n", "\n", " if saved: # Will be True if the save_decision_policy decided to save.\n", " print(f\" Saved checkpoint for step {step_to_save_at}...\")\n", @@ -467,8 +467,8 @@ " print(f\"Restore from the latest checkpoint in {training_ckpt_dir}...\")\n", "\n", " # It returns None if no checkpoint is found.\n", - " resumed_train_state = ckptr.load_pytree(\n", - " abstract_pytree=simulated_train_state # Provide an abstract state for structure and sharding.\n", + " resumed_train_state = ckptr.load(\n", + " abstract_state=simulated_train_state # Provide an abstract state for structure and sharding.\n", " )\n", "\n", "# If a checkpoint was successfully loaded, resumed_train_state will not be None.\n", @@ -565,12 +565,12 @@ "cleanup_directory_if_exists(str(final_params_save_dir))\n", "\n", "print(f\"Saving final parameters to: {final_params_save_dir}...\")\n", - "ocp.save_pytree(\n", + "ocp.save(\n", " path=final_params_save_dir,\n", - " pytree=final_model_params_to_save,\n", + " state=final_model_params_to_save,\n", " overwrite=True # overwrites an existing checkpoint in directory\n", ")\n", - "print(\"Final model parameters saved via `save_pytree`.\")" + "print(\"Final model parameters saved via `save`.\")" ] }, { @@ -607,9 +607,9 @@ "source": [ "if final_params_save_dir.exists() and len(os.listdir(str(final_params_save_dir))) > 0:\n", " print(f\"Loading parameters from {final_params_save_dir} for verification...\")\n", - " loaded_final_params = ocp.load_pytree(\n", + " loaded_final_params = ocp.load(\n", " final_params_save_dir,\n", - " abstract_pytree=final_model_params_to_save # Use instance as a template for structure and sharding.\n", + " abstract_state=final_model_params_to_save # Use instance as a template for structure and sharding.\n", " )\n", " # Check that the loaded parameters match the original ones.\n", " params_match = jax.tree_util.tree_all(\n", diff --git a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb index 659bd814e..9a429295a 100644 --- a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb +++ b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb @@ -37,16 +37,16 @@ "\n", "To save:\n", "\n", - "* `ocp.save_pytree(...)`\n", - "* `ocp.save_pytree_async(...)`\n", - "* `training.Checkpointer.save_pytree(...)`\n", - "* `training.Checkpointer.save_pytree_async(...)`\n", + "* `ocp.save(...)`\n", + "* `ocp.save_async(...)`\n", + "* `training.Checkpointer.save(...)`\n", + "* `training.Checkpointer.save_async(...)`\n", "\n", "To load:\n", - "* `ocp.load_pytree(...)`\n", - "* `ocp.load_pytree_async(...)`\n", - "* `training.Checkpointer.load_pytree(...)`\n", - "* `training.Checkpointer.load_pytree_async(...)`\n", + "* `ocp.load(...)`\n", + "* `ocp.load_async(...)`\n", + "* `training.Checkpointer.load(...)`\n", + "* `training.Checkpointer.load_async(...)`\n", "\n", "Of course, the `save_checkpointables(...)` and `load_checkpointables(...)`\n", "flavor APIs can be used to save a PyTree too." @@ -128,7 +128,7 @@ "path.rmtree(missing_ok=True)\n", "\n", "# Simple save using default options:\n", - "ocp.save_pytree(path, pytree)" + "ocp.save(path, pytree)" ] }, { @@ -150,7 +150,7 @@ }, "outputs": [], "source": [ - "loaded = ocp.load_pytree(path)\n", + "loaded = ocp.load(path)\n", "loaded" ] }, @@ -200,11 +200,11 @@ }, "outputs": [], "source": [ - "abstract_pytree = {\n", + "abstract_state = {\n", " 'a': jax.ShapeDtypeStruct(shape=(16,), dtype=np.int32, sharding=sharding),\n", " 'b': jax.ShapeDtypeStruct(shape=(16,), dtype=np.int32, sharding=sharding),\n", "}\n", - "abstract_pytree" + "abstract_state" ] }, { @@ -215,8 +215,8 @@ }, "outputs": [], "source": [ - "# Load using abstract_pytree.\n", - "loaded = ocp.load_pytree(path, abstract_pytree)\n", + "# Load using abstract_state.\n", + "loaded = ocp.load(path, abstract_state)\n", "loaded" ] }, @@ -237,7 +237,7 @@ "id": "tQ6L_wtnVq_8" }, "source": [ - "The `pytree_metadata` method returns a `CheckpointMetadata` object with a number of properties, but the core `metadata` property is just an abstract PyTree. This can also be used for loading as shown below." + "The `metadata` method returns a `CheckpointMetadata` object with a number of properties, but the core `metadata` property is just an abstract PyTree. This can also be used for loading as shown below." ] }, { @@ -248,8 +248,8 @@ }, "outputs": [], "source": [ - "pytree_metadata = ocp.pytree_metadata(path).metadata\n", - "pytree_metadata" + "metadata = ocp.metadata(path).metadata\n", + "metadata" ] }, { @@ -260,7 +260,7 @@ }, "outputs": [], "source": [ - "loaded = ocp.load_pytree(path, pytree_metadata)\n", + "loaded = ocp.load(path, metadata)\n", "loaded" ] }, @@ -295,7 +295,7 @@ }, "outputs": [], "source": [ - "ocp.load_pytree(path, pytree)" + "ocp.load(path, pytree)" ] }, { @@ -357,7 +357,7 @@ "unnecessary metadata reads.\n", "\n", "```\n", - "abstract_pytree = {\n", + "abstract_state = {\n", " 'a': jax.ShapeDtypeStruct(shape=..., dtype=..., sharding=jax.sharding.NamedSharding(...))\n", "}\n", "```\n", @@ -368,7 +368,7 @@ "will be used to restore specific properties for each leaf.\n", "\n", "```\n", - "abstract_pytree = {\n", + "abstract_state = {\n", " 'a': jax.ShapeDtypeStruct,\n", " 'b': int,\n", " 'c': np.ndarray,\n", @@ -381,7 +381,7 @@ "which type each leaf should be loaded as.\n", "\n", "```\n", - "abstract_pytree = {\n", + "abstract_state = {\n", " 'a': None,\n", " 'b': None,\n", "}\n", @@ -393,7 +393,7 @@ "in your code if the checkpoint does not have the structure you expect.\n", "\n", "```\n", - "abstract_pytree = None\n", + "abstract_state = None\n", "```" ] }, @@ -427,8 +427,8 @@ " return x.update(dtype=np.int16)\n", "\n", "\n", - "cast_dtype_abstract_pytree = jax.tree_util.tree_map(\n", - " set_loading_dtype, abstract_pytree\n", + "cast_dtype_abstract_state = jax.tree_util.tree_map(\n", + " set_loading_dtype, abstract_state\n", ")" ] }, @@ -440,7 +440,7 @@ }, "outputs": [], "source": [ - "ocp.load_pytree(path, cast_dtype_abstract_pytree)" + "ocp.load(path, cast_dtype_abstract_state)" ] }, { @@ -476,7 +476,7 @@ }, "outputs": [], "source": [ - "loaded = ocp.load_pytree(path)" + "loaded = ocp.load(path)" ] }, { @@ -517,10 +517,10 @@ " return x.update(sharding=sharding)\n", "\n", "\n", - "change_sharding_abstract_pytree = jax.tree_util.tree_map(\n", - " set_sharding, abstract_pytree\n", + "change_sharding_abstract_state = jax.tree_util.tree_map(\n", + " set_sharding, abstract_state\n", ")\n", - "loaded = ocp.load_pytree(path, change_sharding_abstract_pytree)" + "loaded = ocp.load(path, change_sharding_abstract_state)" ] }, { @@ -551,11 +551,11 @@ }, "outputs": [], "source": [ - "pytree_metadata = ocp.pytree_metadata(path).metadata\n", - "change_sharding_pytree_metadata = jax.tree_util.tree_map(\n", - " lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=sharding), pytree_metadata\n", + "metadata = ocp.metadata(path).metadata\n", + "change_sharding_metadata = jax.tree_util.tree_map(\n", + " lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=sharding), metadata\n", ")\n", - "loaded = ocp.load_pytree(path, change_sharding_pytree_metadata)\n", + "loaded = ocp.load(path, change_sharding_metadata)\n", "(loaded['a'].sharding, loaded['b'].sharding)" ] }, @@ -593,7 +593,7 @@ " 'b': 13.5,\n", " 'c': create_sharded_array(np.arange(8)),\n", "}\n", - "ocp.save_pytree(path, pytree_with_scalars)" + "ocp.save(path, pytree_with_scalars)" ] }, { @@ -604,12 +604,12 @@ }, "outputs": [], "source": [ - "abstract_pytree_with_scalars = {\n", + "abstract_state_with_scalars = {\n", " 'a': float,\n", " 'b': int,\n", " 'c': np.empty((8,)),\n", "}\n", - "ocp.load_pytree(path, abstract_pytree_with_scalars)" + "ocp.load(path, abstract_state_with_scalars)" ] }, { @@ -649,7 +649,7 @@ "path = epath.Path('/tmp/checkpointing-pytrees/partial/')\n", "path.rmtree(missing_ok=True)\n", "\n", - "ocp.save_pytree(path / '1', original_item)" + "ocp.save(path / '1', original_item)" ] }, { @@ -695,7 +695,7 @@ " 'step': 101,\n", "}\n", "\n", - "ocp.load_pytree(path / '1', reference_item)" + "ocp.load(path / '1', reference_item)" ] }, { @@ -725,7 +725,7 @@ " pytree_options=PyTreeOptions(...),\n", " file_options=FileOptions(...),\n", "):\n", - " ocp.save_pytree(path, pytree)\n", + " ocp.save(path, pytree)\n", "```\n", "\n", "Let's explore few examples. Please also take a look at API Reference for specific option details." @@ -778,7 +778,7 @@ }, "outputs": [], "source": [ - "ocp.save_pytree(path / '1', pytree)" + "ocp.save(path / '1', pytree)" ] }, { @@ -789,7 +789,7 @@ }, "outputs": [], "source": [ - "loaded = ocp.load_pytree(path / '1')" + "loaded = ocp.load(path / '1')" ] }, { @@ -838,7 +838,7 @@ " )\n", " )\n", "):\n", - " ocp.save_pytree(path / '2', pytree, overwrite=True)" + " ocp.save(path / '2', pytree, overwrite=True)" ] }, { @@ -849,7 +849,7 @@ }, "outputs": [], "source": [ - "loaded = ocp.load_pytree(path / '2')" + "loaded = ocp.load(path / '2')" ] }, { @@ -892,7 +892,7 @@ " )\n", " )\n", "):\n", - " ocp.save_pytree(path / '3', pytree, overwrite=True)" + " ocp.save(path / '3', pytree, overwrite=True)" ] }, { @@ -903,7 +903,7 @@ }, "outputs": [], "source": [ - "loaded = ocp.load_pytree(path / '3')" + "loaded = ocp.load(path / '3')" ] }, { @@ -950,7 +950,7 @@ " )\n", " )\n", "):\n", - " ocp.save_pytree(path / '4', pytree, overwrite=True)" + " ocp.save(path / '4', pytree, overwrite=True)" ] }, { @@ -997,7 +997,7 @@ " )\n", " )\n", "):\n", - " ocp.save_pytree(path / '5', pytree, overwrite=True)\n", + " ocp.save(path / '5', pytree, overwrite=True)\n", "\n", "!ls /tmp/checkpointing-pytrees/advanced/5/pytree" ] @@ -1048,7 +1048,7 @@ "outputs": [], "source": [ "# Original shape.\n", - "loaded = ocp.load_pytree(path / '1')\n", + "loaded = ocp.load(path / '1')\n", "\n", "(loaded['a'].shape, loaded['b'].shape)" ] @@ -1061,21 +1061,21 @@ }, "outputs": [], "source": [ - "different_shape_abstract_pytree = {\n", + "different_shape_abstract_state = {\n", " 'a': jax.ShapeDtypeStruct(\n", " shape=(8,),\n", - " dtype=abstract_pytree['a'].dtype,\n", - " sharding=abstract_pytree['a'].sharding,\n", + " dtype=abstract_state['a'].dtype,\n", + " sharding=abstract_state['a'].sharding,\n", " ),\n", " 'b': jax.ShapeDtypeStruct(\n", " shape=(32,),\n", - " dtype=abstract_pytree['b'].dtype,\n", - " sharding=abstract_pytree['b'].sharding,\n", + " dtype=abstract_state['b'].dtype,\n", + " sharding=abstract_state['b'].sharding,\n", " ),\n", "}\n", "\n", "try:\n", - " ocp.load_pytree(path / '1', different_shape_abstract_pytree)\n", + " ocp.load(path / '1', different_shape_abstract_state)\n", "except BaseException as e:\n", " print(e)" ] @@ -1104,7 +1104,7 @@ " )\n", " )\n", "):\n", - " loaded = ocp.load_pytree(path / '1', different_shape_abstract_pytree)" + " loaded = ocp.load(path / '1', different_shape_abstract_state)" ] }, { diff --git a/docs/guides/checkpoint/v1/model_surgery.ipynb b/docs/guides/checkpoint/v1/model_surgery.ipynb index c956858ed..2c0379bb1 100644 --- a/docs/guides/checkpoint/v1/model_surgery.ipynb +++ b/docs/guides/checkpoint/v1/model_surgery.ipynb @@ -64,7 +64,7 @@ " ),\n", " pytree,\n", ")\n", - "ocp.save_pytree(path, pytree, overwrite=True)" + "ocp.save(path, pytree, overwrite=True)" ], "outputs": [], "execution_count": 11 @@ -126,7 +126,7 @@ " 'step': np.array([]),\n", "}\n", "\n", - "ocp.load_pytree(path, abstract_tree)" + "ocp.load(path, abstract_tree)" ], "outputs": [], "execution_count": 12 @@ -159,7 +159,7 @@ "}\n", "\n", "try:\n", - " ocp.load_pytree(path, bad_abstract_tree)\n", + " ocp.load(path, bad_abstract_tree)\n", "except Exception as e:\n", " print(e)" ], @@ -193,7 +193,7 @@ " pytree\n", ")\n", "\n", - "ocp.load_pytree(path, easy_abstract_tree)" + "ocp.load(path, easy_abstract_tree)" ], "outputs": [], "execution_count": 14 @@ -204,7 +204,7 @@ }, "cell_type": "markdown", "source": [ - "We may not have direct access to the original PyTree when creating the abstract counterpart, and in that case, we'll need to use the on-disk `pytree_metadata`." + "We may not have direct access to the original PyTree when creating the abstract counterpart, and in that case, we'll need to use the on-disk `metadata`." ] }, { @@ -213,14 +213,14 @@ }, "cell_type": "code", "source": [ - "on_disk_pytree_structure = ocp.pytree_metadata(path).metadata\n", + "on_disk_pytree_structure = ocp.metadata(path).metadata\n", "\n", "abstract_tree_from_metadata = jax.tree.map_with_path(\n", " _create_abstract_leaf_for_partial_load,\n", " on_disk_pytree_structure\n", ")\n", "\n", - "ocp.load_pytree(path, abstract_tree_from_metadata)" + "ocp.load(path, abstract_tree_from_metadata)" ], "outputs": [], "execution_count": 15 @@ -262,7 +262,7 @@ "\n", "# Loading PyTrees with certain leaves missing is unsafe\n", "try:\n", - " ocp.load_pytree(path, abstract_tree)\n", + " ocp.load(path, abstract_tree)\n", "except ValueError as e:\n", " print(e)\n", "\n", @@ -274,7 +274,7 @@ " ),\n", " ),\n", "):\n", - " ocp.load_pytree(path, abstract_tree)" + " ocp.load(path, abstract_tree)" ], "outputs": [], "execution_count": 16 diff --git a/docs/guides/checkpoint/v1/orbax_checkpoint_101.ipynb b/docs/guides/checkpoint/v1/orbax_checkpoint_101.ipynb index fd6258729..2ce451c2b 100644 --- a/docs/guides/checkpoint/v1/orbax_checkpoint_101.ipynb +++ b/docs/guides/checkpoint/v1/orbax_checkpoint_101.ipynb @@ -129,7 +129,7 @@ "cell_type": "code", "source": [ "checkpoint_name = next_checkpoint_name()\n", - "ocp.save_pytree(directory / checkpoint_name, pytree)" + "ocp.save(directory / checkpoint_name, pytree)" ], "outputs": [], "execution_count": null @@ -149,7 +149,7 @@ }, "cell_type": "code", "source": [ - "ocp.load_pytree(directory / checkpoint_name)" + "ocp.load(directory / checkpoint_name)" ], "outputs": [], "execution_count": null @@ -160,7 +160,7 @@ }, "cell_type": "markdown", "source": [ - "We can inspect the tree structure and array properties using `pytree_metadata`." + "We can inspect the tree structure and array properties using `metadata`." ] }, { @@ -169,7 +169,7 @@ }, "cell_type": "code", "source": [ - "ocp.pytree_metadata(directory / checkpoint_name).metadata" + "ocp.metadata(directory / checkpoint_name).metadata" ], "outputs": [], "execution_count": null @@ -180,7 +180,7 @@ }, "cell_type": "markdown", "source": [ - "Note that we are accessing the property: `pytree_metadata(...).metadata`. This is the metadata specific to the PyTree itself. Other properties are general to the entire checkpoint, such as timestamps." + "Note that we are accessing the property: `metadata(...).metadata`. This is the metadata specific to the PyTree itself. Other properties are general to the entire checkpoint, such as timestamps." ] }, { @@ -237,7 +237,7 @@ " # to initialize the current experiment. The latest checkpoint comes from this\n", " # experiment, and allows us to resume after interruption.\n", " if source_checkpoint_path:\n", - " return ocp.load_pytree(source_checkpoint_path)\n", + " return ocp.load(source_checkpoint_path)\n", " # Otherwise, init from scratch\n", " else:\n", " return initialize_state()" @@ -261,12 +261,12 @@ " train_state = init_or_restore(directory / checkpoint_name)\n", " start_step = 0\n", " else:\n", - " train_state = ckptr.load_pytree()\n", + " train_state = ckptr.load()\n", " start_step = ckptr.latest.step\n", "\n", " for step in range(start_step, total_steps):\n", " train_state = train_step(train_state)\n", - " ckptr.save_pytree(step, train_state)" + " ckptr.save(step, train_state)" ], "outputs": [], "execution_count": null @@ -337,9 +337,9 @@ "in isolation.\n", "\n", "Both API levels share the same core API's:\n", - "* **Saving**: `save_pytree` / `save_checkpointables`\n", - "* **Loading**: `load_pytree` / `load_checkpointables`\n", - "* **Metadata**: `pytree_metadata` / `checkpointables_metadata`\n", + "* **Saving**: `save` / `save_checkpointables`\n", + "* **Loading**: `load` / `load_checkpointables`\n", + "* **Metadata**: `metadata` / `checkpointables_metadata`\n", "\n", "These conceptually perform the same tasks for both API levels, but are accessed\n", "slightly differently. In a training loop, we must create a\n", @@ -351,16 +351,16 @@ "\n", "**Higher level** (sequence-of-steps):\n", "```py\n", - "ocp.save_pytree(100, state)\n", - "ocp.load_pytree(100, abstract_state)\n", - "ocp.pytree_metadata(100)\n", + "ocp.save(100, state)\n", + "ocp.load(100, abstract_state)\n", + "ocp.metadata(100)\n", "```\n", "\n", "**Lower level** (individual paths):\n", "```py\n", - "ocp.save_pytree('/tmp/my/checkpoint', state)\n", - "ocp.load_pytree('/tmp/my/checkpoint', abstract_state)\n", - "ocp.pytree_metadata('/tmp/my/checkpoint')\n", + "ocp.save('/tmp/my/checkpoint', state)\n", + "ocp.load('/tmp/my/checkpoint', abstract_state)\n", + "ocp.metadata('/tmp/my/checkpoint')\n", "```\n", "\n", "See additional documentation on {doc}`Training`\n", diff --git a/docs/guides/checkpoint/v1/orbax_v0_to_v1_migration.ipynb b/docs/guides/checkpoint/v1/orbax_v0_to_v1_migration.ipynb index 7aa5955ca..69d5cc66e 100644 --- a/docs/guides/checkpoint/v1/orbax_v0_to_v1_migration.ipynb +++ b/docs/guides/checkpoint/v1/orbax_v0_to_v1_migration.ipynb @@ -306,7 +306,7 @@ }, "cell_type": "markdown", "source": [ - "An pytree checkpoint in the above layout can be loaded using `ocp.load_pytree(...)` function." + "An pytree checkpoint in the above layout can be loaded using `ocp.load(...)` function." ] }, { @@ -328,7 +328,7 @@ "source": [ "# Load a pytree from a directory with no checkpointables.\n", "\n", - "loaded = ocp.load_pytree(my_checkpoint_dir, checkpointable_name=None)\n", + "loaded = ocp.load(my_checkpoint_dir, checkpointable_name=None)\n", "# Use the loaded pytree.\n", "print('loaded=', loaded)" ], @@ -358,7 +358,7 @@ }, "cell_type": "markdown", "source": [ - "#### Loading pytree checkpoint with `load_pytree(...)`" + "#### Loading pytree checkpoint with `load(...)`" ] }, { @@ -369,10 +369,10 @@ "source": [ "| Restore API | Response\n", ":------- | :-------- |\n", - "|ocp.load_pytree(`step_1234`)|Loads PyTree under subdirectory, `pytree`|\n", - "|ocp.load_pytree(`step_1234`, `checkpointable_name='pytree'`)|Loads PyTree under subdirectory, `pytree`|\n", - "|ocp.load_pytree(`step_1234`, `checkpointable_name='state'`)|Loads PyTree under subdirectory, `state`|\n", - "|ocp.load_pytree(`my_checkpoint`, `checkpointable_name=None`)|Loads PyTree directly from `my_checkpoint`|" + "|ocp.load(`step_1234`)|Loads PyTree under subdirectory, `pytree`|\n", + "|ocp.load(`step_1234`, `checkpointable_name='pytree'`)|Loads PyTree under subdirectory, `pytree`|\n", + "|ocp.load(`step_1234`, `checkpointable_name='state'`)|Loads PyTree under subdirectory, `state`|\n", + "|ocp.load(`my_checkpoint`, `checkpointable_name=None`)|Loads PyTree directly from `my_checkpoint`|" ] }, { @@ -385,12 +385,12 @@ "\n", "| Restore API | Response\n", ":------- | :-------- |\n", - "|ocp.load_pytree(`root_dir`)|Error: expecting a subdir named `pytree`|\n", - "|ocp.load_pytree(`root_dir`, `checkpointable_name='pytree'`)|Error: expecting a subdir named `pytree`|\n", - "|ocp.load_pytree(`root_dir`, `checkpointable_name=None`)|Error: expecting pytree metadata file|\n", - "|ocp.load_pytree(`step_1234`, `checkpointable_name=None`)|Error: expecting pytree metadata file|\n", - "|ocp.load_pytree(`my_checkpoint`)|Error: expecting a subdir named `pytree`|\n", - "|ocp.load_pytree(`my_checkpoint`, `checkpointable_name='pytree'`)|Error: expecting a subdir named `pytree`|" + "|ocp.load(`root_dir`)|Error: expecting a subdir named `pytree`|\n", + "|ocp.load(`root_dir`, `checkpointable_name='pytree'`)|Error: expecting a subdir named `pytree`|\n", + "|ocp.load(`root_dir`, `checkpointable_name=None`)|Error: expecting pytree metadata file|\n", + "|ocp.load(`step_1234`, `checkpointable_name=None`)|Error: expecting pytree metadata file|\n", + "|ocp.load(`my_checkpoint`)|Error: expecting a subdir named `pytree`|\n", + "|ocp.load(`my_checkpoint`, `checkpointable_name='pytree'`)|Error: expecting a subdir named `pytree`|" ] }, { @@ -425,9 +425,9 @@ "| Restore API | Response\n", ":------- | :-------- |\n", "|ocp.load_checkpointables(`root_dir`)|Error: suggesting to try a subdir instead|\n", - "|ocp.load_checkpointables(`my_checkpoint`)|Error: suggesting to use load_pytree instead|\n", + "|ocp.load_checkpointables(`my_checkpoint`)|Error: suggesting to use load instead|\n", "|ocp.load_checkpointables(`root_dir`, `dict(state=abstract_tree, pytree=abstract_tree)`)|Error: suggesting to try a subdir instead|\n", - "|ocp.load_checkpointables(`my_checkpoint`, `dict(state=abstract_tree, pytree=abstract_tree)`)|Error: suggesting to use load_pytree instead|\n" + "|ocp.load_checkpointables(`my_checkpoint`, `dict(state=abstract_tree, pytree=abstract_tree)`)|Error: suggesting to use load instead|\n" ] }, { @@ -464,15 +464,15 @@ "|`latest_step()`|`latest`|\n", "|`reload()`|`reload()`|\n", "|`should_save(step)`|`should_save(step)`|\n", - "|`save(...)`|`save_pytree(...)`, `save_checkpointables(...)`|\n", + "|`save(...)`|`save(...)`, `save_checkpointables(...)`|\n", "||and `save_*_async(...)`|\n", - "|`restore(...)`|`load_pytree(...)`, `load_checkpointables(...)` |\n", + "|`restore(...)`|`load(...)`, `load_checkpointables(...)` |\n", "||and `load_*_async(...)`|\n", - "|`item_metadata(step)`|`pytree_metadata(step)`,|\n", + "|`item_metadata(step)`|`metadata(step)`,|\n", "||`checkpointables_metadata(step)`|\n", - "|`metrics(step)`|`pytree_metadata(step).metrics`,|\n", + "|`metrics(step)`|`metadata(step).metrics`,|\n", "||`checkpointables_metadata(step).metrics`|\n", - "|`metadata(step)`|`pytree_metadata(step)`,|\n", + "|`metadata(step)`|`metadata(step)`,|\n", "||`checkpointables_metadata(step)`|\n", "|`metadata(None)` or `metadata()`|`root_metadata()`|\n", "|`wait_until_finished`|Call `AsyncResponse.result()`|\n", diff --git a/docs/guides/checkpoint/v1/partial_saving.ipynb b/docs/guides/checkpoint/v1/partial_saving.ipynb index 2b05d7450..885c08b5c 100644 --- a/docs/guides/checkpoint/v1/partial_saving.ipynb +++ b/docs/guides/checkpoint/v1/partial_saving.ipynb @@ -16,7 +16,7 @@ "\n", "Partial saving operates on a \"session\" or \"transaction\" model. Instead of overwriting your checkpoint directly, Orbax stages all changes in a temporary, in-progress location. The workflow consists of two stages:\n", "\n", - "1. **Incremental Updates**: Calls to functions like {py:func}`ocp.partial.save_pytree ` contribute data to an in-progress checkpointing session. These changes are staged in a temporary location and are not yet visible at the final checkpoint path. From the user's perspective, the first save call simply begins this incremental process, and subsequent calls add to it.\n", + "1. **Incremental Updates**: Calls to functions like {py:func}`ocp.partial.save ` contribute data to an in-progress checkpointing session. These changes are staged in a temporary location and are not yet visible at the final checkpoint path. From the user's perspective, the first save call simply begins this incremental process, and subsequent calls add to it.\n", "2. **Finalization**: A concluding call to {py:func}`ocp.partial.finalize ` completes the session. This action commits all the staged changes, making the checkpoint available at its final destination and ready for consumption.\n", "\n", "This approach ensures that the modification process is safe and atomic. If the process is interrupted before finalization, your original checkpoint remains untouched.\n", @@ -79,7 +79,7 @@ " 'step': 10000,\n", "}\n", "\n", - "ocp.partial.save_pytree(path, initial_state)\n", + "ocp.partial.save(path, initial_state)\n", "assert not path.exists()\n", "assert (path.parent / (path.name + '.partial_save')).exists()" ] @@ -109,7 +109,7 @@ " },\n", "}\n", "\n", - "ocp.partial.save_pytree(path, new_state)\n", + "ocp.partial.save(path, new_state)\n", "assert not path.exists()\n", "assert (path.parent / (path.name + '.partial_save')).exists()" ] @@ -132,7 +132,7 @@ "cell_type": "code", "source": [ "try:\n", - " ocp.load_pytree(path)\n", + " ocp.load(path)\n", "except Exception as e:\n", " print(\"LOAD ERROR\")\n", " print(e)" @@ -183,7 +183,7 @@ }, "outputs": [], "source": [ - "restored_state = ocp.load_pytree(path)\n", + "restored_state = ocp.load(path)\n", "\n", "expected_state = {\n", " 'params': {\n", @@ -211,7 +211,7 @@ "source": [ "### API Reference\n", "\n", - " - `ocp.partial.save_pytree()` / `ocp.partial.save_pytree_async()`: Saves a PyTree to the temporary partial save location. These functions can be called multiple times.\n", + " - `ocp.partial.save()` / `ocp.partial.save_async()`: Saves a PyTree to the temporary partial save location. These functions can be called multiple times.\n", " - `ocp.partial.finalize()`: Commits the transaction, making the checkpoint permanent at the specified path. This must be called to complete the process.\n", "\n", "### Advanced Workflow: Combining Partial Saving and Partial Restore\n", @@ -279,13 +279,13 @@ " 'optimizer_state': [np.random.rand(128) for _ in range(16)],\n", "}\n", "base_model_state = jax.tree.map(create_sharded_array, base_model_state)\n", - "ocp.save_pytree(base_path, base_model_state)\n", + "ocp.save(base_path, base_model_state)\n", "\n", "abstract_base_model_state = jax.tree.map(\n", " ocp.arrays.to_shape_dtype_struct,\n", " base_model_state\n", ")\n", - "init_ckpt = ocp.load_pytree(base_path, abstract_base_model_state)\n", + "init_ckpt = ocp.load(base_path, abstract_base_model_state)\n", "print(\"\\n--- Setup ---\")\n", "print(f\"Optimizer state exists in initial checkpoint: {'optimizer_state' in init_ckpt}\")\n", "print(f\"Model version exists in initial checkpoint: {'model_version' in init_ckpt}\")\n", @@ -324,7 +324,7 @@ " loading=ocp.options.PyTreeOptions.Loading(partial_load=True)\n", " )\n", "):\n", - " loaded_params = ocp.load_pytree(base_path, abstract_params)" + " loaded_params = ocp.load(base_path, abstract_params)" ] }, { @@ -352,7 +352,7 @@ "\n", "metadata = {'model_version': 'v1.2-finetuned'}\n", "save_params = ocp.tree.merge(save_params, metadata)\n", - "ocp.partial.save_pytree(inference_path, metadata) # Initial partial save for metadata\n", + "ocp.partial.save(inference_path, metadata) # Initial partial save for metadata\n", "\n", "for layer, weights in loaded_params['params']['encoder_stack'].items():\n", " new_weights = weights + np.random.rand(2)\n", @@ -366,7 +366,7 @@ " },\n", " }\n", " save_params = ocp.tree.merge(save_params, stack_layer)\n", - " ocp.partial.save_pytree(inference_path, stack_layer) # One partial save per layer\n", + " ocp.partial.save(inference_path, stack_layer) # One partial save per layer\n", "\n", "ocp.partial.finalize(inference_path)" ] @@ -396,7 +396,7 @@ " ),\n", " save_params\n", ")\n", - "final_ckpt = ocp.load_pytree(inference_path, abstract_params)\n", + "final_ckpt = ocp.load(inference_path, abstract_params)\n", "\n", "print(\"\\n--- Verification ---\")\n", "print(f\"Optimizer state exists in final checkpoint: {'optimizer_state' in final_ckpt}\")\n", diff --git a/docs/guides/checkpoint/v1/training.ipynb b/docs/guides/checkpoint/v1/training.ipynb index 0a3546229..8f5b8262c 100644 --- a/docs/guides/checkpoint/v1/training.ipynb +++ b/docs/guides/checkpoint/v1/training.ipynb @@ -170,7 +170,7 @@ "with training.Checkpointer(root_directory) as ckptr:\n", " num_steps = 10\n", " for step in range(num_steps):\n", - " saved = ckptr.save_pytree(step, pytree)\n", + " saved = ckptr.save(step, pytree)\n", " assert saved\n", " pytree = train_step(pytree)" ] @@ -194,7 +194,7 @@ "outputs": [], "source": [ "with training.Checkpointer(root_directory) as ckptr:\n", - " print(ckptr.load_pytree())" + " print(ckptr.load())" ] }, { @@ -302,7 +302,7 @@ " save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(3),\n", ") as ckptr:\n", " for step in range(10):\n", - " ckptr.save_pytree(step, pytree)" + " ckptr.save(step, pytree)" ] }, { @@ -357,7 +357,7 @@ " num_steps = 10\n", " for step in range(num_steps):\n", " is_final = step == num_steps - 1\n", - " ckptr.save_pytree(\n", + " ckptr.save(\n", " step,\n", " pytree,\n", " metrics={'accuracy': 0.85},\n", @@ -462,7 +462,7 @@ }, "source": [ "In many cases, we wish to cheaply gain information about checkpoint properties\n", - "without loading the entire model. Using the `pytree_metadata` API, we can learn\n", + "without loading the entire model. Using the `metadata` API, we can learn\n", "about the tree structure of our PyTree, as well as information about each array\n", "in the tree." ] @@ -488,11 +488,11 @@ "outputs": [], "source": [ "# Loads metadata from the latest checkpoint.\n", - "ckptr.pytree_metadata()\n", + "ckptr.metadata()\n", "# Loads metadata corresponding to the first step.\n", - "ckptr.pytree_metadata(ckptr.checkpoints[0])\n", + "ckptr.metadata(ckptr.checkpoints[0])\n", "# Loads metadata from a specific integer step.\n", - "ckptr.pytree_metadata(3)\n", + "ckptr.metadata(3)\n", "\n", "print()" ] @@ -514,7 +514,7 @@ }, "outputs": [], "source": [ - "ckptr.pytree_metadata()" + "ckptr.metadata()" ] }, { @@ -535,8 +535,8 @@ }, "outputs": [], "source": [ - "print(ckptr.pytree_metadata().metrics)\n", - "print(ckptr.pytree_metadata().custom_metadata)" + "print(ckptr.metadata().metrics)\n", + "print(ckptr.metadata().custom_metadata)" ] }, { @@ -560,7 +560,7 @@ "source": [ "import pprint\n", "\n", - "pprint.pprint(ckptr.pytree_metadata().metadata)" + "pprint.pprint(ckptr.metadata().metadata)" ] }, { @@ -621,7 +621,7 @@ "root_directory.rmtree(missing_ok=True)\n", "with training.Checkpointer(root_directory) as ckptr:\n", " for step in range(10):\n", - " ckptr.save_pytree(step, pytree)\n", + " ckptr.save(step, pytree)\n", " print([c.step for c in ckptr.checkpoints])" ] }, @@ -653,7 +653,7 @@ ") as ckptr:\n", " print([c.step for c in ckptr.checkpoints])\n", " assert ckptr.latest.step == 9\n", - " ckptr.save_pytree(10, pytree)\n", + " ckptr.save(10, pytree)\n", " print([c.step for c in ckptr.checkpoints])" ] }, @@ -696,11 +696,11 @@ "outputs": [], "source": [ "# Loads from the latest checkpoint.\n", - "ckptr.load_pytree()\n", + "ckptr.load()\n", "# Loads the first available checkpoint in the root directory.\n", - "ckptr.load_pytree(ckptr.checkpoints[0])\n", + "ckptr.load(ckptr.checkpoints[0])\n", "# Loads from a specific integer step.\n", - "ckptr.load_pytree(4)\n", + "ckptr.load(4)\n", "\n", "print()" ] @@ -745,7 +745,7 @@ "sharding = jax.sharding.NamedSharding(\n", " jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec('x')\n", ")\n", - "abstract_pytree = {\n", + "abstract_state = {\n", " 'params': {\n", " 'layer0': jax.ShapeDtypeStruct((8, 2), np.float32, sharding=sharding),\n", " },\n", @@ -762,7 +762,7 @@ }, "outputs": [], "source": [ - "ckptr.load_pytree(None, abstract_pytree)" + "ckptr.load(None, abstract_state)" ] }, { @@ -1084,7 +1084,7 @@ "Now, we can define our main `train()` function, throughout which we will demonstrate checkpointing. A couple notes:
\n", "\n", "* We use `FixedIntervalPolicy` so that our checkpoint is saved every 10 training steps.\n", - "* We use `nnx.state()` to convert the model object (DotReluDot) and optimizer to a checkpointable PyTree, which can then be checkpointed with `ckptr.save_pytree()`\n", + "* We use `nnx.state()` to convert the model object (DotReluDot) and optimizer to a checkpointable PyTree, which can then be checkpointed with `ckptr.save()`\n", "\n", "When actually loading a checkpoint, we do the following:\n", "* If a checkpoint exists in our current checkpoints directory, we restore the latest one.\n", @@ -1122,9 +1122,9 @@ " if ckpt_path or ckptr.latest:\n", " # If a checkpoint already exists, we restore it.\n", " if ckptr.latest:\n", - " loaded_state = ckptr.load_pytree(abstract_pytree=abs_state)\n", + " loaded_state = ckptr.load(abstract_state=abs_state)\n", " else:\n", - " loaded_state = ocp.load_pytree(path=ckpt_path, abstract_pytree=abs_state)\n", + " loaded_state = ocp.load(path=ckpt_path, abstract_state=abs_state)\n", " # Update model and optimizer separately\n", " nnx.update(model, loaded_state['params'])\n", " nnx.update(optimizer, loaded_state['optimizer'])\n", @@ -1157,7 +1157,7 @@ " 'params': nnx.state(model),\n", " 'optimizer': nnx.state(optimizer, nnx.optimizer.OptState),\n", " }\n", - " ckptr.save_pytree(step, state)" + " ckptr.save(step, state)" ] }, {