Support Replicate Parallel Operator on CPUs for Realm backend#1640
Support Replicate Parallel Operator on CPUs for Realm backend#1640seemamirch wants to merge 8 commits into
Conversation
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op
|
@lockshaw @elliottslaughter - Please review |
b18c75d to
3405621
Compare
lockshaw
left a comment
There was a problem hiding this comment.
@lockshaw reviewed 9 files and all commit messages, and made 12 comments.
Reviewable status: 9 of 17 files reviewed, 12 unresolved discussions (waiting on seemamirch).
lib/op-attrs/test/src/op-attrs/ops/element_unary.cc line 65 at r1 (raw file):
} SUBCASE("discard copy degree > 1") {
Minor: Ideally add a test case for the correct behavior rather than removing it. I'm also happy to contribute this if you'd prefer
lib/realm-execution/include/realm-execution/realm_context.h line 69 at r1 (raw file):
///\{ Realm::Event issue_copy(ParallelTensorShape const &src_shape,
Minor: It would be good to get a docstring with an explanation of all these parameters at some point
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 1 at r1 (raw file):
#pragma once
For consistency with the rest of the codebase
Suggestion:
#ifndef ...lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 8 at r1 (raw file):
// Sum reduction for float struct SumReductionFloat {
Minor: It looks(?) like this API comes from realm, is there a link to some docs somewhere that we could include in the docstrings for this file for people not as familiar with the API>
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 9 at r1 (raw file):
// Sum reduction for float struct SumReductionFloat { using LHS = float;
Why use aliases here? Since they're just constant (i.e., not dependent on template params or anything) it seems like it would be cleaner to just omit them and use the type directly
lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 34 at r1 (raw file):
DeviceSpecificPtr<PerDeviceOpState> *> device_state_map; std::vector<Realm::Event> completion_events;
Minor: It seems like the pattern of pair<T, Realm::Event> is becoming pretty common in realm-execution, maybe we can generalize this into a more structured future type to avoid some of the low-level manipulations? I don't love juggling these separate datastructures, it feels like creating opportunities for one to get out of sync with the other and create a bunch of bugs
lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml line 30 at r1 (raw file):
[[values]] type = "::FlexFlow::ReplicateAttrs"
Isn't ReplicateAttrs already part of PCGOperatorAttrs?
lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc line 36 at r1 (raw file):
return true; } return false;
In general we avoid assigning to refs unless it has a provable (i.e., profiling has been run) performance benefit as it creates more opportunities for lifetime/memory issues
Suggestion:
TrainingOperationAttrs op_attrs = i.node_attrs.op_attrs.value();
return op_attrs.is_replicate();lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 25 at r1 (raw file):
// find the layer that produces this tensor for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {
I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h
lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 41 at r1 (raw file):
parallel_tensor_guid_t const &tensor) { std::unordered_map<parallel_layer_guid_t, TensorSlotName> result; for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {
I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h
lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 79 at r1 (raw file):
static DynamicNodeInvocation build_replicate_invocation(parallel_layer_guid_t const &layer, ParallelLayerAttrs const &attrs,
If this should only handle ReplicateAttrs, change the type of the attrs parameter to ReplicateAttrs
lib/task-spec/src/task-spec/ops/impl/element_binary.cc line 39 at r1 (raw file):
ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); std::optional<ElementBinaryPerDeviceState> per_device_state =
Am I correct in understanding that this is because CPU implementations for these ops don't need the per device state?
elliottslaughter
left a comment
There was a problem hiding this comment.
I got part of the way through reviewing this and then realized that it's the old PR 😅. I'll pick back up on the new PR next week, but in the meantime am publishing my incomplete thoughts on this one, which may or may not apply to the other.
@elliottslaughter reviewed 1 file and made 10 comments.
Reviewable status: 8 of 18 files reviewed, 19 unresolved discussions (waiting on lockshaw and seemamirch).
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 9 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Why use aliases here? Since they're just constant (i.e., not dependent on template params or anything) it seems like it would be cleaner to just omit them and use the type directly
LHS/RHS are part of the Realm API, Realm expects them to be defined.
lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 34 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Minor: It seems like the pattern of
pair<T, Realm::Event>is becoming pretty common inrealm-execution, maybe we can generalize this into a more structured future type to avoid some of the low-level manipulations? I don't love juggling these separate datastructures, it feels like creating opportunities for one to get out of sync with the other and create a bunch of bugs
All of this code can be removed, the Realm wrapper tracks the outstanding events so we never need to do so manually. This is something that would be required only if we weren't using a Realm API wrapper.
The larger point may apply elsewhere but we can revisit if we see other examples.
lib/realm-execution/include/realm-execution/realm_context.h line 77 at r4 (raw file):
int priority = 0, std::optional<Realm::ReductionOpID> redop_id = std::nullopt, bool exclusive = false);
Not sure this is a huge deal but exclusive is only relevant when redop_id is supplied. As written, this is idiomatic for Realm, but it doesn't capture the invariant faithfully (a user could be confused and pass exclusive when it's meaningless to do so). Probably won't cause incorrect execution even if they get this wrong. Thoughts @lockshaw?
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 4 at r4 (raw file):
#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H #include "op-attrs/datatype.dtg.h" #include <realm.h>
We need to include our Realm wrapper instead of upstream, to ensure we always go through PRealm (when enabled).
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 30 at r4 (raw file):
lhs += rhs; } else { // Atomic float add via CAS loop
Instead of building this ourselves I'm inclined to copy-and-paste the Legion reduction operators, which are battle tested (possibly with some of the code paths for old compilers cut out): https://gitlab.com/StanfordLegion/legion/-/blob/master/runtime/legion/api/redop.inl?ref_type=heads#L1333-1365
E.g., Legion has a code path for std::atomic_ref which is a lot cleaner than this.
Also Legion has the GPU versions we could copy from as well.
I think we should also try to separate the task implementation from the IDs for reductions that we may register. There is no reason to be including all of the task implementations everywhere we want an ID, and the reductions themselves are only required at the point where we register them with Realm.
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 143 at r4 (raw file):
* \throws PANIC if no sum reduction is registered for the given datatype */ inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) {
No need for this to live in a header file.
lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 70 at r4 (raw file):
// wait for all init tasks — direct write to *result_ptr happens // before each init task event fires so result is ready after this Realm::Event::merge_events(completion_events).wait();
This can be done via ctx.get_outstanding_events() if you want to block on work issued up to this point.
lib/realm-execution/src/realm-execution/pcg_instance.cc line 219 at r4 (raw file):
}; // issue_replicate_bwd lambda
Would prefer not to include comments that restate the source code.
lib/realm-execution/src/realm-execution/pcg_instance.cc line 221 at r4 (raw file):
// issue_replicate_bwd lambda auto issue_replicate_bwd = [&]() { std::optional<DynamicValueAttrs> output_grad_opt;
Marking this for myself: I want to figure out if it's possible to write this in a way that's agnostic to the specific slot names of the inputs, which are being hard-coded here. Without thinking about this in depth yet, it feels like it could be largely agnostic.
Description of changes:
Add support for replicate op in distributed training & Realm backend
This change is