Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ray_ascend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
on Ascend NPU accelerators.
"""

from typing import List

from ray_ascend import _version

__all__ = [
Expand All @@ -17,7 +19,7 @@
__version__ = _version.version


def register_yr_tensor_transport(devices=["npu", "cpu"]) -> None:
def register_yr_tensor_transport(devices: List[str] = ["npu", "cpu"]) -> None:
"""
Register YR tensor transport for Ray and initialize YR backend.

Expand Down
167 changes: 137 additions & 30 deletions ray_ascend/collective/hccl_collective_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
import logging
import time
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import ray
import torch
Expand Down Expand Up @@ -62,6 +62,17 @@ class HcclDataTypeEnum:

@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
"""Convert a torch dtype to the corresponding HCCL data type enum value.

Args:
dtype: A torch dtype to convert.

Returns:
The corresponding HCCL data type enum value.

Raises:
ValueError: If the dtype is not supported by HCCL.
"""
if dtype == torch.int8:
return cls.HCCL_DATA_TYPE_INT8
if dtype == torch.int16:
Expand Down Expand Up @@ -91,6 +102,17 @@ class HcclRedOpTypeEnum:

@classmethod
def from_ray(cls, op: ReduceOp) -> int:
"""Convert a Ray reduce op to the corresponding HCCL reduce op enum value.

Args:
op: A Ray ReduceOp to convert.

Returns:
The corresponding HCCL reduce op enum value.

Raises:
ValueError: If the op is not supported by HCCL.
"""
if op == ReduceOp.SUM:
return cls.HCCL_REDUCE_SUM
if op == ReduceOp.PRODUCT:
Expand All @@ -115,14 +137,27 @@ class HCCLRootInfoStore:
"""

def __init__(self, name: str):
self.name = name
self.name: str = name
self.root_info_bytes: Optional[bytes] = None

def set_root_info_bytes(self, root_info_bytes: bytes) -> bytes:
"""Set the root info bytes in the store.

Args:
root_info_bytes: The root info bytes to store.

Returns:
The stored root info bytes.
"""
self.root_info_bytes = root_info_bytes
return self.root_info_bytes

def get_root_info_bytes(self) -> Optional[bytes]:
"""Get the root info bytes from the store.

Returns:
The stored root info bytes, or None if not yet set.
"""
if not self.root_info_bytes:
logger.warning(
"The HCCLRootInfo has not been set yet for store {}.".format(self.name)
Expand All @@ -131,13 +166,19 @@ def get_root_info_bytes(self) -> Optional[bytes]:


class HCCLGroup(BaseGroup):
def __init__(self, world_size: int, rank: int, group_name: str) -> None:
"""Init an HCCL collective group."""
def __init__(self, world_size: int, rank: int, group_name: str):
"""Init an HCCL collective group.

Args:
world_size: The number of processes in the collective group.
rank: The rank of this process in the collective group.
group_name: The name of the collective group.
"""
super(HCCLGroup, self).__init__(world_size, rank, group_name)
# Initialize single communicator/stream used for both collective and p2p ops.
self._comm: Optional[hcclComm_t] = None
self._stream: Optional[torch.npu.Stream] = None
self._store_name = get_store_name(f"{self.group_name}@collective")
self._store_name: str = get_store_name(f"{self.group_name}@collective")
self._device: Optional[int] = None
self._barrier_tensor: Optional[torch.Tensor] = None

Expand Down Expand Up @@ -197,10 +238,13 @@ def broadcast(
root_rank = broadcast_options.root_rank

def collective_fn(
input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream
):
# HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, aclrtStream stream)
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
comm: hcclComm_t,
stream: torch.npu.Stream,
) -> None:
with torch.npu.device(input_tensor.device):
# HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, aclrtStream stream)
current_stream = torch.npu.current_stream()
stream.wait_stream(current_stream)
exec_result = self.libhccl.HcclBroadcast(
Expand Down Expand Up @@ -230,17 +274,23 @@ def allgather(
"""Allgather a tensor across NPUs into a list of tensors.

Args:
tensor_list (List[Tensor]): output list for gathered tensors.
tensor_list: output list for gathered tensors.
tensor: the input tensor to allgather across the group.
allgather_options: allgather options.

Returns:
None
"""
# Handle case where tensor_list is wrapped in another list by Ray's collective.py
if tensor_list and isinstance(tensor_list[0], list):
tensor_list = tensor_list[0]

def collective_fn(
input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream
):
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
comm: hcclComm_t,
stream: torch.npu.Stream,
) -> None:
with torch.npu.device(input_tensor.device):
# HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm, aclrtStream stream)
current_stream = torch.npu.current_stream()
Expand All @@ -258,10 +308,6 @@ def collective_fn(
current_stream.wait_event(event)
logger.debug(f"HcclAllGather execute result : {exec_result}")

# Handle case where tensor_list is wrapped in another list by Ray's collective.py
if tensor_list and isinstance(tensor_list[0], list):
tensor_list = tensor_list[0]

output_flattened = [_flatten_for_scatter_gather(tensor_list, copy=False)]

input_tensor = self._validate_tensor(tensor)
Expand All @@ -288,8 +334,11 @@ def allreduce(
"""

def collective_fn(
input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream
):
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
comm: hcclComm_t,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the hcclComm_t pass the syntax check?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because hcclComm_t equal to ctypes.c_void_p.

hcclComm_t = ctypes.c_void_p

stream: torch.npu.Stream,
) -> None:
with torch.npu.device(input_tensor.device):
# HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, HcclComm comm, aclrtStream stream)
current_stream = torch.npu.current_stream()
Expand Down Expand Up @@ -346,8 +395,11 @@ def reduce(
root_rank = reduce_options.root_rank

def collective_fn(
input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream
):
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
comm: hcclComm_t,
stream: torch.npu.Stream,
) -> None:
with torch.npu.device(input_tensor.device):
# HcclResult HcclReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, uint32_t root, HcclComm comm, aclrtStream stream)
current_stream = torch.npu.current_stream()
Expand Down Expand Up @@ -382,7 +434,7 @@ def reducescatter(

Args:
tensor: the output tensor to receive this rank's shard.
tensor_list (List[Tensor]): the list of tensors to be reduced then scattered.
tensor_list: the list of tensors to be reduced then scattered.
reducescatter_options: reduce-scatter options.

Returns:
Expand All @@ -403,8 +455,11 @@ def reducescatter(
copy_event.record(copy_stream)

def collective_fn(
input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream
):
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
comm: hcclComm_t,
stream: torch.npu.Stream,
) -> None:
with torch.npu.device(input_tensor.device):
# Wait for copy operations to complete
stream.wait_event(copy_event)
Expand Down Expand Up @@ -446,7 +501,9 @@ def send(
None
"""

def p2p_fn(tensor: torch.Tensor, comm, stream, peer):
def p2p_fn(
tensor: torch.Tensor, comm: hcclComm_t, stream: torch.npu.Stream, peer: int
) -> None:
with torch.npu.device(f"npu:{tensor.device.index}"):
# HcclResult HcclSend(void* sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank,HcclComm comm, aclrtStream stream)
current_stream = torch.npu.current_stream()
Expand Down Expand Up @@ -481,7 +538,9 @@ def recv(
None
"""

def p2p_fn(tensor: torch.Tensor, comm, stream, peer):
def p2p_fn(
tensor: torch.Tensor, comm: hcclComm_t, stream: torch.npu.Stream, peer: int
) -> None:
with torch.npu.device(f"npu:{tensor.device.index}"):
# HcclResult HcclRecv(void* recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank,HcclComm comm, aclrtStream stream)
current_stream = torch.npu.current_stream()
Expand All @@ -504,6 +563,15 @@ def p2p_fn(tensor: torch.Tensor, comm, stream, peer):
p2p_fn(tensor, comm, stream, recv_options.src_rank)

def _generate_hccl_root_info(self, store_name: str, dev: int = 0) -> HcclRootInfo:
"""Generate HCCL root info and store it in a named actor.

Args:
store_name: The unique store key for the named actor.
dev: The NPU device index to use for root info generation.

Returns:
The generated HcclRootInfo.
"""
root_info = HcclRootInfo()
# NPU need set device before HcclGetRootInfo
with torch.npu.device(f"npu:{dev}"):
Expand All @@ -519,7 +587,9 @@ def _generate_hccl_root_info(self, store_name: str, dev: int = 0) -> HcclRootInf

return root_info

def _get_store_ref(self, store_name: str, timeout_s: int = 30) -> Any:
def _get_store_ref(
self, store_name: str, timeout_s: int = 30
) -> "ray.actor.ActorHandle":
"""Get the reference of the named actor store.

Args:
Expand Down Expand Up @@ -563,11 +633,18 @@ def _get_store_ref(self, store_name: str, timeout_s: int = 30) -> Any:

@staticmethod
def _destroy_store(store_name: str) -> None:
"""Destroy the named actor store.

Args:
store_name: The unique store key for the named actor.
"""
store = ray.get_actor(store_name)
# ray.get([store.__ray_terminate__.remote()])
ray.kill(store)

def _get_hccl_root_info(self, store_ref: Any, timeout_s: int = 30) -> HcclRootInfo:
def _get_hccl_root_info(
self, store_ref: "ray.actor.ActorHandle", timeout_s: int = 30
) -> HcclRootInfo:
"""Get the HcclRootInfo from the store through Ray.

Args:
Expand Down Expand Up @@ -606,7 +683,7 @@ def _init_collective_communicator(self) -> None:
root_info = self._get_hccl_root_info(store_ref)

with torch.npu.device(f"npu:{device}"):
comm: hcclComm_t = hcclComm_t()
comm = hcclComm_t()
result = self.libhccl.HcclCommInitRootInfo(
self.world_size,
ctypes.byref(root_info),
Expand All @@ -629,6 +706,15 @@ def _validate_tensor(

Accepts a Tensor or a single-element list (unwrapped automatically).
Enforces single-device constraint against the communicator's device.

Args:
tensor: The tensor to validate.

Returns:
The validated tensor.

Raises:
RuntimeError: If the tensor is not a torch.Tensor or device mismatch.
"""
# If the input is a list of tensors, we only support single tensor list for now.
# We will extract the single tensor out for validation.
Expand All @@ -646,14 +732,32 @@ def _validate_tensor(
return tensor

def _validate_collective_state(self) -> Tuple[hcclComm_t, torch.npu.Stream]:
"""Validate communicator and stream state and return them."""
"""Validate communicator and stream state and return them.

Returns:
Tuple of (communicator, stream).

Raises:
RuntimeError: If communicator or stream is not initialized.
"""
if self._comm is None or self._stream is None:
raise RuntimeError("Collective communicator is not initialized.")
return self._comm, self._stream


def get_tensor_device(tensor: torch.Tensor) -> int:
"""Return the NPU index of a tensor."""
"""Return the NPU index of a tensor.

Args:
tensor: The tensor to get the device index from.

Returns:
The NPU device index of the tensor.

Raises:
RuntimeError: If the tensor is not on a valid NPU.
ValueError: If the input is not a torch.Tensor.
"""
if isinstance(tensor, torch.Tensor):
device = tensor.device.index
if not isinstance(device, int):
Expand All @@ -674,10 +778,13 @@ def _flatten_for_scatter_gather(

Returns:
The flattened tensor buffer.

Raises:
RuntimeError: If tensor_list is empty.
"""
if not tensor_list:
raise RuntimeError("Received an empty list.")
t: torch.Tensor = tensor_list[0]
t = tensor_list[0]
buffer_shape = [len(tensor_list)] + list(t.shape)

buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
Expand Down
Loading
Loading