Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/op-attrs/src/op-attrs/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees(
ElementUnaryAttrs const &attrs,
ParallelTensorDimDegrees const &input_degrees) {
ASSERT(input_degrees.sum_degree.value == 1);
ASSERT(input_degrees.discard_copy_degree.value == 1);

return input_degrees;
}
Expand Down
8 changes: 0 additions & 8 deletions lib/op-attrs/test/src/op-attrs/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,5 @@ TEST_SUITE(FF_TEST_SUITE) {
SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p)));
}

SUBCASE("discard copy degree > 1") {
positive_int degree = 2_p;

CHECK_THROWS(get_output_shape(
attrs,
make_input(
SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p)));
}
}
}
19 changes: 11 additions & 8 deletions lib/realm-execution/include/realm-execution/realm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ struct RealmContext {
int priority = 0);
///\}

/** \name Data movement */
/** \name Data movement and reduction */
///\{
Realm::Event issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on = Realm::Event::NO_EVENT,
int priority = 0);
Realm::Event
issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on = Realm::Event::NO_EVENT,
int priority = 0,
std::optional<Realm::ReductionOpID> redop_id = std::nullopt,
bool exclusive = false);
///\}

/** \name Instance management */
Expand Down
154 changes: 154 additions & 0 deletions lib/realm-execution/include/realm-execution/tasks/realm_reduction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H
#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H
#include "op-attrs/datatype.dtg.h"
#include <realm.h>

namespace FlexFlow {

/**
* \brief Realm Sum Reduction for Float
* \see https://legion.stanford.edu/tutorial/realm/reductions.html
*/
struct SumReductionFloat {
using LHS = float;
using RHS = float;

/** \brief Identity element for addition (0.0) */
static constexpr RHS identity = 0.0f;

/**
* \brief Apply reduction: lhs += rhs
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param lhs Left-hand side accumulator (modified in place)
* \param rhs Value to add
*/
template <bool EXCLUSIVE>
static void apply(LHS &lhs, RHS rhs) {
if (EXCLUSIVE) {
lhs += rhs;
} else {
// Atomic float add via CAS loop
union {
float f;
int i;
} old_val, new_val;
do {
old_val.f = lhs;
new_val.f = old_val.f + rhs;
} while (
!__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i));
}
}

/**
* \brief Fold two RHS values: rhs1 += rhs2
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param rhs1 Accumulator (modified in place)
* \param rhs2 Value to fold in
*/
template <bool EXCLUSIVE>
static void fold(RHS &rhs1, RHS rhs2) {
if (EXCLUSIVE) {
rhs1 += rhs2;
} else {
// Atomic float add via CAS loop
union {
float f;
int i;
} old_val, new_val;
do {
old_val.f = rhs1;
new_val.f = old_val.f + rhs2;
} while (
!__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i));
}
}
};

/**
* \brief Realm Sum Reduction for Double
* \see https://legion.stanford.edu/tutorial/realm/reductions.html
*/
struct SumReductionDouble {
using LHS = double;
using RHS = double;

/** \brief Identity element for addition (0.0) */
static constexpr RHS identity = 0.0;

/**
* \brief Apply reduction: lhs += rhs
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param lhs Left-hand side accumulator (modified in place)
* \param rhs Value to add
*/
template <bool EXCLUSIVE>
static void apply(LHS &lhs, RHS rhs) {
if (EXCLUSIVE) {
lhs += rhs;
} else {
// Atomic double add via CAS loop using long long reinterpretation
union {
double d;
long long i;
} old_val, new_val;
do {
old_val.d = lhs;
new_val.d = old_val.d + rhs;
} while (!__sync_bool_compare_and_swap(
(long long *)&lhs, old_val.i, new_val.i));
}
}

/**
* \brief Fold two RHS values: rhs1 += rhs2
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param rhs1 Accumulator (modified in place)
* \param rhs2 Value to fold in
*/
template <bool EXCLUSIVE>
static void fold(RHS &rhs1, RHS rhs2) {
if (EXCLUSIVE) {
rhs1 += rhs2;
} else {
// Atomic double add via CAS loop using long long reinterpretation
union {
double d;
long long i;
} old_val, new_val;
do {
old_val.d = rhs1;
new_val.d = old_val.d + rhs2;
} while (!__sync_bool_compare_and_swap(
(long long *)&rhs1, old_val.i, new_val.i));
}
}
};

/**
* \brief Reduction op IDs for sum reductions
* \warning These IDs must not conflict with other registered reduction ops
*/
enum SumReductionOpIDs {
REDOP_SUM_FLOAT = 1, ///< Sum reduction op ID for float
REDOP_SUM_DOUBLE = 2, ///< Sum reduction op ID for double
};

/**
* \brief Returns the Realm reduction op ID for a sum reduction over the given datatype
* \param dtype The datatype to look up
* \return The corresponding Realm::ReductionOpID
* \throws PANIC if no sum reduction is registered for the given datatype
*/
inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) {
switch (dtype) {
case DataType::FLOAT:
return REDOP_SUM_FLOAT;
case DataType::DOUBLE:
return REDOP_SUM_DOUBLE;
default:
PANIC("no sum reduction registered for datatype {}", dtype);
}
}
} // namespace FlexFlow
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
std::unordered_map<DynamicNodeInvocation,
DeviceSpecificPtr<PerDeviceOpState> *>
device_state_map;
std::vector<Realm::Event> completion_events;
for (DynamicNodeInvocation const &invocation : dg.invocations) {
Realm::Processor target_proc = ctx.map_device_coord_to_processor(
assert_unwrap(invocation.node_attrs.device_coord));
Expand All @@ -56,14 +57,17 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
precondition);

if (completion_event.has_value()) {
completion_events.push_back(completion_event.value());
device_state_map.insert(std::pair{invocation, device_state_ptr});
} else {
// Task doesn't require initialization, clean up and don't store result
delete device_state_ptr;
}
}

ctx.get_outstanding_events().wait();
// wait for all init tasks — direct write to *result_ptr happens
// before each init task event fires so result is ready after this
Realm::Event::merge_events(completion_events).wait();

auto deref = [](DeviceSpecificPtr<PerDeviceOpState> *const &p) { return *p; };
std::unordered_map<DynamicNodeInvocation, DeviceSpecificPtr<PerDeviceOpState>>
Expand Down
49 changes: 49 additions & 0 deletions lib/realm-execution/src/realm-execution/pcg_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "realm-execution/instance_allocation.h"
#include "realm-execution/realm_context.h"
#include "realm-execution/tasks/impl/op_task.h"
#include "realm-execution/tasks/realm_reduction.h"
#include "realm-execution/tensor_instance_backing.h"
#include "task-spec/dynamic_graph/copy_insertion.h"
#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h"
Expand Down Expand Up @@ -215,13 +216,61 @@ static Realm::Event spawn_dynamic_node_invocation(
precondition);
};

// issue_replicate_bwd lambda
auto issue_replicate_bwd = [&]() {
std::optional<DynamicValueAttrs> output_grad_opt;
for (auto const &[slot, value] : invocation.inputs) {
if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) {
output_grad_opt = value;
}
}
DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt);
DynamicValueAttrs input_grad = get_only(invocation.outputs).second;
Realm::RegionInstance dst_inst =
tensor_instance_backing.backing.at(input_grad).first;

Realm::ReductionOpID redop_id = get_sum_reduction_op_id(
assert_unwrap(output_grad.parallel_tensor_shape).data_type);

// chain reductions sequentially to avoid write races on dst
Realm::Event e = precondition;
for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) {
DynamicValueAttrs replica_key = output_grad;
replica_key.mapping =
bidict<ParallelTensorSpaceCoordinate, MachineSpaceCoordinate>{{p, m}};
replica_key.shard_coord = p;

Realm::RegionInstance src_inst =
tensor_instance_backing.backing.at(replica_key).first;

e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape),
src_inst,
assert_unwrap(input_grad.parallel_tensor_shape),
dst_inst,
Realm::ProfilingRequestSet{},
e,
0,
redop_id,
false);
}
return e;
};

TrainingOperationAttrs op_attrs =
assert_unwrap(invocation.node_attrs.op_attrs);
return op_attrs.visit<Realm::Event>(overload{
[&](PCGOperatorAttrs const &pcg_op_attrs) {
return pcg_op_attrs.visit<Realm::Event>(overload{
[&](InputAttrs const &) { return Realm::Event::NO_EVENT; },
[&](WeightAttrs const &) { return Realm::Event::NO_EVENT; },
[&](ReplicateAttrs const &) {
if (invocation.node_attrs.task_type.has_value() &&
invocation.node_attrs.task_type.value() ==
DynamicTaskType::BWD) {
return issue_replicate_bwd();
}
return issue_copy(); // forward
},
[&](auto const &) { return spawn_task(); },
});
},
Expand Down
9 changes: 8 additions & 1 deletion lib/realm-execution/src/realm-execution/realm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ Realm::Event
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on,
int priority) {
int priority,
std::optional<Realm::ReductionOpID> redop_id,
bool exclusive) {
TensorShape src_piece_shape = get_piece_shape(src_shape);
TensorShape dst_piece_shape = get_piece_shape(dst_shape);
ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match
Expand All @@ -183,6 +185,11 @@ Realm::Event
size_of_datatype(src_piece_shape.data_type).int_from_positive_int()),
/*subfield_offset=*/0);

// set reduction op on dst field if provided
if (redop_id.has_value()) {
dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive);
}

Realm::Event result;
switch (src_piece_shape.dims.ff_ordered.num_dims()) {
#if REALM_MAX_DIM >= 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ void per_device_op_state_init_task_body(void const *args,
result_state, ctx.get_current_device_idx())};
DeviceSpecificPtr<PerDeviceOpState> result_device_specific{
ctx.get_current_device_idx(), result_state_ptr};
spawn_per_device_op_state_init_return_task(ctx,
task_args.origin_proc,
result_device_specific,
task_args.origin_result_ptr,
Realm::Event::NO_EVENT);

// replace spawn_per_device_op_state_init_return_task with:
// NOTE: SM/TODO: direct write assumes single-node shared address space
// For multi-node, replace with UserEvent trigger pattern
*task_args.origin_result_ptr = result_device_specific;

// spawn_per_device_op_state_init_return_task(ctx,
// task_args.origin_proc,
// result_device_specific,
// task_args.origin_result_ptr,
// Realm::Event::NO_EVENT);
}

std::optional<Realm::Event> spawn_per_device_op_state_init_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "realm-execution/tasks/impl/op_task.h"
#include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h"
#include "realm-execution/tasks/impl/per_device_op_state_init_task.h"
#include "realm-execution/tasks/realm_reduction.h"
#include "realm-execution/tasks/task_id_t.h"
#include "utils/exception.h"

Expand All @@ -30,9 +31,18 @@ Realm::Event register_task(Realm::Processor::Kind target_kind,
Realm::ProfilingRequestSet());
}

static void register_reductions() {
// register sum reduction ops
Realm::Runtime rt = Realm::Runtime::get_runtime();
rt.register_reduction<SumReductionFloat>(REDOP_SUM_FLOAT);
rt.register_reduction<SumReductionDouble>(REDOP_SUM_DOUBLE);
// register_reduction is synchronous — no event returned
}

Realm::Event register_all_tasks() {
std::vector<Realm::Event> pending_registrations;

register_reductions();
std::vector<task_id_t> init_task_ids = {
// Init tasks
task_id_t::BATCHNORM_INIT_TASK_ID,
Expand Down
Loading