diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py index a1dfe8d33..1d87ded07 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py @@ -351,7 +351,7 @@ def test_save_format(self): paths = [self.directory / name for name in fnames] for p in paths: self.assertTrue(p.exists()) - self.assertTrue((p / '.zarray').exists()) + def test_shape_mismatch(self): with self.ocdbt_checkpoint_handler(True) as checkpoint_handler: diff --git a/checkpoint/orbax/checkpoint/checkpoint_utils_test.py b/checkpoint/orbax/checkpoint/checkpoint_utils_test.py index 1005eb5d1..8803a6c49 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_utils_test.py +++ b/checkpoint/orbax/checkpoint/checkpoint_utils_test.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -467,5 +468,7 @@ def test_wait_for_directory_creation(self): self.assertEqual(step, 0) + + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index 32fc7ac01..bd09814d9 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -21,6 +21,7 @@ from typing import Any +from absl import logging from etils import epath import jax import numpy as np @@ -187,3 +188,5 @@ def fully_replicated_host_local_array_to_global_array( global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs ) return result + +