Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- #v1 Rename `save/load_pytree` to `save/load`. Eliminate most user-facing
"pytree" terminology in favor of "state" as a more specific term.
Add `deprecations.py` for handling deprecated public functions.

## [0.11.40] - 2026-05-18

### Removed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ def test_fn(
resolved_path: str = self._client.resolve(self._xid, step)
assert resolved_path.startswith(LUSTRE_PATH_PREFIX), resolved_path
with metrics.measure("save_cache", metrics_to_measure):
ocp.save_pytree(resolved_path, context.pytree)
ocp.save(resolved_path, context.pytree)
with metrics.measure("finalize_cache", metrics_to_measure):
self._client.finalize(self._xid, step)

with metrics.measure("save", metrics_to_measure):
ocp.save_pytree(context.path / str(step), context.pytree)
ocp.save(context.path / str(step), context.pytree)

abstract_pytree = jax.tree.map(
ocp.arrays.to_shape_dtype_struct, context.pytree
Expand All @@ -148,12 +148,14 @@ def test_fn(
with metrics.measure("wait_prefetch_cache", metrics_to_measure):
self._client.await_transfer(self._xid, step - 1)
with metrics.measure("restore_cache", metrics_to_measure):
restored_pytree = ocp.load_pytree(resolved_path, abstract_pytree)
restored_pytree = ocp.load(
resolved_path, abstract_state=abstract_pytree
)
restored_pytree = self._clear_pytree(restored_pytree)
del restored_pytree
with metrics.measure("restore", metrics_to_measure):
restored_pytree = ocp.load_pytree(
context.path / str(step), abstract_pytree
restored_pytree = ocp.load(
context.path / str(step), abstract_state=abstract_pytree
)
restored_pytree = self._clear_pytree(restored_pytree)
del restored_pytree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_benchmark_test_fn(self):
# Create temporary directories for the test
# Ensure all processes use the same directory
base_dir = self.create_tempdir(name='benchmark')
# cache_dir should not exist before save_pytree is called
# cache_dir should not exist before save is called
cache_dir_path = epath.Path(base_dir.full_path) / 'cache'
work_dir = base_dir.mkdir('work')

Expand Down Expand Up @@ -156,7 +156,7 @@ def resolve_side_effect(xid, step):
)

# Create a checkpoint at the "GCS" location for step 0 so restore works
ocp.save_pytree(gcs_dir_path / '0', pytree)
ocp.save(gcs_dir_path / '0', pytree)

# Run the test function
result = generator.test_fn(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ def test_fn(
jax.profiler.start_trace(context.path / "trace_save")
if options.async_enabled:
with metrics.measure("save_blocking", metrics_to_measure):
f = ocp.save_pytree_async(save_path, pytree)
f = ocp.save_async(save_path, pytree)
with metrics.measure("save_background", metrics_to_measure):
f.result()
else:
with metrics.measure("save_blocking", metrics_to_measure):
ocp.save_pytree(save_path, pytree)
ocp.save(save_path, pytree)
with metrics.measure("save_background", metrics_to_measure):
pass
context.pytree = clear_pytree(context.pytree)
Expand All @@ -179,7 +179,7 @@ def test_fn(
if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_load")
with metrics.measure("load", metrics_to_measure):
restored_pytree = ocp.load_pytree(save_path, abstract_pytree)
restored_pytree = ocp.load(save_path, abstract_state=abstract_pytree)
clear_pytree(restored_pytree)
if options.enable_trace:
jax.profiler.stop_trace()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_multi_slice_abstract_state(
) -> Any:
"""Returns the abstract state for all replicas."""
with ocp.Context(context=context):
metadata = ocp.pytree_metadata(reference_checkpoint_path)
metadata = ocp.metadata(reference_checkpoint_path)
# Abstract tree has shardings on a single replica.
single_replica_abstract_state = (
checkpoint_generation.get_abstract_state_from_sharding_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_get_multi_slice_abstract_state(self):
# Setup real checkpoint and sharding config
pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}}
ref_ckpt_path = self.directory / 'ref_ckpt'
ocp.save_pytree(ref_ckpt_path, pytree)
ocp.save(ref_ckpt_path, pytree)

sharding_config = {
'a': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def test_fn(
)

with ocp.Context(context=options.context):
loaded_pytree = ocp.load_pytree(
reference_checkpoint_path, abstract_pytree
loaded_pytree = ocp.load(
reference_checkpoint_path, abstract_state=abstract_pytree
)

for step in range(options.num_savings):
Expand All @@ -132,12 +132,12 @@ def test_fn(
"ReplicaParallelMultislice: Async Saving pytree to %s.",
save_path,
)
f = ocp.save_pytree_async(save_path, loaded_pytree)
f = ocp.save_async(save_path, loaded_pytree)
with metrics.measure("save_background", metrics_to_measure):
f.result()
else:
with metrics.measure("save_blocking", metrics_to_measure):
ocp.save_pytree(save_path, loaded_pytree)
ocp.save(save_path, loaded_pytree)
with metrics.measure("save_background", metrics_to_measure):
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_benchmark_test_fn(self):
# Setup real checkpoint and sharding config
pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}}
ref_ckpt_path = self.directory / 'ref_ckpt'
ocp.save_pytree(ref_ckpt_path, pytree)
ocp.save(ref_ckpt_path, pytree)

sharding_config = {
'a': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_fn(
)

with ocp.Context(context=options.context):
metadata = ocp.pytree_metadata(reference_checkpoint_path)
metadata = ocp.metadata(reference_checkpoint_path)
abstract_pytree = (
checkpoint_generation.get_abstract_state_from_sharding_config(
reference_sharding_path,
Expand All @@ -108,8 +108,8 @@ def test_fn(
if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_load")
with metrics.measure("load", metrics_to_measure):
restored_pytree = ocp.load_pytree(
reference_checkpoint_path, abstract_pytree
restored_pytree = ocp.load(
reference_checkpoint_path, abstract_state=abstract_pytree
)
benchmark.clear_pytree(restored_pytree)
if options.enable_trace:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_benchmark_test_fn(self):
# Setup real checkpoint and sharding config
pytree = {'a': jnp.arange(8), 'b': {'c': jnp.ones((4, 4))}}
ref_ckpt_path = self.directory / 'ref_ckpt'
ocp.save_pytree(ref_ckpt_path, pytree)
ocp.save(ref_ckpt_path, pytree)

sharding_config = {
'a': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def test_fn(
if options.enable_trace:
jax.profiler.start_trace(context.path / "trace_load")
with metrics.measure("load", metrics_to_measure):
restored_pytree = ocp.load_pytree(
reference_checkpoint_path, abstract_pytree
restored_pytree = ocp.load(
reference_checkpoint_path, abstract_state=abstract_pytree
)
benchmark.clear_pytree(restored_pytree)
if options.enable_trace:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_benchmark_test_fn(self):
# Setup real checkpoint and sharding config
pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}}
ref_ckpt_path = self.directory / 'ref_ckpt'
ocp.save_pytree(ref_ckpt_path, pytree)
ocp.save(ref_ckpt_path, pytree)

sharding_config = {
'a': {
Expand Down
15 changes: 12 additions & 3 deletions checkpoint/orbax/checkpoint/experimental/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@
from orbax.checkpoint.experimental.v1._src.loading.loading import (
load_checkpointables,
load_checkpointables_async,
load_pytree,
load_pytree_async,
load,
load_async,
PLACEHOLDER,
)
from orbax.checkpoint.experimental.v1._src.metadata.loading import (
checkpointables_metadata,
pytree_metadata,
metadata,
)
from orbax.checkpoint.experimental.v1._src.metadata.types import (
CheckpointMetadata,
Expand All @@ -63,6 +63,15 @@
from orbax.checkpoint.experimental.v1._src.saving.saving import (
save_checkpointables,
save_checkpointables_async,
save,
save_async,
)

### DEPRECATED APIS ###
from orbax.checkpoint.experimental.v1._src.deprecations.deprecations import (
save_pytree,
save_pytree_async,
load_pytree,
load_pytree_async,
pytree_metadata,
)
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ class Context(epy.ContextManager):

# Basic usage
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()):
ocp.save_pytree(directory, tree)
ocp.save(directory, tree)

# Inheriting properties from an existing context
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()) as outer_ctx:
# inner_ctx inherits pytree_options, but overrides/adds array_options
with ocp.Context(outer_ctx,
array_options=ocp.options.ArrayOptions()
) as inner_ctx:
ocp.save_pytree(directory, tree)
ocp.save(directory, tree)

Context is not shared across threads::

Expand All @@ -92,9 +92,9 @@ class Context(epy.ContextManager):
with ocp.Context(
pytree_options=ocp.options.PyTreeOptions()
): # Thread #1 creates Context.
# The following save_pytree call is executed in Thread #2, which sees
# The following save call is executed in Thread #2, which sees
# a "default" Context, NOT the one created above.
executor.submit(ocp.save_pytree, directory, tree)
executor.submit(ocp.save, directory, tree)


Attributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def is_prioritized_key_fn(path):

pytree = {'a': jnp.ones((1,))}

# save_pytree will eventually call BasePyTreeCheckpointHandler
# save will eventually call BasePyTreeCheckpointHandler
try:
saving.save_pytree('/tmp/test', pytree)
saving.save('/tmp/test', pytree)
except Exception: # pylint: disable=broad-except
# We might get some errors because we mocked too much,
# but we check if mock_handler_class was called.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines deprecated legacy aliases for Orbax V1 API."""

import functools
from typing import Any, Callable, TypeVar
import warnings

from orbax.checkpoint.experimental.v1._src.loading import loading
from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading
from orbax.checkpoint.experimental.v1._src.saving import saving

_FuncT = TypeVar('_FuncT', bound=Callable[..., Any])


def deprecated(*, new: _FuncT) -> Callable[[_FuncT], _FuncT]:
"""Decorator to mark a function as a deprecated alias of another function."""

def decorator(deprecated_func: _FuncT) -> _FuncT:
@functools.wraps(deprecated_func)
def wrapper(*args, **kwargs):
alias_name = getattr(deprecated_func, '__name__', 'unknown')
new_name = getattr(new, '__name__', 'unknown')
warnings.warn(
f'`{alias_name}` is deprecated, use `{new_name}` instead.',
DeprecationWarning,
stacklevel=2,
)
return deprecated_func(*args, **kwargs)

return wrapper

return decorator


@deprecated(new=saving.save)
def save_pytree(*args, **kwargs):
return saving.save(*args, **kwargs)


@deprecated(new=saving.save_async)
def save_pytree_async(*args, **kwargs):
return saving.save_async(*args, **kwargs)


@deprecated(new=loading.load)
def load_pytree(*args, **kwargs):
return loading.load(*args, **kwargs)


@deprecated(new=loading.load_async)
def load_pytree_async(*args, **kwargs):
return loading.load_async(*args, **kwargs)


@deprecated(new=metadata_loading.metadata)
def pytree_metadata(*args, **kwargs):
return metadata_loading.metadata(*args, **kwargs)
Loading
Loading