Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -556,11 +556,11 @@ async def load_multi_host(
restored_pytree[name] = _build_array_on_single_device(info, owner, ctx)
else:
global_transient_array = _build_transient_array(name, info, owner, ctx)
t0 = time.time()
restored_pytree[name] = _reshard_transient_array(
global_transient_array, target_sharding, global_mesh
)
reshard_time += time.time() - t0
restored_pytree[name] = {
"transient_array": global_transient_array,
"target_sharding": target_sharding,
"global_mesh": global_mesh,
}

return restored_pytree, {"io_time": io_time, "reshard_time": reshard_time}

Expand Down Expand Up @@ -631,28 +631,42 @@ async def _load_multi_host(

loaders = await self._get_loaders()

# Call load_multi_host on each loader concurrently.
# Phase 1: Call load_multi_host on each loader concurrently for parallel IO.
# Each loader handles loading from a single file.
start = time.time()
load_ops = []
for loader in loaders:
load_ops.append(loader.load_multi_host(abstract_state))

restored_pytree = {}
restored_pytree_partial = {}
total_io_time = 0.0
total_reshard_time = 0.0
for file_tensors, metrics in await asyncio.gather(*load_ops):
total_io_time += metrics["io_time"]
total_reshard_time += metrics["reshard_time"]
for name, arr in file_tensors.items():
if name in restored_pytree:
for name, arr_dict in file_tensors.items():
if name in restored_pytree_partial:
raise ValueError(f"Duplicate tensor {name} found in multiple files.")
restored_pytree[name] = arr
restored_pytree_partial[name] = arr_dict

# Validate that all requested tensors were found in at least one file.
for k in abstract_state:
if k not in restored_pytree:
restored_pytree = {}
total_reshard_time = 0.0

# Phase 2: Process JAX collectives completely sequentially!
# To ensure exact uniform ordering across all hosts to prevent deadlock, we
# iterate deterministically.
for k in sorted(abstract_state.keys()):
if k not in restored_pytree_partial:
raise KeyError(f"Tensor '{k}' not found in Safetensors checkpoint.")
item = restored_pytree_partial[k]
if context_lib.get_context().safetensors_options.ignore_load_sharding:
restored_pytree[k] = item
else:
t0 = time.time()
restored_pytree[k] = _reshard_transient_array(
item["transient_array"],
item["target_sharding"],
item["global_mesh"],
)
total_reshard_time += time.time() - t0

logging.info(
"[safetensors][multi-host] Loaded and resharded %d tensors from %d"
Expand Down
Loading