[feat] Add HIXL tensor transport for RDT #67
Open
Artimislyy wants to merge 1 commit into
Open
Conversation
- HixlTensorTransport class (one-sided RDMA READ for Ascend NPU) - HixlTransportMetadata / HixlFetchRequest / HixlTensorDesc data classes - hixl_wrapper pybind11 integration (RegisterMem, TransferAsync, GetTransferStatus) - LRU remote engine connection cache - Design docs and implementation prompts Co-Authored-By: Claude <noreply@anthropic.com>
CLA Signature PassArtimislyy, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Provides single-sided RDMA READ transfer capabilities for Ray RDT based on the HIXL Engine. The receiver directly reads from the sender's memory, thereby shortening the transfer path and reducing latency.
Implement
HixlTensorTransport继承自TensorTransportManager,源端通过register_mem注册tensor内存,并使用pickle序列化(data_ptr, nbytes, mem_type)描述,接收端反序列化后构造(local_addr, remote_addr, len)传输描述符,然后通过_connect_remote_engine建立/复用远程引擎连接,调用transfer_async发起RDMA READ,wait_fetch_complete轮询状态直到完成,garbage_collect在引用计数达到零时自动释放资源。此外,已完成pyproject.toml中[hixl]可选依赖组、init.py注册入口以及用户文档的集成。
Sequence Diagram
sequenceDiagram autonumber participant SRC as Source Actor participant RDT as Ray RDT<br/>(框架) participant DST as Dest Actor %% ===== Phase 1: 源端 —— 注册 + 元数据提取 ===== rect rgb(230, 245, 255) Note over SRC: Phase 1 — 源端:注册内存 + 提取元数据 SRC->>RDT: 任务执行完成,产出 tensors RDT->>SRC: extract_tensor_transport_metadata(obj_id, tensors) SRC->>SRC: torch.npu.synchronize() SRC->>SRC: _add_tensor_descs(tensors) Note over SRC: loop 每个 tensor:<br/>hixl_wrapper.register_mem()<br/>→ cache → _tensor_desc_cache[data_ptr]=HixlTensorDesc SRC->>SRC: pickle.dumps([(data_ptr, nbytes, mem_type_str)...]) SRC-->>RDT: HixlTransportMetadata<br/>(serialized_mem_descs, engine_id, meta_version)<br/> SRC->>SRC: _put_meta(obj_id, metadata) end %% ===== Phase 2: 元数据传递 ===== rect rgb(255, 245, 230) Note over RDT,DST: Phase 2 — Ray RDT 传递元数据 RDT->>DST: ObjectRef + HixlTransportMetadata<br/>(跨节点 RPC)<br/> end %% ===== Phase 3: 接收端 —— RDMA READ ===== rect rgb(230, 255, 230) Note over DST: Phase 3 — 接收端:建立连接 + RDMA READ RDT->>DST: fetch_multiple_tensors(obj_id, metadata, ...) DST->>DST: _ensure_hixl_initialized() Note over DST: hixl_wrapper.initialize(local_engine_id, {}) DST->>DST: allocate target tensors DST->>DST: _add_tensor_descs(target_tensors) Note over DST: loop 每个 target tensor:<br/>hixl_wrapper.register_mem() DST->>DST: pickle.loads(serialized_mem_descs)<br/>→ remote_mem_descs DST->>DST: _connect_remote_engine(remote_engine_id, version) Note over DST: LRU 缓存决策 DST->>DST: 构建 op_descs = [(local_addr, remote_addr, len)...] DST->>DST: hixl_wrapper.transfer_async(remote_engine_id, "READ", op_descs) DST-->>RDT: HixlFetchRequest<br/>(transfer_req, remote_engine_id, ...)<br/> end %% ===== Phase 4: 接收端 —— 轮询等待 ===== rect rgb(255, 230, 230) Note over DST: Phase 4 — 接收端:轮询等待传输完成 RDT->>DST: wait_fetch_complete(fetch_request, timeout) loop 轮询 until COMPLETED DST->>DST: hixl_wrapper.get_transfer_status(transfer_req) alt COMPLETED DST->>DST: break else FAILED / TIMEOUT DST-->>RDT: RayDirectTransportError else WAITING DST->>DST: 检查 abort 集合 + deadline 超时<br/>sleep(1ms) 避免忙等 end end DST-->>RDT: 返回 transferred tensors end %% ===== Phase 5: 资源回收 ===== rect rgb(245, 245, 245) Note over SRC,DST: Phase 5 — 资源回收 Note over DST: HixlFetchRequest.__del__() DST->>DST: _cleanup_transfer(obj_id, tensors, transfer_req, ...) DST->>DST: _remove_tensor_descs(target_tensors) Note over DST: loop (metadata_count → 0):<br/>hixl_wrapper.deregister_mem(mem_handle) Note over DST: if LRU 缓存关闭 (MAXSIZE=0):<br/>hixl_wrapper.disconnect(remote_engine_id) RDT->>SRC: garbage_collect(obj_id, metadata, tensors) SRC->>SRC: _managed_meta_hixl.pop(obj_id) SRC->>SRC: _remove_tensor_descs(tensors) Note over SRC: loop (metadata_count → 0):<br/>hixl_wrapper.deregister_mem(mem_handle)<br/> _hixl_engine_meta_version += 1<br/>→ 触发远端重连 endRelated issues
#59 #57
problem
那 RDMA 怎么找到那张 NPU 的物理内存?
靠 HCCS/CANN 底层自动完成:
HIXL 内部用 HCCL 的 one-sided API(HcclBatchGet/Put、HcclRegisterGlobalMem、HcclCommBindMem),不需要通信组,直接单侧 RDMA READ。
发送端 register_mem 时:HCCL 已经把这块虚拟地址的物理内存 pin 在了那张 NPU 上,并在 HCCS 网络上注册了地址映射
接收端 connect(remote_engine_id) 时:HCCS 建立了两台机器之间的 RDMA 通路
接收端 transfer_async 时:HCCS 网络根据 remote_addr 查找发送端 Engine 已注册的内存区域,通过 HCCS 硬件直接从那张 NPU 的物理内存 READ 数据
所以整条链路是:
发送端进程 (绑 NPU 0)
└─ register_mem → HCCL 在 NPU 0 上 pin 内存 + 注册到 HCCS 网络
└─ 把 engine_id + remote_addr 告给接收端
接收端进程 (绑 NPU 1)
└─ connect(engine_id) → HCCS 建立节点间 RDMA 通路
└─ transfer_async(local_addr, remote_addr, len)
→ HCCS 网络自动路由到发送端 NPU 0 的物理内存
→ RDMA READ 到接收端 NPU 1
接收端不需要知道发送端的 NPU 号,因为 remote_addr + HCCS 网络的注册映射已经隐式确定了物理位置。 这就像你写信只需要写对方地址,邮局自己知道怎么送到具体的楼/房间。
所以一个 connect 调用内部做了三件大事:
TCP 握手:两端交换各自的内存信息和通信资源
创建 HcclComm:基于 rank_table 建立 RDMA 通信器
绑定所有已注册内存:把之前 register_mem 产生的所有 mem_handle 绑到这个通信器上
之后 transfer_async 就能直接用这个通信器 + 已绑定的内存做 RDMA READ,不需要再传 NPU 设备号——因为 HcclComm + HcclCommBindMem 已经把一切映射关系都建立好了。