Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/engine/data_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "base/hash.h"
#include "base/thread.h"
Expand Down Expand Up @@ -108,11 +107,13 @@ bool DataLoader::RegisterRequest(const EngineReloadRequest& request) {

auto it = std::find_if(requests_.begin(), requests_.end(),
[id](const RequestData& v) { return v.id == id; });
bool is_new = false;
if (it != requests_.end()) {
it->sequence_id = sequence_id_;
} else {
requests_.emplace_back(RequestData{id, sequence_id_, request});
LOG(INFO) << "New request is registered: " << requests_.back();
is_new = true;
}

// Sorts the requests so requests[0] stores the request with
Expand All @@ -125,8 +126,8 @@ bool DataLoader::RegisterRequest(const EngineReloadRequest& request) {
});

// Needs the reloading process only when requests[0] is different from
// current_request_id_.
return current_request_id_ != requests_.front().id;
// current_request_id_ or this is a new request.
return is_new || current_request_id_ != requests_.front().id;
}

void DataLoader::ReportLoadFailure(const DataLoader::RequestData& request) {
Expand Down Expand Up @@ -198,9 +199,15 @@ void DataLoader::StartReloadLoop(DataLoader::ReloadedCallback callback) {
// until a new high priority data is registered. Retry the loop when a new
// high priority data is registered while waiting.
constexpr absl::Duration kTimeout = absl::Milliseconds(100);
if (!high_priority_data_registered_.HasBeenNotified() &&
high_priority_data_registered_.WaitForNotificationWithTimeout(
kTimeout)) {
bool woken_by_high_priority_notification = false;
{
absl::MutexLock lock(&signal_mu_);
if (!high_priority_notified_) {
woken_by_high_priority_notification =
signal_cv_.WaitWithTimeout(&signal_mu_, kTimeout);
}
}
if (woken_by_high_priority_notification) {
continue;
}

Expand Down Expand Up @@ -229,11 +236,14 @@ bool DataLoader::StartNewDataBuildTask(const EngineReloadRequest& request,
return false;
}

// Receives high priority data.
// Wakes up the loading thread to re-evaluate pending requests.
constexpr int kHighPriority = 10;
if (!high_priority_data_registered_.HasBeenNotified() &&
request.priority() <= kHighPriority) {
high_priority_data_registered_.Notify();
{
absl::MutexLock lock(&signal_mu_);
if (!high_priority_notified_ && request.priority() <= kHighPriority) {
high_priority_notified_ = true;
}
signal_cv_.SignalAll();
}

if (!IsRunning()) {
Expand Down
11 changes: 7 additions & 4 deletions src/engine/data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "base/thread.h"
#include "engine/modules.h"
#include "protocol/engine_builder.pb.h"
Expand Down Expand Up @@ -86,7 +85,9 @@ class DataLoader {

// Disables specific handling for high priority data.
void NotifyHighPriorityDataRegisteredForTesting() {
high_priority_data_registered_.Notify();
absl::MutexLock lock(&signal_mu_);
high_priority_notified_ = true;
signal_cv_.SignalAll();
}

private:
Expand Down Expand Up @@ -141,8 +142,10 @@ class DataLoader {
// meaning that the model registered later is preferred.
uint32_t sequence_id_ ABSL_GUARDED_BY(mutex_) = 0;

// Notify when a new high priority data is registered.
absl::Notification high_priority_data_registered_;
// Used to signal the loading thread to re-evaluate pending requests.
mutable absl::Mutex signal_mu_;
bool high_priority_notified_ ABSL_GUARDED_BY(signal_mu_) = false;
absl::CondVar signal_cv_;

TaskManager load_;
};
Expand Down