Files
cppcoro/test/static_thread_pool_tests.cpp
jeanlemotan bbeaa887cd First
2024-07-02 18:13:47 +02:00

291 lines
6.4 KiB
C++

///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/static_thread_pool.hpp>
#include <cppcoro/task.hpp>
#include <cppcoro/sync_wait.hpp>
#include <cppcoro/when_all.hpp>
#include <vector>
#include <thread>
#include <cassert>
#include <chrono>
#include <iostream>
#include <numeric>
#include "doctest/cppcoro_doctest.h"
TEST_SUITE_BEGIN("static_thread_pool");
TEST_CASE("construct/destruct")
{
cppcoro::static_thread_pool threadPool;
CHECK(threadPool.thread_count() == std::thread::hardware_concurrency());
}
TEST_CASE("construct/destruct to specific thread count")
{
cppcoro::static_thread_pool threadPool{ 5 };
CHECK(threadPool.thread_count() == 5);
}
TEST_CASE("run one task")
{
cppcoro::static_thread_pool threadPool{ 2 };
auto initiatingThreadId = std::this_thread::get_id();
cppcoro::sync_wait([&]() -> cppcoro::task<void>
{
co_await threadPool.schedule();
if (std::this_thread::get_id() == initiatingThreadId)
{
FAIL("schedule() did not switch threads");
}
}());
}
TEST_CASE("launch many tasks remotely")
{
cppcoro::static_thread_pool threadPool;
auto makeTask = [&]() -> cppcoro::task<>
{
co_await threadPool.schedule();
};
std::vector<cppcoro::task<>> tasks;
for (std::uint32_t i = 0; i < 100; ++i)
{
tasks.push_back(makeTask());
}
cppcoro::sync_wait(cppcoro::when_all(std::move(tasks)));
}
cppcoro::task<std::uint64_t> sum_of_squares(
std::uint32_t start,
std::uint32_t end,
cppcoro::static_thread_pool& tp)
{
co_await tp.schedule();
auto count = end - start;
if (count > 1000)
{
auto half = start + count / 2;
auto[a, b] = co_await cppcoro::when_all(
sum_of_squares(start, half, tp),
sum_of_squares(half, end, tp));
co_return a + b;
}
else
{
std::uint64_t sum = 0;
for (std::uint64_t x = start; x < end; ++x)
{
sum += x * x;
}
co_return sum;
}
}
TEST_CASE("launch sub-task with many sub-tasks")
{
using namespace std::chrono_literals;
constexpr std::uint64_t limit = 1'000'000'000;
cppcoro::static_thread_pool tp;
// Wait for the thread-pool thread to start up.
std::this_thread::sleep_for(1ms);
auto start = std::chrono::high_resolution_clock::now();
auto result = cppcoro::sync_wait(sum_of_squares(0, limit , tp));
auto end = std::chrono::high_resolution_clock::now();
std::uint64_t sum = 0;
for (std::uint64_t i = 0; i < limit; ++i)
{
sum += i * i;
}
auto end2 = std::chrono::high_resolution_clock::now();
auto toNs = [](auto time)
{
return std::chrono::duration_cast<std::chrono::nanoseconds>(time).count();
};
std::cout
<< "multi-threaded version took " << toNs(end - start) << "ns\n"
<< "single-threaded version took " << toNs(end2 - end) << "ns" << std::endl;
CHECK(result == sum);
}
struct fork_join_operation
{
std::atomic<std::size_t> m_count;
cppcoro::coroutine_handle<> m_coro;
fork_join_operation() : m_count(1) {}
void begin_work() noexcept
{
m_count.fetch_add(1, std::memory_order_relaxed);
}
void end_work() noexcept
{
if (m_count.fetch_sub(1, std::memory_order_acq_rel) == 1)
{
m_coro.resume();
}
}
bool await_ready() noexcept { return m_count.load(std::memory_order_acquire) == 1; }
bool await_suspend(cppcoro::coroutine_handle<> coro) noexcept
{
m_coro = coro;
return m_count.fetch_sub(1, std::memory_order_acq_rel) != 1;
}
void await_resume() noexcept {};
};
template<typename FUNC, typename RANGE, typename SCHEDULER>
cppcoro::task<void> for_each_async(SCHEDULER& scheduler, RANGE& range, FUNC func)
{
using reference_type = decltype(*range.begin());
// TODO: Use awaiter_t here instead. This currently assumes that
// result of scheduler.schedule() doesn't have an operator co_await().
using schedule_operation = decltype(scheduler.schedule());
struct work_operation
{
fork_join_operation& m_forkJoin;
FUNC& m_func;
reference_type m_value;
schedule_operation m_scheduleOp;
work_operation(fork_join_operation& forkJoin, SCHEDULER& scheduler, FUNC& func, reference_type&& value)
: m_forkJoin(forkJoin)
, m_func(func)
, m_value(static_cast<reference_type&&>(value))
, m_scheduleOp(scheduler.schedule())
{
}
bool await_ready() noexcept { return false; }
CPPCORO_NOINLINE
void await_suspend(cppcoro::coroutine_handle<> coro) noexcept
{
fork_join_operation& forkJoin = m_forkJoin;
FUNC& func = m_func;
reference_type value = static_cast<reference_type&&>(m_value);
static_assert(std::is_same_v<decltype(m_scheduleOp.await_suspend(coro)), void>);
forkJoin.begin_work();
// Schedule the next iteration of the loop to run
m_scheduleOp.await_suspend(coro);
func(static_cast<reference_type&&>(value));
forkJoin.end_work();
}
void await_resume() noexcept {}
};
co_await scheduler.schedule();
fork_join_operation forkJoin;
for (auto&& x : range)
{
co_await work_operation{
forkJoin,
scheduler,
func,
static_cast<decltype(x)>(x)
};
}
co_await forkJoin;
}
std::uint64_t collatz_distance(std::uint64_t number)
{
std::uint64_t count = 0;
while (number > 1)
{
if (number % 2 == 0) number /= 2;
else number = number * 3 + 1;
++count;
}
return count;
}
TEST_CASE("for_each_async")
{
cppcoro::static_thread_pool tp;
{
std::vector<std::uint64_t> values(1'000'000);
std::iota(values.begin(), values.end(), 1);
cppcoro::sync_wait([&]() -> cppcoro::task<>
{
auto start = std::chrono::high_resolution_clock::now();
co_await for_each_async(tp, values, [](std::uint64_t& value)
{
value = collatz_distance(value);
});
auto end = std::chrono::high_resolution_clock::now();
std::cout << "for_each_async of " << values.size()
<< " took " << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count()
<< "us" << std::endl;
for (std::size_t i = 0; i < 1'000'000; ++i)
{
CHECK(values[i] == collatz_distance(i + 1));
}
}());
}
{
std::vector<std::uint64_t> values(1'000'000);
std::iota(values.begin(), values.end(), 1);
auto start = std::chrono::high_resolution_clock::now();
for (auto&& x : values)
{
x = collatz_distance(x);
}
auto end = std::chrono::high_resolution_clock::now();
std::cout << "single-threaded for loop of " << values.size()
<< " took " << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count()
<< "us" << std::endl;
}
}
TEST_SUITE_END();