Skip to content

docs: add versionadded 1.0.0 to multi-device APIs#441

Merged
justinchuby merged 41 commits into
onnx:mainfrom
justinchuby:justinchu/versionadded-multi-device-docs
Jun 12, 2026
Merged

docs: add versionadded 1.0.0 to multi-device APIs#441
justinchuby merged 41 commits into
onnx:mainfrom
justinchuby:justinchu/versionadded-multi-device-docs

Conversation

@justinchuby

@justinchuby justinchuby commented Jun 11, 2026

Copy link
Copy Markdown
Member

Summary

Follow-up docs-only change for the multi-device feature merged in #437.
Adds .. versionadded:: 1.0.0 to the public multi-device APIs so the
generated API docs record when they were introduced.

Changes

  • Add .. versionadded:: 1.0.0 to the multi-device dataclasses:
    ModelConfiguration, NodeDeviceConfiguration, ShardingSpec,
    ShardedDim, SimpleShardedDim, IndexToDeviceGroupMapEntry
  • Add .. versionadded:: 1.0.0 to the new methods:
    Node.shard, Node.sharding_of, Node.set_pipeline_stage,
    Model.add_device_configuration, Model.remove_device_configuration
  • Mention the new device_configurations parameter on Node and Model
    with a brief versionadded note

Validation

  • python -m pytest src/onnx_ir/_multi_device_test.py src/onnx_ir/serde_test.py (277 passed)

Copilot AI and others added 30 commits May 22, 2026 23:35
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
…new fields on Model/Node, update protocols and clone

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
…lds to device_configurations

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Bind sharding and node device configurations directly to Value and
ModelConfiguration objects instead of name strings. The proto tensor_name
and configuration_id are derived from value.name / configuration.name on
serialization, so references follow renames and there is a single source
of truth.

- _multi_device.py: object-bound dataclasses, context-aware deserialization
  (value resolution + configuration placeholders), and check_device_configurations
- serde.py: resolve values from node scope and configurations via a model
  post-pass; lossless round-trip for dangling references
- _core.py: Model.add_device_configuration, Node.shard, Node.sharding_of
- _cloner.py: remap sharding Value references through the value map on clone
- tests: rewritten and expanded for the object-bound API

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add tests for serde helper edge cases (unspecified/symbolic dims, empty
tensor_name, value-less specs, empty configuration names), every
check_device_configurations branch (function nodes, passthrough configs,
missing/empty configuration references, structural violations), clone
edge cases (configless nodes, passthrough bytes, unmapped spec values),
convenience-method branches (None value, shapeless value, pipeline-stage
merge, passthrough skip, explicit num_devices), and a dangling
configuration_id round-trip.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add docs/multi_device.md with a prose guide and runnable examples for
the object-bound multi-device API (add_device_configuration, Node.shard,
Node.sharding_of, check_device_configurations) plus serialization
behavior. Register it in the index toctree and add the multi-device
classes and check_device_configurations to the API reference.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Provide the counterpart to add_device_configuration. Accepts a
ModelConfiguration object or its name. By default it removes only the
model-level configuration and leaves node references intact (surfaced by
check_device_configurations); cascade=True also strips every node sharding
that referenced it, across the graph and functions, leaving no dangling
references.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Remove the bytes / raw-proto passthrough from device configuration
serialization. Model.device_configurations now accepts only
ModelConfiguration and Node.device_configurations only
NodeDeviceConfiguration; anything else is rejected at the serialize
boundary with a TypeError. This makes the type annotations honest and
removes the defensive isinstance checks that the passthrough required
across serde, the checker, the cloner, and the convenience methods.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sharding the same value along multiple axes for one configuration now
builds a single ShardingSpec with one ShardedDim per axis (the canonical
ONNX representation for a multi-axis device mesh), instead of emitting
multiple ShardingSpecs for the same tensor. Devices are unioned across
those calls, and sharding the same axis twice raises.

Enrich docs/multi_device.md with a 'Common sharding patterns' section
covering 1D row/column sharding, 2D device meshes, replication across
device groups (index_to_device_group_map), and querying shardings back
from a node. Patterns follow the Shardy sharding vocabulary.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…review)

Address issues found in a thorough review of the multi-device feature:

- Negative axes: ONNX ShardedDimProto.axis allows [-rank, rank-1]
  (counting from the back). Node.shard and check_device_configurations
  previously rejected all negative axes, so a model that round-tripped a
  negative axis would fail validation with a false positive. Both now
  accept negative axes in range, and the shard() axis-conflict check
  normalizes so axis=-1 and axis=rank-1 are recognized as the same axis.
- pipeline_stage: shard() now raises on a conflicting non-None
  pipeline_stage for an existing configuration instead of silently
  overwriting it.
- check_device_configurations now reports the same value being sharded
  along the same axis more than once within a spec.
- Cloner._remap_device_configurations distinguishes an absent value-map
  entry (keep the reference, e.g. outer-scope values) from an entry
  mapped to None (propagate None so the reference is not left stale).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rename check_device_configurations to _check_device_configurations and
stop exporting it from the top-level onnx_ir namespace and from
_multi_device.__all__. The checker is still used internally and by tests
but is no longer part of the public API. Update tests, docstrings, and
the user docs accordingly.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sharding splits one tensor across devices; pipeline parallelism instead
places whole blocks of the graph on different devices (e.g. an LLM
decoder with early layers on an NPU and later layers on a GPU). This is
expressed with NodeDeviceConfiguration.pipeline_stage, but there was no
convenience method for pure placement since Node.shard always attaches a
sharding spec.

Node.set_pipeline_stage(configuration, stage) attaches or updates a
placement-only configuration. It shares a single NodeDeviceConfiguration
per configuration with Node.shard, so a node can be both sharded and
staged. Add tests and a 'Pipeline parallelism' docs section with a
10-layer decoder example.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The ONNX multi-device proposal represents replication with ShardingSpec
device entries that are (typically negative) keys into
index_to_device_group_map, each naming a group of real device ids the
shard is replicated across ("Sharding as a Broadcast"). The checker
previously treated every device entry as a direct id and required it to
be in [0, num_devices), so a spec-legal replication (device=[-1],
index_to_device_group_map={-1: [0, 1]}) was wrongly flagged as out of
range, even though it round-trips losslessly.

The checker now resolves group keys via index_to_device_group_map and
validates the group's member device ids, falling back to direct-id
validation otherwise. Update the docs device-group examples to the
spec's negative-key convention and add a pure-replication example.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rework the examples so the device split reflects how heterogeneous
hardware is actually used:

- Distinguish the two reasons to place parts of a model on different
  devices: capacity (split a long layer chain across identical devices,
  e.g. two GPUs, to fit memory) versus capability/affinity (route each
  op to the hardware that runs it best). The pipeline example now uses
  two identical GPUs split by position.
- Rewrite the end-to-end example as a realistic heterogeneous plan: the
  embedding Gather on CPU (memory-bound lookup over a large table),
  the decoder layers on the NPU (dense, quantization-friendly compute),
  and the LM head on the GPU (large vocab projection), with the rationale
  spelled out. Note that tensor parallelism is normally kept within a
  homogeneous device group while pipeline placement spans device types.
- Add a 'Patterns at a glance' lookup table and drop fragile heading
  anchor links that broke the docs build.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Editorial pass: remove a redundant 'grouping' example (covered by the 2D
mesh and querying examples), de-duplicate the object-bound / follows-rename
explanation, compress the pipeline split-strategy and CPU/NPU/GPU rationale
bullets to one line each, and merge the overlapping 'shard + stage' notes.
No example code logic changed; docs build clean.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rubber-duck API review flagged that 'devices' meant two different things:
Node.shard took device *indices* (ints) while
Model.add_device_configuration took device *names* (strings), inviting
confusion. Rename Node.shard's parameter to device_indices and
Model.add_device_configuration's to device_names, and cross-reference
them in the docstrings. Update call sites in tests and docs.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Relocate the proto <-> dataclass conversion functions
(serialize/deserialize_model_configuration,
serialize/deserialize_node_device_configuration and their private
helpers) from _multi_device.py to serde.py, placed in the deserialize
and serialize sections alongside the existing glue. _multi_device.py now
holds only the object-bound dataclasses and the private checker, and no
longer imports onnx — a cleaner data-model module. serde.py exports the
four public converters in __all__. Update test references to call them
via serde.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Give the six public multi-device dataclasses thorough docstrings: a prose
overview per class plus an inline docstring on every field explaining what
each value means (e.g. negative device-group keys vs direct device
indices, empty sharded_dim meaning full replication, pipeline_stage=None
meaning no pipeline participation, the int/SymbolicDim/None dim
convention). Use inline attribute docstrings rather than a napoleon
Attributes: section so autodoc does not emit duplicate field descriptions
in the API reference (docs build stays warning-clean).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The in-function 'from onnx_ir import _multi_device' was originally lazy to
avoid an import cycle. Since the serde functions moved out, _multi_device
has no runtime dependencies (it only imports _core under TYPE_CHECKING),
so _cloner can import it at module level safely. Verified no cycle by
importing _cloner first.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- _cloner: drop the defensive isinstance guard and the unreachable
  'mapped is spec.value' branch from _remap_device_configurations. The
  isinstance was a leftover from before device_configurations was
  tightened to only hold NodeDeviceConfiguration, and clone never maps a
  value to itself. Removing both eliminates the uncovered lines Codecov
  flagged and drops the now-unused _multi_device import.
- serde_test: add tests covering the device-configuration resolution
  post-pass over function nodes and its skip of nodes that have no device
  configuration (the previously-uncovered branches in
  _resolve_node_device_configurations).

The remaining flagged lines are old-onnx hasattr version guards (not
reachable in this environment) and a multi-line docstring/raise, which
are not meaningful coverage gaps.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Checker now enforces the object-bound invariant: a node must reference
  the exact ModelConfiguration object registered on the model, not a
  same-named imposter, and device indices are validated against the
  *registered* configuration's num_devices (finding #1).
- Model.add_device_configuration rejects num_devices < 1 and a
  device_names length that does not match num_devices, per the proto's
  MUST rule and to avoid the num_devices=0 placeholder sentinel collision
  (finding onnx#4).
- Node.shard validates pipeline_stage >= 0, matching set_pipeline_stage
  (finding onnx#5).
- Trim the inaccurate 'often negative' note on device-group keys.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
serialize_node_device_configuration and _serialize_sharding_spec now
raise when configuration or value is None instead of silently omitting
the MUST-be-present configuration_id / tensor_name. The cloner drops a
ShardingSpec whose value maps to None (a value dropped from the clone)
rather than producing an unserializable value=None spec.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Node.replace_input_with and Node.resize_outputs now drop any ShardingSpec
that targeted a value which is no longer one of the node's inputs or
outputs, via a new Node._drop_sharding_for_value helper. Without this,
device_configurations could point at detached values and fail
serialization. A value still referenced elsewhere on the node keeps its
spec.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Per review: SimpleShardedDim.dim now has type int | SymbolicDim (no
None). An unspecified size is SymbolicDim(None) — the field default and
what deserialization produces when the proto sets neither dim_value nor
dim_param — so SymbolicDim(None) round-trips to itself instead of
collapsing to None. Node.shard uses SymbolicDim(None) when the value's
shape is unknown. _multi_device imports SymbolicDim at runtime for the
default factory.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The dataclass field holding device *name* strings is renamed from
'device' to 'device_names' to match the add_device_configuration
constructor kwarg and to stop colliding (by name) with ShardingSpec.device,
which holds device *indices*. Update serde, docstrings, docs, and tests.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
justinchuby and others added 8 commits June 9, 2026 21:12
- serde: warn (instead of silently dropping) when device metadata is
  present but the target proto can't hold it (finding onnx#6).
- Model.remove_device_configuration: when removal is requested by name
  with cascade=True, also drop node configurations bound to same-named
  imposter objects (finding onnx#7).
- Naming cleanup: rename the node_multi_device local to
  device_configurations, extract a module-level _node_label helper, and
  drop the undefined INV-x labels from the checker docstring.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…sses)

Node.__repr__ appends device_configurations when non-empty (default
output unchanged), and Node.display prints a count, so the metadata is
visible when debugging. Node-replacement passes still drop
device_configurations, which is generally correct since a replaced node's
sharding/placement no longer applies; left as-is intentionally.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
GPT-5.5 review: the ModelConfiguration class docstring still pointed
device indices at :attr:`device`, which was renamed to device_names.
Update the cross-reference. (The :attr:`device` references inside
ShardingSpec correctly point at ShardingSpec.device, which is unchanged.)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Rename the following tuple-typed dataclass fields for consistency with
naming conventions that use plural forms for collections:

- NodeDeviceConfiguration.sharding_spec → sharding_specs
- ShardingSpec.sharded_dim → sharded_dims
- ShardedDim.simple_sharding → simple_shardings

This aligns the IR API with Python conventions where plural names indicate
collection-type fields. All references updated across:

- src/onnx_ir/_multi_device.py: Field definitions and usages in validation
- src/onnx_ir/_core.py: API references in node placement methods
- src/onnx_ir/serde.py: Serialization/deserialization functions
- src/onnx_ir/_cloner.py: Configuration cloning
- Test files: Updated test cases and assertions

All 277 tests pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Update all code examples to use the renamed IR object model fields:
- sharding_spec → sharding_specs
- sharded_dim → sharded_dims
- simple_sharding → simple_shardings

Also update prose references from singular to plural to reflect the
refactored field naming convention. Proto field accesses remain
singular (proto.sharding_spec, proto.sharded_dim, proto.simple_sharding).

Fixes documentation compatibility after field naming refactoring.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add .. versionadded:: 1.0.0 to the public multi-device APIs introduced
in the previous PR:
- ModelConfiguration
- IndexToDeviceGroupMapEntry
- SimpleShardedDim
- ShardedDim
- ShardingSpec
- NodeDeviceConfiguration
- Node.shard
- Node.sharding_of
- Node.set_pipeline_stage
- Model.add_device_configuration
- Model.remove_device_configuration

Also improve documentation for the new device_configurations attributes on
Node and Model, and fix a stale ShardingSpec doc reference to sharded_dims.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@justinchuby justinchuby requested review from a team and titaiwangms as code owners June 11, 2026 19:42
…ded-multi-device-docs

# Conflicts:
#	src/onnx_ir/_core.py
#	src/onnx_ir/_multi_device.py
@codecov

codecov Bot commented Jun 11, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 85.53%. Comparing base (af4bdfe) to head (57d08aa).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #441   +/-   ##
=======================================
  Coverage   85.53%   85.53%           
=======================================
  Files          53       53           
  Lines        7029     7029           
  Branches     1459     1459           
=======================================
  Hits         6012     6012           
  Misses        650      650           
  Partials      367      367           

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@justinchuby justinchuby changed the title Object-bound multi-device configuration metadata (IRv11) docs: add versionadded 1.0.0 to multi-device APIs Jun 12, 2026
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby enabled auto-merge (squash) June 12, 2026 17:36
@justinchuby justinchuby merged commit aface84 into onnx:main Jun 12, 2026
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants