Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from absl import logging
from etils import epath
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import utils as ocp_path_utils

Expand Down Expand Up @@ -89,6 +90,8 @@ async def create_snapshot(self) -> None:
if not await async_path.exists(self._source):
raise ValueError(f"Snapshot source does not exist: {self._source}'.")

event_tracking.record_read_metadata_event(self._source)

t = ocp_path_utils.Timer()
await asyncio.to_thread(
ocp_path_utils.recursively_copy_files,
Expand Down
16 changes: 16 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ def test_replace_source_recovers_on_failure(self):
# It should contain original source data.
self.assertEqual('data', (recovery_path / 'data.txt').read_text())

def test_create_snapshot_records_read_event(self):
default_snapshot = DefaultSnapshot(self.source_path, self.dest_path)
with mock.patch.object(
snapshot.event_tracking, 'record_read_metadata_event'
) as mock_record:
asyncio.run(default_snapshot.create_snapshot())
mock_record.assert_called_once_with(self.source_path)


class EmptySnapshotTest(absltest.TestCase):

Expand Down Expand Up @@ -317,6 +325,14 @@ def test_replace_source_with_relative_source_path_fails(self):
with self.assertRaisesRegex(ValueError, 'Snapshot source must be absolute'):
asyncio.run(empty_snapshot.replace_source())

def test_create_snapshot_does_not_record_read_event(self):
empty_snapshot = EmptySnapshot(self.source_path, self.dest_path)
with mock.patch.object(
snapshot.event_tracking, 'record_read_metadata_event'
) as mock_record:
asyncio.run(empty_snapshot.create_snapshot())
mock_record.assert_not_called()


if __name__ == '__main__':
absltest.main()
Loading