From b33a652fc6499931c6c822d97707d96905fd59a6 Mon Sep 17 00:00:00 2001 From: Nikhil Bansal Date: Fri, 29 May 2026 10:07:17 -0700 Subject: [PATCH] Include v1 save_load_test in OSS. PiperOrigin-RevId: 923473800 --- .../checkpoint_manager_multiprocess_test.py | 9 ++++ .../emergency/p2p/local_multiprocess_test.py | 10 +++++ .../v1/_src/testing/save_load_test.py | 41 +++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test.py diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_multiprocess_test.py index 73e7bde6a..9b88663d4 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_multiprocess_test.py @@ -48,6 +48,15 @@ class P2PCheckpointManagerMultiprocessTest(multiprocess_test.MultiProcessTest): def setUp(self): super().setUp() self.root_dir = self.create_tempdir('p2p_root') + self._original_prefix = ( + future.AwaitableSignalsContract.awaitable_signals_contract_prefix + ) + + def tearDown(self): + super().tearDown() + future.AwaitableSignalsContract.awaitable_signals_contract_prefix = ( + self._original_prefix + ) def initial_state(self, mesh): jax_processes = jax.process_count() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_multiprocess_test.py index dba1be74b..568a7b80a 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_multiprocess_test.py @@ -18,6 +18,7 @@ import numpy as np from orbax.checkpoint import args as args_lib from orbax.checkpoint import test_utils +from orbax.checkpoint._src.futures import future as futures_lib from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args from orbax.checkpoint.experimental.emergency.p2p import local from orbax.checkpoint.experimental.emergency.p2p import options as options_lib @@ -33,6 +34,15 @@ class LocalMultiprocessTest(multiprocess_test.MultiProcessTest): def setUp(self): super().setUp() self.mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) + self._original_prefix = ( + futures_lib.AwaitableSignalsContract.awaitable_signals_contract_prefix + ) + + def tearDown(self): + super().tearDown() + futures_lib.AwaitableSignalsContract.awaitable_signals_contract_prefix = ( + self._original_prefix + ) def test_save_restore(self): directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test.py new file mode 100644 index 000000000..30a3e5a15 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test.py @@ -0,0 +1,41 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import flags +import jax +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint.experimental.v1._src.testing import save_load_test_base + + +FLAGS = flags.FLAGS + +jax.config.update('jax_enable_x64', True) + + +class SaveLoadTest( + save_load_test_base.SaveLoadTestBase.SaveLoadTest, + multiprocess_test.MultiProcessTest, +): + pass + + +class SynchronizationTest( + save_load_test_base.SaveLoadTestBase.SynchronizationTest, + multiprocess_test.MultiProcessTest, +): + pass + + +if __name__ == '__main__': + multiprocess_test.main()