Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def _make_on_commit_callback(
tmpdir: atomicity_types.TemporaryPath,
custom_metadata: dict[str, Any] | None,
checkpoint_start_time: float,
ckpt_args: checkpoint_args.CheckpointArgs,
) -> Callable[[], None]:
# Directory is the final directory.

Expand Down Expand Up @@ -453,9 +454,9 @@ def _callback() -> None:
async def _save(
self,
tmpdir: atomicity_types.TemporaryPath,
*args,
*,
ckpt_args: checkpoint_args.CheckpointArgs,
force: bool = False,
**kwargs,
):
directory = tmpdir.get_final()
if await async_path.exists(directory):
Expand Down Expand Up @@ -483,10 +484,6 @@ async def _save(
else:
await self.create_temporary_path(tmpdir)
# Run copy ops.
# Try to save using new CheckpointArgs API if supported by the handler.
ckpt_args = checkpointer.construct_checkpoint_args(
self._handler, True, *args, **kwargs
)
if isinstance(
self._handler,
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler,
Expand Down Expand Up @@ -554,15 +551,17 @@ def save(
tmpdir = self.get_temporary_path(directory)
self.wait_until_finished()
self.synchronize_next_awaitable_signal_operation_id()
ckpt_args = checkpointer.construct_checkpoint_args(
self._handler, True, *args, **kwargs
)
on_commit_callback = self._make_on_commit_callback(
tmpdir, custom_metadata, checkpoint_start_time
tmpdir, custom_metadata, checkpoint_start_time, ckpt_args
)
commit_ops = asyncio_utils.run_sync(
self._save(
tmpdir,
*args,
ckpt_args=ckpt_args,
force=force,
**kwargs,
)
)
operation_recorder.record_blocking_completion(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Synchronous Checkpointer implementation."""

import os
import time
from typing import Any, Iterable, Optional, Type

Expand Down Expand Up @@ -72,6 +73,8 @@ def construct_checkpoint_args(
return restore_arg_cls(*args, **kwargs)




class Checkpointer(
abstract_checkpointer.AbstractCheckpointer, epy.ContextManager
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import hashlib
import os
import time
from typing import Any, Awaitable, Callable, Iterable
import uuid
Expand Down Expand Up @@ -85,6 +86,8 @@ def add_internal_checkpointables(
return checkpointables




class _SaveResponse(AsyncResponse[None]):
"""An :py:class:`.AsyncResponse` representing the result of :py:func:`.save_async`."""

Expand All @@ -98,6 +101,7 @@ def __init__(
start_time: float,
custom_metadata: tree_types.JsonType | None,
context: context_lib.Context,
checkpointables: dict[str, Any],
async_origin: bool,
partial_save: bool = False,
):
Expand All @@ -108,6 +112,7 @@ def __init__(
self._start_time = start_time
self._custom_metadata = custom_metadata
self._context = context
self._checkpointables = checkpointables
self._async_origin = async_origin
self._partial_save = partial_save
self._thread_runner = thread_utils.BackgroundThreadRunner[None](
Expand Down Expand Up @@ -150,6 +155,7 @@ def create(
start_time=start_time,
custom_metadata=custom_metadata,
context=context,
checkpointables=checkpointables,
async_origin=async_origin,
partial_save=partial_save,
)
Expand Down
Loading