diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index 693c4dec9..574e2a05b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -556,11 +556,11 @@ async def load_multi_host( restored_pytree[name] = _build_array_on_single_device(info, owner, ctx) else: global_transient_array = _build_transient_array(name, info, owner, ctx) - t0 = time.time() - restored_pytree[name] = _reshard_transient_array( - global_transient_array, target_sharding, global_mesh - ) - reshard_time += time.time() - t0 + restored_pytree[name] = { + "transient_array": global_transient_array, + "target_sharding": target_sharding, + "global_mesh": global_mesh, + } return restored_pytree, {"io_time": io_time, "reshard_time": reshard_time} @@ -631,28 +631,42 @@ async def _load_multi_host( loaders = await self._get_loaders() - # Call load_multi_host on each loader concurrently. + # Phase 1: Call load_multi_host on each loader concurrently for parallel IO. # Each loader handles loading from a single file. start = time.time() load_ops = [] for loader in loaders: load_ops.append(loader.load_multi_host(abstract_state)) - restored_pytree = {} + restored_pytree_partial = {} total_io_time = 0.0 - total_reshard_time = 0.0 for file_tensors, metrics in await asyncio.gather(*load_ops): total_io_time += metrics["io_time"] - total_reshard_time += metrics["reshard_time"] - for name, arr in file_tensors.items(): - if name in restored_pytree: + for name, arr_dict in file_tensors.items(): + if name in restored_pytree_partial: raise ValueError(f"Duplicate tensor {name} found in multiple files.") - restored_pytree[name] = arr + restored_pytree_partial[name] = arr_dict - # Validate that all requested tensors were found in at least one file. - for k in abstract_state: - if k not in restored_pytree: + restored_pytree = {} + total_reshard_time = 0.0 + + # Phase 2: Process JAX collectives completely sequentially! + # To ensure exact uniform ordering across all hosts to prevent deadlock, we + # iterate deterministically. + for k in sorted(abstract_state.keys()): + if k not in restored_pytree_partial: raise KeyError(f"Tensor '{k}' not found in Safetensors checkpoint.") + item = restored_pytree_partial[k] + if context_lib.get_context().safetensors_options.ignore_load_sharding: + restored_pytree[k] = item + else: + t0 = time.time() + restored_pytree[k] = _reshard_transient_array( + item["transient_array"], + item["target_sharding"], + item["global_mesh"], + ) + total_reshard_time += time.time() - t0 logging.info( "[safetensors][multi-host] Loaded and resharded %d tensors from %d"