From 366eabf0c677c6d5ff30ead6cb0155030725336c Mon Sep 17 00:00:00 2001 From: Artimislyy <2249614312@qq.com> Date: Tue, 16 Jun 2026 15:07:37 +0800 Subject: [PATCH] Add HIXL tensor transport implementation and design docs - HixlTensorTransport class (one-sided RDMA READ for Ascend NPU) - HixlTransportMetadata / HixlFetchRequest / HixlTensorDesc data classes - hixl_wrapper pybind11 integration (RegisterMem, TransferAsync, GetTransferStatus) - LRU remote engine connection cache - Design docs and implementation prompts Co-Authored-By: Claude --- docs/hixl-implementation-prompt-phase3.md | 55 + docs/hixl-implementation-prompt-phase4.md | 826 ++++++++++++++ docs/hixl-tensor-transport-design.md | 1000 +++++++++++++++++ docs/hixl-wrapper-bindings-plan.md | 878 +++++++++++++++ pyproject.toml | 7 + ray_ascend/__init__.py | 74 ++ .../direct_transport/hixl_tensor_transport.py | 847 ++++++++++++++ 7 files changed, 3687 insertions(+) create mode 100644 docs/hixl-implementation-prompt-phase3.md create mode 100644 docs/hixl-implementation-prompt-phase4.md create mode 100644 docs/hixl-tensor-transport-design.md create mode 100644 docs/hixl-wrapper-bindings-plan.md create mode 100644 ray_ascend/direct_transport/hixl_tensor_transport.py diff --git a/docs/hixl-implementation-prompt-phase3.md b/docs/hixl-implementation-prompt-phase3.md new file mode 100644 index 0000000..e6a09e5 --- /dev/null +++ b/docs/hixl-implementation-prompt-phase3.md @@ -0,0 +1,55 @@ +# Phase 3 Prompt:注册集成 + ray-ascend 入口 + +你需要在 ray-ascend 项目中完成 HIXL tensor transport 的注册集成,确保 `register_hixl_tensor_transport()` 被调用后 `@ray.method(tensor_transport="HIXL")` 能正常工作。 + +## 背景 + +- `pyproject.toml` 缺少 `hixl` 可选依赖组 + +## 参考:现有注册集成模式 + +请严格参考以下现有代码的模式和风格: + +1. **ray-ascend 注册入口风格**:参考 `ray-ascend/ray_ascend/__init__.py` 中已有的 `register_yr_tensor_transport()`(23-97 行)和 `register_hccl_tensor_transport()`(144-173 行) + - 参数校验(devices is None → raise ValueError) + - ImportError 处理(依赖不装 → 明确提示安装命令) + - 最终调用 `register_tensor_transport(name, devices, Class, torch.Tensor)` + +2. **pyproject.toml 可选依赖格式**:参考 `ray-ascend/pyproject.toml` 中已有的 `[project.optional-dependencies.yr]`(42-48 行) + +## 需要完成的修改 + + +### 修改 3:在 pyproject.toml 添加 hixl 可选依赖组 + +参考 `[project.optional-dependencies.yr]`(第 42-48 行)的格式,在 `ray-ascend/pyproject.toml` 的 `[project.optional-dependencies]` 中新增: + +```toml +hixl = [ + "hixl_engine>=0.0.1", + "torch>=2.7.1; platform_machine == 'x86_64'", + "torch>=2.7.1; platform_machine == 'aarch64'", + "torch-npu>=2.7.1.post2", +] +``` + +> 注意:`hixl_engine` 的实际包名和版本号需根据 wheel 包确认。如果 wheel 尚未发布到 PyPI,可暂用 URL 引用:`"hixl_engine @ https:///hixl_engine-0.0.1-py3-none-any.whl"` + +### 修改 4:确认 __init__.py 注册入口 + +检查 `ray-ascend/ray_ascend/__init__.py` 中已有的 `register_hixl_tensor_transport()` 函数(176-246 行)是否与设计文档一致。设计文档在 `ray-ascend/docs/hixl-tensor-transport-design.md` 第 7 节。 + +重点确认: +1. `devices is None` 校验存在 +2. 导入路径正确:`from ray_ascend.direct_transport.hixl_tensor_transport import HixlTensorTransport` +3. `hixl_wrapper` 可导入性检查存在 +4. 最终调用 `register_tensor_transport("HIXL", devices, HixlTensorTransport, torch.Tensor)` — 这要求 `HixlTensorTransport` 必须是 `TensorTransportManager` 的子类 + +如果以上全部正确则无需改动;如有遗漏请修正。 + +## 输出要求 + +请修改以下文件: + +1. `ray-ascend/pyproject.toml` — 添加 hixl 可选依赖组 +2. `ray-ascend/ray_ascend/__init__.py` — 如需修正则修改(否则不动) \ No newline at end of file diff --git a/docs/hixl-implementation-prompt-phase4.md b/docs/hixl-implementation-prompt-phase4.md new file mode 100644 index 0000000..016cc73 --- /dev/null +++ b/docs/hixl-implementation-prompt-phase4.md @@ -0,0 +1,826 @@ +# Phase 4 Prompt:测试 + 调试 + +你需要在 ray-ascend 项目中为 HIXL tensor transport 创建完整的单元测试文件,覆盖所有核心方法的逻辑正确性。 + +## 背景 + +- HIXL tensor transport 目前**完全没有测试文件** +- `tests/direct_transport/` 下只有 `test_yr_transport.py`、`test_yr_transport_util.py` 和 `test_hccl_tensor_transport.py` +- HIXL 依赖 NPU 硬件和 RDMA 链路,测试需要分层:L1(纯 mock 单元测试,无硬件)和 L2(硬件集成测试,skipif 标记) +- Phase 3 已完成基类继承修复和命名统一(`agent → engine`),测试中应使用 `hixl_engine_meta_version` 等命名 + +## 参考:现有测试模式 + +请严格参考以下现有代码的测试风格和模式: + +1. **NIXL 引用计数单元测试风格**:参考 `ray/python/ray/tests/rdt/test_rdt_nixl.py` 第 434-527 行 + - 直接 `NixlTensorTransport()` 实例化(不需要 Ray 集群) + - 验证 `metadata_count` 的增减逻辑 + - 验证 `_remove_tensor_descs` 在 `metadata_count == 0` 时 deregister + - 验证 `_managed_meta_nixl` 的 pop 和清理逻辑 + +2. **NIXL register/deregister 测试风格**:参考 `ray/python/ray/tests/rdt/test_rdt_nixl.py` 第 577-609 行 + - `register_nixl_memory → deregister_nixl_memory` 流程 + - 验证 GC 后引用计数归零 + +## Mock 设计 + +### MockHixlWrapper(手动模拟类) + +模拟 `hixl_wrapper` 模块的全部 API。这是一个**手写的类**(不是 MagicMock),因为 `HixlTensorTransport` 直接调用其方法名和属性名,MagicMock 无法保证方法名匹配。 + +```python +class MockHixlWrapper: + """Simulates hixl_wrapper module for unit testing. + + Key design decisions: + - Hand-written class (not MagicMock) so method names match real hixl_wrapper + - Internal state tracking via dicts so tests can verify registration/connect behavior + - Auto-progression: get_transfer_status advances WAITING → COMPLETED on first call + """ + + # Status codes — match real hixl_wrapper constants + kSuccess = 0 + kAlreadyConnected = 103903 + kFailed = 503900 + kParamInvalid = 103900 + kTimeout = 103901 + kNotConnected = 103902 + + _registered_mems: Dict[int, int] = {} # addr → mem_handle + _connected_engines: Dict[str, bool] = {} # engine_id → True + _next_handle: int = 1 + _transfer_reqs: Dict[int, str] = {} # req_id → status_str + + @classmethod + def initialize(cls, engine_id: str, options: dict) -> int: + """Initialize HIXL engine. Always succeeds in mock.""" + return cls.kSuccess + + @classmethod + def register_mem(cls, mem_desc: tuple, mem_type: str) -> tuple: + """Register memory region. Returns (kSuccess, handle_int).""" + addr, nbytes = mem_desc + handle = cls._next_handle + cls._next_handle += 1 + cls._registered_mems[addr] = handle + return (cls.kSuccess, handle) + + @classmethod + def deregister_mem(cls, mem_handle: int) -> int: + """Deregister memory region. Returns kSuccess.""" + cls._registered_mems = { + k: v for k, v in cls._registered_mems.items() if v != mem_handle + } + return cls.kSuccess + + @classmethod + def connect(cls, remote_engine: str, timeout_ms: int = 1000) -> int: + """Connect to remote engine. Returns kSuccess.""" + cls._connected_engines[remote_engine] = True + return cls.kSuccess + + @classmethod + def disconnect(cls, remote_engine: str, timeout_ms: int = 1000) -> int: + """Disconnect from remote engine. Returns kSuccess.""" + cls._connected_engines.pop(remote_engine, None) + return cls.kSuccess + + @classmethod + def transfer_async(cls, remote_engine: str, operation: str, op_descs: list) -> tuple: + """Initiate async transfer. Returns (kSuccess, req_id_int).""" + req_id = cls._next_handle + cls._next_handle += 1 + cls._transfer_reqs[req_id] = "WAITING" + return (cls.kSuccess, req_id) + + @classmethod + def get_transfer_status(cls, req_id: int) -> tuple: + """Poll transfer status. Auto-progresses WAITING → COMPLETED.""" + if req_id not in cls._transfer_reqs: + return (cls.kFailed, "FAILED") + current = cls._transfer_reqs[req_id] + if current == "WAITING": + cls._transfer_reqs[req_id] = "COMPLETED" + return (cls.kSuccess, "WAITING") # First poll returns WAITING + return (cls.kSuccess, current) + + @classmethod + def reset(cls): + """Reset all mock state between tests.""" + cls._registered_mems.clear() + cls._connected_engines.clear() + cls._transfer_reqs.clear() + cls._next_handle = 1 +``` + +> **注意**:`get_transfer_status` 的 mock 需要模拟真实 HIXL 行为——第一次轮询返回 `"WAITING"`,随后自动推进到 `"COMPLETED"`。这样 `wait_fetch_complete` 的轮询循环可以正常退出,不需要 `time.sleep` 模拟。 + +### Patch 目标 + +```python +PATCH_TARGET = "ray_ascend.direct_transport.hixl_tensor_transport.hixl_wrapper" +``` + +这是 `hixl_tensor_transport.py` 第 20-22 行的 lazy import 变量名。patch 这个位置会让 `HixlTensorTransport` 内所有对 `hixl_wrapper.xxx` 的调用走 mock。 + +### Ray mock + +`_ensure_hixl_initialized()` 调用 `ray.get_runtime_context()` 和 `ray.util.get_node_ip_address()`,需要 patch: + +```python +with patch("ray.get_runtime_context") as mock_ctx, \ + patch("ray.util.get_node_ip_address", return_value="10.0.0.1"): + mock_ctx.return_value.get_actor_id.return_value = "test_actor_123" +``` + +### Fixture 设计 + +```python +@pytest.fixture +def mock_hixl_wrapper(): + """Patch hixl_wrapper with a MockHixlWrapper instance.""" + mock = MockHixlWrapper() + with patch(PATCH_TARGET, mock): + mock.reset() + yield mock + +@pytest.fixture +def transport(mock_hixl_wrapper): + """Create HixlTensorTransport with mocked hixl_wrapper and Ray runtime.""" + with ( + patch(PATCH_TARGET, mock_hixl_wrapper), + patch("ray.get_runtime_context") as mock_ctx, + patch("ray.util.get_node_ip_address", return_value="10.0.0.1"), + ): + mock_ctx.return_value.get_actor_id.return_value = "test_actor_123" + t = HixlTensorTransport() + t._ensure_hixl_initialized() + yield t +``` + +> **注意**:`transport` fixture 中需要同时 patch `hixl_wrapper` 和 Ray API,且 `_ensure_hixl_initialized()` 在 fixture 内调用,这样后续测试中 transport 已处于初始化状态。 + +## 需要创建的测试文件 + +创建文件 `ray-ascend/tests/direct_transport/test_hixl_tensor_transport.py`,包含以下 11 个测试套件: + +### Suite 1:TestDataClasses(纯 Python,无 mock) + +验证数据类定义和继承关系: + +```python +class TestDataClasses: + """Verify data class definitions and inheritance.""" + + def test_hixl_communicator_metadata_inherits(self): + assert issubclass(HixlCommunicatorMetadata, CommunicatorMetadata) + + def test_hixl_transport_metadata_inherits(self): + assert issubclass(HixlTransportMetadata, TensorTransportMetadata) + + def test_hixl_transport_metadata_fields(self): + meta = HixlTransportMetadata( + tensor_meta=[((2, 3), torch.float32)], + tensor_device="npu", + hixl_serialized_mem_descs=b"fake", + hixl_engine_id="10.0.0.1:12345", + hixl_engine_meta_version=0, + ) + assert meta.hixl_serialized_mem_descs is not None + assert meta.hixl_engine_id is not None + assert meta.hixl_engine_meta_version == 0 + + def test_hixl_transport_metadata_no_duplicate_base_fields(self): + """子类不应重复定义基类的 tensor_meta 和 tensor_device 字段。 + dataclass 继承中如果子类重新定义了基类字段,会导致字段顺序错误。 + """ + base_fields = [f.name for f in TensorTransportMetadata.__dataclass_fields__.values()] + child_fields = [f.name for f in HixlTransportMetadata.__dataclass_fields__.values()] + # 基类字段应出现在子类字段列表的前面(继承顺序),但不应有重复定义 + # dataclass 继承的正确行为:子类字段列表 = 基类字段 + 新增字段 + new_fields = [f for f in child_fields if f not in base_fields] + assert "hixl_serialized_mem_descs" in new_fields + assert "hixl_engine_id" in new_fields + assert "hixl_engine_meta_version" in new_fields + + def test_hixl_tensor_desc_fields(self): + desc = HixlTensorDesc(mem_handle=42, nbytes=1024, mem_type_str="npu", metadata_count=1) + assert desc.mem_handle == 42 + assert desc.nbytes == 1024 + assert desc.mem_type_str == "npu" + assert desc.metadata_count == 1 + + def test_hixl_fetch_request_inherits(self): + assert issubclass(HixlFetchRequest, FetchRequest) + + def test_hixl_fetch_request_custom_fields(self): + req = HixlFetchRequest( + obj_id="test_obj", + tensors=[], + transfer_req=123, + remote_engine_id="10.0.0.1:12345", + remove_tensor_descs=True, + transport=None, # None so __del__ won't crash + ) + assert req.transfer_req == 123 + assert req.remote_engine_id == "10.0.0.1:12345" + assert req.remove_tensor_descs is True +``` + +### Suite 2:TestTransportProperties(纯 Python,无 mock) + +```python +class TestTransportProperties: + """Test static properties without hardware or mock.""" + + def test_tensor_transport_backend(self): + t = HixlTensorTransport() + assert t.tensor_transport_backend() == "HIXL" + + def test_is_one_sided(self): + assert HixlTensorTransport.is_one_sided() is True + + def test_can_abort_transport(self): + assert HixlTensorTransport.can_abort_transport() is True + + def test_inherits_tensor_transport_manager(self): + assert issubclass(HixlTensorTransport, TensorTransportManager) + + def test_send_multiple_tensors_raises(self): + t = HixlTensorTransport() + with pytest.raises(NotImplementedError, match="one-sided"): + t.send_multiple_tensors([], HixlTransportMetadata(tensor_meta=[], tensor_device=None), HixlCommunicatorMetadata()) +``` + +### Suite 3:TestMemoryRegistration(mock hixl_wrapper) + +```python +class TestMemoryRegistration: + """Test _add_tensor_descs and _remove_tensor_descs with mock.""" + + def test_register_new_cpu_tensor(self, transport, mock_hixl_wrapper): + """New CPU tensor → register_mem called, metadata_count=1, mem_type_str='cpu'.""" + t = torch.randn(2, 3, device="cpu") + transport._add_tensor_descs([t]) + key = t.untyped_storage().data_ptr() + assert key in transport._tensor_desc_cache + desc = transport._tensor_desc_cache[key] + assert desc.metadata_count == 1 + assert desc.mem_type_str == "cpu" + assert desc.mem_handle in mock_hixl_wrapper._registered_mems.values() + + def test_register_same_tensor_twice_bumps_ref_count(self, transport, mock_hixl_wrapper): + """Same tensor registered twice → metadata_count=2, no second register_mem call.""" + t = torch.randn(2, 3, device="cpu") + transport._add_tensor_descs([t]) + transport._add_tensor_descs([t]) + key = t.untyped_storage().data_ptr() + assert transport._tensor_desc_cache[key].metadata_count == 2 + # Only one register_mem call (check mock handle count) + assert len(mock_hixl_wrapper._registered_mems) == 1 + + def test_register_multiple_tensors(self, transport, mock_hixl_wrapper): + """Multiple different tensors → each gets its own HixlTensorDesc.""" + t1 = torch.randn(2, 3, device="cpu") + t2 = torch.randn(4, 5, device="cpu") + transport._add_tensor_descs([t1, t2]) + assert len(transport._tensor_desc_cache) == 2 + + def test_deregister_when_ref_count_zero(self, transport, mock_hixl_wrapper): + """metadata_count→0 → deregister_mem called, entry removed from cache.""" + t = torch.randn(2, 3, device="cpu") + transport._add_tensor_descs([t]) + key = t.untyped_storage().data_ptr() + handle = transport._tensor_desc_cache[key].mem_handle + + transport._remove_tensor_descs([t]) + assert key not in transport._tensor_desc_cache + assert handle not in mock_hixl_wrapper._registered_mems.values() + + def test_partial_deregister_keeps_registration(self, transport, mock_hixl_wrapper): + """metadata_count 2→1 → entry stays, no deregister_mem call.""" + t = torch.randn(2, 3, device="cpu") + transport._add_tensor_descs([t]) + transport._add_tensor_descs([t]) # ref_count = 2 + + transport._remove_tensor_descs([t]) # ref_count = 1 + key = t.untyped_storage().data_ptr() + assert key in transport._tensor_desc_cache + assert transport._tensor_desc_cache[key].metadata_count == 1 + # Handle still in mock registered_mems + assert len(mock_hixl_wrapper._registered_mems) == 1 + + def test_deregister_unknown_tensor_is_noop(self, transport, mock_hixl_wrapper): + """Removing a tensor not in cache should skip silently.""" + t = torch.randn(2, 3, device="cpu") + transport._remove_tensor_descs([t]) # Never registered + # Should not crash + + def test_engine_meta_version_bumps_on_full_deregister(self, transport, mock_hixl_wrapper): + """_hixl_engine_meta_version increments when memory is fully deregistered.""" + initial = transport._hixl_engine_meta_version + t = torch.randn(2, 3, device="cpu") + transport._add_tensor_descs([t]) + transport._remove_tensor_descs([t]) + assert transport._hixl_engine_meta_version > initial + + def test_engine_meta_version_no_bump_on_partial_deregister(self, transport, mock_hixl_wrapper): + """_hixl_engine_meta_version does NOT change when metadata_count > 0 after deregister.""" + initial = transport._hixl_engine_meta_version + t = torch.randn(2, 3, device="cpu") + transport._add_tensor_descs([t]) + transport._add_tensor_descs([t]) + transport._remove_tensor_descs([t]) + assert transport._hixl_engine_meta_version == initial + + def test_tensor_memory_registered(self, transport, mock_hixl_wrapper): + """_tensor_memory_registered returns True for registered, False for unregistered.""" + t = torch.randn(2, 3, device="cpu") + assert transport._tensor_memory_registered(t) is False + transport._add_tensor_descs([t]) + assert transport._tensor_memory_registered(t) is True +``` + +### Suite 4:TestMetadataExtraction(mock hixl_wrapper) + +```python +class TestMetadataExtraction: + """Test extract_tensor_transport_metadata with mock.""" + + def test_basic_extraction_cpu(self, transport, mock_hixl_wrapper): + """CPU tensor extraction → registers memory, returns metadata with correct fields.""" + tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", tensors) + assert isinstance(meta, HixlTransportMetadata) + assert meta.tensor_device == "cpu" + assert len(meta.tensor_meta) == 1 + assert meta.hixl_serialized_mem_descs is not None + assert meta.hixl_engine_id == transport._local_engine_id + assert meta.hixl_engine_meta_version == transport._hixl_engine_meta_version + + def test_metadata_stored_in_managed_meta(self, transport, mock_hixl_wrapper): + """extract_ stores metadata in _managed_meta_hixl.""" + tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", tensors) + assert transport._get_meta("obj1") == meta + + def test_serialized_mem_descs_format(self, transport, mock_hixl_wrapper): + """Serialized descs = pickle([(data_ptr, nbytes, mem_type_str)]).""" + tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", tensors) + descs = pickle.loads(meta.hixl_serialized_mem_descs) + assert len(descs) == 1 + data_ptr, nbytes, mem_type = descs[0] + assert mem_type == "cpu" + assert nbytes == tensors[0].untyped_storage().nbytes() + + def test_multiple_tensors_serialization(self, transport, mock_hixl_wrapper): + """Multiple tensors → multiple entries in serialized descs.""" + tensors = [torch.randn(2, 3, device="cpu"), torch.randn(4, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", tensors) + descs = pickle.loads(meta.hixl_serialized_mem_descs) + assert len(descs) == 2 + + def test_contiguous_check_raises(self, transport, mock_hixl_wrapper): + """Non-contiguous tensor → ValueError.""" + t = torch.randn(2, 4, device="cpu").t() # Transposed → non-contiguous + with pytest.raises(ValueError, match="contiguous"): + transport.extract_tensor_transport_metadata("obj1", [t]) + + def test_empty_object_returns_none_fields(self, transport, mock_hixl_wrapper): + """Empty rdt_object → metadata with None serialized fields.""" + meta = transport.extract_tensor_transport_metadata("obj1", []) + assert meta.hixl_serialized_mem_descs is None + assert meta.hixl_engine_id is None + assert meta.hixl_engine_meta_version is None + assert meta.tensor_meta == [] + assert meta.tensor_device is None + + def test_get_communicator_metadata(self, transport, mock_hixl_wrapper): + """get_communicator_metadata returns empty HixlCommunicatorMetadata.""" + comm = transport.get_communicator_metadata(None, None) + assert isinstance(comm, HixlCommunicatorMetadata) +``` + +### Suite 5:TestFetchAndWait(mock hixl_wrapper) + +```python +class TestFetchAndWait: + """Test fetch_multiple_tensors and wait_fetch_complete with mock.""" + + def test_basic_fetch_and_wait(self, transport, mock_hixl_wrapper): + """Complete flow: extract → fetch → wait → result.""" + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + fetch_req = transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + assert isinstance(fetch_req, HixlFetchRequest) + assert fetch_req.transfer_req is not None + assert fetch_req.remote_engine_id == meta.hixl_engine_id + + result = transport.wait_fetch_complete(fetch_req) + assert len(result) == 1 + assert result[0].shape == src_tensors[0].shape + + def test_fetch_registers_target_tensors(self, transport, mock_hixl_wrapper): + """Fetch should register target tensors' memory with HIXL.""" + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + fetch_req = transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + + # Target tensors should be registered + for t in fetch_req.tensors: + key = t.untyped_storage().data_ptr() + assert key in transport._tensor_desc_cache + + def test_fetch_connects_remote_engine(self, transport, mock_hixl_wrapper): + """Fetch should connect to remote engine specified in metadata.""" + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + fetch_req = transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + remote_engine_id = meta.hixl_engine_id + assert remote_engine_id in mock_hixl_wrapper._connected_engines + + def test_fetch_caches_remote_engine(self, transport, mock_hixl_wrapper): + """Second fetch to same engine should reuse connection (LRU cache).""" + src_tensors = [torch.randn(2, 3, device="cpu")] + meta1 = transport.extract_tensor_transport_metadata("obj1", src_tensors) + meta2 = transport.extract_tensor_transport_metadata("obj2", src_tensors) + + transport.fetch_multiple_tensors("obj1", meta1, HixlCommunicatorMetadata()) + transport.fetch_multiple_tensors("obj2", meta2, HixlCommunicatorMetadata()) + + assert meta1.hixl_engine_id in transport._remote_engines + + def test_fetch_size_mismatch_raises(self, transport, mock_hixl_wrapper): + """Local vs remote nbytes mismatch → RuntimeError.""" + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + # Tamper with serialized descs to create size mismatch + descs = pickle.loads(meta.hixl_serialized_mem_descs) + descs[0] = (descs[0][0], descs[0][1] + 100, descs[0][2]) # wrong nbytes + meta.hixl_serialized_mem_descs = pickle.dumps(descs) + + with pytest.raises(Exception, match="size mismatch"): + transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + + def test_recv_multiple_tensors_is_sync_wrapper(self, transport, mock_hixl_wrapper): + """recv_multiple_tensors = fetch + wait.""" + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + result = transport.recv_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + assert len(result) == 1 + assert result[0].shape == src_tensors[0].shape +``` + +### Suite 6:TestGarbageCollection(mock hixl_wrapper) + +```python +class TestGarbageCollection: + """Test garbage_collect with mock.""" + + def test_gc_removes_meta_and_deregisters(self, transport, mock_hixl_wrapper): + """GC pops metadata and deregisters tensor when ref_count → 0.""" + tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", tensors) + transport.garbage_collect("obj1", meta, tensors) + + assert transport._get_meta("obj1") is None + key = tensors[0].untyped_storage().data_ptr() + assert key not in transport._tensor_desc_cache + + def test_gc_unknown_obj_id_is_noop(self, transport, mock_hixl_wrapper): + """GC for unknown obj_id → returns without error.""" + meta = HixlTransportMetadata( + tensor_meta=[], tensor_device=None, + hixl_serialized_mem_descs=None, + ) + transport.garbage_collect("unknown_obj", meta, []) + # Should not raise + + def test_gc_shared_tensor_keeps_registration(self, transport, mock_hixl_wrapper): + """GC of one metadata keeps tensor if another metadata still references it.""" + t = torch.randn(2, 3, device="cpu") + meta1 = transport.extract_tensor_transport_metadata("obj1", [t]) + meta2 = transport.extract_tensor_transport_metadata("obj2", [t]) + + transport.garbage_collect("obj1", meta1, [t]) + key = t.untyped_storage().data_ptr() + assert key in transport._tensor_desc_cache + assert transport._tensor_desc_cache[key].metadata_count == 1 + + def test_gc_second_time_removes_registration(self, transport, mock_hixl_wrapper): + """GC of both metadatas → tensor fully deregistered.""" + t = torch.randn(2, 3, device="cpu") + meta1 = transport.extract_tensor_transport_metadata("obj1", [t]) + meta2 = transport.extract_tensor_transport_metadata("obj2", [t]) + + transport.garbage_collect("obj1", meta1, [t]) + transport.garbage_collect("obj2", meta2, [t]) + key = t.untyped_storage().data_ptr() + assert key not in transport._tensor_desc_cache + + def test_gc_bumps_engine_meta_version(self, transport, mock_hixl_wrapper): + """GC that fully deregisters memory bumps _hixl_engine_meta_version.""" + initial = transport._hixl_engine_meta_version + tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", tensors) + transport.garbage_collect("obj1", meta, tensors) + assert transport._hixl_engine_meta_version > initial +``` + +### Suite 7:TestAbortTransport(mock hixl_wrapper) + +```python +class TestAbortTransport: + """Test abort_transport mechanism.""" + + def test_abort_marks_obj_id(self, transport, mock_hixl_wrapper): + """abort_transport adds obj_id to _aborted_transfer_obj_ids.""" + transport.abort_transport("obj1", HixlCommunicatorMetadata()) + assert "obj1" in transport._aborted_transfer_obj_ids + + def test_aborted_fetch_raises_error(self, transport, mock_hixl_wrapper): + """Fetch on aborted obj_id → RuntimeError with 'aborted' message.""" + transport.abort_transport("obj1", HixlCommunicatorMetadata()) + + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + with pytest.raises(RuntimeError, match="aborted"): + transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + + def test_abort_removed_after_fetch_error(self, transport, mock_hixl_wrapper): + """Aborted obj_id is removed from set after the RuntimeError is raised.""" + transport.abort_transport("obj1", HixlCommunicatorMetadata()) + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + try: + transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + except RuntimeError: + pass + assert "obj1" not in transport._aborted_transfer_obj_ids +``` + +### Suite 8:TestRemoteEngineCache(mock hixl_wrapper) + +```python +class TestRemoteEngineCache: + """Test LRU remote engine connection caching.""" + + def test_lru_eviction(self, transport, mock_hixl_wrapper): + """When cache is full, least recently used engine is evicted.""" + import ray_ascend.direct_transport.hixl_tensor_transport as hixl_mod + original = hixl_mod.HIXL_REMOTE_ENGINE_CACHE_MAXSIZE + hixl_mod.HIXL_REMOTE_ENGINE_CACHE_MAXSIZE = 2 + + try: + transport._connect_remote_engine("engine_A", 0) + transport._connect_remote_engine("engine_B", 0) + transport._connect_remote_engine("engine_C", 0) # evicts engine_A + assert "engine_A" not in transport._remote_engines + assert "engine_B" in transport._remote_engines + assert "engine_C" in transport._remote_engines + finally: + hixl_mod.HIXL_REMOTE_ENGINE_CACHE_MAXSIZE = original + + def test_version_mismatch_reconnects(self, transport, mock_hixl_wrapper): + """Different meta version → disconnect + reconnect with new version.""" + transport._connect_remote_engine("engine_A", 0) + assert transport._remote_engines["engine_A"] == 0 + + transport._connect_remote_engine("engine_A", 5) + assert transport._remote_engines["engine_A"] == 5 + # Engine was reconnected (disconnect called then connect called) + + def test_version_match_reuses_connection(self, transport, mock_hixl_wrapper): + """Same meta version → reuse cached connection, no new connect call.""" + transport._connect_remote_engine("engine_A", 0) + initial_connected_count = len(mock_hixl_wrapper._connected_engines) + + transport._connect_remote_engine("engine_A", 0) + # Should not call connect again + assert len(mock_hixl_wrapper._connected_engines) == initial_connected_count + + def test_no_cache_mode_connects_fresh(self, transport, mock_hixl_wrapper): + """HIXL_REMOTE_ENGINE_CACHE_MAXSIZE=0 → no caching, connect fresh each time.""" + import ray_ascend.direct_transport.hixl_tensor_transport as hixl_mod + original = hixl_mod.HIXL_REMOTE_ENGINE_CACHE_MAXSIZE + hixl_mod.HIXL_REMOTE_ENGINE_CACHE_MAXSIZE = 0 + + try: + transport._connect_remote_engine("engine_A", 0) + assert len(transport._remote_engines) == 0 + finally: + hixl_mod.HIXL_REMOTE_ENGINE_CACHE_MAXSIZE = original + + def test_disconnect_is_best_effort(self, transport, mock_hixl_wrapper): + """_disconnect_remote_engine should not raise even if disconnect fails.""" + # Make disconnect raise + mock_hixl_wrapper.disconnect = MagicMock(side_effect=RuntimeError("connection lost")) + # Should not raise, just log warning + transport._disconnect_remote_engine("engine_A") +``` + +### Suite 9:TestActorHealthCheck(mock Ray actor) + +```python +class TestActorHealthCheck: + """Test actor_has_tensor_transport with mock Ray actor.""" + + def test_success(self, transport, mock_hixl_wrapper): + mock_actor = MagicMock() + mock_actor.__ray_call__ = MagicMock() + mock_actor.__ray_call__.options.return_value = mock_actor.__ray_call__ + mock_actor.__ray_call__.remote.return_value = "mock_ref" + + with patch("ray.get", return_value=True): + result = transport.actor_has_tensor_transport(mock_actor) + assert result is True + mock_actor.__ray_call__.options.assert_called_once_with( + concurrency_group="_ray_system" + ) + + def test_failure(self, transport, mock_hixl_wrapper): + mock_actor = MagicMock() + mock_actor.__ray_call__ = MagicMock() + mock_actor.__ray_call__.options.return_value = mock_actor.__ray_call__ + mock_actor.__ray_call__.remote.return_value = "mock_ref" + + with patch("ray.get", return_value=False): + result = transport.actor_has_tensor_transport(mock_actor) + assert result is False +``` + +### Suite 10:TestErrorHandling(mock hixl_wrapper) + +```python +class TestErrorHandling: + """Test error paths and edge cases.""" + + def test_hixl_wrapper_not_installed_raises_import_error(self): + """hixl_wrapper=None → _ensure_hixl_initialized raises ImportError.""" + with patch(PATCH_TARGET, None): + t = HixlTensorTransport() + with pytest.raises(ImportError, match="hixl_wrapper"): + t._ensure_hixl_initialized() + + def test_register_mem_failure_raises_runtime_error(self, transport, mock_hixl_wrapper): + """register_mem returns kFailed → RuntimeError.""" + mock_hixl_wrapper.register_mem = MagicMock( + return_value=(mock_hixl_wrapper.kFailed, None) + ) + t = torch.randn(2, 3, device="cpu") + with pytest.raises(RuntimeError, match="RegisterMem"): + transport._add_tensor_descs([t]) + + def test_connect_failure_raises_runtime_error(self, transport, mock_hixl_wrapper): + """connect returns kFailed → RuntimeError.""" + mock_hixl_wrapper.connect = MagicMock(return_value=mock_hixl_wrapper.kFailed) + with pytest.raises(RuntimeError, match="Connect"): + transport._connect_remote_engine("bad_engine", 0) + + def test_transfer_async_failure_raises_ray_direct_transport_error(self, transport, mock_hixl_wrapper): + """transfer_async returns kFailed → RayDirectTransportError.""" + mock_hixl_wrapper.transfer_async = MagicMock( + return_value=(mock_hixl_wrapper.kFailed, None) + ) + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + with pytest.raises(Exception, match="HIXL transfer failed"): + transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + + def test_get_transfer_status_failed_state_raises(self, transport, mock_hixl_wrapper): + """get_transfer_status returns FAILED → RuntimeError in wait_fetch_complete.""" + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + fetch_req = transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + + # Override mock to return FAILED + mock_hixl_wrapper.get_transfer_status = MagicMock( + return_value=(mock_hixl_wrapper.kSuccess, "FAILED") + ) + with pytest.raises(Exception, match="FAILED"): + transport.wait_fetch_complete(fetch_req) + + def test_cleanup_on_fetch_failure_removes_target_descs(self, transport, mock_hixl_wrapper): + """When fetch fails, target tensor descs should be cleaned up.""" + mock_hixl_wrapper.transfer_async = MagicMock( + return_value=(mock_hixl_wrapper.kFailed, None) + ) + src_tensors = [torch.randn(2, 3, device="cpu")] + meta = transport.extract_tensor_transport_metadata("obj1", src_tensors) + + try: + transport.fetch_multiple_tensors("obj1", meta, HixlCommunicatorMetadata()) + except Exception: + pass + + # Source tensor should still be registered (from extract_) + src_key = src_tensors[0].untyped_storage().data_ptr() + assert src_key in transport._tensor_desc_cache + + def test_cleanup_without_hixl_initialized_is_noop(self): + """_cleanup_transfer when _hixl_initialized=False → returns immediately.""" + t = HixlTensorTransport() + # _hixl_initialized is False by default + t._cleanup_transfer("obj1", [], None, None, False) + # Should not raise + + def test_deregister_mem_failure_logs_warning(self, transport, mock_hixl_wrapper): + """deregister_mem returning error → warning log, not exception.""" + mock_hixl_wrapper.deregister_mem = MagicMock(return_value=mock_hixl_wrapper.kFailed) + t = torch.randn(2, 3, device="cpu") + transport._add_tensor_descs([t]) + # Should not raise, just log + transport._remove_tensor_descs([t]) +``` + +### Suite 11:TestNPUIntegration(硬件依赖,skipif 标记) + +```python +NPU_AVAILABLE = False +try: + import torch_npu + if torch.npu.is_available(): + NPU_AVAILABLE = True +except ImportError: + pass + +skip_no_npu = pytest.mark.skipif(not NPU_AVAILABLE, reason="NPU hardware not available") + + +@skip_no_npu +class TestNPUIntegration: + """Integration tests requiring real NPU hardware and hixl_wrapper.""" + + @pytest.fixture + def transport_real(self): + """Create HixlTensorTransport with real hixl_wrapper (no mock).""" + pytest.importorskip("hixl_wrapper") + import ray + ray.init(ignore_reinit_error=True) + t = HixlTensorTransport() + t._ensure_hixl_initialized() + yield t + ray.shutdown() + + def test_npu_tensor_registration(self, transport_real): + """Register and deregister a real NPU tensor.""" + t = torch.randn(2, 3, device="npu") + transport_real._add_tensor_descs([t]) + key = t.untyped_storage().data_ptr() + assert key in transport_real._tensor_desc_cache + assert transport_real._tensor_desc_cache[key].mem_type_str == "npu" + + transport_real._remove_tensor_descs([t]) + assert key not in transport_real._tensor_desc_cache + + def test_npu_extract_metadata(self, transport_real): + """Extract metadata for NPU tensors with real hixl_wrapper.""" + tensors = [torch.randn(2, 3, device="npu")] + meta = transport_real.extract_tensor_transport_metadata("obj1", tensors) + assert meta.tensor_device == "npu" + assert meta.hixl_serialized_mem_descs is not None +``` + +## 输出要求 + +创建文件 `ray-ascend/tests/direct_transport/test_hixl_tensor_transport.py`,包含以上所有 11 个测试套件。 + +**关键约束**: +- L1 mock 测试全部使用 CPU tensor(`device="cpu"`),不需要 NPU +- NPU 测试只在 `TestNPUIntegration` 中,用 `pytest.mark.skipif` 标记 +- 不要 mock torch — `untyped_storage().data_ptr()`、`nbytes()` 等是真实 Python 操作 +- 创建测试用 `HixlFetchRequest` 时设 `transport=None`,防止 `__del__` 调用 cleanup +- LRU 缓存测试中临时修改 `HIXL_REMOTE_ENGINE_CACHE_MAXSIZE`,测试后恢复原值 + +## 验证标准 + +生成的测试文件应该能运行(假设依赖已安装): + +```bash +cd /home/lyy/code/hixl/ray-ascend + +# L1 单元测试(不需要 NPU 硬件) +python -m pytest tests/direct_transport/test_hixl_tensor_transport.py -v -k "not NPUIntegration" + +# 只跑数据类和静态属性(最快验证) +python -m pytest tests/direct_transport/test_hixl_tensor_transport.py::TestDataClasses -v +python -m pytest tests/direct_transport/test_hixl_tensor_transport.py::TestTransportProperties -v + +# 覆盖率检查 +python -m pytest tests/direct_transport/test_hixl_tensor_transport.py \ + --cov=ray_ascend.direct_transport.hixl_tensor_transport \ + --cov-report=term-missing \ + -k "not NPUIntegration" + +# L2 集成测试(需要 NPU + RDMA 环境) +python -m pytest tests/direct_transport/test_hixl_tensor_transport.py::TestNPUIntegration -v +``` diff --git a/docs/hixl-tensor-transport-design.md b/docs/hixl-tensor-transport-design.md new file mode 100644 index 0000000..b3c50b1 --- /dev/null +++ b/docs/hixl-tensor-transport-design.md @@ -0,0 +1,1000 @@ +# HixlTensorTransport 实现设计文档 + +## 1. 概述 + +| 项目 | 内容 | +|---|---| +| 类名 | `HixlTensorTransport` | +| 基类 | `TensorTransportManager`(Ray RDT 插件接口) | +| 参考实现 | `NixlTensorTransport`(NIXL 一侧 RDMA 传输) | +| 传输类型 | 一侧 RDMA READ(`is_one_sided=True`) | +| 设备 | Ascend NPU (`"npu"`) + CPU (`"cpu"`) | +| C++ 绑定 | `hixl_wrapper`(pybind11 模块,绑定 `hixl::Hixl`) | +| 文件路径 | `ray-ascend/ray_ascend/direct_transport/hixl_tensor_transport.py` | + +--- + +## 2. 与 NIXL 的关键差异 + +以下差异决定了 HIXL 实现不能照抄 NIXL 代码: + +| 环节 | NIXL 做法 | HIXL 做法 | 设计影响 | +|---|---|---|---| +| 内存描述序列化 | `nixl_agent.get_serialized_descs()` 内置方法 | **无内置序列化**,需自定义 | `HixlTransportMetadata` 用 `pickle.dumps([(data_ptr, nbytes, mem_type_str)])` | +| 远端 agent 添加 | `nixl_agent.add_remote_agent(meta)` 自动 | **显式 Connect** `hixl_wrapper.connect(remote_engine_id)` | 新增 `_connect_remote_engine` / `_disconnect_remote_engine` 方法,LRU 缓存 `_remote_engines` | +| 传输描述交换 | `nixl_agent.get_xfer_descs()` + `deserialize_descs()` 自动 | **需从 metadata 中提取远端 addr/len**,手动构造 `TransferOpDesc` | `fetch_multiple_tensors` 中反序列化 `remote_mem_descs`,构建 `[(local_addr, remote_addr, len)]` | +| 设备类型 | `"cuda"` / `"cpu"` | `"npu"` / `"cpu"` | `torch.npu.synchronize()` 替代 `torch.cuda.synchronize()` | +| 异步状态 | `check_xfer_state` → `"PROC"/"DONE"/"ERR"` | `get_transfer_status` → `"WAITING"/"COMPLETED"/"TIMEOUT"/"FAILED"` | `wait_fetch_complete` 轮询逻辑不同 | +| Agent 初始化 | `nixl_agent(actor_id, config)` 创建实例对象,同进程可有多个 agent | `hixl_wrapper.initialize(engine_id, options)` 创建**进程级全局单例**,同进程只能有一个引擎 | 惰初始化用 `_ensure_hixl_initialized()` + `_local_engine_id`,不存 agent 实例引用。NIXL 的 `self._nixl_agent` 存的是实例对象,HIXL 不需要——`hixl_wrapper` 是全局模块,所有方法都是模块级函数,直接调用即可。**此差异不影响 Python 端设计**——Ray RDT 的 `get_tensor_transport_manager("HIXL")` 本身也只创建一个 `HixlTensorTransport` Python 单例 | +| 注册句柄 | NIXL `register_memory` 返回 `reg_desc`(可序列化) | HIXL `register_mem` 返回 `mem_handle`(`int`,不可序列化) | `HixlTensorDesc` 需额外存 `nbytes` 和 `mem_type_str`,因为 `mem_handle` 不包含地址信息 | + +--- + +## 3. 数据类设计 + +### 3.1 `HixlCommunicatorMetadata` + +```python +@dataclass +class HixlCommunicatorMetadata(CommunicatorMetadata): + """Metadata for the HIXL communicator.""" +``` + +### 3.2 `HixlTransportMetadata` + +```python +@dataclass +class HixlTransportMetadata(TensorTransportMetadata): + """Metadata for tensors stored in the NPU/CPU object store for HIXL transport. + + Args: + hixl_serialized_mem_descs: Pickle-serialized list of + (data_ptr, nbytes, mem_type_str) tuples describing the source + tensors' registered memory regions. + hixl_engine_id: The local HIXL engine identifier (format: "host_ip:port") + that the remote side uses to Connect back. + hixl_engine_meta_version: Monotonically increasing version number bumped + whenever memory is deregistered, so the receiver can detect stale + descriptors. + """ + + hixl_serialized_mem_descs: Optional[bytes] = None + hixl_engine_id: Optional[str] = None + hixl_engine_meta_version: Optional[int] = 0 + + __eq__ = object.__eq__ + __hash__ = object.__hash__ +``` + +### 3.3 `HixlTensorDesc` + +```python +@dataclass +class HixlTensorDesc: + """Cached registration info for a single tensor storage. + + Attributes: + mem_handle: The opaque handle returned by hixl_wrapper.register_mem. + Represented as a Python int (uintptr_t under the hood). + nbytes: Size of the registered memory region in bytes. + mem_type_str: "npu" or "cpu" — used when building TransferOpDesc and + for serialization into HixlTransportMetadata. + metadata_count: Number of HixlTransportMetadata objects that reference + this tensor. When it reaches zero, we call DeregisterMem. + """ + + mem_handle: Any + nbytes: int + mem_type_str: str + metadata_count: int +``` + +### 3.4 `HixlFetchRequest` + +```python +@dataclass +class HixlFetchRequest(FetchRequest): + """HIXL-specific fetch request carrying the async transfer state. + + Returned by fetch_multiple_tensors and consumed by wait_fetch_complete. + Resource cleanup happens in __del__ so that handles are released even if + the caller never waits on the request. + + Args: + obj_id: Inherited. The object ID for the transfer, used for abort checks. + tensors: Inherited. Pre-allocated output tensors (populated before the + transfer starts). + transfer_req: HIXL TransferReq handle (uintptr_t → Python int). + remote_engine_id: The remote engine ID (ip:port) that was connected + for this transfer. + remove_tensor_descs: Whether to remove tensor descriptors from the + cache during cleanup (True when fetch_multiple_tensors added them). + transport: Reference to the HixlTensorTransport instance for cleanup. + """ + + transfer_req: Any = None + remote_engine_id: Optional[str] = None + remove_tensor_descs: bool = False + transport: Any = None + + def __del__(self): + if self.transport is not None: + self.transport._cleanup_transfer( + self.obj_id, + self.tensors, + self.transfer_req, + self.remote_engine_id, + self.remove_tensor_descs, + ) +``` + +--- + +## 4. `HixlTensorTransport` 类完整实现 + +### 4.1 类声明与 `__init__` + +```python +class HixlTensorTransport(TensorTransportManager): + """HIXL Engine-based one-sided RDMA tensor transport for Ray RDT.""" + + def __init__(self): + # Lazily initialized because hixl_wrapper may not be installed on + # nodes that are only coordinating (not participating in transfers). + self._hixl_initialized = False + self._local_engine_id: Optional[str] = None + + # Object IDs whose transfers have been aborted. + self._aborted_transfer_obj_ids: set = set() + self._aborted_transfer_obj_ids_lock = threading.Lock() + + # Mapping from tensor storage data_ptr → HixlTensorDesc. + # Unlike _managed_meta_hixl, we only deregister tensors when ALL + # metadata containing the tensor is freed (reference counting via + # metadata_count). + self._tensor_desc_cache: Dict[int, HixlTensorDesc] = {} + + # Mapping from object ID → HixlTransportMetadata. + # Lifetime is tied to the object ref; freed when the ref goes out of + # scope (garbage_collect is called). + self._managed_meta_hixl: Dict[str, Any] = {} + + # Lock protecting _tensor_desc_cache and _managed_meta_hixl since they + # can be accessed from the main task execution thread or the + # _ray_system thread. + self._cache_lock = threading.RLock() + + # LRU cache of remote engine IDs. When full, the least recently used + # remote engine is evicted and Disconnect is called. + self._remote_engines: OrderedDict = OrderedDict() + + # Incremented whenever memory is deregistered so receivers can detect + # stale descriptors. + self._hixl_engine_meta_version: int = 0 +``` + +### 4.2 `tensor_transport_backend` + +```python + def tensor_transport_backend(self) -> str: + return "HIXL" +``` + +### 4.3 `is_one_sided` + +```python + @staticmethod + def is_one_sided() -> bool: + return True # HIXL RDMA: receiver initiates READ (one-sided) +``` + +### 4.4 `can_abort_transport` + +```python + @staticmethod + def can_abort_transport() -> bool: + return True # TransferAsync can be interrupted via abort flag +``` + +### 4.5 `_ensure_hixl_initialized` + +```python + def _ensure_hixl_initialized(self): + """Lazily initializes the HIXL engine via hixl_wrapper. + + The engine ID is constructed from the Ray actor's node IP + actor_id + as the port component, ensuring uniqueness per actor. + + Raises: + ImportError: If hixl_wrapper is not installed. + RuntimeError: If HIXL initialization fails. + """ + if self._hixl_initialized: + return + + if hixl_wrapper is None: + raise ImportError( + "hixl_wrapper module not found. " + "Please install the HIXL Engine wheel: " + "pip install hixl_engine-0.0.1-py3-none-any.whl" + ) + + # Build a local engine ID from the Ray actor's IP address. + # The port component uses the actor_id to ensure uniqueness. + ctx = ray.get_runtime_context() + actor_id = ctx.get_actor_id() + if actor_id is None: + # Driver process — generate a unique ID. + import uuid + actor_id = f"RAY-DRIVER-{uuid.uuid4()}" + + node_ip = ray.util.get_node_ip_address() + self._local_engine_id = f"{node_ip}:{actor_id}" + + status = hixl_wrapper.initialize(self._local_engine_id, {}) + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"Failed to initialize HIXL engine with id " + f"'{self._local_engine_id}', status={status}. " + f"Common causes:\n" + f" - HIXL library not installed or incompatible version\n" + f" - RDMA hardware not available on this node\n" + f" - CANN driver/runtime version mismatch" + ) + + self._hixl_initialized = True + logger.info( + f"HIXL engine initialized with local_engine_id=" + f"{self._local_engine_id}" + ) +``` + +### 4.6 `actor_has_tensor_transport` + +```python + def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: + """Check if a remote actor has the HIXL transport available.""" + # TODO: This is called on a .remote RDT call, so it's quite expensive. + def __ray_actor_has_tensor_transport__( + self: "ray.actor.ActorHandle", + ) -> bool: + # Check if hixl_wrapper is installed and can initialize + try: + from ray.experimental.rdt.util import ( + get_tensor_transport_manager, + ) + + get_tensor_transport_manager("HIXL")._ensure_hixl_initialized() + return True + except Exception: + return False + + return ray.get( + actor.__ray_call__.options(concurrency_group="_ray_system").remote( + __ray_actor_has_tensor_transport__ + ) + ) +``` + +### 4.7 `register_hixl_memory` + +```python + def register_hixl_memory(self, tensor: "torch.Tensor") -> None: + """Registers the tensor's memory with HIXL and bumps the reference + count so the memory region is never deregistered. + + Mirrors NixlTensorTransport.register_nixl_memory(). + Call this to pre-register a tensor's memory for the lifetime of the + process, which can improve performance if the same tensor is re-used + in multiple RDT objects. + """ + self._add_tensor_descs([tensor]) +``` + +### 4.8 `deregister_hixl_memory` + +```python + def deregister_hixl_memory(self, tensor: "torch.Tensor") -> None: + """Decrements the reference count for the tensor's HIXL memory + registration added by register_hixl_memory. + + If the reference count reaches 0, the memory is deregistered from + HIXL. This should only be called after register_hixl_memory has been + called for this tensor. Any existing ObjectRef instances that reference + this tensor's memory will keep the HIXL registration alive independently + until they go out of scope. + + Mirrors NixlTensorTransport.deregister_nixl_memory(). + """ + self._remove_tensor_descs([tensor]) +``` + +### 4.9 `extract_tensor_transport_metadata` + +```python + def extract_tensor_transport_metadata( + self, + obj_id: str, + rdt_object: List["torch.Tensor"], + ) -> HixlTransportMetadata: + """Source side: register tensor memory and serialize descriptors. + + Called on the source actor immediately after the task creates the + result tensors. We: + 1. Synchronize the device to ensure data is written. + 2. Register each tensor's storage with HIXL (RegisterMem). + 3. Serialize the memory descriptions as pickle bytes. + 4. Return HixlTransportMetadata with the serialized descs, the + local engine ID, and the current meta version. + + Args: + obj_id: The object ID for the RDT object. + rdt_object: The RDT object (list of tensors). + + Returns: + HixlTransportMetadata containing serialized memory descriptions + and the local engine ID. + """ + import torch + + with self._cache_lock: + device = None + tensor_meta = [] + mem_descs_for_serialization = [] + + if rdt_object: + # All tensors must share the same device type, + # but we don't assume they're all on the same device index. + devices = set() + device = rdt_object[0].device + for t in rdt_object: + if t.device.type != device.type: + raise ValueError( + "All tensors in an RDT object must have the same " + "device type." + ) + if not t.is_contiguous(): + raise ValueError( + "All tensors in an RDT object must be contiguous." + ) + tensor_meta.append((t.shape, t.dtype)) + devices.add(t.device) + + if device.type == "npu": + # Synchronize before registration to assure the data has + # been written — HIXL does not guarantee this. + for dev in devices: + torch.npu.synchronize(dev) + + self._add_tensor_descs(rdt_object) + + # Build serialization payload: for each registered tensor, + # we pack (data_ptr, nbytes, mem_type_str). The receiver + # uses these to construct TransferOpDesc tuples. + for t in rdt_object: + key = t.untyped_storage().data_ptr() + desc = self._tensor_desc_cache[key] + mem_descs_for_serialization.append( + (key, desc.nbytes, desc.mem_type_str) + ) + + serialized_mem_descs = pickle.dumps(mem_descs_for_serialization) + engine_id = self._local_engine_id + engine_meta_version = self._hixl_engine_meta_version + else: + serialized_mem_descs = None + engine_id = None + engine_meta_version = None + + ret = HixlTransportMetadata( + tensor_meta=tensor_meta, + tensor_device=device.type if device else None, + hixl_serialized_mem_descs=serialized_mem_descs, + hixl_engine_id=engine_id, + hixl_engine_meta_version=engine_meta_version, + ) + self._put_meta(obj_id, ret) + return ret +``` + +### 4.10 `get_communicator_metadata` + +```python + def get_communicator_metadata( + self, + src_actor: "ray.actor.ActorHandle", + dst_actor: "ray.actor.ActorHandle", + backend: Optional[str] = None, + ) -> HixlCommunicatorMetadata: + """One-sided RDMA transport: no communicator metadata needed.""" + return HixlCommunicatorMetadata() +``` + +### 4.11 `fetch_multiple_tensors` + +```python + def fetch_multiple_tensors( + self, + obj_id: str, + tensor_transport_metadata: TensorTransportMetadata, + communicator_metadata: CommunicatorMetadata, + target_buffers: Optional[List["torch.Tensor"]] = None, + ) -> HixlFetchRequest: + """Receiver side: initiate an RDMA READ transfer. + + This triggers the transfer but does not wait for completion. Call + wait_fetch_complete(fetch_request) to retrieve the tensors. + + Steps: + 1. Allocate target tensors (or use provided buffers). + 2. Register target memory with HIXL. + 3. Deserialize the source memory descriptions from metadata. + 4. Connect to the remote HIXL engine (using engine_id from metadata). + 5. Build TransferOpDesc tuples: (local_addr, remote_addr, len). + 6. Call hixl_wrapper.transfer_async("READ", op_descs, remote_engine_id). + 7. Return HixlFetchRequest with the async transfer handle. + + Args: + obj_id: The object ID for the transfer. + tensor_transport_metadata: Source-side metadata containing + serialized memory descriptions and the remote engine ID. + communicator_metadata: Empty HixlCommunicatorMetadata. + target_buffers: Optional pre-allocated buffers to receive into. + + Returns: + HixlFetchRequest carrying the async transfer state. + """ + from ray.experimental.rdt.util import ( + create_empty_tensors_from_metadata, + ) + + tensors = target_buffers or create_empty_tensors_from_metadata( + tensor_transport_metadata + ) + + assert isinstance(tensor_transport_metadata, HixlTransportMetadata) + assert isinstance(communicator_metadata, HixlCommunicatorMetadata) + + serialized_mem_descs = tensor_transport_metadata.hixl_serialized_mem_descs + remote_engine_id = tensor_transport_metadata.hixl_engine_id + + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError( + f"HIXL transfer aborted for object id: {obj_id}" + ) + + transfer_req = None + added_tensor_descs = False + + assert tensors + + try: + self._ensure_hixl_initialized() + + # Register local target tensors with HIXL. + self._add_tensor_descs(tensors) + added_tensor_descs = True + + # Deserialize the source-side memory descriptions. + remote_mem_descs = pickle.loads(serialized_mem_descs) + + # Connect to the remote HIXL engine (or reuse cached connection). + remote_engine_meta_version = ( + tensor_transport_metadata.hixl_engine_meta_version + ) + + self._connect_remote_engine( + remote_engine_id, remote_engine_meta_version + ) + + # Build TransferOpDesc tuples for RDMA READ. + # For each tensor pair (local target, remote source): + # local_addr = target tensor's storage data_ptr + # remote_addr = source tensor's data_ptr (from deserialized mem desc) + # len = nbytes (must match; we validate this) + op_descs = [] + for i, t in enumerate(tensors): + remote_addr, remote_nbytes, _ = remote_mem_descs[i] + local_addr = t.untyped_storage().data_ptr() + local_nbytes = t.untyped_storage().nbytes() + if local_nbytes != remote_nbytes: + raise RuntimeError( + f"HIXL transfer size mismatch for tensor {i}: " + f"local={local_nbytes} bytes vs remote={remote_nbytes} bytes" + ) + op_descs.append((local_addr, remote_addr, remote_nbytes)) + + # Initiate async RDMA READ from remote engine. + status, transfer_req = hixl_wrapper.transfer_async( + remote_engine_id, "READ", op_descs + ) + + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"HIXL TransferAsync returned error status={status} " + f"for object id: {obj_id}" + ) + + return HixlFetchRequest( + obj_id=obj_id, + tensors=tensors, + transfer_req=transfer_req, + remote_engine_id=remote_engine_id, + remove_tensor_descs=added_tensor_descs, + transport=self, + ) + except Exception: + self._cleanup_transfer( + obj_id, tensors, transfer_req, remote_engine_id, + added_tensor_descs, + ) + # Import here to avoid circular dependency on startup. + from ray.exceptions import RayDirectTransportError + + raise RayDirectTransportError( + f"The HIXL transfer failed for object id: {obj_id}. " + f"The source actor may have died during the transfer. " + f"The exception thrown from HIXL transfer was:\n " + f"{traceback.format_exc()}" + ) from None +``` + +> **关键差异**:NIXL 的 `fetch_multiple_tensors` 流程是: +> 1. `nixl_agent.deserialize_descs(serialized_descs)` → 得到远端描述 +> 2. `_add_tensor_descs(tensors)` → 注册本地目标内存 +> 3. `nixl_agent.get_xfer_descs(tensors)` → 得到本地描述 +> 4. `nixl_agent.add_remote_agent(remote_meta)` → 添加远端 agent +> 5. `nixl_agent.initialize_xfer("READ", local, remote, remote_name, UUID)` → 初始化传输 +> 6. `nixl_agent.transfer(xfer_handle)` → 启动传输 +> +> HIXL 的流程是: +> 1. `pickle.loads(serialized_mem_descs)` → 得到远端 `(data_ptr, nbytes, mem_type_str)` 列表 +> 2. `_add_tensor_descs(tensors)` → 注册本地目标内存 +> 3. 手动构建 `[(local_addr, remote_addr, len)]` → 构造 TransferOpDesc +> 4. `_connect_remote_engine(remote_engine_id, version)` → 显式建链(替代 add_remote_agent) +> 5. `hixl_wrapper.transfer_async(remote_engine_id, "READ", op_descs)` → 启动传输(一步完成初始化+启动) + +### 4.12 `wait_fetch_complete` + +```python + def wait_fetch_complete( + self, fetch_request: FetchRequest, timeout: float = -1 + ) -> List["torch.Tensor"]: + """Wait for a previously initiated HIXL fetch to complete. + + Polls hixl_wrapper.get_transfer_status until the state is "COMPLETED", + "TIMEOUT", or "FAILED". Supports abort via _aborted_transfer_obj_ids. + + Args: + fetch_request: The HixlFetchRequest returned by + fetch_multiple_tensors. + timeout: Maximum time in seconds to wait. -1 means wait + indefinitely. 0 means return immediately if not ready. + + Returns: + List of tensors that were transferred. + + Raises: + RayDirectTransportError: If the transfer failed. + TimeoutError: If the timeout is exceeded. + """ + assert isinstance(fetch_request, HixlFetchRequest) + obj_id = fetch_request.obj_id + + if not fetch_request.tensors: + return fetch_request.tensors + + try: + # Poll transfer status until completion. + deadline = None if timeout < 0 else time.monotonic() + timeout + while True: + self._ensure_hixl_initialized() + status, transfer_status = hixl_wrapper.get_transfer_status( + fetch_request.transfer_req + ) + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"HIXL GetTransferStatus returned error status={status} " + f"for object id: {obj_id}" + ) + + if transfer_status == "FAILED": + raise RuntimeError( + f"HIXL transfer got FAILED state for object id: {obj_id}" + ) + if transfer_status == "TIMEOUT": + raise RuntimeError( + f"HIXL transfer got TIMEOUT state for object id: {obj_id}" + ) + if transfer_status == "WAITING": + if deadline is not None and time.monotonic() >= deadline: + raise TimeoutError( + f"HIXL transfer timed out after {timeout}s " + f"for object id: {obj_id}" + ) + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError( + f"HIXL transfer aborted for object id: {obj_id}" + ) + time.sleep(0.001) # Avoid busy waiting + elif transfer_status == "COMPLETED": + break + + return fetch_request.tensors + except TimeoutError: + raise + except Exception: + from ray.exceptions import RayDirectTransportError + + raise RayDirectTransportError( + f"The HIXL transfer failed for object id: {obj_id}. " + f"The source actor may have died during the transfer. " + f"The exception thrown from HIXL transfer was:\n " + f"{traceback.format_exc()}" + ) from None +``` + +### 4.13 `_cleanup_transfer` + +```python + def _cleanup_transfer( + self, + obj_id: str, + tensors: List["torch.Tensor"], + transfer_req: Any, + remote_engine_id: Optional[str], + remove_tensor_descs: bool, + ) -> None: + """Best-effort cleanup after a transfer completes or fails. + + We may encounter errors or HIXL may raise errors like connection + loss, so we do best-effort cleanup without raising further errors. + """ + if not self._hixl_initialized: + return + + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.discard(obj_id) + + # HIXL does not have an explicit release_xfer_handle API; + # the TransferReq is consumed by GetTransferStatus polling. + + # Evict remote engine from LRU cache if caching is disabled. + if HIXL_REMOTE_ENGINE_CACHE_MAXSIZE == 0 and remote_engine_id: + self._disconnect_remote_engine(remote_engine_id) + + if remove_tensor_descs: + self._remove_tensor_descs(tensors) +``` + +### 4.14 `recv_multiple_tensors` + +```python + def recv_multiple_tensors( + self, + obj_id: str, + tensor_transport_metadata: TensorTransportMetadata, + communicator_metadata: CommunicatorMetadata, + target_buffers: Optional[List["torch.Tensor"]] = None, + ) -> List["torch.Tensor"]: + """Receives multiple tensors synchronously (fetch + wait).""" + fetch_request = self.fetch_multiple_tensors( + obj_id, tensor_transport_metadata, communicator_metadata, + target_buffers, + ) + return self.wait_fetch_complete(fetch_request) +``` + +### 4.15 `send_multiple_tensors` + +```python + def send_multiple_tensors( + self, + tensors: List["torch.Tensor"], + tensor_transport_metadata: TensorTransportMetadata, + communicator_metadata: CommunicatorMetadata, + ): + """Not implemented — HIXL is a one-sided transport.""" + raise NotImplementedError( + "HIXL transport does not support send_multiple_tensors, " + "since it is a one-sided transport." + ) +``` + +### 4.16 `garbage_collect` + +```python + def garbage_collect( + self, + obj_id: str, + tensor_transport_meta: TensorTransportMetadata, + tensors: List["torch.Tensor"], + ): + """Release source-side resources for an RDT object. + + Called on the source actor after Ray's distributed ref counting + determines the object is out of scope. We: + 1. Pop the metadata from _managed_meta_hixl. + 2. Remove tensor descriptors (decrement ref count; deregister + when it reaches zero). + """ + with self._cache_lock: + assert isinstance(tensor_transport_meta, HixlTransportMetadata) + if obj_id not in self._managed_meta_hixl: + return + self._managed_meta_hixl.pop(obj_id, None) + self._remove_tensor_descs(tensors) +``` + +### 4.17 `abort_transport` + +```python + def abort_transport( + self, + obj_id: str, + communicator_metadata: CommunicatorMetadata, + ): + """Mark a transfer as aborted so wait_fetch_complete can exit.""" + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.add(obj_id) +``` + +### 4.18 `_add_tensor_descs` + +```python + def _add_tensor_descs(self, tensors: List["torch.Tensor"]): + """Register tensor memory with HIXL and bump reference counts. + + If a tensor's storage is already registered (keyed by data_ptr), we + only increment the metadata_count. Otherwise we call + hixl_wrapper.register_mem and cache the handle + registration params. + """ + self._ensure_hixl_initialized() + + with self._cache_lock: + for tensor in tensors: + key = tensor.untyped_storage().data_ptr() + if key in self._tensor_desc_cache: + self._tensor_desc_cache[key].metadata_count += 1 + continue + + # Determine memory type: NPU tensors → "npu", CPU → "cpu". + mem_type_str = "npu" if tensor.device.type == "npu" else "cpu" + + # Register the full underlying storage with HIXL. + addr = tensor.untyped_storage().data_ptr() + nbytes = tensor.untyped_storage().nbytes() + + try: + status, mem_handle = hixl_wrapper.register_mem( + (addr, nbytes), mem_type_str + ) + except Exception as e: + raise RuntimeError( + f"Failed to register {mem_type_str} memory with HIXL " + f"(addr=0x{addr:x}, size={nbytes} bytes). " + f"Common causes:\n" + f" - CANN driver/runtime not installed\n" + f" - RDMA device not available\n" + f" - HCCS link not established\n" + f" - Container privilege level too low" + ) from e + + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"HIXL RegisterMem returned error status={status} " + f"for {mem_type_str} memory " + f"(addr=0x{addr:x}, size={nbytes} bytes)" + ) + + self._tensor_desc_cache[key] = HixlTensorDesc( + mem_handle=mem_handle, + nbytes=nbytes, + mem_type_str=mem_type_str, + metadata_count=1, + ) +``` + +### 4.19 `_remove_tensor_descs` + +```python + def _remove_tensor_descs(self, tensors: List["torch.Tensor"]): + """Decrement reference counts and deregister when they reach zero. + + When metadata_count drops to zero we call hixl_wrapper.deregister_mem + with the cached MemHandle and bump _hixl_engine_meta_version. + """ + with self._cache_lock: + for tensor in tensors: + key = tensor.untyped_storage().data_ptr() + if key not in self._tensor_desc_cache: + continue + tensor_desc = self._tensor_desc_cache[key] + tensor_desc.metadata_count -= 1 + if tensor_desc.metadata_count == 0: + self._tensor_desc_cache.pop(key) + try: + status = hixl_wrapper.deregister_mem( + tensor_desc.mem_handle + ) + if status != hixl_wrapper.kSuccess: + logger.warning( + f"HIXL DeregisterMem returned status={status} " + f"for handle={tensor_desc.mem_handle}" + ) + except Exception: + logger.warning( + f"HIXL DeregisterMem raised exception for " + f"handle={tensor_desc.mem_handle}", + exc_info=True, + ) + self._hixl_engine_meta_version += 1 +``` + +### 4.20 `_tensor_memory_registered` + +```python + def _tensor_memory_registered(self, t: "torch.Tensor") -> bool: + """Check if the tensor's memory has been registered with HIXL.""" + entry = self._tensor_desc_cache.get(t.untyped_storage().data_ptr()) + return entry is not None +``` + +> **注意**:NIXL 的 `_tensor_memory_registered` 检查 `entry is not None and entry.reg_desc is not None`(区分 pool-managed 和传统注册)。HIXL 没有 pool 概念(初始实现不含 MemoryPool),所以只要 `entry is not None` 就足够了。如果后续加 MemoryPool,需要像 NIXL 那样检查 `mem_handle is not None`。 + +### 4.21 `_get_meta` / `_put_meta` + +```python + def _get_num_managed_meta_hixl(self) -> int: + """Return the number of tracked HixlTransportMetadata objects.""" + with self._cache_lock: + return len(self._managed_meta_hixl) + + def _get_meta(self, object_id: str) -> Optional[HixlTransportMetadata]: + """Get the HIXL transport metadata for the given object ID.""" + with self._cache_lock: + if object_id in self._managed_meta_hixl: + return self._managed_meta_hixl[object_id] + return None + + def _put_meta(self, object_id: str, meta: HixlTransportMetadata): + """Store the HIXL transport metadata for the given object ID.""" + with self._cache_lock: + self._managed_meta_hixl[object_id] = meta +``` + +### 4.22 远端引擎连接管理(LRU 缓存) + +```python + def _connect_remote_engine( + self, remote_engine_id: str, remote_engine_meta_version: int + ) -> None: + """Connect to a remote HIXL engine, with LRU caching. + + Mirrors NixlTensorTransport's _remote_agents logic: + - If the remote engine is already cached and the meta version + matches, we reuse the connection (move to end of LRU). + - If the meta version differs (source deregistered memory), we + disconnect first and reconnect. + - If the cache is full, evict the least recently used engine. + # 情况 1:已在缓存 + 版号一致 → 复用连接,return,不 connect + # 情况 2:已在缓存 + 版号不一致 → 断开 + 重连 + 存缓存 + # 情况 3:不在缓存 + 缓存未满 → connect + 存缓存 + # 情况 4:不在缓存 + 缓存已满 → 淘汰最旧 + connect + 存缓存 + # ===== else 分支(缓存关闭)===== + # 情况只有一种:直接 connect,不查缓存,不存缓存 + """ + if HIXL_REMOTE_ENGINE_CACHE_MAXSIZE > 0: + if remote_engine_id in self._remote_engines: + cached_version = self._remote_engines[remote_engine_id] + if cached_version != remote_engine_meta_version: + # Source deregistered memory — stale descriptors. + # Disconnect before reconnecting. + self._disconnect_remote_engine(remote_engine_id) + else: + # Reuse cached connection; move to end of LRU. + self._remote_engines.move_to_end(remote_engine_id) + return + + elif len(self._remote_engines) >= HIXL_REMOTE_ENGINE_CACHE_MAXSIZE: + # Evict least recently used remote engine. + evicted_engine_id, _ = self._remote_engines.popitem(last=False) + self._disconnect_remote_engine(evicted_engine_id) + + # Establish new connection. + status = hixl_wrapper.connect(remote_engine_id) + if status != hixl_wrapper.kSuccess and status != hixl_wrapper.kAlreadyConnected: + raise RuntimeError( + f"HIXL Connect to '{remote_engine_id}' failed, " + f"status={status}" + ) + + self._remote_engines[remote_engine_id] = remote_engine_meta_version + else: + # No caching — connect fresh each time. + status = hixl_wrapper.connect(remote_engine_id) + if status != hixl_wrapper.kSuccess and status != hixl_wrapper.kAlreadyConnected: + raise RuntimeError( + f"HIXL Connect to '{remote_engine_id}' failed, " + f"status={status}" + ) + + def _disconnect_remote_engine(self, remote_engine_id: str) -> None: + """Disconnect from a remote HIXL engine (best-effort).""" + try: + hixl_wrapper.disconnect(remote_engine_id) + except Exception: + logger.warning( + f"HIXL Disconnect from '{remote_engine_id}' raised exception", + exc_info=True, + ) +``` + +## 6. 导入和模块结构 + +```python +import logging +import pickle +import threading +import time +import traceback +import uuid +from collections import OrderedDict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import ray +from ray.experimental.rdt.tensor_transport_manager import ( + CommunicatorMetadata, + FetchRequest, + TensorTransportManager, + TensorTransportMetadata, +) + +if TYPE_CHECKING: + import torch + +logger = logging.getLogger(__name__) + +# Lazy import: hixl_wrapper may not be installed on all nodes. +try: + import hixl_wrapper +except ImportError: + hixl_wrapper = None + +# Maximum number of cached HIXL remote engine connections. +# When exceeded, the least recently used remote engine is evicted and +# Disconnect is called. Set to 0 to disable remote engine reuse. +HIXL_REMOTE_ENGINE_CACHE_MAXSIZE = 1000 +``` + +> **导入差异**:NIXL 从 `ray._private.ray_constants` 导入 `NIXL_REMOTE_AGENT_CACHE_MAXSIZE`。HIXL 在模块内定义 `HIXL_REMOTE_ENGINE_CACHE_MAXSIZE = 1000`,因为这是 ray-ascend 的自定义常量,不应侵入 Ray 主仓库。后续如果需要,可以移到 `ray_ascend` 的配置模块中。 + +--- + +## 7. 注册方式 + +在 `ray-ascend/ray_ascend/__init__.py` 中已有 `register_hixl_tensor_transport` 函数(已经实现,只需确保导入路径正确): + +```python +def register_hixl_tensor_transport(devices: List[str] = ["npu", "cpu"]) -> None: + import torch + from ray.experimental import register_tensor_transport + from ray_ascend.direct_transport.hixl_tensor_transport import HixlTensorTransport + + # Verify hixl_wrapper is importable before registration. + try: + import hixl_wrapper + except ImportError as e: + raise ImportError( + "hixl_wrapper module not found. HIXL tensor transport requires " + "the HIXL Engine wheel. Please install: " + "pip install hixl_engine-0.0.1-py3-none-any.whl" + ) from e + + register_tensor_transport("HIXL", devices, HixlTensorTransport, torch.Tensor) +``` \ No newline at end of file diff --git a/docs/hixl-wrapper-bindings-plan.md b/docs/hixl-wrapper-bindings-plan.md new file mode 100644 index 0000000..7574f68 --- /dev/null +++ b/docs/hixl-wrapper-bindings-plan.md @@ -0,0 +1,878 @@ +# HIXL Engine Python 绑定实现计划 + +| 项目 | 内容 | +|---|---| +| 模块名 | `hixl_wrapper` | +| 目标 | 为 `hixl::Hixl` 类创建 pybind11 模块,供 Ray RDT 传输层调用 HIXL RDMA 传输能力 | +| 参考实现 | `hixl/src/python/llm_wrapper/`(LLM-DataDist Python 绑定) | +| 打包方式 | 独立 wheel(`hixl_engine-0.0.1-py3-none-any.whl`) | + +--- + +## 1. 设计背景 + +### 1.1 架构层次 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ LLM-DataDist(最高层) │ +│ 类: LlmDataDist / LLMDataDistV2 │ +│ 语义: KV Cache(RegisterCache、PullCache、TransferCache) │ +│ pybind: llm_datadist_wrapper ← 已存在,不需要改动 │ +├─────────────────────────────────────────────────────────────┤ +│ HIXL Engine(中间层) │ +│ 类: hixl::Hixl │ +│ 语义: 通用 RDMA 传输(RegisterMem、Connect、TransferSync) │ +│ pybind: hixl_wrapper ← 本计划要创建的 │ +├─────────────────────────────────────────────────────────────┤ +│ ADXL(最底层) │ +│ 类: adxl::AdxlInnerEngine │ +│ 语义: RDMA/HCCS 硬件级传输 │ +│ pybind: 无,不需要 │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 1.2 与参考实现(llm_wrapper)的关键差异 + +本设计严格遵循 `llm_wrapper` 的代码风格,但 HIXL Engine API 与 LLM-DataDist API 存在以下差异,需要特别注意: + +| 差异点 | LLM-DataDist (llm_wrapper) | HIXL Engine (hixl_wrapper) | 处理方式 | +|---|---|---|---| +| 单例生命周期 | `Init(cluster_id, options)` → 创建并初始化 | `Initialize(local_engine, options)` → 创建并初始化 | 参数类型不同(`uint64_t cluster_id` vs `std::string local_engine`),但模式相同 | +| 资源清理 | `Finalize()` 返回 `void` | `Finalize()` 返回 `void` | **完全一致**,Python 端无返回值 | +| 输出参数数量 | 多个方法有 2 个输出参数(如 `RegisterCache` 返回 `Status + Cache`) | 大部分方法只有 0 或 1 个输出参数 | 仅 `RegisterMem`、`TransferAsync`、`GetTransferStatus`、`GetNotifies` 有输出参数 | +| `void*` handle | LLM-DataDist 不暴露 `void*` 到 Python | `MemHandle`/`TransferReq` 是 `void*`,需要 `uintptr_t` 桥接 | **新增类型映射**:`void*` ↔ `uintptr_t` ↔ Python `int` | +| 枚举类型 | LLM-DataDist 使用 `ge::Status` 等已有类型 | HIXL 使用自有枚举 `MemType`、`TransferOp`、`TransferStatus` | **新增 str ↔ enum 转换** | +| struct 拆包 | `CacheDesc`、`CacheKey` 等复杂 struct | `MemDesc`、`TransferOpDesc`、`NotifyDesc` 较简单 struct | Tuple 别名更简单(2-3 元素 vs 7-10 元素) | +| CMake 链接目标 | `llm_datadist` | `cann_hixl` | 链接目标不同 | +| Wheel 打包 | 合并到 `llm_datadist` wheel | **独立 wheel** | 新增独立打包流程 | + +--- + +## 2. 文件清单 + +| 文件路径 | 作用 | +|---|---| +| `hixl/src/python/hixl_wrapper/hixl_engine_wrapper.h` | Wrapper 类声明(Tuple 类型别名 + 静态方法声明) | +| `hixl/src/python/hixl_wrapper/hixl_engine_wrapper.cc` | Wrapper 类实现(tuple↔C++ 转换 + 委托调用) | +| `hixl/src/python/hixl_wrapper/hixl_wrapper.cc` | pybind11 模块入口(注册函数和常量) | +| `hixl/src/python/hixl_wrapper/CMakeLists.txt` | 构建配置 | +| `hixl/src/python/hixl_engine/CMakeLists.txt` | 独立 wheel 打包配置 | +| `hixl/src/python/hixl_engine/setup.py` | wheel 打包脚本 | +| `hixl/src/python/hixl_engine/MANIFEST.in` | wheel 包含规则 | +| `hixl/src/python/hixl_engine/hixl_engine/__init__.py` | Python 包入口 | +| `hixl/src/python/CMakeLists.txt` | 新增 `add_subdirectory(hixl_wrapper)` 和 `add_subdirectory(hixl_engine)` | + +--- + +## 3. C++ API 方法签名交叉验证 + +逐方法对比 `hixl::Hixl` C++ API(`include/hixl/hixl.h`)与 Wrapper 签名,确认映射正确性: + +| # | C++ 方法 | C++ 签名 | Wrapper 签名 | 输出参数处理 | Python 返回值 | 验证结果 | +|---|---|---|---|---|---|---| +| 1 | `Initialize` | `Status Initialize(const AscendString &local_engine, const std::map &options)` | `Status Initialize(const std::string &local_engine, const std::map &options)` | 无输出参数 | `int` (status) | ✅ `AscendString` → `std::string`,`std::map` 直接映射 | +| 2 | `Finalize` | `void Finalize()` | `void Finalize()` | 无输出参数 | `None` | ✅ **与 C++ API 一致**,Python 端无返回值 | +| 3 | `RegisterMem` | `Status RegisterMem(const MemDesc &mem, MemType type, MemHandle &mem_handle)` | `std::pair RegisterMem(const MemDescTuple &mem_desc, const std::string &mem_type)` | `MemHandle &mem_handle` → 第二个返回值 | `(int, int)` | ✅ `MemDesc` → `MemDescTuple`,`MemType` → `str`,`MemHandle` → `uintptr_t` | +| 4 | `DeregisterMem` | `Status DeregisterMem(MemHandle mem_handle)` | `Status DeregisterMem(uintptr_t mem_handle)` | 无输出参数 | `int` (status) | ✅ `MemHandle` → `uintptr_t` | +| 5 | `Connect` | `Status Connect(const AscendString &remote_engine, int32_t timeout_in_millis = 1000)` | `Status Connect(const std::string &remote_engine, int32_t timeout_ms = 1000)` | 无输出参数 | `int` (status) | ✅ 默认超时 1000ms 已保留 | +| 6 | `Disconnect` | `Status Disconnect(const AscendString &remote_engine, int32_t timeout_in_millis = 1000)` | `Status Disconnect(const std::string &remote_engine, int32_t timeout_ms = 1000)` | 无输出参数 | `int` (status) | ✅ 默认超时 1000ms 已保留 | +| 7 | `TransferSync` | `Status TransferSync(const AscendString &remote_engine, TransferOp operation, const std::vector &op_descs, int32_t timeout_in_millis = 1000)` | `Status TransferSync(const std::string &remote_engine, const std::string &operation, const std::vector &op_descs, int32_t timeout_ms = 1000)` | 无输出参数 | `int` (status) | ✅ `TransferOp` → `str`,默认超时 1000ms 已保留 | +| 8 | `TransferAsync` | `Status TransferAsync(const AscendString &remote_engine, TransferOp operation, const std::vector &op_descs, const TransferArgs &optional_args, TransferReq &req)` | `std::pair TransferAsync(const std::string &remote_engine, const std::string &operation, const std::vector &op_descs)` | `TransferReq &req` → 第二个返回值;`TransferArgs` 内部构造 | `(int, int)` | ✅ `TransferArgs` 不暴露给 Python(reserved 字段默认为 0) | +| 9 | `GetTransferStatus` | `Status GetTransferStatus(const TransferReq &req, TransferStatus &status)` | `std::pair GetTransferStatus(uintptr_t req_id)` | `TransferStatus &status` → 第二个返回值(转为 str) | `(int, str)` | ✅ `TransferReq` → `uintptr_t`,`TransferStatus` → `str` | +| 10 | `SendNotify` | `Status SendNotify(const AscendString &remote_engine, const NotifyDesc ¬ify, int32_t timeout_in_millis = 1000)` | `Status SendNotify(const std::string &remote_engine, const NotifyDescTuple ¬ify, int32_t timeout_ms = 1000)` | 无输出参数 | `int` (status) | ✅ `NotifyDesc` → `NotifyDescTuple`,默认超时 1000ms 已保留 | +| 11 | `GetNotifies` | `Status GetNotifies(std::vector ¬ifies)` | `std::pair> GetNotifies()` | `std::vector ¬ifies` → 第二个返回值(转为 tuple 列表) | `(int, list[tuple])` | ✅ | + +**原设计文档发现的签名问题(已修正):** + +| 问题 | 原设计 | 修正 | 原因 | +|---|---|---|---| +| Finalize 返回值 | `std::tuple` → Python `(status,)` | `void` → Python `None` | C++ API 返回 `void`,与 llm_wrapper 参考实现一致 | +| 单返回值方法风格 | 全部用 `std::tuple` | 单返回值直接返回 `Status`;多返回值用 `std::pair<>` | 与 llm_wrapper 一致(`UnregisterCache` 返回 `ge::Status`,`RegisterCache` 返回 `std::pair<>`) | +| ParseMemType/ParseTransferOp | 非法字符串静默返回默认值 | 返回 `PARAM_INVALID` + ALOG 警告 | 避免掩盖调用错误 | +| GetNotifies 中 AscendString 转换 | 使用 `GetData()` | 使用 `GetString()` | `GetData()` 在本仓库的 stub 中不存在;`GetString()` 有 null 安全保证(返回 `""` 而不是 `nullptr`) | + +--- + +## 4. 类型映射规则 + +| C++ 类型 | Python 类型 | C++ → Python | Python → C++ | 备注 | +|---|---|---|---|---| +| `AscendString` | `str` | `GetString()` → `std::string` | `std::string.c_str()` → `AscendString()` | 使用 `GetString()` 而非 `GetData()`(null 安全) | +| `MemHandle` (`void*`) | `int` | `reinterpret_cast(handle)` | `reinterpret_cast(uintptr_t)` | Python 端持有的是地址整数,**底层释放后 Python 端的 int 变为悬空指针,调用方需自行管理生命周期** | +| `TransferReq` (`void*`) | `int` | `reinterpret_cast(req)` | `reinterpret_cast(uintptr_t)` | 同 MemHandle,悬空指针风险 | +| `Status` (`uint32_t`) | `int` | 直接映射 | 直接映射 | | +| `MemType` (enum) | `str` | `"npu"` / `"cpu"` | `ParseMemType(str)` → enum | 非法字符串返回 `PARAM_INVALID` | +| `TransferOp` (enum) | `str` | `"READ"` / `"WRITE"` | `ParseTransferOp(str)` → enum | 非法字符串返回 `PARAM_INVALID` | +| `TransferStatus` (enum class) | `str` | `TransferStatusToStr()` | 不需要反向转换 | `"WAITING"` / `"COMPLETED"` / `"TIMEOUT"` / `"FAILED"` | +| `MemDesc` (struct) | `tuple(int, int)` | `UnpackMemDesc()` | `UnpackMemDesc()` | `(addr, len)` | +| `TransferOpDesc` (struct) | `tuple(int, int, int)` | `UnpackTransferOpDescs()` | `UnpackTransferOpDescs()` | `(local_addr, remote_addr, len)` | +| `NotifyDesc` (struct) | `tuple(str, str)` | `GetString()` → `std::string` | `UnpackNotifyDesc()` | `(name, msg)` | +| `TransferArgs` (struct) | 不暴露 | 内部构造 `{}` | — | `reserved[128] = {}` 默认为 0 | + +--- + +## 5. 返回值设计 + +与 llm_wrapper 保持一致的风格: +- **单返回值方法**:直接返回 `Status`(Python 端收到 `int`) +- **多返回值方法**:返回 `std::pair<>`(Python 端收到 `tuple`) + +| 方法 | Wrapper 返回类型 | Python 返回值 | 示例 | +|---|---|---|---| +| `Initialize` | `Status` | `int` | `status = hixl_wrapper.initialize(...)` | +| `Finalize` | `void` | `None` | `hixl_wrapper.finalize()` | +| `RegisterMem` | `std::pair` | `(int, int)` | `status, handle = hixl_wrapper.register_mem(...)` | +| `DeregisterMem` | `Status` | `int` | `status = hixl_wrapper.deregister_mem(handle)` | +| `Connect` | `Status` | `int` | `status = hixl_wrapper.connect(...)` | +| `Disconnect` | `Status` | `int` | `status = hixl_wrapper.disconnect(...)` | +| `TransferSync` | `Status` | `int` | `status = hixl_wrapper.transfer_sync(...)` | +| `TransferAsync` | `std::pair` | `(int, int)` | `status, req_id = hixl_wrapper.transfer_async(...)` | +| `GetTransferStatus` | `std::pair` | `(int, str)` | `status, ts = hixl_wrapper.get_transfer_status(req_id)` | +| `SendNotify` | `Status` | `int` | `status = hixl_wrapper.send_notify(...)` | +| `GetNotifies` | `std::pair>` | `(int, list[tuple])` | `status, notifies = hixl_wrapper.get_notifies()` | + +--- + +## 6. `hixl_engine_wrapper.h` — Wrapper 类声明 + +```cpp +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * 请参阅 License 获取详细信息。 + */ + +#ifndef CANN_HIXL_PYTHON_HIXL_WRAPPER_HIXL_ENGINE_WRAPPER_H_ +#define CANN_HIXL_PYTHON_HIXL_WRAPPER_HIXL_ENGINE_WRAPPER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "hixl/hixl.h" +#include "hixl/hixl_types.h" + +namespace hixl_wrapper { + +// Tuple 类型别名:Python ↔ C++ 的桥梁类型 +using MemDescTuple = std::tuple; // (addr, length) +using TransferOpDescTuple = std::tuple; // (local_addr, remote_addr, len) +using NotifyDescTuple = std::tuple; // (name, msg) + +class HixlEngineWrapper { + public: + // 拆包方法:Python tuple → C++ struct + static hixl::MemDesc UnpackMemDesc(const MemDescTuple &mem_desc_tuple); + static std::vector UnpackTransferOpDescs( + const std::vector &op_desc_tuples); + static hixl::NotifyDesc UnpackNotifyDesc(const NotifyDescTuple ¬ify_tuple); + + // 字符串 ↔ 枚举转换(非法输入返回 PARAM_INVALID) + static std::pair ParseMemType(const std::string &mem_type_str); + static std::pair ParseTransferOp(const std::string &op_str); + static std::string TransferStatusToStr(hixl::TransferStatus status); + + // 业务方法(全部 static,与 llm_wrapper 风格一致) + static hixl::Status Initialize(const std::string &local_engine, + const std::map &options); + static void Finalize(); + static std::pair RegisterMem(const MemDescTuple &mem_desc, + const std::string &mem_type); + static hixl::Status DeregisterMem(uintptr_t mem_handle); + static hixl::Status Connect(const std::string &remote_engine, int32_t timeout_ms = 1000); + static hixl::Status Disconnect(const std::string &remote_engine, int32_t timeout_ms = 1000); + static hixl::Status TransferSync(const std::string &remote_engine, + const std::string &operation, + const std::vector &op_descs, + int32_t timeout_ms = 1000); + static std::pair TransferAsync(const std::string &remote_engine, + const std::string &operation, + const std::vector &op_descs); + static std::pair GetTransferStatus(uintptr_t req_id); + static hixl::Status SendNotify(const std::string &remote_engine, + const NotifyDescTuple ¬ify, + int32_t timeout_ms = 1000); + static std::pair> GetNotifies(); + + private: + static std::unique_ptr hixl_engine_; +}; + +} // namespace hixl_wrapper + +#endif // CANN_HIXL_PYTHON_HIXL_WRAPPER_HIXL_ENGINE_WRAPPER_H_ +``` + +### 设计要点 + +- 所有方法都是 `static`——Python 端不需要创建实例,直接调用裸函数(与 llm_wrapper 一致) +- `hixl_engine_` 是 `static unique_ptr` 单例——`Initialize` 创建,`Finalize` 销毁(与 llm_wrapper 一致) +- `MemHandle`/`TransferReq`(都是 `void*`)用 `uintptr_t` 传递给 Python(Python `int`),调用方需自行管理生命周期,避免悬空指针 +- 枚举类型在 Python 端用 `str` 表示("npu"/"cpu","READ"/"WRITE","WAITING"/"COMPLETED" 等) +- `ParseMemType`/`ParseTransferOp` 返回 `std::pair`——非法字符串返回 `PARAM_INVALID` + ALOG 警告,不再静默吞错 +- 带默认超时参数的方法(`Connect`/`Disconnect`/`TransferSync`/`SendNotify`)默认值为 1000ms,与 C++ API 一致 +- `Finalize` 返回 `void`——与 C++ API 一致,Python 端无返回值 +- 单返回值方法直接返回 `Status`(Python `int`),多返回值方法返回 `std::pair<>`(Python `tuple`)——与 llm_wrapper 一致 + +--- + +## 7. `hixl_engine_wrapper.cc` — Wrapper 类实现 + +### 7.1 生命周期管理 + +```cpp +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * 请参阅 License 获取详细信息。 + */ + +#include "hixl_engine_wrapper.h" +#include "alog/alog.h" + +namespace hixl_wrapper { + +std::unique_ptr HixlEngineWrapper::hixl_engine_; + +hixl::Status HixlEngineWrapper::Initialize( + const std::string &local_engine, + const std::map &options) { + if (hixl_engine_ != nullptr) { + ALOG_WARN("HixlEngineWrapper::Initialize: engine already initialized, repeat init"); + return hixl::PARAM_INVALID; + } + hixl_engine_ = std::make_unique(); + hixl::AscendString ascend_local_engine(local_engine.c_str()); + std::map ascend_options; + for (const auto &opt : options) { + ascend_options.emplace(hixl::AscendString(opt.first.c_str()), + hixl::AscendString(opt.second.c_str())); + } + auto ret = hixl_engine_->Initialize(ascend_local_engine, ascend_options); + if (ret != hixl::SUCCESS) { + ALOG_ERROR("HixlEngineWrapper::Initialize: failed, ret=%u", ret); + hixl_engine_.reset(); + } + return ret; +} + +void HixlEngineWrapper::Finalize() { + if (hixl_engine_ != nullptr) { + hixl_engine_->Finalize(); + hixl_engine_.reset(); + } +} +``` + +> **注意**:`Finalize` 返回 `void`,与 C++ API(`void Hixl::Finalize()`)和 llm_wrapper 参考实现(`void LLMDataDistV2Wrapper::Finalize()`)完全一致。Python 端调用 `hixl_wrapper.finalize()` 无返回值。 + +### 7.2 字符串 ↔ 枚举转换 + +```cpp +std::pair HixlEngineWrapper::ParseMemType(const std::string &mem_type_str) { + if (mem_type_str == "npu") return {hixl::SUCCESS, hixl::MEM_DEVICE}; + if (mem_type_str == "cpu") return {hixl::SUCCESS, hixl::MEM_HOST}; + ALOG_WARN("HixlEngineWrapper::ParseMemType: invalid mem_type '%s', expected 'npu' or 'cpu'", + mem_type_str.c_str()); + return {hixl::PARAM_INVALID, hixl::MEM_DEVICE}; +} + +std::pair HixlEngineWrapper::ParseTransferOp(const std::string &op_str) { + if (op_str == "READ") return {hixl::SUCCESS, hixl::READ}; + if (op_str == "WRITE") return {hixl::SUCCESS, hixl::WRITE}; + ALOG_WARN("HixlEngineWrapper::ParseTransferOp: invalid operation '%s', expected 'READ' or 'WRITE'", + op_str.c_str()); + return {hixl::PARAM_INVALID, hixl::READ}; +} + +std::string HixlEngineWrapper::TransferStatusToStr(hixl::TransferStatus status) { + switch (status) { + case hixl::TransferStatus::WAITING: return "WAITING"; + case hixl::TransferStatus::COMPLETED: return "COMPLETED"; + case hixl::TransferStatus::TIMEOUT: return "TIMEOUT"; + case hixl::TransferStatus::FAILED: return "FAILED"; + default: return "UNKNOWN"; + } +} +``` + +> **设计决策**:`ParseMemType`/`ParseTransferOp` 返回 `std::pair`,非法字符串返回 `PARAM_INVALID` + ALOG 警告。调用方需检查 Status 后再使用 enum 值。这避免了原设计中非法字符串被静默吞错的问题。 + +### 7.3 拆包方法 + +```cpp +hixl::MemDesc HixlEngineWrapper::UnpackMemDesc(const MemDescTuple &t) { + hixl::MemDesc mem_desc{}; + mem_desc.addr = std::get<0>(t); + mem_desc.len = std::get<1>(t); + // reserved[128] 默认值为 0,MemDesc 定义已有 = {} + return mem_desc; +} + +std::vector HixlEngineWrapper::UnpackTransferOpDescs( + const std::vector &op_desc_tuples) { + std::vector op_descs; + op_descs.reserve(op_desc_tuples.size()); + for (const auto &t : op_desc_tuples) { + hixl::TransferOpDesc desc{}; + desc.local_addr = std::get<0>(t); + desc.remote_addr = std::get<1>(t); + desc.len = std::get<2>(t); + op_descs.emplace_back(desc); + } + return op_descs; +} + +hixl::NotifyDesc HixlEngineWrapper::UnpackNotifyDesc(const NotifyDescTuple &t) { + hixl::NotifyDesc notify{}; + notify.name = hixl::AscendString(std::get<0>(t).c_str()); + notify.notify_msg = hixl::AscendString(std::get<1>(t).c_str()); + return notify; +} +``` + +### 7.4 RegisterMem — 有输出参数的方法 + +```cpp +std::pair HixlEngineWrapper::RegisterMem( + const MemDescTuple &mem_desc_tuple, const std::string &mem_type_str) { + hixl::MemHandle handle = nullptr; + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + auto [parse_status, mem_type] = ParseMemType(mem_type_str); + if (parse_status != hixl::SUCCESS) { + return {parse_status, reinterpret_cast(handle)}; + } + auto mem_desc = UnpackMemDesc(mem_desc_tuple); + ret = hixl_engine_->RegisterMem(mem_desc, mem_type, handle); + } else { + ALOG_WARN("HixlEngineWrapper::RegisterMem: engine not initialized"); + } + return {ret, reinterpret_cast(handle)}; +} +``` + +### 7.5 DeregisterMem + +```cpp +hixl::Status HixlEngineWrapper::DeregisterMem(uintptr_t mem_handle) { + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + hixl::MemHandle handle = reinterpret_cast(mem_handle); + ret = hixl_engine_->DeregisterMem(handle); + } else { + ALOG_WARN("HixlEngineWrapper::DeregisterMem: engine not initialized"); + } + return ret; +} +``` + +### 7.6 Connect + +```cpp +hixl::Status HixlEngineWrapper::Connect(const std::string &remote_engine, int32_t timeout_ms) { + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + hixl::AscendString ascend_remote(remote_engine.c_str()); + ret = hixl_engine_->Connect(ascend_remote, timeout_ms); + } else { + ALOG_WARN("HixlEngineWrapper::Connect: engine not initialized"); + } + return ret; +} +``` + +> **默认超时**:`timeout_ms` 默认值为 1000ms,与 C++ API `int32_t timeout_in_millis = 1000` 一致。Python 端调用 `hixl_wrapper.connect("IP:PORT")` 不传 timeout 时使用 1000ms。 + +### 7.7 Disconnect + +```cpp +hixl::Status HixlEngineWrapper::Disconnect(const std::string &remote_engine, int32_t timeout_ms) { + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + hixl::AscendString ascend_remote(remote_engine.c_str()); + ret = hixl_engine_->Disconnect(ascend_remote, timeout_ms); + } else { + ALOG_WARN("HixlEngineWrapper::Disconnect: engine not initialized"); + } + return ret; +} +``` + +### 7.8 TransferSync + +```cpp +hixl::Status HixlEngineWrapper::TransferSync( + const std::string &remote_engine, + const std::string &operation, + const std::vector &op_desc_tuples, + int32_t timeout_ms) { + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + auto [parse_status, op] = ParseTransferOp(operation); + if (parse_status != hixl::SUCCESS) { + return parse_status; + } + hixl::AscendString ascend_remote(remote_engine.c_str()); + auto op_descs = UnpackTransferOpDescs(op_desc_tuples); + ret = hixl_engine_->TransferSync(ascend_remote, op, op_descs, timeout_ms); + } else { + ALOG_WARN("HixlEngineWrapper::TransferSync: engine not initialized"); + } + return ret; +} +``` + +### 7.9 TransferAsync — 有输出参数的方法 + +```cpp +std::pair HixlEngineWrapper::TransferAsync( + const std::string &remote_engine, + const std::string &operation, + const std::vector &op_desc_tuples) { + hixl::TransferReq req = nullptr; + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + auto [parse_status, op] = ParseTransferOp(operation); + if (parse_status != hixl::SUCCESS) { + return {parse_status, reinterpret_cast(req)}; + } + hixl::AscendString ascend_remote(remote_engine.c_str()); + auto op_descs = UnpackTransferOpDescs(op_desc_tuples); + hixl::TransferArgs args{}; // reserved[128] 默认为 0 + ret = hixl_engine_->TransferAsync(ascend_remote, op, op_descs, args, req); + } else { + ALOG_WARN("HixlEngineWrapper::TransferAsync: engine not initialized"); + } + return {ret, reinterpret_cast(req)}; +} +``` + +> **悬空指针风险**:`TransferAsync` 返回的 `req_id`(`uintptr_t`)指向底层引擎的内部数据。如果引擎被 `Finalize` 销毁后再用此 `req_id` 调用 `GetTransferStatus`,会导致悬空指针访问。**调用方需确保在 Finalize 前完成所有异步传输查询。** + +### 7.10 GetTransferStatus — 有输出参数的方法 + +```cpp +std::pair HixlEngineWrapper::GetTransferStatus(uintptr_t req_id) { + hixl::TransferStatus status = hixl::TransferStatus::FAILED; + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + hixl::TransferReq req = reinterpret_cast(req_id); + ret = hixl_engine_->GetTransferStatus(req, status); + } else { + ALOG_WARN("HixlEngineWrapper::GetTransferStatus: engine not initialized"); + } + return {ret, TransferStatusToStr(status)}; +} +``` + +### 7.11 SendNotify + +```cpp +hixl::Status HixlEngineWrapper::SendNotify( + const std::string &remote_engine, + const NotifyDescTuple ¬ify_tuple, + int32_t timeout_ms) { + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + hixl::AscendString ascend_remote(remote_engine.c_str()); + auto notify = UnpackNotifyDesc(notify_tuple); + ret = hixl_engine_->SendNotify(ascend_remote, notify, timeout_ms); + } else { + ALOG_WARN("HixlEngineWrapper::SendNotify: engine not initialized"); + } + return ret; +} +``` + +### 7.12 GetNotifies — 有输出参数的方法 + +```cpp +std::pair> HixlEngineWrapper::GetNotifies() { + std::vector notifies; + hixl::Status ret = hixl::FAILED; + if (hixl_engine_ != nullptr) { + ret = hixl_engine_->GetNotifies(notifies); + } else { + ALOG_WARN("HixlEngineWrapper::GetNotifies: engine not initialized"); + } + std::vector notify_tuples; + for (const auto &n : notifies) { + // 使用 GetString() 而非 GetData(): + // GetString() 有 null 安全保证(name_ 为 nullptr 时返回 ""),GetData() 可能返回 nullptr 导致 crash + std::string name(n.name.GetString()); + std::string msg(n.notify_msg.GetString()); + notify_tuples.emplace_back(std::make_tuple(name, msg)); + } + return {ret, notify_tuples}; +} + +} // namespace hixl_wrapper +``` + +> **设计决策**:使用 `GetString()` 而非 `GetData()`。原因: +> 1. 本仓库的 AscendString stub 实现中只有 `GetString()` 方法,没有 `GetData()` 方法 +> 2. `GetString()` 有 null 安全保证:当内部 `name_` 为 nullptr 时返回静态空字符串 `""`,而 `GetData()`(如果 CANN SDK 中存在)可能返回 `nullptr`,传给 `std::string(nullptr)` 会 crash + +--- + +## 8. `hixl_wrapper.cc` — pybind11 模块入口 + +```cpp +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * 请参阅 License 获取详细信息。 + */ + +#include "Python.h" +#ifdef ASCEND_CI_LIMITED_PY37 +#undef PyCFunction_NewEx +#endif + +#include +#include +#include +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "hixl/hixl.h" +#include "hixl/hixl_types.h" +#include "hixl_engine_wrapper.h" + +#undef PYBIND11_CHECK_PYTHON_VERSION +#define PYBIND11_CHECK_PYTHON_VERSION + +namespace hixl_wrapper { +namespace { +namespace py = pybind11; + +void BindStatusCodes(py::module &m) { + // Status 状态码(与 llm_wrapper 的 BindStatusCodes 风格一致) + m.attr("kSuccess") = py::int_(hixl::SUCCESS); + m.attr("kFailed") = py::int_(hixl::FAILED); + m.attr("kParamInvalid") = py::int_(hixl::PARAM_INVALID); + m.attr("kTimeout") = py::int_(hixl::TIMEOUT); + m.attr("kNotConnected") = py::int_(hixl::NOT_CONNECTED); + m.attr("kAlreadyConnected") = py::int_(hixl::ALREADY_CONNECTED); + m.attr("kNotifyFailed") = py::int_(hixl::NOTIFY_FAILED); + m.attr("kUnsupported") = py::int_(hixl::UNSUPPORTED); + m.attr("kResourceExhausted") = py::int_(hixl::RESOURCE_EXHAUSTED); + + // MemType 枚举值(供 Python 端可选使用,虽然主要用 str 传参) + m.attr("kMemDevice") = py::int_(hixl::MEM_DEVICE); + m.attr("kMemHost") = py::int_(hixl::MEM_HOST); + + // TransferOp 枚举值 + m.attr("kRead") = py::int_(hixl::READ); + m.attr("kWrite") = py::int_(hixl::WRITE); + + // 初始化选项常量(与 C++ OPTION_* 一致) + m.attr("kOptionEnableUseFabricMem") = py::str(hixl::OPTION_ENABLE_USE_FABRIC_MEM); + m.attr("kOptionRdmaTrafficClass") = py::str(hixl::OPTION_RDMA_TRAFFIC_CLASS); + m.attr("kOptionRdmaServiceLevel") = py::str(hixl::OPTION_RDMA_SERVICE_LEVEL); + m.attr("kOptionBufferPool") = py::str(hixl::OPTION_BUFFER_POOL); + m.attr("kOptionGlobalResourceConfig") = py::str(hixl::OPTION_GLOBAL_RESOURCE_CONFIG); +} + +void BuildHixlFuncs(py::module &m) { + // 所有方法使用 py::call_guard(): + // C++ 操作不访问 Python 对象,释放 GIL 让其他 Python 线程并发执行 + (void)m.def("initialize", &HixlEngineWrapper::Initialize, py::call_guard()); + (void)m.def("finalize", &HixlEngineWrapper::Finalize, py::call_guard()); + (void)m.def("register_mem", &HixlEngineWrapper::RegisterMem, py::call_guard()); + (void)m.def("deregister_mem", &HixlEngineWrapper::DeregisterMem, py::call_guard()); + (void)m.def("connect", &HixlEngineWrapper::Connect, py::call_guard()); + (void)m.def("disconnect", &HixlEngineWrapper::Disconnect, py::call_guard()); + (void)m.def("transfer_sync", &HixlEngineWrapper::TransferSync, py::call_guard()); + (void)m.def("transfer_async", &HixlEngineWrapper::TransferAsync, py::call_guard()); + (void)m.def("get_transfer_status", &HixlEngineWrapper::GetTransferStatus, py::call_guard()); + (void)m.def("send_notify", &HixlEngineWrapper::SendNotify, py::call_guard()); + (void)m.def("get_notifies", &HixlEngineWrapper::GetNotifies, py::call_guard()); +} + +} // namespace + +PYBIND11_MODULE(hixl_wrapper, m) { + BindStatusCodes(m); + BuildHixlFuncs(m); +} + +} // namespace hixl_wrapper +``` + +> **兼容性处理**(与 llm_wrapper 一致): +> - `#include "Python.h"` + `ASCEND_CI_LIMITED_PY37` 宏:处理 Python 3.7 兼容性 +> - `PYBIND11_CHECK_PYTHON_VERSION`:允许在构建时 Python 版本与运行时不完全匹配(CANN 构建环境可能使用不同 Python 版本) + +--- + +## 9. GIL 释放策略 + +所有方法使用 `py::call_guard()`,理由: + +| 方法 | 释放 GIL 的理由 | +|---|---| +| `Initialize` | 涉及硬件资源初始化,可能耗时 | +| `Finalize` | 涉及硬件资源销毁,可能耗时 | +| `RegisterMem` | 涉及 NPU 内存注册,可能涉及 RDMA 操作 | +| `DeregisterMem` | 涉及内存解注册 | +| `Connect` | 涉及网络建链,需要等待远端响应 | +| `Disconnect` | 涉及网络断链 | +| `TransferSync` | 同步传输,等待硬件完成 | +| `TransferAsync` | 下发传输请求,涉及硬件操作 | +| `GetTransferStatus` | 查询传输状态 | +| `SendNotify` | 涉及网络通信 | +| `GetNotifies` | 查询通知信息 | + +所有 C++ 操作都不访问 Python 对象,释放 GIL 可以让其他 Python 线程并发执行。 + +> **注意**:`Finalize` 返回 `void`,pybind11 对 `void` 返回值自动返回 `None` 到 Python。GIL 释放不影响返回值传递。 + +--- + +## 10. `CMakeLists.txt` — 构建配置 + +```cmake +# ---------------------------------------------------------------------------- +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# 请参阅 License 获取详细信息。 +# ---------------------------------------------------------------------------- + +if (NOT ENABLE_TEST) + set(CMAKE_SKIP_RPATH TRUE) + add_library(hixl_wrapper MODULE + hixl_wrapper.cc + hixl_engine_wrapper.cc) + + target_include_directories(hixl_wrapper PRIVATE + ${HI_PYTHON_INC} + ${pybind11_INCLUDE_DIR} + ${HIXL_CODE_DIR}/include + ${HIXL_CODE_DIR}/src/llm_datadist + ) + + target_link_libraries(hixl_wrapper PRIVATE + $ + $ + $ + $ + $ + $ + alog + cann_hixl + ) + + set_target_properties(hixl_wrapper + PROPERTIES + PREFIX "" + ) + + target_compile_definitions(hixl_wrapper PRIVATE + PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF + ) + + target_compile_options(hixl_wrapper PRIVATE + -Xlinker -export-dynamic + ) + + target_link_options(hixl_wrapper PRIVATE + -s + ) +endif () +``` + +> **与 `llm_wrapper` CMakeLists.txt 的对比**: +> +> | 配置项 | llm_wrapper | hixl_wrapper | 说明 | +> |---|---|---|---| +> | 链接目标 | `llm_datadist` | `cann_hixl` | 绑定不同层的 C++ 库 | +> | 源文件 | `llm_wrapper_v2.cc` + `llm_datadist_v2_wrapper.cc` | `hixl_wrapper.cc` + `hixl_engine_wrapper.cc` | | +> | include 目录 | 相同 | 相同(需 `src/llm_datadist` 因为 `ge_api_error_codes.h` 通过 `metadef_headers` 提供) | | +> | 其他配置 | 相同 | 相同 | `PREFIX ""`, `PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF`, `-Xlinker -export-dynamic`, `-s` | + +--- + +## 11. 更新 `hixl/src/python/CMakeLists.txt` + +```cmake +add_subdirectory(llm_datadist) +add_subdirectory(llm_wrapper) +add_subdirectory(metadef_wrapper) +add_subdirectory(hixl_wrapper) # ← 新增:HIXL Engine Python 绑定 +add_subdirectory(hixl_engine) # ← 新增:独立 wheel 打包 +``` + +--- + +## 12. 独立 Wheel 打包 + +### 12.1 目录结构 + +``` +hixl/src/python/hixl_engine/ +├── CMakeLists.txt # wheel 打包构建配置 +├── setup.py # setuptools 打包脚本 +├── MANIFEST.in # 包含 .so 文件的规则 +└── hixl_engine/ + └── __init__.py # Python 包入口 +``` + +### 12.2 `hixl_engine/__init__.py` + +```python +"""HIXL Engine Python binding package.""" +from hixl_engine import hixl_wrapper # noqa: F401 +``` + +### 12.3 `setup.py` + +```python +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ---------------------------------------------------------------------------- +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# 请参阅 License 获取详细信息。 +# ---------------------------------------------------------------------------- + +from setuptools import setup, find_packages + +setup( + name='hixl_engine', + version='0.0.1', + description='hixl engine api', + packages=find_packages(), + include_package_data=True, + ext_modules=[] +) +``` + +### 12.4 `MANIFEST.in` + +``` +recursive-include * *.so +``` + +### 12.5 `CMakeLists.txt` — wheel 打包 + +```cmake +# ---------------------------------------------------------------------------- +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# 请参阅 License 获取详细信息。 +# ---------------------------------------------------------------------------- + +add_custom_target(hixl_engine_python ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/hixl_engine-0.0.1-py3-none-any.whl) +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/hixl_engine-0.0.1-py3-none-any.whl + COMMAND echo "package hixl engine whl start" + && mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/wheel1 + && cp -r ${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/wheel1/setup.py + && cp -r ${CMAKE_CURRENT_SOURCE_DIR}/MANIFEST.in ${CMAKE_CURRENT_BINARY_DIR}/wheel1/ + && cp -r ${CMAKE_CURRENT_SOURCE_DIR}/hixl_engine ${CMAKE_CURRENT_BINARY_DIR}/wheel1/ + && cp -r ${CMAKE_CURRENT_BINARY_DIR}/../hixl_wrapper/hixl_wrapper.so ${CMAKE_CURRENT_BINARY_DIR}/wheel1/hixl_engine/ + && cd ${CMAKE_CURRENT_BINARY_DIR}/wheel1 + && ${HI_PYTHON} setup.py bdist_wheel >/dev/null + && cp -f dist/hixl_engine-0.0.1-py3-none-any.whl ${CMAKE_CURRENT_BINARY_DIR}/ + && echo "package hixl engine whl end" + DEPENDS hixl_wrapper +) +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/hixl_engine-0.0.1-py3-none-any.whl OPTIONAL + DESTINATION ${INSTALL_LIBRARY_DIR} +) +``` + +> **与 `llm_datadist` wheel 打包的对比**: +> +> | 配置项 | llm_datadist | hixl_engine | 说明 | +> |---|---|---|---| +> | wheel 名 | `llm_datadist-0.0.1-py3-none-any.whl` | `hixl_engine-0.0.1-py3-none-any.whl` | 独立 wheel | +> | 包含 .so | `llm_datadist_wrapper.so` + `metadef_wrapper.so` | `hixl_wrapper.so` | 只包含一个 .so | +> | Python 包名 | `llm_datadist` | `hixl_engine` | | +> | DEPENDS | `llm_datadist_wrapper` + `metadef_wrapper` + `generate_hixl_version_info` | `hixl_wrapper` | | +> | Python 导入 | `from llm_datadist import llm_datadist_wrapper` | `from hixl_engine import hixl_wrapper` | | + +--- + +## 13. Python 端使用示例 + +```python +import hixl_wrapper + +# 初始化引擎 +status = hixl_wrapper.initialize("192.168.1.1:5000", {hixl_wrapper.kOptionBufferPool: "4G"}) +if status != hixl_wrapper.kSuccess: + raise RuntimeError(f"Initialize failed, status={status}") + +# 注册内存(NPU 设备内存) +status, handle = hixl_wrapper.register_mem((0x1000, 4096), "npu") +if status != hixl_wrapper.kSuccess: + raise RuntimeError(f"RegisterMem failed, status={status}") + +# 连接远端(默认超时 1000ms) +status = hixl_wrapper.connect("192.168.1.2:5000") +# 或显式指定超时: +status = hixl_wrapper.connect("192.168.1.2:5000", 3000) + +# 同步传输(READ:从远端拉到本地,默认超时 1000ms) +status = hixl_wrapper.transfer_sync( + "192.168.1.2:5000", + "READ", + [(local_addr, remote_addr, length)], + 5000 # 显式指定超时 +) + +# 异步传输(WRITE:将本地写到远端) +status, req_id = hixl_wrapper.transfer_async( + "192.168.1.2:5000", + "WRITE", + [(local_addr, remote_addr, length)] +) + +# 查询传输状态(轮询直到 COMPLETED) +status, transfer_status = hixl_wrapper.get_transfer_status(req_id) +# transfer_status 可能是 "WAITING", "COMPLETED", "TIMEOUT", "FAILED" + +# 发送通知(默认超时 1000ms) +status = hixl_wrapper.send_notify( + "192.168.1.2:5000", + ("signal_name", "message_content"), + 3000 +) + +# 获取通知 +status, notifies = hixl_wrapper.get_notifies() +for name, msg in notifies: + print(f"Notify: {name} - {msg}") + +# 清理(必须在 Finalize 前完成所有异步传输查询) +status = hixl_wrapper.deregister_mem(handle) +status = hixl_wrapper.disconnect("192.168.1.2:5000") +hixl_wrapper.finalize() # 返回 None,无返回值 +``` + +--- + +## 14. 安全注意事项 + +### 14.1 悬空指针风险 + +`MemHandle` 和 `TransferReq` 都是 `void*`,通过 `uintptr_t` 桥接传递给 Python。Python 端持有的是地址整数。**如果底层引擎释放了 handle/req,Python 端还持有该整数并再次传回 C++,会导致悬空指针访问**。 + +**调用方需遵守的生命周期规则**: +1. `DeregisterMem(handle)` 后,不可再用该 `handle` 调用任何方法 +2. 异步传输完成后(`GetTransferStatus` 返回 `"COMPLETED"`),`req_id` 不再有效 +3. `Finalize()` 前必须完成所有异步传输查询和资源释放 + +### 14.2 枚举字符串校验 + +`ParseMemType`/`ParseTransferOp` 对非法字符串返回 `PARAM_INVALID` + ALOG 警告,不再静默返回默认值。调用方需检查 Status 后再使用结果。 + +### 14.3 AscendString null 安全 + +`GetNotifies` 中使用 `GetString()` 而非 `GetData()` 转换 `AscendString` → `std::string`。`GetString()` 在 `AscendString` 内部为 null 时返回 `""`(空字符串),避免 `std::string(nullptr)` crash。 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a137b61..7853f9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,13 @@ yr = [ "requests", ] +hixl = [ + "hixl_engine>=0.0.1", + "torch>=2.7.1; platform_machine == 'x86_64'", + "torch>=2.7.1; platform_machine == 'aarch64'", + "torch-npu>=2.7.1.post2", +] + [project.urls] Homepage = "https://github.com/Ascend/ray-ascend" Repository = "https://github.com/Ascend/ray-ascend" diff --git a/ray_ascend/__init__.py b/ray_ascend/__init__.py index 8477487..36d0c99 100644 --- a/ray_ascend/__init__.py +++ b/ray_ascend/__init__.py @@ -13,6 +13,7 @@ "register_yr_tensor_transport", "register_hccl_collective_backend", "register_hccl_tensor_transport", + "register_hixl_tensor_transport", ] __commit__ = _version.commit @@ -171,3 +172,76 @@ def transfer_npu_tensor_via_hccs(self): from .direct_transport.hccl_tensor_transport import HCCLTensorTransport register_tensor_transport("HCCL", ["npu"], HCCLTensorTransport, torch.Tensor) + + +def register_hixl_tensor_transport(devices: List[str] = ["npu", "cpu"]) -> None: + """Register HIXL tensor transport for Ray RDT on Ascend NPU. + + This function should be called in the driver process to register the + HIXL transport backend. It must also be called in each Ray actor's + __init__ to register the transport for that actor process. + + HIXL uses RDMA READ for zero-copy tensor transfer between Ascend NPU + nodes. Unlike YR (which uses a DataSystem intermediary), HIXL is a + true one-sided RDMA transport where the receiver directly reads the + sender's memory. + + Requirements: + - HIXL Engine wheel installed: pip install hixl_engine-0.0.1-py3-none-any.whl + - CANN driver and runtime installed on all NPU nodes + - RDMA/HCCS links established between nodes + + Args: + devices: List of device types to support. Can be: + - ["npu"] for NPU tensors only + - ["npu", "cpu"] for NPU and CPU tensors + - ["cpu"] for CPU tensors only + + Example: + import ray + from ray_ascend import register_hixl_tensor_transport + + ray.init() + register_hixl_tensor_transport(["npu", "cpu"]) + + @ray.remote(resources={"NPU": 1}) + class RayActor: + def __init__(self): + register_hixl_tensor_transport(["npu", "cpu"]) + + @ray.method(tensor_transport="HIXL") + def transfer_npu_tensor_via_rdma(self): + return torch.zeros(1024, device="npu") + """ + if devices is None: + raise ValueError( + "devices cannot be None. Specify a list of device types, " + "e.g., ['npu', 'cpu']" + ) + + import torch + + try: + from ray.experimental import register_tensor_transport + + from ray_ascend.direct_transport.hixl_tensor_transport import ( + HixlTensorTransport, + ) + except ImportError as e: + raise ImportError( + "HIXL tensor transport requires the hixl_engine package. " + "Please install it with: " + "pip install hixl_engine-0.0.1-py3-none-any.whl" + ) from e + + # Verify hixl_wrapper is importable before registration. + try: + import hixl_wrapper + except ImportError as e: + raise ImportError( + "hixl_wrapper module not found. HIXL tensor transport requires " + "the HIXL Engine wheel. Please install: " + "pip install hixl_engine-0.0.1-py3-none-any.whl" + ) from e + + register_tensor_transport("HIXL", devices, HixlTensorTransport, torch.Tensor) diff --git a/ray_ascend/direct_transport/hixl_tensor_transport.py b/ray_ascend/direct_transport/hixl_tensor_transport.py new file mode 100644 index 0000000..fb83784 --- /dev/null +++ b/ray_ascend/direct_transport/hixl_tensor_transport.py @@ -0,0 +1,847 @@ +import logging +import pickle +import threading +import time +import traceback +import uuid +from collections import OrderedDict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import ray +from ray.experimental.rdt.tensor_transport_manager import ( + CommunicatorMetadata, + FetchRequest, + TensorTransportManager, + TensorTransportMetadata, +) + +if TYPE_CHECKING: + import torch + +logger = logging.getLogger(__name__) + +# Lazy import: hixl_wrapper may not be installed on all nodes. +try: + import hixl_wrapper +except ImportError: + hixl_wrapper = None + +# Maximum number of cached HIXL remote engine connections. +# When exceeded, the least recently used remote engine is evicted and +# Disconnect is called. Set to 0 to disable remote engine reuse. +HIXL_REMOTE_ENGINE_CACHE_MAXSIZE = 1000 + + +@dataclass +class HixlCommunicatorMetadata(CommunicatorMetadata): + """Metadata for the HIXL communicator.""" + + +@dataclass +class HixlTransportMetadata(TensorTransportMetadata): + """Metadata for tensors stored in the NPU/CPU object store for HIXL transport. + + Args: + hixl_serialized_mem_descs: Pickle-serialized list of + (data_ptr, nbytes, mem_type_str) tuples describing the source + tensors' registered memory regions. + hixl_engine_id: The local HIXL engine identifier (format: "host_ip:port") + that the remote side uses to Connect back. + hixl_engine_meta_version: Monotonically increasing version number bumped + whenever memory is deregistered, so the receiver can detect stale + descriptors. + """ + + hixl_serialized_mem_descs: Optional[bytes] = None + hixl_engine_id: Optional[str] = None + hixl_engine_meta_version: Optional[int] = 0 + + __eq__ = object.__eq__ + __hash__ = object.__hash__ + + +@dataclass +class HixlTensorDesc: + """Cached registration info for a single tensor storage. + + HIXL's RegisterMem returns only a MemHandle (void*), which does not carry + address or size information. We keep the original registration parameters + alongside the handle so we can: + - Build TransferOpDesc tuples on the source side (addr, len are needed) + - Call DeregisterMem(mem_handle) when the ref count drops to zero + - Serialize (data_ptr, nbytes, mem_type_str) into transport metadata + + Attributes: + mem_handle: The opaque handle returned by hixl_wrapper.register_mem. + Represented as a Python int (uintptr_t under the hood). + nbytes: Size of the registered memory region in bytes. + mem_type_str: "npu" or "cpu" — used when building TransferOpDesc and + for serialization into HixlTransportMetadata. + metadata_count: Number of HixlTransportMetadata objects that reference + this tensor. When it reaches zero, we call DeregisterMem. + """ + + mem_handle: Any # uintptr_t → Python int + nbytes: int + mem_type_str: str # "npu" | "cpu" + metadata_count: int + + +@dataclass +class HixlFetchRequest(FetchRequest): + """HIXL-specific fetch request carrying the async transfer state. + + Returned by fetch_multiple_tensors and consumed by wait_fetch_complete. + Resource cleanup happens in __del__ so that handles are released even if + the caller never waits on the request. + + Args: + obj_id: Inherited. The object ID for the transfer, used for abort checks. + tensors: Inherited. Pre-allocated output tensors (populated before the + transfer starts). + transfer_req: HIXL TransferReq handle (uintptr_t → Python int). + remote_engine_id: The remote engine ID (ip:port) that was connected + for this transfer. + remove_tensor_descs: Whether to remove tensor descriptors from the + cache during cleanup (True when fetch_multiple_tensors added them). + transport: Reference to the HixlTensorTransport instance for cleanup. + """ + + transfer_req: Any = None + remote_engine_id: Optional[str] = None + remove_tensor_descs: bool = False + transport: Any = None + + def __del__(self): + if self.transport is not None: + self.transport._cleanup_transfer( + self.obj_id, + self.tensors, + self.transfer_req, + self.remote_engine_id, + self.remove_tensor_descs, + ) + + +class HixlTensorTransport(TensorTransportManager): + """HIXL Engine-based one-sided RDMA tensor transport for Ray RDT.""" + + def __init__(self): + # Lazily initialized because hixl_wrapper may not be installed on + # nodes that are only coordinating (not participating in transfers). + self._hixl_initialized = False + self._local_engine_id: Optional[str] = None + + # Object IDs whose transfers have been aborted. + self._aborted_transfer_obj_ids: set = set() + self._aborted_transfer_obj_ids_lock = threading.Lock() + + # Mapping from tensor storage data_ptr → HixlTensorDesc. + # Unlike _managed_meta_hixl, we only deregister tensors when ALL + # metadata containing the tensor is freed (reference counting via + # metadata_count). + self._tensor_desc_cache: Dict[int, HixlTensorDesc] = {} + + # Mapping from object ID → HixlTransportMetadata. + # Lifetime is tied to the object ref; freed when the ref goes out of + # scope (garbage_collect is called). + self._managed_meta_hixl: Dict[str, Any] = {} + + # Lock protecting _tensor_desc_cache and _managed_meta_hixl since they + # can be accessed from the main task execution thread or the + # _ray_system thread. + self._cache_lock = threading.RLock() + + # LRU cache of remote engine IDs. When full, the least recently used + # remote engine is evicted and Disconnect is called. + self._remote_engines: OrderedDict = OrderedDict() + + # Incremented whenever memory is deregistered so receivers can detect + # stale descriptors. + self._hixl_engine_meta_version: int = 0 + + def tensor_transport_backend(self) -> str: + return "HIXL" + + @staticmethod + def is_one_sided() -> bool: + return True # HIXL RDMA: receiver initiates READ (one-sided) + + @staticmethod + def can_abort_transport() -> bool: + return True # TransferAsync can be interrupted via abort flag + + # ------------------------------------------------------------------ + # HIXL agent lifecycle + # ------------------------------------------------------------------ + + def _ensure_hixl_initialized(self): + """Lazily initializes the HIXL engine via hixl_wrapper. + + The engine ID is constructed from the Ray actor's node IP + actor_id + as the port component, ensuring uniqueness per actor. + + Raises: + ImportError: If hixl_wrapper is not installed. + RuntimeError: If HIXL initialization fails. + """ + if self._hixl_initialized: + return + + if hixl_wrapper is None: + raise ImportError( + "hixl_wrapper module not found. " + "Please install the HIXL Engine wheel: " + "pip install hixl_engine-0.0.1-py3-none-any.whl" + ) + + # Build a local engine ID from the Ray actor's IP address. + # The port component is generated locally; HIXL uses this as a + # logical identifier for the RDMA endpoint. + ctx = ray.get_runtime_context() + actor_id = ctx.get_actor_id() + if actor_id is None: + # Driver process — generate a unique ID. + actor_id = f"RAY-DRIVER-{uuid.uuid4()}" + + node_ip = ray.util.get_node_ip_address() + # Use actor_id as the port component to ensure uniqueness per actor. + self._local_engine_id = f"{node_ip}:{actor_id}" + + status = hixl_wrapper.initialize(self._local_engine_id, {}) + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"Failed to initialize HIXL engine with id " + f"'{self._local_engine_id}', status={status}. " + f"Common causes:\n" + f" - HIXL library not installed or incompatible version\n" + f" - RDMA hardware not available on this node\n" + f" - CANN driver/runtime version mismatch" + ) + + self._hixl_initialized = True + logger.info( + f"HIXL engine initialized with local_engine_id={self._local_engine_id}" + ) + + def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: + """Check if a remote actor has the HIXL transport available.""" + # TODO: This is called on a .remote RDT call, so it's quite expensive. + def __ray_actor_has_tensor_transport__( + self: "ray.actor.ActorHandle", + ) -> bool: + try: + from ray.experimental.rdt.util import get_tensor_transport_manager + + manager = get_tensor_transport_manager("HIXL") + manager._ensure_hixl_initialized() + return True + except Exception: + return False + + return ray.get( + actor.__ray_call__.options(concurrency_group="_ray_system").remote( + __ray_actor_has_tensor_transport__ + ) + ) + + # ------------------------------------------------------------------ + # Public memory registration API + # ------------------------------------------------------------------ + + def register_hixl_memory(self, tensor: "torch.Tensor") -> None: + """Registers the tensor's memory with HIXL and bumps the reference + count so the memory region is never deregistered. + + Call this to pre-register a tensor's memory for the lifetime of the + process, which can improve performance if the same tensor is re-used + in multiple RDT objects. + """ + self._add_tensor_descs([tensor]) + + def deregister_hixl_memory(self, tensor: "torch.Tensor") -> None: + """Decrements the reference count for the tensor's HIXL memory + registration added by register_hixl_memory. + + If the reference count reaches 0, the memory is deregistered from + HIXL. This should only be called after register_hixl_memory has been + called for this tensor. Any existing ObjectRef instances that reference + this tensor's memory will keep the HIXL registration alive independently + until they go out of scope. + """ + self._remove_tensor_descs([tensor]) + + # ------------------------------------------------------------------ + # Memory registration / deregistration helpers + # ------------------------------------------------------------------ + + def _add_tensor_descs(self, tensors: List["torch.Tensor"]): + """Register tensor memory with HIXL and bump reference counts. + + If a tensor's storage is already registered (keyed by data_ptr), we + only increment the metadata_count. Otherwise we call + hixl_wrapper.register_mem and cache the handle + registration params. + """ + self._ensure_hixl_initialized() + + with self._cache_lock: + for tensor in tensors: + key = tensor.untyped_storage().data_ptr() + if key in self._tensor_desc_cache: + self._tensor_desc_cache[key].metadata_count += 1 + continue + + # Determine memory type: NPU tensors → "npu", CPU → "cpu". + mem_type_str = "npu" if tensor.device.type == "npu" else "cpu" + + # Register the full underlying storage with HIXL. + # HIXL register_mem takes (addr, len) tuple + mem_type string. + addr = tensor.untyped_storage().data_ptr() + nbytes = tensor.untyped_storage().nbytes() + + try: + status, mem_handle = hixl_wrapper.register_mem( + (addr, nbytes), mem_type_str + ) + except Exception as e: + raise RuntimeError( + f"Failed to register {mem_type_str} memory with HIXL " + f"(addr=0x{addr:x}, size={nbytes} bytes). " + f"Common causes:\n" + f" - CANN driver/runtime not installed\n" + f" - RDMA device not available\n" + f" - HCCS link not established\n" + f" - Container privilege level too low" + ) from e + + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"HIXL RegisterMem returned error status={status} " + f"for {mem_type_str} memory " + f"(addr=0x{addr:x}, size={nbytes} bytes)" + ) + + self._tensor_desc_cache[key] = HixlTensorDesc( + mem_handle=mem_handle, + nbytes=nbytes, + mem_type_str=mem_type_str, + metadata_count=1, + ) + + def _remove_tensor_descs(self, tensors: List["torch.Tensor"]): + """Decrement reference counts and deregister when they reach zero. + + When metadata_count drops to zero we call hixl_wrapper.deregister_mem + with the cached MemHandle and bump _hixl_engine_meta_version. + """ + with self._cache_lock: + for tensor in tensors: + key = tensor.untyped_storage().data_ptr() + if key not in self._tensor_desc_cache: + continue + tensor_desc = self._tensor_desc_cache[key] + tensor_desc.metadata_count -= 1 + if tensor_desc.metadata_count == 0: + self._tensor_desc_cache.pop(key) + try: + status = hixl_wrapper.deregister_mem(tensor_desc.mem_handle) + if status != hixl_wrapper.kSuccess: + logger.warning( + f"HIXL DeregisterMem returned status={status} " + f"for handle={tensor_desc.mem_handle}" + ) + except Exception: + logger.warning( + f"HIXL DeregisterMem raised exception for " + f"handle={tensor_desc.mem_handle}", + exc_info=True, + ) + self._hixl_engine_meta_version += 1 + + def _tensor_memory_registered(self, t: "torch.Tensor") -> bool: + """Check if the tensor's memory has been registered with HIXL.""" + return t.untyped_storage().data_ptr() in self._tensor_desc_cache + + # ------------------------------------------------------------------ + # Core transport methods + # ------------------------------------------------------------------ + + def extract_tensor_transport_metadata( + self, + obj_id: str, + rdt_object: List["torch.Tensor"], + ) -> HixlTransportMetadata: + """Source side: register tensor memory and serialize descriptors. + + Called on the source actor immediately after the task creates the + result tensors. We: + 1. Synchronize the device to ensure data is written. + 2. Register each tensor's storage with HIXL (RegisterMem). + 3. Serialize the memory descriptions as pickle bytes. + 4. Return HixlTransportMetadata with the serialized descs, the + local engine ID, and the current meta version. + + Args: + obj_id: The object ID for the RDT object. + rdt_object: The RDT object (list of tensors). + + Returns: + HixlTransportMetadata containing serialized memory descriptions + and the local engine ID. + """ + import torch + + with self._cache_lock: + device = None + tensor_meta = [] + mem_descs_for_serialization = [] + + if rdt_object: + # All tensors must share the same device type. + device = rdt_object[0].device + devices = set() + for t in rdt_object: + if t.device.type != device.type: + raise ValueError( + "All tensors in an RDT object must have the same " + "device type." + ) + if not t.is_contiguous(): + raise ValueError( + "All tensors in an RDT object must be contiguous." + ) + tensor_meta.append((t.shape, t.dtype)) + devices.add(t.device) + + if device.type == "npu": + # Synchronize before registration to assure the data has + # been written — HIXL does not guarantee this. + for dev in devices: + torch.npu.synchronize(dev) + + self._add_tensor_descs(rdt_object) + + # Build serialization payload: for each registered tensor, + # we pack (data_ptr, nbytes, mem_type_str). The receiver + # uses these to construct TransferOpDesc tuples. + for t in rdt_object: + key = t.untyped_storage().data_ptr() + desc = self._tensor_desc_cache[key] + mem_descs_for_serialization.append( + (key, desc.nbytes, desc.mem_type_str) + ) + + serialized_mem_descs = pickle.dumps(mem_descs_for_serialization) + engine_id = self._local_engine_id + engine_meta_version = self._hixl_engine_meta_version + else: + serialized_mem_descs = None + engine_id = None + engine_meta_version = None + + ret = HixlTransportMetadata( + tensor_meta=tensor_meta, + tensor_device=device.type if device else None, + hixl_serialized_mem_descs=serialized_mem_descs, + hixl_engine_id=engine_id, + hixl_engine_meta_version=engine_meta_version, + ) + self._put_meta(obj_id, ret) + return ret + + def get_communicator_metadata( + self, + src_actor: "ray.actor.ActorHandle", + dst_actor: "ray.actor.ActorHandle", + backend: Optional[str] = None, + ) -> HixlCommunicatorMetadata: + """One-sided RDMA transport: no communicator metadata needed.""" + return HixlCommunicatorMetadata() + + def fetch_multiple_tensors( + self, + obj_id: str, + tensor_transport_metadata: HixlTransportMetadata, + communicator_metadata: HixlCommunicatorMetadata, + target_buffers: Optional[List["torch.Tensor"]] = None, + ) -> HixlFetchRequest: + """Receiver side: initiate an RDMA READ transfer. + + This triggers the transfer but does not wait for completion. Call + wait_fetch_complete(fetch_request) to retrieve the tensors. + + Steps: + 1. Allocate target tensors (or use provided buffers). + 2. Register target memory with HIXL. + 3. Deserialize the source memory descriptions from metadata. + 4. Connect to the remote HIXL engine (using engine_id from metadata). + 5. Build TransferOpDesc tuples: (local_addr, remote_addr, len). + 6. Call hixl_wrapper.transfer_async("READ", op_descs, remote_engine_id). + 7. Return HixlFetchRequest with the async transfer handle. + + Args: + obj_id: The object ID for the transfer. + tensor_transport_metadata: Source-side metadata containing + serialized memory descriptions and the remote engine ID. + communicator_metadata: Empty HixlCommunicatorMetadata. + target_buffers: Optional pre-allocated buffers to receive into. + + Returns: + HixlFetchRequest carrying the async transfer state. + """ + from ray.experimental.rdt.util import create_empty_tensors_from_metadata + + tensors = target_buffers or create_empty_tensors_from_metadata( + tensor_transport_metadata + ) + + assert isinstance(tensor_transport_metadata, HixlTransportMetadata) + assert isinstance(communicator_metadata, HixlCommunicatorMetadata) + + serialized_mem_descs = tensor_transport_metadata.hixl_serialized_mem_descs + remote_engine_id = tensor_transport_metadata.hixl_engine_id + + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError( + f"HIXL transfer aborted for object id: {obj_id}" + ) + + transfer_req = None + added_tensor_descs = False + + assert tensors + + try: + self._ensure_hixl_initialized() + + # Register local target tensors with HIXL. + self._add_tensor_descs(tensors) + added_tensor_descs = True + + # Deserialize the source-side memory descriptions. + remote_mem_descs = pickle.loads(serialized_mem_descs) + + # Connect to the remote HIXL engine (or reuse cached connection). + remote_engine_meta_version = ( + tensor_transport_metadata.hixl_engine_meta_version + ) + + self._connect_remote_engine( + remote_engine_id, remote_engine_meta_version + ) + + # Build TransferOpDesc tuples for RDMA READ. + # For each tensor pair (local target, remote source): + # local_addr = target tensor's storage data_ptr + # remote_addr = source tensor's data_ptr (from deserialized mem desc) + # len = nbytes (must match; we validate this) + op_descs = [] + for i, t in enumerate(tensors): + remote_addr, remote_nbytes, _ = remote_mem_descs[i] + local_addr = t.untyped_storage().data_ptr() + local_nbytes = t.untyped_storage().nbytes() + if local_nbytes != remote_nbytes: + raise RuntimeError( + f"HIXL transfer size mismatch for tensor {i}: " + f"local={local_nbytes} bytes vs remote={remote_nbytes} bytes" + ) + op_descs.append((local_addr, remote_addr, remote_nbytes)) + + # Initiate async RDMA READ from remote engine. + status, transfer_req = hixl_wrapper.transfer_async( + remote_engine_id, "READ", op_descs + ) + + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"HIXL TransferAsync returned error status={status} " + f"for object id: {obj_id}" + ) + + return HixlFetchRequest( + obj_id=obj_id, + tensors=tensors, + transfer_req=transfer_req, + remote_engine_id=remote_engine_id, + remove_tensor_descs=added_tensor_descs, + transport=self, + ) + except Exception: + self._cleanup_transfer( + obj_id, tensors, transfer_req, remote_engine_id, + added_tensor_descs, + ) + # Import here to avoid circular dependency on startup. + from ray.exceptions import RayDirectTransportError + + raise RayDirectTransportError( + f"The HIXL transfer failed for object id: {obj_id}. " + f"The source actor may have died during the transfer. " + f"The exception thrown from HIXL transfer was:\n " + f"{traceback.format_exc()}" + ) from None + + def wait_fetch_complete( + self, fetch_request: HixlFetchRequest, timeout: float = -1 + ) -> List["torch.Tensor"]: + """Wait for a previously initiated HIXL fetch to complete. + + Polls hixl_wrapper.get_transfer_status until the state is "COMPLETED", + "TIMEOUT", or "FAILED". Supports abort via _aborted_transfer_obj_ids. + + Args: + fetch_request: The HixlFetchRequest returned by + fetch_multiple_tensors. + timeout: Maximum time in seconds to wait. -1 means wait + indefinitely. 0 means return immediately if not ready. + + Returns: + List of tensors that were transferred. + + Raises: + RayDirectTransportError: If the transfer failed. + TimeoutError: If the timeout is exceeded. + """ + assert isinstance(fetch_request, HixlFetchRequest) + obj_id = fetch_request.obj_id + + if not fetch_request.tensors: + return fetch_request.tensors + + try: + # Poll transfer status until completion. + deadline = None if timeout < 0 else time.monotonic() + timeout + while True: + status, transfer_status = hixl_wrapper.get_transfer_status( + fetch_request.transfer_req + ) + if status != hixl_wrapper.kSuccess: + raise RuntimeError( + f"HIXL GetTransferStatus returned error status={status} " + f"for object id: {obj_id}" + ) + + if transfer_status == "FAILED": + raise RuntimeError( + f"HIXL transfer got FAILED state for object id: {obj_id}" + ) + if transfer_status == "TIMEOUT": + raise RuntimeError( + f"HIXL transfer got TIMEOUT state for object id: {obj_id}" + ) + if transfer_status == "WAITING": + if deadline is not None and time.monotonic() >= deadline: + raise TimeoutError( + f"HIXL transfer timed out after {timeout}s " + f"for object id: {obj_id}" + ) + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError( + f"HIXL transfer aborted for object id: {obj_id}" + ) + time.sleep(0.001) # Avoid busy waiting + elif transfer_status == "COMPLETED": + break + + return fetch_request.tensors + except TimeoutError: + raise + except Exception: + from ray.exceptions import RayDirectTransportError + + raise RayDirectTransportError( + f"The HIXL transfer failed for object id: {obj_id}. " + f"The source actor may have died during the transfer. " + f"The exception thrown from HIXL transfer was:\n " + f"{traceback.format_exc()}" + ) from None + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def _cleanup_transfer( + self, + obj_id: str, + tensors: List["torch.Tensor"], + transfer_req: Any, + remote_engine_id: Optional[str], + remove_tensor_descs: bool, + ) -> None: + """Best-effort cleanup after a transfer completes or fails. + + We may encounter errors or HIXL may raise errors like connection + loss, so we do best-effort cleanup without raising further errors. + """ + if not self._hixl_initialized: + return + + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.discard(obj_id) + + # HIXL does not have an explicit release_xfer_handle API; + # the TransferReq is consumed by GetTransferStatus polling. + + # Evict remote engine from LRU cache if caching is disabled. + if HIXL_REMOTE_ENGINE_CACHE_MAXSIZE == 0 and remote_engine_id: + self._disconnect_remote_engine(remote_engine_id) + + if remove_tensor_descs: + self._remove_tensor_descs(tensors) + + # ------------------------------------------------------------------ + # Remote engine connection management (LRU cache) + # ------------------------------------------------------------------ + + def _connect_remote_engine( + self, remote_engine_id: str, remote_engine_meta_version: int + ) -> None: + """Connect to a remote HIXL engine, with LRU caching. + + - If the remote engine is already cached and the meta version + matches, we reuse the connection (move to end of LRU). + - If the meta version differs (source deregistered memory), we + disconnect first and reconnect. + - If the cache is full, evict the least recently used engine. + # 情况 1:已在缓存 + 版号一致 → 复用连接,return,不 connect + # 情况 2:已在缓存 + 版号不一致 → 断开 + 重连 + 存缓存 + # 情况 3:不在缓存 + 缓存未满 → connect + 存缓存 + # 情况 4:不在缓存 + 缓存已满 → 淘汰最旧 + connect + 存缓存 + # ===== else 分支(缓存关闭)===== + # 情况只有一种:直接 connect,不查缓存,不存缓存 + """ + if HIXL_REMOTE_ENGINE_CACHE_MAXSIZE > 0: + if remote_engine_id in self._remote_engines: + cached_version = self._remote_engines[remote_engine_id] + if cached_version != remote_engine_meta_version: + # Source deregistered memory — stale descriptors. + # Disconnect before reconnecting. + self._disconnect_remote_engine(remote_engine_id) + else: + # Reuse cached connection; move to end of LRU. + self._remote_engines.move_to_end(remote_engine_id) + return + + elif len(self._remote_engines) >= HIXL_REMOTE_ENGINE_CACHE_MAXSIZE: + # Evict least recently used remote engine. + evicted_engine_id, _ = self._remote_engines.popitem(last=False) + self._disconnect_remote_engine(evicted_engine_id) + + # Establish new connection. + status = hixl_wrapper.connect(remote_engine_id) + if status != hixl_wrapper.kSuccess and status != hixl_wrapper.kAlreadyConnected: + raise RuntimeError( + f"HIXL Connect to '{remote_engine_id}' failed, " + f"status={status}" + ) + + self._remote_engines[remote_engine_id] = remote_engine_meta_version + else: + # No caching — connect fresh each time. + status = hixl_wrapper.connect(remote_engine_id) + if status != hixl_wrapper.kSuccess and status != hixl_wrapper.kAlreadyConnected: + raise RuntimeError( + f"HIXL Connect to '{remote_engine_id}' failed, " + f"status={status}" + ) + + def _disconnect_remote_engine(self, remote_engine_id: str) -> None: + """Disconnect from a remote HIXL engine (best-effort).""" + try: + hixl_wrapper.disconnect(remote_engine_id) + except Exception: + logger.warning( + f"HIXL Disconnect from '{remote_engine_id}' raised exception", + exc_info=True, + ) + + # ------------------------------------------------------------------ + # Synchronous recv fallback + # ------------------------------------------------------------------ + + def recv_multiple_tensors( + self, + obj_id: str, + tensor_transport_metadata: HixlTransportMetadata, + communicator_metadata: HixlCommunicatorMetadata, + target_buffers: Optional[List["torch.Tensor"]] = None, + ) -> List["torch.Tensor"]: + """Receives multiple tensors synchronously (fetch + wait).""" + fetch_request = self.fetch_multiple_tensors( + obj_id, tensor_transport_metadata, communicator_metadata, + target_buffers, + ) + return self.wait_fetch_complete(fetch_request) + + def send_multiple_tensors( + self, + tensors: List["torch.Tensor"], + tensor_transport_metadata: HixlTransportMetadata, + communicator_metadata: HixlCommunicatorMetadata, + ): + """Not implemented — HIXL is a one-sided transport.""" + raise NotImplementedError( + "HIXL transport does not support send_multiple_tensors, " + "since it is a one-sided transport." + ) + + # ------------------------------------------------------------------ + # Garbage collection & abort + # Ray 分布式引用计数发现:所有接收方都不再持有这个 ref,执行garbage_collect + # ------------------------------------------------------------------ + + def garbage_collect( + self, + obj_id: str, + tensor_transport_meta: HixlTransportMetadata, + tensors: List["torch.Tensor"], + ): + """Release source-side resources for an RDT object. + + Called on the source actor after Ray's distributed ref counting + determines the object is out of scope. We: + 1. Pop the metadata from _managed_meta_hixl. + 2. Remove tensor descriptors (decrement ref count; deregister + when it reaches zero). + """ + with self._cache_lock: + assert isinstance(tensor_transport_meta, HixlTransportMetadata) + if obj_id not in self._managed_meta_hixl: + return + self._managed_meta_hixl.pop(obj_id, None) + self._remove_tensor_descs(tensors) + + def abort_transport( + self, + obj_id: str, + communicator_metadata: HixlCommunicatorMetadata, + ): + """Mark a transfer as aborted so wait_fetch_complete can exit.""" + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.add(obj_id) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_num_managed_meta_hixl(self) -> int: + """Return the number of tracked HixlTransportMetadata objects.""" + with self._cache_lock: + return len(self._managed_meta_hixl) + + def _get_meta(self, object_id: str) -> Optional[HixlTransportMetadata]: + """Get the HIXL transport metadata for the given object ID.""" + with self._cache_lock: + if object_id in self._managed_meta_hixl: + return self._managed_meta_hixl[object_id] + return None + + def _put_meta(self, object_id: str, meta: HixlTransportMetadata): + """Store the HIXL transport metadata for the given object ID.""" + with self._cache_lock: + self._managed_meta_hixl[object_id] = meta \ No newline at end of file