Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 189 additions & 52 deletions src/windows/wslc/core/AsyncExecution.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,91 +9,228 @@ Module Name:
Abstract:

Provides ForEachAsync, a generic helper for executing a work callback
over a collection concurrently in bounded batches using std::async.
over a collection concurrently using the Windows thread pool with bounded
concurrency and cooperative cancellation.

--*/
#pragma once

#include <algorithm>
#include <future>
#include <chrono>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include <wil/resource.h>
#include <wil/result_macros.h>

namespace wsl::windows::wslc {

// Invokes onWork for each element in items concurrently, in batches of batchSize.
// Results are delivered serially to onSuccess. Errors are delivered serially to onError.
//
// This keeps wall time proportional to ceil(N / batchSize) rather than N for operations
// that have inherent per-item latency (e.g. network or IPC calls).
//
// Note: worker threads have no guaranteed per-thread initialization (e.g. COM). Callers
// whose onWork requires per-thread setup (such as CoInitializeEx) are responsible for
// performing it at the start of the onWork lambda.
//
// TWork : TItem -> TResult (called concurrently)
// TSuccess: TResult -> void (called serially)
// TError : (TItem, wil::ResultException) -> void (called serially)
template <typename TItem, typename TWork, typename TSuccess, typename TError>
void ForEachAsync(const std::vector<TItem>& items, TWork onWork, TSuccess onSuccess, TError onError, size_t batchSize = 10)
{
WI_ASSERT(batchSize > 0);
THROW_HR_IF(E_INVALIDARG, batchSize == 0);
namespace detail {

using TResult = decltype(onWork(std::declval<TItem>()));

struct BatchResult
template <typename TItem, typename TResult>
struct WorkerResult
{
explicit BatchResult(TItem capturedItem) : item(std::move(capturedItem))
{
}

TItem item;
std::optional<TResult> result;
wil::ResultException error{S_OK};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify things, I would recommend storing the error as a std::exception_ptr (null if no error was thrown).

This will have the benefit of allowing us to rethrow non-wil exceptions easily

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and we can get rid of hasError)

bool hasError{false};
};

for (size_t batchStart = 0; batchStart < items.size(); batchStart += batchSize)
// Holds all state for one thread pool worker. Owned via shared_ptr so the memory
// remains valid if ForEachAsync unwinds while a work item is still running.
template <typename TItem, typename TWork, typename TResult>
struct SharedWorker
{
WorkerResult<TItem, TResult> workerResult;
Comment thread
dkbennett marked this conversation as resolved.
TWork* onWork{nullptr};
HANDLE cancelHandle{nullptr};
wil::unique_event done;
wil::unique_threadpool_work work;
};

// Manages a fixed pool of SharedWorkers and a shared cancellation event.
template <typename TItem, typename TWork, typename TSuccess, typename TError>
struct WorkerPool
{
const size_t batchEnd = std::min(batchStart + batchSize, items.size());
NON_COPYABLE(WorkerPool);
NON_MOVABLE(WorkerPool);

using TResult = decltype(std::declval<TWork>()(std::declval<TItem>(), std::declval<HANDLE>()));
using TSharedWorker = SharedWorker<TItem, TWork, TResult>;

std::vector<std::shared_ptr<TSharedWorker>> workers;
std::vector<HANDLE> doneHandles;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the "timeout on cancel" logic, we can get rid of those

wil::unique_event cancelEvent;
std::chrono::milliseconds timeout;
DWORD cancelDrainMs{};

WorkerPool(size_t poolSize, TWork& onWork, std::chrono::milliseconds timeout_, std::chrono::milliseconds cancelDrainTimeout) :
timeout(timeout_), cancelDrainMs(static_cast<DWORD>(cancelDrainTimeout.count()))
{
cancelEvent.create(wil::EventOptions::ManualReset);

workers.reserve(poolSize);
doneHandles.reserve(poolSize);

for (size_t i = 0; i < poolSize; ++i)
{
auto worker = std::make_shared<TSharedWorker>();
worker->done.create(wil::EventOptions::ManualReset);
worker->onWork = &onWork;
worker->cancelHandle = cancelEvent.get();

// Work item is created once per worker and reused for each dispatched item.
worker->work.reset(::CreateThreadpoolWork(ThreadPoolCallback, worker.get(), nullptr));
THROW_LAST_ERROR_IF(!worker->work);

std::vector<std::future<BatchResult>> futures;
futures.reserve(batchEnd - batchStart);
doneHandles.push_back(worker->done.get());
workers.push_back(std::move(worker));
}
}

void Launch(size_t workerIndex, const TItem& item)
{
auto& worker = workers[workerIndex];
worker->workerResult = WorkerResult<TItem, TResult>{};
worker->workerResult.item = item;
Comment thread
dkbennett marked this conversation as resolved.
Outdated
Comment thread
dkbennett marked this conversation as resolved.
Outdated
Comment thread
dkbennett marked this conversation as resolved.
Outdated
worker->done.ResetEvent();
::SubmitThreadpoolWork(worker->work.get());
}

void Drain(size_t workerIndex, TSuccess& onSuccess, TError& onError)
{
auto& worker = workers[workerIndex];

// Ensure the callback has fully returned before reading results.
::WaitForThreadpoolWorkCallbacks(worker->work.get(), FALSE);
if (worker->workerResult.hasError)
{
onError(worker->workerResult.item, worker->workerResult.error);
}
else if (worker->workerResult.result.has_value())
{
onSuccess(*worker->workerResult.result);
}
}

for (size_t i = batchStart; i < batchEnd; ++i)
// Signals cancellation, waits up to cancelDrainMs for workers to exit, then throws ERROR_TIMEOUT.
// Workers that do not exit within cancelDrainMs are abandoned - they retain shared_ptr ownership
// of their state. onWork implementations must check the cancel event at natural checkpoints and
// exit promptly.
//
// Note: TerminateThread() is not used - it skips C++ destructors, leaves user-mode locks
// permanently held (causing deadlocks), and corrupts COM apartment state.
[[noreturn]] void CancelAndThrow(size_t remainingItems)
{
const auto& item = items[i];
futures.push_back(std::async(std::launch::async, [&onWork, item]() -> BatchResult {
BatchResult result{item};
try
{
result.result = onWork(item);
}
catch (const wil::ResultException& ex)
{
result.hasError = true;
result.error = ex;
}
return result;
}));
cancelEvent.SetEvent();

::WaitForMultipleObjects(static_cast<DWORD>(doneHandles.size()), doneHandles.data(), TRUE, cancelDrainMs);
Comment thread
dkbennett marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We should check the return code of WaitForMultipleObjects() here


THROW_HR_MSG(
HRESULT_FROM_WIN32(ERROR_TIMEOUT),
"ForEachAsync: worker exceeded timeout of %lld ms (%zu items remaining).",
static_cast<long long>(timeout.count()),
remainingItems);
}

for (auto& future : futures)
// Thread pool callback - invoked on a pool thread for each submitted work item.
static void CALLBACK ThreadPoolCallback(PTP_CALLBACK_INSTANCE, void* context, PTP_WORK) noexcept
{
auto batchResult = future.get();
auto& worker = *static_cast<TSharedWorker*>(context);

if (batchResult.hasError)
try
{
onError(batchResult.item, batchResult.error);
worker.workerResult.result = (*worker.onWork)(worker.workerResult.item, worker.cancelHandle);
}
else if (batchResult.result.has_value())
catch (const wil::ResultException& ex)
{
onSuccess(*batchResult.result);
worker.workerResult.hasError = true;
worker.workerResult.error = ex;
}
Comment thread
dkbennett marked this conversation as resolved.

worker.done.SetEvent();
}
};

} // namespace detail

// Invokes onWork for each element in items concurrently using the Windows thread pool,
// with concurrency bounded to poolSize. Results are delivered serially to onSuccess.
// Errors are delivered serially to onError.
//
// onWork receives a HANDLE to a cancellation event and should check it at natural
// checkpoints using WaitForSingleObject(cancel, 0), returning early if it is set.
// On timeout, the event is signalled and ForEachAsync waits up to cancelDrainTimeout
// for workers to exit before throwing HRESULT_FROM_WIN32(ERROR_TIMEOUT).
//
// poolSize must not exceed MAXIMUM_WAIT_OBJECTS (64).
//
// Note: thread pool threads have no guaranteed per-thread initialization. Callers
// whose onWork requires per-thread setup (e.g. CoInitializeEx) must perform it at
// the start of the onWork lambda.
//
// TWork : (TItem, HANDLE cancelEvent) -> TResult (called concurrently)
// TSuccess: TResult -> void (called serially)
// TError : (TItem, wil::ResultException) -> void (called serially)
template <typename TItem, typename TWork, typename TSuccess, typename TError>
void ForEachAsync(
const std::vector<TItem>& items,
TWork onWork,
TSuccess onSuccess,
TError onError,
size_t poolSize = 10,
std::chrono::milliseconds timeout = std::chrono::milliseconds::max(),
std::chrono::milliseconds cancelDrainTimeout = std::chrono::seconds(5))
{
THROW_HR_IF(E_INVALIDARG, poolSize == 0);
THROW_HR_IF(E_INVALIDARG, poolSize > MAXIMUM_WAIT_OBJECTS);
Comment thread
dkbennett marked this conversation as resolved.
if (items.empty())
{
return;
}

const DWORD timeoutMs = (timeout == std::chrono::milliseconds::max()) ? INFINITE : static_cast<DWORD>(timeout.count());
const size_t workerCount = std::min(poolSize, items.size());

detail::WorkerPool<TItem, TWork, TSuccess, TError> pool{workerCount, onWork, timeout, cancelDrainTimeout};

// Fill the pool - submit one item per worker to saturate all workers immediately.
size_t nextItem = 0;
for (; nextItem < workerCount; ++nextItem)
{
pool.Launch(nextItem, items[nextItem]);
}

// Keep the pool full - as each worker completes, drain its result and immediately
// assign it the next pending item. WaitForMultipleObjects(FALSE) wakes on the first
// completion, so no worker idles while work remains.
while (nextItem < items.size())
{
const DWORD waitResult = ::WaitForMultipleObjects(static_cast<DWORD>(workerCount), pool.doneHandles.data(), FALSE, timeoutMs);

if (waitResult == WAIT_TIMEOUT)
{
pool.CancelAndThrow(items.size() - nextItem);
Comment thread
dkbennett marked this conversation as resolved.
}

THROW_LAST_ERROR_IF(waitResult == WAIT_FAILED);
const size_t workerIndex = waitResult - WAIT_OBJECT_0;
pool.Drain(workerIndex, onSuccess, onError);
pool.Launch(workerIndex, items[nextItem++]);
Comment thread
dkbennett marked this conversation as resolved.
}

// Wait for all in-flight workers to finish and collect their final results.
const DWORD finalWait = ::WaitForMultipleObjects(static_cast<DWORD>(workerCount), pool.doneHandles.data(), TRUE, timeoutMs);
Comment thread
dkbennett marked this conversation as resolved.
Outdated
if (finalWait == WAIT_TIMEOUT)
{
pool.CancelAndThrow(0);
}

THROW_LAST_ERROR_IF(finalWait == WAIT_FAILED);
for (size_t i = 0; i < workerCount; ++i)
{
pool.Drain(i, onSuccess, onError);
}
}

Expand Down
21 changes: 18 additions & 3 deletions src/windows/wslc/tasks/ContainerTasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,21 @@ void ShowContainerStats(CLIExecutionContext& context)
}
}

// Fetch stats for all containers concurrently in batches. The Docker engine blocks for ~1s
// Fetch stats for all containers concurrently. The Docker engine blocks for ~1s
// per request to collect a valid precpu_stats sample, so issuing requests in parallel keeps
// wall time proportional to ceil(N / batchSize) rather than N.
Comment thread
dkbennett marked this conversation as resolved.
Outdated
nlohmann::json statsJson = nlohmann::json::array();
wsl::windows::wslc::ForEachAsync<std::wstring>(
containers,
// Work to be done for each container ID on a separate thread.
[&session](const std::wstring& containerId) {
// cancelHandle is signalled if the overall operation times out, check it before
// the blocking Stats call so we exit cooperatively without waiting a full ~1s sample.
[&session, userSpecifiedContainers](const std::wstring& containerId, HANDLE cancelHandle) {
Comment thread
dkbennett marked this conversation as resolved.
Outdated
if (::WaitForSingleObject(cancelHandle, 0) == WAIT_OBJECT_0)
{
THROW_HR(HRESULT_FROM_WIN32(ERROR_CANCELLED));
}

// ContainerService::Stats makes COM calls, so we must ensure COM is initialized on this thread.
auto comCleanup = wil::CoInitializeEx(COINIT_MULTITHREADED);
return ComputeContainerStatsJson(ContainerService::Stats(session, WideToMultiByte(containerId)));
Expand All @@ -582,6 +589,12 @@ void ShowContainerStats(CLIExecutionContext& context)
[&](const nlohmann::json& entry) { statsJson.push_back(entry); },
// On Error
[&](const std::wstring& containerId, wil::ResultException error) {
if (error.GetErrorCode() == HRESULT_FROM_WIN32(ERROR_CANCELLED))
{
// Cancellation due to timeout. Let ForEachAsync surface ERROR_TIMEOUT to the caller.
return;
}

if (!userSpecifiedContainers)
{
switch (error.GetErrorCode())
Expand All @@ -599,7 +612,9 @@ void ShowContainerStats(CLIExecutionContext& context)
LOG_HR_MSG(error.GetErrorCode(), "Failed to get stats for container %ws", containerId.c_str());
throw error;
},
10 // Batch Size - chosen to be around typical expected container use while protecting against extreme cases.
10, // Thread pool size - typical expected container use while protecting against extreme cases.
std::chrono::seconds(30), // Timeout - Docker stats blocks ~1s per sample; 30s gives ample headroom on a taxed system.
std::chrono::seconds(5) // Cancel drain - grace period for workers to observe the cancel event and exit cleanly.
);

FormatType format = FormatType::Table; // Default is table
Expand Down
Loading