diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d8c6726..8401a3e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,24 @@ repos: name: ruff-check-f401 args: [--select, "F401", --exit-non-zero-on-fix] + - repo: https://github.com/jsh9/pydoclint + rev: "0.8.3" + hooks: + - id: pydoclint + args: [ + --style=google, + # Current settings (not because we think they're right, but because we + # don't want a baseline the size of the codebase) + --arg-type-hints-in-docstring=False, + --skip-checking-raises=True, + --check-return-types=False, + --allow-init-docstring=True, + --check-class-attributes=False, + # --check-style-mismatch=True, # Bring this back once things are a bit cleaner + ] + types: [python] + files: '^ray_ascend/' + - repo: https://github.com/psf/black rev: 26.1.0 hooks: diff --git a/ray_ascend/__init__.py b/ray_ascend/__init__.py index 2bb2603..7d2d9fc 100644 --- a/ray_ascend/__init__.py +++ b/ray_ascend/__init__.py @@ -3,6 +3,8 @@ on Ascend NPU accelerators. """ +from typing import List + from ray_ascend import _version __all__ = [ @@ -17,7 +19,7 @@ __version__ = _version.version -def register_yr_tensor_transport(devices=["npu", "cpu"]) -> None: +def register_yr_tensor_transport(devices: List[str] = ["npu", "cpu"]) -> None: """ Register YR tensor transport for Ray and initialize YR backend. diff --git a/ray_ascend/collective/hccl_collective_group.py b/ray_ascend/collective/hccl_collective_group.py index f79bb8d..4cf224e 100644 --- a/ray_ascend/collective/hccl_collective_group.py +++ b/ray_ascend/collective/hccl_collective_group.py @@ -2,7 +2,7 @@ import datetime import logging import time -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import ray import torch @@ -62,6 +62,17 @@ class HcclDataTypeEnum: @classmethod def from_torch(cls, dtype: torch.dtype) -> int: + """Convert a torch dtype to the corresponding HCCL data type enum value. + + Args: + dtype: A torch dtype to convert. + + Returns: + The corresponding HCCL data type enum value. + + Raises: + ValueError: If the dtype is not supported by HCCL. + """ if dtype == torch.int8: return cls.HCCL_DATA_TYPE_INT8 if dtype == torch.int16: @@ -91,6 +102,17 @@ class HcclRedOpTypeEnum: @classmethod def from_ray(cls, op: ReduceOp) -> int: + """Convert a Ray reduce op to the corresponding HCCL reduce op enum value. + + Args: + op: A Ray ReduceOp to convert. + + Returns: + The corresponding HCCL reduce op enum value. + + Raises: + ValueError: If the op is not supported by HCCL. + """ if op == ReduceOp.SUM: return cls.HCCL_REDUCE_SUM if op == ReduceOp.PRODUCT: @@ -115,14 +137,27 @@ class HCCLRootInfoStore: """ def __init__(self, name: str): - self.name = name + self.name: str = name self.root_info_bytes: Optional[bytes] = None def set_root_info_bytes(self, root_info_bytes: bytes) -> bytes: + """Set the root info bytes in the store. + + Args: + root_info_bytes: The root info bytes to store. + + Returns: + The stored root info bytes. + """ self.root_info_bytes = root_info_bytes return self.root_info_bytes def get_root_info_bytes(self) -> Optional[bytes]: + """Get the root info bytes from the store. + + Returns: + The stored root info bytes, or None if not yet set. + """ if not self.root_info_bytes: logger.warning( "The HCCLRootInfo has not been set yet for store {}.".format(self.name) @@ -131,13 +166,19 @@ def get_root_info_bytes(self) -> Optional[bytes]: class HCCLGroup(BaseGroup): - def __init__(self, world_size: int, rank: int, group_name: str) -> None: - """Init an HCCL collective group.""" + def __init__(self, world_size: int, rank: int, group_name: str): + """Init an HCCL collective group. + + Args: + world_size: The number of processes in the collective group. + rank: The rank of this process in the collective group. + group_name: The name of the collective group. + """ super(HCCLGroup, self).__init__(world_size, rank, group_name) # Initialize single communicator/stream used for both collective and p2p ops. self._comm: Optional[hcclComm_t] = None self._stream: Optional[torch.npu.Stream] = None - self._store_name = get_store_name(f"{self.group_name}@collective") + self._store_name: str = get_store_name(f"{self.group_name}@collective") self._device: Optional[int] = None self._barrier_tensor: Optional[torch.Tensor] = None @@ -197,10 +238,13 @@ def broadcast( root_rank = broadcast_options.root_rank def collective_fn( - input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream - ): - # HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, aclrtStream stream) + input_tensor: torch.Tensor, + output_tensor: torch.Tensor, + comm: hcclComm_t, + stream: torch.npu.Stream, + ) -> None: with torch.npu.device(input_tensor.device): + # HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, aclrtStream stream) current_stream = torch.npu.current_stream() stream.wait_stream(current_stream) exec_result = self.libhccl.HcclBroadcast( @@ -230,17 +274,23 @@ def allgather( """Allgather a tensor across NPUs into a list of tensors. Args: - tensor_list (List[Tensor]): output list for gathered tensors. + tensor_list: output list for gathered tensors. tensor: the input tensor to allgather across the group. allgather_options: allgather options. Returns: None """ + # Handle case where tensor_list is wrapped in another list by Ray's collective.py + if tensor_list and isinstance(tensor_list[0], list): + tensor_list = tensor_list[0] def collective_fn( - input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream - ): + input_tensor: torch.Tensor, + output_tensor: torch.Tensor, + comm: hcclComm_t, + stream: torch.npu.Stream, + ) -> None: with torch.npu.device(input_tensor.device): # HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm, aclrtStream stream) current_stream = torch.npu.current_stream() @@ -258,10 +308,6 @@ def collective_fn( current_stream.wait_event(event) logger.debug(f"HcclAllGather execute result : {exec_result}") - # Handle case where tensor_list is wrapped in another list by Ray's collective.py - if tensor_list and isinstance(tensor_list[0], list): - tensor_list = tensor_list[0] - output_flattened = [_flatten_for_scatter_gather(tensor_list, copy=False)] input_tensor = self._validate_tensor(tensor) @@ -288,8 +334,11 @@ def allreduce( """ def collective_fn( - input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream - ): + input_tensor: torch.Tensor, + output_tensor: torch.Tensor, + comm: hcclComm_t, + stream: torch.npu.Stream, + ) -> None: with torch.npu.device(input_tensor.device): # HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, HcclComm comm, aclrtStream stream) current_stream = torch.npu.current_stream() @@ -346,8 +395,11 @@ def reduce( root_rank = reduce_options.root_rank def collective_fn( - input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream - ): + input_tensor: torch.Tensor, + output_tensor: torch.Tensor, + comm: hcclComm_t, + stream: torch.npu.Stream, + ) -> None: with torch.npu.device(input_tensor.device): # HcclResult HcclReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, uint32_t root, HcclComm comm, aclrtStream stream) current_stream = torch.npu.current_stream() @@ -382,7 +434,7 @@ def reducescatter( Args: tensor: the output tensor to receive this rank's shard. - tensor_list (List[Tensor]): the list of tensors to be reduced then scattered. + tensor_list: the list of tensors to be reduced then scattered. reducescatter_options: reduce-scatter options. Returns: @@ -403,8 +455,11 @@ def reducescatter( copy_event.record(copy_stream) def collective_fn( - input_tensor: torch.Tensor, output_tensor: torch.Tensor, comm, stream - ): + input_tensor: torch.Tensor, + output_tensor: torch.Tensor, + comm: hcclComm_t, + stream: torch.npu.Stream, + ) -> None: with torch.npu.device(input_tensor.device): # Wait for copy operations to complete stream.wait_event(copy_event) @@ -446,7 +501,9 @@ def send( None """ - def p2p_fn(tensor: torch.Tensor, comm, stream, peer): + def p2p_fn( + tensor: torch.Tensor, comm: hcclComm_t, stream: torch.npu.Stream, peer: int + ) -> None: with torch.npu.device(f"npu:{tensor.device.index}"): # HcclResult HcclSend(void* sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank,HcclComm comm, aclrtStream stream) current_stream = torch.npu.current_stream() @@ -481,7 +538,9 @@ def recv( None """ - def p2p_fn(tensor: torch.Tensor, comm, stream, peer): + def p2p_fn( + tensor: torch.Tensor, comm: hcclComm_t, stream: torch.npu.Stream, peer: int + ) -> None: with torch.npu.device(f"npu:{tensor.device.index}"): # HcclResult HcclRecv(void* recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank,HcclComm comm, aclrtStream stream) current_stream = torch.npu.current_stream() @@ -504,6 +563,15 @@ def p2p_fn(tensor: torch.Tensor, comm, stream, peer): p2p_fn(tensor, comm, stream, recv_options.src_rank) def _generate_hccl_root_info(self, store_name: str, dev: int = 0) -> HcclRootInfo: + """Generate HCCL root info and store it in a named actor. + + Args: + store_name: The unique store key for the named actor. + dev: The NPU device index to use for root info generation. + + Returns: + The generated HcclRootInfo. + """ root_info = HcclRootInfo() # NPU need set device before HcclGetRootInfo with torch.npu.device(f"npu:{dev}"): @@ -519,14 +587,16 @@ def _generate_hccl_root_info(self, store_name: str, dev: int = 0) -> HcclRootInf return root_info - def _get_store_ref(self, store_name: str, timeout_s: int = 30) -> Any: + def _get_store_ref( + self, store_name: str, timeout_s: int = 30 + ) -> "ray.actor.ActorHandle": """Get the reference of the named actor store. Args: store_name: the unique store key timeout_s: timeout in seconds. - Return: + Returns: store_ref: reference to store actor """ if timeout_s <= 0: @@ -563,18 +633,25 @@ def _get_store_ref(self, store_name: str, timeout_s: int = 30) -> Any: @staticmethod def _destroy_store(store_name: str) -> None: + """Destroy the named actor store. + + Args: + store_name: The unique store key for the named actor. + """ store = ray.get_actor(store_name) # ray.get([store.__ray_terminate__.remote()]) ray.kill(store) - def _get_hccl_root_info(self, store_ref: Any, timeout_s: int = 30) -> HcclRootInfo: + def _get_hccl_root_info( + self, store_ref: "ray.actor.ActorHandle", timeout_s: int = 30 + ) -> HcclRootInfo: """Get the HcclRootInfo from the store through Ray. Args: store_ref: reference to the rendezvous store actor. timeout_s: timeout in seconds. - Return: + Returns: root_info: the HcclRootInfo if successful. """ root_info_bytes = None @@ -606,7 +683,7 @@ def _init_collective_communicator(self) -> None: root_info = self._get_hccl_root_info(store_ref) with torch.npu.device(f"npu:{device}"): - comm: hcclComm_t = hcclComm_t() + comm = hcclComm_t() result = self.libhccl.HcclCommInitRootInfo( self.world_size, ctypes.byref(root_info), @@ -629,6 +706,15 @@ def _validate_tensor( Accepts a Tensor or a single-element list (unwrapped automatically). Enforces single-device constraint against the communicator's device. + + Args: + tensor: The tensor to validate. + + Returns: + The validated tensor. + + Raises: + RuntimeError: If the tensor is not a torch.Tensor or device mismatch. """ # If the input is a list of tensors, we only support single tensor list for now. # We will extract the single tensor out for validation. @@ -646,14 +732,32 @@ def _validate_tensor( return tensor def _validate_collective_state(self) -> Tuple[hcclComm_t, torch.npu.Stream]: - """Validate communicator and stream state and return them.""" + """Validate communicator and stream state and return them. + + Returns: + Tuple of (communicator, stream). + + Raises: + RuntimeError: If communicator or stream is not initialized. + """ if self._comm is None or self._stream is None: raise RuntimeError("Collective communicator is not initialized.") return self._comm, self._stream def get_tensor_device(tensor: torch.Tensor) -> int: - """Return the NPU index of a tensor.""" + """Return the NPU index of a tensor. + + Args: + tensor: The tensor to get the device index from. + + Returns: + The NPU device index of the tensor. + + Raises: + RuntimeError: If the tensor is not on a valid NPU. + ValueError: If the input is not a torch.Tensor. + """ if isinstance(tensor, torch.Tensor): device = tensor.device.index if not isinstance(device, int): @@ -674,10 +778,13 @@ def _flatten_for_scatter_gather( Returns: The flattened tensor buffer. + + Raises: + RuntimeError: If tensor_list is empty. """ if not tensor_list: raise RuntimeError("Received an empty list.") - t: torch.Tensor = tensor_list[0] + t = tensor_list[0] buffer_shape = [len(tensor_list)] + list(t.shape) buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) diff --git a/ray_ascend/direct_transport/yr_tensor_transport.py b/ray_ascend/direct_transport/yr_tensor_transport.py index 2464112..970c290 100644 --- a/ray_ascend/direct_transport/yr_tensor_transport.py +++ b/ray_ascend/direct_transport/yr_tensor_transport.py @@ -2,7 +2,7 @@ import pickle import uuid from dataclasses import dataclass, field -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import ray import torch @@ -57,7 +57,7 @@ def is_one_sided() -> bool: def can_abort_transport() -> bool: return False - def _get_worker_address(self) -> tuple[str, int]: + def _get_worker_address(self) -> Tuple[str, int]: """Get worker address from coordinator. Returns: @@ -80,7 +80,7 @@ def _get_worker_address(self) -> tuple[str, int]: host, port_str = worker_addr.split(":") return host, int(port_str) - def get_ds_client(self, device_type: str) -> Any: + def get_ds_client(self, device_type: str): """Creates a YR DS client if it does not already exist.""" if self._ds_client.get(device_type) is not None: return self._ds_client[device_type] @@ -221,7 +221,7 @@ def send_multiple_tensors( tensors: List["torch.Tensor"], tensor_transport_metadata: TensorTransportMetadata, communicator_metadata: CommunicatorMetadata, - ): + ) -> None: raise NotImplementedError( "YR DS transport does not support send_multiple_tensors," "since it is a one-sided transport." @@ -232,7 +232,7 @@ def garbage_collect( obj_id: str, tensor_transport_meta: TensorTransportMetadata, tensors: Optional[List[Any]] = None, - ): + ) -> None: assert isinstance(tensor_transport_meta, YRTransportMetadata) serialized_keys = tensor_transport_meta.ds_serialized_keys device_type = tensor_transport_meta.tensor_device @@ -252,5 +252,5 @@ def abort_transport( self, obj_id: str, communicator_metadata: CommunicatorMetadata, - ): + ) -> None: raise NotImplementedError("YR transport does not support aborting.") diff --git a/ray_ascend/direct_transport/yr_tensor_transport_util.py b/ray_ascend/direct_transport/yr_tensor_transport_util.py index 8f4e29c..3f98044 100644 --- a/ray_ascend/direct_transport/yr_tensor_transport_util.py +++ b/ray_ascend/direct_transport/yr_tensor_transport_util.py @@ -1,6 +1,7 @@ import struct import warnings from concurrent.futures import ThreadPoolExecutor +from typing import List import torch @@ -38,78 +39,126 @@ from ray_ascend.utils.serial_utils import _decoder, _encoder -def raise_if_failed(failed_keys, action): +def raise_if_failed(failed_keys: List[str], action: str) -> None: + """Raise RuntimeError if any keys failed. + + Args: + failed_keys: List of keys that failed the operation. + action: Description of the action (e.g., "put", "get", "delete"). + """ if failed_keys: raise RuntimeError(f"Failed to {action} keys: {failed_keys}") class BaseDSAdapter(ABC): - MAX_KEYS_PER_BATCH = 10000 + """Base class for YR DS client adapters with batch processing support.""" + + MAX_KEYS_PER_BATCH: int = 10000 @abstractmethod - def init(self): + def init(self) -> None: + """Initialize the DS client connection.""" pass - def put(self, keys, tensors): - """Store multiple objects with batch processing.""" + def put(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Store multiple objects with batch processing. + + Args: + keys: List of keys to store. + tensors: List of tensors to store. + """ batch_size = self.MAX_KEYS_PER_BATCH for i in range(0, len(keys), batch_size): self._put_batch(keys[i : i + batch_size], tensors[i : i + batch_size]) @abstractmethod - def _put_batch(self, keys, tensors): - """Process a single batch of put operations.""" + def _put_batch(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Process a single batch of put operations. + + Args: + keys: List of keys for this batch. + tensors: List of tensors for this batch. + """ pass - def get(self, keys, tensors): - """Retrieve multiple objects with batch processing.""" + def get(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Retrieve multiple objects with batch processing. + + Args: + keys: List of keys to retrieve. + tensors: List of tensors to populate with retrieved data. + """ batch_size = self.MAX_KEYS_PER_BATCH for i in range(0, len(keys), batch_size): self._get_batch(keys[i : i + batch_size], tensors[i : i + batch_size]) @abstractmethod - def _get_batch(self, keys, tensors): - """Process a single batch of get operations.""" + def _get_batch(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Process a single batch of get operations. + + Args: + keys: List of keys for this batch. + tensors: List of tensors for this batch. + """ pass - def delete(self, keys): - """Delete multiple keys with batch processing.""" + def delete(self, keys: List[str]) -> None: + """Delete multiple keys with batch processing. + + Args: + keys: List of keys to delete. + """ batch_size = self.MAX_KEYS_PER_BATCH for i in range(0, len(keys), batch_size): self._delete_batch(keys[i : i + batch_size]) @abstractmethod - def _delete_batch(self, keys): - """Process a single batch of delete operations.""" + def _delete_batch(self, keys: List[str]) -> None: + """Process a single batch of delete operations. + + Args: + keys: List of keys for this batch. + """ pass class CPUClientAdapter(BaseDSAdapter): + """DS client adapter for CPU tensors using structured binary packing.""" + # Header: number of entries (uint32, little-endian) - HEADER_FMT = " None: + """Initialize the KV client connection.""" self._client.init() @classmethod - def calc_packed_size(cls, items: list[memoryview]) -> int: - """ - Calculate the total size (in bytes) required to pack a list of memoryview items + def calc_packed_size(cls, items: List[memoryview]) -> int: + """Calculate the total size (in bytes) required to pack a list of memoryview items into the structured binary format used by pack_into. Args: @@ -125,9 +174,8 @@ def calc_packed_size(cls, items: list[memoryview]) -> int: ) @classmethod - def pack_into(cls, target: memoryview, items: list[memoryview]): - """ - Pack multiple contiguous buffers into a single buffer. + def pack_into(cls, target: memoryview, items: List[memoryview]) -> None: + """Pack multiple contiguous buffers into a single buffer. ┌───────────────┐ │ item_count │ uint32 ├───────────────┤ @@ -137,10 +185,10 @@ def pack_into(cls, target: memoryview, items: list[memoryview]): └───────────────┘ Args: - target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData(). + target: A writable memoryview returned by StateValueBuffer.MutableData(). It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items. This buffer is usually mapped to shared memory or Zero-Copy memory area. - items (List[memoryview]): List of read-only memory views (e.g., from serialized objects). + items: List of read-only memory views (e.g., from serialized objects). Each item must support the buffer protocol and be readable as raw bytes. """ @@ -163,13 +211,14 @@ def pack_into(cls, target: memoryview, items: list[memoryview]): payload_offset += item.nbytes @classmethod - def unpack_from(cls, source: memoryview) -> list[memoryview]: - """ - Unpack multiple contiguous buffers from a single packed buffer. + def unpack_from(cls, source: memoryview) -> List[memoryview]: + """Unpack multiple contiguous buffers from a single packed buffer. + Args: - source (memoryview): The packed source buffer. + source: The packed source buffer. + Returns: - list[memoryview]: List of unpacked contiguous buffers. + List of unpacked contiguous buffers. """ mv = memoryview(source) item_count = struct.unpack_from(cls.HEADER_FMT, mv, 0)[0] @@ -181,8 +230,16 @@ def unpack_from(cls, source: memoryview) -> list[memoryview]: offsets.append((offset, length)) return [mv[offset : offset + length] for offset, length in offsets] - def _put_batch(self, keys: list[str], tensors: list[torch.Tensor]): - """Process a single batch of put operations.""" + def _put_batch(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Process a single batch of put operations. + + Args: + keys: List of keys for this batch. + tensors: List of tensors for this batch. + + Raises: + RuntimeError: If any keys fail to be put. + """ items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in tensors] packed_sizes = [self.calc_packed_size(items) for items in items_list] buffers = self._client.mcreate(keys, packed_sizes) @@ -199,26 +256,57 @@ def _put_batch(self, keys: list[str], tensors: list[torch.Tensor]): list(executor.map(lambda p: self.pack_into(*p), tasks)) self._client.mset_buffer(buffers) - def _get_batch(self, keys: list[str], tensors: list[torch.Tensor]): - """Process a single batch of get operations.""" + def _get_batch(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Process a single batch of get operations. + + Args: + keys: List of keys for this batch. + tensors: List of tensors to populate with retrieved data. + + Raises: + RuntimeError: If any key fails to be retrieved. + """ buffers = self._client.get_buffers(keys) for i, buffer in enumerate(buffers): if buffer is None: raise RuntimeError(f"Failed to get key: {keys[i]}") tensors[i] = _decoder.decode(self.unpack_from(buffer)) - def _delete_batch(self, keys): - """Process a single batch of delete operations.""" + def _delete_batch(self, keys: List[str]) -> None: + """Process a single batch of delete operations. + + Args: + keys: List of keys for this batch. + + Raises: + RuntimeError: If any keys fail to be deleted. + """ failed_keys = self._client.delete(keys=keys) raise_if_failed(failed_keys, "delete") - def health_check(self): - return self._client.health_check().is_ok() + def health_check(self) -> bool: + """Check if the DS client is healthy. + + Returns: + True if the client is healthy, False otherwise. + """ + is_healthy: bool = self._client.health_check().is_ok() + return is_healthy class NPUClientAdapter(BaseDSAdapter): + """DS client adapter for NPU tensors using device-direct operations.""" + + def __init__(self, host: str, port: int): + """Initialize NPUClientAdapter with DS server address. - def __init__(self, host, port): + Args: + host: DS server host address. + port: DS server port. + + Raises: + RuntimeError: If 'datasystem' or NPU support is not installed. + """ if not NPU_AVAILABLE: raise RuntimeError( "Missing optional dependency 'datasystem' or NPU support. Install with: " @@ -232,20 +320,44 @@ def __init__(self, host, port): connect_timeout_ms=60000, ) - def init(self): + def init(self) -> None: + """Initialize the DsTensorClient connection.""" self._client.init() - def _put_batch(self, keys, tensors): - """Process a single batch of put operations.""" + def _put_batch(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Process a single batch of put operations for NPU tensors. + + Args: + keys: List of keys for this batch. + tensors: List of NPU tensors for this batch. + + Raises: + RuntimeError: If any keys fail to be put. + """ failed_keys = self._client.dev_mset(keys=keys, tensors=tensors) raise_if_failed(failed_keys, "put") - def _get_batch(self, keys, tensors): - """Process a single batch of get operations.""" + def _get_batch(self, keys: List[str], tensors: List["torch.Tensor"]) -> None: + """Process a single batch of get operations for NPU tensors. + + Args: + keys: List of keys for this batch. + tensors: List of NPU tensors to populate with retrieved data. + + Raises: + RuntimeError: If any keys fail to be retrieved. + """ failed_keys = self._client.dev_mget(keys=keys, tensors=tensors) raise_if_failed(failed_keys, "get") - def _delete_batch(self, keys): - """Process a single batch of delete operations.""" + def _delete_batch(self, keys: List[str]) -> None: + """Process a single batch of delete operations for NPU tensors. + + Args: + keys: List of keys for this batch. + + Raises: + RuntimeError: If any keys fail to be deleted. + """ failed_keys = self._client.dev_delete(keys=keys) raise_if_failed(failed_keys, "delete") diff --git a/ray_ascend/utils/serial_utils.py b/ray_ascend/utils/serial_utils.py index 5010279..173d796 100644 --- a/ray_ascend/utils/serial_utils.py +++ b/ray_ascend/utils/serial_utils.py @@ -12,13 +12,19 @@ class SimpleTensorEncoder: It mimics the interface of MsgpackEncoder.encode() but skips msgpack entirely. """ - def encode(self, obj: torch.Tensor) -> Sequence[bytestr]: - """ - Encode a single torch.Tensor in zero-copy mode. + def encode(self, obj: "torch.Tensor") -> Sequence[bytestr]: + """Encode a single torch.Tensor in zero-copy mode. + + Args: + obj: The torch.Tensor to encode. Returns: A list [meta_bytes, raw_data_buffer] which is compatible with the original MsgpackEncoder's return type. + + Raises: + TypeError: If obj is not a torch.Tensor. + ValueError: If obj is sparse, nested, or not dense. """ if not isinstance(obj, torch.Tensor): raise TypeError("SimpleTensorEncoder only supports torch.Tensor") @@ -40,7 +46,7 @@ def encode(self, obj: torch.Tensor) -> Sequence[bytestr]: # Serialize metadata to bytes using pickle (for simplicity and compatibility) meta_bytes = pickle.dumps(meta_tuple, protocol=pickle.HIGHEST_PROTOCOL) - bufs: list[bytestr] = [meta_bytes, raw_data] + bufs = [meta_bytes, raw_data] return bufs @@ -50,15 +56,17 @@ class SimpleTensorDecoder: It mimics the interface of MsgpackDecoder.decode() but skips msgpack entirely. """ - def decode(self, bufs: Sequence[bytestr]) -> torch.Tensor: - """ - Decode a list of bytes into a torch.Tensor. + def decode(self, bufs: Sequence[bytestr]) -> "torch.Tensor": + """Decode a list of bytes into a torch.Tensor. Args: bufs: A sequence [meta_bytes, raw_data_buffer]. Returns: The reconstructed torch.Tensor. + + Raises: + ValueError: If bufs is a single bytes object instead of a sequence. """ if isinstance(bufs, bytestr): raise ValueError( diff --git a/ray_ascend/utils/yr_utils.py b/ray_ascend/utils/yr_utils.py index e49c7aa..0233067 100644 --- a/ray_ascend/utils/yr_utils.py +++ b/ray_ascend/utils/yr_utils.py @@ -7,7 +7,7 @@ import sys import tempfile import time -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple import ray import requests @@ -46,8 +46,12 @@ def get_free_port() -> int: return int(s.getsockname()[1]) -def check_etcd_installed(): - """Raise RuntimeError if 'etcd' is not found in PATH.""" +def check_etcd_installed() -> None: + """Raise RuntimeError if 'etcd' is not found in PATH. + + Raises: + RuntimeError: If etcd is not installed or not found in PATH. + """ if shutil.which("etcd") is None: raise RuntimeError( "'etcd' is not installed or not found in PATH. Please install etcd and ensure it's accessible from the command line." @@ -59,8 +63,21 @@ def start_etcd( client_port: Optional[int] = None, peer_port: Optional[int] = None, max_retries: int = 3, -) -> tuple[str, subprocess.Popen, str]: - """Start etcd in a subprocess and wait until it's healthy.""" +) -> Tuple[str, subprocess.Popen, str]: + """Start etcd in a subprocess and wait until it's healthy. + + Args: + host: The host address for etcd to bind to. + client_port: The client port. If None, a free port is auto-selected. + peer_port: The peer port. If None, a free port is auto-selected. + max_retries: Maximum number of retries for starting etcd. + + Returns: + Tuple of (etcd_addr, etcd_proc, etcd_data_dir). + + Raises: + RuntimeError: If etcd fails to start after max_retries. + """ check_etcd_installed() for attempt in range(max_retries): @@ -101,7 +118,7 @@ def start_etcd( for _ in range(10): try: resp = requests.get(f"{client_addr}/health", timeout=1) - is_etcd_healthy: bool = ( + is_etcd_healthy = ( resp.status_code == requests.codes.ok and resp.json().get("health") == "true" ) @@ -326,7 +343,7 @@ def __init__( worker_args: Optional[str] = None, worker_port: Optional[int] = None, ): - self.init_mode = init_mode + self.init_mode: str = init_mode self._worker_host: Optional[str] = None self._worker_port: Optional[int] = None self._worker_address: Optional[str] = None @@ -334,7 +351,7 @@ def __init__( self.etcd_address: Optional[str] = None self.metastore_address: Optional[str] = None self.is_head: bool = False - self.worker_args = worker_args or "" + self.worker_args: str = worker_args or "" # Get node IP via Ray API self._worker_host = ray.util.get_node_ip_address() @@ -394,7 +411,11 @@ def start(self) -> str: return worker_address def _start_reaper(self, worker_address: str) -> None: - """Start reaper subprocess for Parent Process Death Detection cleanup.""" + """Start reaper subprocess for Parent Process Death Detection cleanup. + + Args: + worker_address: The worker address to monitor. + """ self._reaper_process = subprocess.Popen( [sys.executable, "-c", _REAPER_SCRIPT, worker_address], stdin=subprocess.PIPE, @@ -470,20 +491,24 @@ class YRBackendCoordinator: The coordinator: - Creates Placement Group for worker actors - Creates DataSystemActor on each node - - Manages node_worker_addresses mapping + - Manages node_worker_addresses mapping. """ def __init__(self): - self._initialized = False + self._initialized: bool = False self._init_mode: Optional[str] = None # "etcd" or "metastore" self._worker_args: str = "" - self._placement_group = None + self._placement_group: Optional["ray.util.placement_group.PlacementGroup"] = ( + None + ) self._etcd_address: Optional[str] = None # for etcd mode self._metastore_address: Optional[str] = None # for metastore mode - self._node_worker_addresses = {} # {node_ip: worker_address} - self._worker_actors: list = [] # Store actor handles for cleanup + self._node_worker_addresses: Dict[str, str] = {} # {node_ip: worker_address} + self._worker_actors: List["ray.actor.ActorHandle"] = [] - def _create_placement_group(self, nodes: list): + def _create_placement_group( + self, nodes: List[dict] + ) -> "ray.util.placement_group.PlacementGroup": """Create placement group with STRICT_SPREAD strategy. Args: @@ -529,7 +554,7 @@ def _create_placement_group(self, nodes: list): ) return pg - def _collect_worker_addresses(self, actors: list) -> None: + def _collect_worker_addresses(self, actors: List["ray.actor.ActorHandle"]) -> None: """Collect worker addresses from actors. Args: @@ -540,7 +565,7 @@ def _collect_worker_addresses(self, actors: list) -> None: worker_addr = ray.get(actor.get_worker_address.remote()) self._node_worker_addresses[node_ip] = worker_addr - def _get_backend_info_dict(self) -> dict: + def _get_backend_info_dict(self) -> Dict[str, Any]: """Return backend info dict.""" return { "init_mode": self._init_mode, @@ -580,7 +605,9 @@ def cleanup(self) -> None: except Exception as e: logger.warning(f"Failed to remove placement group: {e}") - def _get_bundle_node_ip(self, pg, bundle_index: int) -> str: + def _get_bundle_node_ip( + self, pg: "ray.util.placement_group.PlacementGroup", bundle_index: int + ) -> str: """Get node IP for a specific bundle in placement group. Args: @@ -607,7 +634,7 @@ def _get_bundle_node_ip(self, pg, bundle_index: int) -> str: raise RuntimeError(f"Node {node_id} not found in cluster") - def _get_alive_nodes(self) -> list: + def _get_alive_nodes(self) -> List[dict]: """Get list of alive Ray nodes. Returns: @@ -629,7 +656,7 @@ def initialize( worker_args: str, etcd_address: Optional[str] = None, metastore_port: Optional[int] = None, - ) -> dict: + ) -> Dict[str, Any]: """Initialize YR backend. All parameters are provided by ensure_yr_backend_initialized. @@ -681,7 +708,9 @@ def initialize( else: raise RuntimeError(f"Unknown init_mode: {init_mode}") - def _initialize_etcd_mode(self, worker_port: int, worker_args: str) -> dict: + def _initialize_etcd_mode( + self, worker_port: int, worker_args: str + ) -> Dict[str, Any]: """Initialize YR backend using user-provided etcd. In etcd mode, all nodes start DS workers that connect to the same etcd. @@ -745,7 +774,7 @@ def _initialize_etcd_mode(self, worker_port: int, worker_args: str) -> dict: def _initialize_metastore_mode( self, worker_port: int, metastore_port: int, worker_args: str - ) -> dict: + ) -> Dict[str, Any]: """Initialize YR backend using metastore mode. In metastore mode, the head node starts a metastore service, @@ -851,7 +880,7 @@ def get_backend_info( worker_args: Optional[str] = None, etcd_address: Optional[str] = None, metastore_port: Optional[int] = None, - ) -> dict: + ) -> Dict[str, Any]: """Get backend info. Behavior depends on initialization state and parameters: @@ -893,9 +922,16 @@ def ensure_yr_backend_initialized( worker_args: Optional[str] = None, etcd_address: Optional[str] = None, metastore_port: Optional[int] = None, -) -> dict: +) -> Dict[str, Any]: """Ensure YR backend is initialized. + Args: + init_mode: Initialization mode, "etcd" or "metastore" (defaults to env var or "metastore"). + worker_port: DS worker port (defaults to env var or 31501). + worker_args: Additional worker arguments (defaults to env var or ""). + etcd_address: Etcd address (required for etcd mode, defaults to env var). + metastore_port: Metastore port (defaults to env var or 2379). + Returns: backend_info dict containing init_mode, worker_args, etcd/metastore_address, and node_worker_addresses @@ -922,7 +958,7 @@ def ensure_yr_backend_initialized( namespace="yr_backend", get_if_exists=True ).remote() - backend_info: dict = ray.get( + backend_info: Dict[str, Any] = ray.get( coordinator.get_backend_info.remote( init_mode_val, worker_port_val, @@ -934,7 +970,7 @@ def ensure_yr_backend_initialized( return backend_info -def get_yr_backend_info() -> dict: +def get_yr_backend_info() -> Dict[str, Any]: """Get YR backend info if already initialized. This function is used internally by YRTensorTransport to get worker addresses. @@ -951,5 +987,5 @@ def get_yr_backend_info() -> dict: namespace="yr_backend", get_if_exists=True ).remote() - backend_info: dict = ray.get(coordinator.get_backend_info.remote()) + backend_info: Dict[str, Any] = ray.get(coordinator.get_backend_info.remote()) return backend_info