From 3cd3283a2442989196510668b366620694563e1f Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Fri, 29 May 2026 10:52:34 -0700 Subject: [PATCH] #v1 Add a test for saving and loading grain iterator. PiperOrigin-RevId: 923497446 --- .../v1/_src/testing/save_load_test_base.py | 63 +++++++++++++++++-- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index aca332e8a..0f4bc46e8 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -28,6 +28,7 @@ from absl.testing import parameterized from etils import epath import flax +import grain.python as pygrain import jax from jax import numpy as jnp import numpy as np @@ -858,6 +859,56 @@ def test_async_directory_creation(self): self.assertTrue(self.directory.exists()) self.assertLen(list(self.directory.parent.iterdir()), 1) + def test_grain_dataset_checkpointable(self): + # Create a simple dataset + ds = pygrain.MapDataset.source(list(range(10))) + dl = pygrain.DataLoader( + data_source=ds, + sampler=pygrain.SequentialSampler(10, pygrain.ShardOptions(0, 1)), + operations=[pygrain.Batch(1)], + ) + data_iter = iter(dl) + + # Advance it + next(data_iter) # [0] + next(data_iter) # [1] + + checkpointables = { + STATE_CHECKPOINTABLE_KEY: self.pytree, + 'dataset': data_iter, + } + + # Save + ocp.save_checkpointables(self.directory, checkpointables) + + # Advance original to change state + next(data_iter) # [2] + + # Restore + new_dl = pygrain.DataLoader( + data_source=ds, + sampler=pygrain.SequentialSampler(10, pygrain.ShardOptions(0, 1)), + operations=[pygrain.Batch(1)], + ) + new_data_iter = iter(new_dl) + + abstract_checkpointables = { + STATE_CHECKPOINTABLE_KEY: self.abstract_pytree, + 'dataset': new_data_iter, + } + + # Load + loaded = ocp.load_checkpointables( + self.directory, abstract_checkpointables + ) + + # The restored iterator should be at the saved position + # (after [1], so next is [2]) + self.assertEqual(next(loaded['dataset']), [2]) + test_utils.assert_tree_equal( + self, self.pytree, loaded[STATE_CHECKPOINTABLE_KEY] + ) + def test_background_error(self): async def raise_background_error(*args, **kwargs): @@ -1195,24 +1246,24 @@ class SynchronizationTest(_TestSetup): def test_sync_save_increments_operation_id(self): initial_op_id = int(synchronization.get_operation_id()) save_dir = self.directory / 'sync_save' - ocp.save_pytree(save_dir, self.pytree) + ocp.save(save_dir, self.pytree) post_save_op_id = int(synchronization.get_operation_id()) self.assertEqual(post_save_op_id, initial_op_id + 1) def test_async_save_increments_operation_id(self): initial_op_id = int(synchronization.get_operation_id()) save_dir = self.directory / 'async_save' - response = ocp.save_pytree_async(save_dir, self.pytree) + response = ocp.save_async(save_dir, self.pytree) response.result() post_save_op_id = int(synchronization.get_operation_id()) self.assertEqual(post_save_op_id, initial_op_id + 1) def test_sync_load_increments_operation_id(self): save_dir = self.directory / 'sync_load' - ocp.save_pytree(save_dir, self.pytree) + ocp.save(save_dir, self.pytree) initial_op_id = int(synchronization.get_operation_id()) - restored = ocp.load_pytree(save_dir, self.abstract_pytree) + restored = ocp.load(save_dir, self.abstract_pytree) self.assertIsNotNone(restored) post_load_op_id = int(synchronization.get_operation_id()) self.assertEqual(post_load_op_id, initial_op_id + 1) @@ -1222,12 +1273,12 @@ def test_save_then_load_increments_operation_id_sequentially(self): save_dir = self.directory / 'save_then_load_seq' # 1. Perform save (should increment by 1) - ocp.save_pytree(save_dir, self.pytree) + ocp.save(save_dir, self.pytree) post_save_op_id = int(synchronization.get_operation_id()) self.assertEqual(post_save_op_id, initial_op_id + 1) # 2. Perform load (should increment by 1 again) - restored = ocp.load_pytree(save_dir, self.abstract_pytree) + restored = ocp.load(save_dir, self.abstract_pytree) self.assertIsNotNone(restored) post_load_op_id = int(synchronization.get_operation_id()) self.assertEqual(post_load_op_id, initial_op_id + 2)