diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index 3d10b6666..15a7f11fd 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -135,13 +135,13 @@ class MultiprocessingOptions(_ActiveContextGuard): host, only a subset of processes are active, and a custom barrier key is used to prevent collisions with other concurrent checkpointers:: - from orbax.checkpoint.v1.options import - MultiprocessingOptions + from orbax.checkpoint.v1.options import MultiprocessingOptions options = MultiprocessingOptions( primary_host=1, active_processes={1, 2, 3}, barrier_sync_key_prefix="model_a_sync_" + ) Attributes: primary_host: The host id of the primary host. Default to 0. If it's set @@ -500,7 +500,8 @@ class PathwaysOptions(_ActiveContextGuard): """Options used to configure Pathways saving and loading. Attributes: - checkpointing_impl: The implementation to use for Pathways checkpointing. + checkpointing_impl: The implementation mode to use for Pathways + checkpointing. """ checkpointing_impl: pathways_types.CheckpointingImpl | None = None diff --git a/docs/api_reference/checkpoint.v1.options.rst b/docs/api_reference/checkpoint.v1.options.rst index 4002924d4..dba52d3a8 100644 --- a/docs/api_reference/checkpoint.v1.options.rst +++ b/docs/api_reference/checkpoint.v1.options.rst @@ -50,3 +50,14 @@ MemoryOptions ------------------------------------------------------------ .. autoclass:: MemoryOptions :members: + +SafetensorsOptions +------------------------------------------------------------ +.. autoclass:: SafetensorsOptions + :members: + +CheckpointLayout +------------------------------------------------------------ +.. autoclass:: CheckpointLayout + :members: + diff --git a/docs/guides/checkpoint/v1/context.ipynb b/docs/guides/checkpoint/v1/context.ipynb new file mode 100644 index 000000000..70665496a --- /dev/null +++ b/docs/guides/checkpoint/v1/context.ipynb @@ -0,0 +1,423 @@ +{ + "cells": [ + { + "metadata": { + "id": "i1y3dnJObIc-" + }, + "cell_type": "markdown", + "source": [ + "# Configuring Specialized Features\n", + "\n", + "This guide provides a comprehensive overview of `ocp.Context` in Orbax v1. It explains the underlying architecture, demonstrates basic and advanced usage patterns, and outlines best practices for managing environment, runtime, and I/O configuration in your training loops." + ] + }, + { + "metadata": { + "id": "zNlANtqlbL4S" + }, + "cell_type": "markdown", + "source": [ + "## 1. Configuration Behavior\n", + "\n", + "`Context` objects and their underlying configuration option dataclasses (e.g., `ArrayOptions`, `AsyncOptions`, `FileOptions`) serve two distinct roles during their lifecycle:\n", + "\n", + "1. **Standalone Configuration Templates (Mutable):** When you instantiate `ctx = ocp.Context()`, it acts as an in-memory template. You can freely build, modify, and inspect its configuration parameters using a clean, mutable dot-notation syntax.\n", + "2. **Active Execution Policies (Frozen):** Once a `Context` is bound to a context manager (`with ctx:`), it becomes the active runtime policy for all Orbax operations executed within that block.\n", + "\n", + "### Strict immutability during use\n", + "To guarantee thread safety and prevent unpredictable mid-flight side effects, Orbax enforces a strict immutability invariant on active contexts.\n", + "\n", + "Attempting to mutate any configuration parameter while the context is active will immediately be intercepted and raise a `RuntimeError`.\n", + "\n", + "First, let's set up a minimal environment for our examples:" + ] + }, + { + "metadata": { + "id": "1Mf0YMYVbOUn" + }, + "cell_type": "code", + "source": [ + "from etils import epath\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from orbax.checkpoint import v1 as ocp\n", + "\n", + "# Minimal setup for examples\n", + "directory = epath.Path('/tmp/my_checkpoint_dir')\n", + "params = {'w': jnp.zeros((2, 3)), 'b': jnp.ones((3, 4))}\n", + "params_tree = {'params': params}\n", + "step = 0\n", + "\n", + "print('Setup complete.')" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "yP6KZk31bQuy" + }, + "cell_type": "code", + "source": [ + "ctx = ocp.Context()\n", + "ctx.asynchronous.timeout_secs = 600 # Perfectly valid (configuring template)\n", + "\n", + "with ctx:\n", + " # Executing checkpoint operations...\n", + " ocp.save(directory / 'sync', params_tree)\n", + "\n", + " # Attempting to mutate an active context is strictly prohibited:\n", + " try:\n", + " ctx.asynchronous.timeout_secs = 1200\n", + " except RuntimeError as e:\n", + " print(f\"Caught expected error: {e}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "5WNxPz1zbS0Y" + }, + "cell_type": "markdown", + "source": [ + "---\n", + "\n", + "## 2. Basic Usage & Configuration\n", + "\n", + "Configuring a `Context` relies on a hierarchical, dot-notation namespace. You do not need to construct or pass complex option dataclasses directly.\n", + "\n", + "Below is a summary of the available option types.\n", + "\n", + "| Option Type | Dot Path | Description |\n", + "| :--- | :--- | :--- |\n", + "| {py:class}`AsyncOptions ` | `ctx.asynchronous` | Asynchronous checkpoint saving operations, including timeouts, workers, and background finalization behavior. |\n", + "| {py:class}`MultiprocessingOptions ` | `ctx.multiprocessing` | Multi-host and multi-process checkpointing behavior and barrier synchronization. |\n", + "| {py:class}`FileOptions ` | `ctx.file` | Underlying filesystem interactions, directory permissions, atomicity protocols, and customized path implementations. |\n", + "| {py:class}`PyTreeOptions ` | `ctx.pytree` | PyTree-level saving, loading, and structural restoration behavior. |\n", + "| {py:class}`ArrayOptions ` | `ctx.array` | High-performance tensor and array I/O, storage formats, compression, sharding, and multi-replica load-and-broadcast behavior. |\n", + "| {py:class}`CheckpointablesOptions ` | `ctx.checkpointables` | Handler resolution registries for custom checkpointable types. |\n", + "| {py:class}`PathwaysOptions ` | `ctx.pathways` | Pathways-specific distributed checkpointing implementations. |\n", + "| {py:class}`DeletionOptions ` | `ctx.deletion` | Checkpoint cleanup and soft-deletion behavior across storage backends. |\n", + "| {py:class}`MemoryOptions ` | `ctx.memory` | Concurrent I/O memory limits and prioritized transfer scheduling to prevent out-of-memory (OOM) errors during large checkpoint operations. |\n", + "| {py:class}`SafetensorsOptions ` | `ctx.safetensors` | HuggingFace SafeTensors loading and conversion behavior. |\n", + "| {py:class}`CheckpointLayout ` | `ctx.checkpoint_layout` | Permanent on-disk serialization layout format. |" + ] + }, + { + "metadata": { + "id": "7KWR2LqmbWKs" + }, + "cell_type": "markdown", + "source": [ + "### Example: Basic Configuration & Execution\n", + "\n", + "When using `Checkpointer`, you do not need to wrap your training loop inside `with ctx:`. You can pass the `Context` object directly into the `Checkpointer` constructor (`context=ctx`). For standalone free functions (`ocp.save`), you use `with ctx:` to bind the active context." + ] + }, + { + "metadata": { + "id": "Ujyzb-NsbY9N" + }, + "cell_type": "code", + "source": [ + "# 1. Instantiate the root Context\n", + "ctx = ocp.Context()\n", + "\n", + "# 2. Configure options via mutable dot-notation\n", + "ctx.asynchronous.timeout_secs = 1200\n", + "ctx.asynchronous.create_directories_asynchronously = True\n", + "\n", + "ctx.array.saving.use_zarr3 = True\n", + "ctx.array.saving.use_compression = False\n", + "ctx.array.loading.enable_padding_and_truncation = True\n", + "\n", + "ctx.pytree.loading.partial_load = True\n", + "\n", + "# 3a. Using Checkpointer (Pass context directly into constructor)\n", + "with ocp.training.Checkpointer(directory / 'ckptr', context=ctx) as ckptr:\n", + " ckptr.save_checkpointables(step, {'params': params})\n", + "\n", + "# 3b. Using Free Functions (Bind context via with block)\n", + "with ctx:\n", + " ocp.save(directory / 'free', params_tree)" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "MlP_097mbaKS" + }, + "cell_type": "markdown", + "source": [ + "---\n", + "\n", + "## 3. Advanced Usage: Inheritance & Customization\n", + "\n", + "Orbax `Context` supports powerful inheritance patterns, allowing you to branch configurations for specialized sub-tasks without duplicating code or risking side effects." + ] + }, + { + "metadata": { + "id": "vID0rfD0bcIN" + }, + "cell_type": "markdown", + "source": [ + "### 3.1 Context Inheritance\n", + "To inherit properties from an existing parent `Context`, pass the parent context directly to the constructor (`ctx2 = ocp.Context(ctx1)`).\n", + "\n", + "Orbax performs a deep copy of the parent's option tree (while safely sharing immutable functions and callbacks by reference). The resulting child context inherits all parent properties but is completely decoupled, allowing you to mutate `ctx2` independently without affecting `ctx1`.\n", + "\n", + "#### Inheritance in Checkpointer\n", + "Note that when you pass a `Context` object into `Checkpointer(..., context=ctx)` (or when `Checkpointer` inherits an active context from a `with ctx:` block), Orbax automatically executes `ocp.Context(ctx)` under the hood. This means `Checkpointer` inherits all properties from your context but operates on a completely independent, unfrozen child copy, preserving perfect isolation." + ] + }, + { + "metadata": { + "id": "L-Y0KA4WbrdT" + }, + "cell_type": "code", + "source": [ + "# Parent context configures baseline rules\n", + "base_ctx = ocp.Context()\n", + "base_ctx.pytree.loading.partial_load = True\n", + "base_ctx.asynchronous.timeout_secs = 1200\n", + "\n", + "# Checkpointer automatically branches a child context ocp.Context(base_ctx)\n", + "with ocp.training.Checkpointer(directory / 'child_ckptr', context=base_ctx) as ckptr:\n", + " ckptr.save_checkpointables(step, {'params': params})" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "jwaYEENObtam" + }, + "cell_type": "markdown", + "source": [ + "### 3.2 Scoped Storage Options (Per-Leaf Configuration)\n", + "When saving complex PyTrees, you may want certain parameter leaves (e.g., large weight matrices) to use different storage rules (e.g., lower precision dtypes or specific chunk shapes) than smaller parameters (e.g., biases). You can achieve this using `scoped_storage_options_creator`." + ] + }, + { + "metadata": { + "id": "FfnqkGewbu1v" + }, + "cell_type": "code", + "source": [ + "def custom_storage_rules(keypath, value):\n", + " # Downcast large weights to float16, leave biases as default\n", + " if 'weight' in jax.tree_util.keystr(keypath):\n", + " return ocp.options.ArrayOptions.Saving.StorageOptions(dtype=jnp.float16)\n", + " return None # Fall back to global storage_options\n", + "\n", + "ctx = ocp.Context()\n", + "ctx.array.saving.storage_options.dtype = jnp.float32\n", + "ctx.array.saving.scoped_storage_options_creator = custom_storage_rules\n", + "\n", + "with ctx:\n", + " ocp.save(directory / 'scoped', params_tree)" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "YXzPWKsqbwPu" + }, + "cell_type": "markdown", + "source": [ + "### 3.3 Custom Handler Registration\n", + "Orbax provides a global registry for standard handlers (like PyTrees and JSON). However, when defining custom\n", + "`CheckpointableHandler` types (as detailed in the {doc}`Customization Guide `), you can configure\n", + "context-local registries to avoid polluting the global registry or causing conflicts across different modules.\n", + "\n", + "By attaching a custom `CheckpointableHandlerRegistry` to `ctx.checkpointables.registry`, Orbax will resolve handlers for your custom checkpointables exclusively within that context's scope." + ] + }, + { + "metadata": { + "id": "-f0fKFM9b0Sp" + }, + "cell_type": "code", + "source": [ + "# 1. Create a standalone local registry\n", + "local_registry = ocp.handlers.local_registry()\n", + "\n", + "# 2. Register your custom handler (assuming MyCustomHandler is defined)\n", + "# local_registry.add(MyCustomHandler, checkpointable_name='custom_state')\n", + "\n", + "# 3. Attach the local registry to your Context\n", + "ctx = ocp.Context()\n", + "ctx.checkpointables.registry = local_registry\n", + "\n", + "# 4. Execute within the context scope; Orbax will now use local_registry\n", + "with ctx:\n", + " # ocp.save_checkpointables(directory, dict(custom_state=my_custom_object))\n", + " pass" + ], + "outputs": [], + "execution_count": 10 + }, + { + "metadata": { + "id": "nNSYCLSqb1fG" + }, + "cell_type": "markdown", + "source": [ + "---\n", + "\n", + "## 4. Context Best Practices\n", + "\n", + "To ensure clean architectural design and prevent subtle concurrency or scoping bugs across multi-threaded, multi-host, or asynchronous workflows, adhere to the following core principles." + ] + }, + { + "metadata": { + "id": "Ph7AwSNSb3mx" + }, + "cell_type": "markdown", + "source": [ + "### 4.1 Concurrency in Asynchronous Workflows\n", + "\n", + "For standard asynchronous workflows where you configure a `Context` once at the start of your training script, spawning a dedicated child context (sub-context) is **not required**. The active context safely manages the background save operation, and `Checkpointer` automatically creates an isolated child context under the hood.\n", + "\n", + "However, note that any background task or coroutine inheriting the active context receives a reference to the exact same `Context` instance in memory. If your advanced workflow requires the main thread to actively mutate the shared `Context` instance mid-flight after launching an async save, those mutations would propagate to the background task. In those rare cases, branching a child context (`ctx2 = ocp.Context(main_ctx)`) ensures perfect isolation.\n", + "\n", + "Below is the standard, recommended pattern for asynchronous saving:" + ] + }, + { + "metadata": { + "id": "fPVDI-fHdoNW" + }, + "cell_type": "code", + "source": [ + "main_ctx = ocp.Context()\n", + "main_ctx.asynchronous.timeout_secs = 600\n", + "\n", + "# Standard async save operates directly within the active context\n", + "with main_ctx:\n", + " response = ocp.save_async(directory / 'async', params_tree)\n", + " # Perform other concurrent training work...\n", + " response.result()" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "VtSsoHt4cPFv" + }, + "cell_type": "markdown", + "source": [ + "### 4.2 Thread Safety and Background Threads\n", + "\n", + "Orbax is not thread-safe in general. Calling Orbax operations from a custom background thread (e.g., via `threading.Thread` or `ThreadPoolExecutor`) is not supported and should be avoided. If you believe your specific architecture requires calling Orbax from a background thread, please consult with the Orbax team to discuss your use case." + ] + }, + { + "metadata": { + "id": "8DfavNrCcuKj" + }, + "cell_type": "markdown", + "source": [ + "### 4.3 Consistency in Distributed Multi-Host Environments\n", + "\n", + "In distributed training setups (e.g., multi-slice TPU pods or multi-node GPU clusters), `Context` settings like `MultiprocessingOptions`, `AsyncOptions`, and underlying storage sharding rules must be configured identically across all participating hosts. Mismatched configurations can lead to barrier deadlocks or corrupted metadata." + ] + }, + { + "metadata": { + "id": "7VGE49-Tc0rn" + }, + "cell_type": "code", + "source": [ + "# Execute identical configuration logic across all participating hosts\n", + "cluster_ctx = ocp.Context()\n", + "cluster_ctx.multiprocessing.barrier_sync_key_prefix = \"model_v1_sync_\"\n", + "cluster_ctx.asynchronous.timeout_secs = 1200\n", + "\n", + "# Ensure all hosts enter the managed block synchronously\n", + "with cluster_ctx:\n", + " ocp.save(directory / 'cluster', params_tree)" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "ldztvAnXcUvx" + }, + "cell_type": "markdown", + "source": [ + "### 4.4 Configuration Scoping in Multi-Actor Scripts\n", + "When authoring scripts or test suites that mix `Checkpointer` objects and standalone free functions (`ocp.save`), guarantee isolation by wrapping free functions in dedicated `with ocp.Context(...):` blocks or passing explicit contexts into `Checkpointer` constructors. This ensures each independent actor executes with its intended settings regardless of the outer scope." + ] + }, + { + "metadata": { + "id": "9RkzMtWxcV3D" + }, + "cell_type": "markdown", + "source": [ + "### 4.5 Frozen Lineage Protection\n", + "When instantiating a child context from a parent context (`ctx2 = ocp.Context(ctx1)`), Orbax guarantees strict lineage safety. If the parent `ctx1` is currently bound to an active `with` block, `ctx1` remains fully frozen against mutation.\n", + "\n", + "Creating `ctx2` produces a completely independent, unfrozen child copy that you can customize freely, but it does not unfreeze or compromise the parent `ctx1`. Attempting to mutate `ctx1` while it is active is strictly blocked by the freeze guard." + ] + }, + { + "metadata": { + "id": "W-yxJNpydCMB" + }, + "cell_type": "code", + "source": [ + "parent_ctx = ocp.Context()\n", + "parent_ctx.asynchronous.timeout_secs = 600\n", + "\n", + "with parent_ctx:\n", + " # Safely branch an unfrozen child context from the active parent\n", + " child_ctx = ocp.Context(parent_ctx)\n", + " child_ctx.asynchronous.timeout_secs = 300 # Perfectly valid (modifying child)\n", + "\n", + " # The active parent context remains strictly frozen against mutation:\n", + " try:\n", + " parent_ctx.asynchronous.timeout_secs = 1200\n", + " except RuntimeError as e:\n", + " print(f\"Caught expected error: {e}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "06cb5acd" + }, + "cell_type": "markdown", + "source": [ + "---\n", + "\n", + "## 5. Summary of Best Practices\n", + "\n", + "1. **Configure Early, Execute Late:** Instantiate and configure your `ocp.Context` once at the top of your training script or sub-task.\n", + "2. **Dot-Notation:** Use `ctx.asynchronous.timeout_secs = ...` rather than constructing complex option dataclasses.\n", + "3. **Branch via Inheritance:** Use `ctx2 = ocp.Context(ctx1)` when you need modified settings for a specific evaluation or export step.\n", + "4. **Respect Immutability:** Never attempt to mutate a `Context` while it is actively bound to a `with` block or being used by a background asynchronous save.\n", + "5. **Manage Checkpointer Lifecycle:** Prefer `with ocp.training.Checkpointer(...) as ckptr:` to guarantee clean resource cleanup and complete outstanding writes.\n", + "6. **Consult API Reference:** For in-depth configuration fields, sub-options, and advanced usage, refer to the [Orbax Context Options API Reference](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.v1.options.html)." + ] + } + ], + "metadata": { + "colab": { + "private_outputs": true, + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/index.rst b/docs/index.rst index 6c2cbc79b..e2e594bde 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -113,6 +113,7 @@ API that is **easy to use**, **highly performant**, and **maximimally compatible :caption: Advanced Usage guides/checkpoint/v1/customization + guides/checkpoint/v1/context guides/checkpoint/v1/partial_saving guides/checkpoint/v1/model_surgery