From 9d87609163a51eabd71fddbe776668e691e6ef25 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Thu, 28 May 2026 03:42:30 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 922673056 --- .../checkpoint/_src/path/snapshot/snapshot.py | 3 +++ .../_src/path/snapshot/snapshot_test.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py index d22b9ba9b..202676c80 100644 --- a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py +++ b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py @@ -22,6 +22,7 @@ from absl import logging from etils import epath +from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import utils as ocp_path_utils @@ -89,6 +90,8 @@ async def create_snapshot(self) -> None: if not await async_path.exists(self._source): raise ValueError(f"Snapshot source does not exist: {self._source}'.") + event_tracking.record_read_metadata_event(self._source) + t = ocp_path_utils.Timer() await asyncio.to_thread( ocp_path_utils.recursively_copy_files, diff --git a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot_test.py b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot_test.py index 5c70f27a4..656bdce0d 100644 --- a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot_test.py @@ -161,6 +161,14 @@ def test_replace_source_recovers_on_failure(self): # It should contain original source data. self.assertEqual('data', (recovery_path / 'data.txt').read_text()) + def test_create_snapshot_records_read_event(self): + default_snapshot = DefaultSnapshot(self.source_path, self.dest_path) + with mock.patch.object( + snapshot.event_tracking, 'record_read_metadata_event' + ) as mock_record: + asyncio.run(default_snapshot.create_snapshot()) + mock_record.assert_called_once_with(self.source_path) + class EmptySnapshotTest(absltest.TestCase): @@ -317,6 +325,14 @@ def test_replace_source_with_relative_source_path_fails(self): with self.assertRaisesRegex(ValueError, 'Snapshot source must be absolute'): asyncio.run(empty_snapshot.replace_source()) + def test_create_snapshot_does_not_record_read_event(self): + empty_snapshot = EmptySnapshot(self.source_path, self.dest_path) + with mock.patch.object( + snapshot.event_tracking, 'record_read_metadata_event' + ) as mock_record: + asyncio.run(empty_snapshot.create_snapshot()) + mock_record.assert_not_called() + if __name__ == '__main__': absltest.main()