diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc index 32af5af035..1b03ec8028 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc @@ -474,6 +474,12 @@ void TemperedLB::runLB(TimeType total_load) { } } +void TemperedLB::clearDataStructures() { + potential_recipients_.clear(); + load_info_.clear(); + is_overloaded_ = is_underloaded_ = false; +} + void TemperedLB::doLBStages(TimeType start_imb) { decltype(this->cur_objs_) best_objs; LoadType best_load = 0; @@ -483,11 +489,7 @@ void TemperedLB::doLBStages(TimeType start_imb) { auto this_node = theContext()->getNode(); for (trial_ = 0; trial_ < num_trials_; ++trial_) { - // Clear out data structures - selected_.clear(); - underloaded_.clear(); - load_info_.clear(); - is_overloaded_ = is_underloaded_ = false; + clearDataStructures(); TimeType best_imb_this_trial = start_imb + 10; @@ -504,11 +506,7 @@ void TemperedLB::doLBStages(TimeType start_imb) { } this_new_load_ = this_load; } else { - // Clear out data structures from previous iteration - selected_.clear(); - underloaded_.clear(); - load_info_.clear(); - is_overloaded_ = is_underloaded_ = false; + clearDataStructures(); } vt_debug_print( @@ -667,8 +665,8 @@ void TemperedLB::informAsync() { vtAssert(k_max_ > 0, "Number of rounds (k) must be greater than zero"); auto const this_node = theContext()->getNode(); - if (is_underloaded_) { - underloaded_.insert(this_node); + if (canPropagate()) { + potential_recipients_.insert(this_node); } setup_done_ = false; @@ -682,7 +680,7 @@ void TemperedLB::informAsync() { auto propagate_epoch = theTerm()->makeEpochCollective("TemperedLB: informAsync"); // Underloaded start the round - if (is_underloaded_) { + if (canPropagate()) { uint8_t k_cur_async = 0; propagateRound(k_cur_async, false, propagate_epoch); } @@ -695,7 +693,7 @@ void TemperedLB::informAsync() { vt_debug_print( terse, temperedlb, "TemperedLB::informAsync: trial={}, iter={}, known underloaded={}\n", - trial_, iter_, underloaded_.size() + trial_, iter_, potential_recipients_.size() ); } @@ -718,13 +716,13 @@ void TemperedLB::informSync() { vtAssert(k_max_ > 0, "Number of rounds (k) must be greater than zero"); auto const this_node = theContext()->getNode(); - if (is_underloaded_) { - underloaded_.insert(this_node); + if (canPropagate()) { + potential_recipients_.insert(this_node); } - auto propagate_this_round = is_underloaded_; + auto propagate_this_round = canPropagate(); propagate_next_round_ = false; - new_underloaded_ = underloaded_; + new_potential_recipients_ = potential_recipients_; new_load_info_ = load_info_; setup_done_ = false; @@ -754,7 +752,7 @@ void TemperedLB::informSync() { propagate_this_round = propagate_next_round_; propagate_next_round_ = false; - underloaded_ = new_underloaded_; + potential_recipients_ = new_potential_recipients_; load_info_ = new_load_info_; } @@ -762,7 +760,7 @@ void TemperedLB::informSync() { vt_debug_print( terse, temperedlb, "TemperedLB::informSync: trial={}, iter={}, known underloaded={}\n", - trial_, iter_, underloaded_.size() + trial_, iter_, potential_recipients_.size() ); } @@ -793,8 +791,7 @@ void TemperedLB::propagateRound(uint8_t k_cur, bool sync, EpochType epoch) { gen_propagate_.seed(seed_()); } - auto& selected = selected_; - selected = underloaded_; + auto& selected = potential_recipients_; if (selected.find(this_node) == selected.end()) { selected.insert(this_node); } @@ -871,7 +868,7 @@ void TemperedLB::propagateIncomingAsync(LoadMsgAsync* msg) { load_info_[elm.first] = elm.second; if (isUnderloaded(elm.second)) { - underloaded_.insert(elm.first); + potential_recipients_.insert(elm.first); } } } @@ -905,7 +902,7 @@ void TemperedLB::propagateIncomingSync(LoadMsgSync* msg) { new_load_info_[elm.first] = elm.second; if (isUnderloaded(elm.second)) { - new_underloaded_.insert(elm.first); + new_potential_recipients_.insert(elm.first); } } } @@ -996,7 +993,7 @@ NodeType TemperedLB::sampleFromCMF( return selected_node; } -std::vector TemperedLB::makeUnderloaded() const { +std::vector TemperedLB::getPotentialRecipients() const { std::vector under = {}; for (auto&& elm : load_info_) { if (isUnderloaded(elm.second)) { @@ -1203,11 +1200,11 @@ void TemperedLB::decide() { int n_transfers = 0, n_rejected = 0; - if (is_overloaded_) { - std::vector under = makeUnderloaded(); + if (canMigrate()) { + auto potential_recipients = getPotentialRecipients(); std::unordered_map migrate_objs; - if (under.size() > 0) { + if (not potential_recipients.empty()) { std::vector ordered_obj_ids = orderObjects( obj_ordering_, cur_objs_, this_new_load_, target_max_load_ ); @@ -1219,24 +1216,24 @@ void TemperedLB::decide() { if (cmf_type_ == CMFTypeEnum::Original) { // Rebuild the relaxed underloaded set based on updated load of this node - under = makeUnderloaded(); - if (under.size() == 0) { + potential_recipients = getPotentialRecipients(); + if (potential_recipients.size() == 0) { break; } } else if (cmf_type_ == CMFTypeEnum::NormByMaxExcludeIneligible) { // Rebuild the underloaded set and eliminate processors that will // fail the Criterion for this object - under = makeSufficientlyUnderloaded(obj_load); - if (under.size() == 0) { + potential_recipients = makeSufficientlyUnderloaded(obj_load); + if (potential_recipients.size() == 0) { ++n_rejected; iter++; continue; } } // Rebuild the CMF with the new loads taken into account - auto cmf = createCMF(under); + auto cmf = createCMF(potential_recipients); // Select a node using the CMF - auto const selected_node = sampleFromCMF(under, cmf); + auto const selected_node = sampleFromCMF(potential_recipients, cmf); vt_debug_print( verbose, temperedlb, @@ -1256,13 +1253,13 @@ void TemperedLB::decide() { vt_debug_print( verbose, temperedlb, - "TemperedLB::decide: trial={}, iter={}, under.size()={}, " - "selected_node={}, selected_load={:e}, obj_id={:x}, home={}, " - "obj_load={}, target_max_load={}, this_new_load_={}, " - "criterion={}\n", + "TemperedLB::decide: trial={}, iter={}, " + "potential_recipients.size()={}, selected_node={}, " + "selected_load={:e}, obj_id={:x}, home={}, obj_load={}, " + "target_max_load={}, this_new_load_={}, criterion={}\n", trial_, iter_, - under.size(), + potential_recipients.size(), selected_node, selected_load, obj_id.id, @@ -1361,7 +1358,7 @@ void TemperedLB::migrate() { vtAssertExpr(false); } -TimeType TemperedLB::getModeledValue(const elm::ElementIDStruct& obj) { +TimeType TemperedLB::getModeledValue(const elm::ElementIDStruct& obj) const { return load_model_->getModeledLoad( obj, {balance::PhaseOffset::NEXT_PHASE, balance::PhaseOffset::WHOLE_PHASE} ); diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h index 6839ae6eb7..345b907b6f 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h @@ -94,24 +94,37 @@ struct TemperedLB : BaseLB { void informSync(); void decide(); void migrate(); + void clearDataStructures(); + + /** + * \brief Decides whether the rank can perform the migration + */ + virtual bool canMigrate() const { return is_overloaded_; } + /** + * \brief Decides whether the rank can initiate information propagation stage + * + * TemperedLB restricts this to underloaded ranks + */ + virtual bool canPropagate() const { return is_underloaded_; } + bool isDeterministic() const { return deterministic_; } void propagateRound(uint8_t k_cur_async, bool sync, EpochType epoch = no_epoch); void propagateIncomingAsync(LoadMsgAsync* msg); void propagateIncomingSync(LoadMsgSync* msg); - bool isUnderloaded(LoadType load) const; + virtual bool isUnderloaded(LoadType load) const; bool isUnderloadedRelaxed(LoadType over, LoadType under) const; bool isOverloaded(LoadType load) const; std::vector createCMF(NodeSetType const& under); NodeType sampleFromCMF(NodeSetType const& under, std::vector const& cmf); - std::vector makeUnderloaded() const; + virtual std::vector getPotentialRecipients() const; std::vector makeSufficientlyUnderloaded( LoadType load_to_accommodate ) const; ElementLoadType::iterator selectObject( LoadType size, ElementLoadType& load, std::set const& available ); - virtual TimeType getModeledValue(const elm::ElementIDStruct& obj); + virtual TimeType getModeledValue(const elm::ElementIDStruct& obj) const; void lazyMigrateObjsTo(EpochType epoch, NodeType node, ObjsType const& objs); void inLazyMigrations(balance::LazyMigrationMsg* msg); @@ -121,6 +134,9 @@ struct TemperedLB : BaseLB { void setupDone(ReduceMsgType* msg); + std::unordered_map load_info_ = {}; + std::unordered_map cur_objs_ = {}; + private: uint16_t f_ = 0; uint8_t k_max_ = 0; @@ -159,15 +175,12 @@ struct TemperedLB : BaseLB { */ bool target_pole_ = false; std::random_device seed_; - std::unordered_map load_info_ = {}; std::unordered_map new_load_info_ = {}; objgroup::proxy::Proxy proxy_ = {}; bool is_overloaded_ = false; bool is_underloaded_ = false; - std::unordered_set selected_ = {}; - std::unordered_set underloaded_ = {}; - std::unordered_set new_underloaded_ = {}; - std::unordered_map cur_objs_ = {}; + std::unordered_set potential_recipients_ = {}; + std::unordered_set new_potential_recipients_ = {}; LoadType this_new_load_ = 0.0; TimeType new_imbalance_ = 0.0; TimeType target_max_load_ = 0.0; diff --git a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc index 76362c5466..0b9ad16315 100644 --- a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc +++ b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc @@ -48,6 +48,8 @@ #include "vt/vrt/collection/balance/model/load_model.h" #include "vt/vrt/collection/balance/model/weighted_communication_volume.h" +#include + namespace vt { namespace vrt { namespace collection { namespace lb { TemperedWMin::~TemperedWMin() { @@ -110,7 +112,25 @@ void TemperedWMin::inputParams(balance::ConfigEntry* config) { load_model_ptr = theLBManager()->getLoadModel().get(); } -TimeType TemperedWMin::getModeledValue(const elm::ElementIDStruct& obj) { +std::vector TemperedWMin::getPotentialRecipients() const { + auto const this_node = theContext()->getNode(); + std::vector nodes = {}; + + for (auto&& elm : load_info_) { + auto const node = elm.first; + if (node != this_node) { + nodes.push_back(node); + } + } + + if (isDeterministic()) { + std::sort(nodes.begin(), nodes.end()); + } + + return nodes; +} + +TimeType TemperedWMin::getModeledValue(const elm::ElementIDStruct& obj) const { vtAssert( theLBManager()->getLoadModel().get() == load_model_ptr, "Load model must not change" @@ -121,4 +141,13 @@ TimeType TemperedWMin::getModeledValue(const elm::ElementIDStruct& obj) { return total_work_model_->getModeledLoad(obj, when); } +bool TemperedWMin::canMigrate() const { + auto const this_node = theContext()->getNode(); + auto const another_rank = std::find_if( + load_info_.begin(), load_info_.end(), + [this_node](auto const& elm) { return elm.first != this_node; } + ); + return (not cur_objs_.empty()) and (another_rank != load_info_.end()); +} + }}}} // namespace vt::vrt::collection::lb diff --git a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h index dcdaa43bff..a4b604f088 100644 --- a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h +++ b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h @@ -64,7 +64,21 @@ struct TemperedWMin : TemperedLB { void inputParams(balance::ConfigEntry* config) override; protected: - TimeType getModeledValue(const elm::ElementIDStruct& obj) override; + /** + * Allow migration when there are objects to migrate and other ranks are known + */ + bool canMigrate() const override; + /** + * All ranks are allowed to initiate the information propagation stage + */ + bool canPropagate() const override { return true; } + /** + * TemperedWMin does not care about underloaded + */ + bool isUnderloaded(LoadType load) const override { return true; } + + TimeType getModeledValue(const elm::ElementIDStruct& obj) const override; + std::vector getPotentialRecipients() const override; private: std::shared_ptr total_work_model_ = nullptr;