/////////////////////////////////////////////////////////////////////////////// // Copyright (c) Lewis Baker // Licenced under MIT license. See LICENSE.txt for details. /////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include #include #include #include #include #include #include "doctest/cppcoro_doctest.h" TEST_SUITE_BEGIN("static_thread_pool"); TEST_CASE("construct/destruct") { cppcoro::static_thread_pool threadPool; CHECK(threadPool.thread_count() == std::thread::hardware_concurrency()); } TEST_CASE("construct/destruct to specific thread count") { cppcoro::static_thread_pool threadPool{ 5 }; CHECK(threadPool.thread_count() == 5); } TEST_CASE("run one task") { cppcoro::static_thread_pool threadPool{ 2 }; auto initiatingThreadId = std::this_thread::get_id(); cppcoro::sync_wait([&]() -> cppcoro::task { co_await threadPool.schedule(); if (std::this_thread::get_id() == initiatingThreadId) { FAIL("schedule() did not switch threads"); } }()); } TEST_CASE("launch many tasks remotely") { cppcoro::static_thread_pool threadPool; auto makeTask = [&]() -> cppcoro::task<> { co_await threadPool.schedule(); }; std::vector> tasks; for (std::uint32_t i = 0; i < 100; ++i) { tasks.push_back(makeTask()); } cppcoro::sync_wait(cppcoro::when_all(std::move(tasks))); } cppcoro::task sum_of_squares( std::uint32_t start, std::uint32_t end, cppcoro::static_thread_pool& tp) { co_await tp.schedule(); auto count = end - start; if (count > 1000) { auto half = start + count / 2; auto[a, b] = co_await cppcoro::when_all( sum_of_squares(start, half, tp), sum_of_squares(half, end, tp)); co_return a + b; } else { std::uint64_t sum = 0; for (std::uint64_t x = start; x < end; ++x) { sum += x * x; } co_return sum; } } TEST_CASE("launch sub-task with many sub-tasks") { using namespace std::chrono_literals; constexpr std::uint64_t limit = 1'000'000'000; cppcoro::static_thread_pool tp; // Wait for the thread-pool thread to start up. std::this_thread::sleep_for(1ms); auto start = std::chrono::high_resolution_clock::now(); auto result = cppcoro::sync_wait(sum_of_squares(0, limit , tp)); auto end = std::chrono::high_resolution_clock::now(); std::uint64_t sum = 0; for (std::uint64_t i = 0; i < limit; ++i) { sum += i * i; } auto end2 = std::chrono::high_resolution_clock::now(); auto toNs = [](auto time) { return std::chrono::duration_cast(time).count(); }; std::cout << "multi-threaded version took " << toNs(end - start) << "ns\n" << "single-threaded version took " << toNs(end2 - end) << "ns" << std::endl; CHECK(result == sum); } struct fork_join_operation { std::atomic m_count; cppcoro::coroutine_handle<> m_coro; fork_join_operation() : m_count(1) {} void begin_work() noexcept { m_count.fetch_add(1, std::memory_order_relaxed); } void end_work() noexcept { if (m_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { m_coro.resume(); } } bool await_ready() noexcept { return m_count.load(std::memory_order_acquire) == 1; } bool await_suspend(cppcoro::coroutine_handle<> coro) noexcept { m_coro = coro; return m_count.fetch_sub(1, std::memory_order_acq_rel) != 1; } void await_resume() noexcept {}; }; template cppcoro::task for_each_async(SCHEDULER& scheduler, RANGE& range, FUNC func) { using reference_type = decltype(*range.begin()); // TODO: Use awaiter_t here instead. This currently assumes that // result of scheduler.schedule() doesn't have an operator co_await(). using schedule_operation = decltype(scheduler.schedule()); struct work_operation { fork_join_operation& m_forkJoin; FUNC& m_func; reference_type m_value; schedule_operation m_scheduleOp; work_operation(fork_join_operation& forkJoin, SCHEDULER& scheduler, FUNC& func, reference_type&& value) : m_forkJoin(forkJoin) , m_func(func) , m_value(static_cast(value)) , m_scheduleOp(scheduler.schedule()) { } bool await_ready() noexcept { return false; } CPPCORO_NOINLINE void await_suspend(cppcoro::coroutine_handle<> coro) noexcept { fork_join_operation& forkJoin = m_forkJoin; FUNC& func = m_func; reference_type value = static_cast(m_value); static_assert(std::is_same_v); forkJoin.begin_work(); // Schedule the next iteration of the loop to run m_scheduleOp.await_suspend(coro); func(static_cast(value)); forkJoin.end_work(); } void await_resume() noexcept {} }; co_await scheduler.schedule(); fork_join_operation forkJoin; for (auto&& x : range) { co_await work_operation{ forkJoin, scheduler, func, static_cast(x) }; } co_await forkJoin; } std::uint64_t collatz_distance(std::uint64_t number) { std::uint64_t count = 0; while (number > 1) { if (number % 2 == 0) number /= 2; else number = number * 3 + 1; ++count; } return count; } TEST_CASE("for_each_async") { cppcoro::static_thread_pool tp; { std::vector values(1'000'000); std::iota(values.begin(), values.end(), 1); cppcoro::sync_wait([&]() -> cppcoro::task<> { auto start = std::chrono::high_resolution_clock::now(); co_await for_each_async(tp, values, [](std::uint64_t& value) { value = collatz_distance(value); }); auto end = std::chrono::high_resolution_clock::now(); std::cout << "for_each_async of " << values.size() << " took " << std::chrono::duration_cast(end - start).count() << "us" << std::endl; for (std::size_t i = 0; i < 1'000'000; ++i) { CHECK(values[i] == collatz_distance(i + 1)); } }()); } { std::vector values(1'000'000); std::iota(values.begin(), values.end(), 1); auto start = std::chrono::high_resolution_clock::now(); for (auto&& x : values) { x = collatz_distance(x); } auto end = std::chrono::high_resolution_clock::now(); std::cout << "single-threaded for loop of " << values.size() << " took " << std::chrono::duration_cast(end - start).count() << "us" << std::endl; } } TEST_SUITE_END();