Skip to content

Support Replicate Parallel Operator on CPUs for Realm backend#1640

Open
seemamirch wants to merge 8 commits into
flexflow:masterfrom
seemamirch:sm/realm-parallel-operators-replicate
Open

Support Replicate Parallel Operator on CPUs for Realm backend#1640
seemamirch wants to merge 8 commits into
flexflow:masterfrom
seemamirch:sm/realm-parallel-operators-replicate

Conversation

@seemamirch
Copy link
Copy Markdown

@seemamirch seemamirch commented Apr 9, 2026

Description of changes:

Add support for replicate op in distributed training & Realm backend

  • Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion
  • Add perform_shard_expansion_for_replicate and _bwd for shard expansion
  • Add build_replicate_invocation in make_dynamic_open_dataflow_graph
  • Add is_replicate_attrs helper and guard replicate in copy_insertion
  • Add ReplicateAttrs to TrainingOperationAttrs
  • Add SumReductionFloat/Double for backward replicate reduce operation
  • Add issue_replicate_bwd in spawn_dynamic_node_invocation
  • Fix per_device_op_state init race condition with direct write
  • Fix .value() calls on optional per_device_op_state across op unary/binary impls
  • Update issue_copy to support optional reduction op
  • Fix Relu to allow discard_copy_degree > 1
  • Add testcase for Replicate Op

This change is Reviewable

Seema Mirchandaney added 2 commits April 9, 2026 15:49
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion
- Add perform_shard_expansion_for_replicate and _bwd for shard expansion
- Add build_replicate_invocation in make_dynamic_open_dataflow_graph
- Add is_replicate_attrs helper and guard replicate in copy_insertion
- Add ReplicateAttrs to TrainingOperationAttrs
- Add SumReductionFloat/Double for backward replicate reduce operation
- Add issue_replicate_bwd in spawn_dynamic_node_invocation
- Fix per_device_op_state init race condition with direct write
- Fix .value() calls on optional per_device_op_state across op impls
- Update issue_copy to support optional reduction op
- Add testcase for replicate op
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion
- Add perform_shard_expansion_for_replicate and _bwd for shard expansion
- Add build_replicate_invocation in make_dynamic_open_dataflow_graph
- Add is_replicate_attrs helper and guard replicate in copy_insertion
- Add ReplicateAttrs to TrainingOperationAttrs
- Add SumReductionFloat/Double for backward replicate reduce operation
- Add issue_replicate_bwd in spawn_dynamic_node_invocation
- Fix per_device_op_state init race condition with direct write
- Fix .value() calls on optional per_device_op_state across op impls
- Update issue_copy to support optional reduction op
- Add testcase for replicate op
@seemamirch
Copy link
Copy Markdown
Author

@lockshaw @elliottslaughter - Please review

@seemamirch seemamirch force-pushed the sm/realm-parallel-operators-replicate branch from b18c75d to 3405621 Compare April 9, 2026 23:13
@lockshaw lockshaw self-requested a review April 14, 2026 19:12
Copy link
Copy Markdown
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lockshaw reviewed 9 files and all commit messages, and made 12 comments.
Reviewable status: 9 of 17 files reviewed, 12 unresolved discussions (waiting on seemamirch).


lib/op-attrs/test/src/op-attrs/ops/element_unary.cc line 65 at r1 (raw file):

    }

    SUBCASE("discard copy degree > 1") {

Minor: Ideally add a test case for the correct behavior rather than removing it. I'm also happy to contribute this if you'd prefer


lib/realm-execution/include/realm-execution/realm_context.h line 69 at r1 (raw file):

  ///\{
  Realm::Event
      issue_copy(ParallelTensorShape const &src_shape,

Minor: It would be good to get a docstring with an explanation of all these parameters at some point


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 1 at r1 (raw file):

#pragma once

For consistency with the rest of the codebase

Suggestion:

#ifndef ...

lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 8 at r1 (raw file):

// Sum reduction for float
struct SumReductionFloat {

Minor: It looks(?) like this API comes from realm, is there a link to some docs somewhere that we could include in the docstrings for this file for people not as familiar with the API>


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 9 at r1 (raw file):

// Sum reduction for float
struct SumReductionFloat {
  using LHS = float;

Why use aliases here? Since they're just constant (i.e., not dependent on template params or anything) it seems like it would be cleaner to just omit them and use the type directly


lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 34 at r1 (raw file):

                     DeviceSpecificPtr<PerDeviceOpState> *>
      device_state_map;
  std::vector<Realm::Event> completion_events;

Minor: It seems like the pattern of pair<T, Realm::Event> is becoming pretty common in realm-execution, maybe we can generalize this into a more structured future type to avoid some of the low-level manipulations? I don't love juggling these separate datastructures, it feels like creating opportunities for one to get out of sync with the other and create a bunch of bugs


lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml line 30 at r1 (raw file):

[[values]]
type = "::FlexFlow::ReplicateAttrs"

Isn't ReplicateAttrs already part of PCGOperatorAttrs?


lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc line 36 at r1 (raw file):

    return true;
  }
  return false;

In general we avoid assigning to refs unless it has a provable (i.e., profiling has been run) performance benefit as it creates more opportunities for lifetime/memory issues

Suggestion:

  TrainingOperationAttrs op_attrs = i.node_attrs.op_attrs.value();
  return op_attrs.is_replicate();

lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 25 at r1 (raw file):

  // find the layer that produces this tensor
  for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {

I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h


lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 41 at r1 (raw file):

                            parallel_tensor_guid_t const &tensor) {
  std::unordered_map<parallel_layer_guid_t, TensorSlotName> result;
  for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {

I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h


lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 79 at r1 (raw file):

static DynamicNodeInvocation
    build_replicate_invocation(parallel_layer_guid_t const &layer,
                               ParallelLayerAttrs const &attrs,

If this should only handle ReplicateAttrs, change the type of the attrs parameter to ReplicateAttrs


lib/task-spec/src/task-spec/ops/impl/element_binary.cc line 39 at r1 (raw file):

  ProfilingSettings profiling = acc.get_profiling_settings();
  DeviceType kernel_device_type = acc.get_kernel_device_type();
  std::optional<ElementBinaryPerDeviceState> per_device_state =

Am I correct in understanding that this is because CPU implementations for these ops don't need the per device state?

Copy link
Copy Markdown
Collaborator

@elliottslaughter elliottslaughter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got part of the way through reviewing this and then realized that it's the old PR 😅. I'll pick back up on the new PR next week, but in the meantime am publishing my incomplete thoughts on this one, which may or may not apply to the other.

@elliottslaughter reviewed 1 file and made 10 comments.
Reviewable status: 8 of 18 files reviewed, 19 unresolved discussions (waiting on lockshaw and seemamirch).


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 9 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why use aliases here? Since they're just constant (i.e., not dependent on template params or anything) it seems like it would be cleaner to just omit them and use the type directly

LHS/RHS are part of the Realm API, Realm expects them to be defined.


lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 34 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Minor: It seems like the pattern of pair<T, Realm::Event> is becoming pretty common in realm-execution, maybe we can generalize this into a more structured future type to avoid some of the low-level manipulations? I don't love juggling these separate datastructures, it feels like creating opportunities for one to get out of sync with the other and create a bunch of bugs

All of this code can be removed, the Realm wrapper tracks the outstanding events so we never need to do so manually. This is something that would be required only if we weren't using a Realm API wrapper.

The larger point may apply elsewhere but we can revisit if we see other examples.


lib/realm-execution/include/realm-execution/realm_context.h line 77 at r4 (raw file):

                 int priority = 0,
                 std::optional<Realm::ReductionOpID> redop_id = std::nullopt,
                 bool exclusive = false);

Not sure this is a huge deal but exclusive is only relevant when redop_id is supplied. As written, this is idiomatic for Realm, but it doesn't capture the invariant faithfully (a user could be confused and pass exclusive when it's meaningless to do so). Probably won't cause incorrect execution even if they get this wrong. Thoughts @lockshaw?


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 4 at r4 (raw file):

#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H
#include "op-attrs/datatype.dtg.h"
#include <realm.h>

We need to include our Realm wrapper instead of upstream, to ensure we always go through PRealm (when enabled).


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 30 at r4 (raw file):

      lhs += rhs;
    } else {
      // Atomic float add via CAS loop

Instead of building this ourselves I'm inclined to copy-and-paste the Legion reduction operators, which are battle tested (possibly with some of the code paths for old compilers cut out): https://gitlab.com/StanfordLegion/legion/-/blob/master/runtime/legion/api/redop.inl?ref_type=heads#L1333-1365

E.g., Legion has a code path for std::atomic_ref which is a lot cleaner than this.

Also Legion has the GPU versions we could copy from as well.

I think we should also try to separate the task implementation from the IDs for reductions that we may register. There is no reason to be including all of the task implementations everywhere we want an ID, and the reductions themselves are only required at the point where we register them with Realm.


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 143 at r4 (raw file):

 * \throws PANIC if no sum reduction is registered for the given datatype
 */
inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) {

No need for this to live in a header file.


lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 70 at r4 (raw file):

  // wait for all init tasks — direct write to *result_ptr happens
  // before each init task event fires so result is ready after this
  Realm::Event::merge_events(completion_events).wait();

This can be done via ctx.get_outstanding_events() if you want to block on work issued up to this point.


lib/realm-execution/src/realm-execution/pcg_instance.cc line 219 at r4 (raw file):

  };

  // issue_replicate_bwd lambda

Would prefer not to include comments that restate the source code.


lib/realm-execution/src/realm-execution/pcg_instance.cc line 221 at r4 (raw file):

  // issue_replicate_bwd lambda
  auto issue_replicate_bwd = [&]() {
    std::optional<DynamicValueAttrs> output_grad_opt;

Marking this for myself: I want to figure out if it's possible to write this in a way that's agnostic to the specific slot names of the inputs, which are being hard-coded here. Without thinking about this in depth yet, it feels like it could be largely agnostic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants