Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from absl.testing import parameterized
from etils import epath
import flax
import grain.python as pygrain
import jax
from jax import numpy as jnp
import numpy as np
Expand Down Expand Up @@ -858,6 +859,56 @@ def test_async_directory_creation(self):
self.assertTrue(self.directory.exists())
self.assertLen(list(self.directory.parent.iterdir()), 1)

def test_grain_dataset_checkpointable(self):
# Create a simple dataset
ds = pygrain.MapDataset.source(list(range(10)))
dl = pygrain.DataLoader(
data_source=ds,
sampler=pygrain.SequentialSampler(10, pygrain.ShardOptions(0, 1)),
operations=[pygrain.Batch(1)],
)
data_iter = iter(dl)

# Advance it
next(data_iter) # [0]
next(data_iter) # [1]

checkpointables = {
STATE_CHECKPOINTABLE_KEY: self.pytree,
'dataset': data_iter,
}

# Save
ocp.save_checkpointables(self.directory, checkpointables)

# Advance original to change state
next(data_iter) # [2]

# Restore
new_dl = pygrain.DataLoader(
data_source=ds,
sampler=pygrain.SequentialSampler(10, pygrain.ShardOptions(0, 1)),
operations=[pygrain.Batch(1)],
)
new_data_iter = iter(new_dl)

abstract_checkpointables = {
STATE_CHECKPOINTABLE_KEY: self.abstract_pytree,
'dataset': new_data_iter,
}

# Load
loaded = ocp.load_checkpointables(
self.directory, abstract_checkpointables
)

# The restored iterator should be at the saved position
# (after [1], so next is [2])
self.assertEqual(next(loaded['dataset']), [2])
test_utils.assert_tree_equal(
self, self.pytree, loaded[STATE_CHECKPOINTABLE_KEY]
)

def test_background_error(self):

async def raise_background_error(*args, **kwargs):
Expand Down Expand Up @@ -1195,24 +1246,24 @@ class SynchronizationTest(_TestSetup):
def test_sync_save_increments_operation_id(self):
initial_op_id = int(synchronization.get_operation_id())
save_dir = self.directory / 'sync_save'
ocp.save_pytree(save_dir, self.pytree)
ocp.save(save_dir, self.pytree)
post_save_op_id = int(synchronization.get_operation_id())
self.assertEqual(post_save_op_id, initial_op_id + 1)

def test_async_save_increments_operation_id(self):
initial_op_id = int(synchronization.get_operation_id())
save_dir = self.directory / 'async_save'
response = ocp.save_pytree_async(save_dir, self.pytree)
response = ocp.save_async(save_dir, self.pytree)
response.result()
post_save_op_id = int(synchronization.get_operation_id())
self.assertEqual(post_save_op_id, initial_op_id + 1)

def test_sync_load_increments_operation_id(self):
save_dir = self.directory / 'sync_load'
ocp.save_pytree(save_dir, self.pytree)
ocp.save(save_dir, self.pytree)

initial_op_id = int(synchronization.get_operation_id())
restored = ocp.load_pytree(save_dir, self.abstract_pytree)
restored = ocp.load(save_dir, self.abstract_pytree)
self.assertIsNotNone(restored)
post_load_op_id = int(synchronization.get_operation_id())
self.assertEqual(post_load_op_id, initial_op_id + 1)
Expand All @@ -1222,12 +1273,12 @@ def test_save_then_load_increments_operation_id_sequentially(self):
save_dir = self.directory / 'save_then_load_seq'

# 1. Perform save (should increment by 1)
ocp.save_pytree(save_dir, self.pytree)
ocp.save(save_dir, self.pytree)
post_save_op_id = int(synchronization.get_operation_id())
self.assertEqual(post_save_op_id, initial_op_id + 1)

# 2. Perform load (should increment by 1 again)
restored = ocp.load_pytree(save_dir, self.abstract_pytree)
restored = ocp.load(save_dir, self.abstract_pytree)
self.assertIsNotNone(restored)
post_load_op_id = int(synchronization.get_operation_id())
self.assertEqual(post_load_op_id, initial_op_id + 2)
Loading