docs: add versionadded 1.0.0 to multi-device APIs#441
Merged
justinchuby merged 41 commits intoJun 12, 2026
Merged
Conversation
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
…new fields on Model/Node, update protocols and clone Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
…lds to device_configurations Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Bind sharding and node device configurations directly to Value and ModelConfiguration objects instead of name strings. The proto tensor_name and configuration_id are derived from value.name / configuration.name on serialization, so references follow renames and there is a single source of truth. - _multi_device.py: object-bound dataclasses, context-aware deserialization (value resolution + configuration placeholders), and check_device_configurations - serde.py: resolve values from node scope and configurations via a model post-pass; lossless round-trip for dangling references - _core.py: Model.add_device_configuration, Node.shard, Node.sharding_of - _cloner.py: remap sharding Value references through the value map on clone - tests: rewritten and expanded for the object-bound API Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add tests for serde helper edge cases (unspecified/symbolic dims, empty tensor_name, value-less specs, empty configuration names), every check_device_configurations branch (function nodes, passthrough configs, missing/empty configuration references, structural violations), clone edge cases (configless nodes, passthrough bytes, unmapped spec values), convenience-method branches (None value, shapeless value, pipeline-stage merge, passthrough skip, explicit num_devices), and a dangling configuration_id round-trip. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add docs/multi_device.md with a prose guide and runnable examples for the object-bound multi-device API (add_device_configuration, Node.shard, Node.sharding_of, check_device_configurations) plus serialization behavior. Register it in the index toctree and add the multi-device classes and check_device_configurations to the API reference. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Provide the counterpart to add_device_configuration. Accepts a ModelConfiguration object or its name. By default it removes only the model-level configuration and leaves node references intact (surfaced by check_device_configurations); cascade=True also strips every node sharding that referenced it, across the graph and functions, leaving no dangling references. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Remove the bytes / raw-proto passthrough from device configuration serialization. Model.device_configurations now accepts only ModelConfiguration and Node.device_configurations only NodeDeviceConfiguration; anything else is rejected at the serialize boundary with a TypeError. This makes the type annotations honest and removes the defensive isinstance checks that the passthrough required across serde, the checker, the cloner, and the convenience methods. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sharding the same value along multiple axes for one configuration now builds a single ShardingSpec with one ShardedDim per axis (the canonical ONNX representation for a multi-axis device mesh), instead of emitting multiple ShardingSpecs for the same tensor. Devices are unioned across those calls, and sharding the same axis twice raises. Enrich docs/multi_device.md with a 'Common sharding patterns' section covering 1D row/column sharding, 2D device meshes, replication across device groups (index_to_device_group_map), and querying shardings back from a node. Patterns follow the Shardy sharding vocabulary. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…review) Address issues found in a thorough review of the multi-device feature: - Negative axes: ONNX ShardedDimProto.axis allows [-rank, rank-1] (counting from the back). Node.shard and check_device_configurations previously rejected all negative axes, so a model that round-tripped a negative axis would fail validation with a false positive. Both now accept negative axes in range, and the shard() axis-conflict check normalizes so axis=-1 and axis=rank-1 are recognized as the same axis. - pipeline_stage: shard() now raises on a conflicting non-None pipeline_stage for an existing configuration instead of silently overwriting it. - check_device_configurations now reports the same value being sharded along the same axis more than once within a spec. - Cloner._remap_device_configurations distinguishes an absent value-map entry (keep the reference, e.g. outer-scope values) from an entry mapped to None (propagate None so the reference is not left stale). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rename check_device_configurations to _check_device_configurations and stop exporting it from the top-level onnx_ir namespace and from _multi_device.__all__. The checker is still used internally and by tests but is no longer part of the public API. Update tests, docstrings, and the user docs accordingly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sharding splits one tensor across devices; pipeline parallelism instead places whole blocks of the graph on different devices (e.g. an LLM decoder with early layers on an NPU and later layers on a GPU). This is expressed with NodeDeviceConfiguration.pipeline_stage, but there was no convenience method for pure placement since Node.shard always attaches a sharding spec. Node.set_pipeline_stage(configuration, stage) attaches or updates a placement-only configuration. It shares a single NodeDeviceConfiguration per configuration with Node.shard, so a node can be both sharded and staged. Add tests and a 'Pipeline parallelism' docs section with a 10-layer decoder example. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The ONNX multi-device proposal represents replication with ShardingSpec
device entries that are (typically negative) keys into
index_to_device_group_map, each naming a group of real device ids the
shard is replicated across ("Sharding as a Broadcast"). The checker
previously treated every device entry as a direct id and required it to
be in [0, num_devices), so a spec-legal replication (device=[-1],
index_to_device_group_map={-1: [0, 1]}) was wrongly flagged as out of
range, even though it round-trips losslessly.
The checker now resolves group keys via index_to_device_group_map and
validates the group's member device ids, falling back to direct-id
validation otherwise. Update the docs device-group examples to the
spec's negative-key convention and add a pure-replication example.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rework the examples so the device split reflects how heterogeneous hardware is actually used: - Distinguish the two reasons to place parts of a model on different devices: capacity (split a long layer chain across identical devices, e.g. two GPUs, to fit memory) versus capability/affinity (route each op to the hardware that runs it best). The pipeline example now uses two identical GPUs split by position. - Rewrite the end-to-end example as a realistic heterogeneous plan: the embedding Gather on CPU (memory-bound lookup over a large table), the decoder layers on the NPU (dense, quantization-friendly compute), and the LM head on the GPU (large vocab projection), with the rationale spelled out. Note that tensor parallelism is normally kept within a homogeneous device group while pipeline placement spans device types. - Add a 'Patterns at a glance' lookup table and drop fragile heading anchor links that broke the docs build. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Editorial pass: remove a redundant 'grouping' example (covered by the 2D mesh and querying examples), de-duplicate the object-bound / follows-rename explanation, compress the pipeline split-strategy and CPU/NPU/GPU rationale bullets to one line each, and merge the overlapping 'shard + stage' notes. No example code logic changed; docs build clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rubber-duck API review flagged that 'devices' meant two different things: Node.shard took device *indices* (ints) while Model.add_device_configuration took device *names* (strings), inviting confusion. Rename Node.shard's parameter to device_indices and Model.add_device_configuration's to device_names, and cross-reference them in the docstrings. Update call sites in tests and docs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Relocate the proto <-> dataclass conversion functions (serialize/deserialize_model_configuration, serialize/deserialize_node_device_configuration and their private helpers) from _multi_device.py to serde.py, placed in the deserialize and serialize sections alongside the existing glue. _multi_device.py now holds only the object-bound dataclasses and the private checker, and no longer imports onnx — a cleaner data-model module. serde.py exports the four public converters in __all__. Update test references to call them via serde. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Give the six public multi-device dataclasses thorough docstrings: a prose overview per class plus an inline docstring on every field explaining what each value means (e.g. negative device-group keys vs direct device indices, empty sharded_dim meaning full replication, pipeline_stage=None meaning no pipeline participation, the int/SymbolicDim/None dim convention). Use inline attribute docstrings rather than a napoleon Attributes: section so autodoc does not emit duplicate field descriptions in the API reference (docs build stays warning-clean). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The in-function 'from onnx_ir import _multi_device' was originally lazy to avoid an import cycle. Since the serde functions moved out, _multi_device has no runtime dependencies (it only imports _core under TYPE_CHECKING), so _cloner can import it at module level safely. Verified no cycle by importing _cloner first. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- _cloner: drop the defensive isinstance guard and the unreachable 'mapped is spec.value' branch from _remap_device_configurations. The isinstance was a leftover from before device_configurations was tightened to only hold NodeDeviceConfiguration, and clone never maps a value to itself. Removing both eliminates the uncovered lines Codecov flagged and drops the now-unused _multi_device import. - serde_test: add tests covering the device-configuration resolution post-pass over function nodes and its skip of nodes that have no device configuration (the previously-uncovered branches in _resolve_node_device_configurations). The remaining flagged lines are old-onnx hasattr version guards (not reachable in this environment) and a multi-line docstring/raise, which are not meaningful coverage gaps. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Checker now enforces the object-bound invariant: a node must reference the exact ModelConfiguration object registered on the model, not a same-named imposter, and device indices are validated against the *registered* configuration's num_devices (finding #1). - Model.add_device_configuration rejects num_devices < 1 and a device_names length that does not match num_devices, per the proto's MUST rule and to avoid the num_devices=0 placeholder sentinel collision (finding onnx#4). - Node.shard validates pipeline_stage >= 0, matching set_pipeline_stage (finding onnx#5). - Trim the inaccurate 'often negative' note on device-group keys. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
serialize_node_device_configuration and _serialize_sharding_spec now raise when configuration or value is None instead of silently omitting the MUST-be-present configuration_id / tensor_name. The cloner drops a ShardingSpec whose value maps to None (a value dropped from the clone) rather than producing an unserializable value=None spec. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Node.replace_input_with and Node.resize_outputs now drop any ShardingSpec that targeted a value which is no longer one of the node's inputs or outputs, via a new Node._drop_sharding_for_value helper. Without this, device_configurations could point at detached values and fail serialization. A value still referenced elsewhere on the node keeps its spec. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Per review: SimpleShardedDim.dim now has type int | SymbolicDim (no None). An unspecified size is SymbolicDim(None) — the field default and what deserialization produces when the proto sets neither dim_value nor dim_param — so SymbolicDim(None) round-trips to itself instead of collapsing to None. Node.shard uses SymbolicDim(None) when the value's shape is unknown. _multi_device imports SymbolicDim at runtime for the default factory. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The dataclass field holding device *name* strings is renamed from 'device' to 'device_names' to match the add_device_configuration constructor kwarg and to stop colliding (by name) with ShardingSpec.device, which holds device *indices*. Update serde, docstrings, docs, and tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- serde: warn (instead of silently dropping) when device metadata is present but the target proto can't hold it (finding onnx#6). - Model.remove_device_configuration: when removal is requested by name with cascade=True, also drop node configurations bound to same-named imposter objects (finding onnx#7). - Naming cleanup: rename the node_multi_device local to device_configurations, extract a module-level _node_label helper, and drop the undefined INV-x labels from the checker docstring. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…sses) Node.__repr__ appends device_configurations when non-empty (default output unchanged), and Node.display prints a count, so the metadata is visible when debugging. Node-replacement passes still drop device_configurations, which is generally correct since a replaced node's sharding/placement no longer applies; left as-is intentionally. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
GPT-5.5 review: the ModelConfiguration class docstring still pointed device indices at :attr:`device`, which was renamed to device_names. Update the cross-reference. (The :attr:`device` references inside ShardingSpec correctly point at ShardingSpec.device, which is unchanged.) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rename the following tuple-typed dataclass fields for consistency with naming conventions that use plural forms for collections: - NodeDeviceConfiguration.sharding_spec → sharding_specs - ShardingSpec.sharded_dim → sharded_dims - ShardedDim.simple_sharding → simple_shardings This aligns the IR API with Python conventions where plural names indicate collection-type fields. All references updated across: - src/onnx_ir/_multi_device.py: Field definitions and usages in validation - src/onnx_ir/_core.py: API references in node placement methods - src/onnx_ir/serde.py: Serialization/deserialization functions - src/onnx_ir/_cloner.py: Configuration cloning - Test files: Updated test cases and assertions All 277 tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Update all code examples to use the renamed IR object model fields: - sharding_spec → sharding_specs - sharded_dim → sharded_dims - simple_sharding → simple_shardings Also update prose references from singular to plural to reflect the refactored field naming convention. Proto field accesses remain singular (proto.sharding_spec, proto.sharded_dim, proto.simple_sharding). Fixes documentation compatibility after field naming refactoring. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add .. versionadded:: 1.0.0 to the public multi-device APIs introduced in the previous PR: - ModelConfiguration - IndexToDeviceGroupMapEntry - SimpleShardedDim - ShardedDim - ShardingSpec - NodeDeviceConfiguration - Node.shard - Node.sharding_of - Node.set_pipeline_stage - Model.add_device_configuration - Model.remove_device_configuration Also improve documentation for the new device_configurations attributes on Node and Model, and fix a stale ShardingSpec doc reference to sharded_dims. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…ded-multi-device-docs # Conflicts: # src/onnx_ir/_core.py # src/onnx_ir/_multi_device.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #441 +/- ##
=======================================
Coverage 85.53% 85.53%
=======================================
Files 53 53
Lines 7029 7029
Branches 1459 1459
=======================================
Hits 6012 6012
Misses 650 650
Partials 367 367 ☔ View full report in Codecov by Harness. |
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms
approved these changes
Jun 12, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up docs-only change for the multi-device feature merged in #437.
Adds
.. versionadded:: 1.0.0to the public multi-device APIs so thegenerated API docs record when they were introduced.
Changes
.. versionadded:: 1.0.0to the multi-device dataclasses:ModelConfiguration,NodeDeviceConfiguration,ShardingSpec,ShardedDim,SimpleShardedDim,IndexToDeviceGroupMapEntry.. versionadded:: 1.0.0to the new methods:Node.shard,Node.sharding_of,Node.set_pipeline_stage,Model.add_device_configuration,Model.remove_device_configurationdevice_configurationsparameter onNodeandModelwith a brief
versionaddednoteValidation
python -m pytest src/onnx_ir/_multi_device_test.py src/onnx_ir/serde_test.py(277 passed)