commit bbeaa887cddd4065f72401cd4f39cb95ea757666 Author: jeanlemotan Date: Tue Jul 2 18:13:47 2024 +0200 First diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..d2edd51 --- /dev/null +++ b/.clang-format @@ -0,0 +1,81 @@ +--- +BasedOnStyle: LLVM +--- +Language: Cpp +Standard: Cpp11 +ColumnLimit: 100 +TabWidth: 4 +IndentWidth: 4 +UseTab: ForContinuationAndIndentation +AccessModifierOffset: -4 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: false +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortIfStatementsOnASingleLine: false +AllowShortFunctionsOnASingleLine: InlineOnly +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BraceWrapping: { + AfterClass: true, + AfterControlStatement: true, + AfterEnum: true, + AfterFunction: true, + AfterNamespace: true, + AfterStruct: true, + AfterUnion: true, + BeforeCatch: true, + BeforeElse: true, + IndentBraces: false, + #SplitEmptyFunctionBody: false +} +BreakBeforeInheritanceComma: true +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeComma +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +IncludeCategories: +- Regex: '^$' + Priority: 1 +- Regex: '^`](#taskt) + * [`shared_task`](#shared_taskt) + * [`generator`](#generatort) + * [`recursive_generator`](#recursive_generatort) + * [`async_generator`](#async_generatort) +* Awaitable Types + * [`single_consumer_event`](#single_consumer_event) + * [`single_consumer_async_auto_reset_event`](#single_consumer_async_auto_reset_event) + * [`async_mutex`](#async_mutex) + * [`async_manual_reset_event`](#async_manual_reset_event) + * [`async_auto_reset_event`](#async_auto_reset_event) + * [`async_latch`](#async_latch) + * [`sequence_barrier`](#sequence_barrier) + * [`multi_producer_sequencer`](#multi_producer_sequencer) + * [`single_producer_sequencer`](#single_producer_sequencer) +* Functions + * [`sync_wait()`](#sync_wait) + * [`when_all()`](#when_all) + * [`when_all_ready()`](#when_all_ready) + * [`fmap()`](#fmap) + * [`schedule_on()`](#schedule_on) + * [`resume_on()`](#resume_on) +* [Cancellation](#Cancellation) + * `cancellation_token` + * `cancellation_source` + * `cancellation_registration` +* Schedulers and I/O + * [`static_thread_pool`](#static_thread_pool) + * [`io_service` and `io_work_scope`](#io_service-and-io_work_scope) + * [`file`, `readable_file`, `writable_file`](#file-readable_file-writable_file) + * [`read_only_file`, `write_only_file`, `read_write_file`](#read_only_file-write_only_file-read_write_file) +* Networking + * [`socket`](#socket) + * [`ip_address`, `ipv4_address`, `ipv6_address`](#ip_address-ipv4_address-ipv6_address) + * [`ip_endpoint`, `ipv4_endpoint`, `ipv6_endpoint`](#ip_endpoint-ipv4_endpoint-ipv6_endpoint) +* Metafunctions + * [`is_awaitable`](#is_awaitablet) + * [`awaitable_traits`](#awaitable_traitst) +* Concepts + * [`Awaitable`](#Awaitablet-concept) + * [`Awaiter`](#Awaitert-concept) + * [`Scheduler`](#Scheduler-concept) + * [`DelayedScheduler`](#DelayedScheduler-concept) + +This library is an experimental library that is exploring the space of high-performance, +scalable asynchronous programming abstractions that can be built on top of the C++ coroutines +proposal. + +It has been open-sourced in the hope that others will find it useful and that the C++ community +can provide feedback on it and ways to improve it. + +The Linux version is functional except for the `io_context` and file I/O related classes which have not yet been implemented for Linux (see issue [#15](https://github.com/lewissbaker/cppcoro/issues/15) for more info). + +# Class Details + +## `task` + +A task represents an asynchronous computation that is executed lazily in +that the execution of the coroutine does not start until the task is awaited. + +Example: +```c++ +#include +#include + +cppcoro::task count_lines(std::string path) +{ + auto file = co_await cppcoro::read_only_file::open(path); + + int lineCount = 0; + + char buffer[1024]; + size_t bytesRead; + std::uint64_t offset = 0; + do + { + bytesRead = co_await file.read(offset, buffer, sizeof(buffer)); + lineCount += std::count(buffer, buffer + bytesRead, '\n'); + offset += bytesRead; + } while (bytesRead > 0); + + co_return lineCount; +} + +cppcoro::task<> usage_example() +{ + // Calling function creates a new task but doesn't start + // executing the coroutine yet. + cppcoro::task countTask = count_lines("foo.txt"); + + // ... + + // Coroutine is only started when we later co_await the task. + int lineCount = co_await countTask; + + std::cout << "line count = " << lineCount << std::endl; +} +``` + +API Overview: +```c++ +// +namespace cppcoro +{ + template + class task + { + public: + + using promise_type = ; + using value_type = T; + + task() noexcept; + + task(task&& other) noexcept; + task& operator=(task&& other); + + // task is a move-only type. + task(const task& other) = delete; + task& operator=(const task& other) = delete; + + // Query if the task result is ready. + bool is_ready() const noexcept; + + // Wait for the task to complete and return the result or rethrow the + // exception if the operation completed with an unhandled exception. + // + // If the task is not yet ready then the awaiting coroutine will be + // suspended until the task completes. If the the task is_ready() then + // this operation will return the result synchronously without suspending. + Awaiter operator co_await() const & noexcept; + Awaiter operator co_await() const && noexcept; + + // Returns an awaitable that can be co_await'ed to suspend the current + // coroutine until the task completes. + // + // The 'co_await t.when_ready()' expression differs from 'co_await t' in + // that when_ready() only performs synchronization, it does not return + // the result or rethrow the exception. + // + // This can be useful if you want to synchronize with the task without + // the possibility of it throwing an exception. + Awaitable when_ready() const noexcept; + }; + + template + void swap(task& a, task& b); + + // Creates a task that yields the result of co_await'ing the specified awaitable. + // + // This can be used as a form of type-erasure of the concrete awaitable, allowing + // different awaitables that return the same await-result type to be stored in + // the same task type. + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits::await_result_t> + task make_task(AWAITABLE awaitable); +} +``` + +You can create a `task` object by calling a coroutine function that returns +a `task`. + +The coroutine must contain a usage of either `co_await` or `co_return`. +Note that a `task` coroutine may not use the `co_yield` keyword. + +When a coroutine that returns a `task` is called, a coroutine frame +is allocated if necessary and the parameters are captured in the coroutine +frame. The coroutine is suspended at the start of the coroutine body and +execution is returned to the caller and a `task` value that represents +the asynchronous computation is returned from the function call. + +The coroutine body will start executing when the `task` value is +`co_await`ed. This will suspend the awaiting coroutine and start execution +of the coroutine associated with the awaited `task` value. + +The awaiting coroutine will later be resumed on the thread that completes +execution of the awaited `task`'s coroutine. ie. the thread that +executes the `co_return` or that throws an unhandled exception that terminates +execution of the coroutine. + +If the task has already run to completion then awaiting it again will obtain +the already-computed result without suspending the awaiting coroutine. + +If the `task` object is destroyed before it is awaited then the coroutine +never executes and the destructor simply destructs the captured parameters +and frees any memory used by the coroutine frame. + +## `shared_task` + +The `shared_task` class is a coroutine type that yields a single value +asynchronously. + +It is 'lazy' in that execution of the task does not start until it is awaited by some +coroutine. + +It is 'shared' in that the task value can be copied, allowing multiple references to +the result of the task to be created. It also allows multiple coroutines to +concurrently await the result. + +The task will start executing on the thread that first `co_await`s the task. +Subsequent awaiters will either be suspended and be queued for resumption +when the task completes or will continue synchronously if the task has +already run to completion. + +If an awaiter is suspended while waiting for the task to complete then +it will be resumed on the thread that completes execution of the task. +ie. the thread that executes the `co_return` or that throws the unhandled +exception that terminates execution of the coroutine. + +API Summary +```c++ +namespace cppcoro +{ + template + class shared_task + { + public: + + using promise_type = ; + using value_type = T; + + shared_task() noexcept; + shared_task(const shared_task& other) noexcept; + shared_task(shared_task&& other) noexcept; + shared_task& operator=(const shared_task& other) noexcept; + shared_task& operator=(shared_task&& other) noexcept; + + void swap(shared_task& other) noexcept; + + // Query if the task has completed and the result is ready. + bool is_ready() const noexcept; + + // Returns an operation that when awaited will suspend the + // current coroutine until the task completes and the result + // is available. + // + // The type of the result of the 'co_await someTask' expression + // is an l-value reference to the task's result value (unless T + // is void in which case the expression has type 'void'). + // If the task completed with an unhandled exception then the + // exception will be rethrown by the co_await expression. + Awaiter operator co_await() const noexcept; + + // Returns an operation that when awaited will suspend the + // calling coroutine until the task completes and the result + // is available. + // + // The result is not returned from the co_await expression. + // This can be used to synchronize with the task without the + // possibility of the co_await expression throwing an exception. + Awaiter when_ready() const noexcept; + + }; + + template + bool operator==(const shared_task& a, const shared_task& b) noexcept; + template + bool operator!=(const shared_task& a, const shared_task& b) noexcept; + + template + void swap(shared_task& a, shared_task& b) noexcept; + + // Wrap an awaitable value in a shared_task to allow multiple coroutines + // to concurrently await the result. + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits::await_result_t> + shared_task make_shared_task(AWAITABLE awaitable); +} +``` + +All const-methods on `shared_task` are safe to call concurrently with other +const-methods on the same instance from multiple threads. It is not safe to call +non-const methods of `shared_task` concurrently with any other method on the +same instance of a `shared_task`. + +### Comparison to `task` + +The `shared_task` class is similar to `task` in that the task does +not start execution immediately upon the coroutine function being called. +The task only starts executing when it is first awaited. + +It differs from `task` in that the resulting task object can be copied, +allowing multiple task objects to reference the same asynchronous result. +It also supports multiple coroutines concurrently awaiting the result of the task. + +The trade-off is that the result is always an l-value reference to the +result, never an r-value reference (since the result may be shared) which +may limit ability to move-construct the result into a local variable. +It also has a slightly higher run-time cost due to the need to maintain +a reference count and support multiple awaiters. + +## `generator` + +A `generator` represents a coroutine type that produces a sequence of values of type, `T`, +where values are produced lazily and synchronously. + +The coroutine body is able to yield values of type `T` using the `co_yield` keyword. +Note, however, that the coroutine body is not able to use the `co_await` keyword; +values must be produced synchronously. + +For example: +```c++ +cppcoro::generator fibonacci() +{ + std::uint64_t a = 0, b = 1; + while (true) + { + co_yield b; + auto tmp = a; + a = b; + b += tmp; + } +} + +void usage() +{ + for (auto i : fibonacci()) + { + if (i > 1'000'000) break; + std::cout << i << std::endl; + } +} +``` + +When a coroutine function returning a `generator` is called the coroutine is created initially suspended. +Execution of the coroutine enters the coroutine body when the `generator::begin()` method is called and continues until +either the first `co_yield` statement is reached or the coroutine runs to completion. + +If the returned iterator is not equal to the `end()` iterator then dereferencing the iterator will +return a reference to the value passed to the `co_yield` statement. + +Calling `operator++()` on the iterator will resume execution of the coroutine and continue until +either the next `co_yield` point is reached or the coroutine runs to completion(). + +Any unhandled exceptions thrown by the coroutine will propagate out of the `begin()` or +`operator++()` calls to the caller. + +API Summary: +```c++ +namespace cppcoro +{ + template + class generator + { + public: + + using promise_type = ; + + class iterator + { + public: + using iterator_category = std::input_iterator_tag; + using value_type = std::remove_reference_t; + using reference = value_type&; + using pointer = value_type*; + using difference_type = std::size_t; + + iterator(const iterator& other) noexcept; + iterator& operator=(const iterator& other) noexcept; + + // If the generator coroutine throws an unhandled exception before producing + // the next element then the exception will propagate out of this call. + iterator& operator++(); + + reference operator*() const noexcept; + pointer operator->() const noexcept; + + bool operator==(const iterator& other) const noexcept; + bool operator!=(const iterator& other) const noexcept; + }; + + // Constructs to the empty sequence. + generator() noexcept; + + generator(generator&& other) noexcept; + generator& operator=(generator&& other) noexcept; + + generator(const generator& other) = delete; + generator& operator=(const generator&) = delete; + + ~generator(); + + // Starts executing the generator coroutine which runs until either a value is yielded + // or the coroutine runs to completion or an unhandled exception propagates out of the + // the coroutine. + iterator begin(); + + iterator end() noexcept; + + // Swap the contents of two generators. + void swap(generator& other) noexcept; + + }; + + template + void swap(generator& a, generator& b) noexcept; + + // Apply function, func, lazily to each element of the source generator + // and yield a sequence of the results of calls to func(). + template + generator> fmap(FUNC func, generator source); +} +``` + +## `recursive_generator` + +A `recursive_generator` is similar to a `generator` except that it is designed to more efficiently +support yielding the elements of a nested sequence as elements of an outer sequence. + +In addition to being able to `co_yield` a value of type `T` you can also `co_yield` a value of type `recursive_generator`. + +When you `co_yield` a `recursive_generator` value the all elements of the yielded generator are yielded as elements of the current generator. +The current coroutine is suspended until the consumer has finished consuming all elements of the nested generator, after which point execution +of the current coroutine will resume execution to produce the next element. + +The benefit of `recursive_generator` over `generator` for iterating over recursive data-structures is that the `iterator::operator++()` +is able to directly resume the leaf-most coroutine to produce the next element, rather than having to resume/suspend O(depth) coroutines for each element. +The down-side is that there is additional overhead + +For example: +```c++ +// Lists the immediate contents of a directory. +cppcoro::generator list_directory(std::filesystem::path path); + +cppcoro::recursive_generator list_directory_recursive(std::filesystem::path path) +{ + for (auto& entry : list_directory(path)) + { + co_yield entry; + if (entry.is_directory()) + { + co_yield list_directory_recursive(entry.path()); + } + } +} +``` + +Note that applying the `fmap()` operator to a `recursive_generator` will yield a `generator` +type rather than a `recursive_generator`. This is because uses of `fmap` are generally not used +in recursive contexts and we try to avoid the extra overhead incurred by `recursive_generator`. + +## `async_generator` + +An `async_generator` represents a coroutine type that produces a sequence of values of type, `T`, where values are produced lazily and values may be produced asynchronously. + +The coroutine body is able to use both `co_await` and `co_yield` expressions. + +Consumers of the generator can use a `for co_await` range-based for-loop to consume the values. + +Example +```c++ +cppcoro::async_generator ticker(int count, threadpool& tp) +{ + for (int i = 0; i < count; ++i) + { + co_await tp.delay(std::chrono::seconds(1)); + co_yield i; + } +} + +cppcoro::task<> consumer(threadpool& tp) +{ + auto sequence = ticker(10, tp); + for co_await(std::uint32_t i : sequence) + { + std::cout << "Tick " << i << std::endl; + } +} +``` + +API Summary +```c++ +// +namespace cppcoro +{ + template + class async_generator + { + public: + + class iterator + { + public: + using iterator_tag = std::forward_iterator_tag; + using difference_type = std::size_t; + using value_type = std::remove_reference_t; + using reference = value_type&; + using pointer = value_type*; + + iterator(const iterator& other) noexcept; + iterator& operator=(const iterator& other) noexcept; + + // Resumes the generator coroutine if suspended + // Returns an operation object that must be awaited to wait + // for the increment operation to complete. + // If the coroutine runs to completion then the iterator + // will subsequently become equal to the end() iterator. + // If the coroutine completes with an unhandled exception then + // that exception will be rethrown from the co_await expression. + Awaitable operator++() noexcept; + + // Dereference the iterator. + pointer operator->() const noexcept; + reference operator*() const noexcept; + + bool operator==(const iterator& other) const noexcept; + bool operator!=(const iterator& other) const noexcept; + }; + + // Construct to the empty sequence. + async_generator() noexcept; + async_generator(const async_generator&) = delete; + async_generator(async_generator&& other) noexcept; + ~async_generator(); + + async_generator& operator=(const async_generator&) = delete; + async_generator& operator=(async_generator&& other) noexcept; + + void swap(async_generator& other) noexcept; + + // Starts execution of the coroutine and returns an operation object + // that must be awaited to wait for the first value to become available. + // The result of co_await'ing the returned object is an iterator that + // can be used to advance to subsequent elements of the sequence. + // + // This method is not valid to be called once the coroutine has + // run to completion. + Awaitable begin() noexcept; + iterator end() noexcept; + + }; + + template + void swap(async_generator& a, async_generator& b); + + // Apply 'func' to each element of the source generator, yielding a sequence of + // the results of calling 'func' on the source elements. + template + async_generator> fmap(FUNC func, async_generator source); +} +``` + +### Early termination of an async_generator + +When the `async_generator` object is destructed it requests cancellation of the underlying coroutine. +If the coroutine has already run to completion or is currently suspended in a `co_yield` expression +then the coroutine is destroyed immediately. Otherwise, the coroutine will continue execution until +it either runs to completion or reaches the next `co_yield` expression. + +When the coroutine frame is destroyed the destructors of all variables in scope at that point will be +executed to ensure the resources of the generator are cleaned up. + +Note that the caller must ensure that the `async_generator` object must not be destroyed while a +consumer coroutine is executing a `co_await` expression waiting for the next item to be produced. + +## `single_consumer_event` + +This is a simple manual-reset event type that supports only a single +coroutine awaiting it at a time. +This can be used to + +API Summary: +```c++ +// +namespace cppcoro +{ + class single_consumer_event + { + public: + single_consumer_event(bool initiallySet = false) noexcept; + bool is_set() const noexcept; + void set(); + void reset() noexcept; + Awaiter operator co_await() const noexcept; + }; +} +``` + +Example: +```c++ +#include + +cppcoro::single_consumer_event event; +std::string value; + +cppcoro::task<> consumer() +{ + // Coroutine will suspend here until some thread calls event.set() + // eg. inside the producer() function below. + co_await event; + + std::cout << value << std::endl; +} + +void producer() +{ + value = "foo"; + + // This will resume the consumer() coroutine inside the call to set() + // if it is currently suspended. + event.set(); +} +``` + +## `single_consumer_async_auto_reset_event` + +This class provides an async synchronization primitive that allows a single coroutine to +wait until the event is signalled by a call to the `set()` method. + +Once the coroutine that is awaiting the event is released by either a prior or subsequent call to `set()` +the event is automatically reset back to the 'not set' state. + +This class is a more efficient version of `async_auto_reset_event` that can be used in cases where +only a single coroutine will be awaiting the event at a time. If you need to support multiple concurrent +awaiting coroutines on the event then use the `async_auto_reset_event` class instead. + +API Summary: +```c++ +// +namespace cppcoro +{ + class single_consumer_async_auto_reset_event + { + public: + + single_consumer_async_auto_reset_event( + bool initiallySet = false) noexcept; + + // Change the event to the 'set' state. If a coroutine is awaiting the + // event then the event is immediately transitioned back to the 'not set' + // state and the coroutine is resumed. + void set() noexcept; + + // Returns an Awaitable type that can be awaited to wait until + // the event becomes 'set' via a call to the .set() method. If + // the event is already in the 'set' state then the coroutine + // continues without suspending. + // The event is automatically reset back to the 'not set' state + // before resuming the coroutine. + Awaiter operator co_await() const noexcept; + + }; +} +``` + +Example Usage: +```c++ +std::atomic value; +cppcoro::single_consumer_async_auto_reset_event valueDecreasedEvent; + +cppcoro::task<> wait_until_value_is_below(int limit) +{ + while (value.load(std::memory_order_relaxed) >= limit) + { + // Wait until there has been some change that we're interested in. + co_await valueDecreasedEvent; + } +} + +void change_value(int delta) +{ + value.fetch_add(delta, std::memory_order_relaxed); + // Notify the waiter if there has been some change. + if (delta < 0) valueDecreasedEvent.set(); +} +``` + +## `async_mutex` + +Provides a simple mutual exclusion abstraction that allows the caller to 'co_await' the mutex +from within a coroutine to suspend the coroutine until the mutex lock is acquired. + +The implementation is lock-free in that a coroutine that awaits the mutex will not +block the thread but will instead suspend the coroutine and later resume it inside +the call to `unlock()` by the previous lock-holder. + +API Summary: +```c++ +// +namespace cppcoro +{ + class async_mutex_lock; + class async_mutex_lock_operation; + class async_mutex_scoped_lock_operation; + + class async_mutex + { + public: + async_mutex() noexcept; + ~async_mutex(); + + async_mutex(const async_mutex&) = delete; + async_mutex& operator(const async_mutex&) = delete; + + bool try_lock() noexcept; + async_mutex_lock_operation lock_async() noexcept; + async_mutex_scoped_lock_operation scoped_lock_async() noexcept; + void unlock(); + }; + + class async_mutex_lock_operation + { + public: + bool await_ready() const noexcept; + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() const noexcept; + }; + + class async_mutex_scoped_lock_operation + { + public: + bool await_ready() const noexcept; + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + [[nodiscard]] async_mutex_lock await_resume() const noexcept; + }; + + class async_mutex_lock + { + public: + // Takes ownership of the lock. + async_mutex_lock(async_mutex& mutex, std::adopt_lock_t) noexcept; + + // Transfer ownership of the lock. + async_mutex_lock(async_mutex_lock&& other) noexcept; + + async_mutex_lock(const async_mutex_lock&) = delete; + async_mutex_lock& operator=(const async_mutex_lock&) = delete; + + // Releases the lock by calling unlock() on the mutex. + ~async_mutex_lock(); + }; +} +``` + +Example usage: +```c++ +#include +#include +#include +#include + +cppcoro::async_mutex mutex; +std::set values; + +cppcoro::task<> add_item(std::string value) +{ + cppcoro::async_mutex_lock lock = co_await mutex.scoped_lock_async(); + values.insert(std::move(value)); +} +``` + +## `async_manual_reset_event` + +A manual-reset event is a coroutine/thread-synchronization primitive that allows one or more threads +to wait until the event is signalled by a thread that calls `set()`. + +The event is in one of two states; *'set'* and *'not set'*. + +If the event is in the *'set'* state when a coroutine awaits the event then the coroutine +continues without suspending. However if the coroutine is in the *'not set'* state then the +coroutine is suspended until some thread subsequently calls the `set()` method. + +Any threads that were suspended while waiting for the event to become *'set'* will be resumed +inside the next call to `set()` by some thread. + +Note that you must ensure that no coroutines are awaiting a *'not set'* event when the +event is destructed as they will not be resumed. + +Example: +```c++ +cppcoro::async_manual_reset_event event; +std::string value; + +void producer() +{ + value = get_some_string_value(); + + // Publish a value by setting the event. + event.set(); +} + +// Can be called many times to create many tasks. +// All consumer tasks will wait until value has been published. +cppcoro::task<> consumer() +{ + // Wait until value has been published by awaiting event. + co_await event; + + consume_value(value); +} +``` + +API Summary: +```c++ +namespace cppcoro +{ + class async_manual_reset_event_operation; + + class async_manual_reset_event + { + public: + async_manual_reset_event(bool initiallySet = false) noexcept; + ~async_manual_reset_event(); + + async_manual_reset_event(const async_manual_reset_event&) = delete; + async_manual_reset_event(async_manual_reset_event&&) = delete; + async_manual_reset_event& operator=(const async_manual_reset_event&) = delete; + async_manual_reset_event& operator=(async_manual_reset_event&&) = delete; + + // Wait until the event becomes set. + async_manual_reset_event_operation operator co_await() const noexcept; + + bool is_set() const noexcept; + + void set() noexcept; + + void reset() noexcept; + + }; + + class async_manual_reset_event_operation + { + public: + async_manual_reset_event_operation(async_manual_reset_event& event) noexcept; + + bool await_ready() const noexcept; + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() const noexcept; + }; +} +``` + +## `async_auto_reset_event` + +An auto-reset event is a coroutine/thread-synchronization primitive that allows one or more threads +to wait until the event is signalled by a thread by calling `set()`. + +Once a coroutine that is awaiting the event is released by either a prior or subsequent call to `set()` +the event is automatically reset back to the 'not set' state. + +API Summary: +```c++ +// +namespace cppcoro +{ + class async_auto_reset_event_operation; + + class async_auto_reset_event + { + public: + + async_auto_reset_event(bool initiallySet = false) noexcept; + + ~async_auto_reset_event(); + + async_auto_reset_event(const async_auto_reset_event&) = delete; + async_auto_reset_event(async_auto_reset_event&&) = delete; + async_auto_reset_event& operator=(const async_auto_reset_event&) = delete; + async_auto_reset_event& operator=(async_auto_reset_event&&) = delete; + + // Wait for the event to enter the 'set' state. + // + // If the event is already 'set' then the event is set to the 'not set' + // state and the awaiting coroutine continues without suspending. + // Otherwise, the coroutine is suspended and later resumed when some + // thread calls 'set()'. + // + // Note that the coroutine may be resumed inside a call to 'set()' + // or inside another thread's call to 'operator co_await()'. + async_auto_reset_event_operation operator co_await() const noexcept; + + // Set the state of the event to 'set'. + // + // If there are pending coroutines awaiting the event then one + // pending coroutine is resumed and the state is immediately + // set back to the 'not set' state. + // + // This operation is a no-op if the event was already 'set'. + void set() noexcept; + + // Set the state of the event to 'not-set'. + // + // This is a no-op if the state was already 'not set'. + void reset() noexcept; + + }; + + class async_auto_reset_event_operation + { + public: + explicit async_auto_reset_event_operation(async_auto_reset_event& event) noexcept; + async_auto_reset_event_operation(const async_auto_reset_event_operation& other) noexcept; + + bool await_ready() const noexcept; + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() const noexcept; + + }; +} +``` + +## `async_latch` + +An async latch is a synchronization primitive that allows coroutines to asynchronously +wait until a counter has been decremented to zero. + +The latch is a single-use object. Once the counter reaches zero the latch becomes 'ready' +and will remain ready until the latch is destroyed. + +API Summary: +```c++ +// +namespace cppcoro +{ + class async_latch + { + public: + + // Initialise the latch with the specified count. + async_latch(std::ptrdiff_t initialCount) noexcept; + + // Query if the count has reached zero yet. + bool is_ready() const noexcept; + + // Decrement the count by n. + // This will resume any waiting coroutines if the count reaches zero + // as a result of this call. + // It is undefined behaviour to decrement the count below zero. + void count_down(std::ptrdiff_t n = 1) noexcept; + + // Wait until the latch becomes ready. + // If the latch count is not yet zero then the awaiting coroutine will + // be suspended and later resumed by a call to count_down() that decrements + // the count to zero. If the latch count was already zero then the coroutine + // continues without suspending. + Awaiter operator co_await() const noexcept; + + }; +} +``` + +## `sequence_barrier` + +A `sequence_barrier` is a synchronization primitive that allows a single-producer +and multiple consumers to coordinate with respect to a monotonically increasing +sequence number. + +A single producer advances the sequence number by publishing new sequence numbers +in a monotonically increasing order. One or more consumers can query the last +published sequence number and can wait until a particular sequence number has been +published. + +A sequence barrier can be used to represent a cursor into a thread-safe producer/consumer +ring-buffer + +See the LMAX Disruptor pattern for more background: +https://lmax-exchange.github.io/disruptor/files/Disruptor-1.0.pdf + +API Synopsis: +```c++ +namespace cppcoro +{ + template> + class sequence_barrier + { + public: + sequence_barrier(SEQUENCE initialSequence = TRAITS::initial_sequence) noexcept; + ~sequence_barrier(); + + SEQUENCE last_published() const noexcept; + + // Wait until the specified targetSequence number has been published. + // + // If the operation does not complete synchronously then the awaiting + // coroutine is resumed on the specified scheduler. Otherwise, the + // coroutine continues without suspending. + // + // The co_await expression resumes with the updated last_published() + // value, which is guaranteed to be at least 'targetSequence'. + template + [[nodiscard]] + Awaitable wait_until_published(SEQUENCE targetSequence, + SCHEDULER& scheduler) const noexcept; + + void publish(SEQUENCE sequence) noexcept; + }; +} +``` + +## `single_producer_sequencer` + +A `single_producer_sequencer` is a synchronization primitive that can be used to +coordinate access to a ring-buffer for a single producer and one or more consumers. + +A producer first acquires one or more slots in a ring-buffer, writes to the ring-buffer +elements corresponding to those slots, and then finally publishes the values written to +those slots. A producer can never produce more than 'bufferSize' elements in advance +of where the consumer has consumed up to. + +A consumer then waits for certain elements to be published, processes the items and +then notifies the producer when it has finished processing items by publishing the +sequence number it has finished consuming in a `sequence_barrier` object. + + +API Synopsis: +```c++ +// +namespace cppcoro +{ + template< + typename SEQUENCE = std::size_t, + typename TRAITS = sequence_traits> + class single_producer_sequencer + { + public: + using size_type = typename sequence_range::size_type; + + single_producer_sequencer( + const sequence_barrier& consumerBarrier, + std::size_t bufferSize, + SEQUENCE initialSequence = TRAITS::initial_sequence) noexcept; + + // Publisher API: + + template + [[nodiscard]] + Awaitable claim_one(SCHEDULER& scheduler) noexcept; + + template + [[nodiscard]] + Awaitable> claim_up_to( + std::size_t count, + SCHEDULER& scheduler) noexcept; + + void publish(SEQUENCE sequence) noexcept; + + // Consumer API: + + SEQUENCE last_published() const noexcept; + + template + [[nodiscard]] + Awaitable wait_until_published( + SEQUENCE targetSequence, + SCHEDULER& scheduler) const noexcept; + + }; +} +``` + +Example usage: +```c++ +using namespace cppcoro; +using namespace std::chrono; + +struct message +{ + int id; + steady_clock::time_point timestamp; + float data; +}; + +constexpr size_t bufferSize = 16384; // Must be power-of-two +constexpr size_t indexMask = bufferSize - 1; +message buffer[bufferSize]; + +task producer( + io_service& ioSvc, + single_producer_sequencer& sequencer) +{ + auto start = steady_clock::now(); + for (int i = 0; i < 1'000'000; ++i) + { + // Wait until a slot is free in the buffer. + size_t seq = co_await sequencer.claim_one(ioSvc); + + // Populate the message. + auto& msg = buffer[seq & indexMask]; + msg.id = i; + msg.timestamp = steady_clock::now(); + msg.data = 123; + + // Publish the message. + sequencer.publish(seq); + } + + // Publish a sentinel + auto seq = co_await sequencer.claim_one(ioSvc); + auto& msg = buffer[seq & indexMask]; + msg.id = -1; + sequencer.publish(seq); +} + +task consumer( + static_thread_pool& threadPool, + const single_producer_sequencer& sequencer, + sequence_barrier& consumerBarrier) +{ + size_t nextToRead = 0; + while (true) + { + // Wait until the next message is available + // There may be more than one available. + const size_t available = co_await sequencer.wait_until_published(nextToRead, threadPool); + do { + auto& msg = buffer[nextToRead & indexMask]; + if (msg.id == -1) + { + consumerBarrier.publish(nextToRead); + co_return; + } + + processMessage(msg); + } while (nextToRead++ != available); + + // Notify the producer that we've finished processing + // up to 'nextToRead - 1'. + consumerBarrier.publish(available); + } +} + +task example(io_service& ioSvc, static_thread_pool& threadPool) +{ + sequence_barrier barrier; + single_producer_sequencer sequencer{barrier, bufferSize}; + + co_await when_all( + producer(tp, sequencer), + consumer(tp, sequencer, barrier)); +} +``` + +## `multi_producer_sequencer` + +The `multi_producer_sequencer` class is a synchronization primitive that coordinates +access to a ring-buffer for multiple producers and one or more consumers. + +For a single-producer variant see the `single_producer_sequencer` class. + +Note that the ring-buffer must have a size that is a power-of-two. This is because +the implementation uses bitmasks instead of integer division/modulo to calculate +the offset into the buffer. Also, this allows the sequence number to safely wrap +around the 32-bit/64-bit value. + +API Summary: +```c++ +// +namespace cppcoro +{ + template> + class multi_producer_sequencer + { + public: + multi_producer_sequencer( + const sequence_barrier& consumerBarrier, + SEQUENCE initialSequence = TRAITS::initial_sequence); + + std::size_t buffer_size() const noexcept; + + // Consumer interface + // + // Each consumer keeps track of their own 'lastKnownPublished' value + // and must pass this to the methods that query for an updated last-known + // published sequence number. + + SEQUENCE last_published_after(SEQUENCE lastKnownPublished) const noexcept; + + template + Awaitable wait_until_published( + SEQUENCE targetSequence, + SEQUENCE lastKnownPublished, + SCHEDULER& scheduler) const noexcept; + + // Producer interface + + // Query whether any slots available for claiming (approx.) + bool any_available() const noexcept; + + template + Awaitable claim_one(SCHEDULER& scheduler) noexcept; + + template + Awaitable> claim_up_to( + std::size_t count, + SCHEDULER& scheduler) noexcept; + + // Mark the specified sequence number as published. + void publish(SEQUENCE sequence) noexcept; + + // Mark all sequence numbers in the specified range as published. + void publish(const sequence_range& range) noexcept; + }; +} +``` + +## Cancellation + +A `cancellation_token` is a value that can be passed to a function that allows the caller to subsequently communicate a request to cancel the operation to that function. + +To obtain a `cancellation_token` that is able to be cancelled you must first create a `cancellation_source` object. +The `cancellation_source::token()` method can be used to manufacture new `cancellation_token` values that are linked to that `cancellation_source` object. + +When you want to later request cancellation of an operation you have passed a `cancellation_token` to +you can call `cancellation_source::request_cancellation()` on an associated `cancellation_source` object. + +Functions can respond to a request for cancellation in one of two ways: +1. Poll for cancellation at regular intervals by calling either `cancellation_token::is_cancellation_requested()` or `cancellation_token::throw_if_cancellation_requested()`. +2. Register a callback to be executed when cancellation is requested using the `cancellation_registration` class. + +API Summary: +```c++ +namespace cppcoro +{ + class cancellation_source + { + public: + // Construct a new, independently cancellable cancellation source. + cancellation_source(); + + // Construct a new reference to the same cancellation state. + cancellation_source(const cancellation_source& other) noexcept; + cancellation_source(cancellation_source&& other) noexcept; + + ~cancellation_source(); + + cancellation_source& operator=(const cancellation_source& other) noexcept; + cancellation_source& operator=(cancellation_source&& other) noexcept; + + bool is_cancellation_requested() const noexcept; + bool can_be_cancelled() const noexcept; + void request_cancellation(); + + cancellation_token token() const noexcept; + }; + + class cancellation_token + { + public: + // Construct a token that can't be cancelled. + cancellation_token() noexcept; + + cancellation_token(const cancellation_token& other) noexcept; + cancellation_token(cancellation_token&& other) noexcept; + + ~cancellation_token(); + + cancellation_token& operator=(const cancellation_token& other) noexcept; + cancellation_token& operator=(cancellation_token&& other) noexcept; + + bool is_cancellation_requested() const noexcept; + void throw_if_cancellation_requested() const; + + // Query if this token can ever have cancellation requested. + // Code can use this to take a more efficient code-path in cases + // that the operation does not need to handle cancellation. + bool can_be_cancelled() const noexcept; + }; + + // RAII class for registering a callback to be executed if cancellation + // is requested on a particular cancellation token. + class cancellation_registration + { + public: + + // Register a callback to be executed if cancellation is requested. + // Callback will be called with no arguments on the thread that calls + // request_cancellation() if cancellation is not yet requested, or + // called immediately if cancellation has already been requested. + // Callback must not throw an unhandled exception when called. + template + cancellation_registration(cancellation_token token, CALLBACK&& callback); + + cancellation_registration(const cancellation_registration& other) = delete; + + ~cancellation_registration(); + }; + + class operation_cancelled : public std::exception + { + public: + operation_cancelled(); + const char* what() const override; + }; +} +``` + +Example: Polling Approach +```c++ +cppcoro::task<> do_something_async(cppcoro::cancellation_token token) +{ + // Explicitly define cancellation points within the function + // by calling throw_if_cancellation_requested(). + token.throw_if_cancellation_requested(); + + co_await do_step_1(); + + token.throw_if_cancellation_requested(); + + do_step_2(); + + // Alternatively, you can query if cancellation has been + // requested to allow yourself to do some cleanup before + // returning. + if (token.is_cancellation_requested()) + { + display_message_to_user("Cancelling operation..."); + do_cleanup(); + throw cppcoro::operation_cancelled{}; + } + + do_final_step(); +} +``` + +Example: Callback Approach +```c++ +// Say we already have a timer abstraction that supports being +// cancelled but it doesn't support cancellation_tokens natively. +// You can use a cancellation_registration to register a callback +// that calls the existing cancellation API. e.g. +cppcoro::task<> cancellable_timer_wait(cppcoro::cancellation_token token) +{ + auto timer = create_timer(10s); + + cppcoro::cancellation_registration registration(token, [&] + { + // Call existing timer cancellation API. + timer.cancel(); + }); + + co_await timer; +} +``` + +## `static_thread_pool` + +The `static_thread_pool` class provides an abstraction that lets you schedule work +on a fixed-size pool of threads. + +This class implements the **Scheduler** concept (see below). + +You can enqueue work to the thread-pool by executing `co_await threadPool.schedule()`. +This operation will suspend the current coroutine, enqueue it for execution on the +thread-pool and the thread pool will then resume the coroutine when a thread in the +thread-pool is next free to run the coroutine. **This operation is guaranteed not +to throw and, in the common case, will not allocate any memory**. + +This class makes use of a work-stealing algorithm to load-balance work across multiple +threads. Work enqueued to the thread-pool from a thread-pool thread will be scheduled +for execution on the same thread in a LIFO queue. Work enqueued to the thread-pool from +a remote thread will be enqueued to a global FIFO queue. When a worker thread runs out +of work from its local queue it first tries to dequeue work from the global queue. If +that queue is empty then it next tries to steal work from the back of the queues of +the other worker threads. + +API Summary: +```c++ +namespace cppcoro +{ + class static_thread_pool + { + public: + // Initialise the thread-pool with a number of threads equal to + // std::thread::hardware_concurrency(). + static_thread_pool(); + + // Initialise the thread pool with the specified number of threads. + explicit static_thread_pool(std::uint32_t threadCount); + + std::uint32_t thread_count() const noexcept; + + class schedule_operation + { + public: + schedule_operation(static_thread_pool* tp) noexcept; + + bool await_ready() noexcept; + bool await_suspend(cppcoro::coroutine_handle<> h) noexcept; + bool await_resume() noexcept; + + private: + // unspecified + }; + + // Return an operation that can be awaited by a coroutine. + // + // + [[nodiscard]] + schedule_operation schedule() noexcept; + + private: + + // Unspecified + + }; +} +``` + +Example usage: Simple +```c++ +cppcoro::task do_something_on_threadpool(cppcoro::static_thread_pool& tp) +{ + // First schedule the coroutine onto the threadpool. + co_await tp.schedule(); + + // When it resumes, this coroutine is now running on the threadpool. + do_something(); +} +``` + +Example usage: Doing things in parallel - using `schedule_on()` operator with `static_thread_pool`. +```c++ +cppcoro::task dot_product(static_thread_pool& tp, double a[], double b[], size_t count) +{ + if (count > 1000) + { + // Subdivide the work recursively into two equal tasks + // The first half is scheduled to the thread pool so it can run concurrently + // with the second half which continues on this thread. + size_t halfCount = count / 2; + auto [first, second] = co_await when_all( + schedule_on(tp, dot_product(tp, a, b, halfCount), + dot_product(tp, a + halfCount, b + halfCount, count - halfCount)); + co_return first + second; + } + else + { + double sum = 0.0; + for (size_t i = 0; i < count; ++i) + { + sum += a[i] * b[i]; + } + co_return sum; + } +} +``` + +## `io_service` and `io_work_scope` + +The `io_service` class provides an abstraction for processing I/O completion events +from asynchronous I/O operations. + +When an asynchronous I/O operation completes, the coroutine that was awaiting +that operation will be resumed on an I/O thread inside a call to one of the +event-processing methods: `process_events()`, `process_pending_events()`, +`process_one_event()` or `process_one_pending_event()`. + +The `io_service` class does not manage any I/O threads. +You must ensure that some thread calls one of the event-processing methods for coroutines awaiting I/O +completion events to be dispatched. This can either be a dedicated thread that calls `process_events()` +or mixed in with some other event loop (e.g. a UI event loop) by periodically polling for new events +via a call to `process_pending_events()` or `process_one_pending_event()`. + +This allows integration of the `io_service` event-loop with other event loops, such as a user-interface event loop. + +You can multiplex processing of events across multiple threads by having multiple threads call +`process_events()`. You can specify a hint as to the maximum number of threads to have actively +processing events via an optional `io_service` constructor parameter. + +On Windows, the implementation makes use of the Windows I/O Completion Port facility to dispatch +events to I/O threads in a scalable manner. + +API Summary: +```c++ +namespace cppcoro +{ + class io_service + { + public: + + class schedule_operation; + class timed_schedule_operation; + + io_service(); + io_service(std::uint32_t concurrencyHint); + + io_service(io_service&&) = delete; + io_service(const io_service&) = delete; + io_service& operator=(io_service&&) = delete; + io_service& operator=(const io_service&) = delete; + + ~io_service(); + + // Scheduler methods + + [[nodiscard]] + schedule_operation schedule() noexcept; + + template + [[nodiscard]] + timed_schedule_operation schedule_after( + std::chrono::duration delay, + cppcoro::cancellation_token cancellationToken = {}) noexcept; + + // Event-loop methods + // + // I/O threads must call these to process I/O events and execute + // scheduled coroutines. + + std::uint64_t process_events(); + std::uint64_t process_pending_events(); + std::uint64_t process_one_event(); + std::uint64_t process_one_pending_event(); + + // Request that all threads processing events exit their event loops. + void stop() noexcept; + + // Query if some thread has called stop() + bool is_stop_requested() const noexcept; + + // Reset the event-loop after a call to stop() so that threads can + // start processing events again. + void reset(); + + // Reference-counting methods for tracking outstanding references + // to the io_service. + // + // The io_service::stop() method will be called when the last work + // reference is decremented. + // + // Use the io_work_scope RAII class to manage calling these methods on + // entry-to and exit-from a scope. + void notify_work_started() noexcept; + void notify_work_finished() noexcept; + + }; + + class io_service::schedule_operation + { + public: + schedule_operation(const schedule_operation&) noexcept; + schedule_operation& operator=(const schedule_operation&) noexcept; + + bool await_ready() const noexcept; + void await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() noexcept; + }; + + class io_service::timed_schedule_operation + { + public: + timed_schedule_operation(timed_schedule_operation&&) noexcept; + + timed_schedule_operation(const timed_schedule_operation&) = delete; + timed_schedule_operation& operator=(const timed_schedule_operation&) = delete; + timed_schedule_operation& operator=(timed_schedule_operation&&) = delete; + + bool await_ready() const noexcept; + void await_suspend(cppcoro::coroutine_handle<> awaiter); + void await_resume(); + }; + + class io_work_scope + { + public: + + io_work_scope(io_service& ioService) noexcept; + + io_work_scope(const io_work_scope& other) noexcept; + io_work_scope(io_work_scope&& other) noexcept; + + ~io_work_scope(); + + io_work_scope& operator=(const io_work_scope& other) noexcept; + io_work_scope& operator=(io_work_scope&& other) noexcept; + + io_service& service() const noexcept; + }; + +} +``` + +Example: +```c++ +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace fs = cppcoro::filesystem; + +cppcoro::task count_lines(cppcoro::io_service& ioService, fs::path path) +{ + auto file = cppcoro::read_only_file::open(ioService, path); + + constexpr size_t bufferSize = 4096; + auto buffer = std::make_unique(bufferSize); + + std::uint64_t newlineCount = 0; + + for (std::uint64_t offset = 0, fileSize = file.size(); offset < fileSize;) + { + const auto bytesToRead = static_cast( + std::min(bufferSize, fileSize - offset)); + + const auto bytesRead = co_await file.read(offset, buffer.get(), bytesToRead); + + newlineCount += std::count(buffer.get(), buffer.get() + bytesRead, '\n'); + + offset += bytesRead; + } + + co_return newlineCount; +} + +cppcoro::task<> run(cppcoro::io_service& ioService) +{ + cppcoro::io_work_scope ioScope(ioService); + + auto lineCount = co_await count_lines(ioService, fs::path{"foo.txt"}); + + std::cout << "foo.txt has " << lineCount << " lines." << std::endl;; +} + +cppcoro::task<> process_events(cppcoro::io_service& ioService) +{ + // Process events until the io_service is stopped. + // ie. when the last io_work_scope goes out of scope. + ioService.process_events(); + co_return; +} + +int main() +{ + cppcoro::io_service ioService; + + cppcoro::sync_wait(cppcoro::when_all_ready( + run(ioService), + process_events(ioService))); + + return 0; +} +``` + +### `io_service` as a scheduler + +An `io_service` class implements the interfaces for the `Scheduler` and `DelayedScheduler` concepts. + +This allows a coroutine to suspend execution on the current thread and schedule itself for resumption +on an I/O thread associated with a particular `io_service` object. + +Example: +```c++ +cppcoro::task<> do_something(cppcoro::io_service& ioService) +{ + // Coroutine starts execution on the thread of the task awaiter. + + // A coroutine can transfer execution to an I/O thread by awaiting the + // result of io_service::schedule(). + co_await ioService.schedule(); + + // At this point, the coroutine is now executing on an I/O thread + // inside a call to one of the io_service event processing methods. + + // A coroutine can also perform a delayed-schedule that will suspend + // the coroutine for a specified duration of time before scheduling + // it for resumption on an I/O thread. + co_await ioService.schedule_after(100ms); + + // At this point, the coroutine is executing on a potentially different I/O thread. +} +``` + +## `file`, `readable_file`, `writable_file` + +These types are abstract base-classes for performing concrete file I/O. + +API Summary: +```c++ +namespace cppcoro +{ + class file_read_operation; + class file_write_operation; + + class file + { + public: + + virtual ~file(); + + std::uint64_t size() const; + + protected: + + file(file&& other) noexcept; + + }; + + class readable_file : public virtual file + { + public: + + [[nodiscard]] + file_read_operation read( + std::uint64_t offset, + void* buffer, + std::size_t byteCount, + cancellation_token ct = {}) const noexcept; + + }; + + class writable_file : public virtual file + { + public: + + void set_size(std::uint64_t fileSize); + + [[nodiscard]] + file_write_operation write( + std::uint64_t offset, + const void* buffer, + std::size_t byteCount, + cancellation_token ct = {}) noexcept; + + }; + + class file_read_operation + { + public: + + file_read_operation(file_read_operation&& other) noexcept; + + bool await_ready() const noexcept; + bool await_suspend(cppcoro::coroutine_handle<> awaiter); + std::size_t await_resume(); + + }; + + class file_write_operation + { + public: + + file_write_operation(file_write_operation&& other) noexcept; + + bool await_ready() const noexcept; + bool await_suspend(cppcoro::coroutine_handle<> awaiter); + std::size_t await_resume(); + + }; +} +``` + +## `read_only_file`, `write_only_file`, `read_write_file` + +These types represent concrete file I/O classes. + +API Summary: +```c++ +namespace cppcoro +{ + class read_only_file : public readable_file + { + public: + + [[nodiscard]] + static read_only_file open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_share_mode shareMode = file_share_mode::read, + file_buffering_mode bufferingMode = file_buffering_mode::default_); + + }; + + class write_only_file : public writable_file + { + public: + + [[nodiscard]] + static write_only_file open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode = file_open_mode::create_or_open, + file_share_mode shareMode = file_share_mode::none, + file_buffering_mode bufferingMode = file_buffering_mode::default_); + + }; + + class read_write_file : public readable_file, public writable_file + { + public: + + [[nodiscard]] + static read_write_file open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode = file_open_mode::create_or_open, + file_share_mode shareMode = file_share_mode::none, + file_buffering_mode bufferingMode = file_buffering_mode::default_); + + }; +} +``` + +All `open()` functions throw `std::system_error` on failure. + +# Networking + +NOTE: Networking abstractions are currently only supported on the Windows platform. +Linux support will be coming soon. + +## `socket` + +The socket class can be used to send/receive data over the network asynchronously. + +Currently only supports TCP/IP, UDP/IP over IPv4 and IPv6. + +API Summary: +```c++ +// +namespace cppcoro::net +{ + class socket + { + public: + + static socket create_tcpv4(ip_service& ioSvc); + static socket create_tcpv6(ip_service& ioSvc); + static socket create_updv4(ip_service& ioSvc); + static socket create_udpv6(ip_service& ioSvc); + + socket(socket&& other) noexcept; + + ~socket(); + + socket& operator=(socket&& other) noexcept; + + // Return the native socket handle for the socket + native_handle() noexcept; + + const ip_endpoint& local_endpoint() const noexcept; + const ip_endpoint& remote_endpoint() const noexcept; + + void bind(const ip_endpoint& localEndPoint); + + void listen(); + + [[nodiscard]] + Awaitable connect(const ip_endpoint& remoteEndPoint) noexcept; + [[nodiscard]] + Awaitable connect(const ip_endpoint& remoteEndPoint, + cancellation_token ct) noexcept; + + [[nodiscard]] + Awaitable accept(socket& acceptingSocket) noexcept; + [[nodiscard]] + Awaitable accept(socket& acceptingSocket, + cancellation_token ct) noexcept; + + [[nodiscard]] + Awaitable disconnect() noexcept; + [[nodiscard]] + Awaitable disconnect(cancellation_token ct) noexcept; + + [[nodiscard]] + Awaitable send(const void* buffer, std::size_t size) noexcept; + [[nodiscard]] + Awaitable send(const void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + [[nodiscard]] + Awaitable recv(void* buffer, std::size_t size) noexcept; + [[nodiscard]] + Awaitable recv(void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + [[nodiscard]] + socket_recv_from_operation recv_from( + void* buffer, + std::size_t size) noexcept; + [[nodiscard]] + socket_recv_from_operation_cancellable recv_from( + void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + [[nodiscard]] + socket_send_to_operation send_to( + const ip_endpoint& destination, + const void* buffer, + std::size_t size) noexcept; + [[nodiscard]] + socket_send_to_operation_cancellable send_to( + const ip_endpoint& destination, + const void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + void close_send(); + void close_recv(); + + }; +} +``` + +Example: Echo Server +```c++ +#include +#include +#include +#include +#include + +#include +#include + +cppcoro::task handle_connection(socket s) +{ + try + { + const size_t bufferSize = 16384; + auto buffer = std::make_unique(bufferSize); + size_t bytesRead; + do { + // Read some bytes + bytesRead = co_await s.recv(buffer.get(), bufferSize); + + // Write some bytes + size_t bytesWritten = 0; + while (bytesWritten < bytesRead) { + bytesWritten += co_await s.send( + buffer.get() + bytesWritten, + bytesRead - bytesWritten); + } + } while (bytesRead != 0); + + s.close_send(); + + co_await s.disconnect(); + } + catch (...) + { + std::cout << "connection failed" << std:: + } +} + +cppcoro::task echo_server( + cppcoro::net::ipv4_endpoint endpoint, + cppcoro::io_service& ioSvc, + cancellation_token ct) +{ + cppcoro::async_scope scope; + + std::exception_ptr ex; + try + { + auto listeningSocket = cppcoro::net::socket::create_tcpv4(ioSvc); + listeningSocket.bind(endpoint); + listeningSocket.listen(); + + while (true) { + auto connection = cppcoro::net::socket::create_tcpv4(ioSvc); + co_await listeningSocket.accept(connection, ct); + scope.spawn(handle_connection(std::move(connection))); + } + } + catch (cppcoro::operation_cancelled) + { + } + catch (...) + { + ex = std::current_exception(); + } + + // Wait until all handle_connection tasks have finished. + co_await scope.join(); + + if (ex) std::rethrow_exception(ex); +} + +int main(int argc, const char* argv[]) +{ + cppcoro::io_service ioSvc; + + if (argc != 2) return -1; + + auto endpoint = cppcoro::ipv4_endpoint::from_string(argv[1]); + if (!endpoint) return -1; + + (void)cppcoro::sync_wait(cppcoro::when_all( + [&]() -> task<> + { + // Shutdown the event loop once finished. + auto stopOnExit = cppcoro::on_scope_exit([&] { ioSvc.stop(); }); + + cppcoro::cancellation_source canceller; + co_await cppcoro::when_all( + [&]() -> task<> + { + // Run for 30s then stop accepting new connections. + co_await ioSvc.schedule_after(std::chrono::seconds(30)); + canceller.request_cancellation(); + }(), + echo_server(*endpoint, ioSvc, canceller.token())); + }(), + [&]() -> task<> + { + ioSvc.process_events(); + }())); + + return 0; +} +``` + +## `ip_address`, `ipv4_address`, `ipv6_address` + +Helper classes for representing an IP address. + +API Synopsis: +```c++ +namespace cppcoro::net +{ + class ipv4_address + { + using bytes_t = std::uint8_t[4]; + public: + constexpr ipv4_address(); + explicit constexpr ipv4_address(std::uint32_t integer); + explicit constexpr ipv4_address(const std::uint8_t(&bytes)[4]); + explicit constexpr ipv4_address(std::uint8_t b0, + std::uint8_t b1, + std::uint8_t b2, + std::uint8_t b3); + + constexpr const bytes_t& bytes() const; + + constexpr std::uint32_t to_integer() const; + + static constexpr ipv4_address loopback(); + + constexpr bool is_loopback() const; + constexpr bool is_private_network() const; + + constexpr bool operator==(ipv4_address other) const; + constexpr bool operator!=(ipv4_address other) const; + constexpr bool operator<(ipv4_address other) const; + constexpr bool operator>(ipv4_address other) const; + constexpr bool operator<=(ipv4_address other) const; + constexpr bool operator>=(ipv4_address other) const; + + std::string to_string(); + + static std::optional from_string(std::string_view string) noexcept; + }; + + class ipv6_address + { + using bytes_t = std::uint8_t[16]; + public: + constexpr ipv6_address(); + + explicit constexpr ipv6_address( + std::uint64_t subnetPrefix, + std::uint64_t interfaceIdentifier); + + constexpr ipv6_address( + std::uint16_t part0, + std::uint16_t part1, + std::uint16_t part2, + std::uint16_t part3, + std::uint16_t part4, + std::uint16_t part5, + std::uint16_t part6, + std::uint16_t part7); + + explicit constexpr ipv6_address( + const std::uint16_t(&parts)[8]); + + explicit constexpr ipv6_address( + const std::uint8_t(bytes)[16]); + + constexpr const bytes_t& bytes() const; + + constexpr std::uint64_t subnet_prefix() const; + constexpr std::uint64_t interface_identifier() const; + + static constexpr ipv6_address unspecified(); + static constexpr ipv6_address loopback(); + + static std::optional from_string(std::string_view string) noexcept; + + std::string to_string() const; + + constexpr bool operator==(const ipv6_address& other) const; + constexpr bool operator!=(const ipv6_address& other) const; + constexpr bool operator<(const ipv6_address& other) const; + constexpr bool operator>(const ipv6_address& other) const; + constexpr bool operator<=(const ipv6_address& other) const; + constexpr bool operator>=(const ipv6_address& other) const; + + }; + + class ip_address + { + public: + + // Constructs to IPv4 address 0.0.0.0 + ip_address() noexcept; + + ip_address(ipv4_address address) noexcept; + ip_address(ipv6_address address) noexcept; + + bool is_ipv4() const noexcept; + bool is_ipv6() const noexcept; + + const ipv4_address& to_ipv4() const; + const ipv6_address& to_ipv6() const; + + const std::uint8_t* bytes() const noexcept; + + std::string to_string() const; + + static std::optional from_string(std::string_view string) noexcept; + + bool operator==(const ip_address& rhs) const noexcept; + bool operator!=(const ip_address& rhs) const noexcept; + + // ipv4_address sorts less than ipv6_address + bool operator<(const ip_address& rhs) const noexcept; + bool operator>(const ip_address& rhs) const noexcept; + bool operator<=(const ip_address& rhs) const noexcept; + bool operator>=(const ip_address& rhs) const noexcept; + + }; +} +``` + +## `ip_endpoint`, `ipv4_endpoint` `ipv6_endpoint` + +Helper classes for representing an IP address and port-number. + +API Synopsis: +```c++ +namespace cppcoro::net +{ + class ipv4_endpoint + { + public: + ipv4_endpoint() noexcept; + explicit ipv4_endpoint(ipv4_address address, std::uint16_t port = 0) noexcept; + + const ipv4_address& address() const noexcept; + std::uint16_t port() const noexcept; + + std::string to_string() const; + static std::optional from_string(std::string_view string) noexcept; + }; + + bool operator==(const ipv4_endpoint& a, const ipv4_endpoint& b); + bool operator!=(const ipv4_endpoint& a, const ipv4_endpoint& b); + bool operator<(const ipv4_endpoint& a, const ipv4_endpoint& b); + bool operator>(const ipv4_endpoint& a, const ipv4_endpoint& b); + bool operator<=(const ipv4_endpoint& a, const ipv4_endpoint& b); + bool operator>=(const ipv4_endpoint& a, const ipv4_endpoint& b); + + class ipv6_endpoint + { + public: + ipv6_endpoint() noexcept; + explicit ipv6_endpoint(ipv6_address address, std::uint16_t port = 0) noexcept; + + const ipv6_address& address() const noexcept; + std::uint16_t port() const noexcept; + + std::string to_string() const; + static std::optional from_string(std::string_view string) noexcept; + }; + + bool operator==(const ipv6_endpoint& a, const ipv6_endpoint& b); + bool operator!=(const ipv6_endpoint& a, const ipv6_endpoint& b); + bool operator<(const ipv6_endpoint& a, const ipv6_endpoint& b); + bool operator>(const ipv6_endpoint& a, const ipv6_endpoint& b); + bool operator<=(const ipv6_endpoint& a, const ipv6_endpoint& b); + bool operator>=(const ipv6_endpoint& a, const ipv6_endpoint& b); + + class ip_endpoint + { + public: + // Constructs to IPv4 end-point 0.0.0.0:0 + ip_endpoint() noexcept; + + ip_endpoint(ipv4_endpoint endpoint) noexcept; + ip_endpoint(ipv6_endpoint endpoint) noexcept; + + bool is_ipv4() const noexcept; + bool is_ipv6() const noexcept; + + const ipv4_endpoint& to_ipv4() const; + const ipv6_endpoint& to_ipv6() const; + + ip_address address() const noexcept; + std::uint16_t port() const noexcept; + + std::string to_string() const; + + static std::optional from_string(std::string_view string) noexcept; + + bool operator==(const ip_endpoint& rhs) const noexcept; + bool operator!=(const ip_endpoint& rhs) const noexcept; + + // ipv4_endpoint sorts less than ipv6_endpoint + bool operator<(const ip_endpoint& rhs) const noexcept; + bool operator>(const ip_endpoint& rhs) const noexcept; + bool operator<=(const ip_endpoint& rhs) const noexcept; + bool operator>=(const ip_endpoint& rhs) const noexcept; + }; +} +``` + +# Functions + +## `sync_wait()` + +The `sync_wait()` function can be used to synchronously wait until the specified `awaitable` +completes. + +The specified awaitable will be `co_await`ed on current thread inside a newly created coroutine. + +The `sync_wait()` call will block until the operation completes and will return the result of +the `co_await` expression or rethrow the exception if the `co_await` expression completed with +an unhandled exception. + +The `sync_wait()` function is mostly useful for starting a top-level task from within `main()` +and waiting until the task finishes, in practice it is the only way to start the first/top-level +`task`. + +API Summary: +```c++ +// +namespace cppcoro +{ + template + auto sync_wait(AWAITABLE&& awaitable) + -> typename awaitable_traits::await_result_t; +} +``` + +Examples: +```c++ +void example_task() +{ + auto makeTask = []() -> task + { + co_return "foo"; + }; + + auto task = makeTask(); + + // start the lazy task and wait until it completes + sync_wait(task); // -> "foo" + sync_wait(makeTask()); // -> "foo" +} + +void example_shared_task() +{ + auto makeTask = []() -> shared_task + { + co_return "foo"; + }; + + auto task = makeTask(); + // start the shared task and wait until it completes + sync_wait(task) == "foo"; + sync_wait(makeTask()) == "foo"; +} +``` + +## `when_all_ready()` + +The `when_all_ready()` function can be used to create a new awaitable that completes when +all of the input awaitables complete. + +Input tasks can be any type of awaitable. + +When the returned awaitable is `co_await`ed it will `co_await` each of the input awaitables +in turn on the awaiting thread in the order they are passed to the `when_all_ready()` +function. If these tasks to not complete synchronously then they will execute concurrently. + +Once all of the `co_await` expressions on input awaitables have run to completion the +returned awaitable will complete and resume the awaiting coroutine. The awaiting coroutine +will be resumed on the thread of the input awaitable that is last to complete. + +The returned awaitable is guaranteed not to throw an exception when `co_await`ed, +even if some of the input awaitables fail with an unhandled exception. + +Note, however, that the `when_all_ready()` call itself may throw `std::bad_alloc` if it +was unable to allocate memory for the coroutine frames required to await each of the +input awaitables. It may also throw an exception if any of the input awaitable objects +throw from their copy/move constructors. + +The result of `co_await`ing the returned awaitable is a `std::tuple` or `std::vector` +of `when_all_task` objects. These objects allow you to obtain the result (or exception) +of each input awaitable separately by calling the `when_all_task::result()` +method of the corresponding output task. +This allows the caller to concurrently await multiple awaitables and synchronize on +their completion while still retaining the ability to subsequently inspect the results of +each of the `co_await` operations for success/failure. + +This differs from `when_all()` where the failure of any individual `co_await` operation +causes the overall operation to fail with an exception. This means you cannot determine +which of the component `co_await` operations failed and also prevents you from obtaining +the results of the other `co_await` operations. + +API summary: +```c++ +// +namespace cppcoro +{ + // Concurrently await multiple awaitables. + // + // Returns an awaitable object that, when co_await'ed, will co_await each of the input + // awaitable objects and will resume the awaiting coroutine only when all of the + // component co_await operations complete. + // + // Result of co_await'ing the returned awaitable is a std::tuple of detail::when_all_task, + // one for each input awaitable and where T is the result-type of the co_await expression + // on the corresponding awaitable. + // + // AWAITABLES must be awaitable types and must be movable (if passed as rvalue) or copyable + // (if passed as lvalue). The co_await expression will be executed on an rvalue of the + // copied awaitable. + template + auto when_all_ready(AWAITABLES&&... awaitables) + -> Awaitable::await_result_t>...>>; + + // Concurrently await each awaitable in a vector of input awaitables. + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits::await_result_t> + auto when_all_ready(std::vector awaitables) + -> Awaitable>>; +} +``` + +Example usage: +```c++ +task get_record(int id); + +task<> example1() +{ + // Run 3 get_record() operations concurrently and wait until they're all ready. + // Returns a std::tuple of tasks that can be unpacked using structured bindings. + auto [task1, task2, task3] = co_await when_all_ready( + get_record(123), + get_record(456), + get_record(789)); + + // Unpack the result of each task + std::string& record1 = task1.result(); + std::string& record2 = task2.result(); + std::string& record3 = task3.result(); + + // Use records.... +} + +task<> example2() +{ + // Create the input tasks. They don't start executing yet. + std::vector> tasks; + for (int i = 0; i < 1000; ++i) + { + tasks.emplace_back(get_record(i)); + } + + // Execute all tasks concurrently. + std::vector> resultTasks = + co_await when_all_ready(std::move(tasks)); + + // Unpack and handle each result individually once they're all complete. + for (int i = 0; i < 1000; ++i) + { + try + { + std::string& record = tasks[i].result(); + std::cout << i << " = " << record << std::endl; + } + catch (const std::exception& ex) + { + std::cout << i << " : " << ex.what() << std::endl; + } + } +} +``` + +## `when_all()` + +The `when_all()` function can be used to create a new Awaitable that when `co_await`ed +will `co_await` each of the input awaitables concurrently and return an aggregate of +their individual results. + +When the returned awaitable is awaited, it will `co_await` each of the input awaitables +on the current thread. Once the first awaitable suspends, the second task will be started, +and so on. The operations execute concurrently until they have all run to completion. + +Once all component `co_await` operations have run to completion, an aggregate of the +results is constructed from each individual result. If an exception is thrown by any +of the input tasks or if the construction of the aggregate result throws an exception +then the exception will propagate out of the `co_await` of the returned awaitable. + +If multiple `co_await` operations fail with an exception then one of the exceptions +will propagate out of the `co_await when_all()` expression the other exceptions will be silently +ignored. It is not specified which operation's exception will be chosen. + +If it is important to know which component `co_await` operation failed or to retain +the ability to obtain results of other operations even if some of them fail then you +you should use `when_all_ready()` instead. + +API Summary: +```c++ +// +namespace cppcoro +{ + // Variadic version. + // + // Note that if the result of `co_await awaitable` yields a void-type + // for some awaitables then the corresponding component for that awaitable + // in the tuple will be an empty struct of type detail::void_value. + template + auto when_all(AWAITABLES&&... awaitables) + -> Awaitable::await_result_t...>>; + + // Overload for vector>. + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + auto when_all(std::vector awaitables) + -> Awaitable; + + // Overload for vector> that yield a value when awaited. + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + auto when_all(std::vector awaitables) + -> Awaitable, + std::reference_wrapper>, + std::remove_reference_t>>>; +} +``` + +Examples: +```c++ +task get_a(); +task get_b(); + +task<> example1() +{ + // Run get_a() and get_b() concurrently. + // Task yields a std::tuple which can be unpacked using structured bindings. + auto [a, b] = co_await when_all(get_a(), get_b()); + + // use a, b +} + +task get_record(int id); + +task<> example2() +{ + std::vector> tasks; + for (int i = 0; i < 1000; ++i) + { + tasks.emplace_back(get_record(i)); + } + + // Concurrently execute all get_record() tasks. + // If any of them fail with an exception then the exception will propagate + // out of the co_await expression once they have all completed. + std::vector records = co_await when_all(std::move(tasks)); + + // Process results + for (int i = 0; i < 1000; ++i) + { + std::cout << i << " = " << records[i] << std::endl; + } +} +``` + +## `fmap()` + +The `fmap()` function can be used to apply a callable function to the value(s) contained within +a container-type, returning a new container-type of the results of applying the function the +contained value(s). + +The `fmap()` function can apply a function to values of type `generator`, `recursive_generator` +and `async_generator` as well as any value that supports the `Awaitable` concept (eg. `task`). + +Each of these types provides an overload for `fmap()` that takes two arguments; a function to apply +and the container value. +See documentation for each type for the supported `fmap()` overloads. + +For example, the `fmap()` function can be used to apply a function to the eventual result of +a `task`, producing a new `task` that will complete with the return-value of the function. +```c++ +// Given a function you want to apply that converts +// a value of type A to value of type B. +B a_to_b(A value); + +// And a task that yields a value of type A +cppcoro::task get_an_a(); + +// We can apply the function to the result of the task using fmap() +// and obtain a new task yielding the result. +cppcoro::task bTask = fmap(a_to_b, get_an_a()); + +// An alternative syntax is to use the pipe notation. +cppcoro::task bTask = get_an_a() | cppcoro::fmap(a_to_b); +``` + +API Summary: +```c++ +// +namespace cppcoro +{ + template + struct fmap_transform + { + fmap_transform(FUNC&& func) noexcept(std::is_nothrow_move_constructible_v); + FUNC func; + }; + + // Type-deducing constructor for fmap_transform object that can be used + // in conjunction with operator|. + template + fmap_transform fmap(FUNC&& func); + + // operator| overloads for providing pipe-based syntactic sugar for fmap() + // such that the expression: + // | cppcoro::fmap() + // is equivalent to: + // fmap(, ) + + template + decltype(auto) operator|(T&& value, fmap_transform&& transform); + + template + decltype(auto) operator|(T&& value, fmap_transform& transform); + + template + decltype(auto) operator|(T&& value, const fmap_transform& transform); + + // Generic overload for all awaitable types. + // + // Returns an awaitable that when co_awaited, co_awaits the specified awaitable + // and applies the specified func to the result of the 'co_await awaitable' + // expression as if by 'std::invoke(func, co_await awaitable)'. + // + // If the type of 'co_await awaitable' expression is 'void' then co_awaiting the + // returned awaitable is equivalent to 'co_await awaitable, func()'. + template< + typename FUNC, + typename AWAITABLE, + std::enable_if_t, int> = 0> + auto fmap(FUNC&& func, AWAITABLE&& awaitable) + -> Awaitable::await_result_t>>; +} +``` + +The `fmap()` function is designed to look up the correct overload by argument-dependent +lookup (ADL) so it should generally be called without the `cppcoro::` prefix. + +## `resume_on()` + +The `resume_on()` function can be used to control the execution context that an awaitable +will resume the awaiting coroutine on when awaited. When applied to an `async_generator` +it controls which execution context the `co_await g.begin()` and `co_await ++it` operations +resume the awaiting coroutines on. + +Normally, the awaiting coroutine of an awaitable (eg. a `task`) or `async_generator` will +resume execution on whatever thread the operation completed on. In some cases this may not +be the thread that you want to continue executing on. In these cases you can use the +`resume_on()` function to create a new awaitable or generator that will resume execution +on a thread associated with a specified scheduler. + +The `resume_on()` function can be used either as a normal function returning a new awaitable/generator. +Or it can be used in a pipeline-syntax. + +Example: +```c++ +task load_record(int id); + +ui_thread_scheduler uiThreadScheduler; + +task<> example() +{ + // This will start load_record() on the current thread. + // Then when load_record() completes (probably on an I/O thread) + // it will reschedule execution onto thread pool and call to_json + // Once to_json completes it will transfer execution onto the + // ui thread before resuming this coroutine and returning the json text. + task jsonTask = + load_record(123) + | cppcoro::resume_on(threadpool::default()) + | cppcoro::fmap(to_json) + | cppcoro::resume_on(uiThreadScheduler); + + // At this point, all we've done is create a pipeline of tasks. + // The tasks haven't started executing yet. + + // Await the result. Starts the pipeline of tasks. + std::string jsonText = co_await jsonTask; + + // Guaranteed to be executing on ui thread here. + + someUiControl.set_text(jsonText); +} +``` + +API Summary: +```c++ +// +namespace cppcoro +{ + template + auto resume_on(SCHEDULER& scheduler, AWAITABLE awaitable) + -> Awaitable::await_traits_t>; + + template + async_generator resume_on(SCHEDULER& scheduler, async_generator source); + + template + struct resume_on_transform + { + explicit resume_on_transform(SCHEDULER& scheduler) noexcept; + SCHEDULER& scheduler; + }; + + // Construct a transform/operation that can be applied to a source object + // using "pipe" notation (ie. operator|). + template + resume_on_transform resume_on(SCHEDULER& scheduler) noexcept; + + // Equivalent to 'resume_on(transform.scheduler, std::forward(value))' + template + decltype(auto) operator|(T&& value, resume_on_transform transform) + { + return resume_on(transform.scheduler, std::forward(value)); + } +} +``` + +## `schedule_on()` + +The `schedule_on()` function can be used to change the execution context that a given +awaitable or `async_generator` starts executing on. + +When applied to an `async_generator` it also affects which execution context it resumes +on after `co_yield` statement. + +Note that the `schedule_on` transform does not specify the thread that the awaitable or +`async_generator` will complete or yield results on, that is up to the implementation of +the awaitable or generator. + +See the `resume_on()` operator for a transform that controls the thread the operation completes on. + +For example: +```c++ +task get_value(); +io_service ioSvc; + +task<> example() +{ + // Starts executing get_value() on the current thread. + int a = co_await get_value(); + + // Starts executing get_value() on a thread associated with ioSvc. + int b = co_await schedule_on(ioSvc, get_value()); +} +``` + +API Summary: +```c++ +// +namespace cppcoro +{ + // Return a task that yields the same result as 't' but that + // ensures that 't' is co_await'ed on a thread associated with + // the specified scheduler. Resulting task will complete on + // whatever thread 't' would normally complete on. + template + auto schedule_on(SCHEDULER& scheduler, AWAITABLE awaitable) + -> Awaitable::await_result_t>; + + // Return a generator that yields the same sequence of results as + // 'source' but that ensures that execution of the coroutine starts + // execution on a thread associated with 'scheduler' and resumes + // after a 'co_yield' on a thread associated with 'scheduler'. + template + async_generator schedule_on(SCHEDULER& scheduler, async_generator source); + + template + struct schedule_on_transform + { + explicit schedule_on_transform(SCHEDULER& scheduler) noexcept; + SCHEDULER& scheduler; + }; + + template + schedule_on_transform schedule_on(SCHEDULER& scheduler) noexcept; + + template + decltype(auto) operator|(T&& value, schedule_on_transform transform); +} +``` + +# Metafunctions + +## `awaitable_traits` + +This template metafunction can be used to determine what the resulting type of a `co_await` expression +will be if applied to an expression of type `T`. + +Note that this assumes the value of type `T` is being awaited in a context where it is unaffected by +any `await_transform` applied by the coroutine's promise object. The results may differ if a value +of type `T` is awaited in such a context. + +The `awaitable_traits` template metafunction does not define the `awaiter_t` or `await_result_t` +nested typedefs if type, `T`, is not awaitable. This allows its use in SFINAE contexts that disables +overloads when `T` is not awaitable. + +API Summary: +```c++ +// +namespace cppcoro +{ + template + struct awaitable_traits + { + // The type that results from applying `operator co_await()` to a value + // of type T, if T supports an `operator co_await()`, otherwise is type `T&&`. + typename awaiter_t = ; + + // The type of the result of co_await'ing a value of type T. + typename await_result_t = ; + }; +} +``` + +## `is_awaitable` + +The `is_awaitable` template metafunction allows you to query whether or not a given +type can be `co_await`ed or not from within a coroutine. + +API Summary: +```c++ +// +namespace cppcoro +{ + template + struct is_awaitable : std::bool_constant<...> + {}; + + template + constexpr bool is_awaitable_v = is_awaitable::value; +} +``` + +# Concepts + +## `Awaitable` concept + +An `Awaitable` is a concept that indicates that a type can be `co_await`ed in a coroutine context +that has no `await_transform` overloads and that the result of the `co_await` expression has type, `T`. + +For example, the type `task` implements the concept `Awaitable` whereas the type `task&` +implements the concept `Awaitable`. + +## `Awaiter` concept + +An `Awaiter` is a concept that indicates a type contains the `await_ready`, `await_suspend` and +`await_resume` methods required to implement the protocol for suspending/resuming an awaiting +coroutine. + +A type that satisfies `Awaiter` must have, for an instance of the type, `awaiter`: +- `awaiter.await_ready()` -> `bool` +- `awaiter.await_suspend(cppcoro::coroutine_handle{})` -> `void` or `bool` or `cppcoro::coroutine_handle

` for some `P`. +- `awaiter.await_resume()` -> `T` + +Any type that implements the `Awaiter` concept also implements the `Awaitable` concept. + +## `Scheduler` concept + +A `Scheduler` is a concept that allows scheduling execution of coroutines within some execution context. + +```c++ +concept Scheduler +{ + Awaitable schedule(); +} +``` + +Given a type, `S`, that implements the `Scheduler` concept, and an instance, `s`, of type `S`: +* The `s.schedule()` method returns an awaitable-type such that `co_await s.schedule()` + will unconditionally suspend the current coroutine and schedule it for resumption on the + execution context associated with the scheduler, `s`. +* The result of the `co_await s.schedule()` expression has type `void`. + +```c++ +cppcoro::task<> f(Scheduler& scheduler) +{ + // Execution of the coroutine is initially on the caller's execution context. + + // Suspends execution of the coroutine and schedules it for resumption on + // the scheduler's execution context. + co_await scheduler.schedule(); + + // At this point the coroutine is now executing on the scheduler's + // execution context. +} +``` + +## `DelayedScheduler` concept + +A `DelayedScheduler` is a concept that allows a coroutine to schedule itself for execution on +the scheduler's execution context after a specified duration of time has elapsed. + +```c++ +concept DelayedScheduler : Scheduler +{ + template + Awaitable schedule_after(std::chrono::duration delay); + + template + Awaitable schedule_after( + std::chrono::duration delay, + cppcoro::cancellation_token cancellationToken); +} +``` + +Given a type, `S`, that implements the `DelayedScheduler` and an instance, `s` of type `S`: +* The `s.schedule_after(delay)` method returns an object that can be awaited + such that `co_await s.schedule_after(delay)` suspends the current coroutine + for a duration of `delay` before scheduling the coroutine for resumption on + the execution context associated with the scheduler, `s`. +* The `co_await s.schedule_after(delay)` expression has type `void`. + +# Building + +andreasbuhr/cppcoro uses CMake as a build system. + +## Building on Windows + +This library currently requires Visual Studio 2017 or later and the Windows 10 SDK. + +Support for Linux ([#15](https://github.com/lewissbaker/cppcoro/issues/15)) is planned. + +### Prerequisites + +The CMakeLists requires version 3.13 or later. + +Ensure Python 2.7 interpreter is in your PATH and available as 'python'. + +Ensure Visual Studio 2017 Update 3 or later is installed. +Note that there are some known issues with coroutines in Update 2 or earlier that have been fixed in Update 3. + +You can also use an experimental version of the Visual Studio compiler by downloading a NuGet package from https://vcppdogfooding.azurewebsites.net/ and unzipping the .nuget file to a directory. + +Ensure that you have the Windows 10 SDK installed. +It will use the latest Windows 10 SDK and Universal C Runtime version by default. + +### Cloning the repository + +The cppcoro repository makes use of git submodules to pull in the source for the Cake build system. + +This means you need to pass the `--recursive` flag to the `git clone` command. eg. +``` +c:\Code> git clone --recursive https://github.com/lewissbaker/cppcoro.git +``` + +If you have already cloned cppcoro, then you should update the submodules after pulling changes. +``` +c:\Code\cppcoro> git submodule update --init --recursive +``` + +### Building from the command-line + +#### With CMake + +Cppcoro follows the usual CMake workflow with no custom options added. Notable [standard CMake options](https://cmake.org/cmake/help/latest/manual/cmake-variables.7.html): + +| Flag | Description | Default Value | +|----------------------|------------------------------|------------------------| +| BUILD_TESTING | Build the unit tests | ON | +| BUILD_SHARED_LIBS | Build as a shared library | OFF | +| CMAKE_BUILD_TYPE | Build as `Debug`/`Release` | | +| CMAKE_INSTALL_PREFIX | Where to install the library | `/usr/local` (on Unix) | + +CMake also respects the [conventional environment variables](https://cmake.org/cmake/help/latest/manual/cmake-env-variables.7.html): + +| Environment Variable | Description | +|----------------------|-------------------------------| +| CXX | Path to the C++ compiler | +| CXXFLAGS | C++ compiler flags to prepend | +| LDFLAGS | Linker flags to prepend | + +Example: + +```bash +cd +mkdir build +cd build +export CXX=clang++ +export CXXFLAGS="-stdlib=libc++ -march=native" +export LDFLAGS="-stdlib=libc++ -fuse-ld=lld -Wl,--gdb-index" +cmake .. [-GNinja] -DCMAKE_INSTALL_PREFIX=$HOME/.local -DBUILD_SHARED_LIBS=ON +ninja # or make -jN +ninja test # Run the tests +ninja install +``` + +The CMake build scripts will also install a `cppcoroConfig.cmake` file for consumers to use. +It will check at the consumer site that coroutines are indeed supported by the system and enable the appropriate compiler flag for Clang or MSVC, respectively. +Assuming cppcoro has been installed to `$HOME/.local` like in the example above it can be consumed like this: + +```cmake +find_package(cppcoro REQUIRED) +add_executable(app main.cpp) +target_link_libraries(app PRIVATE cppcoro::cppcoro) +``` + +```bash +$ cmake . -Dcppcoro_ROOT=$HOME/.local +# ... +-- Performing Test _CXX_COROUTINES_SUPPORTS_MS_FLAG +-- Performing Test _CXX_COROUTINES_SUPPORTS_MS_FLAG - Failed +-- Performing Test _CXX_COROUTINES_SUPPORTS_CORO_FLAG +-- Performing Test _CXX_COROUTINES_SUPPORTS_CORO_FLAG - Success +-- Looking for C++ include coroutine +-- Looking for C++ include coroutine - not found +-- Looking for C++ include experimental/coroutine +-- Looking for C++ include experimental/coroutine - found +-- Configuring done +-- Generating done +# ... +``` + +## Building on Linux + +The cppcoro project can also be built under Linux using Clang + libc++ 5.0 or later. + +Building cppcoro has been tested under Ubuntu 17.04. + +### Prerequisities + +Ensure you have the following packages installed: +* Python 2.7 +* Clang >= 5.0 +* LLD >= 5.0 +* libc++ >= 5.0 + + +### Building cppcoro + +This is assuming you have Clang and libc++ built and installed. + +If you don't have Clang configured yet, see the following sections +for details on setting up Clang for building with cppcoro. + +Checkout cppcoro and its submodules: +``` +git clone --recursive https://github.com/lewissbaker/cppcoro.git cppcoro +``` + +# Support + +GitHub issues are the primary mechanism for support, bug reports and feature requests. + +Contributions are welcome and pull-requests will be happily reviewed. +I only ask that you agree to license any contributions that you make under the MIT license. + +If you have general questions about C++ coroutines, you can generally find someone to help +in the `#coroutines` channel on [Cpplang Slack](https://cpplang.slack.com/) group. diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt new file mode 100644 index 0000000..e69de29 diff --git a/cmake/FindCoroutines.cmake b/cmake/FindCoroutines.cmake new file mode 100644 index 0000000..672c549 --- /dev/null +++ b/cmake/FindCoroutines.cmake @@ -0,0 +1,282 @@ +# Copyright (c) 2019-present, Facebook, Inc. +# +# This source code is licensed under the Apache License found in the +# LICENSE.txt file in the root directory of this source tree. + +#[=======================================================================[.rst: + +FindCoroutines +############## + +This module supports the C++ standard support for coroutines. Use +the :imp-target:`std::coroutines` imported target to + +Options +******* + +The ``COMPONENTS`` argument to this module supports the following values: + +.. find-component:: Experimental + :name: coro.Experimental + + Allows the module to find the "experimental" Coroutines TS + version of the coroutines library. This is the library that should be + used with the ``std::experimental`` namespace. + +.. find-component:: Final + :name: coro.Final + + Finds the final C++20 standard version of coroutines. + +If no components are provided, behaves as if the +:find-component:`coro.Final` component was specified. + +If both :find-component:`coro.Experimental` and :find-component:`coro.Final` are +provided, first looks for ``Final``, and falls back to ``Experimental`` in case +of failure. If ``Final`` is found, :imp-target:`std::coroutines` and all +:ref:`variables ` will refer to the ``Final`` version. + + +Imported Targets +**************** + +.. imp-target:: std::coroutines + + The ``std::coroutines`` imported target is defined when any requested + version of the C++ coroutines library has been found, whether it is + *Experimental* or *Final*. + + If no version of the coroutines library is available, this target will not + be defined. + + .. note:: + This target has ``cxx_std_17`` as an ``INTERFACE`` + :ref:`compile language standard feature `. Linking + to this target will automatically enable C++17 if no later standard + version is already required on the linking target. + + +.. _coro.variables: + +Variables +********* + +.. variable:: CXX_COROUTINES_HAVE_COROUTINES + + Set to ``TRUE`` when coroutines are supported in both the language and the + library. + +.. variable:: CXX_COROUTINES_HEADER + + Set to either ``coroutine`` or ``experimental/coroutine`` depending on + whether :find-component:`coro.Final` or :find-component:`coro.Experimental` was + found. + +.. variable:: CXX_COROUTINES_NAMESPACE + + Set to either ``std`` or ``std::experimental`` + depending on whether :find-component:`coro.Final` or + :find-component:`coro.Experimental` was found. + + +Examples +******** + +Using `find_package(Coroutines)` with no component arguments: + +.. code-block:: cmake + + find_package(Coroutines REQUIRED) + + add_executable(my-program main.cpp) + target_link_libraries(my-program PRIVATE std::coroutines) + + +#]=======================================================================] + + +if(TARGET std::coroutines) + # This module has already been processed. Don't do it again. + return() +endif() + +include(CheckCXXCompilerFlag) +include(CMakePushCheckState) +include(CheckIncludeFileCXX) +include(CheckCXXSourceCompiles) + +cmake_push_check_state() + +set(CMAKE_REQUIRED_QUIET ${Coroutines_FIND_QUIETLY}) + +check_cxx_compiler_flag(/await _CXX_COROUTINES_SUPPORTS_MS_FLAG) +check_cxx_compiler_flag(/await:heapelide _CXX_COROUTINES_SUPPORTS_MS_HEAPELIDE_FLAG) +check_cxx_compiler_flag(-fcoroutines-ts _CXX_COROUTINES_SUPPORTS_TS_FLAG) +check_cxx_compiler_flag(-fcoroutines _CXX_COROUTINES_SUPPORTS_CORO_FLAG) + +if(_CXX_COROUTINES_SUPPORTS_MS_FLAG) + set(_CXX_COROUTINES_EXTRA_FLAGS "/await") + if(_CXX_COROUTINES_SUPPORTS_MS_HEAPELIDE_FLAG AND CMAKE_SIZEOF_VOID_P GREATER_EQUAL 8) + list(APPEND _CXX_COROUTINES_EXTRA_FLAGS "/await:heapelide") + endif() +elseif(_CXX_COROUTINES_SUPPORTS_TS_FLAG) + set(_CXX_COROUTINES_EXTRA_FLAGS "-fcoroutines-ts") +elseif(_CXX_COROUTINES_SUPPORTS_CORO_FLAG) + set(_CXX_COROUTINES_EXTRA_FLAGS "-fcoroutines") +endif() + +# Normalize and check the component list we were given +set(want_components ${Coroutines_FIND_COMPONENTS}) +if(Coroutines_FIND_COMPONENTS STREQUAL "") + set(want_components Final) +endif() + +# Warn on any unrecognized components +set(extra_components ${want_components}) +list(REMOVE_ITEM extra_components Final Experimental) +foreach(component IN LISTS extra_components) + message(WARNING "Extraneous find_package component for Coroutines: ${component}") +endforeach() + +# Detect which of Experimental and Final we should look for +set(find_experimental TRUE) +set(find_final TRUE) +if(NOT "Final" IN_LIST want_components) + set(find_final FALSE) +endif() +if(NOT "Experimental" IN_LIST want_components) + set(find_experimental FALSE) +endif() + +if(find_final) + check_include_file_cxx("coroutine" _CXX_COROUTINES_HAVE_HEADER) + if(_CXX_COROUTINES_HAVE_HEADER) + check_cxx_source_compiles("#include \n typedef std::suspend_never blub; \nint main() {} " _CXX_COROUTINES_FINAL_HEADER_COMPILES) + set(_CXX_COROUTINES_HAVE_HEADER "${_CXX_COROUTINES_FINAL_HEADER_COMPILES}") + endif() + + if(NOT _CXX_COROUTINES_HAVE_HEADER) + cmake_push_check_state() + set(CMAKE_REQUIRED_FLAGS "${_CXX_COROUTINES_EXTRA_FLAGS}") + check_include_file_cxx("coroutine" _CXX_COROUTINES_HAVE_HEADER_WITH_FLAG) + if(_CXX_COROUTINES_HAVE_HEADER_WITH_FLAG) + check_cxx_source_compiles("#include \n typedef std::suspend_never blub; \nint main() {} " _CXX_COROUTINES_FINAL_HEADER_COMPILES_WITH_FLAG) + set(_CXX_COROUTINES_HAVE_HEADER_WITH_FLAG "${_CXX_COROUTINES_FINAL_HEADER_COMPILES_WITH_FLAG}") + endif() + set(_CXX_COROUTINES_HAVE_HEADER "${_CXX_COROUTINES_HAVE_HEADER_WITH_FLAG}") + cmake_pop_check_state() + endif() + mark_as_advanced(_CXX_COROUTINES_HAVE_HEADER) + if(_CXX_COROUTINES_HAVE_HEADER) + # We found the non-experimental header. Don't bother looking for the + # experimental one. + set(find_experimental FALSE) + endif() +else() + set(_CXX_COROUTINES_HAVE_HEADER FALSE) +endif() + +if(find_experimental) + check_include_file_cxx("experimental/coroutine" _CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER) + if(NOT _CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER) + cmake_push_check_state() + set(CMAKE_REQUIRED_FLAGS "${_CXX_COROUTINES_EXTRA_FLAGS}") + check_include_file_cxx("experimental/coroutine" _CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER_WITH_FLAG) + set(_CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER "${_CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER_WITH_FLAG}") + cmake_pop_check_state() + endif() + mark_as_advanced(_CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER) +else() + set(_CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER FALSE) +endif() + +if(_CXX_COROUTINES_HAVE_HEADER) + set(_have_coro TRUE) + set(_coro_header coroutine) + set(_coro_namespace std) +elseif(_CXX_COROUTINES_HAVE_EXPERIMENTAL_HEADER) + set(_have_coro TRUE) + set(_coro_header experimental/coroutine) + set(_coro_namespace std::experimental) +else() + set(_have_coro FALSE) +endif() + +set(CXX_COROUTINES_HAVE_COROUTINES ${_have_coro} CACHE BOOL "TRUE if we have the C++ coroutines feature") +set(CXX_COROUTINES_HEADER ${_coro_header} CACHE STRING "The header that should be included to obtain the coroutines APIs") +set(CXX_COROUTINES_NAMESPACE ${_coro_namespace} CACHE STRING "The C++ namespace that contains the coroutines APIs") + +set(_found FALSE) + +if(CXX_COROUTINES_HAVE_COROUTINES) + # We have some coroutines library available. Do link checks + string(CONFIGURE [[ + #include + #include <@CXX_COROUTINES_HEADER@> + + struct present { + struct promise_type { + int result; + present get_return_object() { return present{*this}; } + @CXX_COROUTINES_NAMESPACE@::suspend_never initial_suspend() { return {}; } + @CXX_COROUTINES_NAMESPACE@::suspend_always final_suspend() noexcept { return {}; } + void return_value(int i) { result = i; } + void unhandled_exception() {} + }; + friend struct promise_type; + present(present&& that) : coro_(std::exchange(that.coro_, {})) {} + ~present() { if(coro_) coro_.destroy(); } + bool await_ready() const { return true; } + void await_suspend(@CXX_COROUTINES_NAMESPACE@::coroutine_handle<>) const {} + int await_resume() const { return coro_.promise().result; } + private: + present(promise_type& promise) + : coro_(@CXX_COROUTINES_NAMESPACE@::coroutine_handle::from_promise(promise)) {} + @CXX_COROUTINES_NAMESPACE@::coroutine_handle coro_; + }; + + present f(int n) { + if (n < 2) + co_return 1; + else + co_return n * co_await f(n - 1); + } + + int main() { + return f(5).await_resume() != 120; + } + ]] code @ONLY) + + # Try to compile a simple coroutines program without any compiler flags + check_cxx_source_compiles("${code}" CXX_COROUTINES_NO_AWAIT_NEEDED) + + set(can_link ${CXX_COROUTINES_NO_AWAIT_NEEDED}) + + if(NOT CXX_COROUTINES_NO_AWAIT_NEEDED) + # Add the -fcoroutines-ts (or /await) flag + set(CMAKE_REQUIRED_FLAGS "${_CXX_COROUTINES_EXTRA_FLAGS}") + check_cxx_source_compiles("${code}" CXX_COROUTINES_AWAIT_NEEDED) + set(can_link "${CXX_COROUTINES_AWAIT_NEEDED}") + endif() + + if(can_link) + add_library(std::coroutines INTERFACE IMPORTED) + set(_found TRUE) + + if(CXX_COROUTINES_NO_AWAIT_NEEDED) + # Nothing to add... + elseif(CXX_COROUTINES_AWAIT_NEEDED) + target_compile_options(std::coroutines INTERFACE ${_CXX_COROUTINES_EXTRA_FLAGS}) + endif() + else() + set(CXX_COROUTINES_HAVE_COROUTINES FALSE) + endif() +endif() + +cmake_pop_check_state() + +set(Coroutines_FOUND ${_found} CACHE BOOL "TRUE if we can compile and link a program using std::coroutines" FORCE) + +if(Coroutines_FIND_REQUIRED AND NOT Coroutines_FOUND) + message(FATAL_ERROR "Cannot compile simple program using std::coroutines. Is C++17 or later activated?") +endif() diff --git a/cmake/cppcoroConfig.cmake b/cmake/cppcoroConfig.cmake new file mode 100644 index 0000000..0b9f9c0 --- /dev/null +++ b/cmake/cppcoroConfig.cmake @@ -0,0 +1,6 @@ +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}) + +include(CMakeFindDependencyMacro) +find_dependency(Coroutines QUIET REQUIRED) + +include("${CMAKE_CURRENT_LIST_DIR}/cppcoroTargets.cmake") diff --git a/include/cppcoro/async_auto_reset_event.hpp b/include/cppcoro/async_auto_reset_event.hpp new file mode 100644 index 0000000..b74ae2b --- /dev/null +++ b/include/cppcoro/async_auto_reset_event.hpp @@ -0,0 +1,98 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ASYNC_AUTO_RESET_EVENT_HPP_INCLUDED +#define CPPCORO_ASYNC_AUTO_RESET_EVENT_HPP_INCLUDED + +#include +#include +#include + +namespace cppcoro +{ + class async_auto_reset_event_operation; + + /// An async auto-reset event is a coroutine synchronisation abstraction + /// that allows one or more coroutines to wait until some thread calls + /// set() on the event. + /// + /// When a coroutine awaits a 'set' event the event is automatically + /// reset back to the 'not set' state, thus the name 'auto reset' event. + class async_auto_reset_event + { + public: + + /// Initialise the event to either 'set' or 'not set' state. + async_auto_reset_event(bool initiallySet = false) noexcept; + + ~async_auto_reset_event(); + + /// Wait for the event to enter the 'set' state. + /// + /// If the event is already 'set' then the event is set to the 'not set' + /// state and the awaiting coroutine continues without suspending. + /// Otherwise, the coroutine is suspended and later resumed when some + /// thread calls 'set()'. + /// + /// Note that the coroutine may be resumed inside a call to 'set()' + /// or inside another thread's call to 'operator co_await()'. + async_auto_reset_event_operation operator co_await() const noexcept; + + /// Set the state of the event to 'set'. + /// + /// If there are pending coroutines awaiting the event then one + /// pending coroutine is resumed and the state is immediately + /// set back to the 'not set' state. + /// + /// This operation is a no-op if the event was already 'set'. + void set() noexcept; + + /// Set the state of the event to 'not-set'. + /// + /// This is a no-op if the state was already 'not set'. + void reset() noexcept; + + private: + + friend class async_auto_reset_event_operation; + + void resume_waiters(std::uint64_t initialState) const noexcept; + + // Bits 0-31 - Set count + // Bits 32-63 - Waiter count + mutable std::atomic m_state; + + mutable std::atomic m_newWaiters; + + mutable async_auto_reset_event_operation* m_waiters; + + }; + + class async_auto_reset_event_operation + { + public: + + async_auto_reset_event_operation() noexcept; + + explicit async_auto_reset_event_operation(const async_auto_reset_event& event) noexcept; + + async_auto_reset_event_operation(const async_auto_reset_event_operation& other) noexcept; + + bool await_ready() const noexcept { return m_event == nullptr; } + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() const noexcept {} + + private: + + friend class async_auto_reset_event; + + const async_auto_reset_event* m_event; + async_auto_reset_event_operation* m_next; + cppcoro::coroutine_handle<> m_awaiter; + std::atomic m_refCount; + + }; +} + +#endif diff --git a/include/cppcoro/async_generator.hpp b/include/cppcoro/async_generator.hpp new file mode 100644 index 0000000..50403e7 --- /dev/null +++ b/include/cppcoro/async_generator.hpp @@ -0,0 +1,1088 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ASYNC_GENERATOR_HPP_INCLUDED +#define CPPCORO_ASYNC_GENERATOR_HPP_INCLUDED + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace cppcoro +{ + template + class async_generator; + +#if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + + namespace detail + { + template + class async_generator_iterator; + class async_generator_yield_operation; + class async_generator_advance_operation; + + class async_generator_promise_base + { + public: + + async_generator_promise_base() noexcept + : m_exception(nullptr) + { + // Other variables left intentionally uninitialised as they're + // only referenced in certain states by which time they should + // have been initialised. + } + + async_generator_promise_base(const async_generator_promise_base& other) = delete; + async_generator_promise_base& operator=(const async_generator_promise_base& other) = delete; + + cppcoro::suspend_always initial_suspend() const noexcept + { + return {}; + } + + async_generator_yield_operation final_suspend() noexcept; + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + void return_void() noexcept + { + } + + /// Query if the generator has reached the end of the sequence. + /// + /// Only valid to call after resuming from an awaited advance operation. + /// i.e. Either a begin() or iterator::operator++() operation. + bool finished() const noexcept + { + return m_currentValue == nullptr; + } + + void rethrow_if_unhandled_exception() + { + if (m_exception) + { + std::rethrow_exception(std::move(m_exception)); + } + } + + protected: + + async_generator_yield_operation internal_yield_value() noexcept; + + private: + + friend class async_generator_yield_operation; + friend class async_generator_advance_operation; + + std::exception_ptr m_exception; + + cppcoro::coroutine_handle<> m_consumerCoroutine; + + protected: + + void* m_currentValue; + }; + + class async_generator_yield_operation final + { + public: + + async_generator_yield_operation(cppcoro::coroutine_handle<> consumer) noexcept + : m_consumer(consumer) + {} + + bool await_ready() const noexcept + { + return false; + } + + cppcoro::coroutine_handle<> + await_suspend([[maybe_unused]] cppcoro::coroutine_handle<> producer) noexcept + { + return m_consumer; + } + + void await_resume() noexcept {} + + private: + + cppcoro::coroutine_handle<> m_consumer; + + }; + + inline async_generator_yield_operation async_generator_promise_base::final_suspend() noexcept + { + m_currentValue = nullptr; + return internal_yield_value(); + } + + inline async_generator_yield_operation async_generator_promise_base::internal_yield_value() noexcept + { + return async_generator_yield_operation{ m_consumerCoroutine }; + } + + class async_generator_advance_operation + { + protected: + + async_generator_advance_operation(std::nullptr_t) noexcept + : m_promise(nullptr) + , m_producerCoroutine(nullptr) + {} + + async_generator_advance_operation( + async_generator_promise_base& promise, + cppcoro::coroutine_handle<> producerCoroutine) noexcept + : m_promise(std::addressof(promise)) + , m_producerCoroutine(producerCoroutine) + { + } + + public: + + bool await_ready() const noexcept { return false; } + + cppcoro::coroutine_handle<> + await_suspend(cppcoro::coroutine_handle<> consumerCoroutine) noexcept + { + m_promise->m_consumerCoroutine = consumerCoroutine; + return m_producerCoroutine; + } + + protected: + + async_generator_promise_base* m_promise; + cppcoro::coroutine_handle<> m_producerCoroutine; + + }; + + template + class async_generator_promise final : public async_generator_promise_base + { + using value_type = std::remove_reference_t; + + public: + + async_generator_promise() noexcept = default; + + async_generator get_return_object() noexcept; + + async_generator_yield_operation yield_value(value_type& value) noexcept + { + m_currentValue = std::addressof(value); + return internal_yield_value(); + } + + async_generator_yield_operation yield_value(value_type&& value) noexcept + { + return yield_value(value); + } + + T& value() const noexcept + { + return *static_cast(m_currentValue); + } + + }; + + template + class async_generator_promise final : public async_generator_promise_base + { + public: + + async_generator_promise() noexcept = default; + + async_generator get_return_object() noexcept; + + async_generator_yield_operation yield_value(T&& value) noexcept + { + m_currentValue = std::addressof(value); + return internal_yield_value(); + } + + T&& value() const noexcept + { + return std::move(*static_cast(m_currentValue)); + } + + }; + + template + class async_generator_increment_operation final : public async_generator_advance_operation + { + public: + + async_generator_increment_operation(async_generator_iterator& iterator) noexcept + : async_generator_advance_operation(iterator.m_coroutine.promise(), iterator.m_coroutine) + , m_iterator(iterator) + {} + + async_generator_iterator& await_resume(); + + private: + + async_generator_iterator& m_iterator; + + }; + + template + class async_generator_iterator final + { + using promise_type = async_generator_promise; + using handle_type = cppcoro::coroutine_handle; + + public: + + using iterator_category = std::input_iterator_tag; + // Not sure what type should be used for difference_type as we don't + // allow calculating difference between two iterators. + using difference_type = std::ptrdiff_t; + using value_type = std::remove_reference_t; + using reference = std::add_lvalue_reference_t; + using pointer = std::add_pointer_t; + + async_generator_iterator(std::nullptr_t) noexcept + : m_coroutine(nullptr) + {} + + async_generator_iterator(handle_type coroutine) noexcept + : m_coroutine(coroutine) + {} + + async_generator_increment_operation operator++() noexcept + { + return async_generator_increment_operation{ *this }; + } + + reference operator*() const noexcept + { + return m_coroutine.promise().value(); + } + + bool operator==(const async_generator_iterator& other) const noexcept + { + return m_coroutine == other.m_coroutine; + } + + bool operator!=(const async_generator_iterator& other) const noexcept + { + return !(*this == other); + } + + private: + + friend class async_generator_increment_operation; + + handle_type m_coroutine; + + }; + + template + async_generator_iterator& async_generator_increment_operation::await_resume() + { + if (m_promise->finished()) + { + // Update iterator to end() + m_iterator = async_generator_iterator{ nullptr }; + m_promise->rethrow_if_unhandled_exception(); + } + + return m_iterator; + } + + template + class async_generator_begin_operation final : public async_generator_advance_operation + { + using promise_type = async_generator_promise; + using handle_type = cppcoro::coroutine_handle; + + public: + + async_generator_begin_operation(std::nullptr_t) noexcept + : async_generator_advance_operation(nullptr) + {} + + async_generator_begin_operation(handle_type producerCoroutine) noexcept + : async_generator_advance_operation(producerCoroutine.promise(), producerCoroutine) + {} + + bool await_ready() const noexcept + { + return m_promise == nullptr || async_generator_advance_operation::await_ready(); + } + + async_generator_iterator await_resume() + { + if (m_promise == nullptr) + { + // Called begin() on the empty generator. + return async_generator_iterator{ nullptr }; + } + else if (m_promise->finished()) + { + // Completed without yielding any values. + m_promise->rethrow_if_unhandled_exception(); + return async_generator_iterator{ nullptr }; + } + + return async_generator_iterator{ + handle_type::from_promise(*static_cast(m_promise)) + }; + } + }; + } + + template + class [[nodiscard]] async_generator + { + public: + + using promise_type = detail::async_generator_promise; + using iterator = detail::async_generator_iterator; + + async_generator() noexcept + : m_coroutine(nullptr) + {} + + explicit async_generator(promise_type& promise) noexcept + : m_coroutine(cppcoro::coroutine_handle::from_promise(promise)) + {} + + async_generator(async_generator&& other) noexcept + : m_coroutine(other.m_coroutine) + { + other.m_coroutine = nullptr; + } + + ~async_generator() + { + if (m_coroutine) + { + m_coroutine.destroy(); + } + } + + async_generator& operator=(async_generator&& other) noexcept + { + async_generator temp(std::move(other)); + swap(temp); + return *this; + } + + async_generator(const async_generator&) = delete; + async_generator& operator=(const async_generator&) = delete; + + auto begin() noexcept + { + if (!m_coroutine) + { + return detail::async_generator_begin_operation{ nullptr }; + } + + return detail::async_generator_begin_operation{ m_coroutine }; + } + + auto end() noexcept + { + return iterator{ nullptr }; + } + + void swap(async_generator& other) noexcept + { + using std::swap; + swap(m_coroutine, other.m_coroutine); + } + + private: + + cppcoro::coroutine_handle m_coroutine; + + }; + + template + void swap(async_generator& a, async_generator& b) noexcept + { + a.swap(b); + } + + namespace detail + { + template + async_generator async_generator_promise::get_return_object() noexcept + { + return async_generator{ *this }; + } + } +#else // !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + + namespace detail + { + template + class async_generator_iterator; + class async_generator_yield_operation; + class async_generator_advance_operation; + + class async_generator_promise_base + { + public: + + async_generator_promise_base() noexcept + : m_state(state::value_ready_producer_suspended) + , m_exception(nullptr) + { + // Other variables left intentionally uninitialised as they're + // only referenced in certain states by which time they should + // have been initialised. + } + + async_generator_promise_base(const async_generator_promise_base& other) = delete; + async_generator_promise_base& operator=(const async_generator_promise_base& other) = delete; + + cppcoro::suspend_always initial_suspend() const noexcept + { + return {}; + } + + async_generator_yield_operation final_suspend() noexcept; + + void unhandled_exception() noexcept + { + // Don't bother capturing the exception if we have been cancelled + // as there is no consumer that will see it. + if (m_state.load(std::memory_order_relaxed) != state::cancelled) + { + m_exception = std::current_exception(); + } + } + + void return_void() noexcept + { + } + + /// Query if the generator has reached the end of the sequence. + /// + /// Only valid to call after resuming from an awaited advance operation. + /// i.e. Either a begin() or iterator::operator++() operation. + bool finished() const noexcept + { + return m_currentValue == nullptr; + } + + void rethrow_if_unhandled_exception() + { + if (m_exception) + { + std::rethrow_exception(std::move(m_exception)); + } + } + + /// Request that the generator cancel generation of new items. + /// + /// \return + /// Returns true if the request was completed synchronously and the associated + /// producer coroutine is now available to be destroyed. In which case the caller + /// is expected to call destroy() on the coroutine_handle. + /// Returns false if the producer coroutine was not at a suitable suspend-point. + /// The coroutine will be destroyed when it next reaches a co_yield or co_return + /// statement. + bool request_cancellation() noexcept + { + const auto previousState = m_state.exchange(state::cancelled, std::memory_order_acq_rel); + + // Not valid to destroy async_generator object if consumer coroutine still suspended + // in a co_await for next item. + assert(previousState != state::value_not_ready_consumer_suspended); + + // A coroutine should only ever be cancelled once, from the destructor of the + // owning async_generator object. + assert(previousState != state::cancelled); + + return previousState == state::value_ready_producer_suspended; + } + + protected: + + async_generator_yield_operation internal_yield_value() noexcept; + + private: + + friend class async_generator_yield_operation; + friend class async_generator_advance_operation; + + // State transition diagram + // VNRCA - value_not_ready_consumer_active + // VNRCS - value_not_ready_consumer_suspended + // VRPA - value_ready_producer_active + // VRPS - value_ready_producer_suspended + // + // A +--- VNRCA --[C]--> VNRCS yield_value() + // | | | A | A | . + // | [C] [P] | [P] | | . + // | | | [C] | [C] | . + // | | V | V | | . + // operator++/ | VRPS <--[P]--- VRPA V | + // begin() | | | | + // | [C] [C] | + // | +----+ +---+ | + // | | | | + // | V V V + // +--------> cancelled ~async_generator() + // + // [C] - Consumer performs this transition + // [P] - Producer performs this transition + enum class state + { + value_not_ready_consumer_active, + value_not_ready_consumer_suspended, + value_ready_producer_active, + value_ready_producer_suspended, + cancelled + }; + + std::atomic m_state; + + std::exception_ptr m_exception; + + cppcoro::coroutine_handle<> m_consumerCoroutine; + + protected: + + void* m_currentValue; + }; + + class async_generator_yield_operation final + { + using state = async_generator_promise_base::state; + + public: + + async_generator_yield_operation(async_generator_promise_base& promise, state initialState) noexcept + : m_promise(promise) + , m_initialState(initialState) + {} + + bool await_ready() const noexcept + { + return m_initialState == state::value_not_ready_consumer_suspended; + } + + bool await_suspend(cppcoro::coroutine_handle<> producer) noexcept; + + void await_resume() noexcept {} + + private: + async_generator_promise_base& m_promise; + state m_initialState; + }; + + inline async_generator_yield_operation async_generator_promise_base::final_suspend() noexcept + { + m_currentValue = nullptr; + return internal_yield_value(); + } + + inline async_generator_yield_operation async_generator_promise_base::internal_yield_value() noexcept + { + state currentState = m_state.load(std::memory_order_acquire); + assert(currentState != state::value_ready_producer_active); + assert(currentState != state::value_ready_producer_suspended); + + if (currentState == state::value_not_ready_consumer_suspended) + { + // Only need relaxed memory order since we're resuming the + // consumer on the same thread. + m_state.store(state::value_ready_producer_active, std::memory_order_relaxed); + + // Resume the consumer. + // It might ask for another value before returning, in which case it'll + // transition to value_not_ready_consumer_suspended and we can return from + // yield_value without suspending, otherwise we should try to suspend + // the producer in which case the consumer will wake us up again + // when it wants the next value. + m_consumerCoroutine.resume(); + + // Need to use acquire semantics here since it's possible that the + // consumer might have asked for the next value on a different thread + // which executed concurrently with the call to m_consumerCoro on the + // current thread above. + currentState = m_state.load(std::memory_order_acquire); + } + + return async_generator_yield_operation{ *this, currentState }; + } + + inline bool async_generator_yield_operation::await_suspend( + cppcoro::coroutine_handle<> producer) noexcept + { + state currentState = m_initialState; + if (currentState == state::value_not_ready_consumer_active) + { + bool producerSuspended = m_promise.m_state.compare_exchange_strong( + currentState, + state::value_ready_producer_suspended, + std::memory_order_release, + std::memory_order_acquire); + if (producerSuspended) + { + return true; + } + + if (currentState == state::value_not_ready_consumer_suspended) + { + // Can get away with using relaxed memory semantics here since we're + // resuming the consumer on the current thread. + m_promise.m_state.store(state::value_ready_producer_active, std::memory_order_relaxed); + + m_promise.m_consumerCoroutine.resume(); + + // The consumer might have asked for another value before returning, in which case + // it'll transition to value_not_ready_consumer_suspended and we can return without + // suspending, otherwise we should try to suspend the producer, in which case the + // consumer will wake us up again when it wants the next value. + // + // Need to use acquire semantics here since it's possible that the consumer might + // have asked for the next value on a different thread which executed concurrently + // with the call to m_consumerCoro.resume() above. + currentState = m_promise.m_state.load(std::memory_order_acquire); + if (currentState == state::value_not_ready_consumer_suspended) + { + return false; + } + } + } + + // By this point the consumer has been resumed if required and is now active. + + if (currentState == state::value_ready_producer_active) + { + // Try to suspend the producer. + // If we failed to suspend then it's either because the consumer destructed, transitioning + // the state to cancelled, or requested the next item, transitioning the state to value_not_ready_consumer_suspended. + const bool suspendedProducer = m_promise.m_state.compare_exchange_strong( + currentState, + state::value_ready_producer_suspended, + std::memory_order_release, + std::memory_order_acquire); + if (suspendedProducer) + { + return true; + } + + if (currentState == state::value_not_ready_consumer_suspended) + { + // Consumer has asked for the next value. + return false; + } + } + + assert(currentState == state::cancelled); + + // async_generator object has been destroyed and we're now at a + // co_yield/co_return suspension point so we can just destroy + // the coroutine. + producer.destroy(); + + return true; + } + + class async_generator_advance_operation + { + using state = async_generator_promise_base::state; + + protected: + + async_generator_advance_operation(std::nullptr_t) noexcept + : m_promise(nullptr) + , m_producerCoroutine(nullptr) + {} + + async_generator_advance_operation( + async_generator_promise_base& promise, + cppcoro::coroutine_handle<> producerCoroutine) noexcept + : m_promise(std::addressof(promise)) + , m_producerCoroutine(producerCoroutine) + { + state initialState = promise.m_state.load(std::memory_order_acquire); + if (initialState == state::value_ready_producer_suspended) + { + // Can use relaxed memory order here as we will be resuming the producer + // on the same thread. + promise.m_state.store(state::value_not_ready_consumer_active, std::memory_order_relaxed); + + producerCoroutine.resume(); + + // Need to use acquire memory order here since it's possible that the + // coroutine may have transferred execution to another thread and + // completed on that other thread before the call to resume() returns. + initialState = promise.m_state.load(std::memory_order_acquire); + } + + m_initialState = initialState; + } + + public: + + bool await_ready() const noexcept + { + return m_initialState == state::value_ready_producer_suspended; + } + + bool await_suspend(cppcoro::coroutine_handle<> consumerCoroutine) noexcept + { + m_promise->m_consumerCoroutine = consumerCoroutine; + + auto currentState = m_initialState; + if (currentState == state::value_ready_producer_active) + { + // A potential race between whether consumer or producer coroutine + // suspends first. Resolve the race using a compare-exchange. + if (m_promise->m_state.compare_exchange_strong( + currentState, + state::value_not_ready_consumer_suspended, + std::memory_order_release, + std::memory_order_acquire)) + { + return true; + } + + assert(currentState == state::value_ready_producer_suspended); + + m_promise->m_state.store(state::value_not_ready_consumer_active, std::memory_order_relaxed); + + m_producerCoroutine.resume(); + + currentState = m_promise->m_state.load(std::memory_order_acquire); + if (currentState == state::value_ready_producer_suspended) + { + // Producer coroutine produced a value synchronously. + return false; + } + } + + assert(currentState == state::value_not_ready_consumer_active); + + // Try to suspend consumer coroutine, transitioning to value_not_ready_consumer_suspended. + // This could be racing with producer making the next value available and suspending + // (transition to value_ready_producer_suspended) so we use compare_exchange to decide who + // wins the race. + // If compare_exchange succeeds then consumer suspended (and we return true). + // If it fails then producer yielded next value and suspended and we can return + // synchronously without suspended (ie. return false). + return m_promise->m_state.compare_exchange_strong( + currentState, + state::value_not_ready_consumer_suspended, + std::memory_order_release, + std::memory_order_acquire); + } + + protected: + + async_generator_promise_base* m_promise; + cppcoro::coroutine_handle<> m_producerCoroutine; + + private: + + state m_initialState; + + }; + + template + class async_generator_promise final : public async_generator_promise_base + { + using value_type = std::remove_reference_t; + + public: + + async_generator_promise() noexcept = default; + + async_generator get_return_object() noexcept; + + async_generator_yield_operation yield_value(value_type& value) noexcept + { + m_currentValue = std::addressof(value); + return internal_yield_value(); + } + + async_generator_yield_operation yield_value(value_type&& value) noexcept + { + return yield_value(value); + } + + T& value() const noexcept + { + return *static_cast(m_currentValue); + } + + }; + + template + class async_generator_promise final : public async_generator_promise_base + { + public: + + async_generator_promise() noexcept = default; + + async_generator get_return_object() noexcept; + + async_generator_yield_operation yield_value(T&& value) noexcept + { + m_currentValue = std::addressof(value); + return internal_yield_value(); + } + + T&& value() const noexcept + { + return std::move(*static_cast(m_currentValue)); + } + + }; + + template + class async_generator_increment_operation final : public async_generator_advance_operation + { + public: + + async_generator_increment_operation(async_generator_iterator& iterator) noexcept + : async_generator_advance_operation(iterator.m_coroutine.promise(), iterator.m_coroutine) + , m_iterator(iterator) + {} + + async_generator_iterator& await_resume(); + + private: + + async_generator_iterator& m_iterator; + + }; + + template + class async_generator_iterator final + { + using promise_type = async_generator_promise; + using handle_type = cppcoro::coroutine_handle; + + public: + + using iterator_category = std::input_iterator_tag; + // Not sure what type should be used for difference_type as we don't + // allow calculating difference between two iterators. + using difference_type = std::ptrdiff_t; + using value_type = std::remove_reference_t; + using reference = std::add_lvalue_reference_t; + using pointer = std::add_pointer_t; + + async_generator_iterator(std::nullptr_t) noexcept + : m_coroutine(nullptr) + {} + + async_generator_iterator(handle_type coroutine) noexcept + : m_coroutine(coroutine) + {} + + async_generator_increment_operation operator++() noexcept + { + return async_generator_increment_operation{ *this }; + } + + reference operator*() const noexcept + { + return m_coroutine.promise().value(); + } + + bool operator==(const async_generator_iterator& other) const noexcept + { + return m_coroutine == other.m_coroutine; + } + + bool operator!=(const async_generator_iterator& other) const noexcept + { + return !(*this == other); + } + + private: + + friend class async_generator_increment_operation; + + handle_type m_coroutine; + + }; + + template + async_generator_iterator& async_generator_increment_operation::await_resume() + { + if (m_promise->finished()) + { + // Update iterator to end() + m_iterator = async_generator_iterator{ nullptr }; + m_promise->rethrow_if_unhandled_exception(); + } + + return m_iterator; + } + + template + class async_generator_begin_operation final : public async_generator_advance_operation + { + using promise_type = async_generator_promise; + using handle_type = cppcoro::coroutine_handle; + + public: + + async_generator_begin_operation(std::nullptr_t) noexcept + : async_generator_advance_operation(nullptr) + {} + + async_generator_begin_operation(handle_type producerCoroutine) noexcept + : async_generator_advance_operation(producerCoroutine.promise(), producerCoroutine) + {} + + bool await_ready() const noexcept + { + return m_promise == nullptr || async_generator_advance_operation::await_ready(); + } + + async_generator_iterator await_resume() + { + if (m_promise == nullptr) + { + // Called begin() on the empty generator. + return async_generator_iterator{ nullptr }; + } + else if (m_promise->finished()) + { + // Completed without yielding any values. + m_promise->rethrow_if_unhandled_exception(); + return async_generator_iterator{ nullptr }; + } + + return async_generator_iterator{ + handle_type::from_promise(*static_cast(m_promise)) + }; + } + }; + } + + template + class async_generator + { + public: + + using promise_type = detail::async_generator_promise; + using iterator = detail::async_generator_iterator; + + async_generator() noexcept + : m_coroutine(nullptr) + {} + + explicit async_generator(promise_type& promise) noexcept + : m_coroutine(cppcoro::coroutine_handle::from_promise(promise)) + {} + + async_generator(async_generator&& other) noexcept + : m_coroutine(other.m_coroutine) + { + other.m_coroutine = nullptr; + } + + ~async_generator() + { + if (m_coroutine) + { + if (m_coroutine.promise().request_cancellation()) + { + m_coroutine.destroy(); + } + } + } + + async_generator& operator=(async_generator&& other) noexcept + { + async_generator temp(std::move(other)); + swap(temp); + return *this; + } + + async_generator(const async_generator&) = delete; + async_generator& operator=(const async_generator&) = delete; + + auto begin() noexcept + { + if (!m_coroutine) + { + return detail::async_generator_begin_operation{ nullptr }; + } + + return detail::async_generator_begin_operation{ m_coroutine }; + } + + auto end() noexcept + { + return iterator{ nullptr }; + } + + void swap(async_generator& other) noexcept + { + using std::swap; + swap(m_coroutine, other.m_coroutine); + } + + private: + + cppcoro::coroutine_handle m_coroutine; + + }; + + template + void swap(async_generator& a, async_generator& b) noexcept + { + a.swap(b); + } + + namespace detail + { + template + async_generator async_generator_promise::get_return_object() noexcept + { + return async_generator{ *this }; + } + } +#endif // !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + + template + async_generator::iterator&>())>> fmap( + FUNC func, + async_generator source) + { + static_assert( + !std::is_reference_v, + "Passing by reference to async_generator coroutine is unsafe. " + "Use std::ref or std::cref to explicitly pass by reference."); + + // Explicitly hand-coding the loop here rather than using range-based + // for loop since it's difficult to std::forward the value of a + // range-based for-loop, preserving the value category of operator* + // return-value. + auto it = co_await source.begin(); + const auto itEnd = source.end(); + while (it != itEnd) + { + co_yield std::invoke(func, *it); + (void)co_await ++it; + } + } +} + +#endif diff --git a/include/cppcoro/async_latch.hpp b/include/cppcoro/async_latch.hpp new file mode 100644 index 0000000..9dfbb50 --- /dev/null +++ b/include/cppcoro/async_latch.hpp @@ -0,0 +1,75 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ASYNC_LATCH_HPP_INCLUDED +#define CPPCORO_ASYNC_LATCH_HPP_INCLUDED + +#include + +#include +#include + +namespace cppcoro +{ + class async_latch + { + public: + + /// Construct the latch with the specified initial count. + /// + /// \param initialCount + /// The initial count of the latch. The latch will become signalled once + /// \c this->count_down() has been called \p initialCount times. + /// The latch will be immediately signalled on construction if this + /// parameter is zero or negative. + async_latch(std::ptrdiff_t initialCount) noexcept + : m_count(initialCount) + , m_event(initialCount <= 0) + {} + + /// Query if the latch has become signalled. + /// + /// The latch is marked as signalled once the count reaches zero. + bool is_ready() const noexcept { return m_event.is_set(); } + + /// Decrement the count by n. + /// + /// Any coroutines awaiting this latch will be resumed once the count + /// reaches zero. ie. when this method has been called at least 'initialCount' + /// times. + /// + /// Any awaiting coroutines that are currently suspended waiting for the + /// latch to become signalled will be resumed inside the last call to this + /// method (ie. the call that decrements the count to zero). + /// + /// \param n + /// The amount to decrement the count by. + void count_down(std::ptrdiff_t n = 1) noexcept + { + if (m_count.fetch_sub(n, std::memory_order_acq_rel) <= n) + { + m_event.set(); + } + } + + /// Allows the latch to be awaited within a coroutine. + /// + /// If the latch is already signalled (ie. the count has been decremented + /// to zero) then the awaiting coroutine will continue without suspending. + /// Otherwise, the coroutine will suspend and will later be resumed inside + /// a call to `count_down()`. + auto operator co_await() const noexcept + { + return m_event.operator co_await(); + } + + private: + + std::atomic m_count; + async_manual_reset_event m_event; + + }; +} + +#endif diff --git a/include/cppcoro/async_manual_reset_event.hpp b/include/cppcoro/async_manual_reset_event.hpp new file mode 100644 index 0000000..fd58282 --- /dev/null +++ b/include/cppcoro/async_manual_reset_event.hpp @@ -0,0 +1,104 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ASYNC_MANUAL_RESET_EVENT_HPP_INCLUDED +#define CPPCORO_ASYNC_MANUAL_RESET_EVENT_HPP_INCLUDED + +#include +#include +#include + +namespace cppcoro +{ + class async_manual_reset_event_operation; + + /// An async manual-reset event is a coroutine synchronisation abstraction + /// that allows one or more coroutines to wait until some thread calls + /// set() on the event. + /// + /// When a coroutine awaits a 'set' event the coroutine continues without + /// suspending. Otherwise, if it awaits a 'not set' event the coroutine is + /// suspended and is later resumed inside the call to 'set()'. + /// + /// \seealso async_auto_reset_event + class async_manual_reset_event + { + public: + + /// Initialise the event to either 'set' or 'not set' state. + /// + /// \param initiallySet + /// If 'true' then initialises the event to the 'set' state, otherwise + /// initialises the event to the 'not set' state. + async_manual_reset_event(bool initiallySet = false) noexcept; + + ~async_manual_reset_event(); + + /// Wait for the event to enter the 'set' state. + /// + /// If the event is already 'set' then the coroutine continues without + /// suspending. + /// + /// Otherwise, the coroutine is suspended and later resumed when some + /// thread calls 'set()'. The coroutine will be resumed inside the next + /// call to 'set()'. + async_manual_reset_event_operation operator co_await() const noexcept; + + /// Query if the event is currently in the 'set' state. + bool is_set() const noexcept; + + /// Set the state of the event to 'set'. + /// + /// If there are pending coroutines awaiting the event then all + /// pending coroutines are resumed within this call. + /// Any coroutines that subsequently await the event will continue + /// without suspending. + /// + /// This operation is a no-op if the event was already 'set'. + void set() noexcept; + + /// Set the state of the event to 'not-set'. + /// + /// Any coroutines that subsequently await the event will suspend + /// until some thread calls 'set()'. + /// + /// This is a no-op if the state was already 'not set'. + void reset() noexcept; + + private: + + friend class async_manual_reset_event_operation; + + // This variable has 3 states: + // - this - The state is 'set'. + // - nullptr - The state is 'not set' with no waiters. + // - other - The state is 'not set'. + // Points to an 'async_manual_reset_event_operation' that is + // the head of a linked-list of waiters. + mutable std::atomic m_state; + + }; + + class async_manual_reset_event_operation + { + public: + + explicit async_manual_reset_event_operation(const async_manual_reset_event& event) noexcept; + + bool await_ready() const noexcept; + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() const noexcept {} + + private: + + friend class async_manual_reset_event; + + const async_manual_reset_event& m_event; + async_manual_reset_event_operation* m_next; + cppcoro::coroutine_handle<> m_awaiter; + + }; +} + +#endif diff --git a/include/cppcoro/async_mutex.hpp b/include/cppcoro/async_mutex.hpp new file mode 100644 index 0000000..2f4fd61 --- /dev/null +++ b/include/cppcoro/async_mutex.hpp @@ -0,0 +1,200 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ASYNC_MUTEX_HPP_INCLUDED +#define CPPCORO_ASYNC_MUTEX_HPP_INCLUDED + +#include +#include +#include +#include // for std::adopt_lock_t + +namespace cppcoro +{ + class async_mutex_lock; + class async_mutex_lock_operation; + class async_mutex_scoped_lock_operation; + + /// \brief + /// A mutex that can be locked asynchronously using 'co_await'. + /// + /// Ownership of the mutex is not tied to any particular thread. + /// This allows the coroutine owning the lock to transition from + /// one thread to another while holding a lock. + /// + /// Implementation is lock-free, using only std::atomic values for + /// synchronisation. Awaiting coroutines are suspended without blocking + /// the current thread if the lock could not be acquired synchronously. + class async_mutex + { + public: + + /// \brief + /// Construct to a mutex that is not currently locked. + async_mutex() noexcept; + + /// Destroys the mutex. + /// + /// Behaviour is undefined if there are any outstanding coroutines + /// still waiting to acquire the lock. + ~async_mutex(); + + /// \brief + /// Attempt to acquire a lock on the mutex without blocking. + /// + /// \return + /// true if the lock was acquired, false if the mutex was already locked. + /// The caller is responsible for ensuring unlock() is called on the mutex + /// to release the lock if the lock was acquired by this call. + bool try_lock() noexcept; + + /// \brief + /// Acquire a lock on the mutex asynchronously. + /// + /// If the lock could not be acquired synchronously then the awaiting + /// coroutine will be suspended and later resumed when the lock becomes + /// available. If suspended, the coroutine will be resumed inside the + /// call to unlock() from the previous lock owner. + /// + /// \return + /// An operation object that must be 'co_await'ed to wait until the + /// lock is acquired. The result of the 'co_await m.lock_async()' + /// expression has type 'void'. + async_mutex_lock_operation lock_async() noexcept; + + /// \brief + /// Acquire a lock on the mutex asynchronously, returning an object that + /// will call unlock() automatically when it goes out of scope. + /// + /// If the lock could not be acquired synchronously then the awaiting + /// coroutine will be suspended and later resumed when the lock becomes + /// available. If suspended, the coroutine will be resumed inside the + /// call to unlock() from the previous lock owner. + /// + /// \return + /// An operation object that must be 'co_await'ed to wait until the + /// lock is acquired. The result of the 'co_await m.scoped_lock_async()' + /// expression returns an 'async_mutex_lock' object that will call + /// this->mutex() when it destructs. + async_mutex_scoped_lock_operation scoped_lock_async() noexcept; + + /// \brief + /// Unlock the mutex. + /// + /// Must only be called by the current lock-holder. + /// + /// If there are lock operations waiting to acquire the + /// mutex then the next lock operation in the queue will + /// be resumed inside this call. + void unlock(); + + private: + + friend class async_mutex_lock_operation; + + static constexpr std::uintptr_t not_locked = 1; + + // assume == reinterpret_cast(static_cast(nullptr)) + static constexpr std::uintptr_t locked_no_waiters = 0; + + // This field provides synchronisation for the mutex. + // + // It can have three kinds of values: + // - not_locked + // - locked_no_waiters + // - a pointer to the head of a singly linked list of recently + // queued async_mutex_lock_operation objects. This list is + // in most-recently-queued order as new items are pushed onto + // the front of the list. + std::atomic m_state; + + // Linked list of async lock operations that are waiting to acquire + // the mutex. These operations will acquire the lock in the order + // they appear in this list. Waiters in this list will acquire the + // mutex before waiters added to the m_newWaiters list. + async_mutex_lock_operation* m_waiters; + + }; + + /// \brief + /// An object that holds onto a mutex lock for its lifetime and + /// ensures that the mutex is unlocked when it is destructed. + /// + /// It is equivalent to a std::lock_guard object but requires + /// that the result of co_await async_mutex::lock_async() is + /// passed to the constructor rather than passing the async_mutex + /// object itself. + class async_mutex_lock + { + public: + + explicit async_mutex_lock(async_mutex& mutex, std::adopt_lock_t) noexcept + : m_mutex(&mutex) + {} + + async_mutex_lock(async_mutex_lock&& other) noexcept + : m_mutex(other.m_mutex) + { + other.m_mutex = nullptr; + } + + async_mutex_lock(const async_mutex_lock& other) = delete; + async_mutex_lock& operator=(const async_mutex_lock& other) = delete; + + // Releases the lock. + ~async_mutex_lock() + { + if (m_mutex != nullptr) + { + m_mutex->unlock(); + } + } + + private: + + async_mutex* m_mutex; + + }; + + class async_mutex_lock_operation + { + public: + + explicit async_mutex_lock_operation(async_mutex& mutex) noexcept + : m_mutex(mutex) + {} + + bool await_ready() const noexcept { return false; } + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() const noexcept {} + + protected: + + friend class async_mutex; + + async_mutex& m_mutex; + + private: + + async_mutex_lock_operation* m_next; + cppcoro::coroutine_handle<> m_awaiter; + + }; + + class async_mutex_scoped_lock_operation : public async_mutex_lock_operation + { + public: + + using async_mutex_lock_operation::async_mutex_lock_operation; + + [[nodiscard]] + async_mutex_lock await_resume() const noexcept + { + return async_mutex_lock{ m_mutex, std::adopt_lock }; + } + + }; +} + +#endif diff --git a/include/cppcoro/async_scope.hpp b/include/cppcoro/async_scope.hpp new file mode 100644 index 0000000..9b65e1b --- /dev/null +++ b/include/cppcoro/async_scope.hpp @@ -0,0 +1,102 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ASYNC_SCOPE_HPP_INCLUDED +#define CPPCORO_ASYNC_SCOPE_HPP_INCLUDED + +#include + +#include +#include +#include +#include + +namespace cppcoro +{ + class async_scope + { + public: + + async_scope() noexcept + : m_count(1u) + {} + + ~async_scope() + { + // scope must be co_awaited before it destructs. + assert(m_continuation); + } + + template + void spawn(AWAITABLE&& awaitable) + { + [](async_scope* scope, std::decay_t awaitable) -> oneway_task + { + scope->on_work_started(); + auto decrementOnCompletion = on_scope_exit([scope] { scope->on_work_finished(); }); + co_await std::move(awaitable); + }(this, std::forward(awaitable)); + } + + [[nodiscard]] auto join() noexcept + { + class awaiter + { + async_scope* m_scope; + public: + awaiter(async_scope* scope) noexcept : m_scope(scope) {} + + bool await_ready() noexcept + { + return m_scope->m_count.load(std::memory_order_acquire) == 0; + } + + bool await_suspend(cppcoro::coroutine_handle<> continuation) noexcept + { + m_scope->m_continuation = continuation; + return m_scope->m_count.fetch_sub(1u, std::memory_order_acq_rel) > 1u; + } + + void await_resume() noexcept + {} + }; + + return awaiter{ this }; + } + + private: + + void on_work_finished() noexcept + { + if (m_count.fetch_sub(1u, std::memory_order_acq_rel) == 1) + { + m_continuation.resume(); + } + } + + void on_work_started() noexcept + { + assert(m_count.load(std::memory_order_relaxed) != 0); + m_count.fetch_add(1, std::memory_order_relaxed); + } + + struct oneway_task + { + struct promise_type + { + cppcoro::suspend_never initial_suspend() noexcept { return {}; } + cppcoro::suspend_never final_suspend() noexcept { return {}; } + void unhandled_exception() { std::terminate(); } + oneway_task get_return_object() { return {}; } + void return_void() {} + }; + }; + + std::atomic m_count; + cppcoro::coroutine_handle<> m_continuation; + + }; +} + +#endif diff --git a/include/cppcoro/awaitable_traits.hpp b/include/cppcoro/awaitable_traits.hpp new file mode 100644 index 0000000..5a7465d --- /dev/null +++ b/include/cppcoro/awaitable_traits.hpp @@ -0,0 +1,27 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_AWAITABLE_TRAITS_HPP_INCLUDED +#define CPPCORO_AWAITABLE_TRAITS_HPP_INCLUDED + +#include + +#include + +namespace cppcoro +{ + template + struct awaitable_traits + {}; + + template + struct awaitable_traits()))>> + { + using awaiter_t = decltype(cppcoro::detail::get_awaiter(std::declval())); + + using await_result_t = decltype(std::declval().await_resume()); + }; +} + +#endif diff --git a/include/cppcoro/broken_promise.hpp b/include/cppcoro/broken_promise.hpp new file mode 100644 index 0000000..55462fc --- /dev/null +++ b/include/cppcoro/broken_promise.hpp @@ -0,0 +1,24 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_BROKEN_PROMISE_HPP_INCLUDED +#define CPPCORO_BROKEN_PROMISE_HPP_INCLUDED + +#include + +namespace cppcoro +{ + /// \brief + /// Exception thrown when you attempt to retrieve the result of + /// a task that has been detached from its promise/coroutine. + class broken_promise : public std::logic_error + { + public: + broken_promise() + : std::logic_error("broken promise") + {} + }; +} + +#endif diff --git a/include/cppcoro/cancellation_registration.hpp b/include/cppcoro/cancellation_registration.hpp new file mode 100644 index 0000000..64d267b --- /dev/null +++ b/include/cppcoro/cancellation_registration.hpp @@ -0,0 +1,87 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_CANCELLATION_REGISTRATION_HPP_INCLUDED +#define CPPCORO_CANCELLATION_REGISTRATION_HPP_INCLUDED + +#include + +#include +#include +#include +#include +#include + +namespace cppcoro +{ + namespace detail + { + class cancellation_state; + struct cancellation_registration_list_chunk; + struct cancellation_registration_state; + } + + class cancellation_registration + { + public: + + /// Registers the callback to be executed when cancellation is requested + /// on the cancellation_token. + /// + /// The callback will be executed if cancellation is requested for the + /// specified cancellation token. If cancellation has already been requested + /// then the callback will be executed immediately, before the constructor + /// returns. If cancellation has not yet been requested then the callback + /// will be executed on the first thread to request cancellation inside + /// the call to cancellation_source::request_cancellation(). + /// + /// \param token + /// The cancellation token to register the callback with. + /// + /// \param callback + /// The callback to be executed when cancellation is requested on the + /// the cancellation_token. Note that callback must not throw an exception + /// if called when cancellation is requested otherwise std::terminate() + /// will be called. + /// + /// \throw std::bad_alloc + /// If registration failed due to insufficient memory available. + template< + typename FUNC, + typename = std::enable_if_t, FUNC&&>>> + cancellation_registration(cancellation_token token, FUNC&& callback) + : m_callback(std::forward(callback)) + { + register_callback(std::move(token)); + } + + cancellation_registration(const cancellation_registration& other) = delete; + cancellation_registration& operator=(const cancellation_registration& other) = delete; + + /// Deregisters the callback. + /// + /// After the destructor returns it is guaranteed that the callback + /// will not be subsequently called during a call to request_cancellation() + /// on the cancellation_source. + /// + /// This may block if cancellation has been requested on another thread + /// is it will need to wait until this callback has finished executing + /// before the callback can be destroyed. + ~cancellation_registration(); + + private: + + friend class detail::cancellation_state; + friend struct detail::cancellation_registration_state; + + void register_callback(cancellation_token&& token); + + detail::cancellation_state* m_state; + std::function m_callback; + detail::cancellation_registration_list_chunk* m_chunk; + std::uint32_t m_entryIndex; + }; +} + +#endif diff --git a/include/cppcoro/cancellation_source.hpp b/include/cppcoro/cancellation_source.hpp new file mode 100644 index 0000000..e0f100e --- /dev/null +++ b/include/cppcoro/cancellation_source.hpp @@ -0,0 +1,71 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_CANCELLATION_SOURCE_HPP_INCLUDED +#define CPPCORO_CANCELLATION_SOURCE_HPP_INCLUDED + +namespace cppcoro +{ + class cancellation_token; + + namespace detail + { + class cancellation_state; + } + + class cancellation_source + { + public: + + /// Construct to a new cancellation source. + cancellation_source(); + + /// Create a new reference to the same underlying cancellation + /// source as \p other. + cancellation_source(const cancellation_source& other) noexcept; + + cancellation_source(cancellation_source&& other) noexcept; + + ~cancellation_source(); + + cancellation_source& operator=(const cancellation_source& other) noexcept; + + cancellation_source& operator=(cancellation_source&& other) noexcept; + + /// Query if this cancellation source can be cancelled. + /// + /// A cancellation source object will not be cancellable if it has + /// previously been moved into another cancellation_source instance + /// or was copied from a cancellation_source that was not cancellable. + bool can_be_cancelled() const noexcept; + + /// Obtain a cancellation token that can be used to query if + /// cancellation has been requested on this source. + /// + /// The cancellation token can be passed into functions that you + /// may want to later be able to request cancellation. + cancellation_token token() const noexcept; + + /// Request cancellation of operations that were passed an associated + /// cancellation token. + /// + /// Any cancellation callback registered via a cancellation_registration + /// object will be called inside this function by the first thread to + /// call this method. + /// + /// This operation is a no-op if can_be_cancelled() returns false. + void request_cancellation(); + + /// Query if some thread has called 'request_cancellation()' on this + /// cancellation_source. + bool is_cancellation_requested() const noexcept; + + private: + + detail::cancellation_state* m_state; + + }; +} + +#endif diff --git a/include/cppcoro/cancellation_token.hpp b/include/cppcoro/cancellation_token.hpp new file mode 100644 index 0000000..49e8f82 --- /dev/null +++ b/include/cppcoro/cancellation_token.hpp @@ -0,0 +1,72 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_CANCELLATION_TOKEN_HPP_INCLUDED +#define CPPCORO_CANCELLATION_TOKEN_HPP_INCLUDED + +namespace cppcoro +{ + class cancellation_source; + class cancellation_registration; + + namespace detail + { + class cancellation_state; + } + + class cancellation_token + { + public: + + /// Construct to a cancellation token that can't be cancelled. + cancellation_token() noexcept; + + /// Copy another cancellation token. + /// + /// New token will refer to the same underlying state. + cancellation_token(const cancellation_token& other) noexcept; + + cancellation_token(cancellation_token&& other) noexcept; + + ~cancellation_token(); + + cancellation_token& operator=(const cancellation_token& other) noexcept; + + cancellation_token& operator=(cancellation_token&& other) noexcept; + + void swap(cancellation_token& other) noexcept; + + /// Query if it is possible that this operation will be cancelled + /// or not. + /// + /// Cancellable operations may be able to take more efficient code-paths + /// if they don't need to handle cancellation requests. + bool can_be_cancelled() const noexcept; + + /// Query if some thread has requested cancellation on an associated + /// cancellation_source object. + bool is_cancellation_requested() const noexcept; + + /// Throws cppcoro::operation_cancelled exception if cancellation + /// has been requested for the associated operation. + void throw_if_cancellation_requested() const; + + private: + + friend class cancellation_source; + friend class cancellation_registration; + + cancellation_token(detail::cancellation_state* state) noexcept; + + detail::cancellation_state* m_state; + + }; + + inline void swap(cancellation_token& a, cancellation_token& b) noexcept + { + a.swap(b); + } +} + +#endif diff --git a/include/cppcoro/config.hpp b/include/cppcoro/config.hpp new file mode 100644 index 0000000..526bb7f --- /dev/null +++ b/include/cppcoro/config.hpp @@ -0,0 +1,166 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_CONFIG_HPP_INCLUDED +#define CPPCORO_CONFIG_HPP_INCLUDED + +///////////////////////////////////////////////////////////////////////////// +// Compiler Detection + +#if defined(_MSC_VER) +# define CPPCORO_COMPILER_MSVC _MSC_FULL_VER +#else +# define CPPCORO_COMPILER_MSVC 0 +#endif + +#if defined(__clang__) +# define CPPCORO_COMPILER_CLANG (__clang_major__ * 10000 + \ + __clang_minor__ * 100 + \ + __clang_patchlevel__) +#else +# define CPPCORO_COMPILER_CLANG 0 +#endif + +#if defined(__GNUC__) +# define CPPCORO_COMPILER_GCC (__GNUC__ * 10000 + \ + __GNUC_MINOR__ * 100 + \ + __GNUC_PATCHLEVEL__) +#else +# define CPPCORO_COMPILER_GCC 0 +#endif + +/// \def CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER +/// Defined to 1 if the compiler supports returning a coroutine_handle from +/// the await_suspend() method as a way of transferring execution +/// to another coroutine with a guaranteed tail-call. +#if CPPCORO_COMPILER_CLANG +# if __clang_major__ >= 7 +# define CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER 1 +# endif +#endif +#ifndef CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER +# define CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER 0 +#endif + +#if CPPCORO_COMPILER_MSVC +# define CPPCORO_ASSUME(X) __assume(X) +#else +# define CPPCORO_ASSUME(X) +#endif + +#if CPPCORO_COMPILER_MSVC +# define CPPCORO_NOINLINE __declspec(noinline) +#elif CPPCORO_COMPILER_CLANG || CPPCORO_COMPILER_GCC +# define CPPCORO_NOINLINE __attribute__((noinline)) +#else +# define CPPCORO_NOINLINE +#endif + +#if CPPCORO_COMPILER_MSVC +# define CPPCORO_FORCE_INLINE __forceinline +#elif CPPCORO_COMPILER_CLANG +# define CPPCORO_FORCE_INLINE __attribute__((always_inline)) +#else +# define CPPCORO_FORCE_INLINE inline +#endif + +///////////////////////////////////////////////////////////////////////////// +// OS Detection + +/// \def CPPCORO_OS_WINNT +/// Defined to non-zero if the target platform is a WindowsNT variant. +/// 0x0500 - Windows 2000 +/// 0x0501 - Windows XP/Server 2003 +/// 0x0502 - Windows XP SP2/Server 2003 SP1 +/// 0x0600 - Windows Vista/Server 2008 +/// 0x0601 - Windows 7 +/// 0x0602 - Windows 8 +/// 0x0603 - Windows 8.1 +/// 0x0A00 - Windows 10 +#if defined(_WIN32_WINNT) || defined(_WIN32) +# if !defined(_WIN32_WINNT) +// Default to targeting Windows 10 if not defined. +# define _WIN32_WINNT 0x0A00 +# endif +# define CPPCORO_OS_WINNT _WIN32_WINNT +#else +# define CPPCORO_OS_WINNT 0 +#endif + +#if defined(__linux__) +# define CPPCORO_OS_LINUX 1 +#else +# define CPPCORO_OS_LINUX 0 +#endif + +///////////////////////////////////////////////////////////////////////////// +// CPU Detection + +/// \def CPPCORO_CPU_X86 +/// Defined to 1 if target CPU is of x86 family. +#if CPPCORO_COMPILER_MSVC +# if defined(_M_IX86) +# define CPPCORO_CPU_X86 1 +# endif +#elif CPPCORO_COMPILER_GCC || CPPCORO_COMPILER_CLANG +# if defined(__i386__) +# define CPPCORO_CPU_X86 1 +# endif +#endif +#if !defined(CPPCORO_CPU_X86) +# define CPPCORO_CPU_X86 0 +#endif + +/// \def CPPCORO_CPU_X64 +/// Defined to 1 if the target CPU is x64 family. +#if CPPCORO_COMPILER_MSVC +# if defined(_M_X64) +# define CPPCORO_CPU_X64 1 +# endif +#elif CPPCORO_COMPILER_GCC || CPPCORO_COMPILER_CLANG +# if defined(__x86_64__) +# define CPPCORO_CPU_X64 1 +# endif +#endif +#if !defined(CPPCORO_CPU_X64) +# define CPPCORO_CPU_X64 0 +#endif + +/// \def CPPCORO_CPU_32BIT +/// Defined if compiling for a 32-bit CPU architecture. +#if CPPCORO_CPU_X86 +# define CPPCORO_CPU_32BIT 1 +#else +# define CPPCORO_CPU_32BIT 0 +#endif + +/// \def CPPCORO_CPU_64BIT +/// Defined if compiling for a 64-bit CPU architecture. +#if CPPCORO_CPU_X64 +# define CPPCORO_CPU_64BIT 1 +#else +# define CPPCORO_CPU_64BIT 0 +#endif + +#if CPPCORO_COMPILER_MSVC +# define CPPCORO_CPU_CACHE_LINE std::hardware_destructive_interference_size +#else +// On most architectures we can assume a 64-byte cache line. +# define CPPCORO_CPU_CACHE_LINE 64 +#endif + +#if CPPCORO_COMPILER_MSVC + #if __has_include() + #include + #ifdef __cpp_lib_coroutine + #define CPPCORO_COROHEADER_FOUND_AND_USABLE + #endif + #endif +#else + #if __has_include() + #define CPPCORO_COROHEADER_FOUND_AND_USABLE + #endif +#endif + +#endif diff --git a/include/cppcoro/coroutine.hpp b/include/cppcoro/coroutine.hpp new file mode 100644 index 0000000..cd8632a --- /dev/null +++ b/include/cppcoro/coroutine.hpp @@ -0,0 +1,35 @@ +#ifndef CPPCORO_COROUTINE_HPP_INCLUDED +#define CPPCORO_COROUTINE_HPP_INCLUDED + +#include + +#ifdef CPPCORO_COROHEADER_FOUND_AND_USABLE + +#include + +namespace cppcoro { + using std::coroutine_handle; + using std::suspend_always; + using std::noop_coroutine; + using std::suspend_never; +} + +#elif __has_include() + +#include + +namespace cppcoro { + using std::experimental::coroutine_handle; + using std::experimental::suspend_always; + using std::experimental::suspend_never; + +#if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + using std::experimental::noop_coroutine; +#endif +} + +#else +#error Cppcoro requires a C++20 compiler with coroutine support +#endif + +#endif diff --git a/include/cppcoro/detail/any.hpp b/include/cppcoro/detail/any.hpp new file mode 100644 index 0000000..b1ec9cb --- /dev/null +++ b/include/cppcoro/detail/any.hpp @@ -0,0 +1,22 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_ANY_HPP_INCLUDED +#define CPPCORO_DETAIL_ANY_HPP_INCLUDED + +namespace cppcoro +{ + namespace detail + { + // Helper type that can be cast-to from any type. + struct any + { + template + any(T&&) noexcept + {} + }; + } +} + +#endif diff --git a/include/cppcoro/detail/get_awaiter.hpp b/include/cppcoro/detail/get_awaiter.hpp new file mode 100644 index 0000000..57417dc --- /dev/null +++ b/include/cppcoro/detail/get_awaiter.hpp @@ -0,0 +1,49 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_GET_AWAITER_HPP_INCLUDED +#define CPPCORO_DETAIL_GET_AWAITER_HPP_INCLUDED + +#include +#include + +namespace cppcoro +{ + namespace detail + { + template + auto get_awaiter_impl(T&& value, int) + noexcept(noexcept(static_cast(value).operator co_await())) + -> decltype(static_cast(value).operator co_await()) + { + return static_cast(value).operator co_await(); + } + + template + auto get_awaiter_impl(T&& value, long) + noexcept(noexcept(operator co_await(static_cast(value)))) + -> decltype(operator co_await(static_cast(value))) + { + return operator co_await(static_cast(value)); + } + + template< + typename T, + std::enable_if_t::value, int> = 0> + T&& get_awaiter_impl(T&& value, cppcoro::detail::any) noexcept + { + return static_cast(value); + } + + template + auto get_awaiter(T&& value) + noexcept(noexcept(detail::get_awaiter_impl(static_cast(value), 123))) + -> decltype(detail::get_awaiter_impl(static_cast(value), 123)) + { + return detail::get_awaiter_impl(static_cast(value), 123); + } + } +} + +#endif diff --git a/include/cppcoro/detail/is_awaiter.hpp b/include/cppcoro/detail/is_awaiter.hpp new file mode 100644 index 0000000..c5781ce --- /dev/null +++ b/include/cppcoro/detail/is_awaiter.hpp @@ -0,0 +1,55 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_IS_AWAITER_HPP_INCLUDED +#define CPPCORO_DETAIL_IS_AWAITER_HPP_INCLUDED + +#include +#include + +namespace cppcoro +{ + namespace detail + { + template + struct is_coroutine_handle + : std::false_type + {}; + + template + struct is_coroutine_handle> + : std::true_type + {}; + + // NOTE: We're accepting a return value of coroutine_handle

here + // which is an extension supported by Clang which is not yet part of + // the C++ coroutines TS. + template + struct is_valid_await_suspend_return_value : std::disjunction< + std::is_void, + std::is_same, + is_coroutine_handle> + {}; + + template> + struct is_awaiter : std::false_type {}; + + // NOTE: We're testing whether await_suspend() will be callable using an + // arbitrary coroutine_handle here by checking if it supports being passed + // a coroutine_handle. This may result in a false-result for some + // types which are only awaitable within a certain context. + template + struct is_awaiter().await_ready()), + decltype(std::declval().await_suspend(std::declval>())), + decltype(std::declval().await_resume())>> : + std::conjunction< + std::is_constructible().await_ready())>, + detail::is_valid_await_suspend_return_value< + decltype(std::declval().await_suspend(std::declval>()))>> + {}; + } +} + +#endif diff --git a/include/cppcoro/detail/lightweight_manual_reset_event.hpp b/include/cppcoro/detail/lightweight_manual_reset_event.hpp new file mode 100644 index 0000000..fb5b53f --- /dev/null +++ b/include/cppcoro/detail/lightweight_manual_reset_event.hpp @@ -0,0 +1,65 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_LIGHTWEIGHT_MANUAL_RESET_EVENT_HPP_INCLUDED +#define CPPCORO_DETAIL_LIGHTWEIGHT_MANUAL_RESET_EVENT_HPP_INCLUDED + +#include +#include +#include + +#if CPPCORO_OS_LINUX || (CPPCORO_OS_WINNT >= 0x0602) +# include +# include +#elif CPPCORO_OS_WINNT +# include +#else +# include +# include +#endif + +namespace cppcoro +{ + class io_service; + + namespace detail + { + class lightweight_manual_reset_event + { + public: + + lightweight_manual_reset_event(bool initiallySet = false); + + ~lightweight_manual_reset_event(); + + void set() noexcept; + + void reset() noexcept; + + void wait() noexcept; + void wait(std::span srvs, std::chrono::system_clock::duration step) noexcept; + + private: +#if CPPCORO_OS_LINUX + std::atomic m_value; +#elif CPPCORO_OS_WINNT >= 0x0602 + // Windows 8 or newer we can use WaitOnAddress() + std::atomic m_value; +#elif CPPCORO_OS_WINNT + // Before Windows 8 we need to use a WIN32 manual reset event. + cppcoro::detail::win32::handle_t m_eventHandle; +#else + // For other platforms that don't have a native futex + // or manual reset event we can just use a std::mutex + // and std::condition_variable to perform the wait. + // Not so lightweight, but should be portable to all platforms. + std::mutex m_mutex; + std::condition_variable m_cv; + bool m_isSet; +#endif + }; + } +} + +#endif diff --git a/include/cppcoro/detail/manual_lifetime.hpp b/include/cppcoro/detail/manual_lifetime.hpp new file mode 100644 index 0000000..01bb10c --- /dev/null +++ b/include/cppcoro/detail/manual_lifetime.hpp @@ -0,0 +1,120 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_MANUAL_LIFETIME_HPP_INCLUDED +#define CPPCORO_DETAIL_MANUAL_LIFETIME_HPP_INCLUDED + +#include +#include + +namespace cppcoro::detail +{ + template + struct manual_lifetime + { + public: + manual_lifetime() noexcept {} + ~manual_lifetime() noexcept {} + + manual_lifetime(const manual_lifetime&) = delete; + manual_lifetime(manual_lifetime&&) = delete; + manual_lifetime& operator=(const manual_lifetime&) = delete; + manual_lifetime& operator=(manual_lifetime&&) = delete; + + template + std::enable_if_t> construct(Args&&... args) + noexcept(std::is_nothrow_constructible_v) + { + ::new (static_cast(std::addressof(m_value))) T(static_cast(args)...); + } + + void destruct() noexcept(std::is_nothrow_destructible_v) + { + m_value.~T(); + } + + std::add_pointer_t operator->() noexcept { return std::addressof(**this); } + std::add_pointer_t operator->() const noexcept { return std::addressof(**this); } + + T& operator*() & noexcept { return m_value; } + const T& operator*() const & noexcept { return m_value; } + T&& operator*() && noexcept { return static_cast(m_value); } + const T&& operator*() const && noexcept { return static_cast(m_value); } + + private: + union { + T m_value; + }; + }; + + template + struct manual_lifetime + { + public: + manual_lifetime() noexcept {} + ~manual_lifetime() noexcept {} + + manual_lifetime(const manual_lifetime&) = delete; + manual_lifetime(manual_lifetime&&) = delete; + manual_lifetime& operator=(const manual_lifetime&) = delete; + manual_lifetime& operator=(manual_lifetime&&) = delete; + + void construct(T& value) noexcept + { + m_value = std::addressof(value); + } + + void destruct() noexcept {} + + T* operator->() noexcept { return m_value; } + const T* operator->() const noexcept { return m_value; } + + T& operator*() noexcept { return *m_value; } + const T& operator*() const noexcept { return *m_value; } + + private: + T* m_value; + }; + + template + struct manual_lifetime + { + public: + manual_lifetime() noexcept {} + ~manual_lifetime() noexcept {} + + manual_lifetime(const manual_lifetime&) = delete; + manual_lifetime(manual_lifetime&&) = delete; + manual_lifetime& operator=(const manual_lifetime&) = delete; + manual_lifetime& operator=(manual_lifetime&&) = delete; + + void construct(T&& value) noexcept + { + m_value = std::addressof(value); + } + + void destruct() noexcept {} + + T* operator->() noexcept { return m_value; } + const T* operator->() const noexcept { return m_value; } + + T& operator*() & noexcept { return *m_value; } + const T& operator*() const & noexcept { return *m_value; } + T&& operator*() && noexcept { return static_cast(*m_value); } + const T&& operator*() const && noexcept { return static_cast(*m_value); } + + private: + T* m_value; + }; + + template<> + struct manual_lifetime + { + void construct() noexcept {} + void destruct() noexcept {} + void operator*() const noexcept {} + }; +} + +#endif diff --git a/include/cppcoro/detail/remove_rvalue_reference.hpp b/include/cppcoro/detail/remove_rvalue_reference.hpp new file mode 100644 index 0000000..300bca1 --- /dev/null +++ b/include/cppcoro/detail/remove_rvalue_reference.hpp @@ -0,0 +1,29 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_REMOVE_RVALUE_REFERENCE_HPP_INCLUDED +#define CPPCORO_DETAIL_REMOVE_RVALUE_REFERENCE_HPP_INCLUDED + +namespace cppcoro +{ + namespace detail + { + template + struct remove_rvalue_reference + { + using type = T; + }; + + template + struct remove_rvalue_reference + { + using type = T; + }; + + template + using remove_rvalue_reference_t = typename remove_rvalue_reference::type; + } +} + +#endif diff --git a/include/cppcoro/detail/sync_wait_task.hpp b/include/cppcoro/detail/sync_wait_task.hpp new file mode 100644 index 0000000..9f95330 --- /dev/null +++ b/include/cppcoro/detail/sync_wait_task.hpp @@ -0,0 +1,300 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_SYNC_WAIT_TASK_HPP_INCLUDED +#define CPPCORO_DETAIL_SYNC_WAIT_TASK_HPP_INCLUDED + +#include +#include +#include + +#include +#include +#include +#include + +namespace cppcoro +{ + namespace detail + { + template + class sync_wait_task; + + template + class sync_wait_task_promise final + { + using coroutine_handle_t = cppcoro::coroutine_handle>; + + public: + + using reference = RESULT&&; + + sync_wait_task_promise() noexcept + {} + + void start(detail::lightweight_manual_reset_event& event) + { + m_event = &event; + coroutine_handle_t::from_promise(*this).resume(); + } + + auto get_return_object() noexcept + { + return coroutine_handle_t::from_promise(*this); + } + + cppcoro::suspend_always initial_suspend() noexcept + { + return{}; + } + + auto final_suspend() noexcept + { + class completion_notifier + { + public: + + bool await_ready() const noexcept { return false; } + + void await_suspend(coroutine_handle_t coroutine) const noexcept + { + coroutine.promise().m_event->set(); + } + + void await_resume() noexcept {} + }; + + return completion_notifier{}; + } + +#if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000 + // HACK: This is needed to work around a bug in MSVC 2017.7/2017.8. + // See comment in make_sync_wait_task below. + template + Awaitable&& await_transform(Awaitable&& awaitable) + { + return static_cast(awaitable); + } + + struct get_promise_t {}; + static constexpr get_promise_t get_promise = {}; + + auto await_transform(get_promise_t) + { + class awaiter + { + public: + awaiter(sync_wait_task_promise* promise) noexcept : m_promise(promise) {} + bool await_ready() noexcept { + return true; + } + void await_suspend(cppcoro::coroutine_handle<>) noexcept {} + sync_wait_task_promise& await_resume() noexcept + { + return *m_promise; + } + private: + sync_wait_task_promise* m_promise; + }; + return awaiter{ this }; + } +#endif + + auto yield_value(reference result) noexcept + { + m_result = std::addressof(result); + return final_suspend(); + } + + void return_void() noexcept + { + // The coroutine should have either yielded a value or thrown + // an exception in which case it should have bypassed return_void(). + assert(false); + } + + void unhandled_exception() + { + m_exception = std::current_exception(); + } + + reference result() + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + + return static_cast(*m_result); + } + + private: + + detail::lightweight_manual_reset_event* m_event; + std::remove_reference_t* m_result; + std::exception_ptr m_exception; + + }; + + template<> + class sync_wait_task_promise + { + using coroutine_handle_t = cppcoro::coroutine_handle>; + + public: + + sync_wait_task_promise() noexcept + {} + + void start(detail::lightweight_manual_reset_event& event) + { + m_event = &event; + coroutine_handle_t::from_promise(*this).resume(); + } + + auto get_return_object() noexcept + { + return coroutine_handle_t::from_promise(*this); + } + + cppcoro::suspend_always initial_suspend() noexcept + { + return{}; + } + + auto final_suspend() noexcept + { + class completion_notifier + { + public: + + bool await_ready() const noexcept { return false; } + + void await_suspend(coroutine_handle_t coroutine) const noexcept + { + coroutine.promise().m_event->set(); + } + + void await_resume() noexcept {} + }; + + return completion_notifier{}; + } + + void return_void() {} + + void unhandled_exception() + { + m_exception = std::current_exception(); + } + + void result() + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + } + + private: + + detail::lightweight_manual_reset_event* m_event; + std::exception_ptr m_exception; + + }; + + template + class sync_wait_task final + { + public: + + using promise_type = sync_wait_task_promise; + + using coroutine_handle_t = cppcoro::coroutine_handle; + + sync_wait_task(coroutine_handle_t coroutine) noexcept + : m_coroutine(coroutine) + {} + + sync_wait_task(sync_wait_task&& other) noexcept + : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_t{})) + {} + + ~sync_wait_task() + { + if (m_coroutine) m_coroutine.destroy(); + } + + sync_wait_task(const sync_wait_task&) = delete; + sync_wait_task& operator=(const sync_wait_task&) = delete; + + void start(lightweight_manual_reset_event& event) noexcept + { + m_coroutine.promise().start(event); + } + + decltype(auto) result() + { + return m_coroutine.promise().result(); + } + + private: + + coroutine_handle_t m_coroutine; + + }; + +#if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000 + // HACK: Work around bug in MSVC where passing a parameter by universal reference + // results in an error when passed a move-only type, complaining that the copy-constructor + // has been deleted. The parameter should be passed by reference and the compiler should + // notcalling the copy-constructor for the argument + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + sync_wait_task make_sync_wait_task(AWAITABLE& awaitable) + { + // HACK: Workaround another bug in MSVC where the expression 'co_yield co_await x' seems + // to completely ignore the co_yield an never calls promise.yield_value(). + // The coroutine seems to be resuming the 'co_await' after the 'co_yield' + // rather than before the 'co_yield'. + // This bug is present in VS 2017.7 and VS 2017.8. + auto& promise = co_await sync_wait_task_promise::get_promise; + co_await promise.yield_value(co_await std::forward(awaitable)); + + //co_yield co_await std::forward(awaitable); + } + + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + sync_wait_task make_sync_wait_task(AWAITABLE& awaitable) + { + co_await static_cast(awaitable); + } +#else + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + sync_wait_task make_sync_wait_task(AWAITABLE&& awaitable) + { + co_yield co_await std::forward(awaitable); + } + + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + sync_wait_task make_sync_wait_task(AWAITABLE&& awaitable) + { + co_await std::forward(awaitable); + } +#endif + } +} + +#endif diff --git a/include/cppcoro/detail/unwrap_reference.hpp b/include/cppcoro/detail/unwrap_reference.hpp new file mode 100644 index 0000000..08a2159 --- /dev/null +++ b/include/cppcoro/detail/unwrap_reference.hpp @@ -0,0 +1,31 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_UNWRAP_REFERENCE_HPP_INCLUDED +#define CPPCORO_DETAIL_UNWRAP_REFERENCE_HPP_INCLUDED + +#include + +namespace cppcoro +{ + namespace detail + { + template + struct unwrap_reference + { + using type = T; + }; + + template + struct unwrap_reference> + { + using type = T; + }; + + template + using unwrap_reference_t = typename unwrap_reference::type; + } +} + +#endif diff --git a/include/cppcoro/detail/void_value.hpp b/include/cppcoro/detail/void_value.hpp new file mode 100644 index 0000000..420cad6 --- /dev/null +++ b/include/cppcoro/detail/void_value.hpp @@ -0,0 +1,16 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_VOID_VALUE_HPP_INCLUDED +#define CPPCORO_DETAIL_VOID_VALUE_HPP_INCLUDED + +namespace cppcoro +{ + namespace detail + { + struct void_value {}; + } +} + +#endif diff --git a/include/cppcoro/detail/when_all_counter.hpp b/include/cppcoro/detail/when_all_counter.hpp new file mode 100644 index 0000000..0a38e42 --- /dev/null +++ b/include/cppcoro/detail/when_all_counter.hpp @@ -0,0 +1,55 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_WHEN_ALL_COUNTER_HPP_INCLUDED +#define CPPCORO_DETAIL_WHEN_ALL_COUNTER_HPP_INCLUDED + +#include +#include +#include + +namespace cppcoro +{ + namespace detail + { + class when_all_counter + { + public: + + when_all_counter(std::size_t count) noexcept + : m_count(count + 1) + , m_awaitingCoroutine(nullptr) + {} + + bool is_ready() const noexcept + { + // We consider this complete if we're asking whether it's ready + // after a coroutine has already been registered. + return static_cast(m_awaitingCoroutine); + } + + bool try_await(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + m_awaitingCoroutine = awaitingCoroutine; + return m_count.fetch_sub(1, std::memory_order_acq_rel) > 1; + } + + void notify_awaitable_completed() noexcept + { + if (m_count.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + m_awaitingCoroutine.resume(); + } + } + + protected: + + std::atomic m_count; + cppcoro::coroutine_handle<> m_awaitingCoroutine; + + }; + } +} + +#endif diff --git a/include/cppcoro/detail/when_all_ready_awaitable.hpp b/include/cppcoro/detail/when_all_ready_awaitable.hpp new file mode 100644 index 0000000..54a5263 --- /dev/null +++ b/include/cppcoro/detail/when_all_ready_awaitable.hpp @@ -0,0 +1,258 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_WHEN_ALL_READY_AWAITABLE_HPP_INCLUDED +#define CPPCORO_DETAIL_WHEN_ALL_READY_AWAITABLE_HPP_INCLUDED + +#include + +#include +#include + +namespace cppcoro +{ + namespace detail + { + template + class when_all_ready_awaitable; + + template<> + class when_all_ready_awaitable> + { + public: + + constexpr when_all_ready_awaitable() noexcept {} + explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {} + + constexpr bool await_ready() const noexcept { return true; } + void await_suspend(cppcoro::coroutine_handle<>) noexcept {} + std::tuple<> await_resume() const noexcept { return {}; } + + }; + + template + class when_all_ready_awaitable> + { + public: + + explicit when_all_ready_awaitable(TASKS&&... tasks) + noexcept(std::conjunction_v...>) + : m_counter(sizeof...(TASKS)) + , m_tasks(std::move(tasks)...) + {} + + explicit when_all_ready_awaitable(std::tuple&& tasks) + noexcept(std::is_nothrow_move_constructible_v>) + : m_counter(sizeof...(TASKS)) + , m_tasks(std::move(tasks)) + {} + + when_all_ready_awaitable(when_all_ready_awaitable&& other) noexcept + : m_counter(sizeof...(TASKS)) + , m_tasks(std::move(other.m_tasks)) + {} + + auto operator co_await() & noexcept + { + struct awaiter + { + awaiter(when_all_ready_awaitable& awaitable) noexcept + : m_awaitable(awaitable) + {} + + bool await_ready() const noexcept + { + return m_awaitable.is_ready(); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_awaitable.try_await(awaitingCoroutine); + } + + std::tuple& await_resume() noexcept + { + return m_awaitable.m_tasks; + } + + private: + + when_all_ready_awaitable& m_awaitable; + + }; + + return awaiter{ *this }; + } + + auto operator co_await() && noexcept + { + struct awaiter + { + awaiter(when_all_ready_awaitable& awaitable) noexcept + : m_awaitable(awaitable) + {} + + bool await_ready() const noexcept + { + return m_awaitable.is_ready(); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_awaitable.try_await(awaitingCoroutine); + } + + std::tuple&& await_resume() noexcept + { + return std::move(m_awaitable.m_tasks); + } + + private: + + when_all_ready_awaitable& m_awaitable; + + }; + + return awaiter{ *this }; + } + + private: + + bool is_ready() const noexcept + { + return m_counter.is_ready(); + } + + bool try_await(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + start_tasks(std::make_integer_sequence{}); + return m_counter.try_await(awaitingCoroutine); + } + + template + void start_tasks(std::integer_sequence) noexcept + { + (void)std::initializer_list{ + (std::get(m_tasks).start(m_counter), 0)... + }; + } + + when_all_counter m_counter; + std::tuple m_tasks; + + }; + + template + class when_all_ready_awaitable + { + public: + + explicit when_all_ready_awaitable(TASK_CONTAINER&& tasks) noexcept + : m_counter(tasks.size()) + , m_tasks(std::forward(tasks)) + {} + + when_all_ready_awaitable(when_all_ready_awaitable&& other) + noexcept(std::is_nothrow_move_constructible_v) + : m_counter(other.m_tasks.size()) + , m_tasks(std::move(other.m_tasks)) + {} + + when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; + when_all_ready_awaitable& operator=(const when_all_ready_awaitable&) = delete; + + auto operator co_await() & noexcept + { + class awaiter + { + public: + + awaiter(when_all_ready_awaitable& awaitable) + : m_awaitable(awaitable) + {} + + bool await_ready() const noexcept + { + return m_awaitable.is_ready(); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_awaitable.try_await(awaitingCoroutine); + } + + TASK_CONTAINER& await_resume() noexcept + { + return m_awaitable.m_tasks; + } + + private: + + when_all_ready_awaitable& m_awaitable; + + }; + + return awaiter{ *this }; + } + + + auto operator co_await() && noexcept + { + class awaiter + { + public: + + awaiter(when_all_ready_awaitable& awaitable) + : m_awaitable(awaitable) + {} + + bool await_ready() const noexcept + { + return m_awaitable.is_ready(); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_awaitable.try_await(awaitingCoroutine); + } + + TASK_CONTAINER&& await_resume() noexcept + { + return std::move(m_awaitable.m_tasks); + } + + private: + + when_all_ready_awaitable& m_awaitable; + + }; + + return awaiter{ *this }; + } + + private: + + bool is_ready() const noexcept + { + return m_counter.is_ready(); + } + + bool try_await(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + for (auto&& task : m_tasks) + { + task.start(m_counter); + } + + return m_counter.try_await(awaitingCoroutine); + } + + when_all_counter m_counter; + TASK_CONTAINER m_tasks; + + }; + } +} + +#endif diff --git a/include/cppcoro/detail/when_all_task.hpp b/include/cppcoro/detail/when_all_task.hpp new file mode 100644 index 0000000..5787e00 --- /dev/null +++ b/include/cppcoro/detail/when_all_task.hpp @@ -0,0 +1,357 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_WHEN_ALL_TASK_HPP_INCLUDED +#define CPPCORO_DETAIL_WHEN_ALL_TASK_HPP_INCLUDED + +#include + +#include +#include + +#include +#include +#include + +namespace cppcoro +{ + namespace detail + { + template + class when_all_ready_awaitable; + + template + class when_all_task; + + template + class when_all_task_promise final + { + public: + + using coroutine_handle_t = cppcoro::coroutine_handle>; + + when_all_task_promise() noexcept + {} + + auto get_return_object() noexcept + { + return coroutine_handle_t::from_promise(*this); + } + + cppcoro::suspend_always initial_suspend() noexcept + { + return{}; + } + + auto final_suspend() noexcept + { + class completion_notifier + { + public: + + bool await_ready() const noexcept { return false; } + + void await_suspend(coroutine_handle_t coro) const noexcept + { + coro.promise().m_counter->notify_awaitable_completed(); + } + + void await_resume() const noexcept {} + + }; + + return completion_notifier{}; + } + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + void return_void() noexcept + { + // We should have either suspended at co_yield point or + // an exception was thrown before running off the end of + // the coroutine. + assert(false); + } + +#if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000 + // HACK: This is needed to work around a bug in MSVC 2017.7/2017.8. + // See comment in make_when_all_task below. + template + Awaitable&& await_transform(Awaitable&& awaitable) + { + return static_cast(awaitable); + } + + struct get_promise_t {}; + static constexpr get_promise_t get_promise = {}; + + auto await_transform(get_promise_t) + { + class awaiter + { + public: + awaiter(when_all_task_promise* promise) noexcept : m_promise(promise) {} + bool await_ready() noexcept { + return true; + } + void await_suspend(cppcoro::coroutine_handle<>) noexcept {} + when_all_task_promise& await_resume() noexcept + { + return *m_promise; + } + private: + when_all_task_promise* m_promise; + }; + return awaiter{ this }; + } +#endif + + + auto yield_value(RESULT&& result) noexcept + { + m_result = std::addressof(result); + return final_suspend(); + } + + void start(when_all_counter& counter) noexcept + { + m_counter = &counter; + coroutine_handle_t::from_promise(*this).resume(); + } + + RESULT& result() & + { + rethrow_if_exception(); + return *m_result; + } + + RESULT&& result() && + { + rethrow_if_exception(); + return std::forward(*m_result); + } + + private: + + void rethrow_if_exception() + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + } + + when_all_counter* m_counter; + std::exception_ptr m_exception; + std::add_pointer_t m_result; + + }; + + template<> + class when_all_task_promise final + { + public: + + using coroutine_handle_t = cppcoro::coroutine_handle>; + + when_all_task_promise() noexcept + {} + + auto get_return_object() noexcept + { + return coroutine_handle_t::from_promise(*this); + } + + cppcoro::suspend_always initial_suspend() noexcept + { + return{}; + } + + auto final_suspend() noexcept + { + class completion_notifier + { + public: + + bool await_ready() const noexcept { return false; } + + void await_suspend(coroutine_handle_t coro) const noexcept + { + coro.promise().m_counter->notify_awaitable_completed(); + } + + void await_resume() const noexcept {} + + }; + + return completion_notifier{}; + } + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + void return_void() noexcept + { + } + + void start(when_all_counter& counter) noexcept + { + m_counter = &counter; + coroutine_handle_t::from_promise(*this).resume(); + } + + void result() + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + } + + private: + + when_all_counter* m_counter; + std::exception_ptr m_exception; + + }; + + template + class when_all_task final + { + public: + + using promise_type = when_all_task_promise; + + using coroutine_handle_t = typename promise_type::coroutine_handle_t; + + when_all_task(coroutine_handle_t coroutine) noexcept + : m_coroutine(coroutine) + {} + + when_all_task(when_all_task&& other) noexcept + : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_t{})) + {} + + ~when_all_task() + { + if (m_coroutine) m_coroutine.destroy(); + } + + when_all_task(const when_all_task&) = delete; + when_all_task& operator=(const when_all_task&) = delete; + + decltype(auto) result() & + { + return m_coroutine.promise().result(); + } + + decltype(auto) result() && + { + return std::move(m_coroutine.promise()).result(); + } + + decltype(auto) non_void_result() & + { + if constexpr (std::is_void_vresult())>) + { + this->result(); + return void_value{}; + } + else + { + return this->result(); + } + } + + decltype(auto) non_void_result() && + { + if constexpr (std::is_void_vresult())>) + { + std::move(*this).result(); + return void_value{}; + } + else + { + return std::move(*this).result(); + } + } + + private: + + template + friend class when_all_ready_awaitable; + + void start(when_all_counter& counter) noexcept + { + m_coroutine.promise().start(counter); + } + + coroutine_handle_t m_coroutine; + + }; + + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + when_all_task make_when_all_task(AWAITABLE awaitable) + { +#if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000 + // HACK: Workaround another bug in MSVC where the expression 'co_yield co_await x' seems + // to completely ignore the co_yield an never calls promise.yield_value(). + // The coroutine seems to be resuming the 'co_await' after the 'co_yield' + // rather than before the 'co_yield'. + // This bug is present in VS 2017.7 and VS 2017.8. + auto& promise = co_await when_all_task_promise::get_promise; + co_await promise.yield_value(co_await std::forward(awaitable)); +#else + co_yield co_await static_cast(awaitable); +#endif + } + + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + when_all_task make_when_all_task(AWAITABLE awaitable) + { + co_await static_cast(awaitable); + } + + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + when_all_task make_when_all_task(std::reference_wrapper awaitable) + { +#if CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC < 19'20'00000 + // HACK: Workaround another bug in MSVC where the expression 'co_yield co_await x' seems + // to completely ignore the co_yield and never calls promise.yield_value(). + // The coroutine seems to be resuming the 'co_await' after the 'co_yield' + // rather than before the 'co_yield'. + // This bug is present in VS 2017.7 and VS 2017.8. + auto& promise = co_await when_all_task_promise::get_promise; + co_await promise.yield_value(co_await awaitable.get()); +#else + co_yield co_await awaitable.get(); +#endif + } + + template< + typename AWAITABLE, + typename RESULT = typename cppcoro::awaitable_traits::await_result_t, + std::enable_if_t, int> = 0> + when_all_task make_when_all_task(std::reference_wrapper awaitable) + { + co_await awaitable.get(); + } + } +} + +#endif diff --git a/include/cppcoro/detail/win32.hpp b/include/cppcoro/detail/win32.hpp new file mode 100644 index 0000000..a1e68e3 --- /dev/null +++ b/include/cppcoro/detail/win32.hpp @@ -0,0 +1,179 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_WIN32_HPP_INCLUDED +#define CPPCORO_DETAIL_WIN32_HPP_INCLUDED + +#include + +#if !CPPCORO_OS_WINNT +# error is only supported on the Windows platform. +#endif + +#include +#include + +struct _OVERLAPPED; + +namespace cppcoro +{ + namespace detail + { + namespace win32 + { + using handle_t = void*; + using ulongptr_t = std::uintptr_t; + using longptr_t = std::intptr_t; + using dword_t = unsigned long; + using socket_t = std::uintptr_t; + using ulong_t = unsigned long; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(push) +# pragma warning(disable : 4201) // Non-standard anonymous struct/union +#endif + + /// Structure needs to correspond exactly to the builtin + /// _OVERLAPPED structure from Windows.h. + struct overlapped + { + ulongptr_t Internal; + ulongptr_t InternalHigh; + union + { + struct + { + dword_t Offset; + dword_t OffsetHigh; + }; + void* Pointer; + }; + handle_t hEvent; + }; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(pop) +#endif + + struct wsabuf + { + constexpr wsabuf() noexcept + : len(0) + , buf(nullptr) + {} + + constexpr wsabuf(void* ptr, std::size_t size) + : len(size <= ulong_t(-1) ? ulong_t(size) : ulong_t(-1)) + , buf(static_cast(ptr)) + {} + + ulong_t len; + char* buf; + }; + + struct io_state : win32::overlapped + { + using callback_type = void( + io_state* state, + win32::dword_t errorCode, + win32::dword_t numberOfBytesTransferred, + win32::ulongptr_t completionKey); + + io_state(callback_type* callback = nullptr) noexcept + : io_state(std::uint64_t(0), callback) + {} + + io_state(void* pointer, callback_type* callback) noexcept + : m_callback(callback) + { + this->Internal = 0; + this->InternalHigh = 0; + this->Pointer = pointer; + this->hEvent = nullptr; + } + + io_state(std::uint64_t offset, callback_type* callback) noexcept + : m_callback(callback) + { + this->Internal = 0; + this->InternalHigh = 0; + this->Offset = static_cast(offset); + this->OffsetHigh = static_cast(offset >> 32); + this->hEvent = nullptr; + } + + callback_type* m_callback; + }; + + class safe_handle + { + public: + + safe_handle() + : m_handle(nullptr) + {} + + explicit safe_handle(handle_t handle) + : m_handle(handle) + {} + + safe_handle(const safe_handle& other) = delete; + + safe_handle(safe_handle&& other) noexcept + : m_handle(other.m_handle) + { + other.m_handle = nullptr; + } + + ~safe_handle() + { + close(); + } + + safe_handle& operator=(safe_handle handle) noexcept + { + swap(handle); + return *this; + } + + constexpr handle_t handle() const { return m_handle; } + + /// Calls CloseHandle() and sets the handle to NULL. + void close() noexcept; + + void swap(safe_handle& other) noexcept + { + std::swap(m_handle, other.m_handle); + } + + bool operator==(const safe_handle& other) const + { + return m_handle == other.m_handle; + } + + bool operator!=(const safe_handle& other) const + { + return m_handle != other.m_handle; + } + + bool operator==(handle_t handle) const + { + return m_handle == handle; + } + + bool operator!=(handle_t handle) const + { + return m_handle != handle; + } + + private: + + handle_t m_handle; + + }; + } + } +} + +#endif diff --git a/include/cppcoro/detail/win32_overlapped_operation.hpp b/include/cppcoro/detail/win32_overlapped_operation.hpp new file mode 100644 index 0000000..82cc04c --- /dev/null +++ b/include/cppcoro/detail/win32_overlapped_operation.hpp @@ -0,0 +1,376 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_DETAIL_WIN32_OVERLAPPED_OPERATION_HPP_INCLUDED +#define CPPCORO_DETAIL_WIN32_OVERLAPPED_OPERATION_HPP_INCLUDED + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace cppcoro +{ + namespace detail + { + class win32_overlapped_operation_base + : protected detail::win32::io_state + { + public: + + win32_overlapped_operation_base( + detail::win32::io_state::callback_type* callback) noexcept + : detail::win32::io_state(callback) + , m_errorCode(0) + , m_numberOfBytesTransferred(0) + {} + + win32_overlapped_operation_base( + void* pointer, + detail::win32::io_state::callback_type* callback) noexcept + : detail::win32::io_state(pointer, callback) + , m_errorCode(0) + , m_numberOfBytesTransferred(0) + {} + + win32_overlapped_operation_base( + std::uint64_t offset, + detail::win32::io_state::callback_type* callback) noexcept + : detail::win32::io_state(offset, callback) + , m_errorCode(0) + , m_numberOfBytesTransferred(0) + {} + + _OVERLAPPED* get_overlapped() noexcept + { + return reinterpret_cast<_OVERLAPPED*>( + static_cast(this)); + } + + std::size_t get_result() + { + if (m_errorCode != 0) + { + throw std::system_error{ + static_cast(m_errorCode), + std::system_category() + }; + } + + return m_numberOfBytesTransferred; + } + + detail::win32::dword_t m_errorCode; + detail::win32::dword_t m_numberOfBytesTransferred; + + }; + + template + class win32_overlapped_operation + : protected win32_overlapped_operation_base + { + protected: + + win32_overlapped_operation() noexcept + : win32_overlapped_operation_base( + &win32_overlapped_operation::on_operation_completed) + {} + + win32_overlapped_operation(void* pointer) noexcept + : win32_overlapped_operation_base( + pointer, + &win32_overlapped_operation::on_operation_completed) + {} + + win32_overlapped_operation(std::uint64_t offset) noexcept + : win32_overlapped_operation_base( + offset, + &win32_overlapped_operation::on_operation_completed) + {} + + public: + + bool await_ready() const noexcept { return false; } + + CPPCORO_NOINLINE + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) + { + static_assert(std::is_base_of_v); + + m_awaitingCoroutine = awaitingCoroutine; + return static_cast(this)->try_start(); + } + + decltype(auto) await_resume() + { + return static_cast(this)->get_result(); + } + + private: + + static void on_operation_completed( + detail::win32::io_state* ioState, + detail::win32::dword_t errorCode, + detail::win32::dword_t numberOfBytesTransferred, + [[maybe_unused]] detail::win32::ulongptr_t completionKey) noexcept + { + auto* operation = static_cast(ioState); + operation->m_errorCode = errorCode; + operation->m_numberOfBytesTransferred = numberOfBytesTransferred; + operation->m_awaitingCoroutine.resume(); + } + + cppcoro::coroutine_handle<> m_awaitingCoroutine; + + }; + + template + class win32_overlapped_operation_cancellable + : protected win32_overlapped_operation_base + { + // ERROR_OPERATION_ABORTED value from + static constexpr detail::win32::dword_t error_operation_aborted = 995L; + + protected: + + win32_overlapped_operation_cancellable(cancellation_token&& ct) noexcept + : win32_overlapped_operation_base(&win32_overlapped_operation_cancellable::on_operation_completed) + , m_state(ct.is_cancellation_requested() ? state::completed : state::not_started) + , m_cancellationToken(std::move(ct)) + { + m_errorCode = error_operation_aborted; + } + + win32_overlapped_operation_cancellable( + void* pointer, + cancellation_token&& ct) noexcept + : win32_overlapped_operation_base(pointer, &win32_overlapped_operation_cancellable::on_operation_completed) + , m_state(ct.is_cancellation_requested() ? state::completed : state::not_started) + , m_cancellationToken(std::move(ct)) + { + m_errorCode = error_operation_aborted; + } + + win32_overlapped_operation_cancellable( + std::uint64_t offset, + cancellation_token&& ct) noexcept + : win32_overlapped_operation_base(offset, &win32_overlapped_operation_cancellable::on_operation_completed) + , m_state(ct.is_cancellation_requested() ? state::completed : state::not_started) + , m_cancellationToken(std::move(ct)) + { + m_errorCode = error_operation_aborted; + } + + win32_overlapped_operation_cancellable( + win32_overlapped_operation_cancellable&& other) noexcept + : win32_overlapped_operation_base(std::move(other)) + , m_state(other.m_state.load(std::memory_order_relaxed)) + , m_cancellationToken(std::move(other.m_cancellationToken)) + { + assert(m_errorCode == other.m_errorCode); + assert(m_numberOfBytesTransferred == other.m_numberOfBytesTransferred); + } + + public: + + bool await_ready() const noexcept + { + return m_state.load(std::memory_order_relaxed) == state::completed; + } + + CPPCORO_NOINLINE + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) + { + static_assert(std::is_base_of_v); + + m_awaitingCoroutine = awaitingCoroutine; + + // TRICKY: Register cancellation callback before starting the operation + // in case the callback registration throws due to insufficient + // memory. We need to make sure that the logic that occurs after + // starting the operation is noexcept, otherwise we run into the + // problem of not being able to cancel the started operation and + // the dilemma of what to do with the exception. + // + // However, doing this means that the cancellation callback may run + // prior to returning below so in the case that cancellation may + // occur we defer setting the state to 'started' until after + // the operation has finished starting. The cancellation callback + // will only attempt to request cancellation of the operation with + // CancelIoEx() once the state has been set to 'started'. + const bool canBeCancelled = m_cancellationToken.can_be_cancelled(); + if (canBeCancelled) + { + m_cancellationCallback.emplace( + std::move(m_cancellationToken), + [this] { this->on_cancellation_requested(); }); + } + else + { + m_state.store(state::started, std::memory_order_relaxed); + } + + // Now start the operation. + const bool willCompleteAsynchronously = static_cast(this)->try_start(); + if (!willCompleteAsynchronously) + { + // Operation completed synchronously, resume awaiting coroutine immediately. + return false; + } + + if (canBeCancelled) + { + // Need to flag that the operation has finished starting now. + + // However, the operation may have completed concurrently on + // another thread, transitioning directly from not_started -> complete. + // Or it may have had the cancellation callback execute and transition + // from not_started -> cancellation_requested. We use a compare-exchange + // to determine a winner between these potential racing cases. + state oldState = state::not_started; + if (!m_state.compare_exchange_strong( + oldState, + state::started, + std::memory_order_release, + std::memory_order_acquire)) + { + if (oldState == state::cancellation_requested) + { + // Request the operation be cancelled. + // Note that it may have already completed on a background + // thread by now so this request for cancellation may end up + // being ignored. + static_cast(this)->cancel(); + + if (!m_state.compare_exchange_strong( + oldState, + state::started, + std::memory_order_release, + std::memory_order_acquire)) + { + assert(oldState == state::completed); + return false; + } + } + else + { + assert(oldState == state::completed); + return false; + } + } + } + + return true; + } + + decltype(auto) await_resume() + { + // Free memory used by the cancellation callback now that the operation + // has completed rather than waiting until the operation object destructs. + // eg. If the operation is passed to when_all() then the operation object + // may not be destructed until all of the operations complete. + m_cancellationCallback.reset(); + + if (m_errorCode == error_operation_aborted) + { + throw operation_cancelled{}; + } + + return static_cast(this)->get_result(); + } + + private: + + enum class state + { + not_started, + started, + cancellation_requested, + completed + }; + + void on_cancellation_requested() noexcept + { + auto oldState = m_state.load(std::memory_order_acquire); + if (oldState == state::not_started) + { + // This callback is running concurrently with await_suspend(). + // The call to start the operation may not have returned yet so + // we can't safely request cancellation of it. Instead we try to + // notify the await_suspend() thread by transitioning the state + // to state::cancellation_requested so that the await_suspend() + // thread can request cancellation after it has finished starting + // the operation. + const bool transferredCancelResponsibility = + m_state.compare_exchange_strong( + oldState, + state::cancellation_requested, + std::memory_order_release, + std::memory_order_acquire); + if (transferredCancelResponsibility) + { + return; + } + } + + // No point requesting cancellation if the operation has already completed. + if (oldState != state::completed) + { + static_cast(this)->cancel(); + } + } + + static void on_operation_completed( + detail::win32::io_state* ioState, + detail::win32::dword_t errorCode, + detail::win32::dword_t numberOfBytesTransferred, + [[maybe_unused]] detail::win32::ulongptr_t completionKey) noexcept + { + auto* operation = static_cast(ioState); + + operation->m_errorCode = errorCode; + operation->m_numberOfBytesTransferred = numberOfBytesTransferred; + + auto state = operation->m_state.load(std::memory_order_acquire); + if (state == state::started) + { + operation->m_state.store(state::completed, std::memory_order_relaxed); + operation->m_awaitingCoroutine.resume(); + } + else + { + // We are racing with await_suspend() call suspending. + // Try to mark it as completed using an atomic exchange and look + // at the previous value to determine whether the coroutine suspended + // first (in which case we resume it now) or we marked it as completed + // first (in which case await_suspend() will return false and immediately + // resume the coroutine). + state = operation->m_state.exchange( + state::completed, + std::memory_order_acq_rel); + if (state == state::started) + { + // The await_suspend() method returned (or will return) 'true' and so + // we need to resume the coroutine. + operation->m_awaitingCoroutine.resume(); + } + } + } + + std::atomic m_state; + cppcoro::cancellation_token m_cancellationToken; + std::optional m_cancellationCallback; + cppcoro::coroutine_handle<> m_awaitingCoroutine; + + }; + } +} + +#endif diff --git a/include/cppcoro/file.hpp b/include/cppcoro/file.hpp new file mode 100644 index 0000000..1366af8 --- /dev/null +++ b/include/cppcoro/file.hpp @@ -0,0 +1,54 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_FILE_HPP_INCLUDED +#define CPPCORO_FILE_HPP_INCLUDED + +#include + +#include +#include +#include + +#if CPPCORO_OS_WINNT +# include +#endif + +#include + +namespace cppcoro +{ + class io_service; + + class file + { + public: + + file(file&& other) noexcept = default; + + virtual ~file(); + + /// Get the size of the file in bytes. + std::uint64_t size() const; + + protected: + +#if CPPCORO_OS_WINNT + file(detail::win32::safe_handle&& fileHandle) noexcept; + + static detail::win32::safe_handle open( + detail::win32::dword_t fileAccess, + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode, + file_share_mode shareMode, + file_buffering_mode bufferingMode); + + detail::win32::safe_handle m_fileHandle; +#endif + + }; +} + +#endif diff --git a/include/cppcoro/file_buffering_mode.hpp b/include/cppcoro/file_buffering_mode.hpp new file mode 100644 index 0000000..40ddc3f --- /dev/null +++ b/include/cppcoro/file_buffering_mode.hpp @@ -0,0 +1,32 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_FILE_BUFFERING_MODE_HPP_INCLUDED +#define CPPCORO_FILE_BUFFERING_MODE_HPP_INCLUDED + +namespace cppcoro +{ + enum class file_buffering_mode + { + default_ = 0, + sequential = 1, + random_access = 2, + unbuffered = 4, + write_through = 8, + temporary = 16 + }; + + constexpr file_buffering_mode operator&(file_buffering_mode a, file_buffering_mode b) + { + return static_cast( + static_cast(a) & static_cast(b)); + } + + constexpr file_buffering_mode operator|(file_buffering_mode a, file_buffering_mode b) + { + return static_cast(static_cast(a) | static_cast(b)); + } +} + +#endif diff --git a/include/cppcoro/file_open_mode.hpp b/include/cppcoro/file_open_mode.hpp new file mode 100644 index 0000000..e76f699 --- /dev/null +++ b/include/cppcoro/file_open_mode.hpp @@ -0,0 +1,40 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_FILE_OPEN_MODE_HPP_INCLUDED +#define CPPCORO_FILE_OPEN_MODE_HPP_INCLUDED + +namespace cppcoro +{ + enum class file_open_mode + { + /// Open an existing file. + /// + /// If file does not already exist when opening the file then raises + /// an exception. + open_existing, + + /// Create a new file, overwriting an existing file if one exists. + /// + /// If a file exists at the path then it is overwitten with a new file. + /// If no file exists at the path then a new one is created. + create_always, + + /// Create a new file. + /// + /// If the file already exists then raises an exception. + create_new, + + /// Open the existing file if one exists, otherwise create a new empty + /// file. + create_or_open, + + /// Open the existing file, truncating the file size to zero. + /// + /// If the file does not exist then raises an exception. + truncate_existing + }; +} + +#endif diff --git a/include/cppcoro/file_read_operation.hpp b/include/cppcoro/file_read_operation.hpp new file mode 100644 index 0000000..4116c2f --- /dev/null +++ b/include/cppcoro/file_read_operation.hpp @@ -0,0 +1,99 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_FILE_READ_OPERATION_HPP_INCLUDED +#define CPPCORO_FILE_READ_OPERATION_HPP_INCLUDED + +#include +#include +#include + +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro +{ + class file_read_operation_impl + { + public: + + file_read_operation_impl( + detail::win32::handle_t fileHandle, + void* buffer, + std::size_t byteCount) noexcept + : m_fileHandle(fileHandle) + , m_buffer(buffer) + , m_byteCount(byteCount) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + + private: + + detail::win32::handle_t m_fileHandle; + void* m_buffer; + std::size_t m_byteCount; + + }; + + class file_read_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + file_read_operation( + detail::win32::handle_t fileHandle, + std::uint64_t fileOffset, + void* buffer, + std::size_t byteCount) noexcept + : cppcoro::detail::win32_overlapped_operation(fileOffset) + , m_impl(fileHandle, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + file_read_operation_impl m_impl; + + }; + + class file_read_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + file_read_operation_cancellable( + detail::win32::handle_t fileHandle, + std::uint64_t fileOffset, + void* buffer, + std::size_t byteCount, + cancellation_token&& cancellationToken) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable( + fileOffset, std::move(cancellationToken)) + , m_impl(fileHandle, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + + file_read_operation_impl m_impl; + + }; + +#endif +} + +#endif diff --git a/include/cppcoro/file_share_mode.hpp b/include/cppcoro/file_share_mode.hpp new file mode 100644 index 0000000..d7679a5 --- /dev/null +++ b/include/cppcoro/file_share_mode.hpp @@ -0,0 +1,45 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_FILE_SHARE_MODE_HPP_INCLUDED +#define CPPCORO_FILE_SHARE_MODE_HPP_INCLUDED + +namespace cppcoro +{ + enum class file_share_mode + { + /// Don't allow any other processes to open the file concurrently. + none = 0, + + /// Allow other processes to open the file in read-only mode + /// concurrently with this process opening the file. + read = 1, + + /// Allow other processes to open the file in write-only mode + /// concurrently with this process opening the file. + write = 2, + + /// Allow other processes to open the file in read and/or write mode + /// concurrently with this process opening the file. + read_write = read | write, + + /// Allow other processes to delete the file while this process + /// has the file open. + delete_ = 4 + }; + + constexpr file_share_mode operator|(file_share_mode a, file_share_mode b) + { + return static_cast( + static_cast(a) | static_cast(b)); + } + + constexpr file_share_mode operator&(file_share_mode a, file_share_mode b) + { + return static_cast( + static_cast(a) & static_cast(b)); + } +} + +#endif diff --git a/include/cppcoro/file_write_operation.hpp b/include/cppcoro/file_write_operation.hpp new file mode 100644 index 0000000..40f6854 --- /dev/null +++ b/include/cppcoro/file_write_operation.hpp @@ -0,0 +1,98 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_FILE_WRITE_OPERATION_HPP_INCLUDED +#define CPPCORO_FILE_WRITE_OPERATION_HPP_INCLUDED + +#include +#include +#include + +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro +{ + class file_write_operation_impl + { + public: + + file_write_operation_impl( + detail::win32::handle_t fileHandle, + const void* buffer, + std::size_t byteCount) noexcept + : m_fileHandle(fileHandle) + , m_buffer(buffer) + , m_byteCount(byteCount) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + + private: + + detail::win32::handle_t m_fileHandle; + const void* m_buffer; + std::size_t m_byteCount; + + }; + + class file_write_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + file_write_operation( + detail::win32::handle_t fileHandle, + std::uint64_t fileOffset, + const void* buffer, + std::size_t byteCount) noexcept + : cppcoro::detail::win32_overlapped_operation(fileOffset) + , m_impl(fileHandle, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + file_write_operation_impl m_impl; + + }; + + class file_write_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + file_write_operation_cancellable( + detail::win32::handle_t fileHandle, + std::uint64_t fileOffset, + const void* buffer, + std::size_t byteCount, + cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(fileOffset, std::move(ct)) + , m_impl(fileHandle, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + + file_write_operation_impl m_impl; + + }; +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/filesystem.hpp b/include/cppcoro/filesystem.hpp new file mode 100644 index 0000000..798b8f7 --- /dev/null +++ b/include/cppcoro/filesystem.hpp @@ -0,0 +1,24 @@ +#ifndef CPPCORO_FILESYSTEM_HPP_INCLUDED +#define CPPCORO_FILESYSTEM_HPP_INCLUDED + +#if __has_include() + +#include + +namespace cppcoro { + namespace filesystem = std::filesystem; +} + +#elif __has_include() + +#include + +namespace cppcoro { + namespace filesystem = std::experimental::filesystem; +} + +#else +#error Cppcoro requires a C++20 compiler with filesystem support +#endif + +#endif diff --git a/include/cppcoro/fmap.hpp b/include/cppcoro/fmap.hpp new file mode 100644 index 0000000..6c4e802 --- /dev/null +++ b/include/cppcoro/fmap.hpp @@ -0,0 +1,169 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_FMAP_HPP_INCLUDED +#define CPPCORO_FMAP_HPP_INCLUDED + +#include +#include + +#include +#include +#include + +namespace cppcoro +{ + namespace detail + { + template + class fmap_awaiter + { + using awaiter_t = typename awaitable_traits::awaiter_t; + FUNC&& m_func; + awaiter_t m_awaiter; + + public: + + fmap_awaiter(FUNC&& func, AWAITABLE&& awaitable) + noexcept( + std::is_nothrow_move_constructible_v && + noexcept(detail::get_awaiter(static_cast(awaitable)))) + : m_func(static_cast(func)) + , m_awaiter(detail::get_awaiter(static_cast(awaitable))) + {} + + decltype(auto) await_ready() + noexcept(noexcept(static_cast(m_awaiter).await_ready())) + { + return static_cast(m_awaiter).await_ready(); + } + + template + decltype(auto) await_suspend(cppcoro::coroutine_handle coro) + noexcept(noexcept(static_cast(m_awaiter).await_suspend(std::move(coro)))) + { + return static_cast(m_awaiter).await_suspend(std::move(coro)); + } + + template< + typename AWAIT_RESULT = decltype(std::declval().await_resume()), + std::enable_if_t, int> = 0> + decltype(auto) await_resume() + noexcept(noexcept(std::invoke(static_cast(m_func)))) + { + static_cast(m_awaiter).await_resume(); + return std::invoke(static_cast(m_func)); + } + + template< + typename AWAIT_RESULT = decltype(std::declval().await_resume()), + std::enable_if_t, int> = 0> + decltype(auto) await_resume() + noexcept(noexcept(std::invoke(static_cast(m_func), static_cast(m_awaiter).await_resume()))) + { + return std::invoke( + static_cast(m_func), + static_cast(m_awaiter).await_resume()); + } + }; + + template + class fmap_awaitable + { + static_assert(!std::is_lvalue_reference_v); + static_assert(!std::is_lvalue_reference_v); + public: + + template< + typename FUNC_ARG, + typename AWAITABLE_ARG, + std::enable_if_t< + std::is_constructible_v && + std::is_constructible_v, int> = 0> + explicit fmap_awaitable(FUNC_ARG&& func, AWAITABLE_ARG&& awaitable) + noexcept( + std::is_nothrow_constructible_v && + std::is_nothrow_constructible_v) + : m_func(static_cast(func)) + , m_awaitable(static_cast(awaitable)) + {} + + auto operator co_await() const & + { + return fmap_awaiter(m_func, m_awaitable); + } + + auto operator co_await() & + { + return fmap_awaiter(m_func, m_awaitable); + } + + auto operator co_await() && + { + return fmap_awaiter( + static_cast(m_func), + static_cast(m_awaitable)); + } + + private: + + FUNC m_func; + AWAITABLE m_awaitable; + + }; + } + + template + struct fmap_transform + { + explicit fmap_transform(FUNC&& f) + noexcept(std::is_nothrow_move_constructible_v) + : func(std::forward(f)) + {} + + FUNC func; + }; + + template< + typename FUNC, + typename AWAITABLE, + std::enable_if_t, int> = 0> + auto fmap(FUNC&& func, AWAITABLE&& awaitable) + { + return detail::fmap_awaitable< + std::remove_cv_t>, + std::remove_cv_t>>( + std::forward(func), + std::forward(awaitable)); + } + + template + auto fmap(FUNC&& func) + { + return fmap_transform{ std::forward(func) }; + } + + template + decltype(auto) operator|(T&& value, fmap_transform&& transform) + { + // Use ADL for finding fmap() overload. + return fmap(std::forward(transform.func), std::forward(value)); + } + + template + decltype(auto) operator|(T&& value, const fmap_transform& transform) + { + // Use ADL for finding fmap() overload. + return fmap(transform.func, std::forward(value)); + } + + template + decltype(auto) operator|(T&& value, fmap_transform& transform) + { + // Use ADL for finding fmap() overload. + return fmap(transform.func, std::forward(value)); + } +} + +#endif diff --git a/include/cppcoro/generator.hpp b/include/cppcoro/generator.hpp new file mode 100644 index 0000000..7b5df49 --- /dev/null +++ b/include/cppcoro/generator.hpp @@ -0,0 +1,260 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_GENERATOR_HPP_INCLUDED +#define CPPCORO_GENERATOR_HPP_INCLUDED + +#include +#include +#include +#include +#include +#include + +namespace cppcoro +{ + template + class generator; + + namespace detail + { + template + class generator_promise + { + public: + + using value_type = std::remove_reference_t; + using reference_type = std::conditional_t, T, T&>; + using pointer_type = value_type*; + + generator_promise() = default; + + generator get_return_object() noexcept; + + constexpr cppcoro::suspend_always initial_suspend() const noexcept { return {}; } + constexpr cppcoro::suspend_always final_suspend() const noexcept { return {}; } + + template< + typename U = T, + std::enable_if_t::value, int> = 0> + cppcoro::suspend_always yield_value(std::remove_reference_t& value) noexcept + { + m_value = std::addressof(value); + return {}; + } + + cppcoro::suspend_always yield_value(std::remove_reference_t&& value) noexcept + { + m_value = std::addressof(value); + return {}; + } + + void unhandled_exception() + { + m_exception = std::current_exception(); + } + + void return_void() + { + } + + reference_type value() const noexcept + { + return static_cast(*m_value); + } + + // Don't allow any use of 'co_await' inside the generator coroutine. + template + cppcoro::suspend_never await_transform(U&& value) = delete; + + void rethrow_if_exception() + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + } + + private: + + pointer_type m_value; + std::exception_ptr m_exception; + + }; + + struct generator_sentinel {}; + + template + class generator_iterator + { + using coroutine_handle = cppcoro::coroutine_handle>; + + public: + + using iterator_category = std::input_iterator_tag; + // What type should we use for counting elements of a potentially infinite sequence? + using difference_type = std::ptrdiff_t; + using value_type = typename generator_promise::value_type; + using reference = typename generator_promise::reference_type; + using pointer = typename generator_promise::pointer_type; + + // Iterator needs to be default-constructible to satisfy the Range concept. + generator_iterator() noexcept + : m_coroutine(nullptr) + {} + + explicit generator_iterator(coroutine_handle coroutine) noexcept + : m_coroutine(coroutine) + {} + + friend bool operator==(const generator_iterator& it, generator_sentinel) noexcept + { + return !it.m_coroutine || it.m_coroutine.done(); + } + + friend bool operator!=(const generator_iterator& it, generator_sentinel s) noexcept + { + return !(it == s); + } + + friend bool operator==(generator_sentinel s, const generator_iterator& it) noexcept + { + return (it == s); + } + + friend bool operator!=(generator_sentinel s, const generator_iterator& it) noexcept + { + return it != s; + } + + generator_iterator& operator++() + { + m_coroutine.resume(); + if (m_coroutine.done()) + { + m_coroutine.promise().rethrow_if_exception(); + } + + return *this; + } + + // Need to provide post-increment operator to implement the 'Range' concept. + void operator++(int) + { + (void)operator++(); + } + + reference operator*() const noexcept + { + return m_coroutine.promise().value(); + } + + pointer operator->() const noexcept + { + return std::addressof(operator*()); + } + + private: + + coroutine_handle m_coroutine; + }; + } + + template + class [[nodiscard]] generator + { + public: + + using promise_type = detail::generator_promise; + using iterator = detail::generator_iterator; + + generator() noexcept + : m_coroutine(nullptr) + {} + + generator(generator&& other) noexcept + : m_coroutine(other.m_coroutine) + { + other.m_coroutine = nullptr; + } + + generator(const generator& other) = delete; + + ~generator() + { + if (m_coroutine) + { + m_coroutine.destroy(); + } + } + + generator& operator=(generator other) noexcept + { + swap(other); + return *this; + } + + iterator begin() + { + if (m_coroutine) + { + m_coroutine.resume(); + if (m_coroutine.done()) + { + m_coroutine.promise().rethrow_if_exception(); + } + } + + return iterator{ m_coroutine }; + } + + detail::generator_sentinel end() noexcept + { + return detail::generator_sentinel{}; + } + + void swap(generator& other) noexcept + { + std::swap(m_coroutine, other.m_coroutine); + } + + private: + + friend class detail::generator_promise; + + explicit generator(cppcoro::coroutine_handle coroutine) noexcept + : m_coroutine(coroutine) + {} + + cppcoro::coroutine_handle m_coroutine; + + }; + + template + void swap(generator& a, generator& b) + { + a.swap(b); + } + + namespace detail + { + template + generator generator_promise::get_return_object() noexcept + { + using coroutine_handle = cppcoro::coroutine_handle>; + return generator{ coroutine_handle::from_promise(*this) }; + } + } + + template + generator::iterator::reference>> fmap(FUNC func, generator source) + { + for (auto&& value : source) + { + co_yield std::invoke(func, static_cast(value)); + } + } +} + +#endif diff --git a/include/cppcoro/inline_scheduler.hpp b/include/cppcoro/inline_scheduler.hpp new file mode 100644 index 0000000..a0862a0 --- /dev/null +++ b/include/cppcoro/inline_scheduler.hpp @@ -0,0 +1,25 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_INLINE_SCHEDULER_HPP_INCLUDED +#define CPPCORO_INLINE_SCHEDULER_HPP_INCLUDED + +#include + +namespace cppcoro +{ + class inline_scheduler + { + public: + + inline_scheduler() noexcept = default; + + cppcoro::suspend_never schedule() const noexcept + { + return {}; + } + }; +} + +#endif diff --git a/include/cppcoro/io_service.hpp b/include/cppcoro/io_service.hpp new file mode 100644 index 0000000..b6b4dbd --- /dev/null +++ b/include/cppcoro/io_service.hpp @@ -0,0 +1,321 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_IO_SERVICE_HPP_INCLUDED +#define CPPCORO_IO_SERVICE_HPP_INCLUDED + +#include +#include +#include + +#if CPPCORO_OS_WINNT +# include +#endif + +#include +#include +#include +#include +#include +#include +#include + +namespace cppcoro +{ + class io_service + { + public: + + class schedule_operation; + class timed_schedule_operation; + + /// Initialises the io_service. + /// + /// Does not set a concurrency hint. All threads that enter the + /// event loop will actively process events. + io_service(); + + /// Initialise the io_service with a concurrency hint. + /// + /// \param concurrencyHint + /// Specifies the target maximum number of I/O threads to be + /// actively processing events. + /// Note that the number of active threads may temporarily go + /// above this number. + io_service(std::uint32_t concurrencyHint); + + ~io_service(); + + io_service(io_service&& other) = delete; + io_service(const io_service& other) = delete; + io_service& operator=(io_service&& other) = delete; + io_service& operator=(const io_service& other) = delete; + + /// Returns an operation that when awaited suspends the awaiting + /// coroutine and reschedules it for resumption on an I/O thread + /// associated with this io_service. + [[nodiscard]] + schedule_operation schedule() noexcept; + + /// Returns an operation that when awaited will suspend the + /// awaiting coroutine for the specified delay. Once the delay + /// has elapsed, the coroutine will resume execution on an + /// I/O thread associated with this io_service. + /// + /// \param delay + /// The amount of time to delay scheduling resumption of the coroutine + /// on an I/O thread. There is no guarantee that the coroutine will + /// be resumed exactly after this delay. + /// + /// \param cancellationToken [optional] + /// A cancellation token that can be used to communicate a request to + /// cancel the delayed schedule operation and schedule it for resumption + /// immediately. + /// The co_await operation will throw cppcoro::operation_cancelled if + /// cancellation was requested before the coroutine could be resumed. + template + [[nodiscard]] + timed_schedule_operation schedule_after( + const std::chrono::duration& delay, + cancellation_token cancellationToken = {}) noexcept; + + /// Process events until the io_service is stopped. + /// + /// \return + /// The number of events processed during this call. + std::uint64_t process_events(); + + /// Process events until either the io_service is stopped or + /// there are no more pending events in the queue. + /// + /// \return + /// The number of events processed during this call. + std::uint64_t process_pending_events(); + + /// Block until either one event is processed or the io_service is stopped. + /// + /// \return + /// The number of events processed during this call. + /// This will either be 0 or 1. + std::uint64_t process_one_event(); + + /// Process one event if there are any events pending, otherwise if there + /// are no events pending or the io_service is stopped then return immediately. + /// + /// \return + /// The number of events processed during this call. + /// This will either be 0 or 1. + std::uint64_t process_one_pending_event(); + + /// Shut down the io_service. + /// + /// This will cause any threads currently in a call to one of the process_xxx() methods + /// to return from that call once they finish processing the current event. + /// + /// This call does not wait until all threads have exited the event loop so you + /// must use other synchronisation mechanisms to wait for those threads. + void stop() noexcept; + + /// Reset an io_service to prepare it for resuming processing of events. + /// + /// Call this after a call to stop() to allow calls to process_xxx() methods + /// to process events. + /// + /// After calling stop() you should ensure that all threads have returned from + /// calls to process_xxx() methods before calling reset(). + void reset(); + + bool is_stop_requested() const noexcept; + + void notify_work_started() noexcept; + + void notify_work_finished() noexcept; + +#if CPPCORO_OS_WINNT + detail::win32::handle_t native_iocp_handle() noexcept; + void ensure_winsock_initialised(); +#endif + + private: + + class timer_thread_state; + class timer_queue; + + friend class schedule_operation; + friend class timed_schedule_operation; + + void schedule_impl(schedule_operation* operation) noexcept; + + void try_reschedule_overflow_operations() noexcept; + + bool try_enter_event_loop() noexcept; + void exit_event_loop() noexcept; + + bool try_process_one_event(bool waitForEvent); + + void post_wake_up_event() noexcept; + + timer_thread_state* ensure_timer_thread_started(); + + static constexpr std::uint32_t stop_requested_flag = 1; + static constexpr std::uint32_t active_thread_count_increment = 2; + + // Bit 0: stop_requested_flag + // Bit 1-31: count of active threads currently running the event loop + std::atomic m_threadState; + + std::atomic m_workCount; + +#if CPPCORO_OS_WINNT + detail::win32::safe_handle m_iocpHandle; + + std::atomic m_winsockInitialised; + std::mutex m_winsockInitialisationMutex; +#endif + + // Head of a linked-list of schedule operations that are + // ready to run but that failed to be queued to the I/O + // completion port (eg. due to low memory). + std::atomic m_scheduleOperations; + + std::atomic m_timerState; + + }; + + class io_service::schedule_operation + { + public: + + schedule_operation(io_service& service) noexcept + : m_service(service) + {} + + bool await_ready() const noexcept { return false; } + void await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept; + void await_resume() const noexcept {} + + private: + + friend class io_service; + friend class io_service::timed_schedule_operation; + + io_service& m_service; + cppcoro::coroutine_handle<> m_awaiter; + schedule_operation* m_next; + + }; + + class io_service::timed_schedule_operation + { + public: + + timed_schedule_operation( + io_service& service, + std::chrono::high_resolution_clock::time_point resumeTime, + cppcoro::cancellation_token cancellationToken) noexcept; + + timed_schedule_operation(timed_schedule_operation&& other) noexcept; + + ~timed_schedule_operation(); + + timed_schedule_operation& operator=(timed_schedule_operation&& other) = delete; + timed_schedule_operation(const timed_schedule_operation& other) = delete; + timed_schedule_operation& operator=(const timed_schedule_operation& other) = delete; + + bool await_ready() const noexcept; + void await_suspend(cppcoro::coroutine_handle<> awaiter); + void await_resume(); + + private: + + friend class io_service::timer_queue; + friend class io_service::timer_thread_state; + + io_service::schedule_operation m_scheduleOperation; + std::chrono::high_resolution_clock::time_point m_resumeTime; + + cppcoro::cancellation_token m_cancellationToken; + std::optional m_cancellationRegistration; + + timed_schedule_operation* m_next; + + std::atomic m_refCount; + + }; + + class io_work_scope + { + public: + + explicit io_work_scope(io_service& service) noexcept + : m_service(&service) + { + service.notify_work_started(); + } + + io_work_scope(const io_work_scope& other) noexcept + : m_service(other.m_service) + { + if (m_service != nullptr) + { + m_service->notify_work_started(); + } + } + + io_work_scope(io_work_scope&& other) noexcept + : m_service(other.m_service) + { + other.m_service = nullptr; + } + + ~io_work_scope() + { + if (m_service != nullptr) + { + m_service->notify_work_finished(); + } + } + + void swap(io_work_scope& other) noexcept + { + std::swap(m_service, other.m_service); + } + + io_work_scope& operator=(io_work_scope other) noexcept + { + swap(other); + return *this; + } + + io_service& service() noexcept + { + return *m_service; + } + + private: + + io_service* m_service; + + }; + + inline void swap(io_work_scope& a, io_work_scope& b) + { + a.swap(b); + } +} + +template +cppcoro::io_service::timed_schedule_operation +cppcoro::io_service::schedule_after( + const std::chrono::duration& duration, + cppcoro::cancellation_token cancellationToken) noexcept +{ + return timed_schedule_operation{ + *this, + std::chrono::high_resolution_clock::now() + duration, + std::move(cancellationToken) + }; +} + +#endif diff --git a/include/cppcoro/is_awaitable.hpp b/include/cppcoro/is_awaitable.hpp new file mode 100644 index 0000000..ea9bf21 --- /dev/null +++ b/include/cppcoro/is_awaitable.hpp @@ -0,0 +1,26 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_IS_AWAITABLE_HPP_INCLUDED +#define CPPCORO_IS_AWAITABLE_HPP_INCLUDED + +#include + +#include + +namespace cppcoro +{ + template> + struct is_awaitable : std::false_type {}; + + template + struct is_awaitable()))>> + : std::true_type + {}; + + template + constexpr bool is_awaitable_v = is_awaitable::value; +} + +#endif diff --git a/include/cppcoro/multi_producer_sequencer.hpp b/include/cppcoro/multi_producer_sequencer.hpp new file mode 100644 index 0000000..c656594 --- /dev/null +++ b/include/cppcoro/multi_producer_sequencer.hpp @@ -0,0 +1,829 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_MULTI_PRODUCER_SEQUENCER_HPP_INCLUDED +#define CPPCORO_MULTI_PRODUCER_SEQUENCER_HPP_INCLUDED + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace cppcoro +{ + template + class multi_producer_sequencer_claim_one_operation; + + template + class multi_producer_sequencer_claim_operation; + + template + class multi_producer_sequencer_wait_operation_base; + + template + class multi_producer_sequencer_wait_operation; + + /// A multi-producer sequencer is a thread-synchronisation primitive that can be + /// used to synchronise access to a ring-buffer of power-of-two size where you + /// have multiple producers concurrently claiming slots in the ring-buffer and + /// publishing items. + /// + /// When a writer wants to write to a slot in the buffer it first atomically + /// increments a counter by the number of slots it wishes to allocate. + /// It then waits until all of those slots have become available and then + /// returns the range of sequence numbers allocated back to the caller. + /// The caller then writes to those slots and when done publishes them by + /// writing the sequence numbers published to each of the slots to the + /// corresponding element of an array of equal size to the ring buffer. + /// When a reader wants to check if the next sequence number is available + /// it then simply needs to read from the corresponding slot in this array + /// to check if the value stored there is equal to the sequence number it + /// is wanting to read. + /// + /// This means concurrent writers are wait-free when there is space available + /// in the ring buffer, requiring a single atomic fetch-add operation as the + /// only contended write operation. All other writes are to memory locations + /// owned by a particular writer. Concurrent writers can publish items out of + /// order so that one writer does not hold up other writers until the ring + /// buffer fills up. + template< + typename SEQUENCE = std::size_t, + typename TRAITS = sequence_traits> + class multi_producer_sequencer + { + public: + + multi_producer_sequencer( + const sequence_barrier& consumerBarrier, + std::size_t bufferSize, + SEQUENCE initialSequence = TRAITS::initial_sequence); + + /// The size of the circular buffer. This will be a power-of-two. + std::size_t buffer_size() const noexcept { return m_sequenceMask + 1; } + + /// Lookup the last-known-published sequence number after the specified + /// sequence number. + SEQUENCE last_published_after(SEQUENCE lastKnownPublished) const noexcept; + + /// Wait until the specified target sequence number has been published. + /// + /// Returns an awaitable type that when co_awaited will suspend the awaiting + /// coroutine until the specified 'targetSequence' number and all prior sequence + /// numbers have been published. + template + multi_producer_sequencer_wait_operation wait_until_published( + SEQUENCE targetSequence, + SEQUENCE lastKnownPublished, + SCHEDULER& scheduler) const noexcept; + + /// Query if there are currently any slots available for claiming. + /// + /// Note that this return-value is only approximate if you have multiple producers + /// since immediately after returning true another thread may have claimed the + /// last available slot. + bool any_available() const noexcept; + + /// Claim a single slot in the buffer and wait until that slot becomes available. + /// + /// Returns an Awaitable type that yields the sequence number of the slot that + /// was claimed. + /// + /// Once the producer has claimed a slot then they are free to write to that + /// slot within the ring buffer. Once the value has been initialised the item + /// must be published by calling the .publish() method, passing the sequence + /// number. + template + multi_producer_sequencer_claim_one_operation + claim_one(SCHEDULER& scheduler) noexcept; + + /// Claim a contiguous range of sequence numbers corresponding to slots within + /// a ring-buffer. + /// + /// This will claim at most the specified count of sequence numbers but may claim + /// fewer if there are only fewer entries available in the buffer. But will claim + /// at least one sequence number. + /// + /// Returns an awaitable that will yield a sequence_range object containing the + /// sequence numbers that were claimed. + /// + /// The caller is responsible for ensuring that they publish every element of the + /// returned sequence range by calling .publish(). + template + multi_producer_sequencer_claim_operation + claim_up_to(std::size_t count, SCHEDULER& scheduler) noexcept; + + /// Publish the element with the specified sequence number, making it available + /// to consumers. + /// + /// Note that different sequence numbers may be published by different producer + /// threads out of order. A sequence number will not become available to consumers + /// until all preceding sequence numbers have also been published. + /// + /// \param sequence + /// The sequence number of the elemnt to publish + /// This sequence number must have been previously acquired via a call to 'claim_one()' + /// or 'claim_up_to()'. + void publish(SEQUENCE sequence) noexcept; + + /// Publish a contiguous range of sequence numbers, making each of them available + /// to consumers. + /// + /// This is equivalent to calling publish(seq) for each sequence number, seq, in + /// the specified range, but is more efficient since it only checks to see if + /// there are coroutines that need to be woken up once. + void publish(const sequence_range& range) noexcept; + + private: + + template + friend class multi_producer_sequencer_wait_operation_base; + + template + friend class multi_producer_sequencer_claim_operation; + + template + friend class multi_producer_sequencer_claim_one_operation; + + void resume_ready_awaiters() noexcept; + void add_awaiter(multi_producer_sequencer_wait_operation_base* awaiter) const noexcept; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(push) +# pragma warning(disable : 4324) // C4324: structure was padded due to alignment specifier +#endif + + const sequence_barrier& m_consumerBarrier; + const std::size_t m_sequenceMask; + const std::unique_ptr[]> m_published; + + alignas(CPPCORO_CPU_CACHE_LINE) + std::atomic m_nextToClaim; + + alignas(CPPCORO_CPU_CACHE_LINE) + mutable std::atomic*> m_awaiters; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(pop) +#endif + + }; + + template + class multi_producer_sequencer_claim_awaiter + { + public: + + multi_producer_sequencer_claim_awaiter( + const sequence_barrier& consumerBarrier, + std::size_t bufferSize, + const sequence_range& claimedRange, + SCHEDULER& scheduler) noexcept + : m_barrierWait(consumerBarrier, claimedRange.back() - bufferSize, scheduler) + , m_claimedRange(claimedRange) + {} + + bool await_ready() const noexcept + { + return m_barrierWait.await_ready(); + } + + auto await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_barrierWait.await_suspend(awaitingCoroutine); + } + + sequence_range await_resume() noexcept + { + return m_claimedRange; + } + + private: + + sequence_barrier_wait_operation m_barrierWait; + sequence_range m_claimedRange; + + }; + + template + class multi_producer_sequencer_claim_operation + { + public: + + multi_producer_sequencer_claim_operation( + multi_producer_sequencer& sequencer, + std::size_t count, + SCHEDULER& scheduler) noexcept + : m_sequencer(sequencer) + , m_count(count < sequencer.buffer_size() ? count : sequencer.buffer_size()) + , m_scheduler(scheduler) + { + } + + multi_producer_sequencer_claim_awaiter operator co_await() noexcept + { + // We wait until the awaitable is actually co_await'ed before we claim the + // range of elements. If we claimed them earlier, then it may be possible for + // the caller to fail to co_await the result eg. due to an exception, which + // would leave the sequence numbers unable to be published and would eventually + // deadlock consumers that waited on them. + // + // TODO: We could try and acquire only as many as are available if fewer than + // m_count elements are available. This would complicate the logic here somewhat + // as we'd need to use a compare-exchange instead. + const SEQUENCE first = m_sequencer.m_nextToClaim.fetch_add(m_count, std::memory_order_relaxed); + return multi_producer_sequencer_claim_awaiter{ + m_sequencer.m_consumerBarrier, + m_sequencer.buffer_size(), + sequence_range{ first, first + m_count }, + m_scheduler + }; + } + + private: + + multi_producer_sequencer& m_sequencer; + std::size_t m_count; + SCHEDULER& m_scheduler; + + }; + + template + class multi_producer_sequencer_claim_one_awaiter + { + public: + + multi_producer_sequencer_claim_one_awaiter( + const sequence_barrier& consumerBarrier, + std::size_t bufferSize, + SEQUENCE claimedSequence, + SCHEDULER& scheduler) noexcept + : m_waitOp(consumerBarrier, claimedSequence - bufferSize, scheduler) + , m_claimedSequence(claimedSequence) + {} + + bool await_ready() const noexcept + { + return m_waitOp.await_ready(); + } + + auto await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_waitOp.await_suspend(awaitingCoroutine); + } + + SEQUENCE await_resume() noexcept + { + return m_claimedSequence; + } + + private: + + sequence_barrier_wait_operation m_waitOp; + SEQUENCE m_claimedSequence; + + }; + + template + class multi_producer_sequencer_claim_one_operation + { + public: + + multi_producer_sequencer_claim_one_operation( + multi_producer_sequencer& sequencer, + SCHEDULER& scheduler) noexcept + : m_sequencer(sequencer) + , m_scheduler(scheduler) + {} + + multi_producer_sequencer_claim_one_awaiter operator co_await() noexcept + { + return multi_producer_sequencer_claim_one_awaiter{ + m_sequencer.m_consumerBarrier, + m_sequencer.buffer_size(), + m_sequencer.m_nextToClaim.fetch_add(1, std::memory_order_relaxed), + m_scheduler + }; + } + + private: + + multi_producer_sequencer& m_sequencer; + SCHEDULER& m_scheduler; + + }; + + template + class multi_producer_sequencer_wait_operation_base + { + public: + + multi_producer_sequencer_wait_operation_base( + const multi_producer_sequencer& sequencer, + SEQUENCE targetSequence, + SEQUENCE lastKnownPublished) noexcept + : m_sequencer(sequencer) + , m_targetSequence(targetSequence) + , m_lastKnownPublished(lastKnownPublished) + , m_readyToResume(false) + {} + + multi_producer_sequencer_wait_operation_base( + const multi_producer_sequencer_wait_operation_base& other) noexcept + : m_sequencer(other.m_sequencer) + , m_targetSequence(other.m_targetSequence) + , m_lastKnownPublished(other.m_lastKnownPublished) + , m_readyToResume(false) + {} + + bool await_ready() const noexcept + { + return !TRAITS::precedes(m_lastKnownPublished, m_targetSequence); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + m_awaitingCoroutine = awaitingCoroutine; + + m_sequencer.add_awaiter(this); + + // Mark the waiter as ready to resume. + // If it was already marked as ready-to-resume within the call to add_awaiter() or + // on another thread then this exchange() will return true. In this case we want to + // resume immediately and continue execution by returning false. + return !m_readyToResume.exchange(true, std::memory_order_acquire); + } + + SEQUENCE await_resume() noexcept + { + return m_lastKnownPublished; + } + + protected: + + friend class multi_producer_sequencer; + + void resume(SEQUENCE lastKnownPublished) noexcept + { + m_lastKnownPublished = lastKnownPublished; + if (m_readyToResume.exchange(true, std::memory_order_release)) + { + resume_impl(); + } + } + + virtual void resume_impl() noexcept = 0; + + const multi_producer_sequencer& m_sequencer; + SEQUENCE m_targetSequence; + SEQUENCE m_lastKnownPublished; + multi_producer_sequencer_wait_operation_base* m_next; + cppcoro::coroutine_handle<> m_awaitingCoroutine; + std::atomic m_readyToResume; + }; + + template + class multi_producer_sequencer_wait_operation : + public multi_producer_sequencer_wait_operation_base + { + using schedule_operation = decltype(std::declval().schedule()); + + public: + + multi_producer_sequencer_wait_operation( + const multi_producer_sequencer& sequencer, + SEQUENCE targetSequence, + SEQUENCE lastKnownPublished, + SCHEDULER& scheduler) noexcept + : multi_producer_sequencer_wait_operation_base(sequencer, targetSequence, lastKnownPublished) + , m_scheduler(scheduler) + {} + + multi_producer_sequencer_wait_operation( + const multi_producer_sequencer_wait_operation& other) noexcept + : multi_producer_sequencer_wait_operation_base(other) + , m_scheduler(other.m_scheduler) + {} + + ~multi_producer_sequencer_wait_operation() + { + if (m_isScheduleAwaiterCreated) + { + m_scheduleAwaiter.destruct(); + } + if (m_isScheduleOperationCreated) + { + m_scheduleOperation.destruct(); + } + } + + SEQUENCE await_resume() noexcept(noexcept(m_scheduleOperation->await_resume())) + { + if (m_isScheduleOperationCreated) + { + m_scheduleOperation->await_resume(); + } + + return multi_producer_sequencer_wait_operation_base::await_resume(); + } + + private: + + void resume_impl() noexcept override + { + try + { + m_scheduleOperation.construct(m_scheduler.schedule()); + m_isScheduleOperationCreated = true; + + m_scheduleAwaiter.construct(detail::get_awaiter( + static_cast(*m_scheduleOperation))); + m_isScheduleAwaiterCreated = true; + + if (!m_scheduleAwaiter->await_ready()) + { + using await_suspend_result_t = decltype(m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine)); + if constexpr (std::is_void_v) + { + m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine); + return; + } + else if constexpr (std::is_same_v) + { + if (m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine)) + { + return; + } + } + else + { + // Assume it returns a coroutine_handle. + m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine).resume(); + return; + } + } + } + catch (...) + { + // Ignore failure to reschedule and resume inline? + // Should we catch the exception and rethrow from await_resume()? + // Or should we require that 'co_await scheduler.schedule()' is noexcept? + } + + // Resume outside the catch-block. + this->m_awaitingCoroutine.resume(); + } + + SCHEDULER& m_scheduler; + // Can't use std::optional here since T could be a reference. + detail::manual_lifetime m_scheduleOperation; + detail::manual_lifetime::awaiter_t> m_scheduleAwaiter; + bool m_isScheduleOperationCreated = false; + bool m_isScheduleAwaiterCreated = false; + + }; + + template + multi_producer_sequencer::multi_producer_sequencer( + const sequence_barrier& consumerBarrier, + std::size_t bufferSize, + SEQUENCE initialSequence) + : m_consumerBarrier(consumerBarrier) + , m_sequenceMask(bufferSize - 1) + , m_published(std::make_unique[]>(bufferSize)) + , m_nextToClaim(initialSequence + 1) + , m_awaiters(nullptr) + { + // bufferSize must be a positive power-of-two + assert(bufferSize > 0 && (bufferSize & (bufferSize - 1)) == 0); + // but must be no larger than the max diff value. + using diff_t = typename TRAITS::difference_type; + using unsigned_diff_t = std::make_unsigned_t; + constexpr unsigned_diff_t maxSize = static_cast(std::numeric_limits::max()); + assert(bufferSize <= maxSize); + + SEQUENCE seq = initialSequence - (bufferSize - 1); + do + { +#ifdef __cpp_lib_atomic_value_initialization + m_published[seq & m_sequenceMask].store(seq, std::memory_order_relaxed); +#else // ^^^ __cpp_lib_atomic_value_initialization // !__cpp_lib_atomic_value_initialization vvv + std::atomic_init(&m_published[seq & m_sequenceMask], seq); +#endif // !__cpp_lib_atomic_value_initialization + } while (seq++ != initialSequence); + } + + template + SEQUENCE multi_producer_sequencer::last_published_after( + SEQUENCE lastKnownPublished) const noexcept + { + const auto mask = m_sequenceMask; + SEQUENCE seq = lastKnownPublished + 1; + while (m_published[seq & mask].load(std::memory_order_acquire) == seq) + { + lastKnownPublished = seq++; + } + return lastKnownPublished; + } + + template + template + multi_producer_sequencer_wait_operation + multi_producer_sequencer::wait_until_published( + SEQUENCE targetSequence, + SEQUENCE lastKnownPublished, + SCHEDULER& scheduler) const noexcept + { + return multi_producer_sequencer_wait_operation{ + *this, targetSequence, lastKnownPublished, scheduler + }; + } + + template + bool multi_producer_sequencer::any_available() const noexcept + { + return TRAITS::precedes( + m_nextToClaim.load(std::memory_order_relaxed), + m_consumerBarrier.last_published() + buffer_size()); + } + + template + template + multi_producer_sequencer_claim_one_operation + multi_producer_sequencer::claim_one(SCHEDULER& scheduler) noexcept + { + return multi_producer_sequencer_claim_one_operation{ *this, scheduler }; + } + + template + template + multi_producer_sequencer_claim_operation + multi_producer_sequencer::claim_up_to(std::size_t count, SCHEDULER& scheduler) noexcept + { + return multi_producer_sequencer_claim_operation{ *this, count, scheduler }; + } + + template + void multi_producer_sequencer::publish(SEQUENCE sequence) noexcept + { + m_published[sequence & m_sequenceMask].store(sequence, std::memory_order_seq_cst); + + // Resume any waiters that might have been satisfied by this publish operation. + resume_ready_awaiters(); + } + + template + void multi_producer_sequencer::publish(const sequence_range& range) noexcept + { + if (range.empty()) + { + return; + } + + // Publish all but the first sequence number using relaxed atomics. + // No consumer should be reading those subsequent sequence numbers until they've seen + // that the first sequence number in the range is published. + for (SEQUENCE seq : range.skip(1)) + { + m_published[seq & m_sequenceMask].store(seq, std::memory_order_relaxed); + } + + // Now publish the first sequence number with seq_cst semantics. + m_published[range.front() & m_sequenceMask].store(range.front(), std::memory_order_seq_cst); + + // Resume any waiters that might have been satisfied by this publish operation. + resume_ready_awaiters(); + } + + template + void multi_producer_sequencer::resume_ready_awaiters() noexcept + { + using awaiter_t = multi_producer_sequencer_wait_operation_base; + + awaiter_t* awaiters = m_awaiters.load(std::memory_order_seq_cst); + if (awaiters == nullptr) + { + // No awaiters + return; + } + + // There were some awaiters. Try to acquire the list of waiters with an + // atomic exchange as we might be racing with other consumers/producers. + awaiters = m_awaiters.exchange(nullptr, std::memory_order_seq_cst); + if (awaiters == nullptr) + { + // Didn't acquire the list + // Some other thread is now responsible for resuming them. Our job is done. + return; + } + + SEQUENCE lastKnownPublished; + + awaiter_t* awaitersToResume; + awaiter_t** awaitersToResumeTail = &awaitersToResume; + + awaiter_t* awaitersToRequeue; + awaiter_t** awaitersToRequeueTail = &awaitersToRequeue; + + do + { + using diff_t = typename TRAITS::difference_type; + + lastKnownPublished = last_published_after(awaiters->m_lastKnownPublished); + + // First scan the list of awaiters and split them into 'requeue' and 'resume' lists. + auto minDiff = std::numeric_limits::max(); + do + { + auto diff = TRAITS::difference(awaiters->m_targetSequence, lastKnownPublished); + if (diff > 0) + { + // Not ready yet. + minDiff = diff < minDiff ? diff : minDiff; + *awaitersToRequeueTail = awaiters; + awaitersToRequeueTail = &awaiters->m_next; + } + else + { + *awaitersToResumeTail = awaiters; + awaitersToResumeTail = &awaiters->m_next; + } + awaiters->m_lastKnownPublished = lastKnownPublished; + awaiters = awaiters->m_next; + } while (awaiters != nullptr); + + // Null-terinate the requeue list + *awaitersToRequeueTail = nullptr; + + if (awaitersToRequeue != nullptr) + { + // Requeue the waiters that are not ready yet. + awaiter_t* oldHead = nullptr; + while (!m_awaiters.compare_exchange_weak(oldHead, awaitersToRequeue, std::memory_order_seq_cst, std::memory_order_relaxed)) + { + *awaitersToRequeueTail = oldHead; + } + + // Reset the awaitersToRequeue list + awaitersToRequeueTail = &awaitersToRequeue; + + const SEQUENCE earliestTargetSequence = lastKnownPublished + minDiff; + + // Now we need to check again to see if any of the waiters we just enqueued + // is now satisfied by a concurrent call to publish(). + // + // We need to be a bit more careful here since we are no longer holding any + // awaiters and so producers/consumers may advance the sequence number arbitrarily + // far. If the sequence number advances more than buffer_size() ahead of the + // earliestTargetSequence then the m_published[] array may have sequence numbers + // that have advanced beyond earliestTargetSequence, potentially even wrapping + // sequence numbers around to then be preceding where they were before. If this + // happens then we don't need to worry about resuming any awaiters that were waiting + // for 'earliestTargetSequence' since some other thread has already resumed them. + // So the only case we need to worry about here is when all m_published entries for + // sequence numbers in range [lastKnownPublished + 1, earliestTargetSequence] have + // published sequence numbers that match the range. + const auto sequenceMask = m_sequenceMask; + SEQUENCE seq = lastKnownPublished + 1; + while (m_published[seq & sequenceMask].load(std::memory_order_seq_cst) == seq) + { + lastKnownPublished = seq; + if (seq == earliestTargetSequence) + { + // At least one of the awaiters we just published is now satisfied. + // Reacquire the list of awaiters and continue around the outer loop. + awaiters = m_awaiters.exchange(nullptr, std::memory_order_acquire); + break; + } + ++seq; + } + } + } while (awaiters != nullptr); + + // Null-terminate list of awaiters to resume. + *awaitersToResumeTail = nullptr; + + while (awaitersToResume != nullptr) + { + awaiter_t* next = awaitersToResume->m_next; + awaitersToResume->resume(lastKnownPublished); + awaitersToResume = next; + } + } + + template + void multi_producer_sequencer::add_awaiter( + multi_producer_sequencer_wait_operation_base* awaiter) const noexcept + { + using awaiter_t = multi_producer_sequencer_wait_operation_base; + + SEQUENCE targetSequence = awaiter->m_targetSequence; + SEQUENCE lastKnownPublished = awaiter->m_lastKnownPublished; + + awaiter_t* awaitersToEnqueue = awaiter; + awaiter_t** awaitersToEnqueueTail = &awaiter->m_next; + + awaiter_t* awaitersToResume; + awaiter_t** awaitersToResumeTail = &awaitersToResume; + + const SEQUENCE sequenceMask = m_sequenceMask; + + do + { + // Enqueue the awaiters. + { + awaiter_t* oldHead = m_awaiters.load(std::memory_order_relaxed); + do + { + *awaitersToEnqueueTail = oldHead; + } while (!m_awaiters.compare_exchange_weak( + oldHead, + awaitersToEnqueue, + std::memory_order_seq_cst, + std::memory_order_relaxed)); + } + + // Reset list of waiters + awaitersToEnqueueTail = &awaitersToEnqueue; + + // Check to see if the last-known published sequence number has advanced + // while we were enqueuing the awaiters. Need to use seq_cst memory order + // here to ensure that if there are concurrent calls to publish() that would + // wake up any of the awaiters we just enqueued that either we will see their + // write to m_published slots or they will see our write to m_awaiters. + // + // Note also, that we are assuming that the last-known published sequence is + // not going to advance more than buffer_size() ahead of targetSequence since + // there is at least one consumer that won't be resumed and so thus can't + // publish the sequence number it's waiting for to its sequence_barrier and so + // producers won't be able to claim its slot in the buffer. + // + // TODO: Check whether we can weaken the memory order here to just use 'seq_cst' on the + // first .load() and then use 'acquire' on subsequent .load(). + while (m_published[(lastKnownPublished + 1) & sequenceMask].load(std::memory_order_seq_cst) == (lastKnownPublished + 1)) + { + ++lastKnownPublished; + } + + if (!TRAITS::precedes(lastKnownPublished, targetSequence)) + { + // At least one awaiter we just enqueued has now been satisified. + // To ensure it is woken up we need to reacquire the list of awaiters and resume + awaiter_t* awaiters = m_awaiters.exchange(nullptr, std::memory_order_acquire); + + using diff_t = typename TRAITS::difference_type; + + diff_t minDiff = std::numeric_limits::max(); + + while (awaiters != nullptr) + { + diff_t diff = TRAITS::difference(targetSequence, lastKnownPublished); + if (diff > 0) + { + // Not yet ready. + minDiff = diff < minDiff ? diff : minDiff; + *awaitersToEnqueueTail = awaiters; + awaitersToEnqueueTail = &awaiters->m_next; + awaiters->m_lastKnownPublished = lastKnownPublished; + } + else + { + // Now ready. + *awaitersToResumeTail = awaiters; + awaitersToResumeTail = &awaiters->m_next; + } + awaiters = awaiters->m_next; + } + + // Calculate the earliest sequence number that any awaiters in the + // awaitersToEnqueue list are waiting for. We'll use this next time + // around the loop. + targetSequence = static_cast(lastKnownPublished + minDiff); + } + + // Null-terminate list of awaiters to enqueue. + *awaitersToEnqueueTail = nullptr; + + } while (awaitersToEnqueue != nullptr); + + // Null-terminate awaiters to resume. + *awaitersToResumeTail = nullptr; + + // Finally, resume any awaiters we've found that are ready to go. + while (awaitersToResume != nullptr) + { + // Read m_next before calling .resume() as resuming could destroy the awaiter. + awaiter_t* next = awaitersToResume->m_next; + awaitersToResume->resume(lastKnownPublished); + awaitersToResume = next; + } + } +} + +#endif diff --git a/include/cppcoro/net/ip_address.hpp b/include/cppcoro/net/ip_address.hpp new file mode 100644 index 0000000..f1d847f --- /dev/null +++ b/include/cppcoro/net/ip_address.hpp @@ -0,0 +1,147 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_IP_ADDRESS_HPP_INCLUDED +#define CPPCORO_NET_IP_ADDRESS_HPP_INCLUDED + +#include +#include + +#include +#include +#include + +namespace cppcoro +{ + namespace net + { + class ip_address + { + public: + + // Constructs to IPv4 address 0.0.0.0 + ip_address() noexcept; + + ip_address(ipv4_address address) noexcept; + ip_address(ipv6_address address) noexcept; + + bool is_ipv4() const noexcept { return m_family == family::ipv4; } + bool is_ipv6() const noexcept { return m_family == family::ipv6; } + + const ipv4_address& to_ipv4() const; + const ipv6_address& to_ipv6() const; + + const std::uint8_t* bytes() const noexcept; + + std::string to_string() const; + + static std::optional from_string(std::string_view string) noexcept; + + bool operator==(const ip_address& rhs) const noexcept; + bool operator!=(const ip_address& rhs) const noexcept; + + // ipv4_address sorts less than ipv6_address + bool operator<(const ip_address& rhs) const noexcept; + bool operator>(const ip_address& rhs) const noexcept; + bool operator<=(const ip_address& rhs) const noexcept; + bool operator>=(const ip_address& rhs) const noexcept; + + private: + + enum class family + { + ipv4, + ipv6 + }; + + family m_family; + + union + { + ipv4_address m_ipv4; + ipv6_address m_ipv6; + }; + + }; + + inline ip_address::ip_address() noexcept + : m_family(family::ipv4) + , m_ipv4() + {} + + inline ip_address::ip_address(ipv4_address address) noexcept + : m_family(family::ipv4) + , m_ipv4(address) + {} + + inline ip_address::ip_address(ipv6_address address) noexcept + : m_family(family::ipv6) + , m_ipv6(address) + { + } + + inline const ipv4_address& ip_address::to_ipv4() const + { + assert(is_ipv4()); + return m_ipv4; + } + + inline const ipv6_address& ip_address::to_ipv6() const + { + assert(is_ipv6()); + return m_ipv6; + } + + inline const std::uint8_t* ip_address::bytes() const noexcept + { + return is_ipv4() ? m_ipv4.bytes() : m_ipv6.bytes(); + } + + inline bool ip_address::operator==(const ip_address& rhs) const noexcept + { + if (is_ipv4()) + { + return rhs.is_ipv4() && m_ipv4 == rhs.m_ipv4; + } + else + { + return rhs.is_ipv6() && m_ipv6 == rhs.m_ipv6; + } + } + + inline bool ip_address::operator!=(const ip_address& rhs) const noexcept + { + return !(*this == rhs); + } + + inline bool ip_address::operator<(const ip_address& rhs) const noexcept + { + if (is_ipv4()) + { + return !rhs.is_ipv4() || m_ipv4 < rhs.m_ipv4; + } + else + { + return rhs.is_ipv6() && m_ipv6 < rhs.m_ipv6; + } + } + + inline bool ip_address::operator>(const ip_address& rhs) const noexcept + { + return rhs < *this; + } + + inline bool ip_address::operator<=(const ip_address& rhs) const noexcept + { + return !(rhs < *this); + } + + inline bool ip_address::operator>=(const ip_address& rhs) const noexcept + { + return !(*this < rhs); + } + } +} + +#endif diff --git a/include/cppcoro/net/ip_endpoint.hpp b/include/cppcoro/net/ip_endpoint.hpp new file mode 100644 index 0000000..d15422d --- /dev/null +++ b/include/cppcoro/net/ip_endpoint.hpp @@ -0,0 +1,161 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_IP_ENDPOINT_HPP_INCLUDED +#define CPPCORO_NET_IP_ENDPOINT_HPP_INCLUDED + +#include +#include +#include + +#include +#include +#include + +namespace cppcoro +{ + namespace net + { + class ip_endpoint + { + public: + + // Constructs to IPv4 end-point 0.0.0.0:0 + ip_endpoint() noexcept; + + ip_endpoint(ipv4_endpoint endpoint) noexcept; + ip_endpoint(ipv6_endpoint endpoint) noexcept; + + bool is_ipv4() const noexcept { return m_family == family::ipv4; } + bool is_ipv6() const noexcept { return m_family == family::ipv6; } + + const ipv4_endpoint& to_ipv4() const; + const ipv6_endpoint& to_ipv6() const; + + ip_address address() const noexcept; + std::uint16_t port() const noexcept; + + std::string to_string() const; + + static std::optional from_string(std::string_view string) noexcept; + + bool operator==(const ip_endpoint& rhs) const noexcept; + bool operator!=(const ip_endpoint& rhs) const noexcept; + + // ipv4_endpoint sorts less than ipv6_endpoint + bool operator<(const ip_endpoint& rhs) const noexcept; + bool operator>(const ip_endpoint& rhs) const noexcept; + bool operator<=(const ip_endpoint& rhs) const noexcept; + bool operator>=(const ip_endpoint& rhs) const noexcept; + + private: + + enum class family + { + ipv4, + ipv6 + }; + + family m_family; + + union + { + ipv4_endpoint m_ipv4; + ipv6_endpoint m_ipv6; + }; + + }; + + inline ip_endpoint::ip_endpoint() noexcept + : m_family(family::ipv4) + , m_ipv4() + {} + + inline ip_endpoint::ip_endpoint(ipv4_endpoint endpoint) noexcept + : m_family(family::ipv4) + , m_ipv4(endpoint) + {} + + inline ip_endpoint::ip_endpoint(ipv6_endpoint endpoint) noexcept + : m_family(family::ipv6) + , m_ipv6(endpoint) + { + } + + inline const ipv4_endpoint& ip_endpoint::to_ipv4() const + { + assert(is_ipv4()); + return m_ipv4; + } + + inline const ipv6_endpoint& ip_endpoint::to_ipv6() const + { + assert(is_ipv6()); + return m_ipv6; + } + + inline ip_address ip_endpoint::address() const noexcept + { + if (is_ipv4()) + { + return m_ipv4.address(); + } + else + { + return m_ipv6.address(); + } + } + + inline std::uint16_t ip_endpoint::port() const noexcept + { + return is_ipv4() ? m_ipv4.port() : m_ipv6.port(); + } + + inline bool ip_endpoint::operator==(const ip_endpoint& rhs) const noexcept + { + if (is_ipv4()) + { + return rhs.is_ipv4() && m_ipv4 == rhs.m_ipv4; + } + else + { + return rhs.is_ipv6() && m_ipv6 == rhs.m_ipv6; + } + } + + inline bool ip_endpoint::operator!=(const ip_endpoint& rhs) const noexcept + { + return !(*this == rhs); + } + + inline bool ip_endpoint::operator<(const ip_endpoint& rhs) const noexcept + { + if (is_ipv4()) + { + return !rhs.is_ipv4() || m_ipv4 < rhs.m_ipv4; + } + else + { + return rhs.is_ipv6() && m_ipv6 < rhs.m_ipv6; + } + } + + inline bool ip_endpoint::operator>(const ip_endpoint& rhs) const noexcept + { + return rhs < *this; + } + + inline bool ip_endpoint::operator<=(const ip_endpoint& rhs) const noexcept + { + return !(rhs < *this); + } + + inline bool ip_endpoint::operator>=(const ip_endpoint& rhs) const noexcept + { + return !(*this < rhs); + } + } +} + +#endif diff --git a/include/cppcoro/net/ipv4_address.hpp b/include/cppcoro/net/ipv4_address.hpp new file mode 100644 index 0000000..7a5dd43 --- /dev/null +++ b/include/cppcoro/net/ipv4_address.hpp @@ -0,0 +1,134 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_IPV4_ADDRESS_HPP_INCLUDED +#define CPPCORO_NET_IPV4_ADDRESS_HPP_INCLUDED + +#include +#include +#include +#include + +namespace cppcoro::net +{ + class ipv4_address + { + using bytes_t = std::uint8_t[4]; + + public: + + constexpr ipv4_address() + : m_bytes{ 0, 0, 0, 0 } + {} + + explicit constexpr ipv4_address(std::uint32_t integer) + : m_bytes{ + static_cast(integer >> 24), + static_cast(integer >> 16), + static_cast(integer >> 8), + static_cast(integer) } + {} + + explicit constexpr ipv4_address(const std::uint8_t(&bytes)[4]) + : m_bytes{ bytes[0], bytes[1], bytes[2], bytes[3] } + {} + + explicit constexpr ipv4_address( + std::uint8_t b0, + std::uint8_t b1, + std::uint8_t b2, + std::uint8_t b3) + : m_bytes{ b0, b1, b2, b3 } + {} + + constexpr const bytes_t& bytes() const { return m_bytes; } + + constexpr std::uint32_t to_integer() const + { + return + std::uint32_t(m_bytes[0]) << 24 | + std::uint32_t(m_bytes[1]) << 16 | + std::uint32_t(m_bytes[2]) << 8 | + std::uint32_t(m_bytes[3]); + } + + static constexpr ipv4_address loopback() + { + return ipv4_address(127, 0, 0, 1); + } + + constexpr bool is_loopback() const + { + return m_bytes[0] == 127; + } + + constexpr bool is_private_network() const + { + return m_bytes[0] == 10 || + (m_bytes[0] == 172 && (m_bytes[1] & 0xF0) == 0x10) || + (m_bytes[0] == 192 && m_bytes[2] == 168); + } + + constexpr bool operator==(ipv4_address other) const + { + return + m_bytes[0] == other.m_bytes[0] && + m_bytes[1] == other.m_bytes[1] && + m_bytes[2] == other.m_bytes[2] && + m_bytes[3] == other.m_bytes[3]; + } + + constexpr bool operator!=(ipv4_address other) const + { + return !(*this == other); + } + + constexpr bool operator<(ipv4_address other) const + { + return to_integer() < other.to_integer(); + } + + constexpr bool operator>(ipv4_address other) const + { + return other < *this; + } + + constexpr bool operator<=(ipv4_address other) const + { + return !(other < *this); + } + + constexpr bool operator>=(ipv4_address other) const + { + return !(*this < other); + } + + /// Parse a string representation of an IP address. + /// + /// Parses strings of the form: + /// - "num.num.num.num" where num is an integer in range [0, 255]. + /// - A single integer value in range [0, 2^32). + /// + /// \param string + /// The string to parse. + /// Must be in ASCII, UTF-8 or Latin-1 encoding. + /// + /// \return + /// The IP address if successful, otherwise std::nullopt if the string + /// could not be parsed as an IPv4 address. + static std::optional from_string(std::string_view string) noexcept; + + /// Convert the IP address to dotted decimal notation. + /// + /// eg. "12.67.190.23" + std::string to_string() const; + + private: + + alignas(std::uint32_t) std::uint8_t m_bytes[4]; + + }; +} + +#endif diff --git a/include/cppcoro/net/ipv4_endpoint.hpp b/include/cppcoro/net/ipv4_endpoint.hpp new file mode 100644 index 0000000..11ee72e --- /dev/null +++ b/include/cppcoro/net/ipv4_endpoint.hpp @@ -0,0 +1,82 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_IPV4_ENDPOINT_HPP_INCLUDED +#define CPPCORO_NET_IPV4_ENDPOINT_HPP_INCLUDED + +#include + +#include +#include +#include + +namespace cppcoro +{ + namespace net + { + class ipv4_endpoint + { + public: + + // Construct to 0.0.0.0:0 + ipv4_endpoint() noexcept + : m_address() + , m_port(0) + {} + + explicit ipv4_endpoint(ipv4_address address, std::uint16_t port = 0) noexcept + : m_address(address) + , m_port(port) + {} + + const ipv4_address& address() const noexcept { return m_address; } + + std::uint16_t port() const noexcept { return m_port; } + + std::string to_string() const; + + static std::optional from_string(std::string_view string) noexcept; + + private: + + ipv4_address m_address; + std::uint16_t m_port; + + }; + + inline bool operator==(const ipv4_endpoint& a, const ipv4_endpoint& b) + { + return a.address() == b.address() && + a.port() == b.port(); + } + + inline bool operator!=(const ipv4_endpoint& a, const ipv4_endpoint& b) + { + return !(a == b); + } + + inline bool operator<(const ipv4_endpoint& a, const ipv4_endpoint& b) + { + return a.address() < b.address() || + (a.address() == b.address() && a.port() < b.port()); + } + + inline bool operator>(const ipv4_endpoint& a, const ipv4_endpoint& b) + { + return b < a; + } + + inline bool operator<=(const ipv4_endpoint& a, const ipv4_endpoint& b) + { + return !(b < a); + } + + inline bool operator>=(const ipv4_endpoint& a, const ipv4_endpoint& b) + { + return !(a < b); + } + } +} + +#endif diff --git a/include/cppcoro/net/ipv6_address.hpp b/include/cppcoro/net/ipv6_address.hpp new file mode 100644 index 0000000..46a92d7 --- /dev/null +++ b/include/cppcoro/net/ipv6_address.hpp @@ -0,0 +1,245 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_IPV6_ADDRESS_HPP_INCLUDED +#define CPPCORO_NET_IPV6_ADDRESS_HPP_INCLUDED + +#include +#include +#include +#include + +namespace cppcoro::net +{ + class ipv4_address; + + class ipv6_address + { + using bytes_t = std::uint8_t[16]; + + public: + + constexpr ipv6_address(); + + explicit constexpr ipv6_address( + std::uint64_t subnetPrefix, + std::uint64_t interfaceIdentifier); + + constexpr ipv6_address( + std::uint16_t part0, + std::uint16_t part1, + std::uint16_t part2, + std::uint16_t part3, + std::uint16_t part4, + std::uint16_t part5, + std::uint16_t part6, + std::uint16_t part7); + + explicit constexpr ipv6_address( + const std::uint16_t(&parts)[8]); + + explicit constexpr ipv6_address( + const std::uint8_t(&bytes)[16]); + + constexpr const bytes_t& bytes() const { return m_bytes; } + + constexpr std::uint64_t subnet_prefix() const; + + constexpr std::uint64_t interface_identifier() const; + + /// Get the IPv6 unspedified address :: (all zeroes). + static constexpr ipv6_address unspecified(); + + /// Get the IPv6 loopback address ::1. + static constexpr ipv6_address loopback(); + + /// Parse a string representation of an IPv6 address. + /// + /// \param string + /// The string to parse. + /// Must be in ASCII, UTF-8 or Latin-1 encoding. + /// + /// \return + /// The IP address if successful, otherwise std::nullopt if the string + /// could not be parsed as an IPv4 address. + static std::optional from_string(std::string_view string) noexcept; + + /// Convert the IP address to contracted string form. + /// + /// Address is broken up into 16-bit parts, with each part represended in 1-4 + /// lower-case hexadecimal with leading zeroes omitted. Parts are separated + /// by separated by a ':'. The longest contiguous run of zero parts is contracted + /// to "::". + /// + /// For example: + /// ipv6_address::unspecified() -> "::" + /// ipv6_address::loopback() -> "::1" + /// ipv6_address(0x0011223344556677, 0x8899aabbccddeeff) -> + /// "11:2233:4455:6677:8899:aabb:ccdd:eeff" + /// ipv6_address(0x0102030400000000, 0x003fc447ab991011) -> + /// "102:304::3f:c447:ab99:1011" + std::string to_string() const; + + constexpr bool operator==(const ipv6_address& other) const; + constexpr bool operator!=(const ipv6_address& other) const; + constexpr bool operator<(const ipv6_address& other) const; + constexpr bool operator>(const ipv6_address& other) const; + constexpr bool operator<=(const ipv6_address& other) const; + constexpr bool operator>=(const ipv6_address& other) const; + + private: + + alignas(std::uint64_t) std::uint8_t m_bytes[16]; + + }; + + constexpr ipv6_address::ipv6_address() + : m_bytes{ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0 } + {} + + constexpr ipv6_address::ipv6_address( + std::uint64_t subnetPrefix, + std::uint64_t interfaceIdentifier) + : m_bytes{ + static_cast(subnetPrefix >> 56), + static_cast(subnetPrefix >> 48), + static_cast(subnetPrefix >> 40), + static_cast(subnetPrefix >> 32), + static_cast(subnetPrefix >> 24), + static_cast(subnetPrefix >> 16), + static_cast(subnetPrefix >> 8), + static_cast(subnetPrefix), + static_cast(interfaceIdentifier >> 56), + static_cast(interfaceIdentifier >> 48), + static_cast(interfaceIdentifier >> 40), + static_cast(interfaceIdentifier >> 32), + static_cast(interfaceIdentifier >> 24), + static_cast(interfaceIdentifier >> 16), + static_cast(interfaceIdentifier >> 8), + static_cast(interfaceIdentifier) } + {} + + constexpr ipv6_address::ipv6_address( + std::uint16_t part0, + std::uint16_t part1, + std::uint16_t part2, + std::uint16_t part3, + std::uint16_t part4, + std::uint16_t part5, + std::uint16_t part6, + std::uint16_t part7) + : m_bytes{ + static_cast(part0 >> 8), + static_cast(part0), + static_cast(part1 >> 8), + static_cast(part1), + static_cast(part2 >> 8), + static_cast(part2), + static_cast(part3 >> 8), + static_cast(part3), + static_cast(part4 >> 8), + static_cast(part4), + static_cast(part5 >> 8), + static_cast(part5), + static_cast(part6 >> 8), + static_cast(part6), + static_cast(part7 >> 8), + static_cast(part7) } + {} + + constexpr ipv6_address::ipv6_address( + const std::uint16_t(&parts)[8]) + : ipv6_address( + parts[0], parts[1], parts[2], parts[3], + parts[4], parts[5], parts[6], parts[7]) + {} + + constexpr ipv6_address::ipv6_address(const std::uint8_t(&bytes)[16]) + : m_bytes{ + bytes[0], bytes[1], bytes[2], bytes[3], + bytes[4], bytes[5], bytes[6], bytes[7], + bytes[8], bytes[9], bytes[10], bytes[11], + bytes[12], bytes[13], bytes[14], bytes[15] } + {} + + constexpr std::uint64_t ipv6_address::subnet_prefix() const + { + return + static_cast(m_bytes[0]) << 56 | + static_cast(m_bytes[1]) << 48 | + static_cast(m_bytes[2]) << 40 | + static_cast(m_bytes[3]) << 32 | + static_cast(m_bytes[4]) << 24 | + static_cast(m_bytes[5]) << 16 | + static_cast(m_bytes[6]) << 8 | + static_cast(m_bytes[7]); + } + + constexpr std::uint64_t ipv6_address::interface_identifier() const + { + return + static_cast(m_bytes[8]) << 56 | + static_cast(m_bytes[9]) << 48 | + static_cast(m_bytes[10]) << 40 | + static_cast(m_bytes[11]) << 32 | + static_cast(m_bytes[12]) << 24 | + static_cast(m_bytes[13]) << 16 | + static_cast(m_bytes[14]) << 8 | + static_cast(m_bytes[15]); + } + + constexpr ipv6_address ipv6_address::unspecified() + { + return ipv6_address{}; + } + + constexpr ipv6_address ipv6_address::loopback() + { + return ipv6_address{ 0, 0, 0, 0, 0, 0, 0, 1 }; + } + + constexpr bool ipv6_address::operator==(const ipv6_address& other) const + { + for (int i = 0; i < 16; ++i) + { + if (m_bytes[i] != other.m_bytes[i]) return false; + } + return true; + } + + constexpr bool ipv6_address::operator!=(const ipv6_address& other) const + { + return !(*this == other); + } + + constexpr bool ipv6_address::operator<(const ipv6_address& other) const + { + for (int i = 0; i < 16; ++i) + { + if (m_bytes[i] != other.m_bytes[i]) + return m_bytes[i] < other.m_bytes[i]; + } + + return false; + } + + constexpr bool ipv6_address::operator>(const ipv6_address& other) const + { + return (other < *this); + } + + constexpr bool ipv6_address::operator<=(const ipv6_address& other) const + { + return !(other < *this); + } + + constexpr bool ipv6_address::operator>=(const ipv6_address& other) const + { + return !(*this < other); + } +} + +#endif diff --git a/include/cppcoro/net/ipv6_endpoint.hpp b/include/cppcoro/net/ipv6_endpoint.hpp new file mode 100644 index 0000000..d0e50bb --- /dev/null +++ b/include/cppcoro/net/ipv6_endpoint.hpp @@ -0,0 +1,82 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_IPV6_ENDPOINT_HPP_INCLUDED +#define CPPCORO_NET_IPV6_ENDPOINT_HPP_INCLUDED + +#include + +#include +#include +#include + +namespace cppcoro +{ + namespace net + { + class ipv6_endpoint + { + public: + + // Construct to [::]:0 + ipv6_endpoint() noexcept + : m_address() + , m_port(0) + {} + + explicit ipv6_endpoint(ipv6_address address, std::uint16_t port = 0) noexcept + : m_address(address) + , m_port(port) + {} + + const ipv6_address& address() const noexcept { return m_address; } + + std::uint16_t port() const noexcept { return m_port; } + + std::string to_string() const; + + static std::optional from_string(std::string_view string) noexcept; + + private: + + ipv6_address m_address; + std::uint16_t m_port; + + }; + + inline bool operator==(const ipv6_endpoint& a, const ipv6_endpoint& b) + { + return a.address() == b.address() && + a.port() == b.port(); + } + + inline bool operator!=(const ipv6_endpoint& a, const ipv6_endpoint& b) + { + return !(a == b); + } + + inline bool operator<(const ipv6_endpoint& a, const ipv6_endpoint& b) + { + return a.address() < b.address() || + (a.address() == b.address() && a.port() < b.port()); + } + + inline bool operator>(const ipv6_endpoint& a, const ipv6_endpoint& b) + { + return b < a; + } + + inline bool operator<=(const ipv6_endpoint& a, const ipv6_endpoint& b) + { + return !(b < a); + } + + inline bool operator>=(const ipv6_endpoint& a, const ipv6_endpoint& b) + { + return !(a < b); + } + } +} + +#endif diff --git a/include/cppcoro/net/socket.hpp b/include/cppcoro/net/socket.hpp new file mode 100644 index 0000000..a61eac6 --- /dev/null +++ b/include/cppcoro/net/socket.hpp @@ -0,0 +1,268 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_HPP_INCLUDED + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#if CPPCORO_OS_WINNT +# include +#endif + +namespace cppcoro +{ + class io_service; + + namespace net + { + class socket + { + public: + + /// Create a socket that can be used to communicate using TCP/IPv4 protocol. + /// + /// \param ioSvc + /// The I/O service the socket will use for dispatching I/O completion events. + /// + /// \return + /// The newly created socket. + /// + /// \throws std::system_error + /// If the socket could not be created for some reason. + static socket create_tcpv4(io_service& ioSvc); + + /// Create a socket that can be used to communicate using TCP/IPv6 protocol. + /// + /// \param ioSvc + /// The I/O service the socket will use for dispatching I/O completion events. + /// + /// \return + /// The newly created socket. + /// + /// \throws std::system_error + /// If the socket could not be created for some reason. + static socket create_tcpv6(io_service& ioSvc); + + /// Create a socket that can be used to communicate using UDP/IPv4 protocol. + /// + /// \param ioSvc + /// The I/O service the socket will use for dispatching I/O completion events. + /// + /// \return + /// The newly created socket. + /// + /// \throws std::system_error + /// If the socket could not be created for some reason. + static socket create_udpv4(io_service& ioSvc); + + /// Create a socket that can be used to communicate using UDP/IPv6 protocol. + /// + /// \param ioSvc + /// The I/O service the socket will use for dispatching I/O completion events. + /// + /// \return + /// The newly created socket. + /// + /// \throws std::system_error + /// If the socket could not be created for some reason. + static socket create_udpv6(io_service& ioSvc); + + socket(socket&& other) noexcept; + + /// Closes the socket, releasing any associated resources. + /// + /// If the socket still has an open connection then the connection will be + /// reset. The destructor will not block waiting for queueud data to be sent. + /// If you need to ensure that queued data is delivered then you must call + /// disconnect() and wait until the disconnect operation completes. + ~socket(); + + socket& operator=(socket&& other) noexcept; + +#if CPPCORO_OS_WINNT + /// Get the Win32 socket handle assocaited with this socket. + cppcoro::detail::win32::socket_t native_handle() noexcept { return m_handle; } + + /// Query whether I/O operations that complete synchronously will skip posting + /// an I/O completion event to the I/O completion port. + /// + /// The operation class implementations can use this to determine whether or not + /// it should immediately resume the coroutine on the current thread upon an + /// operation completing synchronously or whether it should suspend the coroutine + /// and wait until the I/O completion event is dispatched to an I/O thread. + bool skip_completion_on_success() noexcept { return m_skipCompletionOnSuccess; } +#endif + + /// Get the address and port of the local end-point. + /// + /// If the socket is not bound then this will be the unspecified end-point + /// of the socket's associated address-family. + const ip_endpoint& local_endpoint() const noexcept { return m_localEndPoint; } + + /// Get the address and port of the remote end-point. + /// + /// If the socket is not in the connected state then this will be the unspecified + /// end-point of the socket's associated address-family. + const ip_endpoint& remote_endpoint() const noexcept { return m_remoteEndPoint; } + + /// Bind the local end of this socket to the specified local end-point. + /// + /// \param localEndPoint + /// The end-point to bind to. + /// This can be either an unspecified address (in which case it binds to all available + /// interfaces) and/or an unspecified port (in which case a random port is allocated). + /// + /// \throws std::system_error + /// If the socket could not be bound for some reason. + void bind(const ip_endpoint& localEndPoint); + + /// Put the socket into a passive listening state that will start acknowledging + /// and queueing up new connections ready to be accepted by a call to 'accept()'. + /// + /// The backlog of connections ready to be accepted will be set to some default + /// suitable large value, depending on the network provider. If you need more + /// control over the size of the queue then use the overload of listen() + /// that accepts a 'backlog' parameter. + /// + /// \throws std::system_error + /// If the socket could not be placed into a listening mode. + void listen(); + + /// Put the socket into a passive listening state that will start acknowledging + /// and queueing up new connections ready to be accepted by a call to 'accept()'. + /// + /// \param backlog + /// The maximum number of pending connections to allow in the queue of ready-to-accept + /// connections. + /// + /// \throws std::system_error + /// If the socket could not be placed into a listening mode. + void listen(std::uint32_t backlog); + + /// Connect the socket to the specified remote end-point. + /// + /// The socket must be in a bound but unconnected state prior to this call. + /// + /// \param remoteEndPoint + /// The IP address and port-number to connect to. + /// + /// \return + /// An awaitable object that must be co_await'ed to perform the async connect + /// operation. The result of the co_await expression is type void. + [[nodiscard]] + socket_connect_operation connect(const ip_endpoint& remoteEndPoint) noexcept; + + /// Connect to the specified remote end-point. + /// + /// \param remoteEndPoint + /// The IP address and port of the remote end-point to connect to. + /// + /// \param ct + /// A cancellation token that can be used to communicate a request to + /// later cancel the operation. If the operation is successfully + /// cancelled then it will complete by throwing a cppcoro::operation_cancelled + /// exception. + /// + /// \return + /// An awaitable object that will start the connect operation when co_await'ed + /// and will suspend the coroutine, resuming it when the operation completes. + /// The result of the co_await expression has type 'void'. + [[nodiscard]] + socket_connect_operation_cancellable connect( + const ip_endpoint& remoteEndPoint, + cancellation_token ct) noexcept; + + [[nodiscard]] + socket_accept_operation accept(socket& acceptingSocket) noexcept; + [[nodiscard]] + socket_accept_operation_cancellable accept( + socket& acceptingSocket, + cancellation_token ct) noexcept; + + [[nodiscard]] + socket_disconnect_operation disconnect() noexcept; + [[nodiscard]] + socket_disconnect_operation_cancellable disconnect(cancellation_token ct) noexcept; + + [[nodiscard]] + socket_send_operation send( + const void* buffer, + std::size_t size) noexcept; + [[nodiscard]] + socket_send_operation_cancellable send( + const void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + [[nodiscard]] + socket_recv_operation recv( + void* buffer, + std::size_t size) noexcept; + [[nodiscard]] + socket_recv_operation_cancellable recv( + void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + [[nodiscard]] + socket_recv_from_operation recv_from( + void* buffer, + std::size_t size) noexcept; + [[nodiscard]] + socket_recv_from_operation_cancellable recv_from( + void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + [[nodiscard]] + socket_send_to_operation send_to( + const ip_endpoint& destination, + const void* buffer, + std::size_t size) noexcept; + [[nodiscard]] + socket_send_to_operation_cancellable send_to( + const ip_endpoint& destination, + const void* buffer, + std::size_t size, + cancellation_token ct) noexcept; + + void close_send(); + void close_recv(); + + private: + + friend class socket_accept_operation_impl; + friend class socket_connect_operation_impl; + +#if CPPCORO_OS_WINNT + explicit socket( + cppcoro::detail::win32::socket_t handle, + bool skipCompletionOnSuccess) noexcept; +#endif + +#if CPPCORO_OS_WINNT + cppcoro::detail::win32::socket_t m_handle; + bool m_skipCompletionOnSuccess; +#endif + + ip_endpoint m_localEndPoint; + ip_endpoint m_remoteEndPoint; + + }; + } +} + +#endif diff --git a/include/cppcoro/net/socket_accept_operation.hpp b/include/cppcoro/net/socket_accept_operation.hpp new file mode 100644 index 0000000..ae966f0 --- /dev/null +++ b/include/cppcoro/net/socket_accept_operation.hpp @@ -0,0 +1,108 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_ACCEPT_OPERATION_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_ACCEPT_OPERATION_HPP_INCLUDED + +#include +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include + +# include +# include + +namespace cppcoro +{ + namespace net + { + class socket; + + class socket_accept_operation_impl + { + public: + + socket_accept_operation_impl( + socket& listeningSocket, + socket& acceptingSocket) noexcept + : m_listeningSocket(listeningSocket) + , m_acceptingSocket(acceptingSocket) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void get_result(cppcoro::detail::win32_overlapped_operation_base& operation); + + private: + +#if CPPCORO_COMPILER_MSVC +# pragma warning(push) +# pragma warning(disable : 4324) // Structure padded due to alignment +#endif + + socket& m_listeningSocket; + socket& m_acceptingSocket; + alignas(8) std::uint8_t m_addressBuffer[88]; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(pop) +#endif + + }; + + class socket_accept_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + socket_accept_operation( + socket& listeningSocket, + socket& acceptingSocket) noexcept + : m_impl(listeningSocket, acceptingSocket) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_accept_operation_impl m_impl; + + }; + + class socket_accept_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + socket_accept_operation_cancellable( + socket& listeningSocket, + socket& acceptingSocket, + cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(std::move(ct)) + , m_impl(listeningSocket, acceptingSocket) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_accept_operation_impl m_impl; + + }; + } +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/net/socket_connect_operation.hpp b/include/cppcoro/net/socket_connect_operation.hpp new file mode 100644 index 0000000..b7eedd3 --- /dev/null +++ b/include/cppcoro/net/socket_connect_operation.hpp @@ -0,0 +1,95 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_CONNECT_OPERATION_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_CONNECT_OPERATION_HPP_INCLUDED + +#include +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro +{ + namespace net + { + class socket; + + class socket_connect_operation_impl + { + public: + + socket_connect_operation_impl( + socket& socket, + const ip_endpoint& remoteEndPoint) noexcept + : m_socket(socket) + , m_remoteEndPoint(remoteEndPoint) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void get_result(cppcoro::detail::win32_overlapped_operation_base& operation); + + private: + + socket& m_socket; + ip_endpoint m_remoteEndPoint; + + }; + + class socket_connect_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + socket_connect_operation( + socket& socket, + const ip_endpoint& remoteEndPoint) noexcept + : m_impl(socket, remoteEndPoint) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + decltype(auto) get_result() { return m_impl.get_result(*this); } + + socket_connect_operation_impl m_impl; + + }; + + class socket_connect_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + socket_connect_operation_cancellable( + socket& socket, + const ip_endpoint& remoteEndPoint, + cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(std::move(ct)) + , m_impl(socket, remoteEndPoint) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_connect_operation_impl m_impl; + + }; + } +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/net/socket_disconnect_operation.hpp b/include/cppcoro/net/socket_disconnect_operation.hpp new file mode 100644 index 0000000..7bdcc03 --- /dev/null +++ b/include/cppcoro/net/socket_disconnect_operation.hpp @@ -0,0 +1,85 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_DISCONNECT_OPERATION_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_DISCONNECT_OPERATION_HPP_INCLUDED + +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro +{ + namespace net + { + class socket; + + class socket_disconnect_operation_impl + { + public: + + socket_disconnect_operation_impl(socket& socket) noexcept + : m_socket(socket) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void get_result(cppcoro::detail::win32_overlapped_operation_base& operation); + + private: + + socket& m_socket; + + }; + + class socket_disconnect_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + socket_disconnect_operation(socket& socket) noexcept + : m_impl(socket) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_disconnect_operation_impl m_impl; + + }; + + class socket_disconnect_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + socket_disconnect_operation_cancellable(socket& socket, cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(std::move(ct)) + , m_impl(socket) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_disconnect_operation_impl m_impl; + + }; + } +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/net/socket_recv_from_operation.hpp b/include/cppcoro/net/socket_recv_from_operation.hpp new file mode 100644 index 0000000..37f2d01 --- /dev/null +++ b/include/cppcoro/net/socket_recv_from_operation.hpp @@ -0,0 +1,106 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_RECV_FROM_OPERATION_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_RECV_FROM_OPERATION_HPP_INCLUDED + +#include +#include +#include + +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro::net +{ + class socket; + + class socket_recv_from_operation_impl + { + public: + + socket_recv_from_operation_impl( + socket& socket, + void* buffer, + std::size_t byteCount) noexcept + : m_socket(socket) + , m_buffer(buffer, byteCount) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + std::tuple get_result( + cppcoro::detail::win32_overlapped_operation_base& operation); + + private: + + socket& m_socket; + cppcoro::detail::win32::wsabuf m_buffer; + + static constexpr std::size_t sockaddrStorageAlignment = 4; + + // Storage suitable for either SOCKADDR_IN or SOCKADDR_IN6 + alignas(sockaddrStorageAlignment) std::uint8_t m_sourceSockaddrStorage[28]; + int m_sourceSockaddrLength; + + }; + + class socket_recv_from_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + socket_recv_from_operation( + socket& socket, + void* buffer, + std::size_t byteCount) noexcept + : m_impl(socket, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + decltype(auto) get_result() { return m_impl.get_result(*this); } + + socket_recv_from_operation_impl m_impl; + + }; + + class socket_recv_from_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + socket_recv_from_operation_cancellable( + socket& socket, + void* buffer, + std::size_t byteCount, + cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(std::move(ct)) + , m_impl(socket, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + decltype(auto) get_result() { return m_impl.get_result(*this); } + + socket_recv_from_operation_impl m_impl; + + }; + +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/net/socket_recv_operation.hpp b/include/cppcoro/net/socket_recv_operation.hpp new file mode 100644 index 0000000..c9dca8b --- /dev/null +++ b/include/cppcoro/net/socket_recv_operation.hpp @@ -0,0 +1,94 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_RECV_OPERATION_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_RECV_OPERATION_HPP_INCLUDED + +#include +#include + +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro::net +{ + class socket; + + class socket_recv_operation_impl + { + public: + + socket_recv_operation_impl( + socket& s, + void* buffer, + std::size_t byteCount) noexcept + : m_socket(s) + , m_buffer(buffer, byteCount) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + + private: + + socket& m_socket; + cppcoro::detail::win32::wsabuf m_buffer; + + }; + + class socket_recv_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + socket_recv_operation( + socket& s, + void* buffer, + std::size_t byteCount) noexcept + : m_impl(s, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + socket_recv_operation_impl m_impl; + + }; + + class socket_recv_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + socket_recv_operation_cancellable( + socket& s, + void* buffer, + std::size_t byteCount, + cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(std::move(ct)) + , m_impl(s, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + + socket_recv_operation_impl m_impl; + + }; + +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/net/socket_send_operation.hpp b/include/cppcoro/net/socket_send_operation.hpp new file mode 100644 index 0000000..702d2ab --- /dev/null +++ b/include/cppcoro/net/socket_send_operation.hpp @@ -0,0 +1,94 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_SEND_OPERATION_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_SEND_OPERATION_HPP_INCLUDED + +#include +#include + +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro::net +{ + class socket; + + class socket_send_operation_impl + { + public: + + socket_send_operation_impl( + socket& s, + const void* buffer, + std::size_t byteCount) noexcept + : m_socket(s) + , m_buffer(const_cast(buffer), byteCount) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + + private: + + socket& m_socket; + cppcoro::detail::win32::wsabuf m_buffer; + + }; + + class socket_send_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + socket_send_operation( + socket& s, + const void* buffer, + std::size_t byteCount) noexcept + : m_impl(s, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + socket_send_operation_impl m_impl; + + }; + + class socket_send_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + socket_send_operation_cancellable( + socket& s, + const void* buffer, + std::size_t byteCount, + cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(std::move(ct)) + , m_impl(s, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { return m_impl.cancel(*this); } + + socket_send_operation_impl m_impl; + + }; + +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/net/socket_send_to_operation.hpp b/include/cppcoro/net/socket_send_to_operation.hpp new file mode 100644 index 0000000..60d51b2 --- /dev/null +++ b/include/cppcoro/net/socket_send_to_operation.hpp @@ -0,0 +1,100 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_NET_SOCKET_SEND_TO_OPERATION_HPP_INCLUDED +#define CPPCORO_NET_SOCKET_SEND_TO_OPERATION_HPP_INCLUDED + +#include +#include +#include + +#include + +#if CPPCORO_OS_WINNT +# include +# include + +namespace cppcoro::net +{ + class socket; + + class socket_send_to_operation_impl + { + public: + + socket_send_to_operation_impl( + socket& s, + const ip_endpoint& destination, + const void* buffer, + std::size_t byteCount) noexcept + : m_socket(s) + , m_destination(destination) + , m_buffer(const_cast(buffer), byteCount) + {} + + bool try_start(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + void cancel(cppcoro::detail::win32_overlapped_operation_base& operation) noexcept; + + private: + + socket& m_socket; + ip_endpoint m_destination; + cppcoro::detail::win32::wsabuf m_buffer; + + }; + + class socket_send_to_operation + : public cppcoro::detail::win32_overlapped_operation + { + public: + + socket_send_to_operation( + socket& s, + const ip_endpoint& destination, + const void* buffer, + std::size_t byteCount) noexcept + : m_impl(s, destination, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + socket_send_to_operation_impl m_impl; + + }; + + class socket_send_to_operation_cancellable + : public cppcoro::detail::win32_overlapped_operation_cancellable + { + public: + + socket_send_to_operation_cancellable( + socket& s, + const ip_endpoint& destination, + const void* buffer, + std::size_t byteCount, + cancellation_token&& ct) noexcept + : cppcoro::detail::win32_overlapped_operation_cancellable(std::move(ct)) + , m_impl(s, destination, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::win32_overlapped_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { return m_impl.cancel(*this); } + + socket_send_to_operation_impl m_impl; + + }; + +} + +#endif // CPPCORO_OS_WINNT + +#endif diff --git a/include/cppcoro/on_scope_exit.hpp b/include/cppcoro/on_scope_exit.hpp new file mode 100644 index 0000000..8e2920d --- /dev/null +++ b/include/cppcoro/on_scope_exit.hpp @@ -0,0 +1,147 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ON_SCOPE_EXIT_HPP_INCLUDED +#define CPPCORO_ON_SCOPE_EXIT_HPP_INCLUDED + +#include +#include + +namespace cppcoro +{ + template + class scoped_lambda + { + public: + + scoped_lambda(FUNC&& func) + : m_func(std::forward(func)) + , m_cancelled(false) + {} + + scoped_lambda(const scoped_lambda& other) = delete; + + scoped_lambda(scoped_lambda&& other) + : m_func(std::forward(other.m_func)) + , m_cancelled(other.m_cancelled) + { + other.cancel(); + } + + ~scoped_lambda() + { + if (!m_cancelled) + { + m_func(); + } + } + + void cancel() + { + m_cancelled = true; + } + + void call_now() + { + m_cancelled = true; + m_func(); + } + + private: + + FUNC m_func; + bool m_cancelled; + + }; + + /// A scoped lambda that executes the lambda when the object destructs + /// but only if exiting due to an exception (CALL_ON_FAILURE = true) or + /// only if not exiting due to an exception (CALL_ON_FAILURE = false). + template + class conditional_scoped_lambda + { + public: + + conditional_scoped_lambda(FUNC&& func) + : m_func(std::forward(func)) + , m_uncaughtExceptionCount(std::uncaught_exceptions()) + , m_cancelled(false) + {} + + conditional_scoped_lambda(const conditional_scoped_lambda& other) = delete; + + conditional_scoped_lambda(conditional_scoped_lambda&& other) + noexcept(std::is_nothrow_move_constructible::value) + : m_func(std::forward(other.m_func)) + , m_uncaughtExceptionCount(other.m_uncaughtExceptionCount) + , m_cancelled(other.m_cancelled) + { + other.cancel(); + } + + ~conditional_scoped_lambda() noexcept(CALL_ON_FAILURE || noexcept(std::declval()())) + { + if (!m_cancelled && (is_unwinding_due_to_exception() == CALL_ON_FAILURE)) + { + m_func(); + } + } + + void cancel() noexcept + { + m_cancelled = true; + } + + private: + + bool is_unwinding_due_to_exception() const noexcept + { + return std::uncaught_exceptions() > m_uncaughtExceptionCount; + } + + FUNC m_func; + int m_uncaughtExceptionCount; + bool m_cancelled; + + }; + + /// Returns an object that calls the provided function when it goes out + /// of scope either normally or due to an uncaught exception unwinding + /// the stack. + /// + /// \param func + /// The function to call when the scope exits. + /// The function must be noexcept. + template + auto on_scope_exit(FUNC&& func) + { + return scoped_lambda{ std::forward(func) }; + } + + /// Returns an object that calls the provided function when it goes out + /// of scope due to an uncaught exception unwinding the stack. + /// + /// \param func + /// The function to be called if unwinding due to an exception. + /// The function must be noexcept. + template + auto on_scope_failure(FUNC&& func) + { + return conditional_scoped_lambda{ std::forward(func) }; + } + + /// Returns an object that calls the provided function when it goes out + /// of scope via normal execution (ie. not unwinding due to an exception). + /// + /// \param func + /// The function to call if the scope exits normally. + /// The function does not necessarily need to be noexcept. + template + auto on_scope_success(FUNC&& func) + { + return conditional_scoped_lambda{ std::forward(func) }; + } +} + +#endif diff --git a/include/cppcoro/operation_cancelled.hpp b/include/cppcoro/operation_cancelled.hpp new file mode 100644 index 0000000..2746ae0 --- /dev/null +++ b/include/cppcoro/operation_cancelled.hpp @@ -0,0 +1,24 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_OPERATION_CANCELLED_HPP_INCLUDED +#define CPPCORO_OPERATION_CANCELLED_HPP_INCLUDED + +#include + +namespace cppcoro +{ + class operation_cancelled : public std::exception + { + public: + + operation_cancelled() noexcept + : std::exception() + {} + + const char* what() const noexcept override { return "operation cancelled"; } + }; +} + +#endif diff --git a/include/cppcoro/read_only_file.hpp b/include/cppcoro/read_only_file.hpp new file mode 100644 index 0000000..2527142 --- /dev/null +++ b/include/cppcoro/read_only_file.hpp @@ -0,0 +1,59 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_READ_ONLY_FILE_HPP_INCLUDED +#define CPPCORO_READ_ONLY_FILE_HPP_INCLUDED + +#include +#include +#include + +#include + +namespace cppcoro +{ + class read_only_file : public readable_file + { + public: + + /// Open a file for read-only access. + /// + /// \param ioContext + /// The I/O context to use when dispatching I/O completion events. + /// When asynchronous read operations on this file complete the + /// completion events will be dispatched to an I/O thread associated + /// with the I/O context. + /// + /// \param path + /// Path of the file to open. + /// + /// \param shareMode + /// Specifies the access to be allowed on the file concurrently with this file access. + /// + /// \param bufferingMode + /// Specifies the modes/hints to provide to the OS that affects the behaviour + /// of its file buffering. + /// + /// \return + /// An object that can be used to read from the file. + /// + /// \throw std::system_error + /// If the file could not be opened for read. + [[nodiscard]] + static read_only_file open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_share_mode shareMode = file_share_mode::read, + file_buffering_mode bufferingMode = file_buffering_mode::default_); + + protected: + +#if CPPCORO_OS_WINNT + read_only_file(detail::win32::safe_handle&& fileHandle) noexcept; +#endif + + }; +} + +#endif diff --git a/include/cppcoro/read_write_file.hpp b/include/cppcoro/read_write_file.hpp new file mode 100644 index 0000000..de9e537 --- /dev/null +++ b/include/cppcoro/read_write_file.hpp @@ -0,0 +1,66 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_READ_WRITE_FILE_HPP_INCLUDED +#define CPPCORO_READ_WRITE_FILE_HPP_INCLUDED + +#include +#include +#include +#include +#include + +#include + +namespace cppcoro +{ + class read_write_file : public readable_file, public writable_file + { + public: + + /// Open a file for read-write access. + /// + /// \param ioContext + /// The I/O context to use when dispatching I/O completion events. + /// When asynchronous write operations on this file complete the + /// completion events will be dispatched to an I/O thread associated + /// with the I/O context. + /// + /// \param pathMode + /// Path of the file to open. + /// + /// \param openMode + /// Specifies how the file should be opened and how to handle cases + /// when the file exists or doesn't exist. + /// + /// \param shareMode + /// Specifies the access to be allowed on the file concurrently with this file access. + /// + /// \param bufferingMode + /// Specifies the modes/hints to provide to the OS that affects the behaviour + /// of its file buffering. + /// + /// \return + /// An object that can be used to write to the file. + /// + /// \throw std::system_error + /// If the file could not be opened for write. + [[nodiscard]] + static read_write_file open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode = file_open_mode::create_or_open, + file_share_mode shareMode = file_share_mode::none, + file_buffering_mode bufferingMode = file_buffering_mode::default_); + + protected: + +#if CPPCORO_OS_WINNT + read_write_file(detail::win32::safe_handle&& fileHandle) noexcept; +#endif + + }; +} + +#endif diff --git a/include/cppcoro/readable_file.hpp b/include/cppcoro/readable_file.hpp new file mode 100644 index 0000000..01159df --- /dev/null +++ b/include/cppcoro/readable_file.hpp @@ -0,0 +1,65 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_READABLE_FILE_HPP_INCLUDED +#define CPPCORO_READABLE_FILE_HPP_INCLUDED + +#include +#include +#include + +namespace cppcoro +{ + class readable_file : virtual public file + { + public: + + /// Read some data from the file. + /// + /// Reads \a byteCount bytes from the file starting at \a offset + /// into the specified \a buffer. + /// + /// \param offset + /// The offset within the file to start reading from. + /// If the file has been opened using file_buffering_mode::unbuffered + /// then the offset must be a multiple of the file-system's sector size. + /// + /// \param buffer + /// The buffer to read the file contents into. + /// If the file has been opened using file_buffering_mode::unbuffered + /// then the address of the start of the buffer must be a multiple of + /// the file-system's sector size. + /// + /// \param byteCount + /// The number of bytes to read from the file. + /// If the file has been opeend using file_buffering_mode::unbuffered + /// then the byteCount must be a multiple of the file-system's sector size. + /// + /// \param ct + /// An optional cancellation_token that can be used to cancel the + /// read operation before it completes. + /// + /// \return + /// An object that represents the read-operation. + /// This object must be co_await'ed to start the read operation. + [[nodiscard]] + file_read_operation read( + std::uint64_t offset, + void* buffer, + std::size_t byteCount) const noexcept; + [[nodiscard]] + file_read_operation_cancellable read( + std::uint64_t offset, + void* buffer, + std::size_t byteCount, + cancellation_token ct) const noexcept; + + protected: + + using file::file; + + }; +} + +#endif diff --git a/include/cppcoro/recursive_generator.hpp b/include/cppcoro/recursive_generator.hpp new file mode 100644 index 0000000..65af46c --- /dev/null +++ b/include/cppcoro/recursive_generator.hpp @@ -0,0 +1,345 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_RECURSIVE_GENERATOR_HPP_INCLUDED +#define CPPCORO_RECURSIVE_GENERATOR_HPP_INCLUDED + +#include + +#include +#include +#include +#include +#include + +namespace cppcoro +{ + template + class [[nodiscard]] recursive_generator + { + public: + + class promise_type final + { + public: + + promise_type() noexcept + : m_value(nullptr) + , m_exception(nullptr) + , m_root(this) + , m_parentOrLeaf(this) + {} + + promise_type(const promise_type&) = delete; + promise_type(promise_type&&) = delete; + + auto get_return_object() noexcept + { + return recursive_generator{ *this }; + } + + cppcoro::suspend_always initial_suspend() noexcept + { + return {}; + } + + cppcoro::suspend_always final_suspend() noexcept + { + return {}; + } + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + void return_void() noexcept {} + + cppcoro::suspend_always yield_value(T& value) noexcept + { + m_value = std::addressof(value); + return {}; + } + + cppcoro::suspend_always yield_value(T&& value) noexcept + { + m_value = std::addressof(value); + return {}; + } + + auto yield_value(recursive_generator&& generator) noexcept + { + return yield_value(generator); + } + + auto yield_value(recursive_generator& generator) noexcept + { + struct awaitable + { + + awaitable(promise_type* childPromise) + : m_childPromise(childPromise) + {} + + bool await_ready() noexcept + { + return this->m_childPromise == nullptr; + } + + void await_suspend(cppcoro::coroutine_handle) noexcept + {} + + void await_resume() + { + if (this->m_childPromise != nullptr) + { + this->m_childPromise->throw_if_exception(); + } + } + + private: + promise_type* m_childPromise; + }; + + if (generator.m_promise != nullptr) + { + m_root->m_parentOrLeaf = generator.m_promise; + generator.m_promise->m_root = m_root; + generator.m_promise->m_parentOrLeaf = this; + generator.m_promise->resume(); + + if (!generator.m_promise->is_complete()) + { + return awaitable{ generator.m_promise }; + } + + m_root->m_parentOrLeaf = this; + } + + return awaitable{ nullptr }; + } + + // Don't allow any use of 'co_await' inside the recursive_generator coroutine. + template + cppcoro::suspend_never await_transform(U&& value) = delete; + + void destroy() noexcept + { + cppcoro::coroutine_handle::from_promise(*this).destroy(); + } + + void throw_if_exception() + { + if (m_exception != nullptr) + { + std::rethrow_exception(std::move(m_exception)); + } + } + + bool is_complete() noexcept + { + return cppcoro::coroutine_handle::from_promise(*this).done(); + } + + T& value() noexcept + { + assert(this == m_root); + assert(!is_complete()); + return *(m_parentOrLeaf->m_value); + } + + void pull() noexcept + { + assert(this == m_root); + assert(!m_parentOrLeaf->is_complete()); + + m_parentOrLeaf->resume(); + + while (m_parentOrLeaf != this && m_parentOrLeaf->is_complete()) + { + m_parentOrLeaf = m_parentOrLeaf->m_parentOrLeaf; + m_parentOrLeaf->resume(); + } + } + + private: + + void resume() noexcept + { + cppcoro::coroutine_handle::from_promise(*this).resume(); + } + + std::add_pointer_t m_value; + std::exception_ptr m_exception; + + promise_type* m_root; + + // If this is the promise of the root generator then this field + // is a pointer to the leaf promise. + // For non-root generators this is a pointer to the parent promise. + promise_type* m_parentOrLeaf; + + }; + + recursive_generator() noexcept + : m_promise(nullptr) + {} + + recursive_generator(promise_type& promise) noexcept + : m_promise(&promise) + {} + + recursive_generator(recursive_generator&& other) noexcept + : m_promise(other.m_promise) + { + other.m_promise = nullptr; + } + + recursive_generator(const recursive_generator& other) = delete; + recursive_generator& operator=(const recursive_generator& other) = delete; + + ~recursive_generator() + { + if (m_promise != nullptr) + { + m_promise->destroy(); + } + } + + recursive_generator& operator=(recursive_generator&& other) noexcept + { + if (this != &other) + { + if (m_promise != nullptr) + { + m_promise->destroy(); + } + + m_promise = other.m_promise; + other.m_promise = nullptr; + } + + return *this; + } + + class iterator + { + public: + + using iterator_category = std::input_iterator_tag; + // What type should we use for counting elements of a potentially infinite sequence? + using difference_type = std::ptrdiff_t; + using value_type = std::remove_reference_t; + using reference = std::conditional_t, T, T&>; + using pointer = std::add_pointer_t; + + iterator() noexcept + : m_promise(nullptr) + {} + + explicit iterator(promise_type* promise) noexcept + : m_promise(promise) + {} + + bool operator==(const iterator& other) const noexcept + { + return m_promise == other.m_promise; + } + + bool operator!=(const iterator& other) const noexcept + { + return m_promise != other.m_promise; + } + + iterator& operator++() + { + assert(m_promise != nullptr); + assert(!m_promise->is_complete()); + + m_promise->pull(); + if (m_promise->is_complete()) + { + auto* temp = m_promise; + m_promise = nullptr; + temp->throw_if_exception(); + } + + return *this; + } + + void operator++(int) + { + (void)operator++(); + } + + reference operator*() const noexcept + { + assert(m_promise != nullptr); + return static_cast(m_promise->value()); + } + + pointer operator->() const noexcept + { + return std::addressof(operator*()); + } + + private: + + promise_type* m_promise; + + }; + + iterator begin() + { + if (m_promise != nullptr) + { + m_promise->pull(); + if (!m_promise->is_complete()) + { + return iterator(m_promise); + } + + m_promise->throw_if_exception(); + } + + return iterator(nullptr); + } + + iterator end() noexcept + { + return iterator(nullptr); + } + + void swap(recursive_generator& other) noexcept + { + std::swap(m_promise, other.m_promise); + } + + private: + + friend class promise_type; + + promise_type* m_promise; + + }; + + template + void swap(recursive_generator& a, recursive_generator& b) noexcept + { + a.swap(b); + } + + // Note: When applying fmap operator to a recursive_generator we just yield a non-recursive + // generator since we generally won't be using the result in a recursive context. + template + generator::iterator::reference>> fmap(FUNC func, recursive_generator source) + { + for (auto&& value : source) + { + co_yield std::invoke(func, static_cast(value)); + } + } +} + +#endif diff --git a/include/cppcoro/resume_on.hpp b/include/cppcoro/resume_on.hpp new file mode 100644 index 0000000..b26188b --- /dev/null +++ b/include/cppcoro/resume_on.hpp @@ -0,0 +1,129 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_RESUME_ON_HPP_INCLUDED +#define CPPCORO_RESUME_ON_HPP_INCLUDED + +#include +#include +#include +#include + +#include +#include + +namespace cppcoro +{ + template + struct resume_on_transform + { + explicit resume_on_transform(SCHEDULER& s) noexcept + : scheduler(s) + {} + + SCHEDULER& scheduler; + }; + + template + resume_on_transform resume_on(SCHEDULER& scheduler) noexcept + { + return resume_on_transform(scheduler); + } + + template + decltype(auto) operator|(T&& value, resume_on_transform transform) + { + return resume_on(transform.scheduler, std::forward(value)); + } + + template< + typename SCHEDULER, + typename AWAITABLE, + typename AWAIT_RESULT = detail::remove_rvalue_reference_t::await_result_t>, + std::enable_if_t, int> = 0> + auto resume_on(SCHEDULER& scheduler, AWAITABLE awaitable) + -> task + { + bool rescheduled = false; + std::exception_ptr ex; + try + { + // We manually get the awaiter here so that we can keep + // it alive across the call to `scheduler.schedule()` + // just in case the result is a reference to a value + // in the awaiter that would otherwise be a temporary + // and destructed before the value could be returned. + + auto&& awaiter = detail::get_awaiter(static_cast(awaitable)); + + auto&& result = co_await static_cast(awaiter); + + // Flag as rescheduled before scheduling in case it is the + // schedule() operation that throws an exception as we don't + // want to attempt to schedule twice if scheduling fails. + rescheduled = true; + + co_await scheduler.schedule(); + + co_return static_cast(result); + } + catch (...) + { + ex = std::current_exception(); + } + + // We still want to resume on the scheduler even in the presence + // of an exception. + if (!rescheduled) + { + co_await scheduler.schedule(); + } + + std::rethrow_exception(ex); + } + + template< + typename SCHEDULER, + typename AWAITABLE, + typename AWAIT_RESULT = detail::remove_rvalue_reference_t::await_result_t>, + std::enable_if_t, int> = 0> + auto resume_on(SCHEDULER& scheduler, AWAITABLE awaitable) + -> task<> + { + std::exception_ptr ex; + try + { + co_await static_cast(awaitable); + } + catch (...) + { + ex = std::current_exception(); + } + + // NOTE: We're assuming that `schedule()` operation is noexcept + // here. If it were to throw what would we do if 'ex' was non-null? + // Presumably we'd treat it the same as throwing an exception while + // unwinding and call std::terminate()? + + co_await scheduler.schedule(); + + if (ex) + { + std::rethrow_exception(ex); + } + } + + template + async_generator resume_on(SCHEDULER& scheduler, async_generator source) + { + for (auto iter = co_await source.begin(); iter != source.end(); co_await ++iter) + { + auto& value = *iter; + co_await scheduler.schedule(); + co_yield value; + } + } +} + +#endif diff --git a/include/cppcoro/round_robin_scheduler.hpp b/include/cppcoro/round_robin_scheduler.hpp new file mode 100644 index 0000000..cd15405 --- /dev/null +++ b/include/cppcoro/round_robin_scheduler.hpp @@ -0,0 +1,124 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_ROUND_ROBIN_SCHEDULER_HPP_INCLUDED +#define CPPCORO_ROUND_ROBIN_SCHEDULER_HPP_INCLUDED + +#include + +#include +#include +#include +#include +#include + +namespace cppcoro +{ +#if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + /// This is a scheduler class that schedules coroutines in a round-robin + /// fashion once N coroutines have been scheduled to it. + /// + /// Only supports access from a single thread at a time so + /// + /// This implementation was inspired by Gor Nishanov's CppCon 2018 talk + /// about nano-coroutines. + /// + /// The implementation relies on symmetric transfer and noop_coroutine() + /// and so only works with a relatively recent version of Clang and does + /// not yet work with MSVC. + template + class round_robin_scheduler + { + static_assert( + N >= 2, + "Round robin scheduler must be configured to support at least two coroutines"); + + class schedule_operation + { + public: + explicit schedule_operation(round_robin_scheduler& s) noexcept : m_scheduler(s) {} + + bool await_ready() noexcept + { + return false; + } + + cppcoro::coroutine_handle<> await_suspend( + cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_scheduler.exchange_next(awaitingCoroutine); + } + + void await_resume() noexcept {} + + private: + round_robin_scheduler& m_scheduler; + }; + + friend class schedule_operation; + + public: + round_robin_scheduler() noexcept + : m_index(0) + , m_noop(cppcoro::noop_coroutine()) + { + for (size_t i = 0; i < N - 1; ++i) + { + m_coroutines[i] = m_noop(); + } + } + + ~round_robin_scheduler() + { + // All tasks should have been joined before calling destructor. + assert(std::all_of( + m_coroutines.begin(), + m_coroutines.end(), + [&](auto h) { return h == m_noop; })); + } + + schedule_operation schedule() noexcept + { + return schedule_operation{ *this }; + } + + /// Resume any queued coroutines until there are no more coroutines. + void drain() noexcept + { + size_t countRemaining = N - 1; + do + { + auto nextToResume = exchange_next(m_noop); + if (nextToResume != m_noop) + { + nextToResume.resume(); + countRemaining = N - 1; + } + else + { + --countRemaining; + } + } while (countRemaining > 0); + } + + private: + + cppcoro::coroutine_handle exchange_next( + cppcoro::coroutine_handle<> coroutine) noexcept + { + auto coroutineToResume = std::exchange( + m_scheduler.m_coroutines[m_scheduler.m_index], + awaitingCoroutine); + m_scheduler.m_index = m_scheduler.m_index < (N - 2) ? m_scheduler.m_index + 1 : 0; + return coroutineToResume; + } + + size_t m_index; + const cppcoro::coroutine_handle<> m_noop; + std::array, N - 1> m_coroutines; + }; +#endif +} + +#endif diff --git a/include/cppcoro/schedule_on.hpp b/include/cppcoro/schedule_on.hpp new file mode 100644 index 0000000..98ecf42 --- /dev/null +++ b/include/cppcoro/schedule_on.hpp @@ -0,0 +1,69 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SCHEDULE_ON_HPP_INCLUDED +#define CPPCORO_SCHEDULE_ON_HPP_INCLUDED + +#include +#include +#include +#include + +#include + +namespace cppcoro +{ + template + struct schedule_on_transform + { + explicit schedule_on_transform(SCHEDULER& scheduler) noexcept + : scheduler(scheduler) + {} + + SCHEDULER& scheduler; + }; + + template + schedule_on_transform schedule_on(SCHEDULER& scheduler) + { + return schedule_on_transform{ scheduler }; + } + + template + decltype(auto) operator|(T&& value, schedule_on_transform transform) + { + return schedule_on(transform.scheduler, std::forward(value)); + } + + template + auto schedule_on(SCHEDULER& scheduler, AWAITABLE awaitable) + -> task::await_result_t>> + { + co_await scheduler.schedule(); + co_return co_await std::move(awaitable); + } + + template + async_generator schedule_on(SCHEDULER& scheduler, async_generator source) + { + // Transfer exection to the scheduler before the implicit calls to + // 'co_await begin()' or subsequent calls to `co_await iterator::operator++()` + // below. This ensures that all calls to the generator's coroutine_handle<>::resume() + // are executed on the execution context of the scheduler. + co_await scheduler.schedule(); + + const auto itEnd = source.end(); + auto it = co_await source.begin(); + while (it != itEnd) + { + co_yield *it; + + co_await scheduler.schedule(); + + (void)co_await ++it; + } + } +} + +#endif diff --git a/include/cppcoro/sequence_barrier.hpp b/include/cppcoro/sequence_barrier.hpp new file mode 100644 index 0000000..e1f0eb8 --- /dev/null +++ b/include/cppcoro/sequence_barrier.hpp @@ -0,0 +1,470 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SEQUENCE_BARRIER_HPP_INCLUDED +#define CPPCORO_SEQUENCE_BARRIER_HPP_INCLUDED + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace cppcoro +{ + template + class sequence_barrier_wait_operation_base; + + template + class sequence_barrier_wait_operation; + + /// A sequence barrier is a synchronisation primitive that allows a single-producer + /// and multiple-consumers to coordinate with respect to a monotonically increasing + /// sequence number. + /// + /// A single producer advances the sequence number by publishing new sequence numbers in a + /// monotonically increasing order. One or more consumers can query the last-published + /// sequence number and can wait until a particular sequence number has been published. + /// + /// A sequence barrier can be used to represent a cursor into a thread-safe producer/consumer + /// ring-buffer. + /// + /// See the LMAX Disruptor pattern for more background: + /// https://lmax-exchange.github.io/disruptor/files/Disruptor-1.0.pdf + template< + typename SEQUENCE = std::size_t, + typename TRAITS = sequence_traits> + class sequence_barrier + { + static_assert( + std::is_integral_v, + "sequence_barrier requires an integral sequence type"); + + using awaiter_t = sequence_barrier_wait_operation_base; + + public: + + /// Construct a sequence barrier with the specified initial sequence number + /// as the initial value 'last_published()'. + sequence_barrier(SEQUENCE initialSequence = TRAITS::initial_sequence) noexcept + : m_lastPublished(initialSequence) + , m_awaiters(nullptr) + {} + + ~sequence_barrier() + { + // Shouldn't be destructing a sequence barrier if there are still waiters. + assert(m_awaiters.load(std::memory_order_relaxed) == nullptr); + } + + /// Query the sequence number that was most recently published by the producer. + /// + /// You can assume that all sequence numbers prior to the returned sequence number + /// have also been published. This means you can safely access all elements with + /// sequence numbers up to and including the returned sequence number without any + /// further synchronisation. + SEQUENCE last_published() const noexcept + { + return m_lastPublished.load(std::memory_order_acquire); + } + + /// Wait until a particular sequence number has been published. + /// + /// If the specified sequence number is not yet published then the awaiting coroutine + /// will be suspended and later resumed inside the call to publish() that publishes + /// the specified sequence number. + /// + /// \param targetSequence + /// The sequence number to wait for. + /// + /// \return + /// An awaitable that when co_await'ed will suspend the awaiting coroutine until + /// the specified target sequence number has been published. + /// The result of the co_await expression will be the last-known published sequence + /// number. This is guaranteed not to precede \p targetSequence but may be a sequence + /// number after \p targetSequence, which indicates that more elements have been + /// published than you were waiting for. + template + [[nodiscard]] + sequence_barrier_wait_operation wait_until_published( + SEQUENCE targetSequence, + SCHEDULER& scheduler) const noexcept; + + /// Publish the specified sequence number to consumers. + /// + /// This publishes all sequence numbers up to and including the specified sequence + /// number. This will resume any coroutine that was suspended waiting for a sequence + /// number that was published by this operation. + /// + /// \param sequence + /// The sequence number to publish. This number must not precede the current + /// last_published() value. ie. the published sequence numbers must be monotonically + /// increasing. + void publish(SEQUENCE sequence) noexcept; + + private: + + friend class sequence_barrier_wait_operation_base; + + void add_awaiter(awaiter_t* awaiter) const noexcept; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(push) +# pragma warning(disable : 4324) // C4324: structure was padded due to alignment specifier +#endif + + // First cache-line is written to by the producer only + alignas(CPPCORO_CPU_CACHE_LINE) + std::atomic m_lastPublished; + + // Second cache-line is written to by both the producer and consumers + alignas(CPPCORO_CPU_CACHE_LINE) + mutable std::atomic m_awaiters; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(pop) +#endif + + }; + + template + class sequence_barrier_wait_operation_base + { + public: + + explicit sequence_barrier_wait_operation_base( + const sequence_barrier& barrier, + SEQUENCE targetSequence) noexcept + : m_barrier(barrier) + , m_targetSequence(targetSequence) + , m_lastKnownPublished(barrier.last_published()) + , m_readyToResume(false) + {} + + sequence_barrier_wait_operation_base( + const sequence_barrier_wait_operation_base& other) noexcept + : m_barrier(other.m_barrier) + , m_targetSequence(other.m_targetSequence) + , m_lastKnownPublished(other.m_lastKnownPublished) + , m_readyToResume(false) + {} + + bool await_ready() const noexcept + { + return !TRAITS::precedes(m_lastKnownPublished, m_targetSequence); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + m_awaitingCoroutine = awaitingCoroutine; + m_barrier.add_awaiter(this); + return !m_readyToResume.exchange(true, std::memory_order_acquire); + } + + SEQUENCE await_resume() noexcept + { + return m_lastKnownPublished; + } + + protected: + + friend class sequence_barrier; + + void resume() noexcept + { + // This synchronises with the exchange(true, std::memory_order_acquire) in await_suspend(). + if (m_readyToResume.exchange(true, std::memory_order_release)) + { + resume_impl(); + } + } + + virtual void resume_impl() noexcept = 0; + + const sequence_barrier& m_barrier; + const SEQUENCE m_targetSequence; + SEQUENCE m_lastKnownPublished; + sequence_barrier_wait_operation_base* m_next; + cppcoro::coroutine_handle<> m_awaitingCoroutine; + std::atomic m_readyToResume; + + }; + + template + class sequence_barrier_wait_operation : public sequence_barrier_wait_operation_base + { + using schedule_operation = decltype(std::declval().schedule()); + + public: + sequence_barrier_wait_operation( + const sequence_barrier& barrier, + SEQUENCE targetSequence, + SCHEDULER& scheduler) noexcept + : sequence_barrier_wait_operation_base(barrier, targetSequence) + , m_scheduler(scheduler) + {} + + sequence_barrier_wait_operation( + const sequence_barrier_wait_operation& other) noexcept + : sequence_barrier_wait_operation_base(other) + , m_scheduler(other.m_scheduler) + {} + + ~sequence_barrier_wait_operation() + { + if (m_isScheduleAwaiterCreated) + { + m_scheduleAwaiter.destruct(); + } + if (m_isScheduleOperationCreated) + { + m_scheduleOperation.destruct(); + } + } + + decltype(auto) await_resume() noexcept(noexcept(m_scheduleAwaiter->await_resume())) + { + if (m_isScheduleAwaiterCreated) + { + m_scheduleAwaiter->await_resume(); + } + + return sequence_barrier_wait_operation_base::await_resume(); + } + + private: + + void resume_impl() noexcept override + { + try + { + m_scheduleOperation.construct(m_scheduler.schedule()); + m_isScheduleOperationCreated = true; + + m_scheduleAwaiter.construct(detail::get_awaiter( + static_cast(*m_scheduleOperation))); + m_isScheduleAwaiterCreated = true; + + if (!m_scheduleAwaiter->await_ready()) + { + using await_suspend_result_t = decltype(m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine)); + if constexpr (std::is_void_v) + { + m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine); + return; + } + else if constexpr (std::is_same_v) + { + if (m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine)) + { + return; + } + } + else + { + // Assume it returns a coroutine_handle. + m_scheduleAwaiter->await_suspend(this->m_awaitingCoroutine).resume(); + return; + } + } + } + catch (...) + { + // Ignore failure to reschedule and resume inline? + // Should we catch the exception and rethrow from await_resume()? + // Or should we require that 'co_await scheduler.schedule()' is noexcept? + } + + // Resume outside the catch-block. + this->m_awaitingCoroutine.resume(); + } + + SCHEDULER& m_scheduler; + // Can't use std::optional here since T could be a reference. + detail::manual_lifetime m_scheduleOperation; + detail::manual_lifetime::awaiter_t> m_scheduleAwaiter; + bool m_isScheduleOperationCreated = false; + bool m_isScheduleAwaiterCreated = false; + }; + + template + template + [[nodiscard]] + sequence_barrier_wait_operation sequence_barrier::wait_until_published( + SEQUENCE targetSequence, + SCHEDULER& scheduler) const noexcept + { + return sequence_barrier_wait_operation(*this, targetSequence, scheduler); + } + + template + void sequence_barrier::publish(SEQUENCE sequence) noexcept + { + m_lastPublished.store(sequence, std::memory_order_seq_cst); + + // Cheaper check to see if there are any awaiting coroutines. + auto* awaiters = m_awaiters.load(std::memory_order_seq_cst); + if (awaiters == nullptr) + { + return; + } + + // Acquire the list of awaiters. + // Note we may be racing with add_awaiter() which could also acquire the list of waiters + // so we need to check again whether we won the race and acquired the list. + awaiters = m_awaiters.exchange(nullptr, std::memory_order_acquire); + if (awaiters == nullptr) + { + return; + } + + // Check the list of awaiters for ones that are now satisfied by the sequence number + // we just published. Awaiters are added to either the 'awaitersToResume' list or to + // the 'awaitersToRequeue' list. + awaiter_t* awaitersToResume; + awaiter_t** awaitersToResumeTail = &awaitersToResume; + + awaiter_t* awaitersToRequeue; + awaiter_t** awaitersToRequeueTail = &awaitersToRequeue; + + do + { + if (TRAITS::precedes(sequence, awaiters->m_targetSequence)) + { + // Target sequence not reached. Append to 'requeue' list. + *awaitersToRequeueTail = awaiters; + awaitersToRequeueTail = &awaiters->m_next; + } + else + { + // Target sequence reached. Append to 'resume' list. + *awaitersToResumeTail = awaiters; + awaitersToResumeTail = &awaiters->m_next; + } + awaiters = awaiters->m_next; + } while (awaiters != nullptr); + + // Null-terminate the two lists. + *awaitersToRequeueTail = nullptr; + *awaitersToResumeTail = nullptr; + + if (awaitersToRequeue != nullptr) + { + awaiter_t* oldHead = nullptr; + while (!m_awaiters.compare_exchange_weak( + oldHead, + awaitersToRequeue, + std::memory_order_release, + std::memory_order_relaxed)) + { + *awaitersToRequeueTail = oldHead; + } + } + + while (awaitersToResume != nullptr) + { + auto* next = awaitersToResume->m_next; + awaitersToResume->m_lastKnownPublished = sequence; + awaitersToResume->resume(); + awaitersToResume = next; + } + } + + template + void sequence_barrier::add_awaiter(awaiter_t* awaiter) const noexcept + { + SEQUENCE targetSequence = awaiter->m_targetSequence; + awaiter_t* awaitersToRequeue = awaiter; + awaiter_t** awaitersToRequeueTail = &awaiter->m_next; + + SEQUENCE lastKnownPublished; + awaiter_t* awaitersToResume; + awaiter_t** awaitersToResumeTail = &awaitersToResume; + + do + { + // Enqueue the awaiter(s) + { + auto* oldHead = m_awaiters.load(std::memory_order_relaxed); + do + { + *awaitersToRequeueTail = oldHead; + } while (!m_awaiters.compare_exchange_weak( + oldHead, + awaitersToRequeue, + std::memory_order_seq_cst, + std::memory_order_relaxed)); + } + + // Check that the sequence we were waiting for wasn't published while + // we were enqueueing the waiter. + // This needs to be seq_cst memory order to ensure that in the case that the producer + // publishes a new sequence number concurrently with this call that we either see + // their write to m_lastPublished after enqueueing our awaiter, or they see our + // write to m_awaiters after their write to m_lastPublished. + lastKnownPublished = m_lastPublished.load(std::memory_order_seq_cst); + if (TRAITS::precedes(lastKnownPublished, targetSequence)) + { + // None of the the awaiters we enqueued have been satisfied yet. + break; + } + + // Reset the requeue list to empty + awaitersToRequeueTail = &awaitersToRequeue; + + // At least one of the awaiters we just enqueued is now satisfied by a concurrently + // published sequence number. The producer thread may not have seen our write to m_awaiters + // so we need to try to re-acquire the list of awaiters to ensure that the waiters that + // are now satisfied are woken up. + auto* awaiters = m_awaiters.exchange(nullptr, std::memory_order_acquire); + + auto minDiff = std::numeric_limits::max(); + + while (awaiters != nullptr) + { + const auto diff = TRAITS::difference(awaiters->m_targetSequence, lastKnownPublished); + if (diff > 0) + { + *awaitersToRequeueTail = awaiters; + awaitersToRequeueTail = &awaiters->m_next; + minDiff = diff < minDiff ? diff : minDiff; + } + else + { + *awaitersToResumeTail = awaiters; + awaitersToResumeTail = &awaiters->m_next; + } + + awaiters = awaiters->m_next; + } + + // Null-terminate the list of awaiters to requeue. + *awaitersToRequeueTail = nullptr; + + // Calculate the earliest target sequence required by any of the awaiters to requeue. + targetSequence = static_cast(lastKnownPublished + minDiff); + + } while (awaitersToRequeue != nullptr); + + // Null-terminate the list of awaiters to resume + *awaitersToResumeTail = nullptr; + + // Resume the awaiters that are ready + while (awaitersToResume != nullptr) + { + auto* next = awaitersToResume->m_next; + awaitersToResume->m_lastKnownPublished = lastKnownPublished; + awaitersToResume->resume(); + awaitersToResume = next; + } + } +} + +#endif diff --git a/include/cppcoro/sequence_range.hpp b/include/cppcoro/sequence_range.hpp new file mode 100644 index 0000000..fcc8145 --- /dev/null +++ b/include/cppcoro/sequence_range.hpp @@ -0,0 +1,107 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SEQUENCE_RANGE_HPP_INCLUDED +#define CPPCORO_SEQUENCE_RANGE_HPP_INCLUDED + +#include + +#include +#include + +namespace cppcoro +{ + template> + class sequence_range + { + public: + + using value_type = SEQUENCE; + using difference_type = typename TRAITS::difference_type; + using size_type = typename TRAITS::size_type; + + class const_iterator + { + public: + + using iterator_category = std::random_access_iterator_tag; + using value_type = SEQUENCE; + using difference_type = typename TRAITS::difference_type; + using reference = const SEQUENCE&; + using pointer = const SEQUENCE*; + + explicit constexpr const_iterator(SEQUENCE value) noexcept : m_value(value) {} + + const SEQUENCE& operator*() const noexcept { return m_value; } + const SEQUENCE* operator->() const noexcept { return std::addressof(m_value); } + + const_iterator& operator++() noexcept { ++m_value; return *this; } + const_iterator& operator--() noexcept { --m_value; return *this; } + + const_iterator operator++(int) noexcept { return const_iterator(m_value++); } + const_iterator operator--(int) noexcept { return const_iterator(m_value--); } + + constexpr difference_type operator-(const_iterator other) const noexcept { return TRAITS::difference(m_value, other.m_value); } + constexpr const_iterator operator-(difference_type delta) const noexcept { return const_iterator{ static_cast(m_value - delta) }; } + constexpr const_iterator operator+(difference_type delta) const noexcept { return const_iterator{ static_cast(m_value + delta) }; } + + constexpr bool operator==(const_iterator other) const noexcept { return m_value == other.m_value; } + constexpr bool operator!=(const_iterator other) const noexcept { return m_value != other.m_value; } + + private: + + SEQUENCE m_value; + + }; + + constexpr sequence_range() noexcept + : m_begin() + , m_end() + {} + + constexpr sequence_range(SEQUENCE begin, SEQUENCE end) noexcept + : m_begin(begin) + , m_end(end) + {} + + constexpr const_iterator begin() const noexcept { return const_iterator(m_begin); } + constexpr const_iterator end() const noexcept { return const_iterator(m_end); } + + constexpr SEQUENCE front() const noexcept { return m_begin; } + constexpr SEQUENCE back() const noexcept { return m_end - 1; } + + constexpr size_type size() const noexcept + { + return static_cast(TRAITS::difference(m_end, m_begin)); + } + + constexpr bool empty() const noexcept + { + return m_begin == m_end; + } + + constexpr SEQUENCE operator[](size_type index) const noexcept + { + return m_begin + index; + } + + constexpr sequence_range first(size_type count) const noexcept + { + return sequence_range{ m_begin, static_cast(m_begin + std::min(size(), count)) }; + } + + constexpr sequence_range skip(size_type count) const noexcept + { + return sequence_range{ m_begin + std::min(size(), count), m_end }; + } + + private: + + SEQUENCE m_begin; + SEQUENCE m_end; + + }; +} + +#endif diff --git a/include/cppcoro/sequence_traits.hpp b/include/cppcoro/sequence_traits.hpp new file mode 100644 index 0000000..e06744e --- /dev/null +++ b/include/cppcoro/sequence_traits.hpp @@ -0,0 +1,33 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SEQUENCE_TRAITS_HPP_INCLUDED +#define CPPCORO_SEQUENCE_TRAITS_HPP_INCLUDED + +#include + +namespace cppcoro +{ + template + struct sequence_traits + { + using value_type = SEQUENCE; + using difference_type = std::make_signed_t; + using size_type = std::make_unsigned_t; + + static constexpr value_type initial_sequence = static_cast(-1); + + static constexpr difference_type difference(value_type a, value_type b) + { + return static_cast(a - b); + } + + static constexpr bool precedes(value_type a, value_type b) + { + return difference(a, b) < 0; + } + }; +} + +#endif diff --git a/include/cppcoro/shared_task.hpp b/include/cppcoro/shared_task.hpp new file mode 100644 index 0000000..f447829 --- /dev/null +++ b/include/cppcoro/shared_task.hpp @@ -0,0 +1,511 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SHARED_LAZY_TASK_HPP_INCLUDED +#define CPPCORO_SHARED_LAZY_TASK_HPP_INCLUDED + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include + +namespace cppcoro +{ + template + class shared_task; + + namespace detail + { + struct shared_task_waiter + { + cppcoro::coroutine_handle<> m_continuation; + shared_task_waiter* m_next; + }; + + class shared_task_promise_base + { + friend struct final_awaiter; + + struct final_awaiter + { + bool await_ready() const noexcept { return false; } + + template + void await_suspend(cppcoro::coroutine_handle h) noexcept + { + shared_task_promise_base& promise = h.promise(); + + // Exchange operation needs to be 'release' so that subsequent awaiters have + // visibility of the result. Also needs to be 'acquire' so we have visibility + // of writes to the waiters list. + void* const valueReadyValue = &promise; + void* waiters = promise.m_waiters.exchange(valueReadyValue, std::memory_order_acq_rel); + if (waiters != nullptr) + { + shared_task_waiter* waiter = static_cast(waiters); + while (waiter->m_next != nullptr) + { + // Read the m_next pointer before resuming the coroutine + // since resuming the coroutine may destroy the shared_task_waiter value. + auto* next = waiter->m_next; + waiter->m_continuation.resume(); + waiter = next; + } + + // Resume last waiter in tail position to allow it to potentially + // be compiled as a tail-call. + waiter->m_continuation.resume(); + } + } + + void await_resume() noexcept {} + }; + + public: + + shared_task_promise_base() noexcept + : m_refCount(1) + , m_waiters(&this->m_waiters) + , m_exception(nullptr) + {} + + cppcoro::suspend_always initial_suspend() noexcept { return {}; } + final_awaiter final_suspend() noexcept { return {}; } + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + bool is_ready() const noexcept + { + const void* const valueReadyValue = this; + return m_waiters.load(std::memory_order_acquire) == valueReadyValue; + } + + void add_ref() noexcept + { + m_refCount.fetch_add(1, std::memory_order_relaxed); + } + + /// Decrement the reference count. + /// + /// \return + /// true if successfully detached, false if this was the last + /// reference to the coroutine, in which case the caller must + /// call destroy() on the coroutine handle. + bool try_detach() noexcept + { + return m_refCount.fetch_sub(1, std::memory_order_acq_rel) != 1; + } + + /// Try to enqueue a waiter to the list of waiters. + /// + /// \param waiter + /// Pointer to the state from the waiter object. + /// Must have waiter->m_coroutine member populated with the coroutine + /// handle of the awaiting coroutine. + /// + /// \param coroutine + /// Coroutine handle for this promise object. + /// + /// \return + /// true if the waiter was successfully queued, in which case + /// waiter->m_coroutine will be resumed when the task completes. + /// false if the coroutine was already completed and the awaiting + /// coroutine can continue without suspending. + bool try_await(shared_task_waiter* waiter, cppcoro::coroutine_handle<> coroutine) + { + void* const valueReadyValue = this; + void* const notStartedValue = &this->m_waiters; + constexpr void* startedNoWaitersValue = static_cast(nullptr); + + // NOTE: If the coroutine is not yet started then the first waiter + // will start the coroutine before enqueuing itself up to the list + // of suspended waiters waiting for completion. We split this into + // two steps to allow the first awaiter to return without suspending. + // This avoids recursively resuming the first waiter inside the call to + // coroutine.resume() in the case that the coroutine completes + // synchronously, which could otherwise lead to stack-overflow if + // the awaiting coroutine awaited many synchronously-completing + // tasks in a row. + + // Start the coroutine if not already started. + void* oldWaiters = m_waiters.load(std::memory_order_acquire); + if (oldWaiters == notStartedValue && + m_waiters.compare_exchange_strong( + oldWaiters, + startedNoWaitersValue, + std::memory_order_relaxed)) + { + // Start the task executing. + coroutine.resume(); + oldWaiters = m_waiters.load(std::memory_order_acquire); + } + + // Enqueue the waiter into the list of waiting coroutines. + do + { + if (oldWaiters == valueReadyValue) + { + // Coroutine already completed, don't suspend. + return false; + } + + waiter->m_next = static_cast(oldWaiters); + } while (!m_waiters.compare_exchange_weak( + oldWaiters, + static_cast(waiter), + std::memory_order_release, + std::memory_order_acquire)); + + return true; + } + + protected: + + bool completed_with_unhandled_exception() + { + return m_exception != nullptr; + } + + void rethrow_if_unhandled_exception() + { + if (m_exception != nullptr) + { + std::rethrow_exception(m_exception); + } + } + + private: + + std::atomic m_refCount; + + // Value is either + // - nullptr - indicates started, no waiters + // - this - indicates value is ready + // - &this->m_waiters - indicates coroutine not started + // - other - pointer to head item in linked-list of waiters. + // values are of type 'cppcoro::shared_task_waiter'. + // indicates that the coroutine has been started. + std::atomic m_waiters; + + std::exception_ptr m_exception; + + }; + + template + class shared_task_promise : public shared_task_promise_base + { + public: + + shared_task_promise() noexcept = default; + + ~shared_task_promise() + { + if (this->is_ready() && !this->completed_with_unhandled_exception()) + { + reinterpret_cast(&m_valueStorage)->~T(); + } + } + + shared_task get_return_object() noexcept; + + template< + typename VALUE, + typename = std::enable_if_t>> + void return_value(VALUE&& value) + noexcept(std::is_nothrow_constructible_v) + { + new (&m_valueStorage) T(std::forward(value)); + } + + T& result() + { + this->rethrow_if_unhandled_exception(); + return *reinterpret_cast(&m_valueStorage); + } + + private: + + // Not using std::aligned_storage here due to bug in MSVC 2015 Update 2 + // that means it doesn't work for types with alignof(T) > 8. + // See MS-Connect bug #2658635. + alignas(T) char m_valueStorage[sizeof(T)]; + + }; + + template<> + class shared_task_promise : public shared_task_promise_base + { + public: + + shared_task_promise() noexcept = default; + + shared_task get_return_object() noexcept; + + void return_void() noexcept + {} + + void result() + { + this->rethrow_if_unhandled_exception(); + } + + }; + + template + class shared_task_promise : public shared_task_promise_base + { + public: + + shared_task_promise() noexcept = default; + + shared_task get_return_object() noexcept; + + void return_value(T& value) noexcept + { + m_value = std::addressof(value); + } + + T& result() + { + this->rethrow_if_unhandled_exception(); + return *m_value; + } + + private: + + T* m_value; + + }; + } + + template + class [[nodiscard]] shared_task + { + public: + + using promise_type = detail::shared_task_promise; + + using value_type = T; + + private: + + struct awaitable_base + { + cppcoro::coroutine_handle m_coroutine; + detail::shared_task_waiter m_waiter; + + awaitable_base(cppcoro::coroutine_handle coroutine) noexcept + : m_coroutine(coroutine) + {} + + bool await_ready() const noexcept + { + return !m_coroutine || m_coroutine.promise().is_ready(); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept + { + m_waiter.m_continuation = awaiter; + return m_coroutine.promise().try_await(&m_waiter, m_coroutine); + } + }; + + public: + + shared_task() noexcept + : m_coroutine(nullptr) + {} + + explicit shared_task(cppcoro::coroutine_handle coroutine) + : m_coroutine(coroutine) + { + // Don't increment the ref-count here since it has already been + // initialised to 2 (one for shared_task and one for coroutine) + // in the shared_task_promise constructor. + } + + shared_task(shared_task&& other) noexcept + : m_coroutine(other.m_coroutine) + { + other.m_coroutine = nullptr; + } + + shared_task(const shared_task& other) noexcept + : m_coroutine(other.m_coroutine) + { + if (m_coroutine) + { + m_coroutine.promise().add_ref(); + } + } + + ~shared_task() + { + destroy(); + } + + shared_task& operator=(shared_task&& other) noexcept + { + if (&other != this) + { + destroy(); + + m_coroutine = other.m_coroutine; + other.m_coroutine = nullptr; + } + + return *this; + } + + shared_task& operator=(const shared_task& other) noexcept + { + if (m_coroutine != other.m_coroutine) + { + destroy(); + + m_coroutine = other.m_coroutine; + + if (m_coroutine) + { + m_coroutine.promise().add_ref(); + } + } + + return *this; + } + + void swap(shared_task& other) noexcept + { + std::swap(m_coroutine, other.m_coroutine); + } + + /// \brief + /// Query if the task result is complete. + /// + /// Awaiting a task that is ready will not block. + bool is_ready() const noexcept + { + return !m_coroutine || m_coroutine.promise().is_ready(); + } + + auto operator co_await() const noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + decltype(auto) await_resume() + { + if (!this->m_coroutine) + { + throw broken_promise{}; + } + + return this->m_coroutine.promise().result(); + } + }; + + return awaitable{ m_coroutine }; + } + + /// \brief + /// Returns an awaitable that will await completion of the task without + /// attempting to retrieve the result. + auto when_ready() const noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + void await_resume() const noexcept {} + }; + + return awaitable{ m_coroutine }; + } + + private: + + template + friend bool operator==(const shared_task&, const shared_task&) noexcept; + + void destroy() noexcept + { + if (m_coroutine) + { + if (!m_coroutine.promise().try_detach()) + { + m_coroutine.destroy(); + } + } + } + + cppcoro::coroutine_handle m_coroutine; + + }; + + template + bool operator==(const shared_task& lhs, const shared_task& rhs) noexcept + { + return lhs.m_coroutine == rhs.m_coroutine; + } + + template + bool operator!=(const shared_task& lhs, const shared_task& rhs) noexcept + { + return !(lhs == rhs); + } + + template + void swap(shared_task& a, shared_task& b) noexcept + { + a.swap(b); + } + + namespace detail + { + template + shared_task shared_task_promise::get_return_object() noexcept + { + return shared_task{ + cppcoro::coroutine_handle::from_promise(*this) + }; + } + + template + shared_task shared_task_promise::get_return_object() noexcept + { + return shared_task{ + cppcoro::coroutine_handle::from_promise(*this) + }; + } + + inline shared_task shared_task_promise::get_return_object() noexcept + { + return shared_task{ + cppcoro::coroutine_handle::from_promise(*this) + }; + } + } + + template + auto make_shared_task(AWAITABLE awaitable) + -> shared_task::await_result_t>> + { + co_return co_await static_cast(awaitable); + } +} + +#endif diff --git a/include/cppcoro/single_consumer_async_auto_reset_event.hpp b/include/cppcoro/single_consumer_async_auto_reset_event.hpp new file mode 100644 index 0000000..56d8a9e --- /dev/null +++ b/include/cppcoro/single_consumer_async_auto_reset_event.hpp @@ -0,0 +1,101 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SINGLE_CONSUMER_ASYNC_AUTO_RESET_EVENT_HPP_INCLUDED +#define CPPCORO_SINGLE_CONSUMER_ASYNC_AUTO_RESET_EVENT_HPP_INCLUDED + +#include +#include +#include +#include + +namespace cppcoro +{ + class single_consumer_async_auto_reset_event + { + public: + + single_consumer_async_auto_reset_event(bool initiallySet = false) noexcept + : m_state(initiallySet ? this : nullptr) + {} + + void set() noexcept + { + void* oldValue = m_state.exchange(this, std::memory_order_release); + if (oldValue != nullptr && oldValue != this) + { + // There was a waiting coroutine that we now need to resume. + auto handle = *static_cast*>(oldValue); + + // We also need to transition the state back to 'not set' before + // resuming the coroutine. This operation needs to be 'acquire' + // so that it synchronises with other calls to .set() that execute + // concurrently with this call and execute the above m_state.exchange(this) + // operation with 'release' semantics. + // This needs to be an exchange() instead of a store() so that it can have + // 'acquire' semantics. + (void)m_state.exchange(nullptr, std::memory_order_acquire); + + // Finally, resume the waiting coroutine. + handle.resume(); + } + } + + auto operator co_await() const noexcept + { + class awaiter + { + public: + + awaiter(const single_consumer_async_auto_reset_event& event) noexcept + : m_event(event) + {} + + bool await_ready() const noexcept { return false; } + + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + m_awaitingCoroutine = awaitingCoroutine; + + void* oldValue = nullptr; + if (!m_event.m_state.compare_exchange_strong( + oldValue, + &m_awaitingCoroutine, + std::memory_order_release, + std::memory_order_relaxed)) + { + // This will only fail if the event was already 'set' + // In which case we can just reset back to 'not set' + // Need to use exchange() rather than store() here so we can make this + // operation an 'acquire' operation so that we get visibility of all + // writes prior to all preceding calls to .set(). + assert(oldValue == &m_event); + (void)m_event.m_state.exchange(nullptr, std::memory_order_acquire); + return false; + } + + return true; + } + + void await_resume() noexcept {} + + private: + const single_consumer_async_auto_reset_event& m_event; + cppcoro::coroutine_handle<> m_awaitingCoroutine; + }; + + return awaiter{ *this }; + } + + private: + + // nullptr - not set, no waiter + // this - set + // other - not set, pointer is address of a coroutine_handle<> to resume. + mutable std::atomic m_state; + + }; +} + +#endif diff --git a/include/cppcoro/single_consumer_event.hpp b/include/cppcoro/single_consumer_event.hpp new file mode 100644 index 0000000..983c5ea --- /dev/null +++ b/include/cppcoro/single_consumer_event.hpp @@ -0,0 +1,128 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SINGLE_CONSUMER_EVENT_HPP_INCLUDED +#define CPPCORO_SINGLE_CONSUMER_EVENT_HPP_INCLUDED + +#include +#include + +namespace cppcoro +{ + /// \brief + /// A manual-reset event that supports only a single awaiting + /// coroutine at a time. + /// + /// You can co_await the event to suspend the current coroutine until + /// some thread calls set(). If the event is already set then the + /// coroutine will not be suspended and will continue execution. + /// If the event was not yet set then the coroutine will be resumed + /// on the thread that calls set() within the call to set(). + /// + /// Callers must ensure that only one coroutine is executing a + /// co_await statement at any point in time. + class single_consumer_event + { + public: + + /// \brief + /// Construct a new event, initialising to either 'set' or 'not set' state. + /// + /// \param initiallySet + /// If true then initialises the event to the 'set' state. + /// Otherwise, initialised the event to the 'not set' state. + single_consumer_event(bool initiallySet = false) noexcept + : m_state(initiallySet ? state::set : state::not_set) + {} + + /// Query if this event has been set. + bool is_set() const noexcept + { + return m_state.load(std::memory_order_acquire) == state::set; + } + + /// \brief + /// Transition this event to the 'set' state if it is not already set. + /// + /// If there was a coroutine awaiting the event then it will be resumed + /// inside this call. + void set() + { + const state oldState = m_state.exchange(state::set, std::memory_order_acq_rel); + if (oldState == state::not_set_consumer_waiting) + { + m_awaiter.resume(); + } + } + + /// \brief + /// Transition this event to the 'non set' state if it was in the set state. + void reset() noexcept + { + state oldState = state::set; + m_state.compare_exchange_strong(oldState, state::not_set, std::memory_order_relaxed); + } + + /// \brief + /// Wait until the event becomes set. + /// + /// If the event is already set then the awaiting coroutine will not be suspended + /// and will continue execution. If the event was not yet set then the coroutine + /// will be suspended and will be later resumed inside a subsequent call to set() + /// on the thread that calls set(). + auto operator co_await() noexcept + { + class awaiter + { + public: + + awaiter(single_consumer_event& event) : m_event(event) {} + + bool await_ready() const noexcept + { + return m_event.is_set(); + } + + bool await_suspend(cppcoro::coroutine_handle<> awaiter) + { + m_event.m_awaiter = awaiter; + + state oldState = state::not_set; + return m_event.m_state.compare_exchange_strong( + oldState, + state::not_set_consumer_waiting, + std::memory_order_release, + std::memory_order_acquire); + } + + void await_resume() noexcept {} + + private: + + single_consumer_event& m_event; + + }; + + return awaiter{ *this }; + } + + private: + + enum class state + { + not_set, + not_set_consumer_waiting, + set + }; + + // TODO: Merge these two fields into a single std::atomic + // by encoding 'not_set' as 0 (nullptr), 'set' as 1 and + // 'not_set_consumer_waiting' as a coroutine handle pointer. + std::atomic m_state; + cppcoro::coroutine_handle<> m_awaiter; + + }; +} + +#endif diff --git a/include/cppcoro/single_producer_sequencer.hpp b/include/cppcoro/single_producer_sequencer.hpp new file mode 100644 index 0000000..77e8d58 --- /dev/null +++ b/include/cppcoro/single_producer_sequencer.hpp @@ -0,0 +1,246 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SINGLE_PRODUCER_SEQUENCER_HPP_INCLUDED +#define CPPCORO_SINGLE_PRODUCER_SEQUENCER_HPP_INCLUDED + +#include +#include +#include + +namespace cppcoro +{ + template + class single_producer_sequencer_claim_one_operation; + + template + class single_producer_sequencer_claim_operation; + + template< + typename SEQUENCE = std::size_t, + typename TRAITS = sequence_traits> + class single_producer_sequencer + { + public: + + using size_type = typename sequence_range::size_type; + + single_producer_sequencer( + const sequence_barrier& consumerBarrier, + std::size_t bufferSize, + SEQUENCE initialSequence = TRAITS::initial_sequence) noexcept + : m_consumerBarrier(consumerBarrier) + , m_bufferSize(bufferSize) + , m_nextToClaim(initialSequence + 1) + , m_producerBarrier(initialSequence) + {} + + /// Claim a slot in the ring buffer asynchronously. + /// + /// \return + /// Returns an operation that when awaited will suspend the coroutine until + /// a slot is available for writing in the ring buffer. The result of the + /// co_await expression will be the sequence number of the slot. + /// The caller must publish() the claimed sequence number once they have written to + /// the ring-buffer. + template + [[nodiscard]] + single_producer_sequencer_claim_one_operation + claim_one(SCHEDULER& scheduler) noexcept; + + /// Claim one or more contiguous slots in the ring-buffer. + /// + /// Use this method over many calls to claim_one() when you have multiple elements to + /// enqueue. This will claim as many slots as are available up to the specified count + /// but may claim as few as one slot if only one slot is available. + /// + /// \param count + /// The maximum number of slots to claim. + /// + /// \return + /// Returns an awaitable object that when awaited returns a sequence_range that contains + /// the range of sequence numbers that were claimed. Once you have written element values + /// to all of the claimed slots you must publish() the sequence range in order to make + /// the elements available to consumers. + template + [[nodiscard]] + single_producer_sequencer_claim_operation claim_up_to( + std::size_t count, SCHEDULER& scheduler) noexcept; + + /// Publish the specified sequence number. + /// + /// This also implies that all prior sequence numbers have already been published. + void publish(SEQUENCE sequence) noexcept + { + m_producerBarrier.publish(sequence); + } + + /// Publish a contiguous range of sequence numbers. + /// + /// You must have already published all prior sequence numbers. + /// + /// This is equivalent to just publishing the last sequence number in the range. + void publish(const sequence_range& sequences) noexcept + { + m_producerBarrier.publish(sequences.back()); + } + + /// Query what the last-published sequence number is. + /// + /// You can assume that all prior sequence numbers are also published. + SEQUENCE last_published() const noexcept + { + return m_producerBarrier.last_published(); + } + + /// Asynchronously wait until the specified sequence number is published. + /// + /// \param targetSequence + /// The sequence number to wait for. + /// + /// \return + /// Returns an Awaitable type that, when awaited, will suspend the awaiting coroutine until the + /// specified sequence number has been published. + /// + /// The result of the 'co_await barrier.wait_until_published(seq)' expression will be the + /// last-published sequence number, which is guaranteed to be at least 'seq' but may be some + /// subsequent sequence number if additional items were published while waiting for the + /// the requested sequence number to be published. + template + [[nodiscard]] + auto wait_until_published(SEQUENCE targetSequence, SCHEDULER& scheduler) const noexcept + { + return m_producerBarrier.wait_until_published(targetSequence, scheduler); + } + + private: + + template + friend class single_producer_sequencer_claim_operation; + + template + friend class single_producer_sequencer_claim_one_operation; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(push) +# pragma warning(disable : 4324) // C4324: structure was padded due to alignment specifier +#endif + + const sequence_barrier& m_consumerBarrier; + const std::size_t m_bufferSize; + + alignas(CPPCORO_CPU_CACHE_LINE) + SEQUENCE m_nextToClaim; + + sequence_barrier m_producerBarrier; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(pop) +#endif + }; + + template + class single_producer_sequencer_claim_one_operation + { + public: + + single_producer_sequencer_claim_one_operation( + single_producer_sequencer& sequencer, + SCHEDULER& scheduler) noexcept + : m_consumerWaitOperation( + sequencer.m_consumerBarrier, + static_cast(sequencer.m_nextToClaim - sequencer.m_bufferSize), + scheduler) + , m_sequencer(sequencer) + {} + + bool await_ready() const noexcept + { + return m_consumerWaitOperation.await_ready(); + } + + auto await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_consumerWaitOperation.await_suspend(awaitingCoroutine); + } + + SEQUENCE await_resume() const noexcept + { + return m_sequencer.m_nextToClaim++; + } + + private: + + sequence_barrier_wait_operation m_consumerWaitOperation; + single_producer_sequencer& m_sequencer; + + }; + + template + class single_producer_sequencer_claim_operation + { + public: + + explicit single_producer_sequencer_claim_operation( + single_producer_sequencer& sequencer, + std::size_t count, + SCHEDULER& scheduler) noexcept + : m_consumerWaitOperation( + sequencer.m_consumerBarrier, + static_cast(sequencer.m_nextToClaim - sequencer.m_bufferSize), + scheduler) + , m_sequencer(sequencer) + , m_count(count) + {} + + bool await_ready() const noexcept + { + return m_consumerWaitOperation.await_ready(); + } + + auto await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + return m_consumerWaitOperation.await_suspend(awaitingCoroutine); + } + + sequence_range await_resume() noexcept + { + const SEQUENCE lastAvailableSequence = + static_cast(m_consumerWaitOperation.await_resume() + m_sequencer.m_bufferSize); + const SEQUENCE begin = m_sequencer.m_nextToClaim; + const std::size_t availableCount = static_cast(lastAvailableSequence - begin) + 1; + const std::size_t countToClaim = std::min(m_count, availableCount); + const SEQUENCE end = static_cast(begin + countToClaim); + m_sequencer.m_nextToClaim = end; + return sequence_range(begin, end); + } + + private: + + sequence_barrier_wait_operation m_consumerWaitOperation; + single_producer_sequencer& m_sequencer; + std::size_t m_count; + + }; + + template + template + [[nodiscard]] + single_producer_sequencer_claim_one_operation + single_producer_sequencer::claim_one(SCHEDULER& scheduler) noexcept + { + return single_producer_sequencer_claim_one_operation{ *this, scheduler }; + } + + template + template + [[nodiscard]] + single_producer_sequencer_claim_operation + single_producer_sequencer::claim_up_to(std::size_t count, SCHEDULER& scheduler) noexcept + { + return single_producer_sequencer_claim_operation(*this, count, scheduler); + } +} + +#endif diff --git a/include/cppcoro/static_thread_pool.hpp b/include/cppcoro/static_thread_pool.hpp new file mode 100644 index 0000000..07a2078 --- /dev/null +++ b/include/cppcoro/static_thread_pool.hpp @@ -0,0 +1,116 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_STATIC_THREAD_POOL_HPP_INCLUDED +#define CPPCORO_STATIC_THREAD_POOL_HPP_INCLUDED + +#include +#include +#include +#include +#include +#include +#include + +namespace cppcoro +{ + class static_thread_pool + { + public: + + /// Initialise to a number of threads equal to the number of cores + /// on the current machine. + static_thread_pool(); + + /// Construct a thread pool with the specified number of threads. + /// + /// \param threadCount + /// The number of threads in the pool that will be used to execute work. + explicit static_thread_pool(std::uint32_t threadCount); + + ~static_thread_pool(); + + class schedule_operation + { + public: + + schedule_operation(static_thread_pool* tp) noexcept : m_threadPool(tp) {} + + bool await_ready() noexcept { return false; } + void await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept; + void await_resume() noexcept {} + + private: + + friend class static_thread_pool; + + static_thread_pool* m_threadPool; + cppcoro::coroutine_handle<> m_awaitingCoroutine; + schedule_operation* m_next; + + }; + + std::uint32_t thread_count() const noexcept { return m_threadCount; } + + [[nodiscard]] + schedule_operation schedule() noexcept { return schedule_operation{ this }; } + + private: + + friend class schedule_operation; + + void run_worker_thread(std::uint32_t threadIndex) noexcept; + + void shutdown(); + + void schedule_impl(schedule_operation* operation) noexcept; + + void remote_enqueue(schedule_operation* operation) noexcept; + + bool has_any_queued_work_for(std::uint32_t threadIndex) noexcept; + + bool approx_has_any_queued_work_for(std::uint32_t threadIndex) const noexcept; + + bool is_shutdown_requested() const noexcept; + + void notify_intent_to_sleep(std::uint32_t threadIndex) noexcept; + void try_clear_intent_to_sleep(std::uint32_t threadIndex) noexcept; + + schedule_operation* try_global_dequeue() noexcept; + + /// Try to steal a task from another thread. + /// + /// \return + /// A pointer to the operation that was stolen if one could be stolen + /// from another thread. Otherwise returns nullptr if none of the other + /// threads had any tasks that could be stolen. + schedule_operation* try_steal_from_other_thread(std::uint32_t thisThreadIndex) noexcept; + + void wake_one_thread() noexcept; + + class thread_state; + + static thread_local thread_state* s_currentState; + static thread_local static_thread_pool* s_currentThreadPool; + + const std::uint32_t m_threadCount; + const std::unique_ptr m_threadStates; + + std::vector m_threads; + + std::atomic m_stopRequested; + + std::mutex m_globalQueueMutex; + std::atomic m_globalQueueHead; + + //alignas(std::hardware_destructive_interference_size) + std::atomic m_globalQueueTail; + + //alignas(std::hardware_destructive_interference_size) + std::atomic m_sleepingThreadCount; + + }; +} + +#endif diff --git a/include/cppcoro/sync_wait.hpp b/include/cppcoro/sync_wait.hpp new file mode 100644 index 0000000..c59bcda --- /dev/null +++ b/include/cppcoro/sync_wait.hpp @@ -0,0 +1,50 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SYNC_WAIT_HPP_INCLUDED +#define CPPCORO_SYNC_WAIT_HPP_INCLUDED + +#include +#include +#include + +#include +#include +#include + +namespace cppcoro +{ + template + auto sync_wait(AWAITABLE&& awaitable) + -> typename cppcoro::awaitable_traits::await_result_t + { + auto task = detail::make_sync_wait_task(std::forward(awaitable)); + detail::lightweight_manual_reset_event event; + task.start(event); + event.wait(); + return task.result(); + } + template + auto sync_wait(AWAITABLE&& awaitable, io_service& srv, std::chrono::system_clock::duration step) + -> typename cppcoro::awaitable_traits::await_result_t + { + auto task = detail::make_sync_wait_task(std::forward(awaitable)); + detail::lightweight_manual_reset_event event; + task.start(event); + event.wait({ &srv, 1 }, step); + return task.result(); + } + template + auto sync_wait(AWAITABLE&& awaitable, std::span srvs, std::chrono::system_clock::duration step) + -> typename cppcoro::awaitable_traits::await_result_t + { + auto task = detail::make_sync_wait_task(std::forward(awaitable)); + detail::lightweight_manual_reset_event event; + task.start(event); + event.wait(srvs, step); + return task.result(); + } +} + +#endif diff --git a/include/cppcoro/task.hpp b/include/cppcoro/task.hpp new file mode 100644 index 0000000..dd261cb --- /dev/null +++ b/include/cppcoro/task.hpp @@ -0,0 +1,481 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_TASK_HPP_INCLUDED +#define CPPCORO_TASK_HPP_INCLUDED + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace cppcoro +{ + template class task; + + namespace detail + { + class task_promise_base + { + friend struct final_awaitable; + + struct final_awaitable + { + bool await_ready() const noexcept { return false; } + +#if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + template + cppcoro::coroutine_handle<> await_suspend( + cppcoro::coroutine_handle coro) noexcept + { + return coro.promise().m_continuation; + } +#else + // HACK: Need to add CPPCORO_NOINLINE to await_suspend() method + // to avoid MSVC 2017.8 from spilling some local variables in + // await_suspend() onto the coroutine frame in some cases. + // Without this, some tests in async_auto_reset_event_tests.cpp + // were crashing under x86 optimised builds. + template + CPPCORO_NOINLINE + void await_suspend(cppcoro::coroutine_handle coroutine) noexcept + { + task_promise_base& promise = coroutine.promise(); + + // Use 'release' memory semantics in case we finish before the + // awaiter can suspend so that the awaiting thread sees our + // writes to the resulting value. + // Use 'acquire' memory semantics in case the caller registered + // the continuation before we finished. Ensure we see their write + // to m_continuation. + if (promise.m_state.exchange(true, std::memory_order_acq_rel)) + { + promise.m_continuation.resume(); + } + } +#endif + + void await_resume() noexcept {} + }; + + public: + + task_promise_base() noexcept +#if !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + : m_state(false) +#endif + {} + + auto initial_suspend() noexcept + { + return cppcoro::suspend_always{}; + } + + auto final_suspend() noexcept + { + return final_awaitable{}; + } + +#if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + void set_continuation(cppcoro::coroutine_handle<> continuation) noexcept + { + m_continuation = continuation; + } +#else + bool try_set_continuation(cppcoro::coroutine_handle<> continuation) + { + m_continuation = continuation; + return !m_state.exchange(true, std::memory_order_acq_rel); + } +#endif + + private: + + cppcoro::coroutine_handle<> m_continuation; + +#if !CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + // Initially false. Set to true when either a continuation is registered + // or when the coroutine has run to completion. Whichever operation + // successfully transitions from false->true got there first. + std::atomic m_state; +#endif + + }; + + template + class task_promise final : public task_promise_base + { + public: + + task_promise() noexcept {} + + ~task_promise() + { + switch (m_resultType) + { + case result_type::value: + m_value.~T(); + break; + case result_type::exception: + m_exception.~exception_ptr(); + break; + default: + break; + } + } + + task get_return_object() noexcept; + + void unhandled_exception() noexcept + { + ::new (static_cast(std::addressof(m_exception))) std::exception_ptr( + std::current_exception()); + m_resultType = result_type::exception; + } + + template< + typename VALUE, + typename = std::enable_if_t>> + void return_value(VALUE&& value) + noexcept(std::is_nothrow_constructible_v) + { + ::new (static_cast(std::addressof(m_value))) T(std::forward(value)); + m_resultType = result_type::value; + } + + T& result() & + { + if (m_resultType == result_type::exception) + { + std::rethrow_exception(m_exception); + } + + assert(m_resultType == result_type::value); + + return m_value; + } + + // HACK: Need to have co_await of task return prvalue rather than + // rvalue-reference to work around an issue with MSVC where returning + // rvalue reference of a fundamental type from await_resume() will + // cause the value to be copied to a temporary. This breaks the + // sync_wait() implementation. + // See https://github.com/lewissbaker/cppcoro/issues/40#issuecomment-326864107 + using rvalue_type = std::conditional_t< + std::is_arithmetic_v || std::is_pointer_v, + T, + T&&>; + + rvalue_type result() && + { + if (m_resultType == result_type::exception) + { + std::rethrow_exception(m_exception); + } + + assert(m_resultType == result_type::value); + + return std::move(m_value); + } + + private: + + enum class result_type { empty, value, exception }; + + result_type m_resultType = result_type::empty; + + union + { + T m_value; + std::exception_ptr m_exception; + }; + + }; + + template<> + class task_promise : public task_promise_base + { + public: + + task_promise() noexcept = default; + + task get_return_object() noexcept; + + void return_void() noexcept + {} + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + void result() + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + } + + private: + + std::exception_ptr m_exception; + + }; + + template + class task_promise : public task_promise_base + { + public: + + task_promise() noexcept = default; + + task get_return_object() noexcept; + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + void return_value(T& value) noexcept + { + m_value = std::addressof(value); + } + + T& result() + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + + return *m_value; + } + + private: + + T* m_value = nullptr; + std::exception_ptr m_exception; + + }; + } + + /// \brief + /// A task represents an operation that produces a result both lazily + /// and asynchronously. + /// + /// When you call a coroutine that returns a task, the coroutine + /// simply captures any passed parameters and returns exeuction to the + /// caller. Execution of the coroutine body does not start until the + /// coroutine is first co_await'ed. + template + class [[nodiscard]] task + { + public: + + using promise_type = detail::task_promise; + + using value_type = T; + + private: + + struct awaitable_base + { + cppcoro::coroutine_handle m_coroutine; + + awaitable_base(cppcoro::coroutine_handle coroutine) noexcept + : m_coroutine(coroutine) + {} + + bool await_ready() const noexcept + { + return !m_coroutine || m_coroutine.done(); + } + +#if CPPCORO_COMPILER_SUPPORTS_SYMMETRIC_TRANSFER + cppcoro::coroutine_handle<> await_suspend( + cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + m_coroutine.promise().set_continuation(awaitingCoroutine); + return m_coroutine; + } +#else + bool await_suspend(cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + // NOTE: We are using the bool-returning version of await_suspend() here + // to work around a potential stack-overflow issue if a coroutine + // awaits many synchronously-completing tasks in a loop. + // + // We first start the task by calling resume() and then conditionally + // attach the continuation if it has not already completed. This allows us + // to immediately resume the awaiting coroutine without increasing + // the stack depth, avoiding the stack-overflow problem. However, it has + // the down-side of requiring a std::atomic to arbitrate the race between + // the coroutine potentially completing on another thread concurrently + // with registering the continuation on this thread. + // + // We can eliminate the use of the std::atomic once we have access to + // coroutine_handle-returning await_suspend() on both MSVC and Clang + // as this will provide ability to suspend the awaiting coroutine and + // resume another coroutine with a guaranteed tail-call to resume(). + m_coroutine.resume(); + return m_coroutine.promise().try_set_continuation(awaitingCoroutine); + } +#endif + }; + + public: + + task() noexcept + : m_coroutine(nullptr) + {} + + explicit task(cppcoro::coroutine_handle coroutine) + : m_coroutine(coroutine) + {} + + task(task&& t) noexcept + : m_coroutine(t.m_coroutine) + { + t.m_coroutine = nullptr; + } + + /// Disable copy construction/assignment. + task(const task&) = delete; + task& operator=(const task&) = delete; + + /// Frees resources used by this task. + ~task() + { + if (m_coroutine) + { + m_coroutine.destroy(); + } + } + + task& operator=(task&& other) noexcept + { + if (std::addressof(other) != this) + { + if (m_coroutine) + { + m_coroutine.destroy(); + } + + m_coroutine = other.m_coroutine; + other.m_coroutine = nullptr; + } + + return *this; + } + + /// \brief + /// Query if the task result is complete. + /// + /// Awaiting a task that is ready is guaranteed not to block/suspend. + bool is_ready() const noexcept + { + return !m_coroutine || m_coroutine.done(); + } + + auto operator co_await() const & noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + decltype(auto) await_resume() + { + if (!this->m_coroutine) + { + throw broken_promise{}; + } + + return this->m_coroutine.promise().result(); + } + }; + + return awaitable{ m_coroutine }; + } + + auto operator co_await() const && noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + decltype(auto) await_resume() + { + if (!this->m_coroutine) + { + throw broken_promise{}; + } + + return std::move(this->m_coroutine.promise()).result(); + } + }; + + return awaitable{ m_coroutine }; + } + + /// \brief + /// Returns an awaitable that will await completion of the task without + /// attempting to retrieve the result. + auto when_ready() const noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + void await_resume() const noexcept {} + }; + + return awaitable{ m_coroutine }; + } + + private: + + cppcoro::coroutine_handle m_coroutine; + + }; + + namespace detail + { + template + task task_promise::get_return_object() noexcept + { + return task{ cppcoro::coroutine_handle::from_promise(*this) }; + } + + inline task task_promise::get_return_object() noexcept + { + return task{ cppcoro::coroutine_handle::from_promise(*this) }; + } + + template + task task_promise::get_return_object() noexcept + { + return task{ cppcoro::coroutine_handle::from_promise(*this) }; + } + } + + template + auto make_task(AWAITABLE awaitable) + -> task::await_result_t>> + { + co_return co_await static_cast(awaitable); + } +} + +#endif diff --git a/include/cppcoro/when_all.hpp b/include/cppcoro/when_all.hpp new file mode 100644 index 0000000..f0c1be8 --- /dev/null +++ b/include/cppcoro/when_all.hpp @@ -0,0 +1,91 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_WHEN_ALL_HPP_INCLUDED +#define CPPCORO_WHEN_ALL_HPP_INCLUDED + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace cppcoro +{ + ////////// + // Variadic when_all() + + template< + typename... AWAITABLES, + std::enable_if_t< + std::conjunction_v>>...>, + int> = 0> + [[nodiscard]] auto when_all(AWAITABLES&&... awaitables) + { + return fmap([](auto&& taskTuple) + { + return std::apply([](auto&&... tasks) { + return std::make_tuple(static_cast(tasks).non_void_result()...); + }, static_cast(taskTuple)); + }, when_all_ready(std::forward(awaitables)...)); + } + + ////////// + // when_all() with vector of awaitable + + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits>::await_result_t, + std::enable_if_t, int> = 0> + [[nodiscard]] + auto when_all(std::vector awaitables) + { + return fmap([](auto&& taskVector) { + for (auto& task : taskVector) + { + task.result(); + } + }, when_all_ready(std::move(awaitables))); + } + + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits>::await_result_t, + std::enable_if_t, int> = 0> + [[nodiscard]] + auto when_all(std::vector awaitables) + { + using result_t = std::conditional_t< + std::is_lvalue_reference_v, + std::reference_wrapper>, + std::remove_reference_t>; + + return fmap([](auto&& taskVector) { + std::vector results; + results.reserve(taskVector.size()); + for (auto& task : taskVector) + { + if constexpr (std::is_rvalue_reference_v) + { + results.emplace_back(std::move(task).result()); + } + else + { + results.emplace_back(task.result()); + } + } + return results; + }, when_all_ready(std::move(awaitables))); + } +} + +#endif diff --git a/include/cppcoro/when_all_ready.hpp b/include/cppcoro/when_all_ready.hpp new file mode 100644 index 0000000..dccc80f --- /dev/null +++ b/include/cppcoro/when_all_ready.hpp @@ -0,0 +1,56 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_WHEN_ALL_READY_HPP_INCLUDED +#define CPPCORO_WHEN_ALL_READY_HPP_INCLUDED + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace cppcoro +{ + template< + typename... AWAITABLES, + std::enable_if_t>>...>, int> = 0> + [[nodiscard]] + CPPCORO_FORCE_INLINE auto when_all_ready(AWAITABLES&&... awaitables) + { + return detail::when_all_ready_awaitable>>::await_result_t>...>>( + std::make_tuple(detail::make_when_all_task(std::forward(awaitables))...)); + } + + // TODO: Generalise this from vector to arbitrary sequence of awaitable. + + template< + typename AWAITABLE, + typename RESULT = typename awaitable_traits>::await_result_t> + [[nodiscard]] auto when_all_ready(std::vector awaitables) + { + std::vector> tasks; + + tasks.reserve(awaitables.size()); + + for (auto& awaitable : awaitables) + { + tasks.emplace_back(detail::make_when_all_task(std::move(awaitable))); + } + + return detail::when_all_ready_awaitable>>( + std::move(tasks)); + } +} + +#endif diff --git a/include/cppcoro/writable_file.hpp b/include/cppcoro/writable_file.hpp new file mode 100644 index 0000000..40ef7fe --- /dev/null +++ b/include/cppcoro/writable_file.hpp @@ -0,0 +1,71 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_WRITABLE_FILE_HPP_INCLUDED +#define CPPCORO_WRITABLE_FILE_HPP_INCLUDED + +#include +#include +#include + +namespace cppcoro +{ + class writable_file : virtual public file + { + public: + + /// Set the size of the file. + /// + /// \param fileSize + /// The new size of the file in bytes. + void set_size(std::uint64_t fileSize); + + /// Write some data to the file. + /// + /// Writes \a byteCount bytes from the file starting at \a offset + /// into the specified \a buffer. + /// + /// \param offset + /// The offset within the file to start writing from. + /// If the file has been opened using file_buffering_mode::unbuffered + /// then the offset must be a multiple of the file-system's sector size. + /// + /// \param buffer + /// The buffer containing the data to be written to the file. + /// If the file has been opened using file_buffering_mode::unbuffered + /// then the address of the start of the buffer must be a multiple of + /// the file-system's sector size. + /// + /// \param byteCount + /// The number of bytes to write to the file. + /// If the file has been opeend using file_buffering_mode::unbuffered + /// then the byteCount must be a multiple of the file-system's sector size. + /// + /// \param ct + /// An optional cancellation_token that can be used to cancel the + /// write operation before it completes. + /// + /// \return + /// An object that represents the write operation. + /// This object must be co_await'ed to start the write operation. + [[nodiscard]] + file_write_operation write( + std::uint64_t offset, + const void* buffer, + std::size_t byteCount) noexcept; + [[nodiscard]] + file_write_operation_cancellable write( + std::uint64_t offset, + const void* buffer, + std::size_t byteCount, + cancellation_token ct) noexcept; + + protected: + + using file::file; + + }; +} + +#endif diff --git a/include/cppcoro/write_only_file.hpp b/include/cppcoro/write_only_file.hpp new file mode 100644 index 0000000..1af6657 --- /dev/null +++ b/include/cppcoro/write_only_file.hpp @@ -0,0 +1,65 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_WRITE_ONLY_FILE_HPP_INCLUDED +#define CPPCORO_WRITE_ONLY_FILE_HPP_INCLUDED + +#include +#include +#include +#include + +#include + +namespace cppcoro +{ + class write_only_file : public writable_file + { + public: + + /// Open a file for write-only access. + /// + /// \param ioContext + /// The I/O context to use when dispatching I/O completion events. + /// When asynchronous write operations on this file complete the + /// completion events will be dispatched to an I/O thread associated + /// with the I/O context. + /// + /// \param pathMode + /// Path of the file to open. + /// + /// \param openMode + /// Specifies how the file should be opened and how to handle cases + /// when the file exists or doesn't exist. + /// + /// \param shareMode + /// Specifies the access to be allowed on the file concurrently with this file access. + /// + /// \param bufferingMode + /// Specifies the modes/hints to provide to the OS that affects the behaviour + /// of its file buffering. + /// + /// \return + /// An object that can be used to write to the file. + /// + /// \throw std::system_error + /// If the file could not be opened for write. + [[nodiscard]] + static write_only_file open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode = file_open_mode::create_or_open, + file_share_mode shareMode = file_share_mode::none, + file_buffering_mode bufferingMode = file_buffering_mode::default_); + + protected: + +#if CPPCORO_OS_WINNT + write_only_file(detail::win32::safe_handle&& fileHandle) noexcept; +#endif + + }; +} + +#endif diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt new file mode 100644 index 0000000..3ea6a31 --- /dev/null +++ b/lib/CMakeLists.txt @@ -0,0 +1,178 @@ +set(includes + awaitable_traits.hpp + is_awaitable.hpp + async_auto_reset_event.hpp + async_manual_reset_event.hpp + async_generator.hpp + async_mutex.hpp + async_latch.hpp + async_scope.hpp + broken_promise.hpp + cancellation_registration.hpp + cancellation_source.hpp + cancellation_token.hpp + task.hpp + sequence_barrier.hpp + sequence_traits.hpp + single_producer_sequencer.hpp + multi_producer_sequencer.hpp + shared_task.hpp + shared_task.hpp + single_consumer_event.hpp + single_consumer_async_auto_reset_event.hpp + sync_wait.hpp + task.hpp + io_service.hpp + config.hpp + on_scope_exit.hpp + file_share_mode.hpp + file_open_mode.hpp + file_buffering_mode.hpp + file.hpp + fmap.hpp + when_all.hpp + when_all_ready.hpp + resume_on.hpp + schedule_on.hpp + generator.hpp + readable_file.hpp + recursive_generator.hpp + writable_file.hpp + read_only_file.hpp + write_only_file.hpp + read_write_file.hpp + file_read_operation.hpp + file_write_operation.hpp + static_thread_pool.hpp +) +list(TRANSFORM includes PREPEND "${PROJECT_SOURCE_DIR}/include/cppcoro/") + +set(netIncludes + ip_address.hpp + ip_endpoint.hpp + ipv4_address.hpp + ipv4_endpoint.hpp + ipv6_address.hpp + ipv6_endpoint.hpp + socket.hpp +) +list(TRANSFORM netIncludes PREPEND "${PROJECT_SOURCE_DIR}/include/cppcoro/net/") + +set(detailIncludes + void_value.hpp + when_all_ready_awaitable.hpp + when_all_counter.hpp + when_all_task.hpp + get_awaiter.hpp + is_awaiter.hpp + any.hpp + sync_wait_task.hpp + unwrap_reference.hpp + lightweight_manual_reset_event.hpp +) +list(TRANSFORM detailIncludes PREPEND "${PROJECT_SOURCE_DIR}/include/cppcoro/detail/") + +set(privateHeaders + cancellation_state.hpp + socket_helpers.hpp + auto_reset_event.hpp + spin_wait.hpp + spin_mutex.hpp +) + +set(sources + async_auto_reset_event.cpp + async_manual_reset_event.cpp + async_mutex.cpp + cancellation_state.cpp + cancellation_token.cpp + cancellation_source.cpp + cancellation_registration.cpp + lightweight_manual_reset_event.cpp + ip_address.cpp + ip_endpoint.cpp + ipv4_address.cpp + ipv4_endpoint.cpp + ipv6_address.cpp + ipv6_endpoint.cpp + static_thread_pool.cpp + auto_reset_event.cpp + spin_wait.cpp + spin_mutex.cpp +) + +if(WIN32) + set(win32DetailIncludes + win32.hpp + win32_overlapped_operation.hpp + ) + list(TRANSFORM win32DetailIncludes PREPEND "${PROJECT_SOURCE_DIR}/include/cppcoro/detail/") + list(APPEND detailIncludes ${win32DetailIncludes}) + + set(win32NetIncludes + socket.hpp + socket_accept_operation.hpp + socket_connect_operation.hpp + socket_disconnect_operation.hpp + socket_recv_operation.hpp + socket_recv_from_operation.hpp + socket_send_operation.hpp + socket_send_to_operation.hpp + ) + list(TRANSFORM win32NetIncludes PREPEND "${PROJECT_SOURCE_DIR}/include/cppcoro/net/") + list(APPEND netIncludes ${win32NetIncludes}) + + set(win32Sources + win32.cpp + io_service.cpp + file.cpp + readable_file.cpp + writable_file.cpp + read_only_file.cpp + write_only_file.cpp + read_write_file.cpp + file_read_operation.cpp + file_write_operation.cpp + socket_helpers.cpp + socket.cpp + socket_accept_operation.cpp + socket_connect_operation.cpp + socket_disconnect_operation.cpp + socket_send_operation.cpp + socket_send_to_operation.cpp + socket_recv_operation.cpp + socket_recv_from_operation.cpp + ) + list(APPEND sources ${win32Sources}) + + list(APPEND libraries Ws2_32 Mswsock Synchronization) + list(APPEND compile_options /EHsc) + + if("${MSVC_VERSION}" VERSION_GREATER_EQUAL 1900) + # TODO remove this when experimental/non-experimental include are fixed + list(APPEND compile_definition _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING=1) + endif() +endif() + +add_library(cppcoro + ${includes} + ${netIncludes} + ${detailIncludes} + ${privateHeaders} + ${sources} +) + +target_include_directories(cppcoro PUBLIC + $ + $) + +target_compile_definitions(cppcoro PUBLIC ${compile_definition}) +target_compile_options(cppcoro PUBLIC ${compile_options}) + +#find_package(Coroutines COMPONENTS Experimental Final REQUIRED) +#target_link_libraries(cppcoro PUBLIC std::coroutines ${libraries}) + +install(TARGETS cppcoro EXPORT cppcoroTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) diff --git a/lib/async_auto_reset_event.cpp b/lib/async_auto_reset_event.cpp new file mode 100644 index 0000000..fa0bb73 --- /dev/null +++ b/lib/async_auto_reset_event.cpp @@ -0,0 +1,285 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include + +#include +#include + +namespace +{ + namespace local + { + // Some helpers for manipulating the 'm_state' value. + + constexpr std::uint64_t set_increment = 1; + constexpr std::uint64_t waiter_increment = std::uint64_t(1) << 32; + + constexpr std::uint32_t get_set_count(std::uint64_t state) + { + return static_cast(state); + } + + constexpr std::uint32_t get_waiter_count(std::uint64_t state) + { + return static_cast(state >> 32); + } + + constexpr std::uint32_t get_resumable_waiter_count(std::uint64_t state) + { + return std::min(get_set_count(state), get_waiter_count(state)); + } + } +} + +cppcoro::async_auto_reset_event::async_auto_reset_event(bool initiallySet) noexcept + : m_state(initiallySet ? local::set_increment : 0) + , m_newWaiters(nullptr) + , m_waiters(nullptr) +{ +} + +cppcoro::async_auto_reset_event::~async_auto_reset_event() +{ + assert(m_newWaiters.load(std::memory_order_relaxed) == nullptr); + assert(m_waiters == nullptr); +} + +cppcoro::async_auto_reset_event_operation +cppcoro::async_auto_reset_event::operator co_await() const noexcept +{ + std::uint64_t oldState = m_state.load(std::memory_order_relaxed); + if (local::get_set_count(oldState) > local::get_waiter_count(oldState)) + { + // Try to synchronously acquire the event. + if (m_state.compare_exchange_strong( + oldState, + oldState - local::set_increment, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + // Acquired the event, return an operation object that + // won't suspend. + return async_auto_reset_event_operation{}; + } + } + + return async_auto_reset_event_operation{ *this }; +} + +void cppcoro::async_auto_reset_event::set() noexcept +{ + std::uint64_t oldState = m_state.load(std::memory_order_relaxed); + do + { + if (local::get_set_count(oldState) > local::get_waiter_count(oldState)) + { + // Already set. + return; + } + + // Increment the set-count + } while (!m_state.compare_exchange_weak( + oldState, + oldState + local::set_increment, + std::memory_order_acq_rel, + std::memory_order_acquire)); + + // Did we transition from non-zero waiters and zero set-count + // to non-zero set-count? + // If so then we acquired the lock and are responsible for resuming waiters. + if (oldState != 0 && local::get_set_count(oldState) == 0) + { + // We acquired the lock. + resume_waiters(oldState + local::set_increment); + } +} + +void cppcoro::async_auto_reset_event::reset() noexcept +{ + std::uint64_t oldState = m_state.load(std::memory_order_relaxed); + while (local::get_set_count(oldState) > local::get_waiter_count(oldState)) + { + if (m_state.compare_exchange_weak( + oldState, + oldState - local::set_increment, + std::memory_order_relaxed)) + { + // Successfully reset. + return; + } + } + + // Not set. Nothing to do. +} + +void cppcoro::async_auto_reset_event::resume_waiters( + std::uint64_t initialState) const noexcept +{ + async_auto_reset_event_operation* waitersToResumeList = nullptr; + async_auto_reset_event_operation** waitersToResumeListEnd = &waitersToResumeList; + + std::uint32_t waiterCountToResume = local::get_resumable_waiter_count(initialState); + + assert(waiterCountToResume > 0); + + do + { + // Dequeue 'waiterCountToResume' from m_waiters/m_newWaiters and + // push them onto 'waitersToResumeList'. + for (std::uint32_t i = 0; i < waiterCountToResume; ++i) + { + if (m_waiters == nullptr) + { + // We've run out of of waiters that we can consume without synchronisation + // Dequeue the list of new waiters atomically. + auto* newWaiters = m_newWaiters.exchange(nullptr, std::memory_order_acquire); + + // There should always be enough waiters in the list as + // the waiters are queued before the waiter-count is incremented. + assert(newWaiters != nullptr); + CPPCORO_ASSUME(newWaiters != nullptr); + + // Reverse order of new waiters so they are resumed in FIFO. + // This ensures fairness. + // + // The alternative would be to not reverse the list and instead + // resume waiters in the reverse order they were queued in. + // This might result in better cache locality (most recently + // suspended coroutine might still be in cache). + // It should still provide a bounded wait time as well since we + // are guaranteed to process all waiters in this list before + // looking at any waiters newly queued after this point. + // Something to consider. + do + { + auto* next = newWaiters->m_next; + newWaiters->m_next = m_waiters; + m_waiters = newWaiters; + newWaiters = next; + } while (newWaiters != nullptr); + } + + assert(m_waiters != nullptr); + + // Pop the next waiter off the list + auto* waiterToResume = m_waiters; + m_waiters = m_waiters->m_next; + + // Push it onto the end of the list of waiters to resume + waiterToResume->m_next = nullptr; + *waitersToResumeListEnd = waiterToResume; + waitersToResumeListEnd = &waiterToResume->m_next; + } + + // We've now removed 'waiterCountToResume' waiters from the list + // so we can now decrement both the waiter and set count. + // + // However, there might have been more waiters or more calls to + // set() since we last checked so we need to go around again if + // there are still waiters that are ready to resume after decrementing + // both the 'waiter count' and 'set count' by 'waiterCountToResume'. + const std::uint64_t delta = + std::uint64_t(waiterCountToResume) | + std::uint64_t(waiterCountToResume) << 32; + + // Needs to be 'release' as we're releasing the lock and anyone that + // subsequently acquires the lock needs to see our prior writes to + // m_waiters. + // Needs to be 'acquire' in the case that new waiters were added so + // that we see their prior writes to 'm_newWaiters'. + const std::uint64_t newState = + m_state.fetch_sub(delta, std::memory_order_acq_rel) - delta; + + waiterCountToResume = local::get_resumable_waiter_count(newState); + } while (waiterCountToResume > 0); + + // Now resume all of the waiters we've dequeued. + // There should be at least one. + assert(waitersToResumeList != nullptr); + CPPCORO_ASSUME(waitersToResumeList != nullptr); + + do + { + auto* const waiter = waitersToResumeList; + + // Read 'next' before resuming since resuming the waiter is + // likely to destroy the waiter object. + auto* const next = waitersToResumeList->m_next; + + // Decrement reference count and see if we decremented the last + // reference and if so then we are responsible for resuming. + // If not, then await_suspend() is responsible for resuming by + // returning 'false' and not suspending. + if (waiter->m_refCount.fetch_sub(1, std::memory_order_release) == 1) + { + waiter->m_awaiter.resume(); + } + + waitersToResumeList = next; + } while (waitersToResumeList != nullptr); +} + +cppcoro::async_auto_reset_event_operation::async_auto_reset_event_operation() noexcept + : m_event(nullptr) +{} + +cppcoro::async_auto_reset_event_operation::async_auto_reset_event_operation( + const async_auto_reset_event& event) noexcept + : m_event(&event) + , m_refCount(2) +{} + +cppcoro::async_auto_reset_event_operation::async_auto_reset_event_operation( + const async_auto_reset_event_operation& other) noexcept + : m_event(other.m_event) + , m_refCount(2) +{} + +bool cppcoro::async_auto_reset_event_operation::await_suspend( + cppcoro::coroutine_handle<> awaiter) noexcept +{ + m_awaiter = awaiter; + + // Queue the waiter to the m_newWaiters list. + async_auto_reset_event_operation* head = m_event->m_newWaiters.load(std::memory_order_relaxed); + do + { + m_next = head; + } while (!m_event->m_newWaiters.compare_exchange_weak( + head, + this, + std::memory_order_release, + std::memory_order_relaxed)); + + // Increment the waiter count. + // Needs to be 'release' so that our prior write to m_newWaiters is + // visible to anyone that acquires the lock. + // Needs to be 'acquire' in case we acquired the lock so we can see + // others' writes to m_newWaiters and writes prior to set() calls. + const std::uint64_t oldState = + m_event->m_state.fetch_add(local::waiter_increment, std::memory_order_acq_rel); + + if (oldState != 0 && local::get_waiter_count(oldState) == 0) + { + // We transitioned from non-zero set and zero waiters to + // non-zero set and non-zero waiters, so we acquired the lock + // and thus responsibility for resuming waiters. + m_event->resume_waiters(oldState + local::waiter_increment); + } + + // Decrement the ref-count to indicate that this waiter is now safe + // to resume. We don't want it to resume while we're still accessing the + // m_event object as resuming it might cause the event object to be + // destructed. + // + // Need 'acquire' semantics here in the case that another thread has + // concurrently dequeued us and scheduled us for resumption by decrementing + // the ref-count with 'release' semantics so that we see the writes prior + // to the 'set()' call that released this waiter. + return m_refCount.fetch_sub(1, std::memory_order_acquire) != 1; +} diff --git a/lib/async_manual_reset_event.cpp b/lib/async_manual_reset_event.cpp new file mode 100644 index 0000000..f663b00 --- /dev/null +++ b/lib/async_manual_reset_event.cpp @@ -0,0 +1,99 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include + +#include + +cppcoro::async_manual_reset_event::async_manual_reset_event(bool initiallySet) noexcept + : m_state(initiallySet ? static_cast(this) : nullptr) +{} + +cppcoro::async_manual_reset_event::~async_manual_reset_event() +{ + // There should be no coroutines still awaiting the event. + assert( + m_state.load(std::memory_order_relaxed) == nullptr || + m_state.load(std::memory_order_relaxed) == static_cast(this)); +} + +bool cppcoro::async_manual_reset_event::is_set() const noexcept +{ + return m_state.load(std::memory_order_acquire) == static_cast(this); +} + +cppcoro::async_manual_reset_event_operation +cppcoro::async_manual_reset_event::operator co_await() const noexcept +{ + return async_manual_reset_event_operation{ *this }; +} + +void cppcoro::async_manual_reset_event::set() noexcept +{ + void* const setState = static_cast(this); + + // Needs 'release' semantics so that prior writes are visible to event awaiters + // that synchronise either via 'is_set()' or 'operator co_await()'. + // Needs 'acquire' semantics in case there are any waiters so that we see + // prior writes to the waiting coroutine's state and to the contents of + // the queued async_manual_reset_event_operation objects. + void* oldState = m_state.exchange(setState, std::memory_order_acq_rel); + if (oldState != setState) + { + auto* current = static_cast(oldState); + while (current != nullptr) + { + auto* next = current->m_next; + current->m_awaiter.resume(); + current = next; + } + } +} + +void cppcoro::async_manual_reset_event::reset() noexcept +{ + void* oldState = static_cast(this); + m_state.compare_exchange_strong(oldState, nullptr, std::memory_order_relaxed); +} + +cppcoro::async_manual_reset_event_operation::async_manual_reset_event_operation( + const async_manual_reset_event& event) noexcept + : m_event(event) +{ +} + +bool cppcoro::async_manual_reset_event_operation::await_ready() const noexcept +{ + return m_event.is_set(); +} + +bool cppcoro::async_manual_reset_event_operation::await_suspend( + cppcoro::coroutine_handle<> awaiter) noexcept +{ + m_awaiter = awaiter; + + const void* const setState = static_cast(&m_event); + + void* oldState = m_event.m_state.load(std::memory_order_acquire); + do + { + if (oldState == setState) + { + // State is now 'set' no need to suspend. + return false; + } + + m_next = static_cast(oldState); + } while (!m_event.m_state.compare_exchange_weak( + oldState, + static_cast(this), + std::memory_order_release, + std::memory_order_acquire)); + + // Successfully queued this waiter to the list. + return true; +} diff --git a/lib/async_mutex.cpp b/lib/async_mutex.cpp new file mode 100644 index 0000000..f713c93 --- /dev/null +++ b/lib/async_mutex.cpp @@ -0,0 +1,122 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include + +cppcoro::async_mutex::async_mutex() noexcept + : m_state(not_locked) + , m_waiters(nullptr) +{} + +cppcoro::async_mutex::~async_mutex() +{ + [[maybe_unused]] auto state = m_state.load(std::memory_order_relaxed); + assert(state == not_locked || state == locked_no_waiters); + assert(m_waiters == nullptr); +} + +bool cppcoro::async_mutex::try_lock() noexcept +{ + // Try to atomically transition from nullptr (not-locked) -> this (locked-no-waiters). + auto oldState = not_locked; + return m_state.compare_exchange_strong( + oldState, + locked_no_waiters, + std::memory_order_acquire, + std::memory_order_relaxed); +} + +cppcoro::async_mutex_lock_operation cppcoro::async_mutex::lock_async() noexcept +{ + return async_mutex_lock_operation{ *this }; +} + +cppcoro::async_mutex_scoped_lock_operation cppcoro::async_mutex::scoped_lock_async() noexcept +{ + return async_mutex_scoped_lock_operation{ *this }; +} + +void cppcoro::async_mutex::unlock() +{ + assert(m_state.load(std::memory_order_relaxed) != not_locked); + + async_mutex_lock_operation* waitersHead = m_waiters; + if (waitersHead == nullptr) + { + auto oldState = locked_no_waiters; + const bool releasedLock = m_state.compare_exchange_strong( + oldState, + not_locked, + std::memory_order_release, + std::memory_order_relaxed); + if (releasedLock) + { + return; + } + + // At least one new waiter. + // Acquire the list of new waiter operations atomically. + oldState = m_state.exchange(locked_no_waiters, std::memory_order_acquire); + + assert(oldState != locked_no_waiters && oldState != not_locked); + + // Transfer the list to m_waiters, reversing the list in the process so + // that the head of the list is the first to be resumed. + auto* next = reinterpret_cast(oldState); + do + { + auto* temp = next->m_next; + next->m_next = waitersHead; + waitersHead = next; + next = temp; + } while (next != nullptr); + } + + assert(waitersHead != nullptr); + + m_waiters = waitersHead->m_next; + + // Resume the waiter. + // This will pass the ownership of the lock on to that operation/coroutine. + waitersHead->m_awaiter.resume(); +} + +bool cppcoro::async_mutex_lock_operation::await_suspend(cppcoro::coroutine_handle<> awaiter) noexcept +{ + m_awaiter = awaiter; + + std::uintptr_t oldState = m_mutex.m_state.load(std::memory_order_acquire); + while (true) + { + if (oldState == async_mutex::not_locked) + { + if (m_mutex.m_state.compare_exchange_weak( + oldState, + async_mutex::locked_no_waiters, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + // Acquired lock, don't suspend. + return false; + } + } + else + { + // Try to push this operation onto the head of the waiter stack. + m_next = reinterpret_cast(oldState); + if (m_mutex.m_state.compare_exchange_weak( + oldState, + reinterpret_cast(this), + std::memory_order_release, + std::memory_order_relaxed)) + { + // Queued operation to waiters list, suspend now. + return true; + } + } + } +} diff --git a/lib/auto_reset_event.cpp b/lib/auto_reset_event.cpp new file mode 100644 index 0000000..25e0f3e --- /dev/null +++ b/lib/auto_reset_event.cpp @@ -0,0 +1,97 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include "auto_reset_event.hpp" + +#if CPPCORO_OS_WINNT +# define WIN32_LEAN_AND_MEAN +# include +# include +#endif + +namespace cppcoro +{ +#if CPPCORO_OS_WINNT + + auto_reset_event::auto_reset_event(bool initiallySet) + : m_event(::CreateEventW(NULL, FALSE, initiallySet ? TRUE : FALSE, NULL)) + { + if (m_event.handle() == NULL) + { + DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "auto_reset_event: CreateEvent failed" + }; + } + } + + auto_reset_event::~auto_reset_event() + { + } + + void auto_reset_event::set() + { + BOOL ok =::SetEvent(m_event.handle()); + if (!ok) + { + DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "auto_reset_event: SetEvent failed" + }; + } + } + + void auto_reset_event::wait() + { + DWORD result = ::WaitForSingleObjectEx(m_event.handle(), INFINITE, FALSE); + if (result != WAIT_OBJECT_0) + { + DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "auto_reset_event: WaitForSingleObjectEx failed" + }; + } + } + +#else + + auto_reset_event::auto_reset_event(bool initiallySet) + : m_isSet(initiallySet) + {} + + auto_reset_event::~auto_reset_event() + {} + + void auto_reset_event::set() + { + std::unique_lock lock{ m_mutex }; + if (!m_isSet) + { + m_isSet = true; + m_cv.notify_one(); + } + } + + void auto_reset_event::wait() + { + std::unique_lock lock{ m_mutex }; + while (!m_isSet) + { + m_cv.wait(lock); + } + m_isSet = false; + } + +#endif +} diff --git a/lib/auto_reset_event.hpp b/lib/auto_reset_event.hpp new file mode 100644 index 0000000..480f2d2 --- /dev/null +++ b/lib/auto_reset_event.hpp @@ -0,0 +1,44 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_AUTO_RESET_EVENT_HPP_INCLUDED +#define CPPCORO_AUTO_RESET_EVENT_HPP_INCLUDED + +#include + +#if CPPCORO_OS_WINNT +# include +#else +# include +# include +#endif + +namespace cppcoro +{ + class auto_reset_event + { + public: + + auto_reset_event(bool initiallySet = false); + + ~auto_reset_event(); + + void set(); + + void wait(); + + private: + +#if CPPCORO_OS_WINNT + cppcoro::detail::win32::safe_handle m_event; +#else + std::mutex m_mutex; + std::condition_variable m_cv; + bool m_isSet; +#endif + + }; +} + +#endif diff --git a/lib/build.cake b/lib/build.cake new file mode 100644 index 0000000..a32e1c4 --- /dev/null +++ b/lib/build.cake @@ -0,0 +1,182 @@ +############################################################################### +# Copyright Lewis Baker +# Licenced under MIT license. See LICENSE.txt for details. +############################################################################### + +import cake.path + +from cake.tools import compiler, script, env, project, variant + +includes = cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', [ + 'awaitable_traits.hpp', + 'is_awaitable.hpp', + 'async_auto_reset_event.hpp', + 'async_manual_reset_event.hpp', + 'async_generator.hpp', + 'async_mutex.hpp', + 'async_latch.hpp', + 'async_scope.hpp', + 'broken_promise.hpp', + 'cancellation_registration.hpp', + 'cancellation_source.hpp', + 'cancellation_token.hpp', + 'task.hpp', + 'sequence_barrier.hpp', + 'sequence_traits.hpp', + 'single_producer_sequencer.hpp', + 'multi_producer_sequencer.hpp', + 'shared_task.hpp', + 'single_consumer_event.hpp', + 'single_consumer_async_auto_reset_event.hpp', + 'sync_wait.hpp', + 'task.hpp', + 'io_service.hpp', + 'config.hpp', + 'on_scope_exit.hpp', + 'file_share_mode.hpp', + 'file_open_mode.hpp', + 'file_buffering_mode.hpp', + 'file.hpp', + 'fmap.hpp', + 'when_all.hpp', + 'when_all_ready.hpp', + 'resume_on.hpp', + 'schedule_on.hpp', + 'generator.hpp', + 'readable_file.hpp', + 'recursive_generator.hpp', + 'writable_file.hpp', + 'read_only_file.hpp', + 'write_only_file.hpp', + 'read_write_file.hpp', + 'file_read_operation.hpp', + 'file_write_operation.hpp', + 'static_thread_pool.hpp', + ]) + +netIncludes = cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'net', [ + 'ip_address.hpp', + 'ip_endpoint.hpp', + 'ipv4_address.hpp', + 'ipv4_endpoint.hpp', + 'ipv6_address.hpp', + 'ipv6_endpoint.hpp', + 'socket.hpp', +]) + +detailIncludes = cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'detail', [ + 'void_value.hpp', + 'when_all_ready_awaitable.hpp', + 'when_all_counter.hpp', + 'when_all_task.hpp', + 'get_awaiter.hpp', + 'is_awaiter.hpp', + 'any.hpp', + 'sync_wait_task.hpp', + 'unwrap_reference.hpp', + 'lightweight_manual_reset_event.hpp', + ]) + +privateHeaders = script.cwd([ + 'cancellation_state.hpp', + 'socket_helpers.hpp', + 'auto_reset_event.hpp', + 'spin_wait.hpp', + 'spin_mutex.hpp', + ]) + +sources = script.cwd([ + 'async_auto_reset_event.cpp', + 'async_manual_reset_event.cpp', + 'async_mutex.cpp', + 'cancellation_state.cpp', + 'cancellation_token.cpp', + 'cancellation_source.cpp', + 'cancellation_registration.cpp', + 'lightweight_manual_reset_event.cpp', + 'ip_address.cpp', + 'ip_endpoint.cpp', + 'ipv4_address.cpp', + 'ipv4_endpoint.cpp', + 'ipv6_address.cpp', + 'ipv6_endpoint.cpp', + 'static_thread_pool.cpp', + 'auto_reset_event.cpp', + 'spin_wait.cpp', + 'spin_mutex.cpp', + ]) + +extras = script.cwd([ + 'build.cake', + 'use.cake', + ]) + +if variant.platform == "windows": + detailIncludes.extend(cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'detail', [ + 'win32.hpp', + 'win32_overlapped_operation.hpp', + ])) + netIncludes.extend(cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'net', [ + 'socket.hpp', + 'socket_accept_operation.hpp', + 'socket_connect_operation.hpp', + 'socket_disconnect_operation.hpp', + 'socket_recv_operation.hpp', + 'socket_recv_from_operation.hpp', + 'socket_send_operation.hpp', + 'socket_send_to_operation.hpp', + ])) + sources.extend(script.cwd([ + 'win32.cpp', + 'io_service.cpp', + 'file.cpp', + 'readable_file.cpp', + 'writable_file.cpp', + 'read_only_file.cpp', + 'write_only_file.cpp', + 'read_write_file.cpp', + 'file_read_operation.cpp', + 'file_write_operation.cpp', + 'socket_helpers.cpp', + 'socket.cpp', + 'socket_accept_operation.cpp', + 'socket_connect_operation.cpp', + 'socket_disconnect_operation.cpp', + 'socket_send_operation.cpp', + 'socket_send_to_operation.cpp', + 'socket_recv_operation.cpp', + 'socket_recv_from_operation.cpp', + ])) + +buildDir = env.expand('${CPPCORO_BUILD}') + +compiler.addIncludePath(env.expand('${CPPCORO}/include')) + +objects = compiler.objects( + targetDir=env.expand('${CPPCORO_BUILD}/obj'), + sources=sources, + ) + +lib = compiler.library( + target=env.expand('${CPPCORO_LIB}/cppcoro'), + sources=objects, + ) + +vcproj = project.project( + target=env.expand('${CPPCORO_PROJECT}/cppcoro'), + items={ + 'Include': { + 'Detail': detailIncludes, + 'Net': netIncludes, + '': includes, + }, + 'Source': sources + privateHeaders, + '': extras + }, + output=lib, + ) + +script.setResult( + project=vcproj, + library=lib, + ) diff --git a/lib/cancellation_registration.cpp b/lib/cancellation_registration.cpp new file mode 100644 index 0000000..d9533bb --- /dev/null +++ b/lib/cancellation_registration.cpp @@ -0,0 +1,41 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include "cancellation_state.hpp" + +#include + +cppcoro::cancellation_registration::~cancellation_registration() +{ + if (m_state != nullptr) + { + m_state->deregister_callback(this); + m_state->release_token_ref(); + } +} + +void cppcoro::cancellation_registration::register_callback(cancellation_token&& token) +{ + auto* state = token.m_state; + if (state != nullptr && state->can_be_cancelled()) + { + m_state = state; + if (state->try_register_callback(this)) + { + token.m_state = nullptr; + } + else + { + m_state = nullptr; + m_callback(); + } + } + else + { + m_state = nullptr; + } +} diff --git a/lib/cancellation_source.cpp b/lib/cancellation_source.cpp new file mode 100644 index 0000000..240cf1d --- /dev/null +++ b/lib/cancellation_source.cpp @@ -0,0 +1,97 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include "cancellation_state.hpp" + +#include + +cppcoro::cancellation_source::cancellation_source() + : m_state(detail::cancellation_state::create()) +{ +} + +cppcoro::cancellation_source::cancellation_source(const cancellation_source& other) noexcept + : m_state(other.m_state) +{ + if (m_state != nullptr) + { + m_state->add_source_ref(); + } +} + +cppcoro::cancellation_source::cancellation_source(cancellation_source&& other) noexcept + : m_state(other.m_state) +{ + other.m_state = nullptr; +} + +cppcoro::cancellation_source::~cancellation_source() +{ + if (m_state != nullptr) + { + m_state->release_source_ref(); + } +} + +cppcoro::cancellation_source& cppcoro::cancellation_source::operator=(const cancellation_source& other) noexcept +{ + if (m_state != other.m_state) + { + if (m_state != nullptr) + { + m_state->release_source_ref(); + } + + m_state = other.m_state; + + if (m_state != nullptr) + { + m_state->add_source_ref(); + } + } + + return *this; +} + +cppcoro::cancellation_source& cppcoro::cancellation_source::operator=(cancellation_source&& other) noexcept +{ + if (this != &other) + { + if (m_state != nullptr) + { + m_state->release_source_ref(); + } + + m_state = other.m_state; + other.m_state = nullptr; + } + + return *this; +} + +bool cppcoro::cancellation_source::can_be_cancelled() const noexcept +{ + return m_state != nullptr; +} + +cppcoro::cancellation_token cppcoro::cancellation_source::token() const noexcept +{ + return cancellation_token(m_state); +} + +void cppcoro::cancellation_source::request_cancellation() +{ + if (m_state != nullptr) + { + m_state->request_cancellation(); + } +} + +bool cppcoro::cancellation_source::is_cancellation_requested() const noexcept +{ + return m_state != nullptr && m_state->is_cancellation_requested(); +} diff --git a/lib/cancellation_state.cpp b/lib/cancellation_state.cpp new file mode 100644 index 0000000..a81123c --- /dev/null +++ b/lib/cancellation_state.cpp @@ -0,0 +1,624 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include "cancellation_state.hpp" + +#include "cppcoro/config.hpp" + +#include + +#include +#include + +namespace cppcoro +{ + namespace detail + { + struct cancellation_registration_list_chunk + { + static cancellation_registration_list_chunk* allocate(std::uint32_t entryCount); + static void free(cancellation_registration_list_chunk* chunk) noexcept; + + std::atomic m_nextChunk; + cancellation_registration_list_chunk* m_prevChunk; + std::atomic m_approximateFreeCount; + std::uint32_t m_entryCount; + std::atomic m_entries[1]; + }; + + struct cancellation_registration_list + { + static cancellation_registration_list* allocate(); + static void free(cancellation_registration_list* bucket) noexcept; + + std::atomic m_approximateTail; + cancellation_registration_list_chunk m_headChunk; + }; + + struct cancellation_registration_result + { + cancellation_registration_result( + cancellation_registration_list_chunk* chunk, + std::uint32_t entryIndex) + : m_chunk(chunk) + , m_entryIndex(entryIndex) + {} + + cancellation_registration_list_chunk* m_chunk; + std::uint32_t m_entryIndex; + }; + + struct cancellation_registration_state + { + static cancellation_registration_state* allocate(); + static void free(cancellation_registration_state* list) noexcept; + + cancellation_registration_result add_registration( + cppcoro::cancellation_registration* registration); + + std::thread::id m_notificationThreadId; + + // Store N separate lists and randomly apportion threads to a given + // list to reduce chance of contention. + std::uint32_t m_listCount; + std::atomic m_lists[1]; + }; + } +} + +cppcoro::detail::cancellation_registration_list_chunk* +cppcoro::detail::cancellation_registration_list_chunk::allocate(std::uint32_t entryCount) +{ + auto* chunk = static_cast(std::malloc( + sizeof(cancellation_registration_list_chunk) + + (entryCount - 1) * sizeof(cancellation_registration_list_chunk::m_entries[0]))); + if (chunk == nullptr) + { + throw std::bad_alloc{}; + } + + ::new (&chunk->m_nextChunk) std::atomic(nullptr); + chunk->m_prevChunk = nullptr; + ::new (&chunk->m_approximateFreeCount) std::atomic(static_cast(entryCount - 1)); + chunk->m_entryCount = entryCount; + for (std::uint32_t i = 0; i < entryCount; ++i) + { + ::new (&chunk->m_entries[i]) std::atomic(nullptr); + } + + return chunk; +} + +void cppcoro::detail::cancellation_registration_list_chunk::free( + cancellation_registration_list_chunk* chunk) noexcept +{ + std::free(chunk); +} + +cppcoro::detail::cancellation_registration_list* +cppcoro::detail::cancellation_registration_list::allocate() +{ + constexpr std::uint32_t initialChunkSize = 16; + + const std::size_t bufferSize = + sizeof(cancellation_registration_list) + + (initialChunkSize - 1) * sizeof(cancellation_registration_list_chunk::m_entries[0]); + + auto* bucket = static_cast(std::malloc(bufferSize)); + if (bucket == nullptr) + { + throw std::bad_alloc{}; + } + + ::new (&bucket->m_approximateTail) std::atomic(&bucket->m_headChunk); + ::new (&bucket->m_headChunk.m_nextChunk) std::atomic(nullptr); + bucket->m_headChunk.m_prevChunk = nullptr; + ::new (&bucket->m_headChunk.m_approximateFreeCount) + std::atomic(static_cast(initialChunkSize - 1)); + bucket->m_headChunk.m_entryCount = initialChunkSize; + for (std::uint32_t i = 0; i < initialChunkSize; ++i) + { + ::new (&bucket->m_headChunk.m_entries[i]) std::atomic(nullptr); + } + + return bucket; +} + +void cppcoro::detail::cancellation_registration_list::free(cancellation_registration_list* list) noexcept +{ + std::free(list); +} + +cppcoro::detail::cancellation_registration_state* +cppcoro::detail::cancellation_registration_state::allocate() +{ + constexpr std::uint32_t maxListCount = 16; + + auto listCount = std::thread::hardware_concurrency(); + if (listCount > maxListCount) + { + listCount = maxListCount; + } + else if (listCount == 0) + { + listCount = 1; + } + + const std::size_t bufferSize = + sizeof(cancellation_registration_state) + + (listCount - 1) * sizeof(cancellation_registration_state::m_lists[0]); + + auto* state = static_cast(std::malloc(bufferSize)); + if (state == nullptr) + { + throw std::bad_alloc{}; + } + + state->m_listCount = listCount; + for (std::uint32_t i = 0; i < listCount; ++i) + { + ::new (&state->m_lists[i]) std::atomic(nullptr); + } + + return state; +} + +void cppcoro::detail::cancellation_registration_state::free(cancellation_registration_state* state) noexcept +{ + std::free(state); +} + +cppcoro::detail::cancellation_registration_result +cppcoro::detail::cancellation_registration_state::add_registration( + cppcoro::cancellation_registration* registration) +{ + // Pick a list to add to based on the current thread to reduce the + // chance of contention with multiple threads concurrently registering + // callbacks. + const auto threadIdHashCode = std::hash{}(std::this_thread::get_id()); + auto& listPtr = m_lists[threadIdHashCode % m_listCount]; + + auto* list = listPtr.load(std::memory_order_acquire); + if (list == nullptr) + { + auto* newList = cancellation_registration_list::allocate(); + + // Pre-claim the first slot. + registration->m_chunk = &newList->m_headChunk; + registration->m_entryIndex = 0; + ::new (&newList->m_headChunk.m_entries[0]) std::atomic(registration); + + if (listPtr.compare_exchange_strong( + list, + newList, + std::memory_order_seq_cst, + std::memory_order_acquire)) + { + return cancellation_registration_result(&newList->m_headChunk, 0); + } + else + { + cancellation_registration_list::free(newList); + } + } + + while (true) + { + // Navigate to the end of the chain of chunks and work backwards looking for a free slot. + auto* const originalLastChunk = list->m_approximateTail.load(std::memory_order_acquire); + + auto* lastChunk = originalLastChunk; + for (auto* next = lastChunk->m_nextChunk.load(std::memory_order_acquire); + next != nullptr; + next = next->m_nextChunk.load(std::memory_order_acquire)) + { + lastChunk = next; + } + + // Work around false-warning raised by MSVC static analysis complaining that + // warning C28182: Dereferencing NULL pointer. 'lastChunk' contains the same NULL value as 'chunk' did. + // on statement initialising 'elementCount' below. + CPPCORO_ASSUME(lastChunk != nullptr); + + if (lastChunk != originalLastChunk) + { + // Update the cache of last chunk pointer so that subsequent + // registration requests can start there instead. + // Doesn't matter if these writes race as it will eventually + // converge to the true last chunk. + list->m_approximateTail.store(lastChunk, std::memory_order_release); + } + + for (auto* chunk = lastChunk; + chunk != nullptr; + chunk = chunk->m_prevChunk) + { + auto freeCount = chunk->m_approximateFreeCount.load(std::memory_order_relaxed); + + // If it looks like there are no free slots then decrement the count again + // to force it to re-search every so-often, just in case the count has gotten + // out-of-sync with the true free count and is reporting none free even though + // there are some (or possibly all) free slots. + if (freeCount < 1) + { + --freeCount; + chunk->m_approximateFreeCount.store(freeCount, std::memory_order_relaxed); + } + + constexpr std::int32_t forcedSearchThreshold = -10; + if (freeCount > 0 || freeCount < forcedSearchThreshold) + { + const std::uint32_t entryCount = chunk->m_entryCount; + const std::uint32_t indexMask = entryCount - 1; + const std::uint32_t startIndex = entryCount - freeCount; + + registration->m_chunk = chunk; + + for (std::uint32_t i = 0; i < entryCount; ++i) + { + const std::uint32_t entryIndex = (startIndex + i) & indexMask; + auto& entry = chunk->m_entries[entryIndex]; + + // Do a cheap initial read of the entry value to see if the + // entry is likely free. This can potentially read stale values + // and so may lead to falsely thinking it's free or falsely + // thinking it's occupied. But approximate is good enough here. + auto* entryValue = entry.load(std::memory_order_relaxed); + if (entryValue == nullptr) + { + registration->m_entryIndex = entryIndex; + + if (entry.compare_exchange_strong( + entryValue, + registration, + std::memory_order_seq_cst, + std::memory_order_relaxed)) + { + // Successfully claimed the slot. + const std::int32_t newFreeCount = freeCount < 0 ? 0 : freeCount - 1; + chunk->m_approximateFreeCount.store(newFreeCount, std::memory_order_relaxed); + return cancellation_registration_result(chunk, entryIndex); + } + } + } + + // Read through all elements of chunk with no success. + // Clear free-count back to 0. + chunk->m_approximateFreeCount.store(0, std::memory_order_relaxed); + } + } + + // We've traversed through all of the chunks and found no free slots. + // So try and allocate a new chunk and append it to the list. + + constexpr std::uint32_t maxElementCount = 1024; + + const std::uint32_t elementCount = + lastChunk->m_entryCount < maxElementCount ? + lastChunk->m_entryCount * 2 : maxElementCount; + + // May throw std::bad_alloc if out of memory. + auto* newChunk = cancellation_registration_list_chunk::allocate(elementCount); + newChunk->m_prevChunk = lastChunk; + + // Pre-allocate first slot. + registration->m_chunk = newChunk; + registration->m_entryIndex = 0; + ::new (&newChunk->m_entries[0]) std::atomic(registration); + + cancellation_registration_list_chunk* oldNext = nullptr; + if (lastChunk->m_nextChunk.compare_exchange_strong( + oldNext, + newChunk, + std::memory_order_seq_cst, + std::memory_order_relaxed)) + { + list->m_approximateTail.store(newChunk, std::memory_order_release); + return cancellation_registration_result(newChunk, 0); + } + + // Some other thread published a new chunk to the end of the list + // concurrently. Free our chunk and go around the loop again, hopefully + // allocating a slot from the chunk the other thread just allocated. + cancellation_registration_list_chunk::free(newChunk); + } +} + +cppcoro::detail::cancellation_state* cppcoro::detail::cancellation_state::create() +{ + return new cancellation_state(); +} + +cppcoro::detail::cancellation_state::~cancellation_state() +{ + assert((m_state.load(std::memory_order_relaxed) & cancellation_ref_count_mask) == 0); + + // Use relaxed memory order in reads here since we should already have visibility + // to all writes as the ref-count decrement that preceded the call to the destructor + // has acquire-release semantics. + + auto* registrationState = m_registrationState.load(std::memory_order_relaxed); + if (registrationState != nullptr) + { + for (std::uint32_t i = 0; i < registrationState->m_listCount; ++i) + { + auto* list = registrationState->m_lists[i].load(std::memory_order_relaxed); + if (list != nullptr) + { + auto* chunk = list->m_headChunk.m_nextChunk.load(std::memory_order_relaxed); + cancellation_registration_list::free(list); + + while (chunk != nullptr) + { + auto* next = chunk->m_nextChunk.load(std::memory_order_relaxed); + cancellation_registration_list_chunk::free(chunk); + chunk = next; + } + } + } + + cancellation_registration_state::free(registrationState); + } +} + +void cppcoro::detail::cancellation_state::add_token_ref() noexcept +{ + m_state.fetch_add(cancellation_token_ref_increment, std::memory_order_relaxed); +} + +void cppcoro::detail::cancellation_state::release_token_ref() noexcept +{ + const std::uint64_t oldState = m_state.fetch_sub(cancellation_token_ref_increment, std::memory_order_acq_rel); + if ((oldState & cancellation_ref_count_mask) == cancellation_token_ref_increment) + { + delete this; + } +} + +void cppcoro::detail::cancellation_state::add_source_ref() noexcept +{ + m_state.fetch_add(cancellation_source_ref_increment, std::memory_order_relaxed); +} + +void cppcoro::detail::cancellation_state::release_source_ref() noexcept +{ + const std::uint64_t oldState = m_state.fetch_sub(cancellation_source_ref_increment, std::memory_order_acq_rel); + if ((oldState & cancellation_ref_count_mask) == cancellation_source_ref_increment) + { + delete this; + } +} + +bool cppcoro::detail::cancellation_state::can_be_cancelled() const noexcept +{ + return (m_state.load(std::memory_order_acquire) & can_be_cancelled_mask) != 0; +} + +bool cppcoro::detail::cancellation_state::is_cancellation_requested() const noexcept +{ + return (m_state.load(std::memory_order_acquire) & cancellation_requested_flag) != 0; +} + +bool cppcoro::detail::cancellation_state::is_cancellation_notification_complete() const noexcept +{ + return (m_state.load(std::memory_order_acquire) & cancellation_notification_complete_flag) != 0; +} + +void cppcoro::detail::cancellation_state::request_cancellation() +{ + const auto oldState = m_state.fetch_or(cancellation_requested_flag, std::memory_order_seq_cst); + if ((oldState & cancellation_requested_flag) != 0) + { + // Some thread has already called request_cancellation(). + return; + } + + // We are the first caller of request_cancellation. + // Need to execute any registered callbacks to notify them of cancellation. + + // NOTE: We need to use sequentially-consistent operations here to ensure + // that if there is a concurrent call to try_register_callback() on another + // thread that either the other thread will read the prior write to m_state + // after they write to a registration slot or we will read their write to the + // registration slot after the prior write to m_state. + + auto* const registrationState = m_registrationState.load(std::memory_order_seq_cst); + if (registrationState != nullptr) + { + // Note that there should be no data-race in writing to this value here + // as another thread will only read it if they are trying to deregister + // a callback and that fails because we have acquired the pointer to + // the registration inside the loop below. In this case the atomic + // exchange that acquires the pointer below acts as a release-operation + // that synchronises with the failed exchange operation in deregister_callback() + // which has acquire semantics and thus will have visibility of the write to + // the m_notificationThreadId value. + registrationState->m_notificationThreadId = std::this_thread::get_id(); + + for (std::uint32_t listIndex = 0, listCount = registrationState->m_listCount; + listIndex < listCount; + ++listIndex) + { + auto* list = registrationState->m_lists[listIndex].load(std::memory_order_seq_cst); + if (list == nullptr) + { + continue; + } + + auto* chunk = &list->m_headChunk; + do + { + for (std::uint32_t entryIndex = 0, entryCount = chunk->m_entryCount; + entryIndex < entryCount; + ++entryIndex) + { + auto& entry = chunk->m_entries[entryIndex]; + + // Quick read-only operation to check if any registration + // is present. + auto* registration = entry.load(std::memory_order_seq_cst); + if (registration != nullptr) + { + // Try to acquire ownership of the registration by replacing its + // slot with nullptr atomically. This resolves the race between + // a concurrent call to deregister_callback() from the registration's + // destructor. + registration = entry.exchange(nullptr, std::memory_order_seq_cst); + if (registration != nullptr) + { + try + { + registration->m_callback(); + } + catch (...) + { + // TODO: What should behaviour of unhandled exception in a callback be here? + std::terminate(); + } + } + } + } + + chunk = chunk->m_nextChunk.load(std::memory_order_seq_cst); + } while (chunk != nullptr); + } + + m_state.fetch_add(cancellation_notification_complete_flag, std::memory_order_release); + } +} + +bool cppcoro::detail::cancellation_state::try_register_callback( + cancellation_registration* registration) +{ + if (is_cancellation_requested()) + { + return false; + } + + auto* registrationState = m_registrationState.load(std::memory_order_acquire); + if (registrationState == nullptr) + { + // Could throw std::bad_alloc + auto* newRegistrationState = cancellation_registration_state::allocate(); + + // Need to use 'sequentially consistent' on the write here to ensure that if + // we subsequently read a value from m_state at the end of this function that + // doesn't have the cancellation_requested_flag bit set that a subsequent call + // in another thread to request_cancellation() will see this write. + if (m_registrationState.compare_exchange_strong( + registrationState, + newRegistrationState, + std::memory_order_seq_cst, + std::memory_order_acquire)) + { + registrationState = newRegistrationState; + } + else + { + cancellation_registration_state::free(newRegistrationState); + } + } + + // Could throw std::bad_alloc + auto result = registrationState->add_registration(registration); + + // Need to check status again to handle the case where + // another thread calls request_cancellation() concurrently + // but doesn't see our write to the registration list. + // + // Note, we don't call IsCancellationRequested() here since that + // only provides 'acquire' memory semantics and we need 'seq_cst' + // semantics. + if ((m_state.load(std::memory_order_seq_cst) & cancellation_requested_flag) != 0) + { + // Cancellation was requested concurrently with adding the + // registration to the list. Try to remove the registration. + // If successful we return false to indicate that the callback + // has not been registered and the caller should execute the + // callback. If it fails it means that the thread that requested + // cancellation will execute our callback and we need to wait + // until it finishes before returning. + auto& entry = result.m_chunk->m_entries[result.m_entryIndex]; + + // Need to use compare_exchange here rather than just exchange since + // it may be possible that the thread calling request_cancellation() + // acquired our registration and executed the callback, freeing up + // the slot and then a third thread registers a new registration + // that gets allocated to this slot. + // + // Can use relaxed memory order here since in the case that this succeeds + // no other thread will have written to the cancellation_registration record + // so we can safely read from the record without synchronisation. + auto* oldValue = registration; + const bool deregisteredSuccessfully = + entry.compare_exchange_strong(oldValue, nullptr, std::memory_order_relaxed); + if (deregisteredSuccessfully) + { + return false; + } + + // Otherwise, the cancelling thread has taken ownership for executing + // the callback and we can just act as if the registration succeeded. + } + + return true; +} + +void cppcoro::detail::cancellation_state::deregister_callback(cancellation_registration* registration) noexcept +{ + auto* chunk = registration->m_chunk; + auto& entry = chunk->m_entries[registration->m_entryIndex]; + + // Use 'acquire' memory order on failure case so that we synchronise with the write + // to the slot inside request_cancellation() that acquired the registration such that + // we have visibility of its prior write to m_notifyingThreadId. + // + // Could use 'relaxed' memory order on success case as if this succeeds it means that + // no thread will have written to the registration object. + auto* oldValue = registration; + bool deregisteredSuccessfully = entry.compare_exchange_strong( + oldValue, + nullptr, + std::memory_order_acquire); + if (deregisteredSuccessfully) + { + // Increment free-count if it won't make it larger than entry count. + const std::int32_t oldFreeCount = chunk->m_approximateFreeCount.load(std::memory_order_relaxed); + if (oldFreeCount < static_cast(chunk->m_entryCount)) + { + const std::int32_t newFreeCount = oldFreeCount < 0 ? 1 : oldFreeCount + 1; + chunk->m_approximateFreeCount.store(newFreeCount, std::memory_order_relaxed); + } + } + else + { + // A thread executing request_cancellation() has acquired this callback and + // is executing it. Need to wait until it finishes executing before we return + // and the registration object is destructed. + // + // However, we also need to handle the case where the registration is being + // removed from within a callback which would otherwise deadlock waiting + // for the callbacks to finish executing. + + // Use relaxed memory order here as we should already have visibility + // of the write to m_registrationState from when the registration was first + // registered. + auto* registrationState = m_registrationState.load(std::memory_order_relaxed); + if (std::this_thread::get_id() != registrationState->m_notificationThreadId) + { + // TODO: More efficient busy-wait backoff strategy + while (!is_cancellation_notification_complete()) + { + std::this_thread::yield(); + } + } + } +} + +cppcoro::detail::cancellation_state::cancellation_state() noexcept + : m_state(cancellation_source_ref_increment) + , m_registrationState(nullptr) +{ +} diff --git a/lib/cancellation_state.hpp b/lib/cancellation_state.hpp new file mode 100644 index 0000000..9bdb40d --- /dev/null +++ b/lib/cancellation_state.hpp @@ -0,0 +1,108 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_CANCELLATION_STATE_HPP_INCLUDED +#define CPPCORO_CANCELLATION_STATE_HPP_INCLUDED + +#include + +#include +#include +#include + +namespace cppcoro +{ + namespace detail + { + struct cancellation_registration_state; + + class cancellation_state + { + public: + + /// Allocates a new cancellation_state object. + /// + /// \throw std::bad_alloc + /// If there was insufficient memory to allocate one. + static cancellation_state* create(); + + ~cancellation_state(); + + /// Increment the reference count of cancellation_token and + /// cancellation_registration objects referencing this state. + void add_token_ref() noexcept; + + /// Decrement the reference count of cancellation_token and + /// cancellation_registration objects referencing this state. + void release_token_ref() noexcept; + + /// Increment the reference count of cancellation_source objects. + void add_source_ref() noexcept; + + /// Decrement the reference count of cancellation_souce objects. + /// + /// The cancellation_state will no longer be cancellable once the + /// cancellation_source ref count reaches zero. + void release_source_ref() noexcept; + + /// Query if the cancellation_state can have cancellation requested. + /// + /// \return + /// Returns true if there are no more references to a cancellation_source + /// object. + bool can_be_cancelled() const noexcept; + + /// Query if some thread has called request_cancellation(). + bool is_cancellation_requested() const noexcept; + + /// Flag state has having cancellation_requested and execute any + /// registered callbacks. + void request_cancellation(); + + /// Try to register the cancellation_registration as a callback to be executed + /// when cancellation is requested. + /// + /// \return + /// true if the callback was successfully registered, false if the callback was + /// not registered because cancellation had already been requested. + /// + /// \throw std::bad_alloc + /// If callback was unable to be registered due to insufficient memory. + bool try_register_callback(cancellation_registration* registration); + + /// Deregister a callback previously registered successfully in a call to try_register_callback(). + /// + /// If the callback is currently being executed on another + /// thread that is concurrently calling request_cancellation() + /// then this call will block until the callback has finished executing. + void deregister_callback(cancellation_registration* registration) noexcept; + + private: + + cancellation_state() noexcept; + + bool is_cancellation_notification_complete() const noexcept; + + static constexpr std::uint64_t cancellation_requested_flag = 1; + static constexpr std::uint64_t cancellation_notification_complete_flag = 2; + static constexpr std::uint64_t cancellation_source_ref_increment = 4; + static constexpr std::uint64_t cancellation_token_ref_increment = UINT64_C(1) << 33; + static constexpr std::uint64_t can_be_cancelled_mask = cancellation_token_ref_increment - 1; + static constexpr std::uint64_t cancellation_ref_count_mask = + ~(cancellation_requested_flag | cancellation_notification_complete_flag); + + // A value that has: + // - bit 0 - indicates whether cancellation has been requested. + // - bit 1 - indicates whether cancellation notification is complete. + // - bits 2-32 - ref-count for cancellation_source instances. + // - bits 33-63 - ref-count for cancellation_token/cancellation_registration instances. + std::atomic m_state; + + std::atomic m_registrationState; + + }; + } +} + +#endif diff --git a/lib/cancellation_token.cpp b/lib/cancellation_token.cpp new file mode 100644 index 0000000..ad13360 --- /dev/null +++ b/lib/cancellation_token.cpp @@ -0,0 +1,108 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include "cancellation_state.hpp" + +#include +#include + +cppcoro::cancellation_token::cancellation_token() noexcept + : m_state(nullptr) +{ +} + +cppcoro::cancellation_token::cancellation_token(const cancellation_token& other) noexcept + : m_state(other.m_state) +{ + if (m_state != nullptr) + { + m_state->add_token_ref(); + } +} + +cppcoro::cancellation_token::cancellation_token(cancellation_token&& other) noexcept + : m_state(other.m_state) +{ + other.m_state = nullptr; +} + +cppcoro::cancellation_token::~cancellation_token() +{ + if (m_state != nullptr) + { + m_state->release_token_ref(); + } +} + +cppcoro::cancellation_token& cppcoro::cancellation_token::operator=(const cancellation_token& other) noexcept +{ + if (other.m_state != m_state) + { + if (m_state != nullptr) + { + m_state->release_token_ref(); + } + + m_state = other.m_state; + + if (m_state != nullptr) + { + m_state->add_token_ref(); + } + } + + return *this; +} + +cppcoro::cancellation_token& cppcoro::cancellation_token::operator=(cancellation_token&& other) noexcept +{ + if (this != &other) + { + if (m_state != nullptr) + { + m_state->release_token_ref(); + } + + m_state = other.m_state; + other.m_state = nullptr; + } + + return *this; +} + +void cppcoro::cancellation_token::swap(cancellation_token& other) noexcept +{ + std::swap(m_state, other.m_state); +} + +bool cppcoro::cancellation_token::can_be_cancelled() const noexcept +{ + return m_state != nullptr && m_state->can_be_cancelled(); +} + +bool cppcoro::cancellation_token::is_cancellation_requested() const noexcept +{ + return m_state != nullptr && m_state->is_cancellation_requested(); +} + +void cppcoro::cancellation_token::throw_if_cancellation_requested() const +{ + if (is_cancellation_requested()) + { + throw operation_cancelled{}; + } +} + +cppcoro::cancellation_token::cancellation_token(detail::cancellation_state* state) noexcept + : m_state(state) +{ + if (m_state != nullptr) + { + m_state->add_token_ref(); + } +} diff --git a/lib/file.cpp b/lib/file.cpp new file mode 100644 index 0000000..0e15585 --- /dev/null +++ b/lib/file.cpp @@ -0,0 +1,168 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include +#endif + +cppcoro::file::~file() +{} + +std::uint64_t cppcoro::file::size() const +{ +#if CPPCORO_OS_WINNT + LARGE_INTEGER size; + BOOL ok = ::GetFileSizeEx(m_fileHandle.handle(), &size); + if (!ok) + { + DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "error getting file size: GetFileSizeEx" + }; + } + + return size.QuadPart; +#endif +} + +cppcoro::file::file(detail::win32::safe_handle&& fileHandle) noexcept + : m_fileHandle(std::move(fileHandle)) +{ +} + +cppcoro::detail::win32::safe_handle cppcoro::file::open( + detail::win32::dword_t fileAccess, + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode, + file_share_mode shareMode, + file_buffering_mode bufferingMode) +{ + DWORD flags = FILE_FLAG_OVERLAPPED; + if ((bufferingMode & file_buffering_mode::random_access) == file_buffering_mode::random_access) + { + flags |= FILE_FLAG_RANDOM_ACCESS; + } + if ((bufferingMode & file_buffering_mode::sequential) == file_buffering_mode::sequential) + { + flags |= FILE_FLAG_SEQUENTIAL_SCAN; + } + if ((bufferingMode & file_buffering_mode::write_through) == file_buffering_mode::write_through) + { + flags |= FILE_FLAG_WRITE_THROUGH; + } + if ((bufferingMode & file_buffering_mode::temporary) == file_buffering_mode::temporary) + { + flags |= FILE_ATTRIBUTE_TEMPORARY; + } + if ((bufferingMode & file_buffering_mode::unbuffered) == file_buffering_mode::unbuffered) + { + flags |= FILE_FLAG_NO_BUFFERING; + } + + DWORD shareFlags = 0; + if ((shareMode & file_share_mode::read) == file_share_mode::read) + { + shareFlags |= FILE_SHARE_READ; + } + if ((shareMode & file_share_mode::write) == file_share_mode::write) + { + shareFlags |= FILE_SHARE_WRITE; + } + if ((shareMode & file_share_mode::delete_) == file_share_mode::delete_) + { + shareFlags |= FILE_SHARE_DELETE; + } + + DWORD creationDisposition = 0; + switch (openMode) + { + case file_open_mode::create_or_open: + creationDisposition = OPEN_ALWAYS; + break; + case file_open_mode::create_always: + creationDisposition = CREATE_ALWAYS; + break; + case file_open_mode::create_new: + creationDisposition = CREATE_NEW; + break; + case file_open_mode::open_existing: + creationDisposition = OPEN_EXISTING; + break; + case file_open_mode::truncate_existing: + creationDisposition = TRUNCATE_EXISTING; + break; + } + + // Open the file + detail::win32::safe_handle fileHandle( + ::CreateFileW( + path.wstring().c_str(), + fileAccess, + shareFlags, + nullptr, + creationDisposition, + flags, + nullptr)); + if (fileHandle.handle() == INVALID_HANDLE_VALUE) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "error opening file: CreateFileW" + }; + } + + // Associate with the I/O service's completion port. + const HANDLE result = ::CreateIoCompletionPort( + fileHandle.handle(), + ioService.native_iocp_handle(), + 0, + 0); + if (result == nullptr) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "error opening file: CreateIoCompletionPort" + }; + } + + // Configure I/O operations to avoid dispatching a completion event + // to the I/O service if the operation completes synchronously. + // This avoids unnecessary suspension/resuption of the awaiting coroutine. + const BOOL ok = ::SetFileCompletionNotificationModes( + fileHandle.handle(), + FILE_SKIP_COMPLETION_PORT_ON_SUCCESS | + FILE_SKIP_SET_EVENT_ON_HANDLE); + if (!ok) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "error opening file: SetFileCompletionNotificationModes" + }; + } + + return std::move(fileHandle); +} diff --git a/lib/file_read_operation.cpp b/lib/file_read_operation.cpp new file mode 100644 index 0000000..ca27489 --- /dev/null +++ b/lib/file_read_operation.cpp @@ -0,0 +1,53 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include + +bool cppcoro::file_read_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + const DWORD numberOfBytesToRead = + m_byteCount <= 0xFFFFFFFF ? + static_cast(m_byteCount) : DWORD(0xFFFFFFFF); + + DWORD numberOfBytesRead = 0; + BOOL ok = ::ReadFile( + m_fileHandle, + m_buffer, + numberOfBytesToRead, + &numberOfBytesRead, + operation.get_overlapped()); + const DWORD errorCode = ok ? ERROR_SUCCESS : ::GetLastError(); + if (errorCode != ERROR_IO_PENDING) + { + // Completed synchronously. + // + // We are assuming that the file-handle has been set to the + // mode where synchronous completions do not post a completion + // event to the I/O completion port and thus can return without + // suspending here. + + operation.m_errorCode = errorCode; + operation.m_numberOfBytesTransferred = numberOfBytesRead; + + return false; + } + + return true; +} + +void cppcoro::file_read_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx(m_fileHandle, operation.get_overlapped()); +} + +#endif // CPPCORO_OS_WINNT diff --git a/lib/file_write_operation.cpp b/lib/file_write_operation.cpp new file mode 100644 index 0000000..68a3ac4 --- /dev/null +++ b/lib/file_write_operation.cpp @@ -0,0 +1,53 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include + +bool cppcoro::file_write_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + const DWORD numberOfBytesToWrite = + m_byteCount <= 0xFFFFFFFF ? + static_cast(m_byteCount) : DWORD(0xFFFFFFFF); + + DWORD numberOfBytesWritten = 0; + BOOL ok = ::WriteFile( + m_fileHandle, + m_buffer, + numberOfBytesToWrite, + &numberOfBytesWritten, + operation.get_overlapped()); + const DWORD errorCode = ok ? ERROR_SUCCESS : ::GetLastError(); + if (errorCode != ERROR_IO_PENDING) + { + // Completed synchronously. + // + // We are assuming that the file-handle has been set to the + // mode where synchronous completions do not post a completion + // event to the I/O completion port and thus can return without + // suspending here. + + operation.m_errorCode = errorCode; + operation.m_numberOfBytesTransferred = numberOfBytesWritten; + + return false; + } + + return true; +} + +void cppcoro::file_write_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx(m_fileHandle, operation.get_overlapped()); +} + +#endif // CPPCORO_OS_WINNT diff --git a/lib/io_service.cpp b/lib/io_service.cpp new file mode 100644 index 0000000..551c280 --- /dev/null +++ b/lib/io_service.cpp @@ -0,0 +1,1020 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include +#include +#include +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +# include +# include +#endif + +namespace +{ +#if CPPCORO_OS_WINNT + cppcoro::detail::win32::safe_handle create_io_completion_port(std::uint32_t concurrencyHint) + { + HANDLE handle = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, concurrencyHint); + if (handle == NULL) + { + DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "Error creating io_service: CreateIoCompletionPort" + }; + } + + return cppcoro::detail::win32::safe_handle{ handle }; + } + + cppcoro::detail::win32::safe_handle create_auto_reset_event() + { + HANDLE eventHandle = ::CreateEventW(nullptr, FALSE, FALSE, nullptr); + if (eventHandle == NULL) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "Error creating manual reset event: CreateEventW" + }; + } + + return cppcoro::detail::win32::safe_handle{ eventHandle }; + } + + cppcoro::detail::win32::safe_handle create_waitable_timer_event() + { + const BOOL isManualReset = FALSE; + HANDLE handle = ::CreateWaitableTimerW(nullptr, isManualReset, nullptr); + if (handle == nullptr) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category() + }; + } + + return cppcoro::detail::win32::safe_handle{ handle }; + } +#endif +} + +/// \brief +/// A queue of pending timers that supports efficiently determining +/// and dequeueing the earliest-due timers in the queue. +/// +/// Implementation utilises a heap-sorted vector of entries with an +/// additional sorted linked-list that can be used as a fallback in +/// cases that there was insufficient memory to store all timer +/// entries in the vector. +/// +/// This fallback is required to guarantee that all operations on this +/// queue are noexcept.s +class cppcoro::io_service::timer_queue +{ +public: + + using time_point = std::chrono::high_resolution_clock::time_point; + + timer_queue() noexcept; + + ~timer_queue(); + + bool is_empty() const noexcept; + + time_point earliest_due_time() const noexcept; + + void enqueue_timer(cppcoro::io_service::timed_schedule_operation* timer) noexcept; + + void dequeue_due_timers( + time_point currentTime, + cppcoro::io_service::timed_schedule_operation*& timerList) noexcept; + + void remove_cancelled_timers( + cppcoro::io_service::timed_schedule_operation*& timerList) noexcept; + +private: + + struct timer_entry + { + timer_entry(cppcoro::io_service::timed_schedule_operation* timer) + : m_dueTime(timer->m_resumeTime) + , m_timer(timer) + {} + + time_point m_dueTime; + cppcoro::io_service::timed_schedule_operation* m_timer; + }; + + static bool compare_entries(const timer_entry& a, const timer_entry& b) noexcept + { + return a.m_dueTime > b.m_dueTime; + } + + // A heap-sorted list of active timer entries + // Earliest due timer is at the front of the queue + std::vector m_timerEntries; + + // Linked-list of overflow timer entries used in case there was + // insufficient memory available to grow m_timerEntries. + // List is sorted in ascending order of due-time using insertion-sort. + // This is required to support the noexcept guarantee of enqueue_timer(). + cppcoro::io_service::timed_schedule_operation* m_overflowTimers; + +}; + +cppcoro::io_service::timer_queue::timer_queue() noexcept + : m_timerEntries() + , m_overflowTimers(nullptr) +{} + +cppcoro::io_service::timer_queue::~timer_queue() +{ + assert(is_empty()); +} + +bool cppcoro::io_service::timer_queue::is_empty() const noexcept +{ + return m_timerEntries.empty() && m_overflowTimers == nullptr; +} + +cppcoro::io_service::timer_queue::time_point +cppcoro::io_service::timer_queue::earliest_due_time() const noexcept +{ + if (!m_timerEntries.empty()) + { + if (m_overflowTimers != nullptr) + { + return std::min( + m_timerEntries.front().m_dueTime, + m_overflowTimers->m_resumeTime); + } + + return m_timerEntries.front().m_dueTime; + } + else if (m_overflowTimers != nullptr) + { + return m_overflowTimers->m_resumeTime; + } + + return time_point::max(); +} + +void cppcoro::io_service::timer_queue::enqueue_timer( + cppcoro::io_service::timed_schedule_operation* timer) noexcept +{ + try + { + m_timerEntries.emplace_back(timer); + std::push_heap(m_timerEntries.begin(), m_timerEntries.end(), compare_entries); + } + catch (...) + { + const auto& newDueTime = timer->m_resumeTime; + + auto** current = &m_overflowTimers; + while ((*current) != nullptr && (*current)->m_resumeTime <= newDueTime) + { + current = &(*current)->m_next; + } + + timer->m_next = *current; + *current = timer; + } +} + +void cppcoro::io_service::timer_queue::dequeue_due_timers( + time_point currentTime, + cppcoro::io_service::timed_schedule_operation*& timerList) noexcept +{ + while (!m_timerEntries.empty() && m_timerEntries.front().m_dueTime <= currentTime) + { + auto* timer = m_timerEntries.front().m_timer; + std::pop_heap(m_timerEntries.begin(), m_timerEntries.end(), compare_entries); + m_timerEntries.pop_back(); + + timer->m_next = timerList; + timerList = timer; + } + + while (m_overflowTimers != nullptr && m_overflowTimers->m_resumeTime <= currentTime) + { + auto* timer = m_overflowTimers; + m_overflowTimers = timer->m_next; + timer->m_next = timerList; + timerList = timer; + } +} + +void cppcoro::io_service::timer_queue::remove_cancelled_timers( + cppcoro::io_service::timed_schedule_operation*& timerList) noexcept +{ + // Perform a linear scan of all timers looking for any that have + // had cancellation requested. + + const auto addTimerToList = [&](timed_schedule_operation* timer) + { + timer->m_next = timerList; + timerList = timer; + }; + + const auto isTimerCancelled = [](const timer_entry& entry) + { + return entry.m_timer->m_cancellationToken.is_cancellation_requested(); + }; + + auto firstCancelledEntry = std::find_if( + m_timerEntries.begin(), + m_timerEntries.end(), + isTimerCancelled); + if (firstCancelledEntry != m_timerEntries.end()) + { + auto nonCancelledEnd = firstCancelledEntry; + addTimerToList(nonCancelledEnd->m_timer); + + for (auto iter = firstCancelledEntry + 1; + iter != m_timerEntries.end(); + ++iter) + { + if (isTimerCancelled(*iter)) + { + addTimerToList(iter->m_timer); + } + else + { + *nonCancelledEnd++ = std::move(*iter); + } + } + + m_timerEntries.erase(nonCancelledEnd, m_timerEntries.end()); + + std::make_heap( + m_timerEntries.begin(), + m_timerEntries.end(), + compare_entries); + } + + { + timed_schedule_operation** current = &m_overflowTimers; + while ((*current) != nullptr) + { + auto* timer = (*current); + if (timer->m_cancellationToken.is_cancellation_requested()) + { + *current = timer->m_next; + addTimerToList(timer); + } + else + { + current = &timer->m_next; + } + } + } +} + +class cppcoro::io_service::timer_thread_state +{ +public: + + timer_thread_state(); + ~timer_thread_state(); + + timer_thread_state(const timer_thread_state& other) = delete; + timer_thread_state& operator=(const timer_thread_state& other) = delete; + + void request_timer_cancellation() noexcept; + + void run() noexcept; + + void wake_up_timer_thread() noexcept; + +#if CPPCORO_OS_WINNT + detail::win32::safe_handle m_wakeUpEvent; + detail::win32::safe_handle m_waitableTimerEvent; +#endif + + std::atomic m_newlyQueuedTimers; + std::atomic m_timerCancellationRequested; + std::atomic m_shutDownRequested; + + std::thread m_thread; +}; + + + +cppcoro::io_service::io_service() + : io_service(0) +{ +} + +cppcoro::io_service::io_service(std::uint32_t concurrencyHint) + : m_threadState(0) + , m_workCount(0) +#if CPPCORO_OS_WINNT + , m_iocpHandle(create_io_completion_port(concurrencyHint)) + , m_winsockInitialised(false) + , m_winsockInitialisationMutex() +#endif + , m_scheduleOperations(nullptr) + , m_timerState(nullptr) +{ +} + +cppcoro::io_service::~io_service() +{ + assert(m_scheduleOperations.load(std::memory_order_relaxed) == nullptr); + assert(m_threadState.load(std::memory_order_relaxed) < active_thread_count_increment); + + delete m_timerState.load(std::memory_order_relaxed); + +#if CPPCORO_OS_WINNT + if (m_winsockInitialised.load(std::memory_order_relaxed)) + { + // TODO: Should we be checking return-code here? + // Don't want to throw from the destructor, so perhaps just log an error? + (void)::WSACleanup(); + } +#endif +} + +cppcoro::io_service::schedule_operation cppcoro::io_service::schedule() noexcept +{ + return schedule_operation{ *this }; +} + +std::uint64_t cppcoro::io_service::process_events() +{ + std::uint64_t eventCount = 0; + if (try_enter_event_loop()) + { + auto exitLoop = on_scope_exit([&] { exit_event_loop(); }); + + constexpr bool waitForEvent = true; + while (try_process_one_event(waitForEvent)) + { + ++eventCount; + } + } + + return eventCount; +} + +std::uint64_t cppcoro::io_service::process_pending_events() +{ + std::uint64_t eventCount = 0; + if (try_enter_event_loop()) + { + auto exitLoop = on_scope_exit([&] { exit_event_loop(); }); + + constexpr bool waitForEvent = false; + while (try_process_one_event(waitForEvent)) + { + ++eventCount; + } + } + + return eventCount; +} + +std::uint64_t cppcoro::io_service::process_one_event() +{ + std::uint64_t eventCount = 0; + if (try_enter_event_loop()) + { + auto exitLoop = on_scope_exit([&] { exit_event_loop(); }); + + constexpr bool waitForEvent = true; + if (try_process_one_event(waitForEvent)) + { + ++eventCount; + } + } + + return eventCount; +} + +std::uint64_t cppcoro::io_service::process_one_pending_event() +{ + std::uint64_t eventCount = 0; + if (try_enter_event_loop()) + { + auto exitLoop = on_scope_exit([&] { exit_event_loop(); }); + + constexpr bool waitForEvent = false; + if (try_process_one_event(waitForEvent)) + { + ++eventCount; + } + } + + return eventCount; +} + +void cppcoro::io_service::stop() noexcept +{ + const auto oldState = m_threadState.fetch_or(stop_requested_flag, std::memory_order_release); + if ((oldState & stop_requested_flag) == 0) + { + for (auto activeThreadCount = oldState / active_thread_count_increment; + activeThreadCount > 0; + --activeThreadCount) + { + post_wake_up_event(); + } + } +} + +void cppcoro::io_service::reset() +{ + const auto oldState = m_threadState.fetch_and(~stop_requested_flag, std::memory_order_relaxed); + + // Check that there were no active threads running the event loop. + assert(oldState == stop_requested_flag); +} + +bool cppcoro::io_service::is_stop_requested() const noexcept +{ + return (m_threadState.load(std::memory_order_acquire) & stop_requested_flag) != 0; +} + +void cppcoro::io_service::notify_work_started() noexcept +{ + m_workCount.fetch_add(1, std::memory_order_relaxed); +} + +void cppcoro::io_service::notify_work_finished() noexcept +{ + if (m_workCount.fetch_sub(1, std::memory_order_relaxed) == 1) + { + stop(); + } +} + +cppcoro::detail::win32::handle_t cppcoro::io_service::native_iocp_handle() noexcept +{ + return m_iocpHandle.handle(); +} + +#if CPPCORO_OS_WINNT + +void cppcoro::io_service::ensure_winsock_initialised() +{ + if (!m_winsockInitialised.load(std::memory_order_acquire)) + { + std::lock_guard lock(m_winsockInitialisationMutex); + if (!m_winsockInitialised.load(std::memory_order_acquire)) + { + const WORD requestedVersion = MAKEWORD(2, 2); + WSADATA winsockData; + const int result = ::WSAStartup(requestedVersion, &winsockData); + if (result == SOCKET_ERROR) + { + const int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "Error initialsing winsock: WSAStartup"); + } + + m_winsockInitialised.store(true, std::memory_order_release); + } + } +} + +#endif // CPPCORO_OS_WINNT + +void cppcoro::io_service::schedule_impl(schedule_operation* operation) noexcept +{ +#if CPPCORO_OS_WINNT + const BOOL ok = ::PostQueuedCompletionStatus( + m_iocpHandle.handle(), + 0, + reinterpret_cast(operation->m_awaiter.address()), + nullptr); + if (!ok) + { + // Failed to post to the I/O completion port. + // + // This is most-likely because the queue is currently full. + // + // We'll queue up the operation to a linked-list using a lock-free + // push and defer the dispatch to the completion port until some I/O + // thread next enters its event loop. + auto* head = m_scheduleOperations.load(std::memory_order_acquire); + do + { + operation->m_next = head; + } while (!m_scheduleOperations.compare_exchange_weak( + head, + operation, + std::memory_order_release, + std::memory_order_acquire)); + } +#endif +} + +void cppcoro::io_service::try_reschedule_overflow_operations() noexcept +{ +#if CPPCORO_OS_WINNT + auto* operation = m_scheduleOperations.exchange(nullptr, std::memory_order_acquire); + while (operation != nullptr) + { + auto* next = operation->m_next; + BOOL ok = ::PostQueuedCompletionStatus( + m_iocpHandle.handle(), + 0, + reinterpret_cast(operation->m_awaiter.address()), + nullptr); + if (!ok) + { + // Still unable to queue these operations. + // Put them back on the list of overflow operations. + auto* tail = operation; + while (tail->m_next != nullptr) + { + tail = tail->m_next; + } + + schedule_operation* head = nullptr; + while (!m_scheduleOperations.compare_exchange_weak( + head, + operation, + std::memory_order_release, + std::memory_order_relaxed)) + { + tail->m_next = head; + } + + return; + } + + operation = next; + } +#endif +} + +bool cppcoro::io_service::try_enter_event_loop() noexcept +{ + auto currentState = m_threadState.load(std::memory_order_relaxed); + do + { + if ((currentState & stop_requested_flag) != 0) + { + return false; + } + } while (!m_threadState.compare_exchange_weak( + currentState, + currentState + active_thread_count_increment, + std::memory_order_relaxed)); + + return true; +} + +void cppcoro::io_service::exit_event_loop() noexcept +{ + m_threadState.fetch_sub(active_thread_count_increment, std::memory_order_relaxed); +} + +bool cppcoro::io_service::try_process_one_event(bool waitForEvent) +{ +#if CPPCORO_OS_WINNT + if (is_stop_requested()) + { + return false; + } + + const DWORD timeout = waitForEvent ? INFINITE : 0; + + while (true) + { + // Check for any schedule_operation objects that were unable to be + // queued to the I/O completion port and try to requeue them now. + try_reschedule_overflow_operations(); + + DWORD numberOfBytesTransferred = 0; + ULONG_PTR completionKey = 0; + LPOVERLAPPED overlapped = nullptr; + BOOL ok = ::GetQueuedCompletionStatus( + m_iocpHandle.handle(), + &numberOfBytesTransferred, + &completionKey, + &overlapped, + timeout); + if (overlapped != nullptr) + { + DWORD errorCode = ok ? ERROR_SUCCESS : ::GetLastError(); + + auto* state = static_cast( + reinterpret_cast(overlapped)); + + state->m_callback( + state, + errorCode, + numberOfBytesTransferred, + completionKey); + + return true; + } + else if (ok) + { + if (completionKey != 0) + { + // This was a coroutine scheduled via a call to + // io_service::schedule(). + cppcoro::coroutine_handle<>::from_address( + reinterpret_cast(completionKey)).resume(); + return true; + } + + // Empty event is a wake-up request, typically associated with a + // request to exit the event loop. + // However, there may be spurious such events remaining in the queue + // from a previous call to stop() that has since been reset() so we + // need to check whether stop is still required. + if (is_stop_requested()) + { + return false; + } + } + else + { + const DWORD errorCode = ::GetLastError(); + if (errorCode == WAIT_TIMEOUT) + { + return false; + } + + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "Error retrieving item from io_service queue: GetQueuedCompletionStatus" + }; + } + } +#endif +} + +void cppcoro::io_service::post_wake_up_event() noexcept +{ +#if CPPCORO_OS_WINNT + // We intentionally ignore the return code here. + // + // Assume that if posting an event failed that it failed because the queue was full + // and the system is out of memory. In this case threads should find other events + // in the queue next time they check anyway and thus wake-up. + (void)::PostQueuedCompletionStatus(m_iocpHandle.handle(), 0, 0, nullptr); +#endif +} + +cppcoro::io_service::timer_thread_state* +cppcoro::io_service::ensure_timer_thread_started() +{ + auto* timerState = m_timerState.load(std::memory_order_acquire); + if (timerState == nullptr) + { + auto newTimerState = std::make_unique(); + if (m_timerState.compare_exchange_strong( + timerState, + newTimerState.get(), + std::memory_order_release, + std::memory_order_acquire)) + { + // We managed to install our timer_thread_state before some + // other thread did, don't free it here - it will be freed in + // the io_service destructor. + timerState = newTimerState.release(); + } + } + + return timerState; +} + +cppcoro::io_service::timer_thread_state::timer_thread_state() +#if CPPCORO_OS_WINNT + : m_wakeUpEvent(create_auto_reset_event()) + , m_waitableTimerEvent(create_waitable_timer_event()) +#endif + , m_newlyQueuedTimers(nullptr) + , m_timerCancellationRequested(false) + , m_shutDownRequested(false) + , m_thread([this] { this->run(); }) +{ +} + +cppcoro::io_service::timer_thread_state::~timer_thread_state() +{ + m_shutDownRequested.store(true, std::memory_order_release); + wake_up_timer_thread(); + m_thread.join(); +} + +void cppcoro::io_service::timer_thread_state::request_timer_cancellation() noexcept +{ + const bool wasTimerCancellationAlreadyRequested = + m_timerCancellationRequested.exchange(true, std::memory_order_release); + if (!wasTimerCancellationAlreadyRequested) + { + wake_up_timer_thread(); + } +} + +void cppcoro::io_service::timer_thread_state::run() noexcept +{ +#if CPPCORO_OS_WINNT + using clock = std::chrono::high_resolution_clock; + using time_point = clock::time_point; + + timer_queue timerQueue; + + const DWORD waitHandleCount = 2; + const HANDLE waitHandles[waitHandleCount] = + { + m_wakeUpEvent.handle(), + m_waitableTimerEvent.handle() + }; + + time_point lastSetWaitEventTime = time_point::max(); + + timed_schedule_operation* timersReadyToResume = nullptr; + + DWORD timeout = INFINITE; + while (!m_shutDownRequested.load(std::memory_order_relaxed)) + { + const DWORD waitResult = ::WaitForMultipleObjectsEx( + waitHandleCount, + waitHandles, + FALSE, // waitAll + timeout, + FALSE); // alertable + if (waitResult == WAIT_OBJECT_0 || waitResult == WAIT_FAILED) + { + // Wake-up event (WAIT_OBJECT_0) + // + // We are only woken up for: + // - handling timer cancellation + // - handling newly queued timers + // - shutdown + // + // We also handle WAIT_FAILED here so that we remain responsive + // to new timers and cancellation even if the OS fails to perform + // the wait operation for some reason. + + // Handle cancelled timers + if (m_timerCancellationRequested.exchange(false, std::memory_order_acquire)) + { + timerQueue.remove_cancelled_timers(timersReadyToResume); + } + + // Handle newly queued timers + auto* newTimers = m_newlyQueuedTimers.exchange(nullptr, std::memory_order_acquire); + while (newTimers != nullptr) + { + auto* timer = newTimers; + newTimers = timer->m_next; + + if (timer->m_cancellationToken.is_cancellation_requested()) + { + timer->m_next = timersReadyToResume; + timersReadyToResume = timer; + } + else + { + timerQueue.enqueue_timer(timer); + } + } + } + else if (waitResult == (WAIT_OBJECT_0 + 1)) + { + lastSetWaitEventTime = time_point::max(); + } + + if (!timerQueue.is_empty()) + { + time_point currentTime = clock::now(); + + timerQueue.dequeue_due_timers(currentTime, timersReadyToResume); + + if (!timerQueue.is_empty()) + { + auto earliestDueTime = timerQueue.earliest_due_time(); + assert(earliestDueTime > currentTime); + + // Set the waitable timer before trying to schedule any of the ready-to-run + // timers to avoid the concept of 'current time' on which we calculate the + // amount of time to wait until the next timer is ready. + if (earliestDueTime != lastSetWaitEventTime) + { + using ticks = std::chrono::duration>; + + auto timeUntilNextDueTime = earliestDueTime - currentTime; + + // Negative value indicates relative time. + LARGE_INTEGER dueTime; + dueTime.QuadPart = -std::chrono::duration_cast(timeUntilNextDueTime).count(); + + // Period of 0 indicates no repeat on the timer. + const LONG period = 0; + + // Don't wake the system from a suspended state just to + // raise the timer event. + const BOOL resumeFromSuspend = FALSE; + + const BOOL ok = ::SetWaitableTimer( + m_waitableTimerEvent.handle(), + &dueTime, + period, + nullptr, + nullptr, + resumeFromSuspend); + if (ok) + { + lastSetWaitEventTime = earliestDueTime; + timeout = INFINITE; + } + else + { + // Not sure what could cause the call to SetWaitableTimer() + // to fail here but we'll just try falling back to using + // the timeout parameter of the WaitForMultipleObjects() call. + // + // wake-up at least once every second and retry setting + // the timer at that point. + using namespace std::literals::chrono_literals; + if (timeUntilNextDueTime > 1s) + { + timeout = 1000; + } + else if (timeUntilNextDueTime > 1ms) + { + timeout = static_cast( + std::chrono::duration_cast( + timeUntilNextDueTime).count()); + } + else + { + timeout = 1; + } + } + } + } + } + + // Now schedule any ready-to-run timers. + while (timersReadyToResume != nullptr) + { + auto* timer = timersReadyToResume; + auto* nextTimer = timer->m_next; + + // Use 'release' memory order to ensure that any prior writes to + // m_next "happen before" any potential uses of that same memory + // back on the thread that is executing timed_schedule_operation::await_suspend() + // which has the synchronising 'acquire' semantics. + if (timer->m_refCount.fetch_sub(1, std::memory_order_release) == 1) + { + timer->m_scheduleOperation.m_service.schedule_impl( + &timer->m_scheduleOperation); + } + + timersReadyToResume = nextTimer; + } + } +#endif +} + +void cppcoro::io_service::timer_thread_state::wake_up_timer_thread() noexcept +{ +#if CPPCORO_OS_WINNT + (void)::SetEvent(m_wakeUpEvent.handle()); +#endif +} + +void cppcoro::io_service::schedule_operation::await_suspend( + cppcoro::coroutine_handle<> awaiter) noexcept +{ + m_awaiter = awaiter; + m_service.schedule_impl(this); +} + +cppcoro::io_service::timed_schedule_operation::timed_schedule_operation( + io_service& service, + std::chrono::high_resolution_clock::time_point resumeTime, + cppcoro::cancellation_token cancellationToken) noexcept + : m_scheduleOperation(service) + , m_resumeTime(resumeTime) + , m_cancellationToken(std::move(cancellationToken)) + , m_refCount(2) +{ +} + +cppcoro::io_service::timed_schedule_operation::timed_schedule_operation( + timed_schedule_operation&& other) noexcept + : m_scheduleOperation(std::move(other.m_scheduleOperation)) + , m_resumeTime(std::move(other.m_resumeTime)) + , m_cancellationToken(std::move(other.m_cancellationToken)) + , m_refCount(2) +{ +} + +cppcoro::io_service::timed_schedule_operation::~timed_schedule_operation() +{ +} + +bool cppcoro::io_service::timed_schedule_operation::await_ready() const noexcept +{ + return m_cancellationToken.is_cancellation_requested(); +} + +void cppcoro::io_service::timed_schedule_operation::await_suspend( + cppcoro::coroutine_handle<> awaiter) +{ + m_scheduleOperation.m_awaiter = awaiter; + + auto& service = m_scheduleOperation.m_service; + + // Ensure the timer state is initialised and the timer thread started. + auto* timerState = service.ensure_timer_thread_started(); + + if (m_cancellationToken.can_be_cancelled()) + { + m_cancellationRegistration.emplace(m_cancellationToken, [timerState] + { + timerState->request_timer_cancellation(); + }); + } + + // Queue the timer schedule to the queue of incoming new timers. + // + // We need to do a careful dance here because it could be possible + // that immediately after queueing the timer this thread could be + // context-switched out, the timer thread could pick it up and + // schedule it to be resumed, it could be resumed on an I/O thread + // and complete its work and the io_service could be destructed. + // All before we get to execute timerState.wake_up_timer_thread() + // below. To work around this race we use a reference-counter + // with initial value 2 and have both the timer thread and this + // thread decrement the count once the awaiter is ready to be + // rescheduled. Whichever thread decrements the ref-count to 0 + // is responsible for scheduling the awaiter for resumption. + + + // Not sure if we need 'acquire' semantics on this load and + // on the failure-case of the compare_exchange below. + // + // It could potentially be made 'release' if we can guarantee + // that a read-with 'acquire' semantics in the timer thread + // of the latest value will synchronise with all prior writes + // to that value that used 'release' semantics. + auto* prev = timerState->m_newlyQueuedTimers.load(std::memory_order_acquire); + do + { + m_next = prev; + } while (!timerState->m_newlyQueuedTimers.compare_exchange_weak( + prev, + this, + std::memory_order_release, + std::memory_order_acquire)); + + if (prev == nullptr) + { + timerState->wake_up_timer_thread(); + } + + // Use 'acquire' semantics here to synchronise with the 'release' + // operation performed on the timer thread to ensure that we have + // seen all potential writes to this object. Without this, it's + // possible that a write to the m_next field by the timer thread + // will race with subsequent writes to that same memory by this + // thread or whatever I/O thread resumes the coroutine. + if (m_refCount.fetch_sub(1, std::memory_order_acquire) == 1) + { + service.schedule_impl(&m_scheduleOperation); + } +} + +void cppcoro::io_service::timed_schedule_operation::await_resume() +{ + m_cancellationRegistration.reset(); + m_cancellationToken.throw_if_cancellation_requested(); +} diff --git a/lib/ip_address.cpp b/lib/ip_address.cpp new file mode 100644 index 0000000..38e490f --- /dev/null +++ b/lib/ip_address.cpp @@ -0,0 +1,27 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +std::string cppcoro::net::ip_address::to_string() const +{ + return is_ipv4() ? m_ipv4.to_string() : m_ipv6.to_string(); +} + +std::optional +cppcoro::net::ip_address::from_string(std::string_view string) noexcept +{ + if (auto ipv4 = ipv4_address::from_string(string); ipv4) + { + return *ipv4; + } + + if (auto ipv6 = ipv6_address::from_string(string); ipv6) + { + return *ipv6; + } + + return std::nullopt; +} diff --git a/lib/ip_endpoint.cpp b/lib/ip_endpoint.cpp new file mode 100644 index 0000000..b1b915e --- /dev/null +++ b/lib/ip_endpoint.cpp @@ -0,0 +1,27 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +std::string cppcoro::net::ip_endpoint::to_string() const +{ + return is_ipv4() ? m_ipv4.to_string() : m_ipv6.to_string(); +} + +std::optional +cppcoro::net::ip_endpoint::from_string(std::string_view string) noexcept +{ + if (auto ipv4 = ipv4_endpoint::from_string(string); ipv4) + { + return *ipv4; + } + + if (auto ipv6 = ipv6_endpoint::from_string(string); ipv6) + { + return *ipv6; + } + + return std::nullopt; +} diff --git a/lib/ipv4_address.cpp b/lib/ipv4_address.cpp new file mode 100644 index 0000000..686628b --- /dev/null +++ b/lib/ipv4_address.cpp @@ -0,0 +1,174 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +namespace +{ + namespace local + { + constexpr bool is_digit(char c) + { + return c >= '0' && c <= '9'; + } + + constexpr std::uint8_t digit_value(char c) + { + return static_cast(c - '0'); + } + } +} + +std::optional +cppcoro::net::ipv4_address::from_string(std::string_view string) noexcept +{ + if (string.empty()) return std::nullopt; + + if (!local::is_digit(string[0])) + { + return std::nullopt; + } + + const auto length = string.length(); + + std::uint8_t partValues[4]; + + if (string[0] == '0' && length > 1) + { + if (local::is_digit(string[1])) + { + // Octal format (not supported) + return std::nullopt; + } + else if (string[1] == 'x') + { + // Hexadecimal format (not supported) + return std::nullopt; + } + } + + // Parse the first integer. + // Could be a single 32-bit integer or first integer in a dotted decimal string. + + std::size_t pos = 0; + + { + constexpr std::uint32_t maxValue = 0xFFFFFFFFu / 10; + constexpr std::uint32_t maxDigit = 0xFFFFFFFFu % 10; + + std::uint32_t partValue = local::digit_value(string[pos]); + ++pos; + + while (pos < length && local::is_digit(string[pos])) + { + const auto digitValue = local::digit_value(string[pos]); + ++pos; + + // Check if this digit would overflow the 32-bit integer + if (partValue > maxValue || (partValue == maxValue && digitValue > maxDigit)) + { + return std::nullopt; + } + + partValue = (partValue * 10) + digitValue; + } + + if (pos == length) + { + // A single-integer string + return ipv4_address{ partValue }; + } + else if (partValue > 255) + { + // Not a valid first component of dotted decimal + return std::nullopt; + } + + partValues[0] = static_cast(partValue); + } + + for (int part = 1; part < 4; ++part) + { + if ((pos + 1) >= length || string[pos] != '.' || !local::is_digit(string[pos + 1])) + { + return std::nullopt; + } + + // Skip the '.' + ++pos; + + // Check for an octal format (not yet supported) + const bool isPartOctal = + (pos + 1) < length && + string[pos] == '0' && + local::is_digit(string[pos + 1]); + if (isPartOctal) + { + return std::nullopt; + } + + std::uint32_t partValue = local::digit_value(string[pos]); + ++pos; + if (pos < length && local::is_digit(string[pos])) + { + partValue = (partValue * 10) + local::digit_value(string[pos]); + ++pos; + if (pos < length && local::is_digit(string[pos])) + { + partValue = (partValue * 10) + local::digit_value(string[pos]); + if (partValue > 255) + { + return std::nullopt; + } + + ++pos; + } + } + + partValues[part] = static_cast(partValue); + } + + if (pos < length) + { + // Extra chars after end of a valid IPv4 string + return std::nullopt; + } + + return ipv4_address{ partValues }; +} + +std::string cppcoro::net::ipv4_address::to_string() const +{ + // Buffer is large enough to hold larges ip address + // "xxx.xxx.xxx.xxx" + char buffer[15]; + + char* c = &buffer[0]; + for (int i = 0; i < 4; ++i) + { + if (i > 0) + { + *c++ = '.'; + } + + if (m_bytes[i] >= 100) + { + *c++ = '0' + (m_bytes[i] / 100); + *c++ = '0' + (m_bytes[i] % 100) / 10; + *c++ = '0' + (m_bytes[i] % 10); + } + else if (m_bytes[i] >= 10) + { + *c++ = '0' + (m_bytes[i] / 10); + *c++ = '0' + (m_bytes[i] % 10); + } + else + { + *c++ = '0' + m_bytes[i]; + } + } + + return std::string{ &buffer[0], c }; +} diff --git a/lib/ipv4_endpoint.cpp b/lib/ipv4_endpoint.cpp new file mode 100644 index 0000000..61d6ead --- /dev/null +++ b/lib/ipv4_endpoint.cpp @@ -0,0 +1,71 @@ +/////////////////////////////////////////////////////////////////////////////// +// Kt C++ Library +// Copyright (c) 2015 Lewis Baker +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include + +namespace +{ + namespace local + { + bool is_digit(char c) + { + return c >= '0' && c <= '9'; + } + + std::uint8_t digit_value(char c) + { + return static_cast(c - '0'); + } + + std::optional parse_port(std::string_view string) + { + if (string.empty()) return std::nullopt; + + std::uint32_t value = 0; + for (auto c : string) + { + if (!is_digit(c)) return std::nullopt; + value = value * 10 + digit_value(c); + if (value > 0xFFFFu) return std::nullopt; + } + + return static_cast(value); + } + } +} + +std::string cppcoro::net::ipv4_endpoint::to_string() const +{ + auto s = m_address.to_string(); + s.push_back(':'); + s.append(std::to_string(m_port)); + return s; +} + +std::optional +cppcoro::net::ipv4_endpoint::from_string(std::string_view string) noexcept +{ + auto colonPos = string.find(':'); + if (colonPos == std::string_view::npos) + { + return std::nullopt; + } + + auto address = ipv4_address::from_string(string.substr(0, colonPos)); + if (!address) + { + return std::nullopt; + } + + auto port = local::parse_port(string.substr(colonPos + 1)); + if (!port) + { + return std::nullopt; + } + + return ipv4_endpoint{ *address, *port }; +} diff --git a/lib/ipv6_address.cpp b/lib/ipv6_address.cpp new file mode 100644 index 0000000..7d5216e --- /dev/null +++ b/lib/ipv6_address.cpp @@ -0,0 +1,362 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include + +namespace +{ + namespace local + { + constexpr bool is_digit(char c) + { + return c >= '0' && c <= '9'; + } + + constexpr std::uint8_t digit_value(char c) + { + return static_cast(c - '0'); + } + + std::optional try_parse_hex_digit(char c) + { + if (c >= '0' && c <= '9') + { + return static_cast(c - '0'); + } + else if (c >= 'a' && c <= 'f') + { + return static_cast(c - 'a' + 10); + } + else if (c >= 'A' && c <= 'F') + { + return static_cast(c - 'A' + 10); + } + + return std::nullopt; + } + + char hex_char(std::uint8_t value) + { + return value < 10 ? + static_cast('0' + value) : + static_cast('a' + value - 10); + } + } +} + +std::optional +cppcoro::net::ipv6_address::from_string(std::string_view string) noexcept +{ + // Longest possible valid IPv6 string is + // "xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:nnn.nnn.nnn.nnn" + constexpr std::size_t maxLength = 45; + + if (string.empty() || string.length() > maxLength) + { + return std::nullopt; + } + + const std::size_t length = string.length(); + + std::optional doubleColonPos; + + std::size_t pos = 0; + + if (length >= 2 && string[0] == ':' && string[1] == ':') + { + doubleColonPos = 0; + pos = 2; + } + + int partCount = 0; + std::uint16_t parts[8] = { 0 }; + + while (pos < length && partCount < 8) + { + std::uint8_t digits[4]; + int digitCount = 0; + auto digit = local::try_parse_hex_digit(string[pos]); + if (!digit) + { + return std::nullopt; + } + + do + { + digits[digitCount] = *digit; + ++digitCount; + ++pos; + } while (digitCount < 4 && pos < length && (digit = local::try_parse_hex_digit(string[pos]))); + + // If we're not at the end of the string then there must either be a ':' or a '.' next + // followed by the next part. + if (pos < length) + { + // Check if there's room for anything after the separator. + if ((pos + 1) == length) + { + return std::nullopt; + } + + if (string[pos] == ':') + { + ++pos; + if (string[pos] == ':') + { + if (doubleColonPos) + { + // This is a second double-colon, which is invalid. + return std::nullopt; + } + + doubleColonPos = partCount + 1; + ++pos; + } + } + else if (string[pos] == '.') + { + // Treat the current set of digits as decimal digits and parse + // the remaining three groups as dotted decimal notation. + + // Decimal notation produces two 16-bit parts. + // If we already have more than 6 parts then we'll end up + // with too many. + if (partCount > 6) + { + return std::nullopt; + } + + // Check for over-long or octal notation. + if (digitCount > 3 || (digitCount > 1 && digits[0] == 0)) + { + return std::nullopt; + } + + // Check that digits are valid decimal digits + if (digits[0] > 9 || + (digitCount > 1 && digits[1] > 9) || + (digitCount == 3 && digits[2] > 9)) + { + return std::nullopt; + } + + std::uint16_t decimalParts[4]; + + { + decimalParts[0] = digits[0]; + for (int i = 1; i < digitCount; ++i) + { + decimalParts[0] *= 10; + decimalParts[0] += digits[i]; + } + + if (decimalParts[0] > 255) + { + return std::nullopt; + } + } + + for (int decimalPart = 1; decimalPart < 4; ++decimalPart) + { + if (string[pos] != '.') + { + return std::nullopt; + } + + ++pos; + + if (pos == length || !local::is_digit(string[pos])) + { + // Expected a number after a dot. + return std::nullopt; + } + + const bool hasLeadingZero = string[pos] == '0'; + + decimalParts[decimalPart] = local::digit_value(string[pos]); + ++pos; + digitCount = 1; + while (digitCount < 3 && pos < length && local::is_digit(string[pos])) + { + decimalParts[decimalPart] *= 10; + decimalParts[decimalPart] += local::digit_value(string[pos]); + ++pos; + ++digitCount; + } + + if (decimalParts[decimalPart] > 255) + { + return std::nullopt; + } + + // Detect octal-style number (redundant leading zero) + if (digitCount > 1 && hasLeadingZero) + { + return std::nullopt; + } + } + + parts[partCount] = (decimalParts[0] << 8) + decimalParts[1]; + parts[partCount + 1] = (decimalParts[2] << 8) + decimalParts[3]; + partCount += 2; + + // Dotted decimal notation only appears at end. + // Don't parse any more of the string. + break; + } + else + { + // Invalid separator. + return std::nullopt; + } + } + + // Current part was made up of hex-digits. + std::uint16_t partValue = digits[0]; + for (int i = 1; i < digitCount; ++i) + { + partValue = partValue * 16 + digits[i]; + } + + parts[partCount] = partValue; + ++partCount; + } + + // Finished parsing the IPv6 address, we should have consumed all of the string. + if (pos < length) + { + return std::nullopt; + } + + if (partCount < 8) + { + if (!doubleColonPos) + { + return std::nullopt; + } + + const int preCount = *doubleColonPos; + + //CPPCORO_ASSUME(preCount <= partCount); + + const int postCount = partCount - preCount; + const int zeroCount = 8 - preCount - postCount; + + // Move parts after double colon down to the end. + for (int i = 0; i < postCount; ++i) + { + parts[7 - i] = parts[7 - zeroCount - i]; + } + + // Fill gap with zeroes. + for (int i = 0; i < zeroCount; ++i) + { + parts[preCount + i] = 0; + } + } + else if (doubleColonPos) + { + return std::nullopt; + } + + return ipv6_address{ parts }; +} + +std::string cppcoro::net::ipv6_address::to_string() const +{ + std::uint32_t longestZeroRunStart = 0; + std::uint32_t longestZeroRunLength = 0; + for (std::uint32_t i = 0; i < 8; ) + { + if (m_bytes[2 * i] == 0 && m_bytes[2 * i + 1] == 0) + { + const std::uint32_t zeroRunStart = i; + ++i; + while (i < 8 && m_bytes[2 * i] == 0 && m_bytes[2 * i + 1] == 0) + { + ++i; + } + + std::uint32_t zeroRunLength = i - zeroRunStart; + if (zeroRunLength > longestZeroRunLength) + { + longestZeroRunLength = zeroRunLength; + longestZeroRunStart = zeroRunStart; + } + } + else + { + ++i; + } + } + + // Longest string will be 8 x 4 digits + 7 ':' separators + char buffer[40]; + + char* c = &buffer[0]; + + auto appendPart = [&](std::uint32_t index) + { + const std::uint8_t highByte = m_bytes[index * 2]; + const std::uint8_t lowByte = m_bytes[index * 2 + 1]; + + // Don't output leading zero hex digits in the part string. + if (highByte > 0 || lowByte > 15) + { + if (highByte > 0) + { + if (highByte > 15) + { + *c++ = local::hex_char(highByte >> 4); + } + *c++ = local::hex_char(highByte & 0xF); + } + *c++ = local::hex_char(lowByte >> 4); + } + *c++ = local::hex_char(lowByte & 0xF); + }; + + if (longestZeroRunLength >= 2) + { + for (std::uint32_t i = 0; i < longestZeroRunStart; ++i) + { + if (i > 0) + { + *c++ = ':'; + } + + appendPart(i); + } + + *c++ = ':'; + *c++ = ':'; + + for (std::uint32_t i = longestZeroRunStart + longestZeroRunLength; i < 8; ++i) + { + appendPart(i); + + if (i < 7) + { + *c++ = ':'; + } + } + } + else + { + appendPart(0); + for (std::uint32_t i = 1; i < 8; ++i) + { + *c++ = ':'; + appendPart(i); + } + } + + assert((c - &buffer[0]) <= sizeof(buffer)); + + return std::string{ &buffer[0], c }; +} diff --git a/lib/ipv6_endpoint.cpp b/lib/ipv6_endpoint.cpp new file mode 100644 index 0000000..4f63a1d --- /dev/null +++ b/lib/ipv6_endpoint.cpp @@ -0,0 +1,84 @@ +/////////////////////////////////////////////////////////////////////////////// +// Kt C++ Library +// Copyright (c) 2015 Lewis Baker +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include + +namespace +{ + namespace local + { + bool is_digit(char c) + { + return c >= '0' && c <= '9'; + } + + std::uint8_t digit_value(char c) + { + return static_cast(c - '0'); + } + + std::optional parse_port(std::string_view string) + { + if (string.empty()) return std::nullopt; + + std::uint32_t value = 0; + for (auto c : string) + { + if (!is_digit(c)) return std::nullopt; + value = value * 10 + digit_value(c); + if (value > 0xFFFFu) return std::nullopt; + } + + return static_cast(value); + } + } +} + +std::string cppcoro::net::ipv6_endpoint::to_string() const +{ + std::string result; + result.push_back('['); + result += m_address.to_string(); + result += "]:"; + result += std::to_string(m_port); + return result; +} + +std::optional +cppcoro::net::ipv6_endpoint::from_string(std::string_view string) noexcept +{ + // Shortest valid endpoint is "[::]:0" + if (string.size() < 6) + { + return std::nullopt; + } + + if (string[0] != '[') + { + return std::nullopt; + } + + auto closeBracketPos = string.find("]:", 1); + if (closeBracketPos == std::string_view::npos) + { + return std::nullopt; + } + + auto address = ipv6_address::from_string(string.substr(1, closeBracketPos - 1)); + if (!address) + { + return std::nullopt; + } + + auto port = local::parse_port(string.substr(closeBracketPos + 2)); + if (!port) + { + return std::nullopt; + } + + return ipv6_endpoint{ *address, *port }; +} diff --git a/lib/lightweight_manual_reset_event.cpp b/lib/lightweight_manual_reset_event.cpp new file mode 100644 index 0000000..35b8f16 --- /dev/null +++ b/lib/lightweight_manual_reset_event.cpp @@ -0,0 +1,254 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include + +# if CPPCORO_OS_WINNT >= 0x0602 + +cppcoro::detail::lightweight_manual_reset_event::lightweight_manual_reset_event(bool initiallySet) + : m_value(initiallySet ? 1 : 0) +{} + +cppcoro::detail::lightweight_manual_reset_event::~lightweight_manual_reset_event() +{ +} + +void cppcoro::detail::lightweight_manual_reset_event::set() noexcept +{ + m_value.store(1, std::memory_order_release); + ::WakeByAddressAll(&m_value); +} + +void cppcoro::detail::lightweight_manual_reset_event::reset() noexcept +{ + m_value.store(0, std::memory_order_relaxed); +} + +void cppcoro::detail::lightweight_manual_reset_event::wait() noexcept +{ + wait({}, std::chrono::milliseconds(0)); +} +void cppcoro::detail::lightweight_manual_reset_event::wait( + std::span srvs, std::chrono::system_clock::duration step) noexcept +{ + const DWORD stepMs = static_cast(std::chrono::duration_cast(step).count()); + DWORD delay = srvs.empty() ? INFINITE : stepMs; + + // Wait in a loop as WaitOnAddress() can have spurious wake-ups. + int value = m_value.load(std::memory_order_acquire); + BOOL ok = TRUE; + while (value == 0) + { + if (!srvs.empty()) + { + //if there was one processed event, pass 0 timeout so we get a chance to process more, quickly + //otherwise, wait the full step + uint64_t tasks = 0; + for (auto& srv: srvs) + tasks += srv.process_one_pending_event(); + + delay = tasks > 0 ? 0 : stepMs; + } + else if (!ok) + { + // Previous call to WaitOnAddress() failed for some reason. + // Put thread to sleep to avoid sitting in a busy loop if it keeps failing. + ::Sleep(1); + } + + ok = ::WaitOnAddress(&m_value, &value, sizeof(m_value), delay); + value = m_value.load(std::memory_order_acquire); + } +} + +# else + +cppcoro::detail::lightweight_manual_reset_event::lightweight_manual_reset_event(bool initiallySet) + : m_eventHandle(::CreateEventW(nullptr, TRUE, initiallySet, nullptr)) +{ + if (m_eventHandle == NULL) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category() + }; + } +} + +cppcoro::detail::lightweight_manual_reset_event::~lightweight_manual_reset_event() +{ + // Ignore failure to close the object. + // We can't do much here as we want destructor to be noexcept. + (void)::CloseHandle(m_eventHandle); +} + +void cppcoro::detail::lightweight_manual_reset_event::set() noexcept +{ + if (!::SetEvent(m_eventHandle)) + { + std::abort(); + } +} + +void cppcoro::detail::lightweight_manual_reset_event::reset() noexcept +{ + if (!::ResetEvent(m_eventHandle)) + { + std::abort(); + } +} + +void cppcoro::detail::lightweight_manual_reset_event::wait() noexcept +{ + constexpr BOOL alertable = FALSE; + DWORD waitResult = ::WaitForSingleObjectEx(m_eventHandle, INFINITE, alertable); + if (waitResult == WAIT_FAILED) + { + std::abort(); + } +} + +# endif + +#elif CPPCORO_OS_LINUX + +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + namespace local + { + // No futex() function provided by libc. + // Wrap the syscall ourselves here. + int futex( + int* UserAddress, + int FutexOperation, + int Value, + const struct timespec* timeout, + int* UserAddress2, + int Value3) + { + return syscall( + SYS_futex, + UserAddress, + FutexOperation, + Value, + timeout, + UserAddress2, + Value3); + } + } +} + +cppcoro::detail::lightweight_manual_reset_event::lightweight_manual_reset_event(bool initiallySet) + : m_value(initiallySet ? 1 : 0) +{} + +cppcoro::detail::lightweight_manual_reset_event::~lightweight_manual_reset_event() +{ +} + +void cppcoro::detail::lightweight_manual_reset_event::set() noexcept +{ + m_value.store(1, std::memory_order_release); + + constexpr int numberOfWaitersToWakeUp = INT_MAX; + + [[maybe_unused]] int numberOfWaitersWokenUp = local::futex( + reinterpret_cast(&m_value), + FUTEX_WAKE_PRIVATE, + numberOfWaitersToWakeUp, + nullptr, + nullptr, + 0); + + // There are no errors expected here unless this class (or the caller) + // has done something wrong. + assert(numberOfWaitersWokenUp != -1); +} + +void cppcoro::detail::lightweight_manual_reset_event::reset() noexcept +{ + m_value.store(0, std::memory_order_relaxed); +} + +void cppcoro::detail::lightweight_manual_reset_event::wait() noexcept +{ + // Wait in a loop as futex() can have spurious wake-ups. + int oldValue = m_value.load(std::memory_order_acquire); + while (oldValue == 0) + { + int result = local::futex( + reinterpret_cast(&m_value), + FUTEX_WAIT_PRIVATE, + oldValue, + nullptr, + nullptr, + 0); + if (result == -1) + { + if (errno == EAGAIN) + { + // The state was changed from zero before we could wait. + // Must have been changed to 1. + return; + } + + // Other errors we'll treat as transient and just read the + // value and go around the loop again. + } + + oldValue = m_value.load(std::memory_order_acquire); + } +} + +#else + +cppcoro::detail::lightweight_manual_reset_event::lightweight_manual_reset_event(bool initiallySet) + : m_isSet(initiallySet) +{ +} + +cppcoro::detail::lightweight_manual_reset_event::~lightweight_manual_reset_event() +{ +} + +void cppcoro::detail::lightweight_manual_reset_event::set() noexcept +{ + std::lock_guard lock(m_mutex); + m_isSet = true; + m_cv.notify_all(); +} + +void cppcoro::detail::lightweight_manual_reset_event::reset() noexcept +{ + std::lock_guard lock(m_mutex); + m_isSet = false; +} + +void cppcoro::detail::lightweight_manual_reset_event::wait() noexcept +{ + std::unique_lock lock(m_mutex); + m_cv.wait(lock, [this] { return m_isSet; }); +} + +#endif diff --git a/lib/read_only_file.cpp b/lib/read_only_file.cpp new file mode 100644 index 0000000..b278365 --- /dev/null +++ b/lib/read_only_file.cpp @@ -0,0 +1,36 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include + +cppcoro::read_only_file cppcoro::read_only_file::open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_share_mode shareMode, + file_buffering_mode bufferingMode) +{ + return read_only_file(file::open( + GENERIC_READ, + ioService, + path, + file_open_mode::open_existing, + shareMode, + bufferingMode)); +} + +cppcoro::read_only_file::read_only_file( + detail::win32::safe_handle&& fileHandle) noexcept + : file(std::move(fileHandle)) + , readable_file(detail::win32::safe_handle{}) +{ +} + +#endif diff --git a/lib/read_write_file.cpp b/lib/read_write_file.cpp new file mode 100644 index 0000000..13838ea --- /dev/null +++ b/lib/read_write_file.cpp @@ -0,0 +1,38 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include + +cppcoro::read_write_file cppcoro::read_write_file::open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode, + file_share_mode shareMode, + file_buffering_mode bufferingMode) +{ + return read_write_file(file::open( + GENERIC_READ | GENERIC_WRITE, + ioService, + path, + openMode, + shareMode, + bufferingMode)); +} + +cppcoro::read_write_file::read_write_file( + detail::win32::safe_handle&& fileHandle) noexcept + : file(std::move(fileHandle)) + , readable_file(detail::win32::safe_handle{}) + , writable_file(detail::win32::safe_handle{}) +{ +} + +#endif diff --git a/lib/readable_file.cpp b/lib/readable_file.cpp new file mode 100644 index 0000000..12b90f2 --- /dev/null +++ b/lib/readable_file.cpp @@ -0,0 +1,36 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#if CPPCORO_OS_WINNT + +cppcoro::file_read_operation cppcoro::readable_file::read( + std::uint64_t offset, + void* buffer, + std::size_t byteCount) const noexcept +{ + return file_read_operation( + m_fileHandle.handle(), + offset, + buffer, + byteCount); +} + +cppcoro::file_read_operation_cancellable cppcoro::readable_file::read( + std::uint64_t offset, + void* buffer, + std::size_t byteCount, + cancellation_token ct) const noexcept +{ + return file_read_operation_cancellable( + m_fileHandle.handle(), + offset, + buffer, + byteCount, + std::move(ct)); +} + +#endif diff --git a/lib/socket.cpp b/lib/socket.cpp new file mode 100644 index 0000000..fd438a1 --- /dev/null +++ b/lib/socket.cpp @@ -0,0 +1,493 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include "socket_helpers.hpp" + +#if CPPCORO_OS_WINNT +# include +# include +# include +# include + +namespace +{ + namespace local + { + std::tuple create_socket( + int addressFamily, + int socketType, + int protocol, + HANDLE ioCompletionPort) + { + // Enumerate available protocol providers for the specified socket type. + + WSAPROTOCOL_INFOW stackInfos[4]; + std::unique_ptr heapInfos; + WSAPROTOCOL_INFOW* selectedProtocolInfo = nullptr; + + { + INT protocols[] = { protocol, 0 }; + DWORD bufferSize = sizeof(stackInfos); + WSAPROTOCOL_INFOW* infos = stackInfos; + + int protocolCount = ::WSAEnumProtocolsW(protocols, infos, &bufferSize); + if (protocolCount == SOCKET_ERROR) + { + int errorCode = ::WSAGetLastError(); + if (errorCode == WSAENOBUFS) + { + DWORD requiredElementCount = bufferSize / sizeof(WSAPROTOCOL_INFOW); + heapInfos = std::make_unique(requiredElementCount); + bufferSize = requiredElementCount * sizeof(WSAPROTOCOL_INFOW); + infos = heapInfos.get(); + protocolCount = ::WSAEnumProtocolsW(protocols, infos, &bufferSize); + if (protocolCount == SOCKET_ERROR) + { + errorCode = ::WSAGetLastError(); + } + } + + if (protocolCount == SOCKET_ERROR) + { + throw std::system_error( + errorCode, + std::system_category(), + "Error creating socket: WSAEnumProtocolsW"); + } + } + + if (protocolCount == 0) + { + throw std::system_error( + std::make_error_code(std::errc::protocol_not_supported)); + } + + for (int i = 0; i < protocolCount; ++i) + { + auto& info = infos[i]; + if (info.iAddressFamily == addressFamily && info.iProtocol == protocol && info.iSocketType == socketType) + { + selectedProtocolInfo = &info; + break; + } + } + + if (selectedProtocolInfo == nullptr) + { + throw std::system_error( + std::make_error_code(std::errc::address_family_not_supported)); + } + } + + // WSA_FLAG_NO_HANDLE_INHERIT for SDKs earlier than Windows 7. + constexpr DWORD flagNoInherit = 0x80; + + const DWORD flags = WSA_FLAG_OVERLAPPED | flagNoInherit; + + const SOCKET socketHandle = ::WSASocketW( + addressFamily, socketType, protocol, selectedProtocolInfo, 0, flags); + if (socketHandle == INVALID_SOCKET) + { + const int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "Error creating socket: WSASocketW"); + } + + auto closeSocketOnFailure = cppcoro::on_scope_failure([&] + { + ::closesocket(socketHandle); + }); + + // This is needed on operating systems earlier than Windows 7 to prevent + // socket handles from being inherited. On Windows 7 or later this is + // redundant as the WSA_FLAG_NO_HANDLE_INHERIT flag passed to creation + // above causes the socket to be atomically created with this flag cleared. + if (!::SetHandleInformation((HANDLE)socketHandle, HANDLE_FLAG_INHERIT, 0)) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "Error creating socket: SetHandleInformation"); + } + + // Associate the socket with the I/O completion port. + { + const HANDLE result = ::CreateIoCompletionPort( + (HANDLE)socketHandle, + ioCompletionPort, + ULONG_PTR(0), + DWORD(0)); + if (result == nullptr) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error( + static_cast(errorCode), + std::system_category(), + "Error creating socket: CreateIoCompletionPort"); + } + } + + const bool skipCompletionPortOnSuccess = + (selectedProtocolInfo->dwServiceFlags1 & XP1_IFS_HANDLES) != 0; + + { + UCHAR completionModeFlags = FILE_SKIP_SET_EVENT_ON_HANDLE; + if (skipCompletionPortOnSuccess) + { + completionModeFlags |= FILE_SKIP_COMPLETION_PORT_ON_SUCCESS; + } + + const BOOL ok = ::SetFileCompletionNotificationModes( + (HANDLE)socketHandle, + completionModeFlags); + if (!ok) + { + const DWORD errorCode = ::GetLastError(); + throw std::system_error( + static_cast(errorCode), + std::system_category(), + "Error creating socket: SetFileCompletionNotificationModes"); + } + } + + if (socketType == SOCK_STREAM) + { + // Turn off linger so that the destructor doesn't block while closing + // the socket or silently continue to flush remaining data in the + // background after ::closesocket() is called, which could fail and + // we'd never know about it. + // We expect clients to call Disconnect() or use CloseSend() to cleanly + // shut-down connections instead. + BOOL value = TRUE; + const int result = ::setsockopt(socketHandle, + SOL_SOCKET, + SO_DONTLINGER, + reinterpret_cast(&value), + sizeof(value)); + if (result == SOCKET_ERROR) + { + const int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "Error creating socket: setsockopt(SO_DONTLINGER)"); + } + } + + return std::make_tuple(socketHandle, skipCompletionPortOnSuccess); + } + } +} + +cppcoro::net::socket cppcoro::net::socket::create_tcpv4(io_service& ioSvc) +{ + ioSvc.ensure_winsock_initialised(); + + auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( + AF_INET, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle()); + + socket result(socketHandle, skipCompletionPortOnSuccess); + result.m_localEndPoint = ipv4_endpoint(); + result.m_remoteEndPoint = ipv4_endpoint(); + return result; +} + +cppcoro::net::socket cppcoro::net::socket::create_tcpv6(io_service& ioSvc) +{ + ioSvc.ensure_winsock_initialised(); + + auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( + AF_INET6, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle()); + + socket result(socketHandle, skipCompletionPortOnSuccess); + result.m_localEndPoint = ipv6_endpoint(); + result.m_remoteEndPoint = ipv6_endpoint(); + return result; +} + +cppcoro::net::socket cppcoro::net::socket::create_udpv4(io_service& ioSvc) +{ + ioSvc.ensure_winsock_initialised(); + + auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( + AF_INET, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle()); + + socket result(socketHandle, skipCompletionPortOnSuccess); + result.m_localEndPoint = ipv4_endpoint(); + result.m_remoteEndPoint = ipv4_endpoint(); + return result; +} + +cppcoro::net::socket cppcoro::net::socket::create_udpv6(io_service& ioSvc) +{ + ioSvc.ensure_winsock_initialised(); + + auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( + AF_INET6, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle()); + + socket result(socketHandle, skipCompletionPortOnSuccess); + result.m_localEndPoint = ipv6_endpoint(); + result.m_remoteEndPoint = ipv6_endpoint(); + return result; +} + +cppcoro::net::socket::socket(socket&& other) noexcept + : m_handle(std::exchange(other.m_handle, INVALID_SOCKET)) + , m_skipCompletionOnSuccess(other.m_skipCompletionOnSuccess) + , m_localEndPoint(std::move(other.m_localEndPoint)) + , m_remoteEndPoint(std::move(other.m_remoteEndPoint)) +{} + +cppcoro::net::socket::~socket() +{ + if (m_handle != INVALID_SOCKET) + { + ::closesocket(m_handle); + } +} + +cppcoro::net::socket& +cppcoro::net::socket::operator=(socket&& other) noexcept +{ + auto handle = std::exchange(other.m_handle, INVALID_SOCKET); + if (m_handle != INVALID_SOCKET) + { + ::closesocket(m_handle); + } + + m_handle = handle; + m_skipCompletionOnSuccess = other.m_skipCompletionOnSuccess; + m_localEndPoint = other.m_localEndPoint; + m_remoteEndPoint = other.m_remoteEndPoint; + + return *this; +} + +void cppcoro::net::socket::bind(const ip_endpoint& localEndPoint) +{ + SOCKADDR_STORAGE sockaddrStorage = { 0 }; + SOCKADDR* sockaddr = reinterpret_cast(&sockaddrStorage); + if (localEndPoint.is_ipv4()) + { + SOCKADDR_IN& ipv4Sockaddr = *reinterpret_cast(sockaddr); + ipv4Sockaddr.sin_family = AF_INET; + std::memcpy(&ipv4Sockaddr.sin_addr, localEndPoint.to_ipv4().address().bytes(), 4); + ipv4Sockaddr.sin_port = localEndPoint.to_ipv4().port(); + } + else + { + SOCKADDR_IN6& ipv6Sockaddr = *reinterpret_cast(sockaddr); + ipv6Sockaddr.sin6_family = AF_INET6; + std::memcpy(&ipv6Sockaddr.sin6_addr, localEndPoint.to_ipv6().address().bytes(), 16); + ipv6Sockaddr.sin6_port = localEndPoint.to_ipv6().port(); + } + + int result = ::bind(m_handle, sockaddr, sizeof(sockaddrStorage)); + if (result != 0) + { + // WSANOTINITIALISED: WSAStartup not called + // WSAENETDOWN: network subsystem failed + // WSAEACCES: access denied + // WSAEADDRINUSE: port in use + // WSAEADDRNOTAVAIL: address is not an address that can be bound to + // WSAEFAULT: invalid pointer passed to bind() + // WSAEINPROGRESS: a callback is in progress + // WSAEINVAL: socket already bound + // WSAENOBUFS: system failed to allocate memory + // WSAENOTSOCK: socket was not a valid socket. + int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "Error binding to endpoint: bind()"); + } + + int sockaddrLen = sizeof(sockaddrStorage); + result = ::getsockname(m_handle, sockaddr, &sockaddrLen); + if (result == 0) + { + m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(*sockaddr); + } + else + { + m_localEndPoint = localEndPoint; + } +} + +void cppcoro::net::socket::listen() +{ + int result = ::listen(m_handle, SOMAXCONN); + if (result != 0) + { + int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "Failed to start listening on bound endpoint: listen"); + } +} + +void cppcoro::net::socket::listen(std::uint32_t backlog) +{ + if (backlog > 0x7FFFFFFF) + { + backlog = 0x7FFFFFFF; + } + + int result = ::listen(m_handle, (int)backlog); + if (result != 0) + { + // WSANOTINITIALISED: WSAStartup not called + // WSAENETDOWN: network subsystem failed + // WSAEADDRINUSE: port in use + // WSAEINPROGRESS: a callback is in progress + // WSAEINVAL: socket not yet bound + // WSAEISCONN: socket already connected + // WSAEMFILE: no more socket descriptors available + // WSAENOBUFS: system failed to allocate memory + // WSAENOTSOCK: socket was not a valid socket. + // WSAEOPNOTSUPP: The socket does not support listening + + int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "Failed to start listening on bound endpoint: listen"); + } +} + +cppcoro::net::socket_accept_operation +cppcoro::net::socket::accept(socket& acceptingSocket) noexcept +{ + return socket_accept_operation{ *this, acceptingSocket }; +} + +cppcoro::net::socket_accept_operation_cancellable +cppcoro::net::socket::accept(socket& acceptingSocket, cancellation_token ct) noexcept +{ + return socket_accept_operation_cancellable{ *this, acceptingSocket, std::move(ct) }; +} + +cppcoro::net::socket_connect_operation +cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint) noexcept +{ + return socket_connect_operation{ *this, remoteEndPoint }; +} + +cppcoro::net::socket_connect_operation_cancellable +cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint, cancellation_token ct) noexcept +{ + return socket_connect_operation_cancellable{ *this, remoteEndPoint, std::move(ct) }; +} + +cppcoro::net::socket_disconnect_operation +cppcoro::net::socket::disconnect() noexcept +{ + return socket_disconnect_operation(*this); +} + +cppcoro::net::socket_disconnect_operation_cancellable +cppcoro::net::socket::disconnect(cancellation_token ct) noexcept +{ + return socket_disconnect_operation_cancellable{ *this, std::move(ct) }; +} + +cppcoro::net::socket_send_operation +cppcoro::net::socket::send(const void* buffer, std::size_t byteCount) noexcept +{ + return socket_send_operation{ *this, buffer, byteCount }; +} + +cppcoro::net::socket_send_operation_cancellable +cppcoro::net::socket::send(const void* buffer, std::size_t byteCount, cancellation_token ct) noexcept +{ + return socket_send_operation_cancellable{ *this, buffer, byteCount, std::move(ct) }; +} + +cppcoro::net::socket_recv_operation +cppcoro::net::socket::recv(void* buffer, std::size_t byteCount) noexcept +{ + return socket_recv_operation{ *this, buffer, byteCount }; +} + +cppcoro::net::socket_recv_operation_cancellable +cppcoro::net::socket::recv(void* buffer, std::size_t byteCount, cancellation_token ct) noexcept +{ + return socket_recv_operation_cancellable{ *this, buffer, byteCount, std::move(ct) }; +} + +cppcoro::net::socket_recv_from_operation +cppcoro::net::socket::recv_from(void* buffer, std::size_t byteCount) noexcept +{ + return socket_recv_from_operation{ *this, buffer, byteCount }; +} + +cppcoro::net::socket_recv_from_operation_cancellable +cppcoro::net::socket::recv_from(void* buffer, std::size_t byteCount, cancellation_token ct) noexcept +{ + return socket_recv_from_operation_cancellable{ *this, buffer, byteCount, std::move(ct) }; +} + +cppcoro::net::socket_send_to_operation +cppcoro::net::socket::send_to(const ip_endpoint& destination, const void* buffer, std::size_t byteCount) noexcept +{ + return socket_send_to_operation{ *this, destination, buffer, byteCount }; +} + +cppcoro::net::socket_send_to_operation_cancellable +cppcoro::net::socket::send_to(const ip_endpoint& destination, const void* buffer, std::size_t byteCount, cancellation_token ct) noexcept +{ + return socket_send_to_operation_cancellable{ *this, destination, buffer, byteCount, std::move(ct) }; +} + +void cppcoro::net::socket::close_send() +{ + int result = ::shutdown(m_handle, SD_SEND); + if (result == SOCKET_ERROR) + { + int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "failed to close socket send stream: shutdown(SD_SEND)"); + } +} + +void cppcoro::net::socket::close_recv() +{ + int result = ::shutdown(m_handle, SD_RECEIVE); + if (result == SOCKET_ERROR) + { + int errorCode = ::WSAGetLastError(); + throw std::system_error( + errorCode, + std::system_category(), + "failed to close socket receive stream: shutdown(SD_RECEIVE)"); + } +} + +cppcoro::net::socket::socket( + cppcoro::detail::win32::socket_t handle, + bool skipCompletionOnSuccess) noexcept + : m_handle(handle) + , m_skipCompletionOnSuccess(skipCompletionOnSuccess) +{ +} + +#endif diff --git a/lib/socket_accept_operation.cpp b/lib/socket_accept_operation.cpp new file mode 100644 index 0000000..b65ab14 --- /dev/null +++ b/lib/socket_accept_operation.cpp @@ -0,0 +1,129 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include "socket_helpers.hpp" + +#include + +#if CPPCORO_OS_WINNT +# include +# include +# include +# include + +// TODO: Eliminate duplication of implementation between socket_accept_operation +// and socket_accept_operation_cancellable. + +bool cppcoro::net::socket_accept_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + static_assert( + (sizeof(m_addressBuffer) / 2) >= (16 + sizeof(SOCKADDR_IN)) && + (sizeof(m_addressBuffer) / 2) >= (16 + sizeof(SOCKADDR_IN6)), + "AcceptEx requires address buffer to be at least 16 bytes more than largest address."); + + // Need to read this flag before starting the operation, otherwise + // it may be possible that the operation will complete immediately + // on another thread and then destroy the socket before we get a + // chance to read it. + const bool skipCompletionOnSuccess = m_listeningSocket.skip_completion_on_success(); + + DWORD bytesReceived = 0; + BOOL ok = ::AcceptEx( + m_listeningSocket.native_handle(), + m_acceptingSocket.native_handle(), + m_addressBuffer, + 0, + sizeof(m_addressBuffer) / 2, + sizeof(m_addressBuffer) / 2, + &bytesReceived, + operation.get_overlapped()); + if (!ok) + { + int errorCode = ::WSAGetLastError(); + if (errorCode != ERROR_IO_PENDING) + { + operation.m_errorCode = static_cast(errorCode); + return false; + } + } + else if (skipCompletionOnSuccess) + { + operation.m_errorCode = ERROR_SUCCESS; + return false; + } + + return true; +} + +void cppcoro::net::socket_accept_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx( + reinterpret_cast(m_listeningSocket.native_handle()), + operation.get_overlapped()); +} + +void cppcoro::net::socket_accept_operation_impl::get_result( + cppcoro::detail::win32_overlapped_operation_base& operation) +{ + if (operation.m_errorCode != ERROR_SUCCESS) + { + throw std::system_error{ + static_cast(operation.m_errorCode), + std::system_category(), + "Accepting a connection failed: AcceptEx" + }; + } + + sockaddr* localSockaddr = nullptr; + sockaddr* remoteSockaddr = nullptr; + + INT localSockaddrLength; + INT remoteSockaddrLength; + + ::GetAcceptExSockaddrs( + m_addressBuffer, + 0, + sizeof(m_addressBuffer) / 2, + sizeof(m_addressBuffer) / 2, + &localSockaddr, + &localSockaddrLength, + &remoteSockaddr, + &remoteSockaddrLength); + + m_acceptingSocket.m_localEndPoint = + detail::sockaddr_to_ip_endpoint(*localSockaddr); + + m_acceptingSocket.m_remoteEndPoint = + detail::sockaddr_to_ip_endpoint(*remoteSockaddr); + + { + // Need to set SO_UPDATE_ACCEPT_CONTEXT after the accept completes + // to ensure that ::shutdown() and ::setsockopt() calls work on the + // accepted socket. + SOCKET listenSocket = m_listeningSocket.native_handle(); + const int result = ::setsockopt( + m_acceptingSocket.native_handle(), + SOL_SOCKET, + SO_UPDATE_ACCEPT_CONTEXT, + (const char*)&listenSocket, + sizeof(SOCKET)); + if (result == SOCKET_ERROR) + { + const int errorCode = ::WSAGetLastError(); + throw std::system_error{ + errorCode, + std::system_category(), + "Socket accept operation failed: setsockopt(SO_UPDATE_ACCEPT_CONTEXT)" + }; + } + } +} + +#endif diff --git a/lib/socket_connect_operation.cpp b/lib/socket_connect_operation.cpp new file mode 100644 index 0000000..3e6036f --- /dev/null +++ b/lib/socket_connect_operation.cpp @@ -0,0 +1,178 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include + +#include "socket_helpers.hpp" + +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include +# include +# include + +bool cppcoro::net::socket_connect_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + // Lookup the address of the ConnectEx function pointer for this socket. + LPFN_CONNECTEX connectExPtr; + { + GUID connectExGuid = WSAID_CONNECTEX; + DWORD byteCount = 0; + int result = ::WSAIoctl( + m_socket.native_handle(), + SIO_GET_EXTENSION_FUNCTION_POINTER, + static_cast(&connectExGuid), + sizeof(connectExGuid), + static_cast(&connectExPtr), + sizeof(connectExPtr), + &byteCount, + nullptr, + nullptr); + if (result == SOCKET_ERROR) + { + operation.m_errorCode = ::WSAGetLastError(); + return false; + } + } + + // Need to read this flag before starting the operation, otherwise + // it may be possible that the operation will complete immediately + // on another thread and then destroy the socket before we get a + // chance to read it. + const bool skipCompletionOnSuccess = m_socket.skip_completion_on_success(); + + SOCKADDR_STORAGE remoteSockaddrStorage; + const int sockaddrNameLength = cppcoro::net::detail::ip_endpoint_to_sockaddr( + m_remoteEndPoint, + std::ref(remoteSockaddrStorage)); + + DWORD bytesSent = 0; + const BOOL ok = connectExPtr( + m_socket.native_handle(), + reinterpret_cast(&remoteSockaddrStorage), + sockaddrNameLength, + nullptr, // send buffer + 0, // size of send buffer + &bytesSent, + operation.get_overlapped()); + if (!ok) + { + const int errorCode = ::WSAGetLastError(); + if (errorCode != ERROR_IO_PENDING) + { + // Failed synchronously. + operation.m_errorCode = static_cast(errorCode); + return false; + } + } + else if (skipCompletionOnSuccess) + { + // Successfully completed synchronously and no completion event + // will be posted to an I/O thread so we can return without suspending. + operation.m_errorCode = ERROR_SUCCESS; + return false; + } + + return true; +} + +void cppcoro::net::socket_connect_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx( + reinterpret_cast(m_socket.native_handle()), + operation.get_overlapped()); +} + +void cppcoro::net::socket_connect_operation_impl::get_result( + cppcoro::detail::win32_overlapped_operation_base& operation) +{ + if (operation.m_errorCode != ERROR_SUCCESS) + { + if (operation.m_errorCode == ERROR_OPERATION_ABORTED) + { + throw operation_cancelled{}; + } + + throw std::system_error{ + static_cast(operation.m_errorCode), + std::system_category(), + "Connect operation failed: ConnectEx" + }; + } + + // We need to call setsockopt() to update the socket state with information + // about the connection now that it has been successfully connected. + { + const int result = ::setsockopt( + m_socket.native_handle(), + SOL_SOCKET, + SO_UPDATE_CONNECT_CONTEXT, + nullptr, + 0); + if (result == SOCKET_ERROR) + { + // This shouldn't fail, but just in case it does we fall back to + // setting the remote address as specified in the call to Connect(). + // + // Don't really want to throw an exception here since the connection + // has actually been established. + m_socket.m_remoteEndPoint = m_remoteEndPoint; + return; + } + } + + { + SOCKADDR_STORAGE localSockaddr; + int nameLength = sizeof(localSockaddr); + const int result = ::getsockname( + m_socket.native_handle(), + reinterpret_cast(&localSockaddr), + &nameLength); + if (result == 0) + { + m_socket.m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint( + *reinterpret_cast(&localSockaddr)); + } + else + { + // Failed to get the updated local-end-point + // Just leave m_localEndPoint set to whatever bind() left it as. + // + // TODO: Should we be throwing an exception here instead? + } + } + + { + SOCKADDR_STORAGE remoteSockaddr; + int nameLength = sizeof(remoteSockaddr); + const int result = ::getpeername( + m_socket.native_handle(), + reinterpret_cast(&remoteSockaddr), + &nameLength); + if (result == 0) + { + m_socket.m_remoteEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint( + *reinterpret_cast(&remoteSockaddr)); + } + else + { + // Failed to get the actual remote end-point so just fall back to + // remembering the actual end-point that was passed to connect(). + // + // TODO: Should we be throwing an exception here instead? + m_socket.m_remoteEndPoint = m_remoteEndPoint; + } + } +} + +#endif diff --git a/lib/socket_disconnect_operation.cpp b/lib/socket_disconnect_operation.cpp new file mode 100644 index 0000000..fc87481 --- /dev/null +++ b/lib/socket_disconnect_operation.cpp @@ -0,0 +1,107 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include "socket_helpers.hpp" + +#include + +#if CPPCORO_OS_WINNT +# include +# include +# include +# include + +bool cppcoro::net::socket_disconnect_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + // Lookup the address of the DisconnectEx function pointer for this socket. + LPFN_DISCONNECTEX disconnectExPtr; + { + GUID disconnectExGuid = WSAID_DISCONNECTEX; + DWORD byteCount = 0; + const int result = ::WSAIoctl( + m_socket.native_handle(), + SIO_GET_EXTENSION_FUNCTION_POINTER, + static_cast(&disconnectExGuid), + sizeof(disconnectExGuid), + static_cast(&disconnectExPtr), + sizeof(disconnectExPtr), + &byteCount, + nullptr, + nullptr); + if (result == SOCKET_ERROR) + { + operation.m_errorCode = static_cast(::WSAGetLastError()); + return false; + } + } + + // Need to read this flag before starting the operation, otherwise + // it may be possible that the operation will complete immediately + // on another thread and then destroy the socket before we get a + // chance to read it. + const bool skipCompletionOnSuccess = m_socket.skip_completion_on_success(); + + // Need to add TF_REUSE_SOCKET to these flags if we want to allow reusing + // a socket for subsequent connections once the disconnect operation + // completes. + const DWORD flags = 0; + + const BOOL ok = disconnectExPtr( + m_socket.native_handle(), + operation.get_overlapped(), + flags, + 0); + if (!ok) + { + const int errorCode = ::WSAGetLastError(); + if (errorCode != ERROR_IO_PENDING) + { + // Failed synchronously. + operation.m_errorCode = static_cast(errorCode); + return false; + } + } + else if (skipCompletionOnSuccess) + { + // Successfully completed synchronously and no completion event + // will be posted to an I/O thread so we can return without suspending. + operation.m_errorCode = ERROR_SUCCESS; + return false; + } + + return true; +} + +void cppcoro::net::socket_disconnect_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx( + reinterpret_cast(m_socket.native_handle()), + operation.get_overlapped()); +} + +void cppcoro::net::socket_disconnect_operation_impl::get_result( + cppcoro::detail::win32_overlapped_operation_base& operation) +{ + if (operation.m_errorCode != ERROR_SUCCESS) + { + if (operation.m_errorCode == ERROR_OPERATION_ABORTED) + { + throw operation_cancelled{}; + } + + throw std::system_error{ + static_cast(operation.m_errorCode), + std::system_category(), + "Disconnect operation failed: DisconnectEx" + }; + } +} + +#endif diff --git a/lib/socket_helpers.cpp b/lib/socket_helpers.cpp new file mode 100644 index 0000000..f00bc09 --- /dev/null +++ b/lib/socket_helpers.cpp @@ -0,0 +1,85 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include "socket_helpers.hpp" + +#include + +#if CPPCORO_OS_WINNT +#include +#include + +#include +#include +#include +#include + + +cppcoro::net::ip_endpoint +cppcoro::net::detail::sockaddr_to_ip_endpoint(const sockaddr& address) noexcept +{ + if (address.sa_family == AF_INET) + { + SOCKADDR_IN ipv4Address; + std::memcpy(&ipv4Address, &address, sizeof(ipv4Address)); + + std::uint8_t addressBytes[4]; + std::memcpy(addressBytes, &ipv4Address.sin_addr, 4); + + return ipv4_endpoint{ + ipv4_address{ addressBytes }, + ntohs(ipv4Address.sin_port) + }; + } + else + { + assert(address.sa_family == AF_INET6); + + SOCKADDR_IN6 ipv6Address; + std::memcpy(&ipv6Address, &address, sizeof(ipv6Address)); + + return ipv6_endpoint{ + ipv6_address{ ipv6Address.sin6_addr.u.Byte }, + ntohs(ipv6Address.sin6_port) + }; + } +} + +int cppcoro::net::detail::ip_endpoint_to_sockaddr( + const ip_endpoint& endPoint, + std::reference_wrapper address) noexcept +{ + if (endPoint.is_ipv4()) + { + const auto& ipv4EndPoint = endPoint.to_ipv4(); + + SOCKADDR_IN ipv4Address; + ipv4Address.sin_family = AF_INET; + std::memcpy(&ipv4Address.sin_addr, ipv4EndPoint.address().bytes(), 4); + ipv4Address.sin_port = htons(ipv4EndPoint.port()); + std::memset(&ipv4Address.sin_zero, 0, sizeof(ipv4Address.sin_zero)); + + std::memcpy(&address.get(), &ipv4Address, sizeof(ipv4Address)); + + return sizeof(SOCKADDR_IN); + } + else + { + const auto& ipv6EndPoint = endPoint.to_ipv6(); + + SOCKADDR_IN6 ipv6Address; + ipv6Address.sin6_family = AF_INET6; + std::memcpy(&ipv6Address.sin6_addr, ipv6EndPoint.address().bytes(), 16); + ipv6Address.sin6_port = htons(ipv6EndPoint.port()); + ipv6Address.sin6_flowinfo = 0; + ipv6Address.sin6_scope_struct = SCOPEID_UNSPECIFIED_INIT; + + std::memcpy(&address.get(), &ipv6Address, sizeof(ipv6Address)); + + return sizeof(SOCKADDR_IN6); + } +} + +#endif // CPPCORO_OS_WINNT diff --git a/lib/socket_helpers.hpp b/lib/socket_helpers.hpp new file mode 100644 index 0000000..2083f3a --- /dev/null +++ b/lib/socket_helpers.hpp @@ -0,0 +1,47 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_PRIVATE_SOCKET_HELPERS_HPP_INCLUDED +#define CPPCORO_PRIVATE_SOCKET_HELPERS_HPP_INCLUDED + +#include + +#if CPPCORO_OS_WINNT +# include +struct sockaddr; +struct sockaddr_storage; +#endif + +namespace cppcoro +{ + namespace net + { + class ip_endpoint; + + namespace detail + { +#if CPPCORO_OS_WINNT + /// Convert a sockaddr to an IP endpoint. + ip_endpoint sockaddr_to_ip_endpoint(const sockaddr& address) noexcept; + + /// Converts an ip_endpoint to a sockaddr structure. + /// + /// \param endPoint + /// The IP endpoint to convert to a sockaddr structure. + /// + /// \param address + /// The sockaddr structure to populate. + /// + /// \return + /// The length of the sockaddr structure that was populated. + int ip_endpoint_to_sockaddr( + const ip_endpoint& endPoint, + std::reference_wrapper address) noexcept; + +#endif + } + } +} + +#endif diff --git a/lib/socket_recv_from_operation.cpp b/lib/socket_recv_from_operation.cpp new file mode 100644 index 0000000..8e994ec --- /dev/null +++ b/lib/socket_recv_from_operation.cpp @@ -0,0 +1,96 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#if CPPCORO_OS_WINNT +# include "socket_helpers.hpp" + +# include +# include +# include +# include + +bool cppcoro::net::socket_recv_from_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + static_assert( + sizeof(m_sourceSockaddrStorage) >= sizeof(SOCKADDR_IN) && + sizeof(m_sourceSockaddrStorage) >= sizeof(SOCKADDR_IN6)); + static_assert( + sockaddrStorageAlignment >= alignof(SOCKADDR_IN) && + sockaddrStorageAlignment >= alignof(SOCKADDR_IN6)); + + // Need to read this flag before starting the operation, otherwise + // it may be possible that the operation will complete immediately + // on another thread, resume the coroutine and then destroy the + // socket before we get a chance to read it. + const bool skipCompletionOnSuccess = m_socket.skip_completion_on_success(); + + m_sourceSockaddrLength = sizeof(m_sourceSockaddrStorage); + + DWORD numberOfBytesReceived = 0; + DWORD flags = 0; + int result = ::WSARecvFrom( + m_socket.native_handle(), + reinterpret_cast(&m_buffer), + 1, // buffer count + &numberOfBytesReceived, + &flags, + reinterpret_cast(&m_sourceSockaddrStorage), + &m_sourceSockaddrLength, + operation.get_overlapped(), + nullptr); + if (result == SOCKET_ERROR) + { + int errorCode = ::WSAGetLastError(); + if (errorCode != WSA_IO_PENDING) + { + // Failed synchronously. + operation.m_errorCode = static_cast(errorCode); + operation.m_numberOfBytesTransferred = numberOfBytesReceived; + return false; + } + } + else if (skipCompletionOnSuccess) + { + // Completed synchronously, no completion event will be posted to the IOCP. + operation.m_errorCode = ERROR_SUCCESS; + operation.m_numberOfBytesTransferred = numberOfBytesReceived; + return false; + } + + // Operation will complete asynchronously. + return true; +} + +void cppcoro::net::socket_recv_from_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx( + reinterpret_cast(m_socket.native_handle()), + operation.get_overlapped()); +} + +std::tuple +cppcoro::net::socket_recv_from_operation_impl::get_result( + cppcoro::detail::win32_overlapped_operation_base& operation) +{ + if (operation.m_errorCode != ERROR_SUCCESS) + { + throw std::system_error( + static_cast(operation.m_errorCode), + std::system_category(), + "Error receiving message on socket: WSARecvFrom"); + } + + return std::make_tuple( + static_cast(operation.m_numberOfBytesTransferred), + detail::sockaddr_to_ip_endpoint( + *reinterpret_cast(&m_sourceSockaddrStorage))); +} + +#endif diff --git a/lib/socket_recv_operation.cpp b/lib/socket_recv_operation.cpp new file mode 100644 index 0000000..9930e9b --- /dev/null +++ b/lib/socket_recv_operation.cpp @@ -0,0 +1,66 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include +# include +# include + +bool cppcoro::net::socket_recv_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + // Need to read this flag before starting the operation, otherwise + // it may be possible that the operation will complete immediately + // on another thread and then destroy the socket before we get a + // chance to read it. + const bool skipCompletionOnSuccess = m_socket.skip_completion_on_success(); + + DWORD numberOfBytesReceived = 0; + DWORD flags = 0; + int result = ::WSARecv( + m_socket.native_handle(), + reinterpret_cast(&m_buffer), + 1, // buffer count + &numberOfBytesReceived, + &flags, + operation.get_overlapped(), + nullptr); + if (result == SOCKET_ERROR) + { + int errorCode = ::WSAGetLastError(); + if (errorCode != WSA_IO_PENDING) + { + // Failed synchronously. + operation.m_errorCode = static_cast(errorCode); + operation.m_numberOfBytesTransferred = numberOfBytesReceived; + return false; + } + } + else if (skipCompletionOnSuccess) + { + // Completed synchronously, no completion event will be posted to the IOCP. + operation.m_errorCode = ERROR_SUCCESS; + operation.m_numberOfBytesTransferred = numberOfBytesReceived; + return false; + } + + // Operation will complete asynchronously. + return true; +} + + +void cppcoro::net::socket_recv_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx( + reinterpret_cast(m_socket.native_handle()), + operation.get_overlapped()); +} + +#endif diff --git a/lib/socket_send_operation.cpp b/lib/socket_send_operation.cpp new file mode 100644 index 0000000..e5217e1 --- /dev/null +++ b/lib/socket_send_operation.cpp @@ -0,0 +1,64 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#if CPPCORO_OS_WINNT +# include +# include +# include +# include + +bool cppcoro::net::socket_send_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + // Need to read this flag before starting the operation, otherwise + // it may be possible that the operation will complete immediately + // on another thread and then destroy the socket before we get a + // chance to read it. + const bool skipCompletionOnSuccess = m_socket.skip_completion_on_success(); + + DWORD numberOfBytesSent = 0; + int result = ::WSASend( + m_socket.native_handle(), + reinterpret_cast(&m_buffer), + 1, // buffer count + &numberOfBytesSent, + 0, // flags + operation.get_overlapped(), + nullptr); + if (result == SOCKET_ERROR) + { + int errorCode = ::WSAGetLastError(); + if (errorCode != WSA_IO_PENDING) + { + // Failed synchronously. + operation.m_errorCode = static_cast(errorCode); + operation.m_numberOfBytesTransferred = numberOfBytesSent; + return false; + } + } + else if (skipCompletionOnSuccess) + { + // Completed synchronously, no completion event will be posted to the IOCP. + operation.m_errorCode = ERROR_SUCCESS; + operation.m_numberOfBytesTransferred = numberOfBytesSent; + return false; + } + + // Operation will complete asynchronously. + return true; +} + +void cppcoro::net::socket_send_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx( + reinterpret_cast(m_socket.native_handle()), + operation.get_overlapped()); +} + +#endif diff --git a/lib/socket_send_to_operation.cpp b/lib/socket_send_to_operation.cpp new file mode 100644 index 0000000..80db248 --- /dev/null +++ b/lib/socket_send_to_operation.cpp @@ -0,0 +1,72 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#if CPPCORO_OS_WINNT +# include "socket_helpers.hpp" + +# include +# include +# include +# include + +bool cppcoro::net::socket_send_to_operation_impl::try_start( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + // Need to read this flag before starting the operation, otherwise + // it may be possible that the operation will complete immediately + // on another thread and then destroy the socket before we get a + // chance to read it. + const bool skipCompletionOnSuccess = m_socket.skip_completion_on_success(); + + SOCKADDR_STORAGE destinationAddress; + const int destinationLength = detail::ip_endpoint_to_sockaddr( + m_destination, std::ref(destinationAddress)); + + DWORD numberOfBytesSent = 0; + int result = ::WSASendTo( + m_socket.native_handle(), + reinterpret_cast(&m_buffer), + 1, // buffer count + &numberOfBytesSent, + 0, // flags + reinterpret_cast(&destinationAddress), + destinationLength, + operation.get_overlapped(), + nullptr); + if (result == SOCKET_ERROR) + { + int errorCode = ::WSAGetLastError(); + if (errorCode != WSA_IO_PENDING) + { + // Failed synchronously. + operation.m_errorCode = static_cast(errorCode); + operation.m_numberOfBytesTransferred = numberOfBytesSent; + return false; + } + } + else if (skipCompletionOnSuccess) + { + // Completed synchronously, no completion event will be posted to the IOCP. + operation.m_errorCode = ERROR_SUCCESS; + operation.m_numberOfBytesTransferred = numberOfBytesSent; + return false; + } + + // Operation will complete asynchronously. + return true; +} + +void cppcoro::net::socket_send_to_operation_impl::cancel( + cppcoro::detail::win32_overlapped_operation_base& operation) noexcept +{ + (void)::CancelIoEx( + reinterpret_cast(m_socket.native_handle()), + operation.get_overlapped()); +} + +#endif diff --git a/lib/spin_mutex.cpp b/lib/spin_mutex.cpp new file mode 100644 index 0000000..da0594f --- /dev/null +++ b/lib/spin_mutex.cpp @@ -0,0 +1,37 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include "spin_mutex.hpp" +#include "spin_wait.hpp" + +namespace cppcoro +{ + spin_mutex::spin_mutex() noexcept + : m_isLocked(false) + { + } + + bool spin_mutex::try_lock() noexcept + { + return !m_isLocked.exchange(true, std::memory_order_acquire); + } + + void spin_mutex::lock() noexcept + { + spin_wait wait; + while (!try_lock()) + { + while (m_isLocked.load(std::memory_order_relaxed)) + { + wait.spin_one(); + } + } + } + + void spin_mutex::unlock() noexcept + { + m_isLocked.store(false, std::memory_order_release); + } +} diff --git a/lib/spin_mutex.hpp b/lib/spin_mutex.hpp new file mode 100644 index 0000000..c2a285e --- /dev/null +++ b/lib/spin_mutex.hpp @@ -0,0 +1,47 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SPIN_MUTEX_HPP_INCLUDED +#define CPPCORO_SPIN_MUTEX_HPP_INCLUDED + +#include + +namespace cppcoro +{ + class spin_mutex + { + public: + + /// Initialise the mutex to the unlocked state. + spin_mutex() noexcept; + + /// Attempt to lock the mutex without blocking + /// + /// \return + /// true if the lock was acquired, false if the lock was already held + /// and could not be immediately acquired. + bool try_lock() noexcept; + + /// Block the current thread until the lock is acquired. + /// + /// This will busy-wait until it acquires the lock. + /// + /// This has 'acquire' memory semantics and synchronises + /// with prior calls to unlock(). + void lock() noexcept; + + /// Release the lock. + /// + /// This has 'release' memory semantics and synchronises with + /// lock() and try_lock(). + void unlock() noexcept; + + private: + + std::atomic m_isLocked; + + }; +} + +#endif diff --git a/lib/spin_wait.cpp b/lib/spin_wait.cpp new file mode 100644 index 0000000..70226e2 --- /dev/null +++ b/lib/spin_wait.cpp @@ -0,0 +1,101 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include "spin_wait.hpp" + +#include +#include + +#if CPPCORO_OS_WINNT +# define WIN32_LEAN_AND_MEAN +# include +#endif + +namespace +{ + namespace local + { + constexpr std::uint32_t yield_threshold = 10; + } +} + +namespace cppcoro +{ + spin_wait::spin_wait() noexcept + { + reset(); + } + + bool spin_wait::next_spin_will_yield() const noexcept + { + return m_count >= local::yield_threshold; + } + + void spin_wait::reset() noexcept + { + static const std::uint32_t initialCount = + std::thread::hardware_concurrency() > 1 ? 0 : local::yield_threshold; + m_count = initialCount; + } + + void spin_wait::spin_one() noexcept + { +#if CPPCORO_OS_WINNT + // Spin strategy taken from .NET System.SpinWait class. + // I assume the Microsoft developers knew what they're doing. + if (!next_spin_will_yield()) + { + // CPU-level pause + // Allow other hyper-threads to run while we busy-wait. + + // Make each busy-spin exponentially longer + const std::uint32_t loopCount = 2u << m_count; + for (std::uint32_t i = 0; i < loopCount; ++i) + { + ::YieldProcessor(); + ::YieldProcessor(); + } + } + else + { + // We've already spun a number of iterations. + // + const auto yieldCount = m_count - local::yield_threshold; + if (yieldCount % 20 == 19) + { + // Yield remainder of time slice to another thread and + // don't schedule this thread for a little while. + ::SleepEx(1, FALSE); + } + else if (yieldCount % 5 == 4) + { + // Yield remainder of time slice to another thread + // that is ready to run (possibly from another processor?). + ::SleepEx(0, FALSE); + } + else + { + // Yield to another thread that is ready to run on the + // current processor. + ::SwitchToThread(); + } + } +#else + if (next_spin_will_yield()) + { + std::this_thread::yield(); + } +#endif + + ++m_count; + if (m_count == 0) + { + // Don't wrap around to zero as this would go back to + // busy-waiting. + m_count = local::yield_threshold; + } + } +} + diff --git a/lib/spin_wait.hpp b/lib/spin_wait.hpp new file mode 100644 index 0000000..c202d2c --- /dev/null +++ b/lib/spin_wait.hpp @@ -0,0 +1,31 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_SPIN_WAIT_HPP_INCLUDED +#define CPPCORO_SPIN_WAIT_HPP_INCLUDED + +#include + +namespace cppcoro +{ + class spin_wait + { + public: + + spin_wait() noexcept; + + bool next_spin_will_yield() const noexcept; + + void spin_one() noexcept; + + void reset() noexcept; + + private: + + std::uint32_t m_count; + + }; +} + +#endif diff --git a/lib/static_thread_pool.cpp b/lib/static_thread_pool.cpp new file mode 100644 index 0000000..4b919a5 --- /dev/null +++ b/lib/static_thread_pool.cpp @@ -0,0 +1,754 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include "auto_reset_event.hpp" +#include "spin_mutex.hpp" +#include "spin_wait.hpp" + +#include +#include +#include +#include + +namespace +{ + namespace local + { + // Keep each thread's local queue under 1MB + constexpr std::size_t max_local_queue_size = 1024 * 1024 / sizeof(void*); + constexpr std::size_t initial_local_queue_size = 256; + } +} + +namespace cppcoro +{ + thread_local static_thread_pool::thread_state* static_thread_pool::s_currentState = nullptr; + thread_local static_thread_pool* static_thread_pool::s_currentThreadPool = nullptr; + + class static_thread_pool::thread_state + { + public: + + explicit thread_state() + : m_localQueue( + std::make_unique[]>( + local::initial_local_queue_size)) + , m_mask(local::initial_local_queue_size - 1) + , m_head(0) + , m_tail(0) + , m_isSleeping(false) + { + } + + bool try_wake_up() + { + if (m_isSleeping.load(std::memory_order_seq_cst)) + { + if (m_isSleeping.exchange(false, std::memory_order_seq_cst)) + { + try + { + m_wakeUpEvent.set(); + } + catch (...) + { + // TODO: What do we do here? + } + return true; + } + } + + return false; + } + + void notify_intent_to_sleep() noexcept + { + m_isSleeping.store(true, std::memory_order_relaxed); + } + + void sleep_until_woken() noexcept + { + try + { + m_wakeUpEvent.wait(); + } + catch (...) + { + using namespace std::chrono_literals; + std::this_thread::sleep_for(1ms); + } + } + + bool approx_has_any_queued_work() const noexcept + { + return difference( + m_head.load(std::memory_order_relaxed), + m_tail.load(std::memory_order_relaxed)) > 0; + } + + bool has_any_queued_work() noexcept + { + std::scoped_lock lock{ m_remoteMutex }; + auto tail = m_tail.load(std::memory_order_relaxed); + auto head = m_head.load(std::memory_order_seq_cst); + return difference(head, tail) > 0; + } + + bool try_local_enqueue(schedule_operation*& operation) noexcept + { + // Head is only ever written-to by the current thread so we + // are safe to use relaxed memory order when reading it. + auto head = m_head.load(std::memory_order_relaxed); + + // It is possible this method may be running concurrently with + // try_remote_steal() which may have just speculatively incremented m_tail + // trying to steal the last item in the queue but has not yet read the + // queue item. So we need to make sure we don't write to the last available + // space (at slot m_tail - 1) as this may still contain a pointer to an + // operation that has not yet been executed. + // + // Note that it's ok to read stale values from m_tail since new values + // won't ever decrease the number of available slots by more than 1. + // Reading a stale value can just mean that sometimes the queue appears + // empty when it may actually have slots free. + // + // Here m_mask is equal to buffersize - 1 so we can only write to a slot + // if the number of items consumed in the queue (head - tail) is less than + // the mask. + auto tail = m_tail.load(std::memory_order_relaxed); + if (difference(head, tail) < static_cast(m_mask)) + { + // There is space left in the local buffer. + m_localQueue[head & m_mask].store(operation, std::memory_order_relaxed); + m_head.store(head + 1, std::memory_order_seq_cst); + return true; + } + + if (m_mask == local::max_local_queue_size) + { + // No space in the buffer and we don't want to grow + // it any further. + return false; + } + + // Allocate the new buffer before taking out the lock so that + // we ensure we hold the lock for as short a time as possible. + const size_t newSize = (m_mask + 1) * 2; + + std::unique_ptr[]> newLocalQueue{ + new (std::nothrow) std::atomic[newSize] + }; + if (!newLocalQueue) + { + // Unable to allocate more memory. + return false; + } + + if (!m_remoteMutex.try_lock()) + { + // Don't wait to acquire the lock if we can't get it immediately. + // Fail and let it be enqueued to the global queue. + // TODO: Should we have a per-thread overflow queue instead? + return false; + } + + std::scoped_lock lock{ std::adopt_lock, m_remoteMutex }; + + // We can now re-read tail, guaranteed that we are not seeing a stale version. + tail = m_tail.load(std::memory_order_relaxed); + + // Copy the existing operations. + const size_t newMask = newSize - 1; + for (size_t i = tail; i != head; ++i) + { + newLocalQueue[i & newMask].store( + m_localQueue[i & m_mask].load(std::memory_order_relaxed), + std::memory_order_relaxed); + } + + // Finally, write the new operation to the queue. + newLocalQueue[head & newMask].store(operation, std::memory_order_relaxed); + + m_head.store(head + 1, std::memory_order_relaxed); + m_localQueue = std::move(newLocalQueue); + m_mask = newMask; + return true; + } + + schedule_operation* try_local_pop() noexcept + { + // Cheap, approximate, no memory-barrier check for emptiness + auto head = m_head.load(std::memory_order_relaxed); + auto tail = m_tail.load(std::memory_order_relaxed); + if (difference(head, tail) <= 0) + { + // Empty + return nullptr; + } + + // 3 classes of interleaving of try_local_pop() and try_remote_steal() + // - local pop completes before remote steal (easy) + // - remote steal completes before local pop (easy) + // - both are executed concurrently, both see each other's writes (harder) + + // Speculatively try to acquire the head item of the work queue by + // decrementing the head cursor. This may race with a concurrent call + // to try_remote_steal() that is also trying to speculatively increment + // the tail cursor to steal from the other end of the queue. In the case + // that they both try to dequeue the last/only item in the queue then we + // need to fall back to locking to decide who wins + + auto newHead = head - 1; + m_head.store(newHead, std::memory_order_seq_cst); + + tail = m_tail.load(std::memory_order_seq_cst); + + if (difference(newHead, tail) < 0) + { + // There was a race to get the last item. + // We don't know whether the remote steal saw our write + // and decided to back off or not, so we acquire the mutex + // so that we wait until the remote steal has completed so + // we can see what decision it made. + std::lock_guard lock{ m_remoteMutex }; + + // Use relaxed since the lock guarantees visibility of the writes + // that the remote steal thread performed. + tail = m_tail.load(std::memory_order_relaxed); + + if (difference(newHead, tail) < 0) + { + // The other thread didn't see our write and stole the last item. + // We need to restore the head back to it's old value. + // We hold the mutex so can just use relaxed memory order for this. + m_head.store(head, std::memory_order_relaxed); + return nullptr; + } + } + + // We successfully acquired an item from the queue. + return m_localQueue[newHead & m_mask].load(std::memory_order_relaxed); + } + + schedule_operation* try_steal(bool* lockUnavailable = nullptr) noexcept + { + if (lockUnavailable == nullptr) + { + m_remoteMutex.lock(); + } + else if (!m_remoteMutex.try_lock()) + { + *lockUnavailable = true; + return nullptr; + } + + std::scoped_lock lock{ std::adopt_lock, m_remoteMutex }; + + auto tail = m_tail.load(std::memory_order_relaxed); + auto head = m_head.load(std::memory_order_seq_cst); + if (difference(head, tail) <= 0) + { + return nullptr; + } + + // It looks like there are items in the queue. + // We'll speculatively try to steal one by incrementing + // the tail cursor. As this may be running concurrently + // with try_local_pop() which is also speculatively trying + // to remove an item from the other end of the queue we + // need to re-read the 'head' cursor afterwards to see + // if there was a potential race to dequeue the last item. + // Use seq_cst memory order both here and in try_local_pop() + // to ensure that either we will see their write to head or + // they will see our write to tail or we will both see each + // other's writes. + m_tail.store(tail + 1, std::memory_order_seq_cst); + head = m_head.load(std::memory_order_seq_cst); + + if (difference(head, tail) > 0) + { + // There was still an item in the queue after incrementing tail. + // We managed to steal an item from the bottom of the stack. + return m_localQueue[tail & m_mask].load(std::memory_order_relaxed); + } + else + { + // Otherwise we failed to steal the last item. + // Restore the old tail position. + m_tail.store(tail, std::memory_order_seq_cst); + return nullptr; + } + } + + private: + + using offset_t = std::make_signed_t; + + static constexpr offset_t difference(size_t a, size_t b) + { + return static_cast(a - b); + } + + std::unique_ptr[]> m_localQueue; + std::size_t m_mask; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(push) +# pragma warning(disable : 4324) +#endif + + //alignas(std::hardware_destructive_interference_size) + std::atomic m_head; + + //alignas(std::hardware_destructive_interference_size) + std::atomic m_tail; + + //alignas(std::hardware_destructive_interference_size) + std::atomic m_isSleeping; + spin_mutex m_remoteMutex; + +#if CPPCORO_COMPILER_MSVC +# pragma warning(pop) +#endif + + auto_reset_event m_wakeUpEvent; + + }; + + void static_thread_pool::schedule_operation::await_suspend( + cppcoro::coroutine_handle<> awaitingCoroutine) noexcept + { + m_awaitingCoroutine = awaitingCoroutine; + m_threadPool->schedule_impl(this); + } + + static_thread_pool::static_thread_pool() + : static_thread_pool(std::thread::hardware_concurrency()) + { + } + + static_thread_pool::static_thread_pool(std::uint32_t threadCount) + : m_threadCount(threadCount > 0 ? threadCount : 1) + , m_threadStates(std::make_unique(m_threadCount)) + , m_stopRequested(false) + , m_globalQueueHead(nullptr) + , m_globalQueueTail(nullptr) + , m_sleepingThreadCount(0) + { + m_threads.reserve(threadCount); + try + { + for (std::uint32_t i = 0; i < m_threadCount; ++i) + { + m_threads.emplace_back([this, i] { this->run_worker_thread(i); }); + } + } + catch (...) + { + try + { + shutdown(); + } + catch (...) + { + std::terminate(); + } + + throw; + } + } + + static_thread_pool::~static_thread_pool() + { + shutdown(); + } + + void static_thread_pool::run_worker_thread(std::uint32_t threadIndex) noexcept + { + auto& localState = m_threadStates[threadIndex]; + s_currentState = &localState; + s_currentThreadPool = this; + + auto tryGetRemote = [&]() + { + // Try to get some new work first from the global queue + // then if that queue is empty then try to steal from + // the local queues of other worker threads. + // We try to get new work from the global queue first + // before stealing as stealing from other threads has + // the side-effect of those threads running out of work + // sooner and then having to steal work which increases + // contention. + auto* op = try_global_dequeue(); + if (op == nullptr) + { + op = try_steal_from_other_thread(threadIndex); + } + return op; + }; + + while (true) + { + // Process operations from the local queue. + schedule_operation* op; + + while (true) + { + op = localState.try_local_pop(); + if (op == nullptr) + { + op = tryGetRemote(); + if (op == nullptr) + { + break; + } + } + + op->m_awaitingCoroutine.resume(); + } + + // No more operations in the local queue or remote queue. + // + // We spin for a little while waiting for new items + // to be enqueued. This avoids the expensive operation + // of putting the thread to sleep and waking it up again + // in the case that an external thread is queueing new work + + cppcoro::spin_wait spinWait; + while (true) + { + for (int i = 0; i < 30; ++i) + { + if (is_shutdown_requested()) + { + return; + } + + spinWait.spin_one(); + + if (approx_has_any_queued_work_for(threadIndex)) + { + op = tryGetRemote(); + if (op != nullptr) + { + // Now that we've executed some work we can + // return to normal processing since this work + // might have queued some more work to the local + // queue which we should process first. + goto normal_processing; + } + } + } + + // We didn't find any work after spinning for a while, let's + // put ourselves to sleep and wait to be woken up. + + // First, let other threads know we're going to sleep. + notify_intent_to_sleep(threadIndex); + + // As notifying the other threads that we're sleeping may have + // raced with other threads enqueueing more work, we need to + // re-check whether there is any more work to be done so that + // we don't get into a situation where we go to sleep and another + // thread has enqueued some work and doesn't know to wake us up. + + if (has_any_queued_work_for(threadIndex)) + { + op = tryGetRemote(); + if (op != nullptr) + { + // Try to clear the intent to sleep so that some other thread + // that subsequently enqueues some work won't mistakenly try + // to wake this threadup when we are already running as there + // might have been some other thread that it could have woken + // up instead which could have resulted in increased parallelism. + // + // However, it's possible that some other thread may have already + // tried to wake us up, in which case the auto_reset_event used to + // wake up this thread may already be in the 'set' state. Leaving + // it in this state won't really hurt. It'll just mean we might get + // a spurious wake-up next time we try to go to sleep. + try_clear_intent_to_sleep(threadIndex); + + goto normal_processing; + } + } + + if (is_shutdown_requested()) + { + return; + } + + localState.sleep_until_woken(); + } + + normal_processing: + assert(op != nullptr); + op->m_awaitingCoroutine.resume(); + } + } + + void static_thread_pool::shutdown() + { + m_stopRequested.store(true, std::memory_order_relaxed); + + for (std::uint32_t i = 0; i < m_threads.size(); ++i) + { + auto& threadState = m_threadStates[i]; + + // We should not be shutting down the thread pool if there is any + // outstanding work in the queue. It is up to the application to + // ensure all enqueued work has completed first. + assert(!threadState.has_any_queued_work()); + + threadState.try_wake_up(); + } + + for (auto& t : m_threads) + { + t.join(); + } + } + + void static_thread_pool::schedule_impl(schedule_operation* operation) noexcept + { + if (s_currentThreadPool != this || + !s_currentState->try_local_enqueue(operation)) + { + remote_enqueue(operation); + } + + wake_one_thread(); + } + + void static_thread_pool::remote_enqueue(schedule_operation* operation) noexcept + { + auto* tail = m_globalQueueTail.load(std::memory_order_relaxed); + do + { + operation->m_next = tail; + } while (!m_globalQueueTail.compare_exchange_weak( + tail, + operation, + std::memory_order_seq_cst, + std::memory_order_relaxed)); + } + + bool static_thread_pool::has_any_queued_work_for(std::uint32_t threadIndex) noexcept + { + if (m_globalQueueTail.load(std::memory_order_seq_cst) != nullptr) + { + return true; + } + + if (m_globalQueueHead.load(std::memory_order_seq_cst) != nullptr) + { + return true; + } + + for (std::uint32_t i = 0; i < m_threadCount; ++i) + { + if (i == threadIndex) continue; + if (m_threadStates[i].has_any_queued_work()) + { + return true; + } + } + + return false; + } + + bool static_thread_pool::approx_has_any_queued_work_for(std::uint32_t threadIndex) const noexcept + { + // Cheap, approximate, read-only implementation that checks whether any work has + // been queued in the system somewhere. We try to avoid writes here so that we + // don't bounce cache-lines around between threads/cores unnecessarily when + // multiple threads are all spinning waiting for work. + + if (m_globalQueueTail.load(std::memory_order_relaxed) != nullptr) + { + return true; + } + + if (m_globalQueueHead.load(std::memory_order_relaxed) != nullptr) + { + return true; + } + + for (std::uint32_t i = 0; i < m_threadCount; ++i) + { + if (i == threadIndex) continue; + if (m_threadStates[i].approx_has_any_queued_work()) + { + return true; + } + } + + return false; + } + + bool static_thread_pool::is_shutdown_requested() const noexcept + { + return m_stopRequested.load(std::memory_order_relaxed); + } + + void static_thread_pool::notify_intent_to_sleep(std::uint32_t threadIndex) noexcept + { + // First mark the thread as asleep + m_threadStates[threadIndex].notify_intent_to_sleep(); + + // Then publish the fact that a thread is asleep by incrementing the count + // of threads that are asleep. + m_sleepingThreadCount.fetch_add(1, std::memory_order_seq_cst); + } + + void static_thread_pool::try_clear_intent_to_sleep(std::uint32_t threadIndex) noexcept + { + // First try to claim that we are waking up one of the threads. + std::uint32_t oldSleepingCount = m_sleepingThreadCount.load(std::memory_order_relaxed); + do + { + if (oldSleepingCount == 0) + { + // No more sleeping threads. + // Someone must have woken us up. + return; + } + } while (!m_sleepingThreadCount.compare_exchange_weak( + oldSleepingCount, + oldSleepingCount - 1, + std::memory_order_acquire, + std::memory_order_relaxed)); + + // Then preferentially try to wake up our thread. + // If some other thread has already requested that this thread wake up + // then we will wake up another thread - the one that should have been woken + // up by the thread that woke this thread up. + if (!m_threadStates[threadIndex].try_wake_up()) + { + for (std::uint32_t i = 0; i < m_threadCount; ++i) + { + if (i == threadIndex) continue; + if (m_threadStates[i].try_wake_up()) + { + return; + } + } + } + } + + static_thread_pool::schedule_operation* + static_thread_pool::try_global_dequeue() noexcept + { + std::scoped_lock lock{ m_globalQueueMutex }; + + auto* head = m_globalQueueHead.load(std::memory_order_relaxed); + if (head == nullptr) + { + // Use seq-cst memory order so that when we check for an item in the + // global queue after signalling an intent to sleep that either we + // will see their enqueue or they will see our signal to sleep and + // wake us up. + if (m_globalQueueTail.load(std::memory_order_seq_cst) == nullptr) + { + return nullptr; + } + + // Acquire the entire set of queued operations in a single operation. + auto* tail = m_globalQueueTail.exchange(nullptr, std::memory_order_acquire); + if (tail == nullptr) + { + return nullptr; + } + + // Reverse the list + do + { + auto* next = std::exchange(tail->m_next, head); + head = std::exchange(tail, next); + } while (tail != nullptr); + } + + m_globalQueueHead = head->m_next; + + return head; + } + + static_thread_pool::schedule_operation* + static_thread_pool::try_steal_from_other_thread(std::uint32_t thisThreadIndex) noexcept + { + // Try first with non-blocking steal attempts. + + bool anyLocksUnavailable = false; + for (std::uint32_t otherThreadIndex = 0; otherThreadIndex < m_threadCount; ++otherThreadIndex) + { + if (otherThreadIndex == thisThreadIndex) continue; + auto& otherThreadState = m_threadStates[otherThreadIndex]; + auto* op = otherThreadState.try_steal(&anyLocksUnavailable); + if (op != nullptr) + { + return op; + } + } + + if (anyLocksUnavailable) + { + // We didn't check all of the other threads for work to steal yet. + // Try again, this time waiting to acquire the locks. + for (std::uint32_t otherThreadIndex = 0; otherThreadIndex < m_threadCount; ++otherThreadIndex) + { + if (otherThreadIndex == thisThreadIndex) continue; + auto& otherThreadState = m_threadStates[otherThreadIndex]; + auto* op = otherThreadState.try_steal(); + if (op != nullptr) + { + return op; + } + } + } + + return nullptr; + } + + void static_thread_pool::wake_one_thread() noexcept + { + // First try to claim responsibility for waking up one thread. + // This first read must be seq_cst to ensure that either we have + // visibility of another thread going to sleep or they have + // visibility of our prior enqueue of an item. + std::uint32_t oldSleepingCount = m_sleepingThreadCount.load(std::memory_order_seq_cst); + do + { + if (oldSleepingCount == 0) + { + // No sleeping threads. + // Someone must have woken us up. + return; + } + } while (!m_sleepingThreadCount.compare_exchange_weak( + oldSleepingCount, + oldSleepingCount - 1, + std::memory_order_acquire, + std::memory_order_relaxed)); + + // Now that we have claimed responsibility for waking a thread up + // we need to find a sleeping thread and wake it up. We should be + // guaranteed of finding a thread to wake-up here, but not necessarily + // in a single pass due to threads potentially waking themselves up + // in try_clear_intent_to_sleep(). + while (true) + { + for (std::uint32_t i = 0; i < m_threadCount; ++i) + { + if (m_threadStates[i].try_wake_up()) + { + return; + } + } + } + } +} diff --git a/lib/use.cake b/lib/use.cake new file mode 100644 index 0000000..750331a --- /dev/null +++ b/lib/use.cake @@ -0,0 +1,20 @@ +############################################################################### +# Copyright (c) Lewis Baker +# Licenced under MIT license. See LICENSE.txt for details. +############################################################################### + +import cake.path + +from cake.tools import script, env, compiler, variant + +compiler.addIncludePath(env.expand('${CPPCORO}/include')) + +buildScript = script.get(script.cwd('build.cake')) +compiler.addLibrary(buildScript.getResult('library')) + +if variant.platform == "windows": + compiler.addLibrary("Synchronization") + compiler.addLibrary("kernel32") + compiler.addLibrary("WS2_32") + compiler.addLibrary("Mswsock") + diff --git a/lib/win32.cpp b/lib/win32.cpp new file mode 100644 index 0000000..b7b497b --- /dev/null +++ b/lib/win32.cpp @@ -0,0 +1,20 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +#endif +#include + +void cppcoro::detail::win32::safe_handle::close() noexcept +{ + if (m_handle != nullptr && m_handle != INVALID_HANDLE_VALUE) + { + ::CloseHandle(m_handle); + m_handle = nullptr; + } +} diff --git a/lib/writable_file.cpp b/lib/writable_file.cpp new file mode 100644 index 0000000..3ba63e7 --- /dev/null +++ b/lib/writable_file.cpp @@ -0,0 +1,75 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include + +void cppcoro::writable_file::set_size( + std::uint64_t fileSize) +{ + LARGE_INTEGER position; + position.QuadPart = fileSize; + + BOOL ok = ::SetFilePointerEx(m_fileHandle.handle(), position, nullptr, FILE_BEGIN); + if (!ok) + { + DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "error setting file size: SetFilePointerEx" + }; + } + + ok = ::SetEndOfFile(m_fileHandle.handle()); + if (!ok) + { + DWORD errorCode = ::GetLastError(); + throw std::system_error + { + static_cast(errorCode), + std::system_category(), + "error setting file size: SetEndOfFile" + }; + } +} + +cppcoro::file_write_operation cppcoro::writable_file::write( + std::uint64_t offset, + const void* buffer, + std::size_t byteCount) noexcept +{ + return file_write_operation{ + m_fileHandle.handle(), + offset, + buffer, + byteCount + }; +} + +cppcoro::file_write_operation_cancellable cppcoro::writable_file::write( + std::uint64_t offset, + const void* buffer, + std::size_t byteCount, + cancellation_token ct) noexcept +{ + return file_write_operation_cancellable{ + m_fileHandle.handle(), + offset, + buffer, + byteCount, + std::move(ct) + }; +} + +#endif diff --git a/lib/write_only_file.cpp b/lib/write_only_file.cpp new file mode 100644 index 0000000..0ed46fc --- /dev/null +++ b/lib/write_only_file.cpp @@ -0,0 +1,37 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#if CPPCORO_OS_WINNT +# ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +# endif +# include + +cppcoro::write_only_file cppcoro::write_only_file::open( + io_service& ioService, + const cppcoro::filesystem::path& path, + file_open_mode openMode, + file_share_mode shareMode, + file_buffering_mode bufferingMode) +{ + return write_only_file(file::open( + GENERIC_WRITE, + ioService, + path, + openMode, + shareMode, + bufferingMode)); +} + +cppcoro::write_only_file::write_only_file( + detail::win32::safe_handle&& fileHandle) noexcept + : file(std::move(fileHandle)) + , writable_file(detail::win32::safe_handle{}) +{ +} + +#endif diff --git a/release/lib/vs2022/Debug/cppcoro.lib b/release/lib/vs2022/Debug/cppcoro.lib new file mode 100644 index 0000000..290833f Binary files /dev/null and b/release/lib/vs2022/Debug/cppcoro.lib differ diff --git a/release/lib/vs2022/Debug/cppcoro.pdb b/release/lib/vs2022/Debug/cppcoro.pdb new file mode 100644 index 0000000..233cd6f Binary files /dev/null and b/release/lib/vs2022/Debug/cppcoro.pdb differ diff --git a/release/lib/vs2022/Release/cppcoro.lib b/release/lib/vs2022/Release/cppcoro.lib new file mode 100644 index 0000000..1e54fa0 Binary files /dev/null and b/release/lib/vs2022/Release/cppcoro.lib differ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..f5afecd --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,62 @@ +add_library(doctest::doctest INTERFACE IMPORTED) +target_include_directories(doctest::doctest INTERFACE doctest) + +include(${CMAKE_CURRENT_LIST_DIR}/doctest/doctest.cmake) + +find_package(Threads REQUIRED) + +add_library(tests-main STATIC + main.cpp + counted.cpp +) +target_link_libraries(tests-main PUBLIC cppcoro doctest::doctest Threads::Threads) + +set(tests + generator_tests.cpp + recursive_generator_tests.cpp + async_generator_tests.cpp + async_auto_reset_event_tests.cpp + async_manual_reset_event_tests.cpp + async_mutex_tests.cpp + async_latch_tests.cpp + cancellation_token_tests.cpp + task_tests.cpp + sequence_barrier_tests.cpp + shared_task_tests.cpp + sync_wait_tests.cpp + single_consumer_async_auto_reset_event_tests.cpp + single_producer_sequencer_tests.cpp + multi_producer_sequencer_tests.cpp + when_all_tests.cpp + when_all_ready_tests.cpp + ip_address_tests.cpp + ip_endpoint_tests.cpp + ipv4_address_tests.cpp + ipv4_endpoint_tests.cpp + ipv6_address_tests.cpp + ipv6_endpoint_tests.cpp + static_thread_pool_tests.cpp +) + +if(WIN32) + list(APPEND tests + scheduling_operator_tests.cpp + io_service_tests.cpp + file_tests.cpp + socket_tests.cpp + ) +else() + # let more time for some tests + set(async_auto_reset_event_tests_TIMEOUT 60) +endif() + +foreach(test ${tests}) + get_filename_component(test_name ${test} NAME_WE) + add_executable(${test_name} ${test}) + target_link_libraries(${test_name} PRIVATE tests-main) + string(REPLACE "_" " " test_prefix ${test_name}) + if (NOT DEFINED ${test_name}_TIMEOUT) + set(${test_name}_TIMEOUT 30) + endif() + doctest_discover_tests(${test_name} TEST_PREFIX ${test_prefix}- PROPERTIES TIMEOUT ${${test_name}_TIMEOUT}) +endforeach() diff --git a/test/async_auto_reset_event_tests.cpp b/test/async_auto_reset_event_tests.cpp new file mode 100644 index 0000000..32cdb42 --- /dev/null +++ b/test/async_auto_reset_event_tests.cpp @@ -0,0 +1,140 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("async_auto_reset_event"); + +TEST_CASE("single waiter") +{ + cppcoro::async_auto_reset_event event; + + bool started = false; + bool finished = false; + auto run = [&]() -> cppcoro::task<> + { + started = true; + co_await event; + finished = true; + }; + + auto check = [&]() -> cppcoro::task<> + { + CHECK(started); + CHECK(!finished); + + event.set(); + + CHECK(finished); + + co_return; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready(run(), check())); +} + +TEST_CASE("multiple waiters") +{ + cppcoro::async_auto_reset_event event; + + + auto run = [&](bool& flag) -> cppcoro::task<> + { + co_await event; + flag = true; + }; + + bool completed1 = false; + bool completed2 = false; + + auto check = [&]() -> cppcoro::task<> + { + CHECK(!completed1); + CHECK(!completed2); + + event.set(); + + CHECK(completed1); + CHECK(!completed2); + + event.set(); + + CHECK(completed2); + + co_return; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + run(completed1), + run(completed2), + check())); +} + +TEST_CASE("multi-threaded") +{ + cppcoro::static_thread_pool tp{ 3 }; + + auto run = [&]() -> cppcoro::task<> + { + cppcoro::async_auto_reset_event event; + + int value = 0; + + auto startWaiter = [&]() -> cppcoro::task<> + { + co_await tp.schedule(); + co_await event; + ++value; + event.set(); + }; + + auto startSignaller = [&]() -> cppcoro::task<> + { + co_await tp.schedule(); + value = 5; + event.set(); + }; + + std::vector> tasks; + + tasks.emplace_back(startSignaller()); + + for (int i = 0; i < 1000; ++i) + { + tasks.emplace_back(startWaiter()); + } + + co_await cppcoro::when_all(std::move(tasks)); + + // NOTE: Can't use CHECK() here because it's not thread-safe + assert(value == 1005); + }; + + std::vector> tasks; + + for (int i = 0; i < 1000; ++i) + { + tasks.emplace_back(run()); + } + + cppcoro::sync_wait(cppcoro::when_all(std::move(tasks))); +} + +TEST_SUITE_END(); diff --git a/test/async_generator_tests.cpp b/test/async_generator_tests.cpp new file mode 100644 index 0000000..7a575f1 --- /dev/null +++ b/test/async_generator_tests.cpp @@ -0,0 +1,330 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("async_generator"); + +TEST_CASE("default-constructed async_generator is an empty sequence") +{ + cppcoro::sync_wait([]() -> cppcoro::task<> + { + // Iterating over default-constructed async_generator just + // gives an empty sequence. + cppcoro::async_generator g; + auto it = co_await g.begin(); + CHECK(it == g.end()); + }()); +} + +TEST_CASE("async_generator doesn't start if begin() not called") +{ + bool startedExecution = false; + { + auto gen = [&]() -> cppcoro::async_generator + { + startedExecution = true; + co_yield 1; + }(); + CHECK(!startedExecution); + } + CHECK(!startedExecution); +} + +TEST_CASE("enumerate sequence of 1 value") +{ + cppcoro::sync_wait([]() -> cppcoro::task<> + { + bool startedExecution = false; + auto makeGenerator = [&]() -> cppcoro::async_generator + { + startedExecution = true; + co_yield 1; + }; + + auto gen = makeGenerator(); + + CHECK(!startedExecution); + + auto it = co_await gen.begin(); + + CHECK(startedExecution); + CHECK(it != gen.end()); + CHECK(*it == 1u); + CHECK(co_await ++it == gen.end()); + }()); +} + +TEST_CASE("enumerate sequence of multiple values") +{ + cppcoro::sync_wait([]() -> cppcoro::task<> + { + bool startedExecution = false; + auto makeGenerator = [&]() -> cppcoro::async_generator + { + startedExecution = true; + co_yield 1; + co_yield 2; + co_yield 3; + }; + + auto gen = makeGenerator(); + + CHECK(!startedExecution); + + auto it = co_await gen.begin(); + + CHECK(startedExecution); + + CHECK(it != gen.end()); + CHECK(*it == 1u); + + CHECK(co_await ++it != gen.end()); + CHECK(*it == 2u); + + CHECK(co_await ++it != gen.end()); + CHECK(*it == 3u); + + CHECK(co_await ++it == gen.end()); + }()); +} + +namespace +{ + class set_to_true_on_destruction + { + public: + + set_to_true_on_destruction(bool* value) + : m_value(value) + {} + + set_to_true_on_destruction(set_to_true_on_destruction&& other) + : m_value(other.m_value) + { + other.m_value = nullptr; + } + + ~set_to_true_on_destruction() + { + if (m_value != nullptr) + { + *m_value = true; + } + } + + set_to_true_on_destruction(const set_to_true_on_destruction&) = delete; + set_to_true_on_destruction& operator=(const set_to_true_on_destruction&) = delete; + + private: + + bool* m_value; + }; +} + +TEST_CASE("destructors of values in scope are called when async_generator destructed early") +{ + cppcoro::sync_wait([]() -> cppcoro::task<> + { + bool aDestructed = false; + bool bDestructed = false; + + auto makeGenerator = [&](set_to_true_on_destruction a) -> cppcoro::async_generator + { + set_to_true_on_destruction b(&bDestructed); + co_yield 1; + co_yield 2; + }; + + { + auto gen = makeGenerator(&aDestructed); + + CHECK(!aDestructed); + CHECK(!bDestructed); + + auto it = co_await gen.begin(); + CHECK(!aDestructed); + CHECK(!bDestructed); + CHECK(*it == 1u); + } + + CHECK(aDestructed); + CHECK(bDestructed); + }()); +} + +TEST_CASE("async producer with async consumer" + * doctest::description{ + "This test tries to cover the different state-transition code-paths\n" + "- consumer resuming producer and producer completing asynchronously\n" + "- producer resuming consumer and consumer requesting next value synchronously\n" + "- producer resuming consumer and consumer requesting next value asynchronously" }) +{ +#if defined(_MSC_VER) && _MSC_FULL_VER <= 191025224 && defined(CPPCORO_RELEASE_OPTIMISED) + FAST_WARN_UNARY_FALSE("MSVC has a known codegen bug under optimised builds, skipping"); + return; +#endif + + cppcoro::single_consumer_event p1; + cppcoro::single_consumer_event p2; + cppcoro::single_consumer_event p3; + cppcoro::single_consumer_event c1; + + auto produce = [&]() -> cppcoro::async_generator + { + co_await p1; + co_yield 1; + co_await p2; + co_yield 2; + co_await p3; + }; + + bool consumerFinished = false; + + auto consume = [&]() -> cppcoro::task<> + { + auto generator = produce(); + auto it = co_await generator.begin(); + CHECK(*it == 1u); + (void)co_await ++it; + CHECK(*it == 2u); + co_await c1; + (void)co_await ++it; + CHECK(it == generator.end()); + consumerFinished = true; + }; + + auto unblock = [&]() -> cppcoro::task<> + { + p1.set(); + p2.set(); + c1.set(); + CHECK(!consumerFinished); + p3.set(); + CHECK(consumerFinished); + co_return; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready(consume(), unblock())); +} + +TEST_CASE("exception thrown before first yield is rethrown from begin operation") +{ + class TestException {}; + auto gen = [](bool shouldThrow) -> cppcoro::async_generator + { + if (shouldThrow) + { + throw TestException(); + } + co_yield 1; + }(true); + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + CHECK_THROWS_AS(co_await gen.begin(), const TestException&); + }()); +} + +TEST_CASE("exception thrown after first yield is rethrown from increment operator") +{ + class TestException {}; + auto gen = [](bool shouldThrow) -> cppcoro::async_generator + { + co_yield 1; + if (shouldThrow) + { + throw TestException(); + } + }(true); + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + auto it = co_await gen.begin(); + CHECK(*it == 1u); + CHECK_THROWS_AS(co_await ++it, const TestException&); + CHECK(it == gen.end()); + }()); +} + +TEST_CASE("large number of synchronous completions doesn't result in stack-overflow") +{ + + auto makeSequence = [](cppcoro::single_consumer_event& event) -> cppcoro::async_generator + { + for (std::uint32_t i = 0; i < 1'000'000u; ++i) + { + if (i == 500'000u) co_await event; + co_yield i; + } + }; + + auto consumer = [](cppcoro::async_generator sequence) -> cppcoro::task<> + { + std::uint32_t expected = 0; + for (auto iter = co_await sequence.begin(); iter != sequence.end(); co_await ++iter) + { + std::uint32_t i = *iter; + CHECK(i == expected++); + } + + CHECK(expected == 1'000'000u); + }; + + auto unblocker = [](cppcoro::single_consumer_event& event) -> cppcoro::task<> + { + // Should have processed the first 500'000 elements synchronously with consumer driving + // iteraction before producer suspends and thus consumer suspends. + // Then we resume producer in call to set() below and it continues processing remaining + // 500'000 elements, this time with producer driving the interaction. + + event.set(); + + co_return; + }; + + cppcoro::single_consumer_event event; + + cppcoro::sync_wait( + cppcoro::when_all_ready( + consumer(makeSequence(event)), + unblocker(event))); +} + +TEST_CASE("fmap") +{ + using cppcoro::async_generator; + using cppcoro::fmap; + + auto iota = [](int count) -> async_generator + { + for (int i = 0; i < count; ++i) + { + co_yield i; + } + }; + + auto squares = iota(5) | fmap([](auto x) { return x * x; }); + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + auto it = co_await squares.begin(); + CHECK(*it == 0); + CHECK(*co_await ++it == 1); + CHECK(*co_await ++it == 4); + CHECK(*co_await ++it == 9); + CHECK(*co_await ++it == 16); + CHECK(co_await ++it == squares.end()); + }()); +} + +TEST_SUITE_END(); diff --git a/test/async_latch_tests.cpp b/test/async_latch_tests.cpp new file mode 100644 index 0000000..9a89207 --- /dev/null +++ b/test/async_latch_tests.cpp @@ -0,0 +1,113 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("async_latch"); + +using namespace cppcoro; + +TEST_CASE("latch constructed with zero count is initially ready") +{ + async_latch latch(0); + CHECK(latch.is_ready()); +} + +TEST_CASE("latch constructed with negative count is initially ready") +{ + async_latch latch(-3); + CHECK(latch.is_ready()); +} + +TEST_CASE("count_down and is_ready") +{ + async_latch latch(3); + CHECK(!latch.is_ready()); + latch.count_down(); + CHECK(!latch.is_ready()); + latch.count_down(); + CHECK(!latch.is_ready()); + latch.count_down(); + CHECK(latch.is_ready()); +} + +TEST_CASE("count_down by n") +{ + async_latch latch(5); + latch.count_down(3); + CHECK(!latch.is_ready()); + latch.count_down(2); + CHECK(latch.is_ready()); +} + +TEST_CASE("single awaiter") +{ + async_latch latch(2); + bool after = false; + sync_wait(when_all_ready( + [&]() -> task<> + { + co_await latch; + after = true; + }(), + [&]() -> task<> + { + CHECK(!after); + latch.count_down(); + CHECK(!after); + latch.count_down(); + CHECK(after); + co_return; + }() + )); +} + +TEST_CASE("multiple awaiters") +{ + async_latch latch(2); + bool after1 = false; + bool after2 = false; + bool after3 = false; + sync_wait(when_all_ready( + [&]() -> task<> + { + co_await latch; + after1 = true; + }(), + [&]() -> task<> + { + co_await latch; + after2 = true; + }(), + [&]() -> task<> + { + co_await latch; + after3 = true; + }(), + [&]() -> task<> + { + CHECK(!after1); + CHECK(!after2); + CHECK(!after3); + latch.count_down(); + CHECK(!after1); + CHECK(!after2); + CHECK(!after3); + latch.count_down(); + CHECK(after1); + CHECK(after2); + CHECK(after3); + co_return; + }())); +} + +TEST_SUITE_END(); diff --git a/test/async_manual_reset_event_tests.cpp b/test/async_manual_reset_event_tests.cpp new file mode 100644 index 0000000..1673204 --- /dev/null +++ b/test/async_manual_reset_event_tests.cpp @@ -0,0 +1,96 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("async_manual_reset_event"); + +TEST_CASE("default constructor initially not set") +{ + cppcoro::async_manual_reset_event event; + CHECK(!event.is_set()); +} + +TEST_CASE("construct event initially set") +{ + cppcoro::async_manual_reset_event event{ true }; + CHECK(event.is_set()); +} + +TEST_CASE("set and reset") +{ + cppcoro::async_manual_reset_event event; + CHECK(!event.is_set()); + event.set(); + CHECK(event.is_set()); + event.set(); + CHECK(event.is_set()); + event.reset(); + CHECK(!event.is_set()); + event.reset(); + CHECK(!event.is_set()); + event.set(); + CHECK(event.is_set()); +} + +TEST_CASE("await not set event") +{ + cppcoro::async_manual_reset_event event; + + auto createWaiter = [&](bool& flag) -> cppcoro::task<> + { + co_await event; + flag = true; + }; + + bool completed1 = false; + bool completed2 = false; + + auto check = [&]() -> cppcoro::task<> + { + CHECK(!completed1); + CHECK(!completed2); + + event.reset(); + + CHECK(!completed1); + CHECK(!completed2); + + event.set(); + + CHECK(completed1); + CHECK(completed2); + + co_return; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + createWaiter(completed1), + createWaiter(completed2), + check())); +} + +TEST_CASE("awaiting already set event doesn't suspend") +{ + cppcoro::async_manual_reset_event event{ true }; + + auto createWaiter = [&]() -> cppcoro::task<> + { + co_await event; + }; + + // Should complete without blocking. + cppcoro::sync_wait(cppcoro::when_all_ready( + createWaiter(), + createWaiter())); +} + +TEST_SUITE_END(); diff --git a/test/async_mutex_tests.cpp b/test/async_mutex_tests.cpp new file mode 100644 index 0000000..0837e52 --- /dev/null +++ b/test/async_mutex_tests.cpp @@ -0,0 +1,90 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("async_mutex"); + +TEST_CASE("try_lock") +{ + cppcoro::async_mutex mutex; + + CHECK(mutex.try_lock()); + + CHECK_FALSE(mutex.try_lock()); + + mutex.unlock(); + + CHECK(mutex.try_lock()); +} + +#if 0 +TEST_CASE("multiple lockers") +{ + int value = 0; + cppcoro::async_mutex mutex; + cppcoro::single_consumer_event a; + cppcoro::single_consumer_event b; + cppcoro::single_consumer_event c; + cppcoro::single_consumer_event d; + + auto f = [&](cppcoro::single_consumer_event& e) -> cppcoro::task<> + { + auto lock = co_await mutex.scoped_lock_async(); + co_await e; + ++value; + }; + + auto check = [&]() -> cppcoro::task<> + { + CHECK(value == 0); + + a.set(); + + CHECK(value == 1); + + auto check2 = [&]() -> cppcoro::task<> + { + b.set(); + + CHECK(value == 2); + + c.set(); + + CHECK(value == 3); + + d.set(); + + CHECK(value == 4); + + co_return; + }; + + // Now that we've queued some waiters and released one waiter this will + // have acquired the list of pending waiters in the local cache. + // We'll now queue up another one before releasing any more waiters + // to test the code-path that looks at the newly queued waiter list + // when the cache of waiters is exhausted. + (void)co_await cppcoro::when_all_ready(f(d), check2()); + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + f(a), + f(b), + f(c), + check())); + + CHECK(value == 4); +} +#endif + +TEST_SUITE_END(); diff --git a/test/build.cake b/test/build.cake new file mode 100644 index 0000000..a9d0f57 --- /dev/null +++ b/test/build.cake @@ -0,0 +1,93 @@ +############################################################################### +# Copyright (c) Lewis Baker +# Licenced under MIT license. See LICENSE.txt for details. +############################################################################### + +import cake.path + +from cake.tools import script, env, compiler, project, variant, test + +script.include([ + env.expand('${CPPCORO}/lib/use.cake'), +]) + +headers = script.cwd([ + "counted.hpp", + "io_service_fixture.hpp", + ]) + +sources = script.cwd([ + 'main.cpp', + 'counted.cpp', + 'generator_tests.cpp', + 'recursive_generator_tests.cpp', + 'async_generator_tests.cpp', + 'async_auto_reset_event_tests.cpp', + 'async_manual_reset_event_tests.cpp', + 'async_mutex_tests.cpp', + 'async_latch_tests.cpp', + 'cancellation_token_tests.cpp', + 'task_tests.cpp', + 'sequence_barrier_tests.cpp', + 'shared_task_tests.cpp', + 'sync_wait_tests.cpp', + 'single_consumer_async_auto_reset_event_tests.cpp', + 'single_producer_sequencer_tests.cpp', + 'multi_producer_sequencer_tests.cpp', + 'when_all_tests.cpp', + 'when_all_ready_tests.cpp', + 'ip_address_tests.cpp', + 'ip_endpoint_tests.cpp', + 'ipv4_address_tests.cpp', + 'ipv4_endpoint_tests.cpp', + 'ipv6_address_tests.cpp', + 'ipv6_endpoint_tests.cpp', + 'static_thread_pool_tests.cpp', + ]) + +if variant.platform == 'windows': + sources += script.cwd([ + 'scheduling_operator_tests.cpp', + 'io_service_tests.cpp', + 'file_tests.cpp', + 'socket_tests.cpp', + ]) + +extras = script.cwd([ + 'build.cake', +]) + +intermediateBuildDir = cake.path.join(env.expand('${CPPCORO_BUILD}'), 'test', 'obj') + +compiler.addDefine('CPPCORO_RELEASE_' + variant.release.upper()) + +objects = compiler.objects( + targetDir=intermediateBuildDir, + sources=sources, +) + +testExe = compiler.program( + target=env.expand('${CPPCORO_BUILD}/test/run'), + sources=objects, +) + +test.alwaysRun = True +testResult = test.run( + program=testExe, + results=env.expand('${CPPCORO_BUILD}/test/run.results'), + ) +script.addTarget('testresult', testResult) + +vcproj = project.project( + target=env.expand('${CPPCORO_PROJECT}/cppcoro_tests'), + items={ + 'Source': sources + headers, + '': extras, + }, + output=testExe, +) + +script.setResult( + project=vcproj, + test=testExe, +) diff --git a/test/cancellation_token_tests.cpp b/test/cancellation_token_tests.cpp new file mode 100644 index 0000000..0684d76 --- /dev/null +++ b/test/cancellation_token_tests.cpp @@ -0,0 +1,342 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("cancellation_token tests"); + +TEST_CASE("default cancellation_token is not cancellable") +{ + cppcoro::cancellation_token t; + CHECK(!t.is_cancellation_requested()); + CHECK(!t.can_be_cancelled()); +} + +TEST_CASE("calling request_cancellation on cancellation_source updates cancellation_token") +{ + cppcoro::cancellation_source s; + cppcoro::cancellation_token t = s.token(); + CHECK(t.can_be_cancelled()); + CHECK(!t.is_cancellation_requested()); + s.request_cancellation(); + CHECK(t.is_cancellation_requested()); + CHECK(t.can_be_cancelled()); +} + +TEST_CASE("cancellation_token can't be cancelled when last cancellation_source destructed") +{ + cppcoro::cancellation_token t; + { + cppcoro::cancellation_source s; + t = s.token(); + CHECK(t.can_be_cancelled()); + } + + CHECK(!t.can_be_cancelled()); +} + +TEST_CASE("cancelation_token can be cancelled when last cancellation_source destructed if cancellation already requested") +{ + cppcoro::cancellation_token t; + { + cppcoro::cancellation_source s; + t = s.token(); + CHECK(t.can_be_cancelled()); + s.request_cancellation(); + } + + CHECK(t.can_be_cancelled()); + CHECK(t.is_cancellation_requested()); +} + +TEST_CASE("cancellation_registration when cancellation not yet requested") +{ + cppcoro::cancellation_source s; + + bool callbackExecuted = false; + { + cppcoro::cancellation_registration callbackRegistration( + s.token(), + [&] { callbackExecuted = true; }); + } + + CHECK(!callbackExecuted); + + { + cppcoro::cancellation_registration callbackRegistration( + s.token(), + [&] { callbackExecuted = true; }); + + CHECK(!callbackExecuted); + + s.request_cancellation(); + + CHECK(callbackExecuted); + } +} + +TEST_CASE("throw_if_cancellation_requested") +{ + cppcoro::cancellation_source s; + cppcoro::cancellation_token t = s.token(); + + CHECK_NOTHROW(t.throw_if_cancellation_requested()); + + s.request_cancellation(); + + CHECK_THROWS_AS(t.throw_if_cancellation_requested(), const cppcoro::operation_cancelled&); +} + +TEST_CASE("cancellation_registration called immediately when cancellation already requested") +{ + cppcoro::cancellation_source s; + s.request_cancellation(); + + bool executed = false; + cppcoro::cancellation_registration r{ s.token(), [&] { executed = true; } }; + CHECK(executed); +} + +TEST_CASE("register many callbacks" + * doctest::description{ + "this checks the code-path that allocates the next chunk of entries " + "in the internal data-structres, which occurs on 17th callback" }) +{ + cppcoro::cancellation_source s; + auto t = s.token(); + + int callbackExecutionCount = 0; + auto callback = [&] { ++callbackExecutionCount; }; + + // Allocate enough to require a second chunk to be allocated. + cppcoro::cancellation_registration r1{ t, callback }; + cppcoro::cancellation_registration r2{ t, callback }; + cppcoro::cancellation_registration r3{ t, callback }; + cppcoro::cancellation_registration r4{ t, callback }; + cppcoro::cancellation_registration r5{ t, callback }; + cppcoro::cancellation_registration r6{ t, callback }; + cppcoro::cancellation_registration r7{ t, callback }; + cppcoro::cancellation_registration r8{ t, callback }; + cppcoro::cancellation_registration r9{ t, callback }; + cppcoro::cancellation_registration r10{ t, callback }; + cppcoro::cancellation_registration r11{ t, callback }; + cppcoro::cancellation_registration r12{ t, callback }; + cppcoro::cancellation_registration r13{ t, callback }; + cppcoro::cancellation_registration r14{ t, callback }; + cppcoro::cancellation_registration r15{ t, callback }; + cppcoro::cancellation_registration r16{ t, callback }; + cppcoro::cancellation_registration r17{ t, callback }; + cppcoro::cancellation_registration r18{ t, callback }; + + s.request_cancellation(); + + CHECK(callbackExecutionCount == 18); +} + +TEST_CASE("concurrent registration and cancellation") +{ + // Just check this runs and terminates without crashing. + for (int i = 0; i < 100; ++i) + { + cppcoro::cancellation_source source; + + std::thread waiter1{ [token = source.token()] + { + std::atomic cancelled = false; + while (!cancelled) + { + cppcoro::cancellation_registration registration{ token, [&] + { + cancelled = true; + } }; + + cppcoro::cancellation_registration reg0{ token, [] {} }; + cppcoro::cancellation_registration reg1{ token, [] {} }; + cppcoro::cancellation_registration reg2{ token, [] {} }; + cppcoro::cancellation_registration reg3{ token, [] {} }; + cppcoro::cancellation_registration reg4{ token, [] {} }; + cppcoro::cancellation_registration reg5{ token, [] {} }; + cppcoro::cancellation_registration reg6{ token, [] {} }; + cppcoro::cancellation_registration reg7{ token, [] {} }; + cppcoro::cancellation_registration reg8{ token, [] {} }; + cppcoro::cancellation_registration reg9{ token, [] {} }; + cppcoro::cancellation_registration reg10{ token, [] {} }; + cppcoro::cancellation_registration reg11{ token, [] {} }; + cppcoro::cancellation_registration reg12{ token, [] {} }; + cppcoro::cancellation_registration reg13{ token, [] {} }; + cppcoro::cancellation_registration reg14{ token, [] {} }; + cppcoro::cancellation_registration reg15{ token, [] {} }; + cppcoro::cancellation_registration reg17{ token, [] {} }; + + std::this_thread::yield(); + } + } }; + + std::thread waiter2{ [token = source.token()] + { + std::atomic cancelled = false; + while (!cancelled) + { + cppcoro::cancellation_registration registration{ token, [&] + { + cancelled = true; + } }; + + cppcoro::cancellation_registration reg0{ token, [] {} }; + cppcoro::cancellation_registration reg1{ token, [] {} }; + cppcoro::cancellation_registration reg2{ token, [] {} }; + cppcoro::cancellation_registration reg3{ token, [] {} }; + cppcoro::cancellation_registration reg4{ token, [] {} }; + cppcoro::cancellation_registration reg5{ token, [] {} }; + cppcoro::cancellation_registration reg6{ token, [] {} }; + cppcoro::cancellation_registration reg7{ token, [] {} }; + cppcoro::cancellation_registration reg8{ token, [] {} }; + cppcoro::cancellation_registration reg9{ token, [] {} }; + cppcoro::cancellation_registration reg10{ token, [] {} }; + cppcoro::cancellation_registration reg11{ token, [] {} }; + cppcoro::cancellation_registration reg12{ token, [] {} }; + cppcoro::cancellation_registration reg13{ token, [] {} }; + cppcoro::cancellation_registration reg14{ token, [] {} }; + cppcoro::cancellation_registration reg15{ token, [] {} }; + cppcoro::cancellation_registration reg16{ token, [] {} }; + + std::this_thread::yield(); + } + } }; + + std::thread waiter3{ [token = source.token()] + { + std::atomic cancelled = false; + while (!cancelled) + { + cppcoro::cancellation_registration registration{ token, [&] + { + cancelled = true; + } }; + + cppcoro::cancellation_registration reg0{ token, [] {} }; + cppcoro::cancellation_registration reg1{ token, [] {} }; + cppcoro::cancellation_registration reg2{ token, [] {} }; + cppcoro::cancellation_registration reg3{ token, [] {} }; + cppcoro::cancellation_registration reg4{ token, [] {} }; + cppcoro::cancellation_registration reg5{ token, [] {} }; + cppcoro::cancellation_registration reg6{ token, [] {} }; + cppcoro::cancellation_registration reg7{ token, [] {} }; + cppcoro::cancellation_registration reg8{ token, [] {} }; + cppcoro::cancellation_registration reg9{ token, [] {} }; + cppcoro::cancellation_registration reg10{ token, [] {} }; + cppcoro::cancellation_registration reg11{ token, [] {} }; + cppcoro::cancellation_registration reg12{ token, [] {} }; + cppcoro::cancellation_registration reg13{ token, [] {} }; + cppcoro::cancellation_registration reg14{ token, [] {} }; + cppcoro::cancellation_registration reg15{ token, [] {} }; + cppcoro::cancellation_registration reg16{ token, [] {} }; + + std::this_thread::yield(); + } + } }; + + std::thread canceller{ [&source] + { + source.request_cancellation(); + } }; + + canceller.join(); + waiter1.join(); + waiter2.join(); + waiter3.join(); + } +} + +TEST_CASE("cancellation registration single-threaded performance") +{ + struct batch + { + batch(cppcoro::cancellation_token t) + : r0(t, [] {}) + , r1(t, [] {}) + , r2(t, [] {}) + , r3(t, [] {}) + , r4(t, [] {}) + , r5(t, [] {}) + , r6(t, [] {}) + , r7(t, [] {}) + , r8(t, [] {}) + , r9(t, [] {}) + {} + + cppcoro::cancellation_registration r0; + cppcoro::cancellation_registration r1; + cppcoro::cancellation_registration r2; + cppcoro::cancellation_registration r3; + cppcoro::cancellation_registration r4; + cppcoro::cancellation_registration r5; + cppcoro::cancellation_registration r6; + cppcoro::cancellation_registration r7; + cppcoro::cancellation_registration r8; + cppcoro::cancellation_registration r9; + }; + + cppcoro::cancellation_source s; + + constexpr int iterationCount = 100'000; + + auto start = std::chrono::high_resolution_clock::now(); + + for (int i = 0; i < iterationCount; ++i) + { + cppcoro::cancellation_registration r{ s.token(), [] {} }; + } + + auto end = std::chrono::high_resolution_clock::now(); + + auto time1 = end - start; + + start = end; + + for (int i = 0; i < iterationCount; ++i) + { + batch b{ s.token() }; + } + + end = std::chrono::high_resolution_clock::now(); + + auto time2 = end - start; + + start = end; + + for (int i = 0; i < iterationCount; ++i) + { + batch b0{ s.token() }; + batch b1{ s.token() }; + batch b2{ s.token() }; + batch b3{ s.token() }; + batch b4{ s.token() }; + } + + end = std::chrono::high_resolution_clock::now(); + + auto time3 = end - start; + + auto report = [](const char* label, auto time, std::uint64_t count) + { + auto us = std::chrono::duration_cast(time).count(); + MESSAGE(label << " took " << us << "us (" << (1000.0 * us / count) << " ns/item)"); + }; + + report("Individual", time1, iterationCount); + report("Batch10", time2, 10 * iterationCount); + report("Batch50", time3, 50 * iterationCount); +} + +TEST_SUITE_END(); diff --git a/test/counted.cpp b/test/counted.cpp new file mode 100644 index 0000000..5867fdc --- /dev/null +++ b/test/counted.cpp @@ -0,0 +1,11 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include "counted.hpp" + +int counted::default_construction_count; +int counted::copy_construction_count; +int counted::move_construction_count; +int counted::destruction_count; diff --git a/test/counted.hpp b/test/counted.hpp new file mode 100644 index 0000000..380c12e --- /dev/null +++ b/test/counted.hpp @@ -0,0 +1,42 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_TESTS_COUNTED_HPP_INCLUDED +#define CPPCORO_TESTS_COUNTED_HPP_INCLUDED + +struct counted +{ + static int default_construction_count; + static int copy_construction_count; + static int move_construction_count; + static int destruction_count; + + int id; + + static void reset_counts() + { + default_construction_count = 0; + copy_construction_count = 0; + move_construction_count = 0; + destruction_count = 0; + } + + static int construction_count() + { + return default_construction_count + copy_construction_count + move_construction_count; + } + + static int active_count() + { + return construction_count() - destruction_count; + } + + counted() : id(default_construction_count++) {} + counted(const counted& other) : id(other.id) { ++copy_construction_count; } + counted(counted&& other) : id(other.id) { ++move_construction_count; other.id = -1; } + ~counted() { ++destruction_count; } + +}; + +#endif diff --git a/test/doctest/cppcoro_doctest.h b/test/doctest/cppcoro_doctest.h new file mode 100644 index 0000000..b083149 --- /dev/null +++ b/test/doctest/cppcoro_doctest.h @@ -0,0 +1,12 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Andreas Buhr +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_CPPCORO_DOCTEST_H_INCLUDED +#define CPPCORO_CPPCORO_DOCTEST_H_INCLUDED + +#define DOCTEST_CONFIG_USE_STD_HEADERS +#include "doctest.h" + +#endif + diff --git a/test/doctest/doctest.cmake b/test/doctest/doctest.cmake new file mode 100644 index 0000000..1376801 --- /dev/null +++ b/test/doctest/doctest.cmake @@ -0,0 +1,175 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +doctest +----- + +This module defines a function to help use the doctest test framework. + +The :command:`doctest_discover_tests` discovers tests by asking the compiled test +executable to enumerate its tests. This does not require CMake to be re-run +when tests change. However, it may not work in a cross-compiling environment, +and setting test properties is less convenient. + +This command is intended to replace use of :command:`add_test` to register +tests, and will create a separate CTest test for each doctest test case. Note +that this is in some cases less efficient, as common set-up and tear-down logic +cannot be shared by multiple test cases executing in the same instance. +However, it provides more fine-grained pass/fail information to CTest, which is +usually considered as more beneficial. By default, the CTest test name is the +same as the doctest name; see also ``TEST_PREFIX`` and ``TEST_SUFFIX``. + +.. command:: doctest_discover_tests + + Automatically add tests with CTest by querying the compiled test executable + for available tests:: + + doctest_discover_tests(target + [TEST_SPEC arg1...] + [EXTRA_ARGS arg1...] + [WORKING_DIRECTORY dir] + [TEST_PREFIX prefix] + [TEST_SUFFIX suffix] + [PROPERTIES name1 value1...] + [TEST_LIST var] + ) + + ``doctest_discover_tests`` sets up a post-build command on the test executable + that generates the list of tests by parsing the output from running the test + with the ``--list-test-cases`` argument. This ensures that the full + list of tests is obtained. Since test discovery occurs at build time, it is + not necessary to re-run CMake when the list of tests changes. + However, it requires that :prop_tgt:`CROSSCOMPILING_EMULATOR` is properly set + in order to function in a cross-compiling environment. + + Additionally, setting properties on tests is somewhat less convenient, since + the tests are not available at CMake time. Additional test properties may be + assigned to the set of tests as a whole using the ``PROPERTIES`` option. If + more fine-grained test control is needed, custom content may be provided + through an external CTest script using the :prop_dir:`TEST_INCLUDE_FILES` + directory property. The set of discovered tests is made accessible to such a + script via the ``_TESTS`` variable. + + The options are: + + ``target`` + Specifies the doctest executable, which must be a known CMake executable + target. CMake will substitute the location of the built executable when + running the test. + + ``TEST_SPEC arg1...`` + Specifies test cases, wildcarded test cases, tags and tag expressions to + pass to the doctest executable with the ``--list-test-cases`` argument. + + ``EXTRA_ARGS arg1...`` + Any extra arguments to pass on the command line to each test case. + + ``WORKING_DIRECTORY dir`` + Specifies the directory in which to run the discovered test cases. If this + option is not provided, the current binary directory is used. + + ``TEST_PREFIX prefix`` + Specifies a ``prefix`` to be prepended to the name of each discovered test + case. This can be useful when the same test executable is being used in + multiple calls to ``doctest_discover_tests()`` but with different + ``TEST_SPEC`` or ``EXTRA_ARGS``. + + ``TEST_SUFFIX suffix`` + Similar to ``TEST_PREFIX`` except the ``suffix`` is appended to the name of + every discovered test case. Both ``TEST_PREFIX`` and ``TEST_SUFFIX`` may + be specified. + + ``PROPERTIES name1 value1...`` + Specifies additional properties to be set on all tests discovered by this + invocation of ``doctest_discover_tests``. + + ``TEST_LIST var`` + Make the list of tests available in the variable ``var``, rather than the + default ``_TESTS``. This can be useful when the same test + executable is being used in multiple calls to ``doctest_discover_tests()``. + Note that this variable is only available in CTest. + +#]=======================================================================] + +#------------------------------------------------------------------------------ +function(doctest_discover_tests TARGET) + cmake_parse_arguments( + "" + "" + "TEST_PREFIX;TEST_SUFFIX;WORKING_DIRECTORY;TEST_LIST" + "TEST_SPEC;EXTRA_ARGS;PROPERTIES" + ${ARGN} + ) + + if(NOT _WORKING_DIRECTORY) + set(_WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") + endif() + if(NOT _TEST_LIST) + set(_TEST_LIST ${TARGET}_TESTS) + endif() + + ## Generate a unique name based on the extra arguments + string(SHA1 args_hash "${_TEST_SPEC} ${_EXTRA_ARGS}") + string(SUBSTRING ${args_hash} 0 7 args_hash) + + # Define rule to generate test list for aforementioned test executable + set(ctest_include_file "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}_include-${args_hash}.cmake") + set(ctest_tests_file "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}_tests-${args_hash}.cmake") + get_property(crosscompiling_emulator + TARGET ${TARGET} + PROPERTY CROSSCOMPILING_EMULATOR + ) + add_custom_command( + TARGET ${TARGET} POST_BUILD + BYPRODUCTS "${ctest_tests_file}" + COMMAND "${CMAKE_COMMAND}" + -D "TEST_TARGET=${TARGET}" + -D "TEST_EXECUTABLE=$" + -D "TEST_EXECUTOR=${crosscompiling_emulator}" + -D "TEST_WORKING_DIR=${_WORKING_DIRECTORY}" + -D "TEST_SPEC=${_TEST_SPEC}" + -D "TEST_EXTRA_ARGS=${_EXTRA_ARGS}" + -D "TEST_PROPERTIES=${_PROPERTIES}" + -D "TEST_PREFIX=${_TEST_PREFIX}" + -D "TEST_SUFFIX=${_TEST_SUFFIX}" + -D "TEST_LIST=${_TEST_LIST}" + -D "CTEST_FILE=${ctest_tests_file}" + -P "${_DOCTEST_DISCOVER_TESTS_SCRIPT}" + VERBATIM + ) + + file(WRITE "${ctest_include_file}" + "if(EXISTS \"${ctest_tests_file}\")\n" + " include(\"${ctest_tests_file}\")\n" + "else()\n" + " add_test(${TARGET}_NOT_BUILT-${args_hash} ${TARGET}_NOT_BUILT-${args_hash})\n" + "endif()\n" + ) + + if(NOT CMAKE_VERSION VERSION_LESS 3.10) + # Add discovered tests to directory TEST_INCLUDE_FILES + set_property(DIRECTORY + APPEND PROPERTY TEST_INCLUDE_FILES "${ctest_include_file}" + ) + else() + # Add discovered tests as directory TEST_INCLUDE_FILE if possible + get_property(test_include_file_set DIRECTORY PROPERTY TEST_INCLUDE_FILE SET) + if(NOT ${test_include_file_set}) + set_property(DIRECTORY + PROPERTY TEST_INCLUDE_FILE "${ctest_include_file}" + ) + else() + message(FATAL_ERROR + "Cannot set more than one TEST_INCLUDE_FILE" + ) + endif() + endif() + +endfunction() + +############################################################################### + +set(_DOCTEST_DISCOVER_TESTS_SCRIPT + ${CMAKE_CURRENT_LIST_DIR}/doctestAddTests.cmake +) diff --git a/test/doctest/doctest.h b/test/doctest/doctest.h new file mode 100644 index 0000000..9444698 --- /dev/null +++ b/test/doctest/doctest.h @@ -0,0 +1,6205 @@ +// ====================================================================== lgtm [cpp/missing-header-guard] +// == DO NOT MODIFY THIS FILE BY HAND - IT IS AUTO GENERATED BY CMAKE! == +// ====================================================================== +// +// doctest.h - the lightest feature-rich C++ single-header testing framework for unit tests and TDD +// +// Copyright (c) 2016-2019 Viktor Kirilov +// +// Distributed under the MIT Software License +// See accompanying file LICENSE.txt or copy at +// https://opensource.org/licenses/MIT +// +// The documentation can be found at the library's page: +// https://github.com/onqtam/doctest/blob/master/doc/markdown/readme.md +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= +// +// The library is heavily influenced by Catch - https://github.com/catchorg/Catch2 +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/catchorg/Catch2/blob/master/LICENSE.txt +// +// The concept of subcases (sections in Catch) and expression decomposition are from there. +// Some parts of the code are taken directly: +// - stringification - the detection of "ostream& operator<<(ostream&, const T&)" and StringMaker<> +// - the Approx() helper class for floating point comparison +// - colors in the console +// - breaking into a debugger +// - signal / SEH handling +// - timer +// - XmlWriter class - thanks to Phil Nash for allowing the direct reuse (AKA copy/paste) +// +// The expression decomposing templates are taken from lest - https://github.com/martinmoene/lest +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/martinmoene/lest/blob/master/LICENSE.txt +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= + +#ifndef DOCTEST_LIBRARY_INCLUDED +#define DOCTEST_LIBRARY_INCLUDED + +// ================================================================================================= +// == VERSION ====================================================================================== +// ================================================================================================= + +#define DOCTEST_VERSION_MAJOR 2 +#define DOCTEST_VERSION_MINOR 4 +#define DOCTEST_VERSION_PATCH 0 +#define DOCTEST_VERSION_STR "2.4.0" + +#define DOCTEST_VERSION \ + (DOCTEST_VERSION_MAJOR * 10000 + DOCTEST_VERSION_MINOR * 100 + DOCTEST_VERSION_PATCH) + +// ================================================================================================= +// == COMPILER VERSION ============================================================================= +// ================================================================================================= + +// ideas for the version stuff are taken from here: https://github.com/cxxstuff/cxx_detect + +#define DOCTEST_COMPILER(MAJOR, MINOR, PATCH) ((MAJOR)*10000000 + (MINOR)*100000 + (PATCH)) + +// GCC/Clang and GCC/MSVC are mutually exclusive, but Clang/MSVC are not because of clang-cl... +#if defined(_MSC_VER) && defined(_MSC_FULL_VER) +#if _MSC_VER == _MSC_FULL_VER / 10000 +#define DOCTEST_MSVC DOCTEST_COMPILER(_MSC_VER / 100, _MSC_VER % 100, _MSC_FULL_VER % 10000) +#else // MSVC +#define DOCTEST_MSVC \ + DOCTEST_COMPILER(_MSC_VER / 100, (_MSC_FULL_VER / 100000) % 100, _MSC_FULL_VER % 100000) +#endif // MSVC +#endif // MSVC +#if defined(__clang__) && defined(__clang_minor__) +#define DOCTEST_CLANG DOCTEST_COMPILER(__clang_major__, __clang_minor__, __clang_patchlevel__) +#elif defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) && \ + !defined(__INTEL_COMPILER) +#define DOCTEST_GCC DOCTEST_COMPILER(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#endif // GCC + +#ifndef DOCTEST_MSVC +#define DOCTEST_MSVC 0 +#endif // DOCTEST_MSVC +#ifndef DOCTEST_CLANG +#define DOCTEST_CLANG 0 +#endif // DOCTEST_CLANG +#ifndef DOCTEST_GCC +#define DOCTEST_GCC 0 +#endif // DOCTEST_GCC + +// ================================================================================================= +// == COMPILER WARNINGS HELPERS ==================================================================== +// ================================================================================================= + +#if DOCTEST_CLANG +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(clang diagnostic ignored w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH DOCTEST_CLANG_SUPPRESS_WARNING(w) +#else // DOCTEST_CLANG +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_CLANG + +#if DOCTEST_GCC +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH _Pragma("GCC diagnostic push") +#define DOCTEST_GCC_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(GCC diagnostic ignored w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP _Pragma("GCC diagnostic pop") +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_GCC_SUPPRESS_WARNING_PUSH DOCTEST_GCC_SUPPRESS_WARNING(w) +#else // DOCTEST_GCC +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH +#define DOCTEST_GCC_SUPPRESS_WARNING(w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_GCC + +#if DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH __pragma(warning(push)) +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) __pragma(warning(disable : w)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP __pragma(warning(pop)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH DOCTEST_MSVC_SUPPRESS_WARNING(w) +#else // DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_MSVC + +// ================================================================================================= +// == COMPILER WARNINGS ============================================================================ +// ================================================================================================= + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdeprecated") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-local-typedef") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") +DOCTEST_GCC_SUPPRESS_WARNING("-Wctor-dtor-privacy") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-local-typedefs") +DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-promo") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4616) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4619) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4996) // The compiler encountered a deprecated declaration +DOCTEST_MSVC_SUPPRESS_WARNING(4706) // assignment within conditional expression +DOCTEST_MSVC_SUPPRESS_WARNING(4512) // 'class' : assignment operator could not be generated +DOCTEST_MSVC_SUPPRESS_WARNING(4127) // conditional expression is constant +DOCTEST_MSVC_SUPPRESS_WARNING(4820) // padding +DOCTEST_MSVC_SUPPRESS_WARNING(4625) // copy constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4626) // assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5027) // move assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5026) // move constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4623) // default constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4640) // construction of local static object is not thread-safe +// static analysis +DOCTEST_MSVC_SUPPRESS_WARNING(26439) // This kind of function may not throw. Declare it 'noexcept' +DOCTEST_MSVC_SUPPRESS_WARNING(26495) // Always initialize a member variable +DOCTEST_MSVC_SUPPRESS_WARNING(26451) // Arithmetic overflow ... +DOCTEST_MSVC_SUPPRESS_WARNING(26444) // Avoid unnamed objects with custom construction and dtr... +DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' + +// 4548 - expression before comma has no effect; expected expression with side - effect +// 4265 - class has virtual functions, but destructor is not virtual +// 4986 - exception specification does not match previous declaration +// 4350 - behavior change: 'member1' called instead of 'member2' +// 4668 - 'x' is not defined as a preprocessor macro, replacing with '0' for '#if/#elif' +// 4365 - conversion from 'int' to 'unsigned long', signed/unsigned mismatch +// 4774 - format string expected in argument 'x' is not a string literal +// 4820 - padding in structs + +// only 4 should be disabled globally: +// - 4514 # unreferenced inline function has been removed +// - 4571 # SEH related +// - 4710 # function not inlined +// - 4711 # function 'x' selected for automatic inline expansion + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ + DOCTEST_MSVC_SUPPRESS_WARNING(4548) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4265) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4986) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4350) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4668) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4365) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4774) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4623) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5039) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5105) + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END DOCTEST_MSVC_SUPPRESS_WARNING_POP + +// ================================================================================================= +// == FEATURE DETECTION ============================================================================ +// ================================================================================================= + +// general compiler feature support table: https://en.cppreference.com/w/cpp/compiler_support +// MSVC C++11 feature support table: https://msdn.microsoft.com/en-us/library/hh567368.aspx +// GCC C++11 feature support table: https://gcc.gnu.org/projects/cxx-status.html +// MSVC version table: +// https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B#Internal_version_numbering +// MSVC++ 14.2 (16) _MSC_VER == 1920 (Visual Studio 2019) +// MSVC++ 14.1 (15) _MSC_VER == 1910 (Visual Studio 2017) +// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) +// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013) +// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012) +// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010) +// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) +// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) + +#if DOCTEST_MSVC && !defined(DOCTEST_CONFIG_WINDOWS_SEH) +#define DOCTEST_CONFIG_WINDOWS_SEH +#endif // MSVC +#if defined(DOCTEST_CONFIG_NO_WINDOWS_SEH) && defined(DOCTEST_CONFIG_WINDOWS_SEH) +#undef DOCTEST_CONFIG_WINDOWS_SEH +#endif // DOCTEST_CONFIG_NO_WINDOWS_SEH + +#if !defined(_WIN32) && !defined(__QNX__) && !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && \ + !defined(__EMSCRIPTEN__) +#define DOCTEST_CONFIG_POSIX_SIGNALS +#endif // _WIN32 +#if defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) && defined(DOCTEST_CONFIG_POSIX_SIGNALS) +#undef DOCTEST_CONFIG_POSIX_SIGNALS +#endif // DOCTEST_CONFIG_NO_POSIX_SIGNALS + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // no exceptions +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#if defined(DOCTEST_CONFIG_NO_EXCEPTIONS) && !defined(DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS) +#define DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS && !DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#if defined(DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN) && !defined(DOCTEST_CONFIG_IMPLEMENT) +#define DOCTEST_CONFIG_IMPLEMENT +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +#if defined(_WIN32) || defined(__CYGWIN__) +#if DOCTEST_MSVC +#define DOCTEST_SYMBOL_EXPORT __declspec(dllexport) +#define DOCTEST_SYMBOL_IMPORT __declspec(dllimport) +#else // MSVC +#define DOCTEST_SYMBOL_EXPORT __attribute__((dllexport)) +#define DOCTEST_SYMBOL_IMPORT __attribute__((dllimport)) +#endif // MSVC +#else // _WIN32 +#define DOCTEST_SYMBOL_EXPORT __attribute__((visibility("default"))) +#define DOCTEST_SYMBOL_IMPORT +#endif // _WIN32 + +#ifdef DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#ifdef DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_EXPORT +#else // DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_IMPORT +#endif // DOCTEST_CONFIG_IMPLEMENT +#else // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#define DOCTEST_INTERFACE +#endif // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL + +#define DOCTEST_EMPTY + +#if DOCTEST_MSVC +#define DOCTEST_NOINLINE __declspec(noinline) +#define DOCTEST_UNUSED +#define DOCTEST_ALIGNMENT(x) +#else // MSVC +#define DOCTEST_NOINLINE __attribute__((noinline)) +#define DOCTEST_UNUSED __attribute__((unused)) +#define DOCTEST_ALIGNMENT(x) __attribute__((aligned(x))) +#endif // MSVC + +#ifndef DOCTEST_NORETURN +#define DOCTEST_NORETURN [[noreturn]] +#endif // DOCTEST_NORETURN + +#ifndef DOCTEST_NOEXCEPT +#define DOCTEST_NOEXCEPT noexcept +#endif // DOCTEST_NOEXCEPT + +// ================================================================================================= +// == FEATURE DETECTION END ======================================================================== +// ================================================================================================= + +// internal macros for string concatenation and anonymous variable name generation +#define DOCTEST_CAT_IMPL(s1, s2) s1##s2 +#define DOCTEST_CAT(s1, s2) DOCTEST_CAT_IMPL(s1, s2) +#ifdef __COUNTER__ // not standard and may be missing for some compilers +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __COUNTER__) +#else // __COUNTER__ +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __LINE__) +#endif // __COUNTER__ + +#define DOCTEST_TOSTR(x) #x + +#ifndef DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x& +#else // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x +#endif // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE + +// not using __APPLE__ because... this is how Catch does it +#ifdef __MAC_OS_X_VERSION_MIN_REQUIRED +#define DOCTEST_PLATFORM_MAC +#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) +#define DOCTEST_PLATFORM_IPHONE +#elif defined(_WIN32) +#define DOCTEST_PLATFORM_WINDOWS +#else // DOCTEST_PLATFORM +#define DOCTEST_PLATFORM_LINUX +#endif // DOCTEST_PLATFORM + +#define DOCTEST_GLOBAL_NO_WARNINGS(var) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-variable") \ + static int var DOCTEST_UNUSED // NOLINT(fuchsia-statically-constructed-objects,cert-err58-cpp) +#define DOCTEST_GLOBAL_NO_WARNINGS_END() DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#ifndef DOCTEST_BREAK_INTO_DEBUGGER +// should probably take a look at https://github.com/scottt/debugbreak +#ifdef DOCTEST_PLATFORM_MAC +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) +#elif DOCTEST_MSVC +#define DOCTEST_BREAK_INTO_DEBUGGER() __debugbreak() +#elif defined(__MINGW32__) +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wredundant-decls") +extern "C" __declspec(dllimport) void __stdcall DebugBreak(); +DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_BREAK_INTO_DEBUGGER() ::DebugBreak() +#else // linux +#define DOCTEST_BREAK_INTO_DEBUGGER() ((void)0) +#endif // linux +#endif // DOCTEST_BREAK_INTO_DEBUGGER + +// this is kept here for backwards compatibility since the config option was changed +#ifdef DOCTEST_CONFIG_USE_IOSFWD +#define DOCTEST_CONFIG_USE_STD_HEADERS +#endif // DOCTEST_CONFIG_USE_IOSFWD + +#ifdef DOCTEST_CONFIG_USE_STD_HEADERS +#include +#include +#include +#else // DOCTEST_CONFIG_USE_STD_HEADERS + +#if DOCTEST_CLANG +// to detect if libc++ is being used with clang (the _LIBCPP_VERSION identifier) +#include +#endif // clang + +#ifdef _LIBCPP_VERSION +#define DOCTEST_STD_NAMESPACE_BEGIN _LIBCPP_BEGIN_NAMESPACE_STD +#define DOCTEST_STD_NAMESPACE_END _LIBCPP_END_NAMESPACE_STD +#else // _LIBCPP_VERSION +#define DOCTEST_STD_NAMESPACE_BEGIN namespace std { +#define DOCTEST_STD_NAMESPACE_END } +#endif // _LIBCPP_VERSION + +// Forward declaring 'X' in namespace std is not permitted by the C++ Standard. +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4643) + +DOCTEST_STD_NAMESPACE_BEGIN // NOLINT (cert-dcl58-cpp) +typedef decltype(nullptr) nullptr_t; +template +struct char_traits; +template <> +struct char_traits; +template +class basic_ostream; +typedef basic_ostream> ostream; +template +class tuple; +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 +template +class allocator; +template +class basic_string; +using string = basic_string, allocator>; +#endif // VS 2019 +DOCTEST_STD_NAMESPACE_END + +DOCTEST_MSVC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_USE_STD_HEADERS + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#include +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + +namespace doctest { + +DOCTEST_INTERFACE extern bool is_running_in_test; + +// A 24 byte string class (can be as small as 17 for x64 and 13 for x86) that can hold strings with length +// of up to 23 chars on the stack before going on the heap - the last byte of the buffer is used for: +// - "is small" bit - the highest bit - if "0" then it is small - otherwise its "1" (128) +// - if small - capacity left before going on the heap - using the lowest 5 bits +// - if small - 2 bits are left unused - the second and third highest ones +// - if small - acts as a null terminator if strlen() is 23 (24 including the null terminator) +// and the "is small" bit remains "0" ("as well as the capacity left") so its OK +// Idea taken from this lecture about the string implementation of facebook/folly - fbstring +// https://www.youtube.com/watch?v=kPR8h4-qZdk +// TODO: +// - optimizations - like not deleting memory unnecessarily in operator= and etc. +// - resize/reserve/clear +// - substr +// - replace +// - back/front +// - iterator stuff +// - find & friends +// - push_back/pop_back +// - assign/insert/erase +// - relational operators as free functions - taking const char* as one of the params +class DOCTEST_INTERFACE String +{ + static const unsigned len = 24; //!OCLINT avoid private static members + static const unsigned last = len - 1; //!OCLINT avoid private static members + + struct view // len should be more than sizeof(view) - because of the final byte for flags + { + char* ptr; + unsigned size; + unsigned capacity; + }; + + union + { + char buf[len]; + view data; + }; + + bool isOnStack() const { return (buf[last] & 128) == 0; } + void setOnHeap(); + void setLast(unsigned in = last); + + void copy(const String& other); + +public: + String(); + ~String(); + + // cppcheck-suppress noExplicitConstructor + String(const char* in); + String(const char* in, unsigned in_size); + + String(const String& other); + String& operator=(const String& other); + + String& operator+=(const String& other); + String operator+(const String& other) const; + + String(String&& other); + String& operator=(String&& other); + + char operator[](unsigned i) const; + char& operator[](unsigned i); + + // the only functions I'm willing to leave in the interface - available for inlining + const char* c_str() const { return const_cast(this)->c_str(); } // NOLINT + char* c_str() { + if(isOnStack()) + return reinterpret_cast(buf); + return data.ptr; + } + + unsigned size() const; + unsigned capacity() const; + + int compare(const char* other, bool no_case = false) const; + int compare(const String& other, bool no_case = false) const; +}; + +DOCTEST_INTERFACE bool operator==(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator!=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>=(const String& lhs, const String& rhs); + +DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); + +namespace Color { + enum Enum + { + None = 0, + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White + }; + + DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, Color::Enum code); +} // namespace Color + +namespace assertType { + enum Enum + { + // macro traits + + is_warn = 1, + is_check = 2 * is_warn, + is_require = 2 * is_check, + + is_normal = 2 * is_require, + is_throws = 2 * is_normal, + is_throws_as = 2 * is_throws, + is_throws_with = 2 * is_throws_as, + is_nothrow = 2 * is_throws_with, + + is_false = 2 * is_nothrow, + is_unary = 2 * is_false, // not checked anywhere - used just to distinguish the types + + is_eq = 2 * is_unary, + is_ne = 2 * is_eq, + + is_lt = 2 * is_ne, + is_gt = 2 * is_lt, + + is_ge = 2 * is_gt, + is_le = 2 * is_ge, + + // macro types + + DT_WARN = is_normal | is_warn, + DT_CHECK = is_normal | is_check, + DT_REQUIRE = is_normal | is_require, + + DT_WARN_FALSE = is_normal | is_false | is_warn, + DT_CHECK_FALSE = is_normal | is_false | is_check, + DT_REQUIRE_FALSE = is_normal | is_false | is_require, + + DT_WARN_THROWS = is_throws | is_warn, + DT_CHECK_THROWS = is_throws | is_check, + DT_REQUIRE_THROWS = is_throws | is_require, + + DT_WARN_THROWS_AS = is_throws_as | is_warn, + DT_CHECK_THROWS_AS = is_throws_as | is_check, + DT_REQUIRE_THROWS_AS = is_throws_as | is_require, + + DT_WARN_THROWS_WITH = is_throws_with | is_warn, + DT_CHECK_THROWS_WITH = is_throws_with | is_check, + DT_REQUIRE_THROWS_WITH = is_throws_with | is_require, + + DT_WARN_THROWS_WITH_AS = is_throws_with | is_throws_as | is_warn, + DT_CHECK_THROWS_WITH_AS = is_throws_with | is_throws_as | is_check, + DT_REQUIRE_THROWS_WITH_AS = is_throws_with | is_throws_as | is_require, + + DT_WARN_NOTHROW = is_nothrow | is_warn, + DT_CHECK_NOTHROW = is_nothrow | is_check, + DT_REQUIRE_NOTHROW = is_nothrow | is_require, + + DT_WARN_EQ = is_normal | is_eq | is_warn, + DT_CHECK_EQ = is_normal | is_eq | is_check, + DT_REQUIRE_EQ = is_normal | is_eq | is_require, + + DT_WARN_NE = is_normal | is_ne | is_warn, + DT_CHECK_NE = is_normal | is_ne | is_check, + DT_REQUIRE_NE = is_normal | is_ne | is_require, + + DT_WARN_GT = is_normal | is_gt | is_warn, + DT_CHECK_GT = is_normal | is_gt | is_check, + DT_REQUIRE_GT = is_normal | is_gt | is_require, + + DT_WARN_LT = is_normal | is_lt | is_warn, + DT_CHECK_LT = is_normal | is_lt | is_check, + DT_REQUIRE_LT = is_normal | is_lt | is_require, + + DT_WARN_GE = is_normal | is_ge | is_warn, + DT_CHECK_GE = is_normal | is_ge | is_check, + DT_REQUIRE_GE = is_normal | is_ge | is_require, + + DT_WARN_LE = is_normal | is_le | is_warn, + DT_CHECK_LE = is_normal | is_le | is_check, + DT_REQUIRE_LE = is_normal | is_le | is_require, + + DT_WARN_UNARY = is_normal | is_unary | is_warn, + DT_CHECK_UNARY = is_normal | is_unary | is_check, + DT_REQUIRE_UNARY = is_normal | is_unary | is_require, + + DT_WARN_UNARY_FALSE = is_normal | is_false | is_unary | is_warn, + DT_CHECK_UNARY_FALSE = is_normal | is_false | is_unary | is_check, + DT_REQUIRE_UNARY_FALSE = is_normal | is_false | is_unary | is_require, + }; +} // namespace assertType + +DOCTEST_INTERFACE const char* assertString(assertType::Enum at); +DOCTEST_INTERFACE const char* failureString(assertType::Enum at); +DOCTEST_INTERFACE const char* skipPathFromFilename(const char* file); + +struct DOCTEST_INTERFACE TestCaseData +{ + String m_file; // the file in which the test was registered + unsigned m_line; // the line where the test was registered + const char* m_name; // name of the test case + const char* m_test_suite; // the test suite in which the test was added + const char* m_description; + bool m_skip; + bool m_may_fail; + bool m_should_fail; + int m_expected_failures; + double m_timeout; +}; + +struct DOCTEST_INTERFACE AssertData +{ + // common - for all asserts + const TestCaseData* m_test_case; + assertType::Enum m_at; + const char* m_file; + int m_line; + const char* m_expr; + bool m_failed; + + // exception-related - for all asserts + bool m_threw; + String m_exception; + + // for normal asserts + String m_decomp; + + // for specific exception-related asserts + bool m_threw_as; + const char* m_exception_type; + const char* m_exception_string; +}; + +struct DOCTEST_INTERFACE MessageData +{ + String m_string; + const char* m_file; + int m_line; + assertType::Enum m_severity; +}; + +struct DOCTEST_INTERFACE SubcaseSignature +{ + String m_name; + const char* m_file; + int m_line; + + bool operator<(const SubcaseSignature& other) const; +}; + +struct DOCTEST_INTERFACE IContextScope +{ + IContextScope(); + virtual ~IContextScope(); + virtual void stringify(std::ostream*) const = 0; +}; + +struct ContextOptions //!OCLINT too many fields +{ + std::ostream* cout; // stdout stream - std::cout by default + std::ostream* cerr; // stderr stream - std::cerr by default + String binary_name; // the test binary name + + // == parameters from the command line + String out; // output filename + String order_by; // how tests should be ordered + unsigned rand_seed; // the seed for rand ordering + + unsigned first; // the first (matching) test to be executed + unsigned last; // the last (matching) test to be executed + + int abort_after; // stop tests after this many failed assertions + int subcase_filter_levels; // apply the subcase filters for the first N levels + + bool success; // include successful assertions in output + bool case_sensitive; // if filtering should be case sensitive + bool exit; // if the program should be exited after the tests are ran/whatever + bool duration; // print the time duration of each test case + bool no_throw; // to skip exceptions-related assertion macros + bool no_exitcode; // if the framework should return 0 as the exitcode + bool no_run; // to not run the tests at all (can be done with an "*" exclude) + bool no_version; // to not print the version of the framework + bool no_colors; // if output to the console should be colorized + bool force_colors; // forces the use of colors even when a tty cannot be detected + bool no_breaks; // to not break into the debugger + bool no_skip; // don't skip test cases which are marked to be skipped + bool gnu_file_line; // if line numbers should be surrounded with :x: and not (x): + bool no_path_in_filenames; // if the path to files should be removed from the output + bool no_line_numbers; // if source code line numbers should be omitted from the output + bool no_skipped_summary; // don't print "skipped" in the summary !!! UNDOCUMENTED !!! + bool no_time_in_output; // omit any time/timestamps from output !!! UNDOCUMENTED !!! + + bool help; // to print the help + bool version; // to print the version + bool count; // if only the count of matching tests is to be retrieved + bool list_test_cases; // to list all tests matching the filters + bool list_test_suites; // to list all suites matching the filters + bool list_reporters; // lists all registered reporters +}; + +namespace detail { +#if defined(DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING) || defined(DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS) + template + struct enable_if + {}; + + template + struct enable_if + { typedef TYPE type; }; +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING) || DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format off + template struct remove_reference { typedef T type; }; + template struct remove_reference { typedef T type; }; + template struct remove_reference { typedef T type; }; + + template struct remove_const { typedef T type; }; + template struct remove_const { typedef T type; }; + // clang-format on + + template + struct deferred_false + // cppcheck-suppress unusedStructMember + { static const bool value = false; }; + + namespace has_insertion_operator_impl { + std::ostream &os(); + template + DOCTEST_REF_WRAP(T) val(); + + template + struct check { + static constexpr auto value = false; + }; + + template + struct check(), void())> { + static constexpr auto value = true; + }; + } // namespace has_insertion_operator_impl + + template + using has_insertion_operator = has_insertion_operator_impl::check; + + DOCTEST_INTERFACE void my_memcpy(void* dest, const void* src, unsigned num); + + DOCTEST_INTERFACE std::ostream* getTlsOss(); // returns a thread-local ostringstream + DOCTEST_INTERFACE String getTlsOssResult(); + + template + struct StringMakerBase + { + template + static String convert(const DOCTEST_REF_WRAP(T)) { + return "{?}"; + } + }; + + template <> + struct StringMakerBase + { + template + static String convert(const DOCTEST_REF_WRAP(T) in) { + *getTlsOss() << in; + return getTlsOssResult(); + } + }; + + DOCTEST_INTERFACE String rawMemoryToString(const void* object, unsigned size); + + template + String rawMemoryToString(const DOCTEST_REF_WRAP(T) object) { + return rawMemoryToString(&object, sizeof(object)); + } + + template + const char* type_to_string() { + return "<>"; + } +} // namespace detail + +template +struct StringMaker : public detail::StringMakerBase::value> +{}; + +template +struct StringMaker +{ + template + static String convert(U* p) { + if(p) + return detail::rawMemoryToString(p); + return "NULL"; + } +}; + +template +struct StringMaker +{ + static String convert(R C::*p) { + if(p) + return detail::rawMemoryToString(p); + return "NULL"; + } +}; + +template +String toString(const DOCTEST_REF_WRAP(T) value) { + return StringMaker::convert(value); +} + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +DOCTEST_INTERFACE String toString(char* in); +DOCTEST_INTERFACE String toString(const char* in); +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +DOCTEST_INTERFACE String toString(bool in); +DOCTEST_INTERFACE String toString(float in); +DOCTEST_INTERFACE String toString(double in); +DOCTEST_INTERFACE String toString(double long in); + +DOCTEST_INTERFACE String toString(char in); +DOCTEST_INTERFACE String toString(char signed in); +DOCTEST_INTERFACE String toString(char unsigned in); +DOCTEST_INTERFACE String toString(int short in); +DOCTEST_INTERFACE String toString(int short unsigned in); +DOCTEST_INTERFACE String toString(int in); +DOCTEST_INTERFACE String toString(int unsigned in); +DOCTEST_INTERFACE String toString(int long in); +DOCTEST_INTERFACE String toString(int long unsigned in); +DOCTEST_INTERFACE String toString(int long long in); +DOCTEST_INTERFACE String toString(int long long unsigned in); +DOCTEST_INTERFACE String toString(std::nullptr_t in); + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 +DOCTEST_INTERFACE String toString(const std::string& in); +#endif // VS 2019 + +class DOCTEST_INTERFACE Approx +{ +public: + explicit Approx(double value); + + Approx operator()(double value) const; + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + explicit Approx(const T& value, + typename detail::enable_if::value>::type* = + static_cast(nullptr)) { + *this = Approx(static_cast(value)); + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& epsilon(double newEpsilon); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename detail::enable_if::value, Approx&>::type epsilon( + const T& newEpsilon) { + m_epsilon = static_cast(newEpsilon); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& scale(double newScale); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename detail::enable_if::value, Approx&>::type scale( + const T& newScale) { + m_scale = static_cast(newScale); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format off + DOCTEST_INTERFACE friend bool operator==(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator==(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator!=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator!=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator<=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator<=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator>=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator>=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator< (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator< (const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator> (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator> (const Approx & lhs, double rhs); + + DOCTEST_INTERFACE friend String toString(const Approx& in); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#define DOCTEST_APPROX_PREFIX \ + template friend typename detail::enable_if::value, bool>::type + + DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(double(lhs), rhs); } + DOCTEST_APPROX_PREFIX operator==(const Approx& lhs, const T& rhs) { return operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator!=(const T& lhs, const Approx& rhs) { return !operator==(lhs, rhs); } + DOCTEST_APPROX_PREFIX operator!=(const Approx& lhs, const T& rhs) { return !operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return double(lhs) < rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < double(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return double(lhs) > rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > double(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return double(lhs) < rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < double(rhs) && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return double(lhs) > rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > double(rhs) && lhs != rhs; } +#undef DOCTEST_APPROX_PREFIX +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format on + +private: + double m_epsilon; + double m_scale; + double m_value; +}; + +DOCTEST_INTERFACE String toString(const Approx& in); + +DOCTEST_INTERFACE const ContextOptions* getContextOptions(); + +#if !defined(DOCTEST_CONFIG_DISABLE) + +namespace detail { + // clang-format off +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + template struct decay_array { typedef T type; }; + template struct decay_array { typedef T* type; }; + template struct decay_array { typedef T* type; }; + + template struct not_char_pointer { enum { value = 1 }; }; + template<> struct not_char_pointer { enum { value = 0 }; }; + template<> struct not_char_pointer { enum { value = 0 }; }; + + template struct can_use_op : public not_char_pointer::type> {}; +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + + struct DOCTEST_INTERFACE TestFailureException + { + }; + + DOCTEST_INTERFACE bool checkIfShouldThrow(assertType::Enum at); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_INTERFACE void throwException(); + + struct DOCTEST_INTERFACE Subcase + { + SubcaseSignature m_signature; + bool m_entered = false; + + Subcase(const String& name, const char* file, int line); + ~Subcase(); + + operator bool() const; + }; + + template + String stringifyBinaryExpr(const DOCTEST_REF_WRAP(L) lhs, const char* op, + const DOCTEST_REF_WRAP(R) rhs) { + return toString(lhs) + op + toString(rhs); + } + +#define DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(op, op_str, op_macro) \ + template \ + DOCTEST_NOINLINE Result operator op(const DOCTEST_REF_WRAP(R) rhs) { \ + bool res = op_macro(lhs, rhs); \ + if(m_at & assertType::is_false) \ + res = !res; \ + if(!res || doctest::getContextOptions()->success) \ + return Result(res, stringifyBinaryExpr(lhs, op_str, rhs)); \ + return Result(res); \ + } + + // more checks could be added - like in Catch: + // https://github.com/catchorg/Catch2/pull/1480/files + // https://github.com/catchorg/Catch2/pull/1481/files +#define DOCTEST_FORBIT_EXPRESSION(rt, op) \ + template \ + rt& operator op(const R&) { \ + static_assert(deferred_false::value, \ + "Expression Too Complex Please Rewrite As Binary Comparison!"); \ + return *this; \ + } + + struct DOCTEST_INTERFACE Result + { + bool m_passed; + String m_decomp; + + Result(bool passed, const String& decomposition = String()); + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Result, &) + DOCTEST_FORBIT_EXPRESSION(Result, ^) + DOCTEST_FORBIT_EXPRESSION(Result, |) + DOCTEST_FORBIT_EXPRESSION(Result, &&) + DOCTEST_FORBIT_EXPRESSION(Result, ||) + DOCTEST_FORBIT_EXPRESSION(Result, ==) + DOCTEST_FORBIT_EXPRESSION(Result, !=) + DOCTEST_FORBIT_EXPRESSION(Result, <) + DOCTEST_FORBIT_EXPRESSION(Result, >) + DOCTEST_FORBIT_EXPRESSION(Result, <=) + DOCTEST_FORBIT_EXPRESSION(Result, >=) + DOCTEST_FORBIT_EXPRESSION(Result, =) + DOCTEST_FORBIT_EXPRESSION(Result, +=) + DOCTEST_FORBIT_EXPRESSION(Result, -=) + DOCTEST_FORBIT_EXPRESSION(Result, *=) + DOCTEST_FORBIT_EXPRESSION(Result, /=) + DOCTEST_FORBIT_EXPRESSION(Result, %=) + DOCTEST_FORBIT_EXPRESSION(Result, <<=) + DOCTEST_FORBIT_EXPRESSION(Result, >>=) + DOCTEST_FORBIT_EXPRESSION(Result, &=) + DOCTEST_FORBIT_EXPRESSION(Result, ^=) + DOCTEST_FORBIT_EXPRESSION(Result, |=) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_GCC_SUPPRESS_WARNING_PUSH + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH + // https://stackoverflow.com/questions/39479163 what's the difference between 4018 and 4389 + DOCTEST_MSVC_SUPPRESS_WARNING(4388) // signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4389) // 'operator' : signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4018) // 'expression' : signed/unsigned mismatch + //DOCTEST_MSVC_SUPPRESS_WARNING(4805) // 'operation' : unsafe mix of type 'type' and type 'type' in operation + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + // clang-format off +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE bool +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE typename enable_if::value || can_use_op::value, bool>::type + inline bool eq(const char* lhs, const char* rhs) { return String(lhs) == String(rhs); } + inline bool ne(const char* lhs, const char* rhs) { return String(lhs) != String(rhs); } + inline bool lt(const char* lhs, const char* rhs) { return String(lhs) < String(rhs); } + inline bool gt(const char* lhs, const char* rhs) { return String(lhs) > String(rhs); } + inline bool le(const char* lhs, const char* rhs) { return String(lhs) <= String(rhs); } + inline bool ge(const char* lhs, const char* rhs) { return String(lhs) >= String(rhs); } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + +#define DOCTEST_RELATIONAL_OP(name, op) \ + template \ + DOCTEST_COMPARISON_RETURN_TYPE name(const DOCTEST_REF_WRAP(L) lhs, \ + const DOCTEST_REF_WRAP(R) rhs) { \ + return lhs op rhs; \ + } + + DOCTEST_RELATIONAL_OP(eq, ==) + DOCTEST_RELATIONAL_OP(ne, !=) + DOCTEST_RELATIONAL_OP(lt, <) + DOCTEST_RELATIONAL_OP(gt, >) + DOCTEST_RELATIONAL_OP(le, <=) + DOCTEST_RELATIONAL_OP(ge, >=) + +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) l == r +#define DOCTEST_CMP_NE(l, r) l != r +#define DOCTEST_CMP_GT(l, r) l > r +#define DOCTEST_CMP_LT(l, r) l < r +#define DOCTEST_CMP_GE(l, r) l >= r +#define DOCTEST_CMP_LE(l, r) l <= r +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) eq(l, r) +#define DOCTEST_CMP_NE(l, r) ne(l, r) +#define DOCTEST_CMP_GT(l, r) gt(l, r) +#define DOCTEST_CMP_LT(l, r) lt(l, r) +#define DOCTEST_CMP_GE(l, r) ge(l, r) +#define DOCTEST_CMP_LE(l, r) le(l, r) +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + + template + // cppcheck-suppress copyCtorAndEqOperator + struct Expression_lhs + { + L lhs; + assertType::Enum m_at; + + explicit Expression_lhs(L in, assertType::Enum at) + : lhs(in) + , m_at(at) {} + + DOCTEST_NOINLINE operator Result() { + bool res = !!lhs; + if(m_at & assertType::is_false) //!OCLINT bitwise operator in conditional + res = !res; + + if(!res || getContextOptions()->success) + return Result(res, toString(lhs)); + return Result(res); + } + + // clang-format off + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(==, " == ", DOCTEST_CMP_EQ) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(!=, " != ", DOCTEST_CMP_NE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>, " > ", DOCTEST_CMP_GT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<, " < ", DOCTEST_CMP_LT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>=, " >= ", DOCTEST_CMP_GE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<=, " <= ", DOCTEST_CMP_LE) //!OCLINT bitwise operator in conditional + // clang-format on + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &&) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ||) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, =) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, +=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, -=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, *=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, /=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, %=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |=) + // these 2 are unfortunate because they should be allowed - they have higher precedence over the comparisons, but the + // ExpressionDecomposer class uses the left shift operator to capture the left operand of the binary expression... + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + struct DOCTEST_INTERFACE ExpressionDecomposer + { + assertType::Enum m_at; + + ExpressionDecomposer(assertType::Enum at); + + // The right operator for capturing expressions is "<=" instead of "<<" (based on the operator precedence table) + // but then there will be warnings from GCC about "-Wparentheses" and since "_Pragma()" is problematic this will stay for now... + // https://github.com/catchorg/Catch2/issues/870 + // https://github.com/catchorg/Catch2/issues/565 + template + Expression_lhs operator<<(const DOCTEST_REF_WRAP(L) operand) { + return Expression_lhs(operand, m_at); + } + }; + + struct DOCTEST_INTERFACE TestSuite + { + const char* m_test_suite; + const char* m_description; + bool m_skip; + bool m_may_fail; + bool m_should_fail; + int m_expected_failures; + double m_timeout; + + TestSuite& operator*(const char* in); + + template + TestSuite& operator*(const T& in) { + in.fill(*this); + return *this; + } + }; + + typedef void (*funcType)(); + + struct DOCTEST_INTERFACE TestCase : public TestCaseData + { + funcType m_test; // a function pointer to the test case + + const char* m_type; // for templated test cases - gets appended to the real name + int m_template_id; // an ID used to distinguish between the different versions of a templated test case + String m_full_name; // contains the name (only for templated test cases!) + the template type + + TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const char* type = "", int template_id = -1); + + TestCase(const TestCase& other); + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + TestCase& operator=(const TestCase& other); + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& operator*(const char* in); + + template + TestCase& operator*(const T& in) { + in.fill(*this); + return *this; + } + + bool operator<(const TestCase& other) const; + }; + + // forward declarations of functions used by the macros + DOCTEST_INTERFACE int regTest(const TestCase& tc); + DOCTEST_INTERFACE int setTestSuite(const TestSuite& ts); + DOCTEST_INTERFACE bool isDebuggerActive(); + + template + int instantiationHelper(const T&) { return 0; } + + namespace binaryAssertComparison { + enum Enum + { + eq = 0, + ne, + gt, + lt, + ge, + le + }; + } // namespace binaryAssertComparison + + // clang-format off + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L), const DOCTEST_REF_WRAP(R) ) const { return false; } }; + +#define DOCTEST_BINARY_RELATIONAL_OP(n, op) \ + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) const { return op(lhs, rhs); } }; + // clang-format on + + DOCTEST_BINARY_RELATIONAL_OP(0, eq) + DOCTEST_BINARY_RELATIONAL_OP(1, ne) + DOCTEST_BINARY_RELATIONAL_OP(2, gt) + DOCTEST_BINARY_RELATIONAL_OP(3, lt) + DOCTEST_BINARY_RELATIONAL_OP(4, ge) + DOCTEST_BINARY_RELATIONAL_OP(5, le) + + struct DOCTEST_INTERFACE ResultBuilder : public AssertData + { + ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type = "", const char* exception_string = ""); + + void setResult(const Result& res); + + template + DOCTEST_NOINLINE void binary_assert(const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + m_failed = !RelationalComparator()(lhs, rhs); + if(m_failed || getContextOptions()->success) + m_decomp = stringifyBinaryExpr(lhs, ", ", rhs); + } + + template + DOCTEST_NOINLINE void unary_assert(const DOCTEST_REF_WRAP(L) val) { + m_failed = !val; + + if(m_at & assertType::is_false) //!OCLINT bitwise operator in conditional + m_failed = !m_failed; + + if(m_failed || getContextOptions()->success) + m_decomp = toString(val); + } + + void translateException(); + + bool log(); + void react() const; + }; + + namespace assertAction { + enum Enum + { + nothing = 0, + dbgbreak = 1, + shouldthrow = 2 + }; + } // namespace assertAction + + DOCTEST_INTERFACE void failed_out_of_a_testing_context(const AssertData& ad); + + DOCTEST_INTERFACE void decomp_assert(assertType::Enum at, const char* file, int line, + const char* expr, Result result); + +#define DOCTEST_ASSERT_OUT_OF_TESTS(decomp) \ + do { \ + if(!is_running_in_test) { \ + if(failed) { \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + rb.m_decomp = decomp; \ + failed_out_of_a_testing_context(rb); \ + if(isDebuggerActive() && !getContextOptions()->no_breaks) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(checkIfShouldThrow(at)) \ + throwException(); \ + } \ + return; \ + } \ + } while(false) + +#define DOCTEST_ASSERT_IN_TESTS(decomp) \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + if(rb.m_failed || getContextOptions()->success) \ + rb.m_decomp = decomp; \ + if(rb.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(rb.m_failed && checkIfShouldThrow(at)) \ + throwException() + + template + DOCTEST_NOINLINE void binary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + bool failed = !RelationalComparator()(lhs, rhs); + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + DOCTEST_ASSERT_IN_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + } + + template + DOCTEST_NOINLINE void unary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) val) { + bool failed = !val; + + if(at & assertType::is_false) //!OCLINT bitwise operator in conditional + failed = !failed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(toString(val)); + DOCTEST_ASSERT_IN_TESTS(toString(val)); + } + + struct DOCTEST_INTERFACE IExceptionTranslator + { + IExceptionTranslator(); + virtual ~IExceptionTranslator(); + virtual bool translate(String&) const = 0; + }; + + template + class ExceptionTranslator : public IExceptionTranslator //!OCLINT destructor of virtual class + { + public: + explicit ExceptionTranslator(String (*translateFunction)(T)) + : m_translateFunction(translateFunction) {} + + bool translate(String& res) const override { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { + throw; // lgtm [cpp/rethrow-no-exception] + // cppcheck-suppress catchExceptionByValue + } catch(T ex) { // NOLINT + res = m_translateFunction(ex); //!OCLINT parameter reassignment + return true; + } catch(...) {} //!OCLINT - empty catch statement +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + ((void)res); // to silence -Wunused-parameter + return false; + } + + private: + String (*m_translateFunction)(T); + }; + + DOCTEST_INTERFACE void registerExceptionTranslatorImpl(const IExceptionTranslator* et); + + template + struct StringStreamBase + { + template + static void convert(std::ostream* s, const T& in) { + *s << toString(in); + } + + // always treat char* as a string in this context - no matter + // if DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING is defined + static void convert(std::ostream* s, const char* in) { *s << String(in); } + }; + + template <> + struct StringStreamBase + { + template + static void convert(std::ostream* s, const T& in) { + *s << in; + } + }; + + template + struct StringStream : public StringStreamBase::value> + {}; + + template + void toStream(std::ostream* s, const T& value) { + StringStream::convert(s, value); + } + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + DOCTEST_INTERFACE void toStream(std::ostream* s, char* in); + DOCTEST_INTERFACE void toStream(std::ostream* s, const char* in); +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + DOCTEST_INTERFACE void toStream(std::ostream* s, bool in); + DOCTEST_INTERFACE void toStream(std::ostream* s, float in); + DOCTEST_INTERFACE void toStream(std::ostream* s, double in); + DOCTEST_INTERFACE void toStream(std::ostream* s, double long in); + + DOCTEST_INTERFACE void toStream(std::ostream* s, char in); + DOCTEST_INTERFACE void toStream(std::ostream* s, char signed in); + DOCTEST_INTERFACE void toStream(std::ostream* s, char unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int short in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int short unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long long in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long long unsigned in); + + // ContextScope base class used to allow implementing methods of ContextScope + // that don't depend on the template parameter in doctest.cpp. + class DOCTEST_INTERFACE ContextScopeBase : public IContextScope { + protected: + ContextScopeBase(); + + void destroy(); + }; + + template class ContextScope : public ContextScopeBase + { + const L &lambda_; + + public: + explicit ContextScope(const L &lambda) : lambda_(lambda) {} + + ContextScope(ContextScope &&other) : lambda_(other.lambda_) {} + + void stringify(std::ostream* s) const override { lambda_(s); } + + ~ContextScope() override { destroy(); } + }; + + struct DOCTEST_INTERFACE MessageBuilder : public MessageData + { + std::ostream* m_stream; + + MessageBuilder(const char* file, int line, assertType::Enum severity); + MessageBuilder() = delete; + ~MessageBuilder(); + + template + MessageBuilder& operator<<(const T& in) { + toStream(m_stream, in); + return *this; + } + + bool log(); + void react(); + }; + + template + ContextScope MakeContextScope(const L &lambda) { + return ContextScope(lambda); + } +} // namespace detail + +#define DOCTEST_DEFINE_DECORATOR(name, type, def) \ + struct name \ + { \ + type data; \ + name(type in = def) \ + : data(in) {} \ + void fill(detail::TestCase& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + void fill(detail::TestSuite& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + } + +DOCTEST_DEFINE_DECORATOR(test_suite, const char*, ""); +DOCTEST_DEFINE_DECORATOR(description, const char*, ""); +DOCTEST_DEFINE_DECORATOR(skip, bool, true); +DOCTEST_DEFINE_DECORATOR(timeout, double, 0); +DOCTEST_DEFINE_DECORATOR(may_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(should_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(expected_failures, int, 0); + +template +int registerExceptionTranslator(String (*translateFunction)(T)) { + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") + static detail::ExceptionTranslator exceptionTranslator(translateFunction); + DOCTEST_CLANG_SUPPRESS_WARNING_POP + detail::registerExceptionTranslatorImpl(&exceptionTranslator); + return 0; +} + +} // namespace doctest + +// in a separate namespace outside of doctest because the DOCTEST_TEST_SUITE macro +// introduces an anonymous namespace in which getCurrentTestSuite gets overridden +namespace doctest_detail_test_suite_ns { +DOCTEST_INTERFACE doctest::detail::TestSuite& getCurrentTestSuite(); +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +#else // DOCTEST_CONFIG_DISABLE +template +int registerExceptionTranslator(String (*)(T)) { + return 0; +} +#endif // DOCTEST_CONFIG_DISABLE + +namespace detail { + typedef void (*assert_handler)(const AssertData&); + struct ContextState; +} // namespace detail + +class DOCTEST_INTERFACE Context +{ + detail::ContextState* p; + + void parseArgs(int argc, const char* const* argv, bool withDefaults = false); + +public: + explicit Context(int argc = 0, const char* const* argv = nullptr); + + ~Context(); + + void applyCommandLine(int argc, const char* const* argv); + + void addFilter(const char* filter, const char* value); + void clearFilters(); + void setOption(const char* option, int value); + void setOption(const char* option, const char* value); + + bool shouldExit(); + + void setAsDefaultForAssertsOutOfTestCases(); + + void setAssertHandler(detail::assert_handler ah); + + int run(); +}; + +namespace TestCaseFailureReason { + enum Enum + { + None = 0, + AssertFailure = 1, // an assertion has failed in the test case + Exception = 2, // test case threw an exception + Crash = 4, // a crash... + TooManyFailedAsserts = 8, // the abort-after option + Timeout = 16, // see the timeout decorator + ShouldHaveFailedButDidnt = 32, // see the should_fail decorator + ShouldHaveFailedAndDid = 64, // see the should_fail decorator + DidntFailExactlyNumTimes = 128, // see the expected_failures decorator + FailedExactlyNumTimes = 256, // see the expected_failures decorator + CouldHaveFailedAndDid = 512 // see the may_fail decorator + }; +} // namespace TestCaseFailureReason + +struct DOCTEST_INTERFACE CurrentTestCaseStats +{ + int numAssertsCurrentTest; + int numAssertsFailedCurrentTest; + double seconds; + int failure_flags; // use TestCaseFailureReason::Enum +}; + +struct DOCTEST_INTERFACE TestCaseException +{ + String error_string; + bool is_crash; +}; + +struct DOCTEST_INTERFACE TestRunStats +{ + unsigned numTestCases; + unsigned numTestCasesPassingFilters; + unsigned numTestSuitesPassingFilters; + unsigned numTestCasesFailed; + int numAsserts; + int numAssertsFailed; +}; + +struct QueryData +{ + const TestRunStats* run_stats = nullptr; + const TestCaseData** data = nullptr; + unsigned num_data = 0; +}; + +struct DOCTEST_INTERFACE IReporter +{ + // The constructor has to accept "const ContextOptions&" as a single argument + // which has most of the options for the run + a pointer to the stdout stream + // Reporter(const ContextOptions& in) + + // called when a query should be reported (listing test cases, printing the version, etc.) + virtual void report_query(const QueryData&) = 0; + + // called when the whole test run starts + virtual void test_run_start() = 0; + // called when the whole test run ends (caching a pointer to the input doesn't make sense here) + virtual void test_run_end(const TestRunStats&) = 0; + + // called when a test case is started (safe to cache a pointer to the input) + virtual void test_case_start(const TestCaseData&) = 0; + // called when a test case is reentered because of unfinished subcases (safe to cache a pointer to the input) + virtual void test_case_reenter(const TestCaseData&) = 0; + // called when a test case has ended + virtual void test_case_end(const CurrentTestCaseStats&) = 0; + + // called when an exception is thrown from the test case (or it crashes) + virtual void test_case_exception(const TestCaseException&) = 0; + + // called whenever a subcase is entered (don't cache pointers to the input) + virtual void subcase_start(const SubcaseSignature&) = 0; + // called whenever a subcase is exited (don't cache pointers to the input) + virtual void subcase_end() = 0; + + // called for each assert (don't cache pointers to the input) + virtual void log_assert(const AssertData&) = 0; + // called for each message (don't cache pointers to the input) + virtual void log_message(const MessageData&) = 0; + + // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator + // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) + virtual void test_case_skipped(const TestCaseData&) = 0; + + // doctest will not be managing the lifetimes of reporters given to it but this would still be nice to have + virtual ~IReporter(); + + // can obtain all currently active contexts and stringify them if one wishes to do so + static int get_num_active_contexts(); + static const IContextScope* const* get_active_contexts(); + + // can iterate through contexts which have been stringified automatically in their destructors when an exception has been thrown + static int get_num_stringified_contexts(); + static const String* get_stringified_contexts(); +}; + +namespace detail { + typedef IReporter* (*reporterCreatorFunc)(const ContextOptions&); + + DOCTEST_INTERFACE void registerReporterImpl(const char* name, int prio, reporterCreatorFunc c, bool isReporter); + + template + IReporter* reporterCreator(const ContextOptions& o) { + return new Reporter(o); + } +} // namespace detail + +template +int registerReporter(const char* name, int priority, bool isReporter) { + detail::registerReporterImpl(name, priority, detail::reporterCreator, isReporter); + return 0; +} +} // namespace doctest + +// if registering is not disabled +#if !defined(DOCTEST_CONFIG_DISABLE) + +// common code in asserts - for convenience +#define DOCTEST_ASSERT_LOG_AND_REACT(b) \ + if(b.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + b.react() + +#ifdef DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) x; +#else // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) \ + try { \ + x; \ + } catch(...) { _DOCTEST_RB.translateException(); } +#endif // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#ifdef DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) \ + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wuseless-cast") \ + static_cast(__VA_ARGS__); \ + DOCTEST_GCC_SUPPRESS_WARNING_POP +#else // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) __VA_ARGS__; +#endif // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS + +// registers the test by initializing a dummy var with a function +#define DOCTEST_REGISTER_FUNCTION(global_prefix, f, decorators) \ + global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ + doctest::detail::regTest( \ + doctest::detail::TestCase( \ + f, __FILE__, __LINE__, \ + doctest_detail_test_suite_ns::getCurrentTestSuite()) * \ + decorators); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, decorators) \ + namespace { \ + struct der : public base \ + { \ + void f(); \ + }; \ + static void func() { \ + der v; \ + v.f(); \ + } \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, func, decorators) \ + } \ + inline DOCTEST_NOINLINE void der::f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, decorators) \ + static void f(); \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, f, decorators) \ + static void f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(f, proxy, decorators) \ + static doctest::detail::funcType proxy() { return f; } \ + DOCTEST_REGISTER_FUNCTION(inline const, proxy(), decorators) \ + static void f() + +// for registering tests +#define DOCTEST_TEST_CASE(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), decorators) + +// for registering tests in classes - requires C++17 for inline variables! +#if __cplusplus >= 201703L || (DOCTEST_MSVC >= DOCTEST_COMPILER(19, 12, 0) && _MSVC_LANG >= 201703L) +#define DOCTEST_TEST_CASE_CLASS(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_PROXY_), \ + decorators) +#else // DOCTEST_TEST_CASE_CLASS +#define DOCTEST_TEST_CASE_CLASS(...) \ + TEST_CASES_CAN_BE_REGISTERED_IN_CLASSES_ONLY_IN_CPP17_MODE_OR_WITH_VS_2017_OR_NEWER +#endif // DOCTEST_TEST_CASE_CLASS + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(c, decorators) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(_DOCTEST_ANON_CLASS_), c, \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), decorators) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING_IMPL(...) \ + template <> \ + inline const char* type_to_string<__VA_ARGS__>() { \ + return "<" #__VA_ARGS__ ">"; \ + } +#define DOCTEST_TYPE_TO_STRING(...) \ + namespace doctest { namespace detail { \ + DOCTEST_TYPE_TO_STRING_IMPL(__VA_ARGS__) \ + } \ + } \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, iter, func) \ + template \ + static void func(); \ + namespace { \ + template \ + struct iter; \ + template \ + struct iter> \ + { \ + iter(const char* file, unsigned line, int index) { \ + doctest::detail::regTest(doctest::detail::TestCase(func, file, line, \ + doctest_detail_test_suite_ns::getCurrentTestSuite(), \ + doctest::detail::type_to_string(), \ + int(line) * 1000 + index) \ + * dec); \ + iter>(file, line, index + 1); \ + } \ + }; \ + template <> \ + struct iter> \ + { \ + iter(const char*, unsigned, int) {} \ + }; \ + } \ + template \ + static void func() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(dec, T, id) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(id, ITERATOR), \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)) + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, anon, ...) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY)) = \ + doctest::detail::instantiationHelper(DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0));\ + DOCTEST_GLOBAL_NO_WARNINGS_END() + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), __VA_ARGS__) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, anon, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(anon, ITERATOR), anon); \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(anon, anon, std::tuple<__VA_ARGS__>) \ + template \ + static void anon() + +#define DOCTEST_TEST_CASE_TEMPLATE(dec, T, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), __VA_ARGS__) + +// for subcases +#define DOCTEST_SUBCASE(name) \ + if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(_DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ + doctest::detail::Subcase(name, __FILE__, __LINE__)) + +// for grouping tests in test suites by using code blocks +#define DOCTEST_TEST_SUITE_IMPL(decorators, ns_name) \ + namespace ns_name { namespace doctest_detail_test_suite_ns { \ + static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() { \ + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4640) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") \ + static doctest::detail::TestSuite data; \ + static bool inited = false; \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP \ + if(!inited) { \ + data* decorators; \ + inited = true; \ + } \ + return data; \ + } \ + } \ + } \ + namespace ns_name + +#define DOCTEST_TEST_SUITE(decorators) \ + DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(_DOCTEST_ANON_SUITE_)) + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(decorators) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * ""); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for registering exception translators +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(translatorName, signature) \ + inline doctest::String translatorName(signature); \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_)) = \ + doctest::registerExceptionTranslator(translatorName); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() \ + doctest::String translatorName(signature) + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_), \ + signature) + +// for registering reporters +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_REPORTER_)) = \ + doctest::registerReporter(name, priority, true); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for registering listeners +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_REPORTER_)) = \ + doctest::registerReporter(name, priority, false); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for logging +#define DOCTEST_INFO(expression) \ + DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_), DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_), \ + DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_), expression) + +#define DOCTEST_INFO_IMPL(lambda_name, mb_name, s_name, expression) \ + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4626) \ + auto lambda_name = [&](std::ostream* s_name) { \ + doctest::detail::MessageBuilder mb_name(__FILE__, __LINE__, doctest::assertType::is_warn); \ + mb_name.m_stream = s_name; \ + mb_name << expression; \ + }; \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP \ + auto DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope(lambda_name) + +#define DOCTEST_CAPTURE(x) DOCTEST_INFO(#x " := " << x) + +#define DOCTEST_ADD_AT_IMPL(type, file, line, mb, x) \ + do { \ + doctest::detail::MessageBuilder mb(file, line, doctest::assertType::type); \ + mb << x; \ + DOCTEST_ASSERT_LOG_AND_REACT(mb); \ + } while(false) + +// clang-format off +#define DOCTEST_ADD_MESSAGE_AT(file, line, x) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), x) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, x) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), x) +#define DOCTEST_ADD_FAIL_AT(file, line, x) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), x) +// clang-format on + +#define DOCTEST_MESSAGE(x) DOCTEST_ADD_MESSAGE_AT(__FILE__, __LINE__, x) +#define DOCTEST_FAIL_CHECK(x) DOCTEST_ADD_FAIL_CHECK_AT(__FILE__, __LINE__, x) +#define DOCTEST_FAIL(x) DOCTEST_ADD_FAIL_AT(__FILE__, __LINE__, x) + +#define DOCTEST_TO_LVALUE(...) __VA_ARGS__ // Not removed to keep backwards compatibility. + +#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_ASSERT_IMPLEMENT_2(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(_DOCTEST_RB.setResult( \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB) \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + do { \ + DOCTEST_ASSERT_IMPLEMENT_2(assert_type, __VA_ARGS__); \ + } while(false) + +#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +// necessary for _MESSAGE +#define DOCTEST_ASSERT_IMPLEMENT_2 DOCTEST_ASSERT_IMPLEMENT_1 + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + doctest::detail::decomp_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__) DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_WARN(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN, __VA_ARGS__) +#define DOCTEST_CHECK(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK, __VA_ARGS__) +#define DOCTEST_REQUIRE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE, __VA_ARGS__) +#define DOCTEST_WARN_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE_FALSE, __VA_ARGS__) + +// clang-format off +#define DOCTEST_WARN_MESSAGE(cond, msg) do { DOCTEST_INFO(msg); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } while(false) +#define DOCTEST_CHECK_MESSAGE(cond, msg) do { DOCTEST_INFO(msg); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } while(false) +#define DOCTEST_REQUIRE_MESSAGE(cond, msg) do { DOCTEST_INFO(msg); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } while(false) +#define DOCTEST_WARN_FALSE_MESSAGE(cond, msg) do { DOCTEST_INFO(msg); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } while(false) +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, msg) do { DOCTEST_INFO(msg); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } while(false) +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, msg) do { DOCTEST_INFO(msg); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } while(false) +// clang-format on + +#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ + do { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #expr, #__VA_ARGS__, message); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(const doctest::detail::remove_const< \ + doctest::detail::remove_reference<__VA_ARGS__>::type>::type&) { \ + _DOCTEST_RB.translateException(); \ + _DOCTEST_RB.m_threw_as = true; \ + } catch(...) { _DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } \ + } while(false) + +#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ + do { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, expr_str, "", __VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(...) { _DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } \ + } while(false) + +#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ + do { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ + } catch(...) { _DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } while(false) + +// clang-format off +#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") +#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") + +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) + +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, msg) do { DOCTEST_INFO(msg); DOCTEST_WARN_THROWS(expr); } while(false) +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, msg) do { DOCTEST_INFO(msg); DOCTEST_CHECK_THROWS(expr); } while(false) +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, msg) do { DOCTEST_INFO(msg); DOCTEST_REQUIRE_THROWS(expr); } while(false) +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, msg) do { DOCTEST_INFO(msg); DOCTEST_WARN_THROWS_AS(expr, ex); } while(false) +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, msg) do { DOCTEST_INFO(msg); DOCTEST_CHECK_THROWS_AS(expr, ex); } while(false) +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, msg) do { DOCTEST_INFO(msg); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } while(false) +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, msg) do { DOCTEST_INFO(msg); DOCTEST_WARN_THROWS_WITH(expr, with); } while(false) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, msg) do { DOCTEST_INFO(msg); DOCTEST_CHECK_THROWS_WITH(expr, with); } while(false) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, msg) do { DOCTEST_INFO(msg); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } while(false) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) do { DOCTEST_INFO(msg); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } while(false) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) do { DOCTEST_INFO(msg); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } while(false) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) do { DOCTEST_INFO(msg); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } while(false) +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, msg) do { DOCTEST_INFO(msg); DOCTEST_WARN_NOTHROW(expr); } while(false) +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, msg) do { DOCTEST_INFO(msg); DOCTEST_CHECK_NOTHROW(expr); } while(false) +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, msg) do { DOCTEST_INFO(msg); DOCTEST_REQUIRE_NOTHROW(expr); } while(false) +// clang-format on + +#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ + do { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY( \ + _DOCTEST_RB.binary_assert( \ + __VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } while(false) + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + do { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(_DOCTEST_RB.unary_assert(__VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } while(false) + +#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ + doctest::detail::binary_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ + #__VA_ARGS__, __VA_ARGS__) + +#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_WARN_EQ(...) DOCTEST_BINARY_ASSERT(DT_WARN_EQ, eq, __VA_ARGS__) +#define DOCTEST_CHECK_EQ(...) DOCTEST_BINARY_ASSERT(DT_CHECK_EQ, eq, __VA_ARGS__) +#define DOCTEST_REQUIRE_EQ(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_EQ, eq, __VA_ARGS__) +#define DOCTEST_WARN_NE(...) DOCTEST_BINARY_ASSERT(DT_WARN_NE, ne, __VA_ARGS__) +#define DOCTEST_CHECK_NE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_NE, ne, __VA_ARGS__) +#define DOCTEST_REQUIRE_NE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_NE, ne, __VA_ARGS__) +#define DOCTEST_WARN_GT(...) DOCTEST_BINARY_ASSERT(DT_WARN_GT, gt, __VA_ARGS__) +#define DOCTEST_CHECK_GT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GT, gt, __VA_ARGS__) +#define DOCTEST_REQUIRE_GT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GT, gt, __VA_ARGS__) +#define DOCTEST_WARN_LT(...) DOCTEST_BINARY_ASSERT(DT_WARN_LT, lt, __VA_ARGS__) +#define DOCTEST_CHECK_LT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LT, lt, __VA_ARGS__) +#define DOCTEST_REQUIRE_LT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LT, lt, __VA_ARGS__) +#define DOCTEST_WARN_GE(...) DOCTEST_BINARY_ASSERT(DT_WARN_GE, ge, __VA_ARGS__) +#define DOCTEST_CHECK_GE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GE, ge, __VA_ARGS__) +#define DOCTEST_REQUIRE_GE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GE, ge, __VA_ARGS__) +#define DOCTEST_WARN_LE(...) DOCTEST_BINARY_ASSERT(DT_WARN_LE, le, __VA_ARGS__) +#define DOCTEST_CHECK_LE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LE, le, __VA_ARGS__) +#define DOCTEST_REQUIRE_LE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LE, le, __VA_ARGS__) + +#define DOCTEST_WARN_UNARY(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY, __VA_ARGS__) +#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY_FALSE, __VA_ARGS__) + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS + +#undef DOCTEST_WARN_THROWS +#undef DOCTEST_CHECK_THROWS +#undef DOCTEST_REQUIRE_THROWS +#undef DOCTEST_WARN_THROWS_AS +#undef DOCTEST_CHECK_THROWS_AS +#undef DOCTEST_REQUIRE_THROWS_AS +#undef DOCTEST_WARN_THROWS_WITH +#undef DOCTEST_CHECK_THROWS_WITH +#undef DOCTEST_REQUIRE_THROWS_WITH +#undef DOCTEST_WARN_THROWS_WITH_AS +#undef DOCTEST_CHECK_THROWS_WITH_AS +#undef DOCTEST_REQUIRE_THROWS_WITH_AS +#undef DOCTEST_WARN_NOTHROW +#undef DOCTEST_CHECK_NOTHROW +#undef DOCTEST_REQUIRE_NOTHROW + +#undef DOCTEST_WARN_THROWS_MESSAGE +#undef DOCTEST_CHECK_THROWS_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_MESSAGE +#undef DOCTEST_WARN_THROWS_AS_MESSAGE +#undef DOCTEST_CHECK_THROWS_AS_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_AS_MESSAGE +#undef DOCTEST_WARN_THROWS_WITH_MESSAGE +#undef DOCTEST_CHECK_THROWS_WITH_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_WITH_MESSAGE +#undef DOCTEST_WARN_THROWS_WITH_AS_MESSAGE +#undef DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE +#undef DOCTEST_WARN_NOTHROW_MESSAGE +#undef DOCTEST_CHECK_NOTHROW_MESSAGE +#undef DOCTEST_REQUIRE_NOTHROW_MESSAGE + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#define DOCTEST_WARN_THROWS(...) ((void)0) +#define DOCTEST_CHECK_THROWS(...) ((void)0) +#define DOCTEST_REQUIRE_THROWS(...) ((void)0) +#define DOCTEST_WARN_THROWS_AS(expr, ...) ((void)0) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) ((void)0) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) ((void)0) +#define DOCTEST_WARN_THROWS_WITH(expr, ...) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) ((void)0) +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) ((void)0) +#define DOCTEST_WARN_NOTHROW(...) ((void)0) +#define DOCTEST_CHECK_NOTHROW(...) ((void)0) +#define DOCTEST_REQUIRE_NOTHROW(...) ((void)0) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, msg) ((void)0) +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, msg) ((void)0) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) ((void)0) +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, msg) ((void)0) + +#else // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#undef DOCTEST_REQUIRE +#undef DOCTEST_REQUIRE_FALSE +#undef DOCTEST_REQUIRE_MESSAGE +#undef DOCTEST_REQUIRE_FALSE_MESSAGE +#undef DOCTEST_REQUIRE_EQ +#undef DOCTEST_REQUIRE_NE +#undef DOCTEST_REQUIRE_GT +#undef DOCTEST_REQUIRE_LT +#undef DOCTEST_REQUIRE_GE +#undef DOCTEST_REQUIRE_LE +#undef DOCTEST_REQUIRE_UNARY +#undef DOCTEST_REQUIRE_UNARY_FALSE + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +// ================================================================================================= +// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == +// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == +// ================================================================================================= +#else // DOCTEST_CONFIG_DISABLE + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ + namespace { \ + template \ + struct der : public base \ + { void f(); }; \ + } \ + template \ + inline void der::f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ + template \ + static inline void f() + +// for registering tests +#define DOCTEST_TEST_CASE(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) + +// for registering tests in classes +#define DOCTEST_TEST_CASE_CLASS(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(_DOCTEST_ANON_CLASS_), x, \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING(...) typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) +#define DOCTEST_TYPE_TO_STRING_IMPL(...) + +// for typed tests +#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ + template \ + inline void DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ + template \ + inline void DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for subcases +#define DOCTEST_SUBCASE(name) + +// for a testsuite block +#define DOCTEST_TEST_SUITE(name) namespace + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(name) typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + template \ + static inline doctest::String DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_)(signature) + +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) + +#define DOCTEST_INFO(x) ((void)0) +#define DOCTEST_CAPTURE(x) ((void)0) +#define DOCTEST_ADD_MESSAGE_AT(file, line, x) ((void)0) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, x) ((void)0) +#define DOCTEST_ADD_FAIL_AT(file, line, x) ((void)0) +#define DOCTEST_MESSAGE(x) ((void)0) +#define DOCTEST_FAIL_CHECK(x) ((void)0) +#define DOCTEST_FAIL(x) ((void)0) + +#define DOCTEST_WARN(...) ((void)0) +#define DOCTEST_CHECK(...) ((void)0) +#define DOCTEST_REQUIRE(...) ((void)0) +#define DOCTEST_WARN_FALSE(...) ((void)0) +#define DOCTEST_CHECK_FALSE(...) ((void)0) +#define DOCTEST_REQUIRE_FALSE(...) ((void)0) + +#define DOCTEST_WARN_MESSAGE(cond, msg) ((void)0) +#define DOCTEST_CHECK_MESSAGE(cond, msg) ((void)0) +#define DOCTEST_REQUIRE_MESSAGE(cond, msg) ((void)0) +#define DOCTEST_WARN_FALSE_MESSAGE(cond, msg) ((void)0) +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, msg) ((void)0) +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, msg) ((void)0) + +#define DOCTEST_WARN_THROWS(...) ((void)0) +#define DOCTEST_CHECK_THROWS(...) ((void)0) +#define DOCTEST_REQUIRE_THROWS(...) ((void)0) +#define DOCTEST_WARN_THROWS_AS(expr, ...) ((void)0) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) ((void)0) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) ((void)0) +#define DOCTEST_WARN_THROWS_WITH(expr, ...) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) ((void)0) +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) ((void)0) +#define DOCTEST_WARN_NOTHROW(...) ((void)0) +#define DOCTEST_CHECK_NOTHROW(...) ((void)0) +#define DOCTEST_REQUIRE_NOTHROW(...) ((void)0) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, msg) ((void)0) +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, msg) ((void)0) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) ((void)0) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) ((void)0) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, msg) ((void)0) +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, msg) ((void)0) +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, msg) ((void)0) + +#define DOCTEST_WARN_EQ(...) ((void)0) +#define DOCTEST_CHECK_EQ(...) ((void)0) +#define DOCTEST_REQUIRE_EQ(...) ((void)0) +#define DOCTEST_WARN_NE(...) ((void)0) +#define DOCTEST_CHECK_NE(...) ((void)0) +#define DOCTEST_REQUIRE_NE(...) ((void)0) +#define DOCTEST_WARN_GT(...) ((void)0) +#define DOCTEST_CHECK_GT(...) ((void)0) +#define DOCTEST_REQUIRE_GT(...) ((void)0) +#define DOCTEST_WARN_LT(...) ((void)0) +#define DOCTEST_CHECK_LT(...) ((void)0) +#define DOCTEST_REQUIRE_LT(...) ((void)0) +#define DOCTEST_WARN_GE(...) ((void)0) +#define DOCTEST_CHECK_GE(...) ((void)0) +#define DOCTEST_REQUIRE_GE(...) ((void)0) +#define DOCTEST_WARN_LE(...) ((void)0) +#define DOCTEST_CHECK_LE(...) ((void)0) +#define DOCTEST_REQUIRE_LE(...) ((void)0) + +#define DOCTEST_WARN_UNARY(...) ((void)0) +#define DOCTEST_CHECK_UNARY(...) ((void)0) +#define DOCTEST_REQUIRE_UNARY(...) ((void)0) +#define DOCTEST_WARN_UNARY_FALSE(...) ((void)0) +#define DOCTEST_CHECK_UNARY_FALSE(...) ((void)0) +#define DOCTEST_REQUIRE_UNARY_FALSE(...) ((void)0) + +#endif // DOCTEST_CONFIG_DISABLE + +// clang-format off +// KEPT FOR BACKWARDS COMPATIBILITY - FORWARDING TO THE RIGHT MACROS +#define DOCTEST_FAST_WARN_EQ DOCTEST_WARN_EQ +#define DOCTEST_FAST_CHECK_EQ DOCTEST_CHECK_EQ +#define DOCTEST_FAST_REQUIRE_EQ DOCTEST_REQUIRE_EQ +#define DOCTEST_FAST_WARN_NE DOCTEST_WARN_NE +#define DOCTEST_FAST_CHECK_NE DOCTEST_CHECK_NE +#define DOCTEST_FAST_REQUIRE_NE DOCTEST_REQUIRE_NE +#define DOCTEST_FAST_WARN_GT DOCTEST_WARN_GT +#define DOCTEST_FAST_CHECK_GT DOCTEST_CHECK_GT +#define DOCTEST_FAST_REQUIRE_GT DOCTEST_REQUIRE_GT +#define DOCTEST_FAST_WARN_LT DOCTEST_WARN_LT +#define DOCTEST_FAST_CHECK_LT DOCTEST_CHECK_LT +#define DOCTEST_FAST_REQUIRE_LT DOCTEST_REQUIRE_LT +#define DOCTEST_FAST_WARN_GE DOCTEST_WARN_GE +#define DOCTEST_FAST_CHECK_GE DOCTEST_CHECK_GE +#define DOCTEST_FAST_REQUIRE_GE DOCTEST_REQUIRE_GE +#define DOCTEST_FAST_WARN_LE DOCTEST_WARN_LE +#define DOCTEST_FAST_CHECK_LE DOCTEST_CHECK_LE +#define DOCTEST_FAST_REQUIRE_LE DOCTEST_REQUIRE_LE + +#define DOCTEST_FAST_WARN_UNARY DOCTEST_WARN_UNARY +#define DOCTEST_FAST_CHECK_UNARY DOCTEST_CHECK_UNARY +#define DOCTEST_FAST_REQUIRE_UNARY DOCTEST_REQUIRE_UNARY +#define DOCTEST_FAST_WARN_UNARY_FALSE DOCTEST_WARN_UNARY_FALSE +#define DOCTEST_FAST_CHECK_UNARY_FALSE DOCTEST_CHECK_UNARY_FALSE +#define DOCTEST_FAST_REQUIRE_UNARY_FALSE DOCTEST_REQUIRE_UNARY_FALSE + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE DOCTEST_TEST_CASE_TEMPLATE_INVOKE +// clang-format on + +// BDD style macros +// clang-format off +#define DOCTEST_SCENARIO(name) DOCTEST_TEST_CASE(" Scenario: " name) +#define DOCTEST_SCENARIO_CLASS(name) DOCTEST_TEST_CASE_CLASS(" Scenario: " name) +#define DOCTEST_SCENARIO_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(" Scenario: " name, T, __VA_ARGS__) +#define DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(" Scenario: " name, T, id) + +#define DOCTEST_GIVEN(name) DOCTEST_SUBCASE(" Given: " name) +#define DOCTEST_WHEN(name) DOCTEST_SUBCASE(" When: " name) +#define DOCTEST_AND_WHEN(name) DOCTEST_SUBCASE("And when: " name) +#define DOCTEST_THEN(name) DOCTEST_SUBCASE(" Then: " name) +#define DOCTEST_AND_THEN(name) DOCTEST_SUBCASE(" And: " name) +// clang-format on + +// == SHORT VERSIONS OF THE MACROS +#if !defined(DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES) + +#define TEST_CASE DOCTEST_TEST_CASE +#define TEST_CASE_CLASS DOCTEST_TEST_CASE_CLASS +#define TEST_CASE_FIXTURE DOCTEST_TEST_CASE_FIXTURE +#define TYPE_TO_STRING DOCTEST_TYPE_TO_STRING +#define TEST_CASE_TEMPLATE DOCTEST_TEST_CASE_TEMPLATE +#define TEST_CASE_TEMPLATE_DEFINE DOCTEST_TEST_CASE_TEMPLATE_DEFINE +#define TEST_CASE_TEMPLATE_INVOKE DOCTEST_TEST_CASE_TEMPLATE_INVOKE +#define TEST_CASE_TEMPLATE_APPLY DOCTEST_TEST_CASE_TEMPLATE_APPLY +#define SUBCASE DOCTEST_SUBCASE +#define TEST_SUITE DOCTEST_TEST_SUITE +#define TEST_SUITE_BEGIN DOCTEST_TEST_SUITE_BEGIN +#define TEST_SUITE_END DOCTEST_TEST_SUITE_END +#define REGISTER_EXCEPTION_TRANSLATOR DOCTEST_REGISTER_EXCEPTION_TRANSLATOR +#define REGISTER_REPORTER DOCTEST_REGISTER_REPORTER +#define REGISTER_LISTENER DOCTEST_REGISTER_LISTENER +#define INFO DOCTEST_INFO +#define CAPTURE DOCTEST_CAPTURE +#define ADD_MESSAGE_AT DOCTEST_ADD_MESSAGE_AT +#define ADD_FAIL_CHECK_AT DOCTEST_ADD_FAIL_CHECK_AT +#define ADD_FAIL_AT DOCTEST_ADD_FAIL_AT +#define MESSAGE DOCTEST_MESSAGE +#define FAIL_CHECK DOCTEST_FAIL_CHECK +#define FAIL DOCTEST_FAIL +#define TO_LVALUE DOCTEST_TO_LVALUE + +#define WARN DOCTEST_WARN +#define WARN_FALSE DOCTEST_WARN_FALSE +#define WARN_THROWS DOCTEST_WARN_THROWS +#define WARN_THROWS_AS DOCTEST_WARN_THROWS_AS +#define WARN_THROWS_WITH DOCTEST_WARN_THROWS_WITH +#define WARN_THROWS_WITH_AS DOCTEST_WARN_THROWS_WITH_AS +#define WARN_NOTHROW DOCTEST_WARN_NOTHROW +#define CHECK DOCTEST_CHECK +#define CHECK_FALSE DOCTEST_CHECK_FALSE +#define CHECK_THROWS DOCTEST_CHECK_THROWS +#define CHECK_THROWS_AS DOCTEST_CHECK_THROWS_AS +#define CHECK_THROWS_WITH DOCTEST_CHECK_THROWS_WITH +#define CHECK_THROWS_WITH_AS DOCTEST_CHECK_THROWS_WITH_AS +#define CHECK_NOTHROW DOCTEST_CHECK_NOTHROW +#define REQUIRE DOCTEST_REQUIRE +#define REQUIRE_FALSE DOCTEST_REQUIRE_FALSE +#define REQUIRE_THROWS DOCTEST_REQUIRE_THROWS +#define REQUIRE_THROWS_AS DOCTEST_REQUIRE_THROWS_AS +#define REQUIRE_THROWS_WITH DOCTEST_REQUIRE_THROWS_WITH +#define REQUIRE_THROWS_WITH_AS DOCTEST_REQUIRE_THROWS_WITH_AS +#define REQUIRE_NOTHROW DOCTEST_REQUIRE_NOTHROW + +#define WARN_MESSAGE DOCTEST_WARN_MESSAGE +#define WARN_FALSE_MESSAGE DOCTEST_WARN_FALSE_MESSAGE +#define WARN_THROWS_MESSAGE DOCTEST_WARN_THROWS_MESSAGE +#define WARN_THROWS_AS_MESSAGE DOCTEST_WARN_THROWS_AS_MESSAGE +#define WARN_THROWS_WITH_MESSAGE DOCTEST_WARN_THROWS_WITH_MESSAGE +#define WARN_THROWS_WITH_AS_MESSAGE DOCTEST_WARN_THROWS_WITH_AS_MESSAGE +#define WARN_NOTHROW_MESSAGE DOCTEST_WARN_NOTHROW_MESSAGE +#define CHECK_MESSAGE DOCTEST_CHECK_MESSAGE +#define CHECK_FALSE_MESSAGE DOCTEST_CHECK_FALSE_MESSAGE +#define CHECK_THROWS_MESSAGE DOCTEST_CHECK_THROWS_MESSAGE +#define CHECK_THROWS_AS_MESSAGE DOCTEST_CHECK_THROWS_AS_MESSAGE +#define CHECK_THROWS_WITH_MESSAGE DOCTEST_CHECK_THROWS_WITH_MESSAGE +#define CHECK_THROWS_WITH_AS_MESSAGE DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE +#define CHECK_NOTHROW_MESSAGE DOCTEST_CHECK_NOTHROW_MESSAGE +#define REQUIRE_MESSAGE DOCTEST_REQUIRE_MESSAGE +#define REQUIRE_FALSE_MESSAGE DOCTEST_REQUIRE_FALSE_MESSAGE +#define REQUIRE_THROWS_MESSAGE DOCTEST_REQUIRE_THROWS_MESSAGE +#define REQUIRE_THROWS_AS_MESSAGE DOCTEST_REQUIRE_THROWS_AS_MESSAGE +#define REQUIRE_THROWS_WITH_MESSAGE DOCTEST_REQUIRE_THROWS_WITH_MESSAGE +#define REQUIRE_THROWS_WITH_AS_MESSAGE DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE +#define REQUIRE_NOTHROW_MESSAGE DOCTEST_REQUIRE_NOTHROW_MESSAGE + +#define SCENARIO DOCTEST_SCENARIO +#define SCENARIO_CLASS DOCTEST_SCENARIO_CLASS +#define SCENARIO_TEMPLATE DOCTEST_SCENARIO_TEMPLATE +#define SCENARIO_TEMPLATE_DEFINE DOCTEST_SCENARIO_TEMPLATE_DEFINE +#define GIVEN DOCTEST_GIVEN +#define WHEN DOCTEST_WHEN +#define AND_WHEN DOCTEST_AND_WHEN +#define THEN DOCTEST_THEN +#define AND_THEN DOCTEST_AND_THEN + +#define WARN_EQ DOCTEST_WARN_EQ +#define CHECK_EQ DOCTEST_CHECK_EQ +#define REQUIRE_EQ DOCTEST_REQUIRE_EQ +#define WARN_NE DOCTEST_WARN_NE +#define CHECK_NE DOCTEST_CHECK_NE +#define REQUIRE_NE DOCTEST_REQUIRE_NE +#define WARN_GT DOCTEST_WARN_GT +#define CHECK_GT DOCTEST_CHECK_GT +#define REQUIRE_GT DOCTEST_REQUIRE_GT +#define WARN_LT DOCTEST_WARN_LT +#define CHECK_LT DOCTEST_CHECK_LT +#define REQUIRE_LT DOCTEST_REQUIRE_LT +#define WARN_GE DOCTEST_WARN_GE +#define CHECK_GE DOCTEST_CHECK_GE +#define REQUIRE_GE DOCTEST_REQUIRE_GE +#define WARN_LE DOCTEST_WARN_LE +#define CHECK_LE DOCTEST_CHECK_LE +#define REQUIRE_LE DOCTEST_REQUIRE_LE +#define WARN_UNARY DOCTEST_WARN_UNARY +#define CHECK_UNARY DOCTEST_CHECK_UNARY +#define REQUIRE_UNARY DOCTEST_REQUIRE_UNARY +#define WARN_UNARY_FALSE DOCTEST_WARN_UNARY_FALSE +#define CHECK_UNARY_FALSE DOCTEST_CHECK_UNARY_FALSE +#define REQUIRE_UNARY_FALSE DOCTEST_REQUIRE_UNARY_FALSE + +// KEPT FOR BACKWARDS COMPATIBILITY +#define FAST_WARN_EQ DOCTEST_FAST_WARN_EQ +#define FAST_CHECK_EQ DOCTEST_FAST_CHECK_EQ +#define FAST_REQUIRE_EQ DOCTEST_FAST_REQUIRE_EQ +#define FAST_WARN_NE DOCTEST_FAST_WARN_NE +#define FAST_CHECK_NE DOCTEST_FAST_CHECK_NE +#define FAST_REQUIRE_NE DOCTEST_FAST_REQUIRE_NE +#define FAST_WARN_GT DOCTEST_FAST_WARN_GT +#define FAST_CHECK_GT DOCTEST_FAST_CHECK_GT +#define FAST_REQUIRE_GT DOCTEST_FAST_REQUIRE_GT +#define FAST_WARN_LT DOCTEST_FAST_WARN_LT +#define FAST_CHECK_LT DOCTEST_FAST_CHECK_LT +#define FAST_REQUIRE_LT DOCTEST_FAST_REQUIRE_LT +#define FAST_WARN_GE DOCTEST_FAST_WARN_GE +#define FAST_CHECK_GE DOCTEST_FAST_CHECK_GE +#define FAST_REQUIRE_GE DOCTEST_FAST_REQUIRE_GE +#define FAST_WARN_LE DOCTEST_FAST_WARN_LE +#define FAST_CHECK_LE DOCTEST_FAST_CHECK_LE +#define FAST_REQUIRE_LE DOCTEST_FAST_REQUIRE_LE + +#define FAST_WARN_UNARY DOCTEST_FAST_WARN_UNARY +#define FAST_CHECK_UNARY DOCTEST_FAST_CHECK_UNARY +#define FAST_REQUIRE_UNARY DOCTEST_FAST_REQUIRE_UNARY +#define FAST_WARN_UNARY_FALSE DOCTEST_FAST_WARN_UNARY_FALSE +#define FAST_CHECK_UNARY_FALSE DOCTEST_FAST_CHECK_UNARY_FALSE +#define FAST_REQUIRE_UNARY_FALSE DOCTEST_FAST_REQUIRE_UNARY_FALSE + +#define TEST_CASE_TEMPLATE_INSTANTIATE DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE + +#endif // DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES + +#if !defined(DOCTEST_CONFIG_DISABLE) + +// this is here to clear the 'current test suite' for the current translation unit - at the top +DOCTEST_TEST_SUITE_END(); + +// add stringification for primitive/fundamental types +namespace doctest { namespace detail { + DOCTEST_TYPE_TO_STRING_IMPL(bool) + DOCTEST_TYPE_TO_STRING_IMPL(float) + DOCTEST_TYPE_TO_STRING_IMPL(double) + DOCTEST_TYPE_TO_STRING_IMPL(long double) + DOCTEST_TYPE_TO_STRING_IMPL(char) + DOCTEST_TYPE_TO_STRING_IMPL(signed char) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned char) +#if !DOCTEST_MSVC || defined(_NATIVE_WCHAR_T_DEFINED) + DOCTEST_TYPE_TO_STRING_IMPL(wchar_t) +#endif // not MSVC or wchar_t support enabled + DOCTEST_TYPE_TO_STRING_IMPL(short int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned short int) + DOCTEST_TYPE_TO_STRING_IMPL(int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned int) + DOCTEST_TYPE_TO_STRING_IMPL(long int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned long int) + DOCTEST_TYPE_TO_STRING_IMPL(long long int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned long long int) +}} // namespace doctest::detail + +#endif // DOCTEST_CONFIG_DISABLE + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_LIBRARY_INCLUDED + +#ifndef DOCTEST_SINGLE_HEADER +#define DOCTEST_SINGLE_HEADER +#endif // DOCTEST_SINGLE_HEADER + +#if defined(DOCTEST_CONFIG_IMPLEMENT) || !defined(DOCTEST_SINGLE_HEADER) + +#ifndef DOCTEST_SINGLE_HEADER +#include "doctest_fwd.h" +#endif // DOCTEST_SINGLE_HEADER + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-macros") + +#ifndef DOCTEST_LIBRARY_IMPLEMENTATION +#define DOCTEST_LIBRARY_IMPLEMENTATION + +DOCTEST_CLANG_SUPPRESS_WARNING_POP + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wglobal-constructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wexit-time-destructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wshorten-64-to-32") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-variable-declarations") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wcovered-switch-default") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-noreturn") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-local-typedef") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdisabled-macro-expansion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-member-function") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-default") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunsafe-loop-optimizations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wold-style-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-local-typedefs") +DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-function") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmultiple-inheritance") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsuggest-attribute") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4616) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4619) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4996) // The compiler encountered a deprecated declaration +DOCTEST_MSVC_SUPPRESS_WARNING(4267) // 'var' : conversion from 'x' to 'y', possible loss of data +DOCTEST_MSVC_SUPPRESS_WARNING(4706) // assignment within conditional expression +DOCTEST_MSVC_SUPPRESS_WARNING(4512) // 'class' : assignment operator could not be generated +DOCTEST_MSVC_SUPPRESS_WARNING(4127) // conditional expression is constant +DOCTEST_MSVC_SUPPRESS_WARNING(4530) // C++ exception handler used, but unwind semantics not enabled +DOCTEST_MSVC_SUPPRESS_WARNING(4577) // 'noexcept' used with no exception handling mode specified +DOCTEST_MSVC_SUPPRESS_WARNING(4774) // format string expected in argument is not a string literal +DOCTEST_MSVC_SUPPRESS_WARNING(4365) // conversion from 'int' to 'unsigned', signed/unsigned mismatch +DOCTEST_MSVC_SUPPRESS_WARNING(4820) // padding in structs +DOCTEST_MSVC_SUPPRESS_WARNING(4640) // construction of local static object is not thread-safe +DOCTEST_MSVC_SUPPRESS_WARNING(5039) // pointer to potentially throwing function passed to extern C +DOCTEST_MSVC_SUPPRESS_WARNING(5045) // Spectre mitigation stuff +DOCTEST_MSVC_SUPPRESS_WARNING(4626) // assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5027) // move assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5026) // move constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4625) // copy constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4800) // forcing value to bool 'true' or 'false' (performance warning) +// static analysis +DOCTEST_MSVC_SUPPRESS_WARNING(26439) // This kind of function may not throw. Declare it 'noexcept' +DOCTEST_MSVC_SUPPRESS_WARNING(26495) // Always initialize a member variable +DOCTEST_MSVC_SUPPRESS_WARNING(26451) // Arithmetic overflow ... +DOCTEST_MSVC_SUPPRESS_WARNING(26444) // Avoid unnamed objects with custom construction and dtor... +DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN + +// required includes - will go only in one translation unit! +#include +#include +#include +// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/onqtam/doctest/pull/37 +#ifdef __BORLANDC__ +#include +#endif // __BORLANDC__ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef DOCTEST_CONFIG_POSIX_SIGNALS +#include +#endif // DOCTEST_CONFIG_POSIX_SIGNALS +#include +#include +#include + +#ifdef DOCTEST_PLATFORM_MAC +#include +#include +#include +#endif // DOCTEST_PLATFORM_MAC + +#ifdef DOCTEST_PLATFORM_WINDOWS + +// defines for a leaner windows.h +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +// not sure what AfxWin.h is for - here I do what Catch does +#ifdef __AFXDLL +#include +#else +#if defined(__MINGW32__) || defined(__MINGW64__) +#include +#else // MINGW +#include +#endif // MINGW +#endif +#include + +#else // DOCTEST_PLATFORM_WINDOWS + +#include +#include + +#endif // DOCTEST_PLATFORM_WINDOWS + +// this is a fix for https://github.com/onqtam/doctest/issues/348 +// https://mail.gnome.org/archives/xml/2012-January/msg00000.html +#if !defined(HAVE_UNISTD_H) && !defined(STDOUT_FILENO) +#define STDOUT_FILENO fileno(stdout) +#endif // HAVE_UNISTD_H + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END + +// counts the number of elements in a C array +#define DOCTEST_COUNTOF(x) (sizeof(x) / sizeof(x[0])) + +#ifdef DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_disabled +#else // DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_not_disabled +#endif // DOCTEST_CONFIG_DISABLE + +#ifndef DOCTEST_CONFIG_OPTIONS_PREFIX +#define DOCTEST_CONFIG_OPTIONS_PREFIX "dt-" +#endif + +#ifndef DOCTEST_THREAD_LOCAL +#define DOCTEST_THREAD_LOCAL thread_local +#endif + +#ifdef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS +#define DOCTEST_OPTIONS_PREFIX_DISPLAY DOCTEST_CONFIG_OPTIONS_PREFIX +#else +#define DOCTEST_OPTIONS_PREFIX_DISPLAY "" +#endif + +namespace doctest { + +bool is_running_in_test = false; + +namespace { + using namespace detail; + // case insensitive strcmp + int stricmp(const char* a, const char* b) { + for(;; a++, b++) { + const int d = tolower(*a) - tolower(*b); + if(d != 0 || !*a) + return d; + } + } + + template + String fpToString(T value, int precision) { + std::ostringstream oss; + oss << std::setprecision(precision) << std::fixed << value; + std::string d = oss.str(); + size_t i = d.find_last_not_of('0'); + if(i != std::string::npos && i != d.size() - 1) { + if(d[i] == '.') + i++; + d = d.substr(0, i + 1); + } + return d.c_str(); + } + + struct Endianness + { + enum Arch + { + Big, + Little + }; + + static Arch which() { + int x = 1; + // casting any data pointer to char* is allowed + auto ptr = reinterpret_cast(&x); + if(*ptr) + return Little; + return Big; + } + }; +} // namespace + +namespace detail { + void my_memcpy(void* dest, const void* src, unsigned num) { memcpy(dest, src, num); } + + String rawMemoryToString(const void* object, unsigned size) { + // Reverse order for little endian architectures + int i = 0, end = static_cast(size), inc = 1; + if(Endianness::which() == Endianness::Little) { + i = end - 1; + end = inc = -1; + } + + unsigned const char* bytes = static_cast(object); + std::ostringstream oss; + oss << "0x" << std::setfill('0') << std::hex; + for(; i != end; i += inc) + oss << std::setw(2) << static_cast(bytes[i]); + return oss.str().c_str(); + } + + DOCTEST_THREAD_LOCAL std::ostringstream g_oss; // NOLINT(cert-err58-cpp) + + std::ostream* getTlsOss() { + g_oss.clear(); // there shouldn't be anything worth clearing in the flags + g_oss.str(""); // the slow way of resetting a string stream + //g_oss.seekp(0); // optimal reset - as seen here: https://stackoverflow.com/a/624291/3162383 + return &g_oss; + } + + String getTlsOssResult() { + //g_oss << std::ends; // needed - as shown here: https://stackoverflow.com/a/624291/3162383 + return g_oss.str().c_str(); + } + +#ifndef DOCTEST_CONFIG_DISABLE + +namespace timer_large_integer +{ + +#if defined(DOCTEST_PLATFORM_WINDOWS) + typedef ULONGLONG type; +#else // DOCTEST_PLATFORM_WINDOWS + using namespace std; + typedef uint64_t type; +#endif // DOCTEST_PLATFORM_WINDOWS +} + +typedef timer_large_integer::type ticks_t; + +#ifdef DOCTEST_CONFIG_GETCURRENTTICKS + ticks_t getCurrentTicks() { return DOCTEST_CONFIG_GETCURRENTTICKS(); } +#elif defined(DOCTEST_PLATFORM_WINDOWS) + ticks_t getCurrentTicks() { + static LARGE_INTEGER hz = {0}, hzo = {0}; + if(!hz.QuadPart) { + QueryPerformanceFrequency(&hz); + QueryPerformanceCounter(&hzo); + } + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart - hzo.QuadPart) * LONGLONG(1000000)) / hz.QuadPart; + } +#else // DOCTEST_PLATFORM_WINDOWS + ticks_t getCurrentTicks() { + timeval t; + gettimeofday(&t, nullptr); + return static_cast(t.tv_sec) * 1000000 + static_cast(t.tv_usec); + } +#endif // DOCTEST_PLATFORM_WINDOWS + + struct Timer + { + void start() { m_ticks = getCurrentTicks(); } + unsigned int getElapsedMicroseconds() const { + return static_cast(getCurrentTicks() - m_ticks); + } + //unsigned int getElapsedMilliseconds() const { + // return static_cast(getElapsedMicroseconds() / 1000); + //} + double getElapsedSeconds() const { return static_cast(getCurrentTicks() - m_ticks) / 1000000.0; } + + private: + ticks_t m_ticks = 0; + }; + + // this holds both parameters from the command line and runtime data for tests + struct ContextState : ContextOptions, TestRunStats, CurrentTestCaseStats + { + std::atomic numAssertsCurrentTest_atomic; + std::atomic numAssertsFailedCurrentTest_atomic; + + std::vector> filters = decltype(filters)(9); // 9 different filters + + std::vector reporters_currently_used; + + const TestCase* currentTest = nullptr; + + assert_handler ah = nullptr; + + Timer timer; + + std::vector stringifiedContexts; // logging from INFO() due to an exception + + // stuff for subcases + std::vector subcasesStack; + std::set subcasesPassed; + int subcasesCurrentMaxLevel; + bool should_reenter; + std::atomic shouldLogCurrentException; + + void resetRunData() { + numTestCases = 0; + numTestCasesPassingFilters = 0; + numTestSuitesPassingFilters = 0; + numTestCasesFailed = 0; + numAsserts = 0; + numAssertsFailed = 0; + numAssertsCurrentTest = 0; + numAssertsFailedCurrentTest = 0; + } + + void finalizeTestCaseData() { + seconds = timer.getElapsedSeconds(); + + // update the non-atomic counters + numAsserts += numAssertsCurrentTest_atomic; + numAssertsFailed += numAssertsFailedCurrentTest_atomic; + numAssertsCurrentTest = numAssertsCurrentTest_atomic; + numAssertsFailedCurrentTest = numAssertsFailedCurrentTest_atomic; + + if(numAssertsFailedCurrentTest) + failure_flags |= TestCaseFailureReason::AssertFailure; + + if(Approx(currentTest->m_timeout).epsilon(DBL_EPSILON) != 0 && + Approx(seconds).epsilon(DBL_EPSILON) > currentTest->m_timeout) + failure_flags |= TestCaseFailureReason::Timeout; + + if(currentTest->m_should_fail) { + if(failure_flags) { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedAndDid; + } else { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedButDidnt; + } + } else if(failure_flags && currentTest->m_may_fail) { + failure_flags |= TestCaseFailureReason::CouldHaveFailedAndDid; + } else if(currentTest->m_expected_failures > 0) { + if(numAssertsFailedCurrentTest == currentTest->m_expected_failures) { + failure_flags |= TestCaseFailureReason::FailedExactlyNumTimes; + } else { + failure_flags |= TestCaseFailureReason::DidntFailExactlyNumTimes; + } + } + + bool ok_to_fail = (TestCaseFailureReason::ShouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::CouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::FailedExactlyNumTimes & failure_flags); + + // if any subcase has failed - the whole test case has failed + if(failure_flags && !ok_to_fail) + numTestCasesFailed++; + } + }; + + ContextState* g_cs = nullptr; + + // used to avoid locks for the debug output + // TODO: figure out if this is indeed necessary/correct - seems like either there still + // could be a race or that there wouldn't be a race even if using the context directly + DOCTEST_THREAD_LOCAL bool g_no_colors; + +#endif // DOCTEST_CONFIG_DISABLE +} // namespace detail + +void String::setOnHeap() { *reinterpret_cast(&buf[last]) = 128; } +void String::setLast(unsigned in) { buf[last] = char(in); } + +void String::copy(const String& other) { + using namespace std; + if(other.isOnStack()) { + memcpy(buf, other.buf, len); + } else { + setOnHeap(); + data.size = other.data.size; + data.capacity = data.size + 1; + data.ptr = new char[data.capacity]; + memcpy(data.ptr, other.data.ptr, data.size + 1); + } +} + +String::String() { + buf[0] = '\0'; + setLast(); +} + +String::~String() { + if(!isOnStack()) + delete[] data.ptr; +} + +String::String(const char* in) + : String(in, strlen(in)) {} + +String::String(const char* in, unsigned in_size) { + using namespace std; + if(in_size <= last) { + memcpy(buf, in, in_size + 1); + setLast(last - in_size); + } else { + setOnHeap(); + data.size = in_size; + data.capacity = data.size + 1; + data.ptr = new char[data.capacity]; + memcpy(data.ptr, in, in_size + 1); + } +} + +String::String(const String& other) { copy(other); } + +String& String::operator=(const String& other) { + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + + copy(other); + } + + return *this; +} + +String& String::operator+=(const String& other) { + const unsigned my_old_size = size(); + const unsigned other_size = other.size(); + const unsigned total_size = my_old_size + other_size; + using namespace std; + if(isOnStack()) { + if(total_size < len) { + // append to the current stack space + memcpy(buf + my_old_size, other.c_str(), other_size + 1); + setLast(last - total_size); + } else { + // alloc new chunk + char* temp = new char[total_size + 1]; + // copy current data to new location before writing in the union + memcpy(temp, buf, my_old_size); // skip the +1 ('\0') for speed + // update data in union + setOnHeap(); + data.size = total_size; + data.capacity = data.size + 1; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } else { + if(data.capacity > total_size) { + // append to the current heap block + data.size = total_size; + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } else { + // resize + data.capacity *= 2; + if(data.capacity <= total_size) + data.capacity = total_size + 1; + // alloc new chunk + char* temp = new char[data.capacity]; + // copy current data to new location before releasing it + memcpy(temp, data.ptr, my_old_size); // skip the +1 ('\0') for speed + // release old chunk + delete[] data.ptr; + // update the rest of the union members + data.size = total_size; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } + + return *this; +} + +String String::operator+(const String& other) const { return String(*this) += other; } + +String::String(String&& other) { + using namespace std; + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); +} + +String& String::operator=(String&& other) { + using namespace std; + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); + } + return *this; +} + +char String::operator[](unsigned i) const { + return const_cast(this)->operator[](i); // NOLINT +} + +char& String::operator[](unsigned i) { + if(isOnStack()) + return reinterpret_cast(buf)[i]; + return data.ptr[i]; +} + +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmaybe-uninitialized") +unsigned String::size() const { + if(isOnStack()) + return last - (unsigned(buf[last]) & 31); // using "last" would work only if "len" is 32 + return data.size; +} +DOCTEST_GCC_SUPPRESS_WARNING_POP + +unsigned String::capacity() const { + if(isOnStack()) + return len; + return data.capacity; +} + +int String::compare(const char* other, bool no_case) const { + if(no_case) + return doctest::stricmp(c_str(), other); + return std::strcmp(c_str(), other); +} + +int String::compare(const String& other, bool no_case) const { + return compare(other.c_str(), no_case); +} + +// clang-format off +bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } +bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } +bool operator< (const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } +bool operator> (const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } +bool operator<=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) < 0 : true; } +bool operator>=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) > 0 : true; } +// clang-format on + +std::ostream& operator<<(std::ostream& s, const String& in) { return s << in.c_str(); } + +namespace { + void color_to_stream(std::ostream&, Color::Enum) DOCTEST_BRANCH_ON_DISABLED({}, ;) +} // namespace + +namespace Color { + std::ostream& operator<<(std::ostream& s, Color::Enum code) { + color_to_stream(s, code); + return s; + } +} // namespace Color + +// clang-format off +const char* assertString(assertType::Enum at) { + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4062) // enum 'x' in switch of enum 'y' is not handled + switch(at) { //!OCLINT missing default in switch statements + case assertType::DT_WARN : return "WARN"; + case assertType::DT_CHECK : return "CHECK"; + case assertType::DT_REQUIRE : return "REQUIRE"; + + case assertType::DT_WARN_FALSE : return "WARN_FALSE"; + case assertType::DT_CHECK_FALSE : return "CHECK_FALSE"; + case assertType::DT_REQUIRE_FALSE : return "REQUIRE_FALSE"; + + case assertType::DT_WARN_THROWS : return "WARN_THROWS"; + case assertType::DT_CHECK_THROWS : return "CHECK_THROWS"; + case assertType::DT_REQUIRE_THROWS : return "REQUIRE_THROWS"; + + case assertType::DT_WARN_THROWS_AS : return "WARN_THROWS_AS"; + case assertType::DT_CHECK_THROWS_AS : return "CHECK_THROWS_AS"; + case assertType::DT_REQUIRE_THROWS_AS : return "REQUIRE_THROWS_AS"; + + case assertType::DT_WARN_THROWS_WITH : return "WARN_THROWS_WITH"; + case assertType::DT_CHECK_THROWS_WITH : return "CHECK_THROWS_WITH"; + case assertType::DT_REQUIRE_THROWS_WITH : return "REQUIRE_THROWS_WITH"; + + case assertType::DT_WARN_THROWS_WITH_AS : return "WARN_THROWS_WITH_AS"; + case assertType::DT_CHECK_THROWS_WITH_AS : return "CHECK_THROWS_WITH_AS"; + case assertType::DT_REQUIRE_THROWS_WITH_AS : return "REQUIRE_THROWS_WITH_AS"; + + case assertType::DT_WARN_NOTHROW : return "WARN_NOTHROW"; + case assertType::DT_CHECK_NOTHROW : return "CHECK_NOTHROW"; + case assertType::DT_REQUIRE_NOTHROW : return "REQUIRE_NOTHROW"; + + case assertType::DT_WARN_EQ : return "WARN_EQ"; + case assertType::DT_CHECK_EQ : return "CHECK_EQ"; + case assertType::DT_REQUIRE_EQ : return "REQUIRE_EQ"; + case assertType::DT_WARN_NE : return "WARN_NE"; + case assertType::DT_CHECK_NE : return "CHECK_NE"; + case assertType::DT_REQUIRE_NE : return "REQUIRE_NE"; + case assertType::DT_WARN_GT : return "WARN_GT"; + case assertType::DT_CHECK_GT : return "CHECK_GT"; + case assertType::DT_REQUIRE_GT : return "REQUIRE_GT"; + case assertType::DT_WARN_LT : return "WARN_LT"; + case assertType::DT_CHECK_LT : return "CHECK_LT"; + case assertType::DT_REQUIRE_LT : return "REQUIRE_LT"; + case assertType::DT_WARN_GE : return "WARN_GE"; + case assertType::DT_CHECK_GE : return "CHECK_GE"; + case assertType::DT_REQUIRE_GE : return "REQUIRE_GE"; + case assertType::DT_WARN_LE : return "WARN_LE"; + case assertType::DT_CHECK_LE : return "CHECK_LE"; + case assertType::DT_REQUIRE_LE : return "REQUIRE_LE"; + + case assertType::DT_WARN_UNARY : return "WARN_UNARY"; + case assertType::DT_CHECK_UNARY : return "CHECK_UNARY"; + case assertType::DT_REQUIRE_UNARY : return "REQUIRE_UNARY"; + case assertType::DT_WARN_UNARY_FALSE : return "WARN_UNARY_FALSE"; + case assertType::DT_CHECK_UNARY_FALSE : return "CHECK_UNARY_FALSE"; + case assertType::DT_REQUIRE_UNARY_FALSE : return "REQUIRE_UNARY_FALSE"; + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP + return ""; +} +// clang-format on + +const char* failureString(assertType::Enum at) { + if(at & assertType::is_warn) //!OCLINT bitwise operator in conditional + return "WARNING"; + if(at & assertType::is_check) //!OCLINT bitwise operator in conditional + return "ERROR"; + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return "FATAL ERROR"; + return ""; +} + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +// depending on the current options this will remove the path of filenames +const char* skipPathFromFilename(const char* file) { + if(getContextOptions()->no_path_in_filenames) { + auto back = std::strrchr(file, '\\'); + auto forward = std::strrchr(file, '/'); + if(back || forward) { + if(back > forward) + forward = back; + return forward + 1; + } + } + return file; +} +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +bool SubcaseSignature::operator<(const SubcaseSignature& other) const { + if(m_line != other.m_line) + return m_line < other.m_line; + if(std::strcmp(m_file, other.m_file) != 0) + return std::strcmp(m_file, other.m_file) < 0; + return m_name.compare(other.m_name) < 0; +} + +IContextScope::IContextScope() = default; +IContextScope::~IContextScope() = default; + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +String toString(char* in) { return toString(static_cast(in)); } +String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +String toString(bool in) { return in ? "true" : "false"; } +String toString(float in) { return fpToString(in, 5) + "f"; } +String toString(double in) { return fpToString(in, 10); } +String toString(double long in) { return fpToString(in, 15); } + +#define DOCTEST_TO_STRING_OVERLOAD(type, fmt) \ + String toString(type in) { \ + char buf[64]; \ + std::sprintf(buf, fmt, in); \ + return buf; \ + } + +DOCTEST_TO_STRING_OVERLOAD(char, "%d") +DOCTEST_TO_STRING_OVERLOAD(char signed, "%d") +DOCTEST_TO_STRING_OVERLOAD(char unsigned, "%u") +DOCTEST_TO_STRING_OVERLOAD(int short, "%d") +DOCTEST_TO_STRING_OVERLOAD(int short unsigned, "%u") +DOCTEST_TO_STRING_OVERLOAD(int, "%d") +DOCTEST_TO_STRING_OVERLOAD(unsigned, "%u") +DOCTEST_TO_STRING_OVERLOAD(int long, "%ld") +DOCTEST_TO_STRING_OVERLOAD(int long unsigned, "%lu") +DOCTEST_TO_STRING_OVERLOAD(int long long, "%lld") +DOCTEST_TO_STRING_OVERLOAD(int long long unsigned, "%llu") + +String toString(std::nullptr_t) { return "NULL"; } + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 +String toString(const std::string& in) { return in.c_str(); } +#endif // VS 2019 + +Approx::Approx(double value) + : m_epsilon(static_cast(std::numeric_limits::epsilon()) * 100) + , m_scale(1.0) + , m_value(value) {} + +Approx Approx::operator()(double value) const { + Approx approx(value); + approx.epsilon(m_epsilon); + approx.scale(m_scale); + return approx; +} + +Approx& Approx::epsilon(double newEpsilon) { + m_epsilon = newEpsilon; + return *this; +} +Approx& Approx::scale(double newScale) { + m_scale = newScale; + return *this; +} + +bool operator==(double lhs, const Approx& rhs) { + // Thanks to Richard Harris for his help refining this formula + return std::fabs(lhs - rhs.m_value) < + rhs.m_epsilon * (rhs.m_scale + std::max(std::fabs(lhs), std::fabs(rhs.m_value))); +} +bool operator==(const Approx& lhs, double rhs) { return operator==(rhs, lhs); } +bool operator!=(double lhs, const Approx& rhs) { return !operator==(lhs, rhs); } +bool operator!=(const Approx& lhs, double rhs) { return !operator==(rhs, lhs); } +bool operator<=(double lhs, const Approx& rhs) { return lhs < rhs.m_value || lhs == rhs; } +bool operator<=(const Approx& lhs, double rhs) { return lhs.m_value < rhs || lhs == rhs; } +bool operator>=(double lhs, const Approx& rhs) { return lhs > rhs.m_value || lhs == rhs; } +bool operator>=(const Approx& lhs, double rhs) { return lhs.m_value > rhs || lhs == rhs; } +bool operator<(double lhs, const Approx& rhs) { return lhs < rhs.m_value && lhs != rhs; } +bool operator<(const Approx& lhs, double rhs) { return lhs.m_value < rhs && lhs != rhs; } +bool operator>(double lhs, const Approx& rhs) { return lhs > rhs.m_value && lhs != rhs; } +bool operator>(const Approx& lhs, double rhs) { return lhs.m_value > rhs && lhs != rhs; } + +String toString(const Approx& in) { + return String("Approx( ") + doctest::toString(in.m_value) + " )"; +} +const ContextOptions* getContextOptions() { return DOCTEST_BRANCH_ON_DISABLED(nullptr, g_cs); } + +} // namespace doctest + +#ifdef DOCTEST_CONFIG_DISABLE +namespace doctest { +Context::Context(int, const char* const*) {} +Context::~Context() = default; +void Context::applyCommandLine(int, const char* const*) {} +void Context::addFilter(const char*, const char*) {} +void Context::clearFilters() {} +void Context::setOption(const char*, int) {} +void Context::setOption(const char*, const char*) {} +bool Context::shouldExit() { return false; } +void Context::setAsDefaultForAssertsOutOfTestCases() {} +void Context::setAssertHandler(detail::assert_handler) {} +int Context::run() { return 0; } + +IReporter::~IReporter() = default; + +int IReporter::get_num_active_contexts() { return 0; } +const IContextScope* const* IReporter::get_active_contexts() { return nullptr; } +int IReporter::get_num_stringified_contexts() { return 0; } +const String* IReporter::get_stringified_contexts() { return nullptr; } + +int registerReporter(const char*, int, IReporter*) { return 0; } + +} // namespace doctest +#else // DOCTEST_CONFIG_DISABLE + +#if !defined(DOCTEST_CONFIG_COLORS_NONE) +#if !defined(DOCTEST_CONFIG_COLORS_WINDOWS) && !defined(DOCTEST_CONFIG_COLORS_ANSI) +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_CONFIG_COLORS_WINDOWS +#else // linux +#define DOCTEST_CONFIG_COLORS_ANSI +#endif // platform +#endif // DOCTEST_CONFIG_COLORS_WINDOWS && DOCTEST_CONFIG_COLORS_ANSI +#endif // DOCTEST_CONFIG_COLORS_NONE + +namespace doctest_detail_test_suite_ns { +// holds the current test suite +doctest::detail::TestSuite& getCurrentTestSuite() { + static doctest::detail::TestSuite data; + return data; +} +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +namespace { + // the int (priority) is part of the key for automatic sorting - sadly one can register a + // reporter with a duplicate name and a different priority but hopefully that won't happen often :| + typedef std::map, reporterCreatorFunc> reporterMap; + + reporterMap& getReporters() { + static reporterMap data; + return data; + } + reporterMap& getListeners() { + static reporterMap data; + return data; + } +} // namespace +namespace detail { +#define DOCTEST_ITERATE_THROUGH_REPORTERS(function, ...) \ + for(auto& curr_rep : g_cs->reporters_currently_used) \ + curr_rep->function(__VA_ARGS__) + + bool checkIfShouldThrow(assertType::Enum at) { + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return true; + + if((at & assertType::is_check) //!OCLINT bitwise operator in conditional + && getContextOptions()->abort_after > 0 && + (g_cs->numAssertsFailed + g_cs->numAssertsFailedCurrentTest_atomic) >= + getContextOptions()->abort_after) + return true; + + return false; + } + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN void throwException() { + g_cs->shouldLogCurrentException = false; + throw TestFailureException(); + } // NOLINT(cert-err60-cpp) +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + void throwException() {} +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +} // namespace detail + +namespace { + using namespace detail; + // matching of a string against a wildcard mask (case sensitivity configurable) taken from + // https://www.codeproject.com/Articles/1088/Wildcard-string-compare-globbing + int wildcmp(const char* str, const char* wild, bool caseSensitive) { + const char* cp = str; + const char* mp = wild; + + while((*str) && (*wild != '*')) { + if((caseSensitive ? (*wild != *str) : (tolower(*wild) != tolower(*str))) && + (*wild != '?')) { + return 0; + } + wild++; + str++; + } + + while(*str) { + if(*wild == '*') { + if(!*++wild) { + return 1; + } + mp = wild; + cp = str + 1; + } else if((caseSensitive ? (*wild == *str) : (tolower(*wild) == tolower(*str))) || + (*wild == '?')) { + wild++; + str++; + } else { + wild = mp; //!OCLINT parameter reassignment + str = cp++; //!OCLINT parameter reassignment + } + } + + while(*wild == '*') { + wild++; + } + return !*wild; + } + + //// C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html + //unsigned hashStr(unsigned const char* str) { + // unsigned long hash = 5381; + // char c; + // while((c = *str++)) + // hash = ((hash << 5) + hash) + c; // hash * 33 + c + // return hash; + //} + + // checks if the name matches any of the filters (and can be configured what to do when empty) + bool matchesAny(const char* name, const std::vector& filters, bool matchEmpty, + bool caseSensitive) { + if(filters.empty() && matchEmpty) + return true; + for(auto& curr : filters) + if(wildcmp(name, curr.c_str(), caseSensitive)) + return true; + return false; + } +} // namespace +namespace detail { + + Subcase::Subcase(const String& name, const char* file, int line) + : m_signature({name, file, line}) { + ContextState* s = g_cs; + + // check subcase filters + if(s->subcasesStack.size() < size_t(s->subcase_filter_levels)) { + if(!matchesAny(m_signature.m_name.c_str(), s->filters[6], true, s->case_sensitive)) + return; + if(matchesAny(m_signature.m_name.c_str(), s->filters[7], false, s->case_sensitive)) + return; + } + + // if a Subcase on the same level has already been entered + if(s->subcasesStack.size() < size_t(s->subcasesCurrentMaxLevel)) { + s->should_reenter = true; + return; + } + + // push the current signature to the stack so we can check if the + // current stack + the current new subcase have been traversed + s->subcasesStack.push_back(m_signature); + if(s->subcasesPassed.count(s->subcasesStack) != 0) { + // pop - revert to previous stack since we've already passed this + s->subcasesStack.pop_back(); + return; + } + + s->subcasesCurrentMaxLevel = s->subcasesStack.size(); + m_entered = true; + + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + Subcase::~Subcase() { + if(m_entered) { + // only mark the subcase stack as passed if no subcases have been skipped + if(g_cs->should_reenter == false) + g_cs->subcasesPassed.insert(g_cs->subcasesStack); + g_cs->subcasesStack.pop_back(); + +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L + if(std::uncaught_exceptions() > 0 +#else + if(std::uncaught_exception() +#endif + && g_cs->shouldLogCurrentException) { + DOCTEST_ITERATE_THROUGH_REPORTERS( + test_case_exception, {"exception thrown in subcase - will translate later " + "when the whole test case has been exited (cannot " + "translate while there is an active exception)", + false}); + g_cs->shouldLogCurrentException = false; + } + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + Subcase::operator bool() const { return m_entered; } + + Result::Result(bool passed, const String& decomposition) + : m_passed(passed) + , m_decomp(decomposition) {} + + ExpressionDecomposer::ExpressionDecomposer(assertType::Enum at) + : m_at(at) {} + + TestSuite& TestSuite::operator*(const char* in) { + m_test_suite = in; + // clear state + m_description = nullptr; + m_skip = false; + m_may_fail = false; + m_should_fail = false; + m_expected_failures = 0; + m_timeout = 0; + return *this; + } + + TestCase::TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const char* type, int template_id) { + m_file = file; + m_line = line; + m_name = nullptr; // will be later overridden in operator* + m_test_suite = test_suite.m_test_suite; + m_description = test_suite.m_description; + m_skip = test_suite.m_skip; + m_may_fail = test_suite.m_may_fail; + m_should_fail = test_suite.m_should_fail; + m_expected_failures = test_suite.m_expected_failures; + m_timeout = test_suite.m_timeout; + + m_test = test; + m_type = type; + m_template_id = template_id; + } + + TestCase::TestCase(const TestCase& other) + : TestCaseData() { + *this = other; + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + DOCTEST_MSVC_SUPPRESS_WARNING(26437) // Do not slice + TestCase& TestCase::operator=(const TestCase& other) { + static_cast(*this) = static_cast(other); + + m_test = other.m_test; + m_type = other.m_type; + m_template_id = other.m_template_id; + m_full_name = other.m_full_name; + + if(m_template_id != -1) + m_name = m_full_name.c_str(); + return *this; + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& TestCase::operator*(const char* in) { + m_name = in; + // make a new name with an appended type for templated test case + if(m_template_id != -1) { + m_full_name = String(m_name) + m_type; + // redirect the name to point to the newly constructed full name + m_name = m_full_name.c_str(); + } + return *this; + } + + bool TestCase::operator<(const TestCase& other) const { + if(m_line != other.m_line) + return m_line < other.m_line; + const int file_cmp = m_file.compare(other.m_file); + if(file_cmp != 0) + return file_cmp < 0; + return m_template_id < other.m_template_id; + } +} // namespace detail +namespace { + using namespace detail; + // for sorting tests by file/line + bool fileOrderComparator(const TestCase* lhs, const TestCase* rhs) { + // this is needed because MSVC gives different case for drive letters + // for __FILE__ when evaluated in a header and a source file + const int res = lhs->m_file.compare(rhs->m_file, bool(DOCTEST_MSVC)); + if(res != 0) + return res < 0; + if(lhs->m_line != rhs->m_line) + return lhs->m_line < rhs->m_line; + return lhs->m_template_id < rhs->m_template_id; + } + + // for sorting tests by suite/file/line + bool suiteOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_test_suite, rhs->m_test_suite); + if(res != 0) + return res < 0; + return fileOrderComparator(lhs, rhs); + } + + // for sorting tests by name/suite/file/line + bool nameOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_name, rhs->m_name); + if(res != 0) + return res < 0; + return suiteOrderComparator(lhs, rhs); + } + + // all the registered tests + std::set& getRegisteredTests() { + static std::set data; + return data; + } + +#ifdef DOCTEST_CONFIG_COLORS_WINDOWS + HANDLE g_stdoutHandle; + WORD g_origFgAttrs; + WORD g_origBgAttrs; + bool g_attrsInitted = false; + + int colors_init() { + if(!g_attrsInitted) { + g_stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); + g_attrsInitted = true; + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(g_stdoutHandle, &csbiInfo); + g_origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | + BACKGROUND_BLUE | BACKGROUND_INTENSITY); + g_origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | + FOREGROUND_BLUE | FOREGROUND_INTENSITY); + } + return 0; + } + + int dumy_init_console_colors = colors_init(); +#endif // DOCTEST_CONFIG_COLORS_WINDOWS + + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + void color_to_stream(std::ostream& s, Color::Enum code) { + ((void)s); // for DOCTEST_CONFIG_COLORS_NONE or DOCTEST_CONFIG_COLORS_WINDOWS + ((void)code); // for DOCTEST_CONFIG_COLORS_NONE +#ifdef DOCTEST_CONFIG_COLORS_ANSI + if(g_no_colors || + (isatty(STDOUT_FILENO) == false && getContextOptions()->force_colors == false)) + return; + + auto col = ""; + // clang-format off + switch(code) { //!OCLINT missing break in switch statement / unnecessary default statement in covered switch statement + case Color::Red: col = "[0;31m"; break; + case Color::Green: col = "[0;32m"; break; + case Color::Blue: col = "[0;34m"; break; + case Color::Cyan: col = "[0;36m"; break; + case Color::Yellow: col = "[0;33m"; break; + case Color::Grey: col = "[1;30m"; break; + case Color::LightGrey: col = "[0;37m"; break; + case Color::BrightRed: col = "[1;31m"; break; + case Color::BrightGreen: col = "[1;32m"; break; + case Color::BrightWhite: col = "[1;37m"; break; + case Color::Bright: // invalid + case Color::None: + case Color::White: + default: col = "[0m"; + } + // clang-format on + s << "\033" << col; +#endif // DOCTEST_CONFIG_COLORS_ANSI + +#ifdef DOCTEST_CONFIG_COLORS_WINDOWS + if(g_no_colors || + (isatty(fileno(stdout)) == false && getContextOptions()->force_colors == false)) + return; + +#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(g_stdoutHandle, x | g_origBgAttrs) + + // clang-format off + switch (code) { + case Color::White: DOCTEST_SET_ATTR(FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::Red: DOCTEST_SET_ATTR(FOREGROUND_RED); break; + case Color::Green: DOCTEST_SET_ATTR(FOREGROUND_GREEN); break; + case Color::Blue: DOCTEST_SET_ATTR(FOREGROUND_BLUE); break; + case Color::Cyan: DOCTEST_SET_ATTR(FOREGROUND_BLUE | FOREGROUND_GREEN); break; + case Color::Yellow: DOCTEST_SET_ATTR(FOREGROUND_RED | FOREGROUND_GREEN); break; + case Color::Grey: DOCTEST_SET_ATTR(0); break; + case Color::LightGrey: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY); break; + case Color::BrightRed: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_RED); break; + case Color::BrightGreen: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN); break; + case Color::BrightWhite: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::None: + case Color::Bright: // invalid + default: DOCTEST_SET_ATTR(g_origFgAttrs); + } + // clang-format on +#endif // DOCTEST_CONFIG_COLORS_WINDOWS + } + DOCTEST_CLANG_SUPPRESS_WARNING_POP + + std::vector& getExceptionTranslators() { + static std::vector data; + return data; + } + + String translateActiveException() { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + String res; + auto& translators = getExceptionTranslators(); + for(auto& curr : translators) + if(curr->translate(res)) + return res; + // clang-format off + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wcatch-value") + try { + throw; + } catch(std::exception& ex) { + return ex.what(); + } catch(std::string& msg) { + return msg.c_str(); + } catch(const char* msg) { + return msg; + } catch(...) { + return "unknown exception"; + } + DOCTEST_GCC_SUPPRESS_WARNING_POP +// clang-format on +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + return ""; +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } +} // namespace + +namespace detail { + // used by the macros for registering tests + int regTest(const TestCase& tc) { + getRegisteredTests().insert(tc); + return 0; + } + + // sets the current test suite + int setTestSuite(const TestSuite& ts) { + doctest_detail_test_suite_ns::getCurrentTestSuite() = ts; + return 0; + } + +#ifdef DOCTEST_IS_DEBUGGER_ACTIVE + bool isDebuggerActive() { return DOCTEST_IS_DEBUGGER_ACTIVE(); } +#else // DOCTEST_IS_DEBUGGER_ACTIVE +#ifdef DOCTEST_PLATFORM_MAC + // The following function is taken directly from the following technical note: + // https://developer.apple.com/library/archive/qa/qa1361/_index.html + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive() { + int mib[4]; + kinfo_proc info; + size_t size; + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + info.kp_proc.p_flag = 0; + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + // Call sysctl. + size = sizeof(info); + if(sysctl(mib, DOCTEST_COUNTOF(mib), &info, &size, 0, 0) != 0) { + std::cerr << "\nCall to sysctl failed - unable to determine if debugger is active **\n"; + return false; + } + // We're being debugged if the P_TRACED flag is set. + return ((info.kp_proc.p_flag & P_TRACED) != 0); + } +#elif DOCTEST_MSVC || defined(__MINGW32__) || defined(__MINGW64__) + bool isDebuggerActive() { return ::IsDebuggerPresent() != 0; } +#else + bool isDebuggerActive() { return false; } +#endif // Platform +#endif // DOCTEST_IS_DEBUGGER_ACTIVE + + void registerExceptionTranslatorImpl(const IExceptionTranslator* et) { + if(std::find(getExceptionTranslators().begin(), getExceptionTranslators().end(), et) == + getExceptionTranslators().end()) + getExceptionTranslators().push_back(et); + } + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + void toStream(std::ostream* s, char* in) { *s << in; } + void toStream(std::ostream* s, const char* in) { *s << in; } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + void toStream(std::ostream* s, bool in) { *s << std::boolalpha << in << std::noboolalpha; } + void toStream(std::ostream* s, float in) { *s << in; } + void toStream(std::ostream* s, double in) { *s << in; } + void toStream(std::ostream* s, double long in) { *s << in; } + + void toStream(std::ostream* s, char in) { *s << in; } + void toStream(std::ostream* s, char signed in) { *s << in; } + void toStream(std::ostream* s, char unsigned in) { *s << in; } + void toStream(std::ostream* s, int short in) { *s << in; } + void toStream(std::ostream* s, int short unsigned in) { *s << in; } + void toStream(std::ostream* s, int in) { *s << in; } + void toStream(std::ostream* s, int unsigned in) { *s << in; } + void toStream(std::ostream* s, int long in) { *s << in; } + void toStream(std::ostream* s, int long unsigned in) { *s << in; } + void toStream(std::ostream* s, int long long in) { *s << in; } + void toStream(std::ostream* s, int long long unsigned in) { *s << in; } + + DOCTEST_THREAD_LOCAL std::vector g_infoContexts; // for logging with INFO() + + ContextScopeBase::ContextScopeBase() { + g_infoContexts.push_back(this); + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + // destroy cannot be inlined into the destructor because that would mean calling stringify after + // ContextScope has been destroyed (base class destructors run after derived class destructors). + // Instead, ContextScope calls this method directly from its destructor. + void ContextScopeBase::destroy() { +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L + if(std::uncaught_exceptions() > 0) { +#else + if(std::uncaught_exception()) { +#endif + std::ostringstream s; + this->stringify(&s); + g_cs->stringifiedContexts.push_back(s.str().c_str()); + } + g_infoContexts.pop_back(); + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP +} // namespace detail +namespace { + using namespace detail; + +#if !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && !defined(DOCTEST_CONFIG_WINDOWS_SEH) + struct FatalConditionHandler + { + void reset() {} + }; +#else // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + + void reportFatal(const std::string&); + +#ifdef DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + DWORD id; + const char* name; + }; + // There is no 1-1 mapping between signals and windows exceptions. + // Windows can easily distinguish between SO and SigSegV, + // but SigInt, SigTerm, etc are handled differently. + SignalDefs signalDefs[] = { + {EXCEPTION_ILLEGAL_INSTRUCTION, "SIGILL - Illegal instruction signal"}, + {EXCEPTION_STACK_OVERFLOW, "SIGSEGV - Stack overflow"}, + {EXCEPTION_ACCESS_VIOLATION, "SIGSEGV - Segmentation violation signal"}, + {EXCEPTION_INT_DIVIDE_BY_ZERO, "Divide by zero error"}, + }; + + struct FatalConditionHandler + { + static LONG CALLBACK handleException(PEXCEPTION_POINTERS ExceptionInfo) { + for(size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + if(ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) { + reportFatal(signalDefs[i].name); + break; + } + } + // If its not an exception we care about, pass it along. + // This stops us from eating debugger breaks etc. + return EXCEPTION_CONTINUE_SEARCH; + } + + FatalConditionHandler() { + isSet = true; + // 32k seems enough for doctest to handle stack overflow, + // but the value was found experimentally, so there is no strong guarantee + guaranteeSize = 32 * 1024; + // Register an unhandled exception filter + previousTop = SetUnhandledExceptionFilter(handleException); + // Pass in guarantee size to be filled + SetThreadStackGuarantee(&guaranteeSize); + } + + static void reset() { + if(isSet) { + // Unregister handler and restore the old guarantee + SetUnhandledExceptionFilter(previousTop); + SetThreadStackGuarantee(&guaranteeSize); + previousTop = nullptr; + isSet = false; + } + } + + ~FatalConditionHandler() { reset(); } + + private: + static bool isSet; + static ULONG guaranteeSize; + static LPTOP_LEVEL_EXCEPTION_FILTER previousTop; + }; + + bool FatalConditionHandler::isSet = false; + ULONG FatalConditionHandler::guaranteeSize = 0; + LPTOP_LEVEL_EXCEPTION_FILTER FatalConditionHandler::previousTop = nullptr; + +#else // DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + int id; + const char* name; + }; + SignalDefs signalDefs[] = {{SIGINT, "SIGINT - Terminal interrupt signal"}, + {SIGILL, "SIGILL - Illegal instruction signal"}, + {SIGFPE, "SIGFPE - Floating point error signal"}, + {SIGSEGV, "SIGSEGV - Segmentation violation signal"}, + {SIGTERM, "SIGTERM - Termination request signal"}, + {SIGABRT, "SIGABRT - Abort (abnormal termination) signal"}}; + + struct FatalConditionHandler + { + static bool isSet; + static struct sigaction oldSigActions[DOCTEST_COUNTOF(signalDefs)]; + static stack_t oldSigStack; + static char altStackMem[4 * SIGSTKSZ]; + + static void handleSignal(int sig) { + const char* name = ""; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + SignalDefs& def = signalDefs[i]; + if(sig == def.id) { + name = def.name; + break; + } + } + reset(); + reportFatal(name); + raise(sig); + } + + FatalConditionHandler() { + isSet = true; + stack_t sigStack; + sigStack.ss_sp = altStackMem; + sigStack.ss_size = sizeof(altStackMem); + sigStack.ss_flags = 0; + sigaltstack(&sigStack, &oldSigStack); + struct sigaction sa = {}; + sa.sa_handler = handleSignal; // NOLINT + sa.sa_flags = SA_ONSTACK; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); + } + } + + ~FatalConditionHandler() { reset(); } + static void reset() { + if(isSet) { + // Set signals back to previous values -- hopefully nobody overwrote them in the meantime + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); + } + // Return the old stack + sigaltstack(&oldSigStack, nullptr); + isSet = false; + } + } + }; + + bool FatalConditionHandler::isSet = false; + struct sigaction FatalConditionHandler::oldSigActions[DOCTEST_COUNTOF(signalDefs)] = {}; + stack_t FatalConditionHandler::oldSigStack = {}; + char FatalConditionHandler::altStackMem[] = {}; + +#endif // DOCTEST_PLATFORM_WINDOWS +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + +} // namespace + +namespace { + using namespace detail; + +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_OUTPUT_DEBUG_STRING(text) ::OutputDebugStringA(text) +#else + // TODO: integration with XCode and other IDEs +#define DOCTEST_OUTPUT_DEBUG_STRING(text) // NOLINT(clang-diagnostic-unused-macros) +#endif // Platform + + void addAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsCurrentTest_atomic++; + } + + void addFailedAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsFailedCurrentTest_atomic++; + } + +#if defined(DOCTEST_CONFIG_POSIX_SIGNALS) || defined(DOCTEST_CONFIG_WINDOWS_SEH) + void reportFatal(const std::string& message) { + g_cs->failure_flags |= TestCaseFailureReason::Crash; + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, {message.c_str(), true}); + + while(g_cs->subcasesStack.size()) { + g_cs->subcasesStack.pop_back(); + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + + g_cs->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH +} // namespace +namespace detail { + + ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const char* exception_string) { + m_test_case = g_cs->currentTest; + m_at = at; + m_file = file; + m_line = line; + m_expr = expr; + m_failed = true; + m_threw = false; + m_threw_as = false; + m_exception_type = exception_type; + m_exception_string = exception_string; +#if DOCTEST_MSVC + if(m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC + ++m_expr; +#endif // MSVC + } + + void ResultBuilder::setResult(const Result& res) { + m_decomp = res.m_decomp; + m_failed = !res.m_passed; + } + + void ResultBuilder::translateException() { + m_threw = true; + m_exception = translateActiveException(); + } + + bool ResultBuilder::log() { + if(m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw; + } else if((m_at & assertType::is_throws_as) && (m_at & assertType::is_throws_with)) { //!OCLINT + m_failed = !m_threw_as || (m_exception != m_exception_string); + } else if(m_at & assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw_as; + } else if(m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + m_failed = m_exception != m_exception_string; + } else if(m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + m_failed = m_threw; + } + + if(m_exception.size()) + m_exception = String("\"") + m_exception + "\""; + + if(is_running_in_test) { + addAssert(m_at); + DOCTEST_ITERATE_THROUGH_REPORTERS(log_assert, *this); + + if(m_failed) + addFailedAssert(m_at); + } else if(m_failed) { + failed_out_of_a_testing_context(*this); + } + + return m_failed && isDebuggerActive() && + !getContextOptions()->no_breaks; // break into debugger + } + + void ResultBuilder::react() const { + if(m_failed && checkIfShouldThrow(m_at)) + throwException(); + } + + void failed_out_of_a_testing_context(const AssertData& ad) { + if(g_cs->ah) + g_cs->ah(ad); + else + std::abort(); + } + + void decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, + Result result) { + bool failed = !result.m_passed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(result.m_decomp); + DOCTEST_ASSERT_IN_TESTS(result.m_decomp); + } + + MessageBuilder::MessageBuilder(const char* file, int line, assertType::Enum severity) { + m_stream = getTlsOss(); + m_file = file; + m_line = line; + m_severity = severity; + } + + IExceptionTranslator::IExceptionTranslator() = default; + IExceptionTranslator::~IExceptionTranslator() = default; + + bool MessageBuilder::log() { + m_string = getTlsOssResult(); + DOCTEST_ITERATE_THROUGH_REPORTERS(log_message, *this); + + const bool isWarn = m_severity & assertType::is_warn; + + // warn is just a message in this context so we don't treat it as an assert + if(!isWarn) { + addAssert(m_severity); + addFailedAssert(m_severity); + } + + return isDebuggerActive() && !getContextOptions()->no_breaks && !isWarn; // break + } + + void MessageBuilder::react() { + if(m_severity & assertType::is_require) //!OCLINT bitwise operator in conditional + throwException(); + } + + MessageBuilder::~MessageBuilder() = default; +} // namespace detail +namespace { + using namespace detail; + + template + DOCTEST_NORETURN void throw_exception(Ex const& e) { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + throw e; +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + std::cerr << "doctest will terminate because it needed to throw an exception.\n" + << "The message was: " << e.what() << '\n'; + std::terminate(); +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } + +#ifndef DOCTEST_INTERNAL_ERROR +#define DOCTEST_INTERNAL_ERROR(msg) \ + throw_exception(std::logic_error( \ + __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) +#endif // DOCTEST_INTERNAL_ERROR + + // clang-format off + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + + class XmlEncode { + public: + enum ForWhat { ForTextNodes, ForAttributes }; + + XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); + + void encodeTo( std::ostream& os ) const; + + friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); + + private: + std::string m_str; + ForWhat m_forWhat; + }; + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ); + + ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT; + ScopedElement& operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT; + + ~ScopedElement(); + + ScopedElement& writeText( std::string const& text, bool indent = true ); + + template + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer = nullptr; + }; + + XmlWriter( std::ostream& os = std::cout ); + ~XmlWriter(); + + XmlWriter( XmlWriter const& ) = delete; + XmlWriter& operator=( XmlWriter const& ) = delete; + + XmlWriter& startElement( std::string const& name ); + + ScopedElement scopedElement( std::string const& name ); + + XmlWriter& endElement(); + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); + + XmlWriter& writeAttribute( std::string const& name, const char* attribute ); + + XmlWriter& writeAttribute( std::string const& name, bool attribute ); + + template + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + std::stringstream rss; + rss << attribute; + return writeAttribute( name, rss.str() ); + } + + XmlWriter& writeText( std::string const& text, bool indent = true ); + + //XmlWriter& writeComment( std::string const& text ); + + //void writeStylesheetRef( std::string const& url ); + + //XmlWriter& writeBlankLine(); + + void ensureTagClosed(); + + private: + + void writeDeclaration(); + + void newlineIfNecessary(); + + bool m_tagIsOpen = false; + bool m_needsNewline = false; + std::vector m_tags; + std::string m_indent; + std::ostream& m_os; + }; + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + +using uchar = unsigned char; + +namespace { + + size_t trailingBytes(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return 2; + } + if ((c & 0xF0) == 0xE0) { + return 3; + } + if ((c & 0xF8) == 0xF0) { + return 4; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + uint32_t headerValue(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return c & 0x1F; + } + if ((c & 0xF0) == 0xE0) { + return c & 0x0F; + } + if ((c & 0xF8) == 0xF0) { + return c & 0x07; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + void hexEscapeChar(std::ostream& os, unsigned char c) { + std::ios_base::fmtflags f(os.flags()); + os << "\\x" + << std::uppercase << std::hex << std::setfill('0') << std::setw(2) + << static_cast(c); + os.flags(f); + } + +} // anonymous namespace + + XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) + : m_str( str ), + m_forWhat( forWhat ) + {} + + void XmlEncode::encodeTo( std::ostream& os ) const { + // Apostrophe escaping not necessary if we always use " to write attributes + // (see: https://www.w3.org/TR/xml/#syntax) + + for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { + uchar c = m_str[idx]; + switch (c) { + case '<': os << "<"; break; + case '&': os << "&"; break; + + case '>': + // See: https://www.w3.org/TR/xml/#syntax + if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') + os << ">"; + else + os << c; + break; + + case '\"': + if (m_forWhat == ForAttributes) + os << """; + else + os << c; + break; + + default: + // Check for control characters and invalid utf-8 + + // Escape control characters in standard ascii + // see https://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 + if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { + hexEscapeChar(os, c); + break; + } + + // Plain ASCII: Write it to stream + if (c < 0x7F) { + os << c; + break; + } + + // UTF-8 territory + // Check if the encoding is valid and if it is not, hex escape bytes. + // Important: We do not check the exact decoded values for validity, only the encoding format + // First check that this bytes is a valid lead byte: + // This means that it is not encoded as 1111 1XXX + // Or as 10XX XXXX + if (c < 0xC0 || + c >= 0xF8) { + hexEscapeChar(os, c); + break; + } + + auto encBytes = trailingBytes(c); + // Are there enough bytes left to avoid accessing out-of-bounds memory? + if (idx + encBytes - 1 >= m_str.size()) { + hexEscapeChar(os, c); + break; + } + // The header is valid, check data + // The next encBytes bytes must together be a valid utf-8 + // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) + bool valid = true; + uint32_t value = headerValue(c); + for (std::size_t n = 1; n < encBytes; ++n) { + uchar nc = m_str[idx + n]; + valid &= ((nc & 0xC0) == 0x80); + value = (value << 6) | (nc & 0x3F); + } + + if ( + // Wrong bit pattern of following bytes + (!valid) || + // Overlong encodings + (value < 0x80) || + ( value < 0x800 && encBytes > 2) || // removed "0x80 <= value &&" because redundant + (0x800 < value && value < 0x10000 && encBytes > 3) || + // Encoded value out of range + (value >= 0x110000) + ) { + hexEscapeChar(os, c); + break; + } + + // If we got here, this is in fact a valid(ish) utf-8 sequence + for (std::size_t n = 0; n < encBytes; ++n) { + os << m_str[idx + n]; + } + idx += encBytes - 1; + break; + } + } + } + + std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { + xmlEncode.encodeTo( os ); + return os; + } + + XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT + : m_writer( other.m_writer ){ + other.m_writer = nullptr; + } + XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT { + if ( m_writer ) { + m_writer->endElement(); + } + m_writer = other.m_writer; + other.m_writer = nullptr; + return *this; + } + + + XmlWriter::ScopedElement::~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { + m_writer->writeText( text, indent ); + return *this; + } + + XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) + { + writeDeclaration(); + } + + XmlWriter::~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + + XmlWriter& XmlWriter::startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + m_os << m_indent << '<' << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& XmlWriter::endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + m_os << "/>"; + m_tagIsOpen = false; + } + else { + m_os << m_indent << ""; + } + m_os << std::endl; + m_tags.pop_back(); + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, const char* attribute ) { + if( !name.empty() && attribute && attribute[0] != '\0' ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { + m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + m_os << m_indent; + m_os << XmlEncode( text ); + m_needsNewline = true; + } + return *this; + } + + //XmlWriter& XmlWriter::writeComment( std::string const& text ) { + // ensureTagClosed(); + // m_os << m_indent << ""; + // m_needsNewline = true; + // return *this; + //} + + //void XmlWriter::writeStylesheetRef( std::string const& url ) { + // m_os << "\n"; + //} + + //XmlWriter& XmlWriter::writeBlankLine() { + // ensureTagClosed(); + // m_os << '\n'; + // return *this; + //} + + void XmlWriter::ensureTagClosed() { + if( m_tagIsOpen ) { + m_os << ">" << std::endl; + m_tagIsOpen = false; + } + } + + void XmlWriter::writeDeclaration() { + m_os << "\n"; + } + + void XmlWriter::newlineIfNecessary() { + if( m_needsNewline ) { + m_os << std::endl; + m_needsNewline = false; + } + } + +// ================================================================================================= +// End of copy-pasted code from Catch +// ================================================================================================= + + // clang-format on + + struct XmlReporter : public IReporter + { + XmlWriter xml; + std::mutex mutex; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + XmlReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + std::stringstream ss; + for(int i = 0; i < num_contexts; ++i) { + contexts[i]->stringify(&ss); + xml.scopedElement("Info").writeText(ss.str()); + ss.str(""); + } + } + } + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + void test_case_start_impl(const TestCaseData& in) { + bool open_ts_tag = false; + if(tc != nullptr) { // we have already opened a test suite + if(std::strcmp(tc->m_test_suite, in.m_test_suite) != 0) { + xml.endElement(); + open_ts_tag = true; + } + } + else { + open_ts_tag = true; // first test case ==> first test suite + } + + if(open_ts_tag) { + xml.startElement("TestSuite"); + xml.writeAttribute("name", in.m_test_suite); + } + + tc = ∈ + xml.startElement("TestCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file.c_str())) + .writeAttribute("line", line(in.m_line)) + .writeAttribute("description", in.m_description); + + if(Approx(in.m_timeout) != 0) + xml.writeAttribute("timeout", in.m_timeout); + if(in.m_may_fail) + xml.writeAttribute("may_fail", true); + if(in.m_should_fail) + xml.writeAttribute("should_fail", true); + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + test_run_start(); + if(opt.list_reporters) { + for(auto& curr : getListeners()) + xml.scopedElement("Listener") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + for(auto& curr : getReporters()) + xml.scopedElement("Reporter") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + } else if(opt.count || opt.list_test_cases) { + for(unsigned i = 0; i < in.num_data; ++i) { + xml.scopedElement("TestCase").writeAttribute("name", in.data[i]->m_name) + .writeAttribute("testsuite", in.data[i]->m_test_suite) + .writeAttribute("filename", skipPathFromFilename(in.data[i]->m_file.c_str())) + .writeAttribute("line", line(in.data[i]->m_line)); + } + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + } else if(opt.list_test_suites) { + for(unsigned i = 0; i < in.num_data; ++i) + xml.scopedElement("TestSuite").writeAttribute("name", in.data[i]->m_test_suite); + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + xml.scopedElement("OverallResultsTestSuites") + .writeAttribute("unskipped", in.run_stats->numTestSuitesPassingFilters); + } + xml.endElement(); + } + + void test_run_start() override { + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + + xml.startElement("doctest").writeAttribute("binary", binary_name); + if(opt.no_version == false) + xml.writeAttribute("version", DOCTEST_VERSION_STR); + + // only the consequential ones (TODO: filters) + xml.scopedElement("Options") + .writeAttribute("order_by", opt.order_by.c_str()) + .writeAttribute("rand_seed", opt.rand_seed) + .writeAttribute("first", opt.first) + .writeAttribute("last", opt.last) + .writeAttribute("abort_after", opt.abort_after) + .writeAttribute("subcase_filter_levels", opt.subcase_filter_levels) + .writeAttribute("case_sensitive", opt.case_sensitive) + .writeAttribute("no_throw", opt.no_throw) + .writeAttribute("no_skip", opt.no_skip); + } + + void test_run_end(const TestRunStats& p) override { + if(tc) // the TestSuite tag - only if there has been at least 1 test case + xml.endElement(); + + xml.scopedElement("OverallResultsAsserts") + .writeAttribute("successes", p.numAsserts - p.numAssertsFailed) + .writeAttribute("failures", p.numAssertsFailed); + + xml.startElement("OverallResultsTestCases") + .writeAttribute("successes", + p.numTestCasesPassingFilters - p.numTestCasesFailed) + .writeAttribute("failures", p.numTestCasesFailed); + if(opt.no_skipped_summary == false) + xml.writeAttribute("skipped", p.numTestCases - p.numTestCasesPassingFilters); + xml.endElement(); + + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + test_case_start_impl(in); + xml.ensureTagClosed(); + } + + void test_case_reenter(const TestCaseData&) override {} + + void test_case_end(const CurrentTestCaseStats& st) override { + xml.startElement("OverallResultsAsserts") + .writeAttribute("successes", + st.numAssertsCurrentTest - st.numAssertsFailedCurrentTest) + .writeAttribute("failures", st.numAssertsFailedCurrentTest); + if(opt.duration) + xml.writeAttribute("duration", st.seconds); + if(tc->m_expected_failures) + xml.writeAttribute("expected_failures", tc->m_expected_failures); + xml.endElement(); + + xml.endElement(); + } + + void test_case_exception(const TestCaseException& e) override { + std::lock_guard lock(mutex); + + xml.scopedElement("Exception") + .writeAttribute("crash", e.is_crash) + .writeText(e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + std::lock_guard lock(mutex); + + xml.startElement("SubCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file)) + .writeAttribute("line", line(in.m_line)); + xml.ensureTagClosed(); + } + + void subcase_end() override { xml.endElement(); } + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed && !opt.success) + return; + + std::lock_guard lock(mutex); + + xml.startElement("Expression") + .writeAttribute("success", !rb.m_failed) + .writeAttribute("type", assertString(rb.m_at)) + .writeAttribute("filename", skipPathFromFilename(rb.m_file)) + .writeAttribute("line", line(rb.m_line)); + + xml.scopedElement("Original").writeText(rb.m_expr); + + if(rb.m_threw) + xml.scopedElement("Exception").writeText(rb.m_exception.c_str()); + + if(rb.m_at & assertType::is_throws_as) + xml.scopedElement("ExpectedException").writeText(rb.m_exception_type); + if(rb.m_at & assertType::is_throws_with) + xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string); + if((rb.m_at & assertType::is_normal) && !rb.m_threw) + xml.scopedElement("Expanded").writeText(rb.m_decomp.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void log_message(const MessageData& mb) override { + std::lock_guard lock(mutex); + + xml.startElement("Message") + .writeAttribute("type", failureString(mb.m_severity)) + .writeAttribute("filename", skipPathFromFilename(mb.m_file)) + .writeAttribute("line", line(mb.m_line)); + + xml.scopedElement("Text").writeText(mb.m_string.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void test_case_skipped(const TestCaseData& in) override { + if(opt.no_skipped_summary == false) { + test_case_start_impl(in); + xml.writeAttribute("skipped", "true"); + xml.endElement(); + } + } + }; + + DOCTEST_REGISTER_REPORTER("xml", 0, XmlReporter); + + void fulltext_log_assert_to_stream(std::ostream& s, const AssertData& rb) { + if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) == + 0) //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << " ) " + << Color::None; + + if(rb.m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "threw as expected!" : "did NOT throw at all!") << "\n"; + } else if((rb.m_at & assertType::is_throws_as) && + (rb.m_at & assertType::is_throws_with)) { //!OCLINT + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string << "\", " << rb.m_exception_type << " ) " << Color::None; + if(rb.m_threw) { + if(!rb.m_failed) { + s << "threw as expected!\n"; + } else { + s << "threw a DIFFERENT exception! (contents: " << rb.m_exception << ")\n"; + } + } else { + s << "did NOT throw at all!\n"; + } + } else if(rb.m_at & + assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", " + << rb.m_exception_type << " ) " << Color::None + << (rb.m_threw ? (rb.m_threw_as ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & + assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string << "\" ) " << Color::None + << (rb.m_threw ? (!rb.m_failed ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "THREW exception: " : "didn't throw!") << Color::Cyan + << rb.m_exception << "\n"; + } else { + s << (rb.m_threw ? "THREW exception: " : + (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")); + if(rb.m_threw) + s << rb.m_exception << "\n"; + else + s << " values: " << assertString(rb.m_at) << "( " << rb.m_decomp << " )\n"; + } + } + + // TODO: + // - log_contexts() + // - log_message() + // - respond to queries + // - honor remaining options + // - more attributes in tags + struct JUnitReporter : public IReporter + { + XmlWriter xml; + std::mutex mutex; + Timer timer; + std::vector deepestSubcaseStackNames; + + struct JUnitTestCaseData + { +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") // gmtime + static std::string getCurrentTimestamp() { + // Beware, this is not reentrant because of backward compatibility issues + // Also, UTC only, again because of backward compatibility (%z is C++11) + time_t rawtime; + std::time(&rawtime); + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + + std::tm* timeInfo; + timeInfo = std::gmtime(&rawtime); + + char timeStamp[timeStampSize]; + const char* const fmt = "%Y-%m-%dT%H:%M:%SZ"; + + std::strftime(timeStamp, timeStampSize, fmt, timeInfo); + return std::string(timeStamp); + } +DOCTEST_CLANG_SUPPRESS_WARNING_POP + + struct JUnitTestMessage + { + JUnitTestMessage(const std::string& _message, const std::string& _type, const std::string& _details) + : message(_message), type(_type), details(_details) {} + + JUnitTestMessage(const std::string& _message, const std::string& _details) + : message(_message), type(), details(_details) {} + + std::string message, type, details; + }; + + struct JUnitTestCase + { + JUnitTestCase(const std::string& _classname, const std::string& _name) + : classname(_classname), name(_name), time(0), failures() {} + + std::string classname, name; + double time; + std::vector failures, errors; + }; + + void add(const std::string& classname, const std::string& name) { + testcases.emplace_back(classname, name); + } + + void appendSubcaseNamesToLastTestcase(std::vector nameStack) { + for(auto& curr: nameStack) + if(curr.size()) + testcases.back().name += std::string("/") + curr.c_str(); + } + + void addTime(double time) { + if(time < 1e-4) + time = 0; + testcases.back().time = time; + totalSeconds += time; + } + + void addFailure(const std::string& message, const std::string& type, const std::string& details) { + testcases.back().failures.emplace_back(message, type, details); + ++totalFailures; + } + + void addError(const std::string& message, const std::string& details) { + testcases.back().errors.emplace_back(message, details); + ++totalErrors; + } + + std::vector testcases; + double totalSeconds = 0; + int totalErrors = 0, totalFailures = 0; + }; + + JUnitTestCaseData testCaseData; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + JUnitReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData&) override {} + + void test_run_start() override {} + + void test_run_end(const TestRunStats& p) override { + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + xml.startElement("testsuites"); + xml.startElement("testsuite").writeAttribute("name", binary_name) + .writeAttribute("errors", testCaseData.totalErrors) + .writeAttribute("failures", testCaseData.totalFailures) + .writeAttribute("tests", p.numAsserts); + if(opt.no_time_in_output == false) { + xml.writeAttribute("time", testCaseData.totalSeconds); + xml.writeAttribute("timestamp", JUnitTestCaseData::getCurrentTimestamp()); + } + if(opt.no_version == false) + xml.writeAttribute("doctest_version", DOCTEST_VERSION_STR); + + for(const auto& testCase : testCaseData.testcases) { + xml.startElement("testcase") + .writeAttribute("classname", testCase.classname) + .writeAttribute("name", testCase.name); + if(opt.no_time_in_output == false) + xml.writeAttribute("time", testCase.time); + // This is not ideal, but it should be enough to mimic gtest's junit output. + xml.writeAttribute("status", "run"); + + for(const auto& failure : testCase.failures) { + xml.scopedElement("failure") + .writeAttribute("message", failure.message) + .writeAttribute("type", failure.type) + .writeText(failure.details, false); + } + + for(const auto& error : testCase.errors) { + xml.scopedElement("error") + .writeAttribute("message", error.message) + .writeText(error.details); + } + + xml.endElement(); + } + xml.endElement(); + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + timer.start(); + } + + void test_case_reenter(const TestCaseData& in) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + + timer.start(); + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + } + + void test_case_end(const CurrentTestCaseStats&) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + } + + void test_case_exception(const TestCaseException& e) override { + std::lock_guard lock(mutex); + testCaseData.addError("exception", e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + std::lock_guard lock(mutex); + deepestSubcaseStackNames.push_back(in.m_name); + } + + void subcase_end() override {} + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed) // report only failures & ignore the `success` option + return; + + std::lock_guard lock(mutex); + + std::ostringstream os; + os << skipPathFromFilename(rb.m_file) << (opt.gnu_file_line ? ":" : "(") + << line(rb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; + + fulltext_log_assert_to_stream(os, rb); + testCaseData.addFailure(rb.m_decomp.c_str(), assertString(rb.m_at), os.str()); + } + + void log_message(const MessageData&) override {} + + void test_case_skipped(const TestCaseData&) override {} + }; + + DOCTEST_REGISTER_REPORTER("junit", 0, JUnitReporter); + + struct Whitespace + { + int nrSpaces; + explicit Whitespace(int nr) + : nrSpaces(nr) {} + }; + + std::ostream& operator<<(std::ostream& out, const Whitespace& ws) { + if(ws.nrSpaces != 0) + out << std::setw(ws.nrSpaces) << ' '; + return out; + } + + struct ConsoleReporter : public IReporter + { + std::ostream& s; + bool hasLoggedCurrentTestStart; + std::vector subcasesStack; + size_t currentSubcaseLevel; + std::mutex mutex; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc; + + ConsoleReporter(const ContextOptions& co) + : s(*co.cout) + , opt(co) {} + + ConsoleReporter(const ContextOptions& co, std::ostream& ostr) + : s(ostr) + , opt(co) {} + + // ========================================================================================= + // WHAT FOLLOWS ARE HELPERS USED BY THE OVERRIDES OF THE VIRTUAL METHODS OF THE INTERFACE + // ========================================================================================= + + void separator_to_stream() { + s << Color::Yellow + << "===============================================================================" + "\n"; + } + + const char* getSuccessOrFailString(bool success, assertType::Enum at, + const char* success_str) { + if(success) + return success_str; + return failureString(at); + } + + Color::Enum getSuccessOrFailColor(bool success, assertType::Enum at) { + return success ? Color::BrightGreen : + (at & assertType::is_warn) ? Color::Yellow : Color::Red; + } + + void successOrFailColoredStringToStream(bool success, assertType::Enum at, + const char* success_str = "SUCCESS") { + s << getSuccessOrFailColor(success, at) + << getSuccessOrFailString(success, at, success_str) << ": "; + } + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + + s << Color::None << " logged: "; + for(int i = 0; i < num_contexts; ++i) { + s << (i == 0 ? "" : " "); + contexts[i]->stringify(&s); + s << "\n"; + } + } + + s << "\n"; + } + + // this was requested to be made virtual so users could override it + virtual void file_line_to_stream(const char* file, int line, + const char* tail = "") { + s << Color::LightGrey << skipPathFromFilename(file) << (opt.gnu_file_line ? ":" : "(") + << (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option + << (opt.gnu_file_line ? ":" : "):") << tail; + } + + void logTestStart() { + if(hasLoggedCurrentTestStart) + return; + + separator_to_stream(); + file_line_to_stream(tc->m_file.c_str(), tc->m_line, "\n"); + if(tc->m_description) + s << Color::Yellow << "DESCRIPTION: " << Color::None << tc->m_description << "\n"; + if(tc->m_test_suite && tc->m_test_suite[0] != '\0') + s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n"; + if(strncmp(tc->m_name, " Scenario:", 11) != 0) + s << Color::Yellow << "TEST CASE: "; + s << Color::None << tc->m_name << "\n"; + + for(size_t i = 0; i < currentSubcaseLevel; ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + + if(currentSubcaseLevel != subcasesStack.size()) { + s << Color::Yellow << "\nDEEPEST SUBCASE STACK REACHED (DIFFERENT FROM THE CURRENT ONE):\n" << Color::None; + for(size_t i = 0; i < subcasesStack.size(); ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + } + + s << "\n"; + + hasLoggedCurrentTestStart = true; + } + + void printVersion() { + if(opt.no_version == false) + s << Color::Cyan << "[doctest] " << Color::None << "doctest version is \"" + << DOCTEST_VERSION_STR << "\"\n"; + } + + void printIntro() { + printVersion(); + s << Color::Cyan << "[doctest] " << Color::None + << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; + } + + void printHelp() { + int sizePrefixDisplay = static_cast(strlen(DOCTEST_OPTIONS_PREFIX_DISPLAY)); + printVersion(); + // clang-format off + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "boolean values: \"1/on/yes/true\" or \"0/off/no/false\"\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filter values: \"str1,str2,str3\" (comma separated strings)\n"; + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filters use wildcards for matching strings\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "something passes a filter if any of the strings in a filter matches\n"; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "ALL FLAGS, OPTIONS AND FILTERS ALSO AVAILABLE WITH A \"" DOCTEST_CONFIG_OPTIONS_PREFIX "\" PREFIX!!!\n"; +#endif + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "Query flags - the program quits after them. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "?, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "help, -" DOCTEST_OPTIONS_PREFIX_DISPLAY "h " + << Whitespace(sizePrefixDisplay*0) << "prints this message\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "v, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "version " + << Whitespace(sizePrefixDisplay*1) << "prints the version\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "c, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "count " + << Whitespace(sizePrefixDisplay*1) << "prints the number of matching tests\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ltc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-cases " + << Whitespace(sizePrefixDisplay*1) << "lists all matching tests by name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-suites " + << Whitespace(sizePrefixDisplay*1) << "lists all matching test suites\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-reporters " + << Whitespace(sizePrefixDisplay*1) << "lists all registered reporters\n\n"; + // ================================================================================== << 79 + s << Color::Cyan << "[doctest] " << Color::None; + s << "The available / options/filters are:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sfe, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tse, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase= " + << Whitespace(sizePrefixDisplay*1) << "filters subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "r, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "reporters= " + << Whitespace(sizePrefixDisplay*1) << "reporters to use (console is default)\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "o, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "out= " + << Whitespace(sizePrefixDisplay*1) << "output filename\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ob, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "order-by= " + << Whitespace(sizePrefixDisplay*1) << "how the tests should be ordered\n"; + s << Whitespace(sizePrefixDisplay*3) << " - by [file/suite/name/rand]\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "rs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "rand-seed= " + << Whitespace(sizePrefixDisplay*1) << "seed for random ordering\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "f, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "first= " + << Whitespace(sizePrefixDisplay*1) << "the first test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "l, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "last= " + << Whitespace(sizePrefixDisplay*1) << "the last test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "aa, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "abort-after= " + << Whitespace(sizePrefixDisplay*1) << "stop after failed assertions\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "scfl,--" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-filter-levels= " + << Whitespace(sizePrefixDisplay*1) << "apply filters for the first levels\n"; + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "Bool options - can be used like flags and true is assumed. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "s, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "success= " + << Whitespace(sizePrefixDisplay*1) << "include successful assertions in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "cs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "case-sensitive= " + << Whitespace(sizePrefixDisplay*1) << "filters being treated as case sensitive\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "e, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "exit= " + << Whitespace(sizePrefixDisplay*1) << "exits after the tests finish\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "d, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "duration= " + << Whitespace(sizePrefixDisplay*1) << "prints the time duration of each test\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nt, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-throw= " + << Whitespace(sizePrefixDisplay*1) << "skips exceptions-related assert checks\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ne, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-exitcode= " + << Whitespace(sizePrefixDisplay*1) << "returns (or exits) always with success\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-run= " + << Whitespace(sizePrefixDisplay*1) << "skips all runtime doctest operations\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nv, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-version= " + << Whitespace(sizePrefixDisplay*1) << "omit the framework version in the output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-colors= " + << Whitespace(sizePrefixDisplay*1) << "disables colors in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "fc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "force-colors= " + << Whitespace(sizePrefixDisplay*1) << "use colors even when not in a tty\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nb, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-breaks= " + << Whitespace(sizePrefixDisplay*1) << "disables breakpoints in debuggers\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ns, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-skip= " + << Whitespace(sizePrefixDisplay*1) << "don't skip test cases marked as skip\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "gfl, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "gnu-file-line= " + << Whitespace(sizePrefixDisplay*1) << ":n: vs (n): for line numbers in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "npf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-path-filenames= " + << Whitespace(sizePrefixDisplay*1) << "only filenames and no paths in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nln, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-line-numbers= " + << Whitespace(sizePrefixDisplay*1) << "0 instead of real line numbers in output\n"; + // ================================================================================== << 79 + // clang-format on + + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "for more information visit the project documentation\n\n"; + } + + void printRegisteredReporters() { + printVersion(); + auto printReporters = [this] (const reporterMap& reporters, const char* type) { + if(reporters.size()) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all registered " << type << "\n"; + for(auto& curr : reporters) + s << "priority: " << std::setw(5) << curr.first.first + << " name: " << curr.first.second << "\n"; + } + }; + printReporters(getListeners(), "listeners"); + printReporters(getReporters(), "reporters"); + } + + void list_query_results() { + separator_to_stream(); + if(opt.count || opt.list_test_cases) { + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + } else if(opt.list_test_suites) { + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "test suites with unskipped test cases passing the current filters: " + << g_cs->numTestSuitesPassingFilters << "\n"; + } + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + if(opt.version) { + printVersion(); + } else if(opt.help) { + printHelp(); + } else if(opt.list_reporters) { + printRegisteredReporters(); + } else if(opt.count || opt.list_test_cases) { + if(opt.list_test_cases) { + s << Color::Cyan << "[doctest] " << Color::None + << "listing all test case names\n"; + separator_to_stream(); + } + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_name << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + + } else if(opt.list_test_suites) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all test suites\n"; + separator_to_stream(); + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_test_suite << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "test suites with unskipped test cases passing the current filters: " + << g_cs->numTestSuitesPassingFilters << "\n"; + } + } + + void test_run_start() override { printIntro(); } + + void test_run_end(const TestRunStats& p) override { + separator_to_stream(); + s << std::dec; + + const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0; + s << Color::Cyan << "[doctest] " << Color::None << "test cases: " << std::setw(6) + << p.numTestCasesPassingFilters << " | " + << ((p.numTestCasesPassingFilters == 0 || anythingFailed) ? Color::None : + Color::Green) + << std::setw(6) << p.numTestCasesPassingFilters - p.numTestCasesFailed << " passed" + << Color::None << " | " << (p.numTestCasesFailed > 0 ? Color::Red : Color::None) + << std::setw(6) << p.numTestCasesFailed << " failed" << Color::None << " | "; + if(opt.no_skipped_summary == false) { + const int numSkipped = p.numTestCases - p.numTestCasesPassingFilters; + s << (numSkipped == 0 ? Color::None : Color::Yellow) << std::setw(6) << numSkipped + << " skipped" << Color::None; + } + s << "\n"; + s << Color::Cyan << "[doctest] " << Color::None << "assertions: " << std::setw(6) + << p.numAsserts << " | " + << ((p.numAsserts == 0 || anythingFailed) ? Color::None : Color::Green) + << std::setw(6) << (p.numAsserts - p.numAssertsFailed) << " passed" << Color::None + << " | " << (p.numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(6) + << p.numAssertsFailed << " failed" << Color::None << " |\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "Status: " << (p.numTestCasesFailed > 0 ? Color::Red : Color::Green) + << ((p.numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl; + } + + void test_case_start(const TestCaseData& in) override { + hasLoggedCurrentTestStart = false; + tc = ∈ + subcasesStack.clear(); + currentSubcaseLevel = 0; + } + + void test_case_reenter(const TestCaseData&) override { + subcasesStack.clear(); + } + + void test_case_end(const CurrentTestCaseStats& st) override { + // log the preamble of the test case only if there is something + // else to print - something other than that an assert has failed + if(opt.duration || + (st.failure_flags && st.failure_flags != TestCaseFailureReason::AssertFailure)) + logTestStart(); + + if(opt.duration) + s << Color::None << std::setprecision(6) << std::fixed << st.seconds + << " s: " << tc->m_name << "\n"; + + if(st.failure_flags & TestCaseFailureReason::Timeout) + s << Color::Red << "Test case exceeded time limit of " << std::setprecision(6) + << std::fixed << tc->m_timeout << "!\n"; + + if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedButDidnt) { + s << Color::Red << "Should have failed but didn't! Marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedAndDid) { + s << Color::Yellow << "Failed as expected so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::CouldHaveFailedAndDid) { + s << Color::Yellow << "Allowed to fail so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::DidntFailExactlyNumTimes) { + s << Color::Red << "Didn't fail exactly " << tc->m_expected_failures + << " times so marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::FailedExactlyNumTimes) { + s << Color::Yellow << "Failed exactly " << tc->m_expected_failures + << " times as expected so marking it as not failed!\n"; + } + if(st.failure_flags & TestCaseFailureReason::TooManyFailedAsserts) { + s << Color::Red << "Aborting - too many failed asserts!\n"; + } + s << Color::None; // lgtm [cpp/useless-expression] + } + + void test_case_exception(const TestCaseException& e) override { + logTestStart(); + + file_line_to_stream(tc->m_file.c_str(), tc->m_line, " "); + successOrFailColoredStringToStream(false, e.is_crash ? assertType::is_require : + assertType::is_check); + s << Color::Red << (e.is_crash ? "test case CRASHED: " : "test case THREW exception: ") + << Color::Cyan << e.error_string << "\n"; + + int num_stringified_contexts = get_num_stringified_contexts(); + if(num_stringified_contexts) { + auto stringified_contexts = get_stringified_contexts(); + s << Color::None << " logged: "; + for(int i = num_stringified_contexts; i > 0; --i) { + s << (i == num_stringified_contexts ? "" : " ") + << stringified_contexts[i - 1] << "\n"; + } + } + s << "\n" << Color::None; + } + + void subcase_start(const SubcaseSignature& subc) override { + std::lock_guard lock(mutex); + subcasesStack.push_back(subc); + ++currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void subcase_end() override { + std::lock_guard lock(mutex); + --currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed && !opt.success) + return; + + std::lock_guard lock(mutex); + + logTestStart(); + + file_line_to_stream(rb.m_file, rb.m_line, " "); + successOrFailColoredStringToStream(!rb.m_failed, rb.m_at); + + fulltext_log_assert_to_stream(s, rb); + + log_contexts(); + } + + void log_message(const MessageData& mb) override { + std::lock_guard lock(mutex); + + logTestStart(); + + file_line_to_stream(mb.m_file, mb.m_line, " "); + s << getSuccessOrFailColor(false, mb.m_severity) + << getSuccessOrFailString(mb.m_severity & assertType::is_warn, mb.m_severity, + "MESSAGE") << ": "; + s << Color::None << mb.m_string << "\n"; + log_contexts(); + } + + void test_case_skipped(const TestCaseData&) override {} + }; + + DOCTEST_REGISTER_REPORTER("console", 0, ConsoleReporter); + +#ifdef DOCTEST_PLATFORM_WINDOWS + struct DebugOutputWindowReporter : public ConsoleReporter + { + DOCTEST_THREAD_LOCAL static std::ostringstream oss; + + DebugOutputWindowReporter(const ContextOptions& co) + : ConsoleReporter(co, oss) {} + +#define DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(func, type, arg) \ + void func(type arg) override { \ + bool with_col = g_no_colors; \ + g_no_colors = false; \ + ConsoleReporter::func(arg); \ + DOCTEST_OUTPUT_DEBUG_STRING(oss.str().c_str()); \ + oss.str(""); \ + g_no_colors = with_col; \ + } + + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_start, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_end, const TestRunStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_start, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_reenter, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_end, const CurrentTestCaseStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_exception, const TestCaseException&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_start, const SubcaseSignature&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_end, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_assert, const AssertData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_message, const MessageData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_skipped, const TestCaseData&, in) + }; + + DOCTEST_THREAD_LOCAL std::ostringstream DebugOutputWindowReporter::oss; +#endif // DOCTEST_PLATFORM_WINDOWS + + // the implementation of parseOption() + bool parseOptionImpl(int argc, const char* const* argv, const char* pattern, String* value) { + // going from the end to the beginning and stopping on the first occurrence from the end + for(int i = argc; i > 0; --i) { + auto index = i - 1; + auto temp = std::strstr(argv[index], pattern); + if(temp && (value || strlen(temp) == strlen(pattern))) { //!OCLINT prefer early exits and continue + // eliminate matches in which the chars before the option are not '-' + bool noBadCharsFound = true; + auto curr = argv[index]; + while(curr != temp) { + if(*curr++ != '-') { + noBadCharsFound = false; + break; + } + } + if(noBadCharsFound && argv[index][0] == '-') { + if(value) { + // parsing the value of an option + temp += strlen(pattern); + const unsigned len = strlen(temp); + if(len) { + *value = temp; + return true; + } + } else { + // just a flag - no value + return true; + } + } + } + } + return false; + } + + // parses an option and returns the string after the '=' character + bool parseOption(int argc, const char* const* argv, const char* pattern, String* value = nullptr, + const String& defaultVal = String()) { + if(value) + *value = defaultVal; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + // offset (normally 3 for "dt-") to skip prefix + if(parseOptionImpl(argc, argv, pattern + strlen(DOCTEST_CONFIG_OPTIONS_PREFIX), value)) + return true; +#endif // DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + return parseOptionImpl(argc, argv, pattern, value); + } + + // locates a flag on the command line + bool parseFlag(int argc, const char* const* argv, const char* pattern) { + return parseOption(argc, argv, pattern); + } + + // parses a comma separated list of words after a pattern in one of the arguments in argv + bool parseCommaSepArgs(int argc, const char* const* argv, const char* pattern, + std::vector& res) { + String filtersString; + if(parseOption(argc, argv, pattern, &filtersString)) { + // tokenize with "," as a separator + // cppcheck-suppress strtokCalled + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + auto pch = std::strtok(filtersString.c_str(), ","); // modifies the string + while(pch != nullptr) { + if(strlen(pch)) + res.push_back(pch); + // uses the strtok() internal state to go to the next token + // cppcheck-suppress strtokCalled + pch = std::strtok(nullptr, ","); + } + DOCTEST_CLANG_SUPPRESS_WARNING_POP + return true; + } + return false; + } + + enum optionType + { + option_bool, + option_int + }; + + // parses an int/bool option from the command line + bool parseIntOption(int argc, const char* const* argv, const char* pattern, optionType type, + int& res) { + String parsedValue; + if(!parseOption(argc, argv, pattern, &parsedValue)) + return false; + + if(type == 0) { + // boolean + const char positive[][5] = {"1", "true", "on", "yes"}; // 5 - strlen("true") + 1 + const char negative[][6] = {"0", "false", "off", "no"}; // 6 - strlen("false") + 1 + + // if the value matches any of the positive/negative possibilities + for(unsigned i = 0; i < 4; i++) { + if(parsedValue.compare(positive[i], true) == 0) { + res = 1; //!OCLINT parameter reassignment + return true; + } + if(parsedValue.compare(negative[i], true) == 0) { + res = 0; //!OCLINT parameter reassignment + return true; + } + } + } else { + // integer + // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... + int theInt = std::atoi(parsedValue.c_str()); // NOLINT + if(theInt != 0) { + res = theInt; //!OCLINT parameter reassignment + return true; + } + } + return false; + } +} // namespace + +Context::Context(int argc, const char* const* argv) + : p(new detail::ContextState) { + parseArgs(argc, argv, true); + if(argc) + p->binary_name = argv[0]; +} + +Context::~Context() { + if(g_cs == p) + g_cs = nullptr; + delete p; +} + +void Context::applyCommandLine(int argc, const char* const* argv) { + parseArgs(argc, argv); + if(argc) + p->binary_name = argv[0]; +} + +// parses args +void Context::parseArgs(int argc, const char* const* argv, bool withDefaults) { + using namespace detail; + + // clang-format off + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sf=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file-exclude=",p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sfe=", p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ts=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite-exclude=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tse=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tc=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case-exclude=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tce=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sc=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase-exclude=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sce=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "reporters=", p->filters[8]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "r=", p->filters[8]); + // clang-format on + + int intRes = 0; + String strRes; + +#define DOCTEST_PARSE_AS_BOOL_OR_FLAG(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_bool, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_bool, intRes)) \ + p->var = !!intRes; \ + else if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name) || \ + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname)) \ + p->var = true; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_INT_OPTION(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_int, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_int, intRes)) \ + p->var = intRes; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_STR_OPTION(name, sname, var, default) \ + if(parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", &strRes, default) || \ + parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", &strRes, default) || \ + withDefaults) \ + p->var = strRes + + // clang-format off + DOCTEST_PARSE_STR_OPTION("out", "o", out, ""); + DOCTEST_PARSE_STR_OPTION("order-by", "ob", order_by, "file"); + DOCTEST_PARSE_INT_OPTION("rand-seed", "rs", rand_seed, 0); + + DOCTEST_PARSE_INT_OPTION("first", "f", first, 0); + DOCTEST_PARSE_INT_OPTION("last", "l", last, UINT_MAX); + + DOCTEST_PARSE_INT_OPTION("abort-after", "aa", abort_after, 0); + DOCTEST_PARSE_INT_OPTION("subcase-filter-levels", "scfl", subcase_filter_levels, INT_MAX); + + DOCTEST_PARSE_AS_BOOL_OR_FLAG("success", "s", success, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("case-sensitive", "cs", case_sensitive, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("exit", "e", exit, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("duration", "d", duration, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-throw", "nt", no_throw, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-exitcode", "ne", no_exitcode, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-run", "nr", no_run, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-version", "nv", no_version, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-colors", "nc", no_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("force-colors", "fc", force_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-breaks", "nb", no_breaks, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skip", "ns", no_skip, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("gnu-file-line", "gfl", gnu_file_line, !bool(DOCTEST_MSVC)); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-path-filenames", "npf", no_path_in_filenames, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-line-numbers", "nln", no_line_numbers, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skipped-summary", "nss", no_skipped_summary, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-time-in-output", "ntio", no_time_in_output, false); + // clang-format on + + if(withDefaults) { + p->help = false; + p->version = false; + p->count = false; + p->list_test_cases = false; + p->list_test_suites = false; + p->list_reporters = false; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "help") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "h") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "?")) { + p->help = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "version") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "v")) { + p->version = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "count") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "c")) { + p->count = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-cases") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ltc")) { + p->list_test_cases = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-suites") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lts")) { + p->list_test_suites = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-reporters") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lr")) { + p->list_reporters = true; + p->exit = true; + } +} + +// allows the user to add procedurally to the filters from the command line +void Context::addFilter(const char* filter, const char* value) { setOption(filter, value); } + +// allows the user to clear all filters from the command line +void Context::clearFilters() { + for(auto& curr : p->filters) + curr.clear(); +} + +// allows the user to override procedurally the int/bool options from the command line +void Context::setOption(const char* option, int value) { + setOption(option, toString(value).c_str()); +} + +// allows the user to override procedurally the string options from the command line +void Context::setOption(const char* option, const char* value) { + auto argv = String("-") + option + "=" + value; + auto lvalue = argv.c_str(); + parseArgs(1, &lvalue); +} + +// users should query this in their main() and exit the program if true +bool Context::shouldExit() { return p->exit; } + +void Context::setAsDefaultForAssertsOutOfTestCases() { g_cs = p; } + +void Context::setAssertHandler(detail::assert_handler ah) { p->ah = ah; } + +// the main function that does all the filtering and test running +int Context::run() { + using namespace detail; + + // save the old context state in case such was setup - for using asserts out of a testing context + auto old_cs = g_cs; + // this is the current contest + g_cs = p; + is_running_in_test = true; + + g_no_colors = p->no_colors; + p->resetRunData(); + + // stdout by default + p->cout = &std::cout; + p->cerr = &std::cerr; + + // or to a file if specified + std::fstream fstr; + if(p->out.size()) { + fstr.open(p->out.c_str(), std::fstream::out); + p->cout = &fstr; + } + + auto cleanup_and_return = [&]() { + if(fstr.is_open()) + fstr.close(); + + // restore context + g_cs = old_cs; + is_running_in_test = false; + + // we have to free the reporters which were allocated when the run started + for(auto& curr : p->reporters_currently_used) + delete curr; + p->reporters_currently_used.clear(); + + if(p->numTestCasesFailed && !p->no_exitcode) + return EXIT_FAILURE; + return EXIT_SUCCESS; + }; + + // setup default reporter if none is given through the command line + if(p->filters[8].empty()) + p->filters[8].push_back("console"); + + // check to see if any of the registered reporters has been selected + for(auto& curr : getReporters()) { + if(matchesAny(curr.first.second.c_str(), p->filters[8], false, p->case_sensitive)) + p->reporters_currently_used.push_back(curr.second(*g_cs)); + } + + // TODO: check if there is nothing in reporters_currently_used + + // prepend all listeners + for(auto& curr : getListeners()) + p->reporters_currently_used.insert(p->reporters_currently_used.begin(), curr.second(*g_cs)); + +#ifdef DOCTEST_PLATFORM_WINDOWS + if(isDebuggerActive()) + p->reporters_currently_used.push_back(new DebugOutputWindowReporter(*g_cs)); +#endif // DOCTEST_PLATFORM_WINDOWS + + // handle version, help and no_run + if(p->no_run || p->version || p->help || p->list_reporters) { + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, QueryData()); + + return cleanup_and_return(); + } + + std::vector testArray; + for(auto& curr : getRegisteredTests()) + testArray.push_back(&curr); + p->numTestCases = testArray.size(); + + // sort the collected records + if(!testArray.empty()) { + if(p->order_by.compare("file", true) == 0) { + std::sort(testArray.begin(), testArray.end(), fileOrderComparator); + } else if(p->order_by.compare("suite", true) == 0) { + std::sort(testArray.begin(), testArray.end(), suiteOrderComparator); + } else if(p->order_by.compare("name", true) == 0) { + std::sort(testArray.begin(), testArray.end(), nameOrderComparator); + } else if(p->order_by.compare("rand", true) == 0) { + std::srand(p->rand_seed); + + // random_shuffle implementation + const auto first = &testArray[0]; + for(size_t i = testArray.size() - 1; i > 0; --i) { + int idxToSwap = std::rand() % (i + 1); // NOLINT + + const auto temp = first[i]; + + first[i] = first[idxToSwap]; + first[idxToSwap] = temp; + } + } + } + + std::set testSuitesPassingFilt; + + bool query_mode = p->count || p->list_test_cases || p->list_test_suites; + std::vector queryResults; + + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_start, DOCTEST_EMPTY); + + // invoke the registered functions if they match the filter criteria (or just count them) + for(auto& curr : testArray) { + const auto& tc = *curr; + + bool skip_me = false; + if(tc.m_skip && !p->no_skip) + skip_me = true; + + if(!matchesAny(tc.m_file.c_str(), p->filters[0], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_file.c_str(), p->filters[1], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_test_suite, p->filters[2], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_test_suite, p->filters[3], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_name, p->filters[4], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_name, p->filters[5], false, p->case_sensitive)) + skip_me = true; + + if(!skip_me) + p->numTestCasesPassingFilters++; + + // skip the test if it is not in the execution range + if((p->last < p->numTestCasesPassingFilters && p->first <= p->last) || + (p->first > p->numTestCasesPassingFilters)) + skip_me = true; + + if(skip_me) { + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_skipped, tc); + continue; + } + + // do not execute the test if we are to only count the number of filter passing tests + if(p->count) + continue; + + // print the name of the test and don't execute it + if(p->list_test_cases) { + queryResults.push_back(&tc); + continue; + } + + // print the name of the test suite if not done already and don't execute it + if(p->list_test_suites) { + if((testSuitesPassingFilt.count(tc.m_test_suite) == 0) && tc.m_test_suite[0] != '\0') { + queryResults.push_back(&tc); + testSuitesPassingFilt.insert(tc.m_test_suite); + p->numTestSuitesPassingFilters++; + } + continue; + } + + // execute the test if it passes all the filtering + { + p->currentTest = &tc; + + p->failure_flags = TestCaseFailureReason::None; + p->seconds = 0; + + // reset atomic counters + p->numAssertsFailedCurrentTest_atomic = 0; + p->numAssertsCurrentTest_atomic = 0; + + p->subcasesPassed.clear(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_start, tc); + + p->timer.start(); + + bool run_test = true; + + do { + // reset some of the fields for subcases (except for the set of fully passed ones) + p->should_reenter = false; + p->subcasesCurrentMaxLevel = 0; + p->subcasesStack.clear(); + + p->shouldLogCurrentException = true; + + // reset stuff for logging with INFO() + p->stringifiedContexts.clear(); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + FatalConditionHandler fatalConditionHandler; // Handle signals + // execute the test + tc.m_test(); + fatalConditionHandler.reset(); +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + } catch(const TestFailureException&) { + p->failure_flags |= TestCaseFailureReason::AssertFailure; + } catch(...) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, + {translateActiveException(), false}); + p->failure_flags |= TestCaseFailureReason::Exception; + } +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + + // exit this loop if enough assertions have failed - even if there are more subcases + if(p->abort_after > 0 && + p->numAssertsFailed + p->numAssertsFailedCurrentTest_atomic >= p->abort_after) { + run_test = false; + p->failure_flags |= TestCaseFailureReason::TooManyFailedAsserts; + } + + if(p->should_reenter && run_test) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_reenter, tc); + if(!p->should_reenter) + run_test = false; + } while(run_test); + + p->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + p->currentTest = nullptr; + + // stop executing tests if enough assertions have failed + if(p->abort_after > 0 && p->numAssertsFailed >= p->abort_after) + break; + } + } + + if(!query_mode) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } else { + QueryData qdata; + qdata.run_stats = g_cs; + qdata.data = queryResults.data(); + qdata.num_data = unsigned(queryResults.size()); + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, qdata); + } + + // see these issues on the reasoning for this: + // - https://github.com/onqtam/doctest/issues/143#issuecomment-414418903 + // - https://github.com/onqtam/doctest/issues/126 + auto DOCTEST_FIX_FOR_MACOS_LIBCPP_IOSFWD_STRING_LINK_ERRORS = []() DOCTEST_NOINLINE + { std::cout << std::string(); }; + DOCTEST_FIX_FOR_MACOS_LIBCPP_IOSFWD_STRING_LINK_ERRORS(); + + return cleanup_and_return(); +} + +IReporter::~IReporter() = default; + +int IReporter::get_num_active_contexts() { return detail::g_infoContexts.size(); } +const IContextScope* const* IReporter::get_active_contexts() { + return get_num_active_contexts() ? &detail::g_infoContexts[0] : nullptr; +} + +int IReporter::get_num_stringified_contexts() { return detail::g_cs->stringifiedContexts.size(); } +const String* IReporter::get_stringified_contexts() { + return get_num_stringified_contexts() ? &detail::g_cs->stringifiedContexts[0] : nullptr; +} + +namespace detail { + void registerReporterImpl(const char* name, int priority, reporterCreatorFunc c, bool isReporter) { + if(isReporter) + getReporters().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + else + getListeners().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + } +} // namespace detail + +} // namespace doctest + +#endif // DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007) // 'function' : must be 'attribute' - see issue #182 +int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); } +DOCTEST_MSVC_SUPPRESS_WARNING_POP +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_LIBRARY_IMPLEMENTATION +#endif // DOCTEST_CONFIG_IMPLEMENT diff --git a/test/doctest/doctestAddTests.cmake b/test/doctest/doctestAddTests.cmake new file mode 100644 index 0000000..98ee4a2 --- /dev/null +++ b/test/doctest/doctestAddTests.cmake @@ -0,0 +1,81 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +set(prefix "${TEST_PREFIX}") +set(suffix "${TEST_SUFFIX}") +set(spec ${TEST_SPEC}) +set(extra_args ${TEST_EXTRA_ARGS}) +set(properties ${TEST_PROPERTIES}) +set(script) +set(suite) +set(tests) + +function(add_command NAME) + set(_args "") + foreach(_arg ${ARGN}) + if(_arg MATCHES "[^-./:a-zA-Z0-9_]") + set(_args "${_args} [==[${_arg}]==]") # form a bracket_argument + else() + set(_args "${_args} ${_arg}") + endif() + endforeach() + set(script "${script}${NAME}(${_args})\n" PARENT_SCOPE) +endfunction() + +# Run test executable to get list of available tests +if(NOT EXISTS "${TEST_EXECUTABLE}") + message(FATAL_ERROR + "Specified test executable '${TEST_EXECUTABLE}' does not exist" + ) +endif() + +if("${spec}" MATCHES .) + set(spec "--test-case=${spec}") +endif() + +execute_process( + COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" ${spec} --list-test-cases + OUTPUT_VARIABLE output + RESULT_VARIABLE result +) +if(NOT ${result} EQUAL 0) + message(FATAL_ERROR + "Error running test executable '${TEST_EXECUTABLE}':\n" + " Result: ${result}\n" + " Output: ${output}\n" + ) +endif() + +string(REPLACE "\n" ";" output "${output}") + +# Parse output +foreach(line ${output}) + if("${line}" STREQUAL "===============================================================================" OR "${line}" MATCHES [==[^\[doctest\] ]==]) + continue() + endif() + set(test ${line}) + # use escape commas to handle properly test cases with commas inside the name + string(REPLACE "," "\\," test_name ${test}) + # ...and add to script + add_command(add_test + "${prefix}${test}${suffix}" + ${TEST_EXECUTOR} + "${TEST_EXECUTABLE}" + "--test-case=${test_name}" + ${extra_args} + ) + add_command(set_tests_properties + "${prefix}${test}${suffix}" + PROPERTIES + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + ${properties} + ) + list(APPEND tests "${prefix}${test}${suffix}") +endforeach() + +# Create a list of all discovered tests, which users may use to e.g. set +# properties on the tests +add_command(set ${TEST_LIST} ${tests}) + +# Write CTest script +file(WRITE "${CTEST_FILE}" "${script}") diff --git a/test/file_tests.cpp b/test/file_tests.cpp new file mode 100644 index 0000000..575bb95 --- /dev/null +++ b/test/file_tests.cpp @@ -0,0 +1,215 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "io_service_fixture.hpp" + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("file"); + +namespace fs = cppcoro::filesystem; + +namespace +{ + class temp_dir_fixture + { + public: + + temp_dir_fixture() + { + auto tempDir = fs::temp_directory_path(); + + std::random_device random; + for (int attempt = 1;; ++attempt) + { + m_path = tempDir / std::to_string(random()); + try + { + fs::create_directories(m_path); + return; + } + catch (const fs::filesystem_error&) + { + if (attempt == 10) + { + throw; + } + } + } + } + + ~temp_dir_fixture() + { + fs::remove_all(m_path); + } + + const cppcoro::filesystem::path& temp_dir() + { + return m_path; + } + + private: + + cppcoro::filesystem::path m_path; + + }; + + class temp_dir_with_io_service_fixture : + public io_service_fixture, + public temp_dir_fixture + { + }; +} + +TEST_CASE_FIXTURE(temp_dir_fixture, "write a file") +{ + auto filePath = temp_dir() / "foo"; + + cppcoro::io_service ioService; + + auto write = [&](cppcoro::io_service& ioService) -> cppcoro::task<> + { + std::printf(" starting write\n"); std::fflush(stdout); + + auto f = cppcoro::write_only_file::open(ioService, filePath); + + CHECK(f.size() == 0); + + char buffer[1024]; + char c = 'a'; + for (int i = 0; i < sizeof(buffer); ++i, c = (c == 'z' ? 'a' : c + 1)) + { + buffer[i] = c; + } + + for (int chunk = 0; chunk < 10; ++chunk) + { + co_await f.write(chunk * sizeof(buffer), buffer, sizeof(buffer)); + } + }; + + auto read = [&](cppcoro::io_service& io) -> cppcoro::task<> + { + std::printf(" starting read\n"); std::fflush(stdout); + + auto f = cppcoro::read_only_file::open(io, filePath); + + const auto fileSize = f.size(); + + CHECK(fileSize == 10240); + + char buffer[20]; + + for (std::uint64_t i = 0; i < fileSize;) + { + auto bytesRead = co_await f.read(i, buffer, 20); + for (size_t j = 0; j < bytesRead; ++j, ++i) + { + CHECK(buffer[j] == ('a' + ((i % 1024) % 26))); + } + } + }; + + cppcoro::sync_wait(cppcoro::when_all( + [&]() -> cppcoro::task + { + auto stopOnExit = cppcoro::on_scope_exit([&] { ioService.stop(); }); + co_await write(ioService); + co_await read(ioService); + co_return 0; + }(), + [&]() -> cppcoro::task + { + ioService.process_events(); + co_return 0; + }())); +} + +TEST_CASE_FIXTURE(temp_dir_with_io_service_fixture, "read write file") +{ + auto run = [&]() -> cppcoro::task<> + { + cppcoro::io_work_scope ioScope{ io_service() }; + auto f = cppcoro::read_write_file::open(io_service(), temp_dir() / "foo.txt"); + + char buffer1[100]; + std::memset(buffer1, 0xAB, sizeof(buffer1)); + + co_await f.write(0, buffer1, sizeof(buffer1)); + + char buffer2[50]; + std::memset(buffer2, 0xCC, sizeof(buffer2)); + + co_await f.read(0, buffer2, 50); + CHECK(std::memcmp(buffer1, buffer2, 50) == 0); + + co_await f.read(50, buffer2, 50); + CHECK(std::memcmp(buffer1 + 50, buffer2, 50) == 0); + }; + + cppcoro::sync_wait(run()); +} + +TEST_CASE_FIXTURE(temp_dir_with_io_service_fixture, "cancel read") +{ + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + cppcoro::io_work_scope ioScope{ io_service() }; + auto f = cppcoro::read_write_file::open(io_service(), temp_dir() / "foo.txt"); + + f.set_size(20 * 1024 * 1024); + + cppcoro::cancellation_source canceller; + + try + { + (void)co_await cppcoro::when_all( + [&]() -> cppcoro::task + { + const auto fileSize = f.size(); + const std::size_t bufferSize = 64 * 1024; + auto buffer = std::make_unique(bufferSize); + std::uint64_t offset = 0; + while (offset < fileSize) + { + auto bytesRead = co_await f.read(offset, buffer.get(), bufferSize, canceller.token()); + offset += bytesRead; + } + WARN("should have been cancelled"); + co_return 0; + }(), + [&]() -> cppcoro::task + { + using namespace std::chrono_literals; + + co_await io_service().schedule_after(1ms); + canceller.request_cancellation(); + co_return 0; + }()); + WARN("Expected exception to be thrown"); + } + catch (const cppcoro::operation_cancelled&) + { + } + }()); +} + +TEST_SUITE_END(); diff --git a/test/generator_tests.cpp b/test/generator_tests.cpp new file mode 100644 index 0000000..92427df --- /dev/null +++ b/test/generator_tests.cpp @@ -0,0 +1,412 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include + +#include +#include +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("generator"); + +using cppcoro::generator; + +TEST_CASE("default-constructed generator is empty sequence") +{ + generator ints; + CHECK(ints.begin() == ints.end()); +} + +TEST_CASE("generator of arithmetic type returns by copy") +{ + auto f = []() -> generator + { + co_yield 1.0f; + co_yield 2.0f; + }; + + auto gen = f(); + auto iter = gen.begin(); + // TODO: Should this really be required? + //static_assert(std::is_same::value, "operator* should return float by value"); + CHECK(*iter == 1.0f); + ++iter; + CHECK(*iter == 2.0f); + ++iter; + CHECK(iter == gen.end()); +} + +TEST_CASE("generator of reference returns by reference") +{ + auto f = [](float& value) -> generator + { + co_yield value; + }; + + float value = 1.0f; + for (auto& x : f(value)) + { + CHECK(&x == &value); + x += 1.0f; + } + + CHECK(value == 2.0f); +} + +TEST_CASE("generator of const type") +{ + auto fib = []() -> generator + { + std::uint64_t a = 0, b = 1; + while (true) + { + co_yield b; + b += std::exchange(a, b); + } + }; + + std::uint64_t count = 0; + for (auto i : fib()) + { + if (i > 1'000'000) { + break; + } + ++count; + } + + // 30th fib number is 832'040 + CHECK(count == 30); +} + +TEST_CASE("value-category of fmap() matches reference type") +{ + using cppcoro::fmap; + + auto checkIsRvalue = [](auto&& x) { + static_assert(std::is_rvalue_reference_v); + static_assert(!std::is_const_v>); + CHECK(x == 123); + return x; + }; + auto checkIsLvalue = [](auto&& x) { + static_assert(std::is_lvalue_reference_v); + static_assert(!std::is_const_v>); + CHECK(x == 123); + return x; + }; + auto checkIsConstLvalue = [](auto&& x) { + static_assert(std::is_lvalue_reference_v); + static_assert(std::is_const_v>); + CHECK(x == 123); + return x; + }; + auto checkIsConstRvalue = [](auto&& x) { + static_assert(std::is_rvalue_reference_v); + static_assert(std::is_const_v>); + CHECK(x == 123); + return x; + }; + + auto consume = [](auto&& range) { + for (auto&& x : range) { + (void)x; + } + }; + + consume([]() -> generator { co_yield 123; }() | fmap(checkIsLvalue)); + consume([]() -> generator { co_yield 123; }() | fmap(checkIsConstLvalue)); + consume([]() -> generator { co_yield 123; }() | fmap(checkIsLvalue)); + consume([]() -> generator { co_yield 123; }() | fmap(checkIsConstLvalue)); + consume([]() -> generator { co_yield 123; }() | fmap(checkIsRvalue)); + consume([]() -> generator { co_yield 123; }() | fmap(checkIsConstRvalue)); +} + +TEST_CASE("generator doesn't start until its called") +{ + bool reachedA = false; + bool reachedB = false; + bool reachedC = false; + auto f = [&]() -> generator + { + reachedA = true; + co_yield 1; + reachedB = true; + co_yield 2; + reachedC = true; + }; + + auto gen = f(); + CHECK(!reachedA); + auto iter = gen.begin(); + CHECK(reachedA); + CHECK(!reachedB); + CHECK(*iter == 1); + ++iter; + CHECK(reachedB); + CHECK(!reachedC); + CHECK(*iter == 2); + ++iter; + CHECK(reachedC); + CHECK(iter == gen.end()); +} + +TEST_CASE("destroying generator before completion destructs objects on stack") +{ + bool destructed = false; + bool completed = false; + auto f = [&]() -> generator + { + auto onExit = cppcoro::on_scope_exit([&] + { + destructed = true; + }); + + co_yield 1; + co_yield 2; + completed = true; + }; + + { + auto g = f(); + auto it = g.begin(); + auto itEnd = g.end(); + CHECK(it != itEnd); + CHECK(*it == 1u); + CHECK(!destructed); + } + + CHECK(!completed); + CHECK(destructed); +} + +TEST_CASE("generator throwing before yielding first element rethrows out of begin()") +{ + class X {}; + + auto g = []() -> cppcoro::generator + { + throw X{}; + co_return; + }(); + + try + { + g.begin(); + FAIL("should have thrown"); + } + catch (const X&) + { + } +} + +TEST_CASE("generator throwing after first element rethrows out of operator++") +{ + class X {}; + + auto g = []() -> cppcoro::generator + { + co_yield 1; + throw X{}; + }(); + + auto iter = g.begin(); + REQUIRE(iter != g.end()); + try + { + ++iter; + FAIL("should have thrown"); + } + catch (const X&) + { + } +} + +namespace +{ + template + auto concat(FIRST&& first, SECOND&& second) + { + using value_type = std::remove_reference_t; + return [](FIRST first, SECOND second) -> cppcoro::generator + { + for (auto&& x : first) co_yield x; + for (auto&& y : second) co_yield y; + }(std::forward(first), std::forward(second)); + } +} + +TEST_CASE("safe capture of r-value reference args") +{ + using namespace std::string_literals; + + // Check that we can capture l-values by reference and that temporary + // values are moved into the coroutine frame. + std::string byRef = "bar"; + auto g = concat("foo"s, concat(byRef, std::vector{ 'b', 'a', 'z' })); + + byRef = "buzz"; + + std::string s; + for (char c : g) + { + s += c; + } + + CHECK(s == "foobuzzbaz"); +} + +namespace +{ + cppcoro::generator range(int start, int end) + { + for (; start < end; ++start) + { + co_yield start; + } + } +} + +TEST_CASE("fmap operator") +{ + cppcoro::generator gen = range(0, 5) + | cppcoro::fmap([](int x) { return x * 3; }); + + auto it = gen.begin(); + CHECK(*it == 0); + CHECK(*++it == 3); + CHECK(*++it == 6); + CHECK(*++it == 9); + CHECK(*++it == 12); + CHECK(++it == gen.end()); +} + +namespace +{ + template + cppcoro::generator low_pass(Range rng) + { + auto it = std::begin(rng); + const auto itEnd = std::end(rng); + + const double invCount = 1.0 / window; + double sum = 0; + + using iter_cat = + typename std::iterator_traits::iterator_category; + + if constexpr (std::is_base_of_v) + { + for (std::size_t count = 0; it != itEnd && count < window; ++it) + { + sum += *it; + ++count; + co_yield sum / count; + } + + for (; it != itEnd; ++it) + { + sum -= *(it - window); + sum += *it; + co_yield sum * invCount; + } + } + else if constexpr (std::is_base_of_v) + { + auto windowStart = it; + for (std::size_t count = 0; it != itEnd && count < window; ++it) + { + sum += *it; + ++count; + co_yield sum / count; + } + + for (; it != itEnd; ++it, ++windowStart) + { + sum -= *windowStart; + sum += *it; + co_yield sum * invCount; + } + } + else + { + // Just assume an input iterator + double buffer[window]; + + for (std::size_t count = 0; it != itEnd && count < window; ++it) + { + buffer[count] = *it; + sum += buffer[count]; + ++count; + co_yield sum / count; + } + + for (std::size_t pos = 0; it != itEnd; ++it, pos = (pos + 1 == window) ? 0 : (pos + 1)) + { + sum -= std::exchange(buffer[pos], *it); + sum += buffer[pos]; + co_yield sum * invCount; + } + } + } +} + +// HACK: Disable this test as it's causing heap corruption errors under MSVC 2017 Update 5 x86 debug builds. +// Still needs investigation of root cause. +TEST_CASE("low_pass" * doctest::skip{ true }) +{ + // With random-access iterator + { + auto gen = low_pass<4>(std::vector{ 10, 13, 10, 15, 18, 9, 11, 15 }); + auto it = gen.begin(); + CHECK(*it == 10.0); + CHECK(*++it == 11.5); + CHECK(*++it == 11.0); + CHECK(*++it == 12.0); + CHECK(*++it == 14.0); + CHECK(*++it == 13.0); + CHECK(*++it == 13.25); + CHECK(*++it == 13.25); + CHECK(++it == gen.end()); + } + + // With forward-iterator + { + auto gen = low_pass<4>(std::forward_list{ 10, 13, 10, 15, 18, 9, 11, 15 }); + auto it = gen.begin(); + CHECK(*it == 10.0); + CHECK(*++it == 11.5); + CHECK(*++it == 11.0); + CHECK(*++it == 12.0); + CHECK(*++it == 14.0); + CHECK(*++it == 13.0); + CHECK(*++it == 13.25); + CHECK(*++it == 13.25); + CHECK(++it == gen.end()); + } + + // With input-iterator + { + auto gen = low_pass<3>(range(10, 20)); + auto it = gen.begin(); + CHECK(*it == 10.0); + CHECK(*++it == 10.5); + CHECK(*++it == 11.0); + CHECK(*++it == 12.0); + CHECK(*++it == 13.0); + CHECK(*++it == 14.0); + CHECK(*++it == 15.0); + CHECK(*++it == 16.0); + CHECK(*++it == 17.0); + CHECK(*++it == 18.0); + CHECK(++it == gen.end()); + } +} + +TEST_SUITE_END(); diff --git a/test/io_service_fixture.hpp b/test/io_service_fixture.hpp new file mode 100644 index 0000000..9552ebe --- /dev/null +++ b/test/io_service_fixture.hpp @@ -0,0 +1,71 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// +#ifndef CPPCORO_TESTS_IO_SERVICE_FIXTURE_HPP_INCLUDED +#define CPPCORO_TESTS_IO_SERVICE_FIXTURE_HPP_INCLUDED + +#include + +#include +#include + +/// \brief +/// Test fixture that creates an io_service and starts up a background thread +/// to process I/O completion events. +/// +/// Thread and io_service are shutdown on destruction. +struct io_service_fixture +{ +public: + + io_service_fixture(std::uint32_t threadCount = 1) + : m_ioService() + { + m_ioThreads.reserve(threadCount); + try + { + for (std::uint32_t i = 0; i < threadCount; ++i) + { + m_ioThreads.emplace_back([this] { m_ioService.process_events(); }); + } + } + catch (...) + { + stop(); + throw; + } + } + + ~io_service_fixture() + { + stop(); + } + + cppcoro::io_service& io_service() { return m_ioService; } + +private: + + void stop() + { + m_ioService.stop(); + for (auto& thread : m_ioThreads) + { + thread.join(); + } + } + + cppcoro::io_service m_ioService; + std::vector m_ioThreads; + +}; + +template +struct io_service_fixture_with_threads : io_service_fixture +{ + io_service_fixture_with_threads() + : io_service_fixture(thread_count) + {} +}; + +#endif diff --git a/test/io_service_tests.cpp b/test/io_service_tests.cpp new file mode 100644 index 0000000..0e64f18 --- /dev/null +++ b/test/io_service_tests.cpp @@ -0,0 +1,230 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "io_service_fixture.hpp" + +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("io_service"); + +TEST_CASE("default construct") +{ + cppcoro::io_service service; + CHECK_FALSE(service.is_stop_requested()); +} + +TEST_CASE("construct with concurrency hint") +{ + cppcoro::io_service service{ 3 }; + CHECK_FALSE(service.is_stop_requested()); +} + +TEST_CASE("process_one_pending_event returns immediately when no events") +{ + cppcoro::io_service service; + CHECK(service.process_one_pending_event() == 0); + CHECK(service.process_pending_events() == 0); +} + +TEST_CASE("schedule coroutine") +{ + cppcoro::io_service service; + + bool reachedPointA = false; + bool reachedPointB = false; + auto startTask = [&](cppcoro::io_service& ioService) -> cppcoro::task<> + { + reachedPointA = true; + co_await ioService.schedule(); + reachedPointB = true; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + startTask(service), + [&]() -> cppcoro::task<> + { + CHECK(reachedPointA); + CHECK_FALSE(reachedPointB); + + service.process_pending_events(); + + CHECK(reachedPointB); + + co_return; + }())); +} + +TEST_CASE_FIXTURE(io_service_fixture_with_threads<2>, "multiple I/O threads servicing events") +{ + std::atomic completedCount = 0; + + auto runOnIoThread = [&]() -> cppcoro::task<> + { + co_await io_service().schedule(); + ++completedCount; + }; + + std::vector> tasks; + { + for (int i = 0; i < 1000; ++i) + { + tasks.emplace_back(runOnIoThread()); + } + } + + cppcoro::sync_wait(cppcoro::when_all(std::move(tasks))); + + CHECK(completedCount == 1000); +} + +TEST_CASE("Multiple concurrent timers") +{ + cppcoro::io_service ioService; + + auto startTimer = [&](std::chrono::milliseconds duration) + -> cppcoro::task + { + auto start = std::chrono::high_resolution_clock::now(); + + co_await ioService.schedule_after(duration); + + auto end = std::chrono::high_resolution_clock::now(); + + co_return end - start; + }; + + auto test = [&]() -> cppcoro::task<> + { + using namespace std::chrono; + using namespace std::chrono_literals; + + auto[time1, time2, time3] = co_await cppcoro::when_all( + startTimer(100ms), + startTimer(120ms), + startTimer(50ms)); + + MESSAGE("Waiting 100ms took " << duration_cast(time1).count() << "us"); + MESSAGE("Waiting 120ms took " << duration_cast(time2).count() << "us"); + MESSAGE("Waiting 50ms took " << duration_cast(time3).count() << "us"); + + CHECK(time1 >= 100ms); + CHECK(time2 >= 120ms); + CHECK(time3 >= 50ms); + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto stopIoOnExit = cppcoro::on_scope_exit([&] { ioService.stop(); }); + co_await test(); + }(), + [&]() -> cppcoro::task<> + { + ioService.process_events(); + co_return; + }())); +} + +TEST_CASE("Timer cancellation" + * doctest::timeout{ 5.0 }) +{ + using namespace std::literals::chrono_literals; + + cppcoro::io_service ioService; + + auto longWait = [&](cppcoro::cancellation_token ct) -> cppcoro::task<> + { + co_await ioService.schedule_after(20'000ms, ct); + }; + + auto cancelAfter = [&](cppcoro::cancellation_source source, auto duration) -> cppcoro::task<> + { + co_await ioService.schedule_after(duration); + source.request_cancellation(); + }; + + auto test = [&]() -> cppcoro::task<> + { + cppcoro::cancellation_source source; + co_await cppcoro::when_all_ready( + [&](cppcoro::cancellation_token ct) -> cppcoro::task<> + { + CHECK_THROWS_AS(co_await longWait(std::move(ct)), const cppcoro::operation_cancelled&); + }(source.token()), + cancelAfter(source, 1ms)); + }; + + auto testTwice = [&]() -> cppcoro::task<> + { + co_await test(); + co_await test(); + }; + + auto stopIoServiceAfter = [&](cppcoro::task<> task) -> cppcoro::task<> + { + co_await task.when_ready(); + ioService.stop(); + co_return co_await task.when_ready(); + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + stopIoServiceAfter(testTwice()), + [&]() -> cppcoro::task<> + { + ioService.process_events(); + co_return; + }())); +} + +TEST_CASE_FIXTURE(io_service_fixture_with_threads<1>, "Many concurrent timers") +{ + auto startTimer = [&]() -> cppcoro::task<> + { + using namespace std::literals::chrono_literals; + co_await io_service().schedule_after(50ms); + }; + + constexpr std::uint32_t taskCount = 10'000; + + auto runManyTimers = [&]() -> cppcoro::task<> + { + std::vector> tasks; + + tasks.reserve(taskCount); + + for (std::uint32_t i = 0; i < taskCount; ++i) + { + tasks.emplace_back(startTimer()); + } + + co_await cppcoro::when_all(std::move(tasks)); + }; + + auto start = std::chrono::high_resolution_clock::now(); + + cppcoro::sync_wait(runManyTimers()); + + auto end = std::chrono::high_resolution_clock::now(); + + MESSAGE( + "Waiting for " << taskCount << " x 50ms timers took " + << std::chrono::duration_cast(end - start).count() + << "ms"); +} + +TEST_SUITE_END(); diff --git a/test/ip_address_tests.cpp b/test/ip_address_tests.cpp new file mode 100644 index 0000000..a09ae32 --- /dev/null +++ b/test/ip_address_tests.cpp @@ -0,0 +1,45 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("ip_address"); + +using cppcoro::net::ip_address; +using cppcoro::net::ipv4_address; +using cppcoro::net::ipv6_address; + +TEST_CASE("default constructor") +{ + ip_address x; + CHECK(x.is_ipv4()); + CHECK(x.to_ipv4() == ipv4_address{}); +} + +TEST_CASE("to_string") +{ + ip_address a = ipv6_address{ 0xAABBCCDD00112233, 0x0102030405060708 }; + ip_address b = ipv4_address{ 192, 168, 0, 1 }; + + CHECK(a.to_string() == "aabb:ccdd:11:2233:102:304:506:708"); + CHECK(b.to_string() == "192.168.0.1"); +} + +TEST_CASE("from_string") +{ + CHECK(ip_address::from_string("") == std::nullopt); + CHECK(ip_address::from_string("foo") == std::nullopt); + CHECK(ip_address::from_string(" 192.168.0.1") == std::nullopt); + CHECK(ip_address::from_string("192.168.0.1asdf") == std::nullopt); + + CHECK(ip_address::from_string("192.168.0.1") == ipv4_address(192, 168, 0, 1)); + CHECK(ip_address::from_string("::192.168.0.1") == ipv6_address(0, 0, 0, 0, 0, 0, 0xc0a8, 0x1)); + CHECK(ip_address::from_string("aabb:ccdd:11:2233:102:304:506:708") == + ipv6_address{ 0xAABBCCDD00112233, 0x0102030405060708 }); +} + +TEST_SUITE_END(); diff --git a/test/ip_endpoint_tests.cpp b/test/ip_endpoint_tests.cpp new file mode 100644 index 0000000..6eafaeb --- /dev/null +++ b/test/ip_endpoint_tests.cpp @@ -0,0 +1,56 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("ip_endpoint"); + +using namespace cppcoro::net; + +namespace +{ + constexpr bool isMsvc15_5X86Optimised = +#if CPPCORO_COMPILER_MSVC && CPPCORO_CPU_X86 && _MSC_VER == 1912 && defined(CPPCORO_RELEASE_OPTIMISED) + true; +#else + false; +#endif +} + +// BUG: Skip this test under MSVC 15.5 x86 optimised builds due to a compiler bug +// that generates bad code. +// See https://developercommunity.visualstudio.com/content/problem/177151/bad-code-generation-under-x86-optimised-for-stdopt.html +TEST_CASE("to_string" * doctest::skip{ isMsvc15_5X86Optimised }) +{ + ip_endpoint a = ipv4_endpoint{ ipv4_address{ 192, 168, 2, 254 }, 80 }; + ip_endpoint b = ipv6_endpoint{ + *ipv6_address::from_string("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + 22 }; + + CHECK(a.to_string() == "192.168.2.254:80"); + CHECK(b.to_string() == "[2001:db8:85a3::8a2e:370:7334]:22"); +} + +TEST_CASE("from_string" * doctest::skip{ isMsvc15_5X86Optimised }) +{ + CHECK(ip_endpoint::from_string("") == std::nullopt); + CHECK(ip_endpoint::from_string("[foo]:123") == std::nullopt); + CHECK(ip_endpoint::from_string("[123]:1000") == std::nullopt); + CHECK(ip_endpoint::from_string("[10.11.12.13]:1000") == std::nullopt); + + CHECK(ip_endpoint::from_string("192.168.2.254:80") == + ipv4_endpoint{ + ipv4_address{ 192, 168, 2, 254 }, 80 }); + CHECK(ip_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:443") == + ipv6_endpoint{ + ipv6_address{ 0x2001, 0xdb8, 0x85a3, 0x0, 0x0, 0x8a2e, 0x370, 0x7334 }, + 443 }); +} + +TEST_SUITE_END(); + diff --git a/test/ipv4_address_tests.cpp b/test/ipv4_address_tests.cpp new file mode 100644 index 0000000..d241657 --- /dev/null +++ b/test/ipv4_address_tests.cpp @@ -0,0 +1,90 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include "doctest/cppcoro_doctest.h" + + +TEST_SUITE_BEGIN("ipv4_address"); + +using cppcoro::net::ipv4_address; + +TEST_CASE("DefaultConstructToZeroes") +{ + CHECK(ipv4_address{}.to_integer() == 0u); +} + +TEST_CASE("to_integer() is BigEndian") +{ + ipv4_address address{ 10, 11, 12, 13 }; + CHECK(address.to_integer() == 0x0A0B0C0Du); +} + +TEST_CASE("is_loopback()") +{ + CHECK(ipv4_address{ 127, 0, 0, 1 }.is_loopback()); + CHECK(ipv4_address{ 127, 0, 0, 50 }.is_loopback()); + CHECK(ipv4_address{ 127, 5, 10, 15 }.is_loopback()); + CHECK(!ipv4_address{ 10, 11, 12, 13 }.is_loopback()); +} + +TEST_CASE("bytes()") +{ + ipv4_address ip{ 19, 63, 129, 200 }; + CHECK(ip.bytes()[0] == 19); + CHECK(ip.bytes()[1] == 63); + CHECK(ip.bytes()[2] == 129); + CHECK(ip.bytes()[3] == 200); +} + +TEST_CASE("to_string()") +{ + CHECK(ipv4_address(0, 0, 0, 0).to_string() == "0.0.0.0"); + CHECK(ipv4_address(10, 125, 255, 7).to_string() == "10.125.255.7"); + CHECK(ipv4_address(123, 234, 101, 255).to_string() == "123.234.101.255"); +} + +TEST_CASE("from_string") +{ + // Check for some invalid strings. + CHECK(ipv4_address::from_string("") == std::nullopt); + CHECK(ipv4_address::from_string("asdf") == std::nullopt); + CHECK(ipv4_address::from_string(" 123.34.56.8") == std::nullopt); + CHECK(ipv4_address::from_string("123.34.56.8 ") == std::nullopt); + CHECK(ipv4_address::from_string("123.") == std::nullopt); + CHECK(ipv4_address::from_string("123.1") == std::nullopt); + CHECK(ipv4_address::from_string("123.12") == std::nullopt); + CHECK(ipv4_address::from_string("123.12.") == std::nullopt); + CHECK(ipv4_address::from_string("123.12.4") == std::nullopt); + CHECK(ipv4_address::from_string("123.12.45") == std::nullopt); + CHECK(ipv4_address::from_string("123.12.45.") == std::nullopt); + + // Overflow of individual parts + CHECK(ipv4_address::from_string("456.12.45.30") == std::nullopt); + CHECK(ipv4_address::from_string("45.256.45.30") == std::nullopt); + CHECK(ipv4_address::from_string("45.25.677.30") == std::nullopt); + CHECK(ipv4_address::from_string("123.12.45.301") == std::nullopt); + + // Can't parse octal yet. + CHECK(ipv4_address::from_string("00") == std::nullopt); + CHECK(ipv4_address::from_string("012345") == std::nullopt); + CHECK(ipv4_address::from_string("045.25.67.30") == std::nullopt); + CHECK(ipv4_address::from_string("45.025.67.30") == std::nullopt); + CHECK(ipv4_address::from_string("45.25.067.30") == std::nullopt); + CHECK(ipv4_address::from_string("45.25.67.030") == std::nullopt); + + // Parse single integer format + CHECK(ipv4_address::from_string("0") == ipv4_address(0)); + CHECK(ipv4_address::from_string("1") == ipv4_address(0, 0, 0, 1)); + CHECK(ipv4_address::from_string("255") == ipv4_address(0, 0, 0, 255)); + CHECK(ipv4_address::from_string("43534243") == ipv4_address(43534243)); + + // Parse dotted decimal format + CHECK(ipv4_address::from_string("45.25.67.30") == ipv4_address(45, 25, 67, 30)); + CHECK(ipv4_address::from_string("0.0.0.0") == ipv4_address(0, 0, 0, 0)); + CHECK(ipv4_address::from_string("1.2.3.4") == ipv4_address(1, 2, 3, 4)); +} +TEST_SUITE_END(); diff --git a/test/ipv4_endpoint_tests.cpp b/test/ipv4_endpoint_tests.cpp new file mode 100644 index 0000000..7d817bb --- /dev/null +++ b/test/ipv4_endpoint_tests.cpp @@ -0,0 +1,33 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("ip_endpoint"); + +using namespace cppcoro::net; + +TEST_CASE("to_string") +{ + CHECK(ipv4_endpoint{ ipv4_address{ 192, 168, 2, 254 }, 80 }.to_string() == "192.168.2.254:80"); +} + +TEST_CASE("from_string") +{ + CHECK(ipv4_endpoint::from_string("") == std::nullopt); + CHECK(ipv4_endpoint::from_string(" ") == std::nullopt); + CHECK(ipv4_endpoint::from_string("100") == std::nullopt); + CHECK(ipv4_endpoint::from_string("100.10.200.20") == std::nullopt); + CHECK(ipv4_endpoint::from_string("100.10.200.20:") == std::nullopt); + CHECK(ipv4_endpoint::from_string("100.10.200.20::80") == std::nullopt); + CHECK(ipv4_endpoint::from_string("100.10.200.20 80") == std::nullopt); + + CHECK(ipv4_endpoint::from_string("192.168.2.254:80") == + ipv4_endpoint{ ipv4_address{ 192, 168, 2, 254 }, 80 }); +} + +TEST_SUITE_END(); diff --git a/test/ipv6_address_tests.cpp b/test/ipv6_address_tests.cpp new file mode 100644 index 0000000..b8fd132 --- /dev/null +++ b/test/ipv6_address_tests.cpp @@ -0,0 +1,154 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include "doctest/cppcoro_doctest.h" + + +TEST_SUITE_BEGIN("ipv6_address"); + +using cppcoro::net::ipv6_address; + +TEST_CASE("default constructor") +{ + ipv6_address zero; + for (std::uint8_t i = 0; i < 16; ++i) + { + CHECK(zero.bytes()[i] == 0); + } + + CHECK(zero == ipv6_address::unspecified()); +} + +TEST_CASE("to_string") +{ + CHECK(ipv6_address(0, 0).to_string() == "::"); + CHECK(ipv6_address::loopback().to_string() == "::1"); + + CHECK( + ipv6_address(0x0102030405060708, 0x090A0B0C0D0E0F10).to_string() == + "102:304:506:708:90a:b0c:d0e:f10"); + CHECK( + ipv6_address(0x0001001001001000, 0x0).to_string() == + "1:10:100:1000::"); + CHECK( + ipv6_address(0x0002030405060708, 0x090A0B0C0D0E0F10).to_string() == + "2:304:506:708:90a:b0c:d0e:f10"); + CHECK( + ipv6_address(0x0000030405060708, 0x090A0B0C0D0E0F10).to_string() == + "0:304:506:708:90a:b0c:d0e:f10"); + CHECK( + ipv6_address(0x0000000005060708, 0x090A0B0C0D0E0F10).to_string() == + "::506:708:90a:b0c:d0e:f10"); + CHECK( + ipv6_address(0x0102030400000000, 0x00000B0C0D0E0F10).to_string() == + "102:304::b0c:d0e:f10"); + CHECK( + ipv6_address(0x0102030405060708, 0x090A0B0C0D0E0000).to_string() == + "102:304:506:708:90a:b0c:d0e:0"); + CHECK( + ipv6_address(0x0102030405060708, 0x090A0B0C00000000).to_string() == + "102:304:506:708:90a:b0c::"); + + // Check that it contracts the first of multiple equal-length zero runs. + CHECK( + ipv6_address(0x0102030400000000, 0x090A0B0C00000000).to_string() == + "102:304::90a:b0c:0:0"); +} + +TEST_CASE("from_string") +{ + CHECK(ipv6_address::from_string("") == std::nullopt); + CHECK(ipv6_address::from_string("123") == std::nullopt); + CHECK(ipv6_address::from_string("foo") == std::nullopt); + CHECK(ipv6_address::from_string(":1234") == std::nullopt); + CHECK(ipv6_address::from_string("0102:0304:0506:0708:090a:0b0c:0d0e:0f10 ") == std::nullopt); + CHECK( + ipv6_address::from_string(" 0102:0304:0506:0708:090a:0b0c:0d0e:0f10") == + std::nullopt); + CHECK( + ipv6_address::from_string("0102:0304:0506:0708:090a:0b0c:0d0e:0f10:") == + std::nullopt); + CHECK( + ipv6_address::from_string("0102:0304:0506:0708:090a:0b0c:0d0e") == + std::nullopt); + CHECK( + ipv6_address::from_string("01022:0304:0506:0708:090a:0b0c:0d0e:0f10") == + std::nullopt); + CHECK( + ipv6_address::from_string("0102:0304:0506:192.168.0.1:0b0c:0d0e:0f10") == + std::nullopt); + CHECK(ipv6_address::from_string("::") == ipv6_address(0, 0)); + CHECK(ipv6_address::from_string("::1") == ipv6_address::loopback()); + CHECK(ipv6_address::from_string("::01") == ipv6_address::loopback()); + CHECK(ipv6_address::from_string("::001") == ipv6_address::loopback()); + CHECK(ipv6_address::from_string("::0001") == ipv6_address::loopback()); + CHECK( + ipv6_address::from_string("0102:0304:0506:0708:090a:0b0c:0d0e:0f10") == + ipv6_address(0x0102030405060708, 0x090A0B0C0D0E0F10)); + CHECK( + ipv6_address::from_string("0002:0304:0506:0708:090a:0b0c:0d0e:0f10") == + ipv6_address(0x0002030405060708, 0x090A0B0C0D0E0F10)); + CHECK( + ipv6_address::from_string("0000:0304:0506:0708:090a:0b0c:0d0e:0f10") == + ipv6_address(0x0000030405060708, 0x090A0B0C0D0E0F10)); + CHECK( + ipv6_address::from_string("::0506:0708:090a:0b0c:0d0e:0f10") == + ipv6_address(0x0000000005060708, 0x090A0B0C0D0E0F10)); + CHECK( + ipv6_address::from_string("0102:0304::0b0c:0d0e:0f10") == + ipv6_address(0x0102030400000000, 0x00000B0C0D0E0F10)); + CHECK( + ipv6_address::from_string("0102:0304:0506:0708:090a:0b0c::") == + ipv6_address(0x0102030405060708, 0x090A0B0C00000000)); + CHECK( + ipv6_address::from_string("2001:db8:85a3:8d3:1319:8a2e:370:7348") == + ipv6_address(0x20010db885a308d3, 0x13198a2e03707348)); +} + +TEST_CASE("from_string IPv4 interop format") +{ + CHECK( + ipv6_address::from_string("::ffff:192.168.0.1") == + ipv6_address(0x0, 0xffffc0a80001)); + CHECK( + ipv6_address::from_string("0102:0304::128.69.32.17") == + ipv6_address(0x0102030400000000, 0x0000000080452011)); + CHECK( + ipv6_address::from_string("0102:0304::128.69.32.17") == + ipv6_address(0x0102030400000000, 0x0000000080452011)); + + // Hexadecimal chars in dotted decimal part + CHECK(ipv6_address::from_string("64:ff9b::12f.100.30.1") == std::nullopt); + CHECK(ipv6_address::from_string("64:ff9b::123.10a.30.1") == std::nullopt); + CHECK(ipv6_address::from_string("64:ff9b::123.100.3d.1") == std::nullopt); + CHECK(ipv6_address::from_string("64:ff9b::12f.100.30.f4") == std::nullopt); + + // Overflow of individual parts of dotted decimal notation + CHECK(ipv6_address::from_string("::ffff:456.12.45.30") == std::nullopt); + CHECK(ipv6_address::from_string("::ffff:45.256.45.30") == std::nullopt); + CHECK(ipv6_address::from_string("::ffff:45.25.677.30") == std::nullopt); + CHECK(ipv6_address::from_string("::ffff:123.12.45.301") == std::nullopt); +} + +TEST_CASE("operator<") +{ + ipv6_address a(0x0, 0x1); + ipv6_address b(0xff00000000000011, 0xee00000000000022); + ipv6_address c(0xee00000000000022, 0xee00000000000022); + ipv6_address d(0xee00000000000022, 0xff00000000000011); + + CHECK(a <= a); + CHECK(a < b); + CHECK(a < c); + CHECK(a < d); + CHECK(b >= b); + CHECK(b > c); + CHECK(b > d); + CHECK(c < d); +} + +TEST_SUITE_END(); diff --git a/test/ipv6_endpoint_tests.cpp b/test/ipv6_endpoint_tests.cpp new file mode 100644 index 0000000..b72350f --- /dev/null +++ b/test/ipv6_endpoint_tests.cpp @@ -0,0 +1,55 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("ipv6_endpoint"); + +using namespace cppcoro::net; + +namespace +{ + constexpr bool isMsvc15_5X86Optimised = +#if CPPCORO_COMPILER_MSVC && CPPCORO_CPU_X86 && _MSC_VER == 1912 && defined(CPPCORO_RELEASE_OPTIMISED) + true; +#else + false; +#endif +} + +// BUG: MSVC 15.5 x86 optimised builds generates bad code +TEST_CASE("to_string" * doctest::skip{ isMsvc15_5X86Optimised }) +{ + CHECK(ipv6_endpoint{ ipv6_address{ 0x20010db885a30000, 0x00008a2e03707334 }, 80 }.to_string() == + "[2001:db8:85a3::8a2e:370:7334]:80"); +} + +// BUG: MSVC 15.5 x86 optimised builds generates bad code +TEST_CASE("from_string" * doctest::skip{ isMsvc15_5X86Optimised }) +{ + CHECK(ipv6_endpoint::from_string("") == std::nullopt); + CHECK(ipv6_endpoint::from_string(" ") == std::nullopt); + CHECK(ipv6_endpoint::from_string("asdf") == std::nullopt); + CHECK(ipv6_endpoint::from_string("100:100") == std::nullopt); + CHECK(ipv6_endpoint::from_string("100.10.200.20:100") == std::nullopt); + CHECK(ipv6_endpoint::from_string("2001:0db8:85a3:0000:0000:8a2e:0370:7334") == std::nullopt); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334") == std::nullopt); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]") == std::nullopt); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:") == std::nullopt); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334] :123") == std::nullopt); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:65536") == std::nullopt); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:6553600") == std::nullopt); + + CHECK(ipv6_endpoint::from_string("[::]:0") == ipv6_endpoint{}); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:80") == + ipv6_endpoint{ ipv6_address{ 0x20010db885a30000, 0x00008a2e03707334 }, 80 }); + CHECK(ipv6_endpoint::from_string("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:65535") == + ipv6_endpoint{ ipv6_address{ 0x20010db885a30000, 0x00008a2e03707334 }, 65535 }); +} + +TEST_SUITE_END(); diff --git a/test/main.cpp b/test/main.cpp new file mode 100644 index 0000000..2956c4d --- /dev/null +++ b/test/main.cpp @@ -0,0 +1,7 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest/cppcoro_doctest.h" diff --git a/test/multi_producer_sequencer_tests.cpp b/test/multi_producer_sequencer_tests.cpp new file mode 100644 index 0000000..0513ed6 --- /dev/null +++ b/test/multi_producer_sequencer_tests.cpp @@ -0,0 +1,205 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +DOCTEST_TEST_SUITE_BEGIN("multi_producer_sequencer"); + +using namespace cppcoro; + +namespace +{ + task<> one_at_a_time_producer( + static_thread_pool& tp, + multi_producer_sequencer& sequencer, + std::uint64_t buffer[], + std::uint64_t iterationCount) + { + if (iterationCount == 0) co_return; + + co_await tp.schedule(); + + const std::size_t bufferSize = sequencer.buffer_size(); + const std::size_t mask = bufferSize - 1; + + std::uint64_t i = 0; + while (i < iterationCount) + { + auto seq = co_await sequencer.claim_one(tp); + buffer[seq & mask] = ++i; + sequencer.publish(seq); + } + + auto finalSeq = co_await sequencer.claim_one(tp); + buffer[finalSeq & mask] = 0; + sequencer.publish(finalSeq); + } + + task<> batch_producer( + static_thread_pool& tp, + multi_producer_sequencer& sequencer, + std::uint64_t buffer[], + std::uint64_t iterationCount, + std::size_t maxBatchSize) + { + const std::size_t bufferSize = sequencer.buffer_size(); + + std::uint64_t i = 0; + while (i < iterationCount) + { + const std::size_t batchSize = static_cast( + std::min(maxBatchSize, iterationCount - i)); + auto sequences = co_await sequencer.claim_up_to(batchSize, tp); + for (auto seq : sequences) + { + buffer[seq % bufferSize] = ++i; + } + sequencer.publish(sequences); + } + + auto finalSeq = co_await sequencer.claim_one(tp); + buffer[finalSeq % bufferSize] = 0; + sequencer.publish(finalSeq); + } + + task consumer( + static_thread_pool& tp, + const multi_producer_sequencer& sequencer, + sequence_barrier& readBarrier, + const std::uint64_t buffer[], + std::uint32_t producerCount) + { + co_await tp.schedule(); + + const std::size_t mask = sequencer.buffer_size() - 1; + + std::uint64_t sum = 0; + + std::uint32_t endCount = 0; + std::size_t nextToRead = 0; + do + { + std::size_t available = co_await sequencer.wait_until_published(nextToRead, nextToRead - 1, tp); + do + { + const auto& value = buffer[nextToRead & mask]; + sum += value; + + // Zero value is sentinel that indicates the end of one of the streams. + const bool isEndOfStream = value == 0; + endCount += isEndOfStream ? 1 : 0; + } while (nextToRead++ != available); + + // Notify that we've finished processing up to 'available'. + readBarrier.publish(available); + } while (endCount < producerCount); + + co_return sum; + } +} + +DOCTEST_TEST_CASE("two producers (batch) / single consumer") +{ + static_thread_pool tp{ 3 }; + + // Allow time for threads to start up. + using namespace std::chrono_literals; + std::this_thread::sleep_for(1ms); + + constexpr std::size_t batchSize = 10; + constexpr std::size_t bufferSize = 16384; + + sequence_barrier readBarrier; + multi_producer_sequencer sequencer(readBarrier, bufferSize); + + constexpr std::uint64_t iterationCount = 1'000'000; + + std::uint64_t buffer[bufferSize]; + + auto startTime = std::chrono::high_resolution_clock::now(); + + constexpr std::uint32_t producerCount = 2; + auto result = std::get<0>(sync_wait(when_all( + consumer(tp, sequencer, readBarrier, buffer, producerCount), + batch_producer(tp, sequencer, buffer, iterationCount, batchSize), + batch_producer(tp, sequencer, buffer, iterationCount, batchSize)))); + + auto endTime = std::chrono::high_resolution_clock::now(); + + auto totalTimeInNs = std::chrono::duration_cast(endTime - startTime).count(); + + MESSAGE( + "Producers = " << producerCount + << ", BatchSize = " << batchSize + << ", MessagesPerProducer = " << iterationCount + << ", TotalTime = " << totalTimeInNs/1000 << "us" + << ", TimePerMessage = " << totalTimeInNs/double(iterationCount * producerCount) << "ns" + << ", MessagesPerSecond = " << 1'000'000'000 * (producerCount * iterationCount) / totalTimeInNs); + + constexpr std::uint64_t expectedResult = + producerCount * std::uint64_t(iterationCount) * std::uint64_t(iterationCount + 1) / 2; + + CHECK(result == expectedResult); +} + +DOCTEST_TEST_CASE("two producers (single) / single consumer") +{ + static_thread_pool tp{ 3 }; + + // Allow time for threads to start up. + using namespace std::chrono_literals; + std::this_thread::sleep_for(1ms); + + constexpr std::size_t bufferSize = 16384; + + sequence_barrier readBarrier; + multi_producer_sequencer sequencer(readBarrier, bufferSize); + + constexpr std::uint64_t iterationCount = 1'000'000; + + std::uint64_t buffer[bufferSize]; + + auto startTime = std::chrono::high_resolution_clock::now(); + + constexpr std::uint32_t producerCount = 2; + auto result = std::get<0>(sync_wait(when_all( + consumer(tp, sequencer, readBarrier, buffer, producerCount), + one_at_a_time_producer(tp, sequencer, buffer, iterationCount), + one_at_a_time_producer(tp, sequencer, buffer, iterationCount)))); + + auto endTime = std::chrono::high_resolution_clock::now(); + + auto totalTimeInNs = std::chrono::duration_cast(endTime - startTime).count(); + + MESSAGE( + "Producers = " << producerCount + << ", NoBatch" + << ", MessagesPerProducer = " << iterationCount + << ", TotalTime = " << totalTimeInNs / 1000 << "us" + << ", TimePerMessage = " << totalTimeInNs / double(iterationCount * producerCount) << "ns" + << ", MessagesPerSecond = " << 1'000'000'000 * (producerCount * iterationCount) / totalTimeInNs); + + constexpr std::uint64_t expectedResult = + producerCount * std::uint64_t(iterationCount) * std::uint64_t(iterationCount + 1) / 2; + + CHECK(result == expectedResult); +} + +DOCTEST_TEST_SUITE_END(); diff --git a/test/recursive_generator_tests.cpp b/test/recursive_generator_tests.cpp new file mode 100644 index 0000000..14fce3c --- /dev/null +++ b/test/recursive_generator_tests.cpp @@ -0,0 +1,424 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include + +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("recursive_generator"); + +using cppcoro::recursive_generator; + +TEST_CASE("default constructed recursive_generator is empty") +{ + recursive_generator ints; + CHECK(ints.begin() == ints.end()); +} + +TEST_CASE("non-recursive use of recursive_generator") +{ + auto f = []() -> recursive_generator + { + co_yield 1.0f; + co_yield 2.0f; + }; + + auto gen = f(); + auto iter = gen.begin(); + CHECK(*iter == 1.0f); + ++iter; + CHECK(*iter == 2.0f); + ++iter; + CHECK(iter == gen.end()); +} + +TEST_CASE("throw before first yield") +{ + class MyException : public std::exception {}; + + auto f = []() -> recursive_generator + { + throw MyException{}; + co_return; + }; + + auto gen = f(); + try + { + auto iter = gen.begin(); + CHECK(false); + } + catch (MyException) + { + CHECK(true); + } +} + +TEST_CASE("throw after first yield") +{ + class MyException : public std::exception {}; + + auto f = []() -> recursive_generator + { + co_yield 1; + throw MyException{}; + }; + + auto gen = f(); + auto iter = gen.begin(); + CHECK(*iter == 1u); + try + { + ++iter; + CHECK(false); + } + catch (MyException) + { + CHECK(true); + } +} + +TEST_CASE("generator doesn't start executing until begin is called") +{ + bool reachedA = false; + bool reachedB = false; + bool reachedC = false; + auto f = [&]() -> recursive_generator + { + reachedA = true; + co_yield 1; + reachedB = true; + co_yield 2; + reachedC = true; + }; + + auto gen = f(); + CHECK(!reachedA); + auto iter = gen.begin(); + CHECK(reachedA); + CHECK(!reachedB); + CHECK(*iter == 1u); + ++iter; + CHECK(reachedB); + CHECK(!reachedC); + CHECK(*iter == 2u); + ++iter; + CHECK(reachedC); + CHECK(iter == gen.end()); +} + +TEST_CASE("destroying generator before completion destructs objects on stack") +{ + bool destructed = false; + bool completed = false; + auto f = [&]() -> recursive_generator + { + auto onExit = cppcoro::on_scope_exit([&] + { + destructed = true; + }); + + co_yield 1; + co_yield 2; + completed = true; + }; + + { + auto g = f(); + auto it = g.begin(); + auto itEnd = g.end(); + CHECK(*it == 1u); + CHECK(!destructed); + } + + CHECK(!completed); + CHECK(destructed); +} + +TEST_CASE("simple recursive yield") +{ + auto f = [](int n, auto& f) -> recursive_generator + { + co_yield n; + if (n > 0) + { + co_yield f(n - 1, f); + co_yield n; + } + }; + + auto f2 = [&f](int n) + { + return f(n, f); + }; + + { + auto gen = f2(1); + auto iter = gen.begin(); + CHECK(*iter == 1u); + ++iter; + CHECK(*iter == 0u); + ++iter; + CHECK(*iter == 1u); + ++iter; + CHECK(iter == gen.end()); + } + + { + auto gen = f2(2); + auto iter = gen.begin(); + CHECK(*iter == 2u); + ++iter; + CHECK(*iter == 1u); + ++iter; + CHECK(*iter == 0u); + ++iter; + CHECK(*iter == 1u); + ++iter; + CHECK(*iter == 2u); + ++iter; + CHECK(iter == gen.end()); + } +} + +TEST_CASE("nested yield that yields nothing") +{ + auto f = []() -> recursive_generator + { + co_return; + }; + + auto g = [&f]() -> recursive_generator + { + co_yield 1; + co_yield f(); + co_yield 2; + }; + + auto gen = g(); + auto iter = gen.begin(); + CHECK(*iter == 1u); + ++iter; + CHECK(*iter == 2u); + ++iter; + CHECK(iter == gen.end()); +} + +TEST_CASE("exception thrown from recursive call can be caught by caller") +{ + class SomeException : public std::exception {}; + + auto f = [](std::uint32_t depth, auto&& f) -> recursive_generator + { + if (depth == 1u) + { + throw SomeException{}; + } + + co_yield 1; + + try + { + co_yield f(1, f); + } + catch (SomeException) + { + } + + co_yield 2; + }; + + auto gen = f(0, f); + auto iter = gen.begin(); + CHECK(*iter == 1u); + ++iter; + CHECK(*iter == 2u); + ++iter; + CHECK(iter == gen.end()); +} + +TEST_CASE("exceptions thrown from nested call can be caught by caller") +{ + class SomeException : public std::exception {}; + + auto f = [](std::uint32_t depth, auto&& f) -> recursive_generator + { + if (depth == 4u) + { + throw SomeException{}; + } + else if (depth == 3u) + { + co_yield 3; + + try + { + co_yield f(4, f); + } + catch (SomeException) + { + } + + co_yield 33; + + throw SomeException{}; + } + else if (depth == 2u) + { + bool caught = false; + try + { + co_yield f(3, f); + } + catch (SomeException) + { + caught = true; + } + + if (caught) + { + co_yield 2; + } + } + else + { + co_yield 1; + co_yield f(2, f); + co_yield f(3, f); + } + }; + + auto gen = f(1, f); + auto iter = gen.begin(); + CHECK(*iter == 1u); + ++iter; + CHECK(*iter == 3u); + ++iter; + CHECK(*iter == 33u); + ++iter; + CHECK(*iter == 2u); + ++iter; + CHECK(*iter == 3u); + ++iter; + CHECK(*iter == 33u); + try + { + ++iter; + CHECK(false); + } + catch (SomeException) + { + } + + CHECK(iter == gen.end()); +} + +namespace +{ + recursive_generator iterate_range(std::uint32_t begin, std::uint32_t end) + { + if ((end - begin) <= 10u) + { + for (std::uint32_t i = begin; i < end; ++i) + { + co_yield i; + } + } + else + { + std::uint32_t mid = begin + (end - begin) / 2; + co_yield iterate_range(begin, mid); + co_yield iterate_range(mid, end); + } + } +} + +TEST_CASE("recursive iteration performance") +{ + const std::uint32_t count = 100000; + + auto start = std::chrono::high_resolution_clock::now(); + + std::uint64_t sum = 0; + for (auto i : iterate_range(0, count)) + { + sum += i; + } + + auto end = std::chrono::high_resolution_clock::now(); + + CHECK(sum == (std::uint64_t(count) * (count - 1)) / 2); + + const auto timeTakenUs = std::chrono::duration_cast(end - start).count(); + MESSAGE("Range iteration of " << count << "elements took " << timeTakenUs << "us"); +} + +TEST_CASE("usage in standard algorithms") +{ + { + auto a = iterate_range(5, 30); + auto b = iterate_range(5, 30); + CHECK(std::equal(a.begin(), a.end(), b.begin(), b.end())); + } + + { + auto a = iterate_range(5, 30); + auto b = iterate_range(5, 300); + CHECK(!std::equal(a.begin(), a.end(), b.begin(), b.end())); + } +} + +namespace +{ + recursive_generator range(int start, int end) + { + while (start < end) + { + co_yield start++; + } + } + + recursive_generator range_chunks(int start, int end, int runLength, int stride) + { + while (start < end) + { + co_yield range(start, std::min(end, start + runLength)); + start += stride; + } + } +} + +TEST_CASE("fmap operator") +{ + // 0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 20, 21, 22, 23, 24 + cppcoro::generator gen = range_chunks(0, 30, 5, 10) + | cppcoro::fmap([](int x) { return x * 3; }); + + auto it = gen.begin(); + CHECK(*it == 0); + CHECK(*++it == 3); + CHECK(*++it == 6); + CHECK(*++it == 9); + CHECK(*++it == 12); + CHECK(*++it == 30); + CHECK(*++it == 33); + CHECK(*++it == 36); + CHECK(*++it == 39); + CHECK(*++it == 42); + CHECK(*++it == 60); + CHECK(*++it == 63); + CHECK(*++it == 66); + CHECK(*++it == 69); + CHECK(*++it == 72); + CHECK(++it == gen.end()); +} + +TEST_SUITE_END(); diff --git a/test/scheduling_operator_tests.cpp b/test/scheduling_operator_tests.cpp new file mode 100644 index 0000000..1e3f061 --- /dev/null +++ b/test/scheduling_operator_tests.cpp @@ -0,0 +1,290 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include + +#include "io_service_fixture.hpp" + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("schedule/resume_on"); + +TEST_CASE_FIXTURE(io_service_fixture, "schedule_on task<> function") +{ + auto mainThreadId = std::this_thread::get_id(); + + std::thread::id ioThreadId; + + auto start = [&]() -> cppcoro::task<> + { + ioThreadId = std::this_thread::get_id(); + CHECK(ioThreadId != mainThreadId); + co_return; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + CHECK(std::this_thread::get_id() == mainThreadId); + + co_await schedule_on(io_service(), start()); + + // TODO: Uncomment this check once the implementation of task + // guarantees that the continuation will resume on the same thread + // that the task completed on. Currently it's possible to resume on + // the thread that launched the task if it completes on another thread + // before the current thread could attach the continuation after it + // suspended. See cppcoro issue #79. + // + // The long-term solution here is to use the symmetric-transfer capability + // to avoid the use of atomics and races, but we're still waiting for MSVC to + // implement this (doesn't seem to be implemented as of VS 2017.8 Preview 5) + //CHECK(std::this_thread::get_id() == ioThreadId); + }()); +} + +TEST_CASE_FIXTURE(io_service_fixture, "schedule_on async_generator<> function") +{ + auto mainThreadId = std::this_thread::get_id(); + + std::thread::id ioThreadId; + + auto makeSequence = [&]() -> cppcoro::async_generator + { + ioThreadId = std::this_thread::get_id(); + CHECK(ioThreadId != mainThreadId); + + co_yield 1; + + CHECK(std::this_thread::get_id() == ioThreadId); + + co_yield 2; + + CHECK(std::this_thread::get_id() == ioThreadId); + + co_yield 3; + + CHECK(std::this_thread::get_id() == ioThreadId); + + co_return; + }; + + cppcoro::io_service otherIoService; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + CHECK(std::this_thread::get_id() == mainThreadId); + + auto seq = schedule_on(io_service(), makeSequence()); + + int expected = 1; + for (auto iter = co_await seq.begin(); iter != seq.end(); co_await ++iter) + { + int value = *iter; + CHECK(value == expected++); + + // Transfer exection back to main thread before + // awaiting next item in the loop to chck that + // the generator is resumed on io_service() thread. + co_await otherIoService.schedule(); + } + + otherIoService.stop(); + }(), + [&]() -> cppcoro::task<> + { + otherIoService.process_events(); + co_return; + }())); +} + +TEST_CASE_FIXTURE(io_service_fixture, "resume_on task<> function") +{ + auto mainThreadId = std::this_thread::get_id(); + + auto start = [&]() -> cppcoro::task<> + { + CHECK(std::this_thread::get_id() == mainThreadId); + co_return; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + CHECK(std::this_thread::get_id() == mainThreadId); + + co_await resume_on(io_service(), start()); + + // NOTE: This check could potentially spuriously fail with the current + // implementation of task. See cppcoro issue #79. + CHECK(std::this_thread::get_id() != mainThreadId); + }()); +} + +constexpr bool isMsvc15_4X86Optimised = +#if defined(_MSC_VER) && _MSC_VER == 1911 && defined(_M_IX86) && !defined(_DEBUG) + true; +#else + false; +#endif + +// Disable under MSVC 15.4 X86 Optimised due to presumed compiler bug that causes +// an access violation. Seems to be fixed under MSVC 15.5. +TEST_CASE_FIXTURE(io_service_fixture, "resume_on async_generator<> function" + * doctest::skip{ isMsvc15_4X86Optimised }) +{ + auto mainThreadId = std::this_thread::get_id(); + + std::thread::id ioThreadId; + + auto makeSequence = [&]() -> cppcoro::async_generator + { + co_await io_service().schedule(); + + ioThreadId = std::this_thread::get_id(); + + CHECK(ioThreadId != mainThreadId); + + co_yield 1; + + co_yield 2; + + co_await io_service().schedule(); + + co_yield 3; + + co_await io_service().schedule(); + + co_return; + }; + + cppcoro::io_service otherIoService; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto stopOnExit = cppcoro::on_scope_exit([&] { otherIoService.stop(); }); + + CHECK(std::this_thread::get_id() == mainThreadId); + + auto seq = resume_on(otherIoService, makeSequence()); + + int expected = 1; + for (auto iter = co_await seq.begin(); iter != seq.end(); co_await ++iter) + { + int value = *iter; + // Every time we receive a value it should be on our requested + // scheduler (ie. main thread) + CHECK(std::this_thread::get_id() == mainThreadId); + CHECK(value == expected++); + + // Occasionally transfer execution to a different thread before + // awaiting next element. + if (value == 2) + { + co_await io_service().schedule(); + } + } + + otherIoService.stop(); + }(), + [&]() -> cppcoro::task<> + { + otherIoService.process_events(); + co_return; + }())); +} + +TEST_CASE_FIXTURE(io_service_fixture, "schedule_on task<> pipe syntax") +{ + auto mainThreadId = std::this_thread::get_id(); + + auto makeTask = [&]() -> cppcoro::task + { + CHECK(std::this_thread::get_id() != mainThreadId); + co_return 123; + }; + + auto triple = [&](int x) + { + CHECK(std::this_thread::get_id() != mainThreadId); + return x * 3; + }; + + CHECK(cppcoro::sync_wait(makeTask() | schedule_on(io_service())) == 123); + + // Shouldn't matter where in sequence schedule_on() appears since it applies + // at the start of the pipeline (ie. before first task starts). + CHECK(cppcoro::sync_wait(makeTask() | schedule_on(io_service()) | cppcoro::fmap(triple)) == 369); + CHECK(cppcoro::sync_wait(makeTask() | cppcoro::fmap(triple) | schedule_on(io_service())) == 369); +} + +TEST_CASE_FIXTURE(io_service_fixture, "resume_on task<> pipe syntax") +{ + auto mainThreadId = std::this_thread::get_id(); + + auto makeTask = [&]() -> cppcoro::task + { + CHECK(std::this_thread::get_id() == mainThreadId); + co_return 123; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + cppcoro::task t = makeTask() | cppcoro::resume_on(io_service()); + CHECK(co_await t == 123); + CHECK(std::this_thread::get_id() != mainThreadId); + }()); +} + +TEST_CASE_FIXTURE(io_service_fixture, "resume_on task<> pipe syntax multiple uses") +{ + auto mainThreadId = std::this_thread::get_id(); + + auto makeTask = [&]() -> cppcoro::task + { + CHECK(std::this_thread::get_id() == mainThreadId); + co_return 123; + }; + + auto triple = [&](int x) + { + CHECK(std::this_thread::get_id() != mainThreadId); + return x * 3; + }; + + cppcoro::io_service otherIoService; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto stopOnExit = cppcoro::on_scope_exit([&] { otherIoService.stop(); }); + + CHECK(std::this_thread::get_id() == mainThreadId); + + cppcoro::task t = + makeTask() + | cppcoro::resume_on(io_service()) + | cppcoro::fmap(triple) + | cppcoro::resume_on(otherIoService); + + CHECK(co_await t == 369); + + CHECK(std::this_thread::get_id() == mainThreadId); + }(), + [&]() -> cppcoro::task<> + { + otherIoService.process_events(); + co_return; + }())); +} + +TEST_SUITE_END(); diff --git a/test/sequence_barrier_tests.cpp b/test/sequence_barrier_tests.cpp new file mode 100644 index 0000000..3097d51 --- /dev/null +++ b/test/sequence_barrier_tests.cpp @@ -0,0 +1,213 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "doctest/cppcoro_doctest.h" + +DOCTEST_TEST_SUITE_BEGIN("sequence_barrier"); + +using namespace cppcoro; + +DOCTEST_TEST_CASE("default construction") +{ + sequence_barrier barrier; + CHECK(barrier.last_published() == sequence_traits::initial_sequence); + barrier.publish(3); + CHECK(barrier.last_published() == 3); +} + +DOCTEST_TEST_CASE("constructing with initial sequence number") +{ + sequence_barrier barrier{ 100 }; + CHECK(barrier.last_published() == 100); +} + +DOCTEST_TEST_CASE("wait_until_published single-threaded") +{ + inline_scheduler scheduler; + + sequence_barrier barrier; + bool reachedA = false; + bool reachedB = false; + bool reachedC = false; + bool reachedD = false; + bool reachedE = false; + bool reachedF = false; + sync_wait(when_all( + [&]() -> task<> + { + CHECK(co_await barrier.wait_until_published(0, scheduler) == 0); + reachedA = true; + CHECK(co_await barrier.wait_until_published(1, scheduler) == 1); + reachedB = true; + CHECK(co_await barrier.wait_until_published(3, scheduler) == 3); + reachedC = true; + CHECK(co_await barrier.wait_until_published(4, scheduler) == 10); + reachedD = true; + co_await barrier.wait_until_published(5, scheduler); + reachedE = true; + co_await barrier.wait_until_published(10, scheduler); + reachedF = true; + }(), + [&]() -> task<> + { + CHECK(!reachedA); + barrier.publish(0); + CHECK(reachedA); + CHECK(!reachedB); + barrier.publish(1); + CHECK(reachedB); + CHECK(!reachedC); + barrier.publish(2); + CHECK(!reachedC); + barrier.publish(3); + CHECK(reachedC); + CHECK(!reachedD); + barrier.publish(10); + CHECK(reachedD); + CHECK(reachedE); + CHECK(reachedF); + co_return; + }())); + CHECK(reachedF); +} + +DOCTEST_TEST_CASE("wait_until_published multiple awaiters") +{ + inline_scheduler scheduler; + + sequence_barrier barrier; + bool reachedA = false; + bool reachedB = false; + bool reachedC = false; + bool reachedD = false; + bool reachedE = false; + sync_wait(when_all( + [&]() -> task<> + { + CHECK(co_await barrier.wait_until_published(0, scheduler) == 0); + reachedA = true; + CHECK(co_await barrier.wait_until_published(1, scheduler) == 1); + reachedB = true; + CHECK(co_await barrier.wait_until_published(3, scheduler) == 3); + reachedC = true; + }(), + [&]() -> task<> + { + CHECK(co_await barrier.wait_until_published(0, scheduler) == 0); + reachedD = true; + CHECK(co_await barrier.wait_until_published(3, scheduler) == 3); + reachedE = true; + }(), + [&]() -> task<> + { + CHECK(!reachedA); + CHECK(!reachedD); + barrier.publish(0); + CHECK(reachedA); + CHECK(reachedD); + CHECK(!reachedB); + CHECK(!reachedE); + barrier.publish(1); + CHECK(reachedB); + CHECK(!reachedC); + CHECK(!reachedE); + barrier.publish(2); + CHECK(!reachedC); + CHECK(!reachedE); + barrier.publish(3); + CHECK(reachedC); + CHECK(reachedE); + co_return; + }())); + CHECK(reachedC); + CHECK(reachedE); +} + +DOCTEST_TEST_CASE("multi-threaded usage single consumer") +{ + static_thread_pool tp{ 2 }; + + sequence_barrier writeBarrier; + sequence_barrier readBarrier; + + constexpr std::size_t iterationCount = 1'000'000; + + constexpr std::size_t bufferSize = 256; + std::uint64_t buffer[bufferSize]; + + auto[result, dummy] = sync_wait(when_all( + [&]() -> task + { + // Consumer + std::uint64_t sum = 0; + + bool reachedEnd = false; + std::size_t nextToRead = 0; + do + { + std::size_t available = co_await writeBarrier.wait_until_published(nextToRead, tp); + do + { + sum += buffer[nextToRead % bufferSize]; + } while (nextToRead++ != available); + + // Zero value is sentinel that indicates the end of the stream. + reachedEnd = buffer[available % bufferSize] == 0; + + // Notify that we've finished processing up to 'available'. + readBarrier.publish(available); + } while (!reachedEnd); + + co_return sum; + }(), + [&]() -> task<> + { + // Producer + std::size_t available = readBarrier.last_published() + bufferSize; + for (std::size_t nextToWrite = 0; nextToWrite <= iterationCount; ++nextToWrite) + { + if (sequence_traits::precedes(available, nextToWrite)) + { + available = co_await readBarrier.wait_until_published(nextToWrite - bufferSize, tp) + bufferSize; + } + + if (nextToWrite == iterationCount) + { + // Write sentinel (zero) as last element. + buffer[nextToWrite % bufferSize] = 0; + } + else + { + // Write value + buffer[nextToWrite % bufferSize] = nextToWrite + 1; + } + + // Notify consumer that we've published a new value. + writeBarrier.publish(nextToWrite); + } + }())); + + // Suppress unused variable warning. + (void)dummy; + + constexpr std::uint64_t expectedResult = + std::uint64_t(iterationCount) * std::uint64_t(iterationCount + 1) / 2; + + CHECK(result == expectedResult); +} + +DOCTEST_TEST_SUITE_END(); diff --git a/test/shared_task_tests.cpp b/test/shared_task_tests.cpp new file mode 100644 index 0000000..d76cbaa --- /dev/null +++ b/test/shared_task_tests.cpp @@ -0,0 +1,248 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include + +#include "counted.hpp" + +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("shared_task"); + +TEST_CASE("awaiting default-constructed task throws broken_promise") +{ + cppcoro::sync_wait([]() -> cppcoro::task<> + { + CHECK_THROWS_AS(co_await cppcoro::shared_task<>{}, const cppcoro::broken_promise&); + }()); +} + +TEST_CASE("coroutine doesn't start executing until awaited") +{ + bool startedExecuting = false; + auto f = [&]() -> cppcoro::shared_task<> + { + startedExecuting = true; + co_return; + }; + + auto t = f(); + + CHECK(!t.is_ready()); + CHECK(!startedExecuting); + + cppcoro::sync_wait([](cppcoro::shared_task<> t) -> cppcoro::task<> + { + co_await t; + }(t)); + + CHECK(t.is_ready()); + CHECK(startedExecuting); +} + +TEST_CASE("result is destroyed when last reference is destroyed") +{ + counted::reset_counts(); + + { + auto t = []() -> cppcoro::shared_task + { + co_return counted{}; + }(); + + CHECK(counted::active_count() == 0); + + cppcoro::sync_wait(t); + + CHECK(counted::active_count() == 1); + } + + CHECK(counted::active_count() == 0); +} + +TEST_CASE("multiple awaiters") +{ + cppcoro::single_consumer_event event; + bool startedExecution = false; + auto produce = [&]() -> cppcoro::shared_task + { + startedExecution = true; + co_await event; + co_return 1; + }; + + auto consume = [](cppcoro::shared_task t) -> cppcoro::task<> + { + int result = co_await t; + CHECK(result == 1); + }; + + auto sharedTask = produce(); + + cppcoro::sync_wait(cppcoro::when_all_ready( + consume(sharedTask), + consume(sharedTask), + consume(sharedTask), + [&]() -> cppcoro::task<> + { + event.set(); + CHECK(sharedTask.is_ready()); + co_return; + }())); + + CHECK(sharedTask.is_ready()); +} + +TEST_CASE("waiting on shared_task in loop doesn't cause stack-overflow") +{ + // This test checks that awaiting a shared_task that completes + // synchronously doesn't recursively resume the awaiter inside the + // call to start executing the task. If it were to do this then we'd + // expect that this test would result in failure due to stack-overflow. + + auto completesSynchronously = []() -> cppcoro::shared_task + { + co_return 1; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + int result = 0; + for (int i = 0; i < 1'000'000; ++i) + { + result += co_await completesSynchronously(); + } + CHECK(result == 1'000'000); + }()); +} + +TEST_CASE("make_shared_task") +{ + bool startedExecution = false; + + auto f = [&]() -> cppcoro::task + { + startedExecution = false; + co_return "test"; + }; + + auto t = f(); + + cppcoro::shared_task sharedT = + cppcoro::make_shared_task(std::move(t)); + + CHECK(!sharedT.is_ready()); + CHECK(!startedExecution); + + auto consume = [](cppcoro::shared_task t) -> cppcoro::task<> + { + auto x = co_await std::move(t); + CHECK(x == "test"); + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + consume(sharedT), + consume(sharedT))); +} + +TEST_CASE("make_shared_task of void" + * doctest::description{ "Tests that workaround for 'co_return ' bug is operational if required" }) +{ + bool startedExecution = false; + + auto f = [&]() -> cppcoro::task<> + { + startedExecution = true; + co_return; + }; + + auto t = f(); + + cppcoro::shared_task<> sharedT = cppcoro::make_shared_task(std::move(t)); + + CHECK(!sharedT.is_ready()); + CHECK(!startedExecution); + + auto consume = [](cppcoro::shared_task<> t) -> cppcoro::task<> + { + co_await t; + }; + + auto c1 = consume(sharedT); + cppcoro::sync_wait(c1); + + CHECK(startedExecution); + + auto c2 = consume(sharedT); + cppcoro::sync_wait(c2); + + CHECK(c1.is_ready()); + CHECK(c2.is_ready()); +} + +TEST_CASE("shared_task fmap operator") +{ + cppcoro::single_consumer_event event; + int value = 0; + + auto setNumber = [&]() -> cppcoro::shared_task<> + { + co_await event; + value = 123; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto numericStringTask = + setNumber() + | cppcoro::fmap([&]() { return std::to_string(value); }); + + CHECK(co_await numericStringTask == "123"); + }(), + [&]() -> cppcoro::task<> + { + CHECK(value == 0); + event.set(); + CHECK(value == 123); + co_return; + }())); +} + +TEST_CASE("shared_task fmap operator") +{ + cppcoro::single_consumer_event event; + + auto getNumber = [&]() -> cppcoro::shared_task + { + co_await event; + co_return 123; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto numericStringTask = + getNumber() + | cppcoro::fmap([](int x) { return std::to_string(x); }); + + CHECK(co_await numericStringTask == "123"); + }(), + [&]() -> cppcoro::task<> + { + event.set(); + co_return; + }())); +} + +TEST_SUITE_END(); diff --git a/test/single_consumer_async_auto_reset_event_tests.cpp b/test/single_consumer_async_auto_reset_event_tests.cpp new file mode 100644 index 0000000..1ff9eea --- /dev/null +++ b/test/single_consumer_async_auto_reset_event_tests.cpp @@ -0,0 +1,93 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("single_consumer_async_auto_reset_event"); + +TEST_CASE("single waiter") +{ + cppcoro::single_consumer_async_auto_reset_event event; + + bool started = false; + bool finished = false; + auto run = [&]() -> cppcoro::task<> + { + started = true; + co_await event; + finished = true; + }; + + auto check = [&]() -> cppcoro::task<> + { + CHECK(started); + CHECK(!finished); + + event.set(); + + CHECK(finished); + + co_return; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready(run(), check())); +} + +TEST_CASE("multi-threaded") +{ + cppcoro::static_thread_pool tp; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + cppcoro::single_consumer_async_auto_reset_event valueChangedEvent; + + std::atomic value; + + auto consumer = [&]() -> cppcoro::task + { + while (value.load(std::memory_order_relaxed) < 10'000) + { + co_await valueChangedEvent; + } + + co_return 0; + }; + + auto modifier = [&](int count) -> cppcoro::task + { + co_await tp.schedule(); + for (int i = 0; i < count; ++i) + { + value.fetch_add(1, std::memory_order_relaxed); + valueChangedEvent.set(); + } + co_return 0; + }; + + for (int i = 0; i < 1000; ++i) + { + value.store(0, std::memory_order_relaxed); + + // Really just checking that we don't deadlock here due to a missed wake-up. + (void)co_await cppcoro::when_all(consumer(), modifier(5'000), modifier(5'000)); + } + }()); +} + +TEST_SUITE_END(); diff --git a/test/single_producer_sequencer_tests.cpp b/test/single_producer_sequencer_tests.cpp new file mode 100644 index 0000000..b401147 --- /dev/null +++ b/test/single_producer_sequencer_tests.cpp @@ -0,0 +1,95 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include "doctest/cppcoro_doctest.h" + +DOCTEST_TEST_SUITE_BEGIN("single_producer_sequencer"); + +using namespace cppcoro; + +DOCTEST_TEST_CASE("multi-threaded usage single consumer") +{ + static_thread_pool tp{ 2 }; + + constexpr std::size_t bufferSize = 256; + + sequence_barrier readBarrier; + single_producer_sequencer sequencer(readBarrier, bufferSize); + + constexpr std::size_t iterationCount = 1'000'000; + + std::uint64_t buffer[bufferSize]; + + auto[result, dummy] = sync_wait(when_all( + [&]() -> task + { + // Consumer + std::uint64_t sum = 0; + + bool reachedEnd = false; + std::size_t nextToRead = 0; + do + { + const std::size_t available = co_await sequencer.wait_until_published(nextToRead, tp); + do + { + sum += buffer[nextToRead % bufferSize]; + } while (nextToRead++ != available); + + // Zero value is sentinel that indicates the end of the stream. + reachedEnd = buffer[available % bufferSize] == 0; + + // Notify that we've finished processing up to 'available'. + readBarrier.publish(available); + } while (!reachedEnd); + + co_return sum; + }(), + [&]() -> task<> + { + // Producer + constexpr std::size_t maxBatchSize = 10; + + std::size_t i = 0; + while (i < iterationCount) + { + const std::size_t batchSize = std::min(maxBatchSize, iterationCount - i); + auto sequences = co_await sequencer.claim_up_to(batchSize, tp); + for (auto seq : sequences) + { + buffer[seq % bufferSize] = ++i; + } + sequencer.publish(sequences.back()); + } + + auto finalSeq = co_await sequencer.claim_one(tp); + buffer[finalSeq % bufferSize] = 0; + sequencer.publish(finalSeq); + }())); + + // Suppress unused variable warning. + (void)dummy; + + constexpr std::uint64_t expectedResult = + std::uint64_t(iterationCount) * std::uint64_t(iterationCount + 1) / 2; + + CHECK(result == expectedResult); +} + +DOCTEST_TEST_SUITE_END(); diff --git a/test/socket_tests.cpp b/test/socket_tests.cpp new file mode 100644 index 0000000..9d17ac2 --- /dev/null +++ b/test/socket_tests.cpp @@ -0,0 +1,474 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "doctest/cppcoro_doctest.h" + +using namespace cppcoro; +using namespace cppcoro::net; + +TEST_SUITE_BEGIN("socket"); + +TEST_CASE("create TCP/IPv4") +{ + io_service ioSvc; + auto socket = socket::create_tcpv4(ioSvc); +} + +TEST_CASE("create TCP/IPv6") +{ + io_service ioSvc; + auto socket = socket::create_tcpv6(ioSvc); +} + +TEST_CASE("create UDP/IPv4") +{ + io_service ioSvc; + auto socket = socket::create_udpv4(ioSvc); +} + +TEST_CASE("create UDP/IPv6") +{ + io_service ioSvc; + auto socket = socket::create_udpv6(ioSvc); +} + +TEST_CASE("TCP/IPv4 connect/disconnect") +{ + io_service ioSvc; + + ip_endpoint serverAddress; + + task serverTask; + + auto server = [&](socket listeningSocket) -> task + { + auto s = socket::create_tcpv4(ioSvc); + co_await listeningSocket.accept(s); + co_await s.disconnect(); + co_return 0; + }; + + { + auto serverSocket = socket::create_tcpv4(ioSvc); + serverSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); + serverSocket.listen(3); + serverAddress = serverSocket.local_endpoint(); + serverTask = server(std::move(serverSocket)); + } + + auto client = [&]() -> task + { + auto s = socket::create_tcpv4(ioSvc); + s.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); + co_await s.connect(serverAddress); + co_await s.disconnect(); + co_return 0; + }; + + task clientTask = client(); + + (void)sync_wait(when_all( + [&]() -> task + { + auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); + (void)co_await when_all(std::move(serverTask), std::move(clientTask)); + co_return 0; + }(), + [&]() -> task + { + ioSvc.process_events(); + co_return 0; + }())); +} + +TEST_CASE("send/recv TCP/IPv4") +{ + io_service ioSvc; + + auto listeningSocket = socket::create_tcpv4(ioSvc); + + listeningSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); + listeningSocket.listen(3); + + auto echoServer = [&]() -> task + { + auto acceptingSocket = socket::create_tcpv4(ioSvc); + + co_await listeningSocket.accept(acceptingSocket); + + std::uint8_t buffer[64]; + std::size_t bytesReceived; + do + { + bytesReceived = co_await acceptingSocket.recv(buffer, sizeof(buffer)); + if (bytesReceived > 0) + { + std::size_t bytesSent = 0; + do + { + bytesSent += co_await acceptingSocket.send( + buffer + bytesSent, + bytesReceived - bytesSent); + } while (bytesSent < bytesReceived); + } + } while (bytesReceived > 0); + + acceptingSocket.close_send(); + + co_await acceptingSocket.disconnect(); + + co_return 0; + }; + + auto echoClient = [&]() -> task + { + auto connectingSocket = socket::create_tcpv4(ioSvc); + + connectingSocket.bind(ipv4_endpoint{}); + + co_await connectingSocket.connect(listeningSocket.local_endpoint()); + + auto receive = [&]() -> task + { + std::uint8_t buffer[100]; + std::uint64_t totalBytesReceived = 0; + std::size_t bytesReceived; + do + { + bytesReceived = co_await connectingSocket.recv(buffer, sizeof(buffer)); + for (std::size_t i = 0; i < bytesReceived; ++i) + { + std::uint64_t byteIndex = totalBytesReceived + i; + std::uint8_t expectedByte = 'a' + (byteIndex % 26); + CHECK(buffer[i] == expectedByte); + } + + totalBytesReceived += bytesReceived; + } while (bytesReceived > 0); + + CHECK(totalBytesReceived == 1000); + + co_return 0; + }; + + auto send = [&]() -> task + { + std::uint8_t buffer[100]; + for (std::uint64_t i = 0; i < 1000; i += sizeof(buffer)) + { + for (std::size_t j = 0; j < sizeof(buffer); ++j) + { + buffer[j] = 'a' + ((i + j) % 26); + } + + std::size_t bytesSent = 0; + do + { + bytesSent += co_await connectingSocket.send(buffer + bytesSent, sizeof(buffer) - bytesSent); + } while (bytesSent < sizeof(buffer)); + } + + connectingSocket.close_send(); + + co_return 0; + }; + + co_await when_all(send(), receive()); + + co_await connectingSocket.disconnect(); + + co_return 0; + }; + + (void)sync_wait(when_all( + [&]() -> task + { + auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); + (void)co_await when_all(echoClient(), echoServer()); + co_return 0; + }(), + [&]() -> task + { + ioSvc.process_events(); + co_return 0; + }())); +} + +#if !CPPCORO_COMPILER_MSVC || CPPCORO_COMPILER_MSVC >= 192000000 || !CPPCORO_CPU_X86 +// HACK: Don't compile this function under MSVC x86. +// It results in an ICE under VS 2017.15 and earlier. + +TEST_CASE("send/recv TCP/IPv4 many connections") +{ + io_service ioSvc; + + auto listeningSocket = socket::create_tcpv4(ioSvc); + + listeningSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); + listeningSocket.listen(20); + + cancellation_source canceller; + + auto handleConnection = [](socket s) -> task + { + std::uint8_t buffer[64]; + std::size_t bytesReceived; + do + { + bytesReceived = co_await s.recv(buffer, sizeof(buffer)); + if (bytesReceived > 0) + { + std::size_t bytesSent = 0; + do + { + bytesSent += co_await s.send( + buffer + bytesSent, + bytesReceived - bytesSent); + } while (bytesSent < bytesReceived); + } + } while (bytesReceived > 0); + + s.close_send(); + + co_await s.disconnect(); + }; + + auto echoServer = [&](cancellation_token ct) -> task<> + { + async_scope connectionScope; + + std::exception_ptr ex; + try + { + while (true) { + auto acceptingSocket = socket::create_tcpv4(ioSvc); + co_await listeningSocket.accept(acceptingSocket, ct); + connectionScope.spawn( + handleConnection(std::move(acceptingSocket))); + } + } + catch (const cppcoro::operation_cancelled&) + { + } + catch (...) + { + ex = std::current_exception(); + } + + co_await connectionScope.join(); + + if (ex) + { + std::rethrow_exception(ex); + } + }; + + auto echoClient = [&]() -> task<> + { + auto connectingSocket = socket::create_tcpv4(ioSvc); + + connectingSocket.bind(ipv4_endpoint{}); + + co_await connectingSocket.connect(listeningSocket.local_endpoint()); + + auto receive = [&]() -> task<> + { + std::uint8_t buffer[100]; + std::uint64_t totalBytesReceived = 0; + std::size_t bytesReceived; + do + { + bytesReceived = co_await connectingSocket.recv(buffer, sizeof(buffer)); + for (std::size_t i = 0; i < bytesReceived; ++i) + { + std::uint64_t byteIndex = totalBytesReceived + i; + std::uint8_t expectedByte = 'a' + (byteIndex % 26); + CHECK(buffer[i] == expectedByte); + } + + totalBytesReceived += bytesReceived; + } while (bytesReceived > 0); + + CHECK(totalBytesReceived == 1000); + }; + + auto send = [&]() -> task<> + { + std::uint8_t buffer[100]; + for (std::uint64_t i = 0; i < 1000; i += sizeof(buffer)) + { + for (std::size_t j = 0; j < sizeof(buffer); ++j) + { + buffer[j] = 'a' + ((i + j) % 26); + } + + std::size_t bytesSent = 0; + do + { + bytesSent += co_await connectingSocket.send(buffer + bytesSent, sizeof(buffer) - bytesSent); + } while (bytesSent < sizeof(buffer)); + } + + connectingSocket.close_send(); + }; + + co_await when_all(send(), receive()); + + co_await connectingSocket.disconnect(); + }; + + auto manyEchoClients = [&](int count) -> task + { + auto shutdownServerOnExit = on_scope_exit([&] + { + canceller.request_cancellation(); + }); + + std::vector> clientTasks; + clientTasks.reserve(count); + + for (int i = 0; i < count; ++i) + { + clientTasks.emplace_back(echoClient()); + } + + co_await when_all(std::move(clientTasks)); + }; + + (void)sync_wait(when_all( + [&]() -> task<> + { + auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); + (void)co_await when_all( + manyEchoClients(20), + echoServer(canceller.token())); + }(), + [&]() -> task<> + { + ioSvc.process_events(); + co_return; + }())); +} + +#endif + +TEST_CASE("udp send_to/recv_from") +{ + io_service ioSvc; + + auto server = [&](socket serverSocket) -> task + { + std::uint8_t buffer[100]; + + auto[bytesReceived, remoteEndPoint] = co_await serverSocket.recv_from(buffer, 100); + CHECK(bytesReceived == 50); + + // Send an ACK response. + { + const std::uint8_t response[1] = { 0 }; + co_await serverSocket.send_to(remoteEndPoint, &response, 1); + } + + // Second message received won't fit within buffer. + try + { + std::tie(bytesReceived, remoteEndPoint) = co_await serverSocket.recv_from(buffer, 100); + FAIL("Should have thrown"); + } + catch (const std::system_error&) + { + // TODO: Map this situation to some kind of error_condition value. + // The win32 ERROR_MORE_DATA error code doesn't seem to map to any of the standard std::errc values. + // + // CHECK(ex.code() == ???); + // + // Possibly also need to switch to returning a std::error_code directly rather than + // throwing a std::system_error for this case. + } + + // Send an NACK response. + { + const std::uint8_t response[1] = { 1 }; + co_await serverSocket.send_to(remoteEndPoint, response, 1); + } + + co_return 0; + }; + + ip_endpoint serverAddress; + + task serverTask; + + { + auto serverSocket = socket::create_udpv4(ioSvc); + serverSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); + serverAddress = serverSocket.local_endpoint(); + serverTask = server(std::move(serverSocket)); + } + + auto client = [&]() -> task + { + auto socket = socket::create_udpv4(ioSvc); + + // don't need to bind(), should be implicitly bound on first send_to(). + + // Send first message of 50 bytes + { + std::uint8_t buffer[50] = { 0 }; + co_await socket.send_to(serverAddress, buffer, 50); + } + + // Receive ACK message + { + std::uint8_t buffer[1]; + auto[bytesReceived, ackAddress] = co_await socket.recv_from(buffer, 1); + CHECK(bytesReceived == 1); + CHECK(buffer[0] == 0); + CHECK(ackAddress == serverAddress); + } + + // Send second message of 128 bytes + { + std::uint8_t buffer[128] = { 0 }; + co_await socket.send_to(serverAddress, buffer, 128); + } + + // Receive NACK message + { + std::uint8_t buffer[1]; + auto[bytesReceived, ackAddress] = co_await socket.recv_from(buffer, 1); + CHECK(bytesReceived == 1); + CHECK(buffer[0] == 1); + CHECK(ackAddress == serverAddress); + } + + co_return 0; + }; + + (void)sync_wait(when_all( + [&]() -> task + { + auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); + (void)co_await when_all(std::move(serverTask), client()); + co_return 0; + }(), + [&]() -> task + { + ioSvc.process_events(); + co_return 0; + }())); +} + +TEST_SUITE_END(); diff --git a/test/static_thread_pool_tests.cpp b/test/static_thread_pool_tests.cpp new file mode 100644 index 0000000..fc99ca4 --- /dev/null +++ b/test/static_thread_pool_tests.cpp @@ -0,0 +1,290 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("static_thread_pool"); + +TEST_CASE("construct/destruct") +{ + cppcoro::static_thread_pool threadPool; + CHECK(threadPool.thread_count() == std::thread::hardware_concurrency()); +} + +TEST_CASE("construct/destruct to specific thread count") +{ + cppcoro::static_thread_pool threadPool{ 5 }; + CHECK(threadPool.thread_count() == 5); +} + +TEST_CASE("run one task") +{ + cppcoro::static_thread_pool threadPool{ 2 }; + + auto initiatingThreadId = std::this_thread::get_id(); + + cppcoro::sync_wait([&]() -> cppcoro::task + { + co_await threadPool.schedule(); + if (std::this_thread::get_id() == initiatingThreadId) + { + FAIL("schedule() did not switch threads"); + } + }()); +} + +TEST_CASE("launch many tasks remotely") +{ + cppcoro::static_thread_pool threadPool; + + auto makeTask = [&]() -> cppcoro::task<> + { + co_await threadPool.schedule(); + }; + + std::vector> tasks; + for (std::uint32_t i = 0; i < 100; ++i) + { + tasks.push_back(makeTask()); + } + + cppcoro::sync_wait(cppcoro::when_all(std::move(tasks))); +} + +cppcoro::task sum_of_squares( + std::uint32_t start, + std::uint32_t end, + cppcoro::static_thread_pool& tp) +{ + co_await tp.schedule(); + + auto count = end - start; + if (count > 1000) + { + auto half = start + count / 2; + auto[a, b] = co_await cppcoro::when_all( + sum_of_squares(start, half, tp), + sum_of_squares(half, end, tp)); + co_return a + b; + } + else + { + std::uint64_t sum = 0; + for (std::uint64_t x = start; x < end; ++x) + { + sum += x * x; + } + co_return sum; + } +} + +TEST_CASE("launch sub-task with many sub-tasks") +{ + using namespace std::chrono_literals; + + constexpr std::uint64_t limit = 1'000'000'000; + + cppcoro::static_thread_pool tp; + + // Wait for the thread-pool thread to start up. + std::this_thread::sleep_for(1ms); + + auto start = std::chrono::high_resolution_clock::now(); + + auto result = cppcoro::sync_wait(sum_of_squares(0, limit , tp)); + + auto end = std::chrono::high_resolution_clock::now(); + + std::uint64_t sum = 0; + for (std::uint64_t i = 0; i < limit; ++i) + { + sum += i * i; + } + + auto end2 = std::chrono::high_resolution_clock::now(); + + auto toNs = [](auto time) + { + return std::chrono::duration_cast(time).count(); + }; + + std::cout + << "multi-threaded version took " << toNs(end - start) << "ns\n" + << "single-threaded version took " << toNs(end2 - end) << "ns" << std::endl; + + CHECK(result == sum); +} + +struct fork_join_operation +{ + std::atomic m_count; + cppcoro::coroutine_handle<> m_coro; + + fork_join_operation() : m_count(1) {} + + void begin_work() noexcept + { + m_count.fetch_add(1, std::memory_order_relaxed); + } + + void end_work() noexcept + { + if (m_count.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + m_coro.resume(); + } + } + + bool await_ready() noexcept { return m_count.load(std::memory_order_acquire) == 1; } + + bool await_suspend(cppcoro::coroutine_handle<> coro) noexcept + { + m_coro = coro; + return m_count.fetch_sub(1, std::memory_order_acq_rel) != 1; + } + + void await_resume() noexcept {}; +}; + +template +cppcoro::task for_each_async(SCHEDULER& scheduler, RANGE& range, FUNC func) +{ + using reference_type = decltype(*range.begin()); + + // TODO: Use awaiter_t here instead. This currently assumes that + // result of scheduler.schedule() doesn't have an operator co_await(). + using schedule_operation = decltype(scheduler.schedule()); + + struct work_operation + { + fork_join_operation& m_forkJoin; + FUNC& m_func; + reference_type m_value; + schedule_operation m_scheduleOp; + + work_operation(fork_join_operation& forkJoin, SCHEDULER& scheduler, FUNC& func, reference_type&& value) + : m_forkJoin(forkJoin) + , m_func(func) + , m_value(static_cast(value)) + , m_scheduleOp(scheduler.schedule()) + { + } + + bool await_ready() noexcept { return false; } + + CPPCORO_NOINLINE + void await_suspend(cppcoro::coroutine_handle<> coro) noexcept + { + fork_join_operation& forkJoin = m_forkJoin; + FUNC& func = m_func; + reference_type value = static_cast(m_value); + + static_assert(std::is_same_v); + + forkJoin.begin_work(); + + // Schedule the next iteration of the loop to run + m_scheduleOp.await_suspend(coro); + + func(static_cast(value)); + + forkJoin.end_work(); + } + + void await_resume() noexcept {} + }; + + co_await scheduler.schedule(); + + fork_join_operation forkJoin; + + for (auto&& x : range) + { + co_await work_operation{ + forkJoin, + scheduler, + func, + static_cast(x) + }; + } + + co_await forkJoin; +} + +std::uint64_t collatz_distance(std::uint64_t number) +{ + std::uint64_t count = 0; + while (number > 1) + { + if (number % 2 == 0) number /= 2; + else number = number * 3 + 1; + ++count; + } + return count; +} + +TEST_CASE("for_each_async") +{ + cppcoro::static_thread_pool tp; + + { + std::vector values(1'000'000); + std::iota(values.begin(), values.end(), 1); + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + auto start = std::chrono::high_resolution_clock::now(); + + co_await for_each_async(tp, values, [](std::uint64_t& value) + { + value = collatz_distance(value); + }); + + auto end = std::chrono::high_resolution_clock::now(); + + std::cout << "for_each_async of " << values.size() + << " took " << std::chrono::duration_cast(end - start).count() + << "us" << std::endl; + + for (std::size_t i = 0; i < 1'000'000; ++i) + { + CHECK(values[i] == collatz_distance(i + 1)); + } + }()); + } + + { + std::vector values(1'000'000); + std::iota(values.begin(), values.end(), 1); + + auto start = std::chrono::high_resolution_clock::now(); + + for (auto&& x : values) + { + x = collatz_distance(x); + } + + auto end = std::chrono::high_resolution_clock::now(); + + std::cout << "single-threaded for loop of " << values.size() + << " took " << std::chrono::duration_cast(end - start).count() + << "us" << std::endl; + } + +} + +TEST_SUITE_END(); diff --git a/test/sync_wait_tests.cpp b/test/sync_wait_tests.cpp new file mode 100644 index 0000000..37cc193 --- /dev/null +++ b/test/sync_wait_tests.cpp @@ -0,0 +1,76 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("sync_wait"); + +static_assert(std::is_same< + decltype(cppcoro::sync_wait(std::declval>())), + std::string&&>::value); +static_assert(std::is_same< + decltype(cppcoro::sync_wait(std::declval&>())), + std::string&>::value); + +TEST_CASE("sync_wait(task)") +{ + auto makeTask = []() -> cppcoro::task + { + co_return "foo"; + }; + + auto task = makeTask(); + CHECK(cppcoro::sync_wait(task) == "foo"); + + CHECK(cppcoro::sync_wait(makeTask()) == "foo"); +} + +TEST_CASE("sync_wait(shared_task)") +{ + auto makeTask = []() -> cppcoro::shared_task + { + co_return "foo"; + }; + + auto task = makeTask(); + + CHECK(cppcoro::sync_wait(task) == "foo"); + CHECK(cppcoro::sync_wait(makeTask()) == "foo"); +} + +TEST_CASE("multiple threads") +{ + // We are creating a new task and starting it inside the sync_wait(). + // The task will reschedule itself for resumption on a thread-pool thread + // which will sometimes complete before this thread calls event.wait() + // inside sync_wait(). Thus we're roughly testing the thread-safety of + // sync_wait(). + cppcoro::static_thread_pool tp{ 1 }; + + int value = 0; + auto createLazyTask = [&]() -> cppcoro::task + { + co_await tp.schedule(); + co_return value++; + }; + + for (int i = 0; i < 10'000; ++i) + { + CHECK(cppcoro::sync_wait(createLazyTask()) == i); + } +} + +TEST_SUITE_END(); diff --git a/test/task_tests.cpp b/test/task_tests.cpp new file mode 100644 index 0000000..68a0c0c --- /dev/null +++ b/test/task_tests.cpp @@ -0,0 +1,349 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include "counted.hpp" + +#include +#include +#include + +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("task"); + +TEST_CASE("task doesn't start until awaited") +{ + bool started = false; + auto func = [&]() -> cppcoro::task<> + { + started = true; + co_return; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + auto t = func(); + CHECK(!started); + + co_await t; + + CHECK(started); + }()); +} + +TEST_CASE("awaiting default-constructed task throws broken_promise") +{ + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + cppcoro::task<> t; + CHECK_THROWS_AS(co_await t, const cppcoro::broken_promise&); + }()); +} + +TEST_CASE("awaiting task that completes asynchronously") +{ + bool reachedBeforeEvent = false; + bool reachedAfterEvent = false; + cppcoro::single_consumer_event event; + auto f = [&]() -> cppcoro::task<> + { + reachedBeforeEvent = true; + co_await event; + reachedAfterEvent = true; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + auto t = f(); + + CHECK(!reachedBeforeEvent); + + (void)co_await cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + co_await t; + CHECK(reachedBeforeEvent); + CHECK(reachedAfterEvent); + }(), + [&]() -> cppcoro::task<> + { + CHECK(reachedBeforeEvent); + CHECK(!reachedAfterEvent); + event.set(); + CHECK(reachedAfterEvent); + co_return; + }()); + }()); +} + +TEST_CASE("destroying task that was never awaited destroys captured args") +{ + counted::reset_counts(); + + auto f = [](counted c) -> cppcoro::task + { + co_return c; + }; + + CHECK(counted::active_count() == 0); + + { + auto t = f(counted{}); + CHECK(counted::active_count() == 1); + } + + CHECK(counted::active_count() == 0); +} + +TEST_CASE("task destructor destroys result") +{ + counted::reset_counts(); + + auto f = []() -> cppcoro::task + { + co_return counted{}; + }; + + { + auto t = f(); + CHECK(counted::active_count() == 0); + + auto& result = cppcoro::sync_wait(t); + + CHECK(counted::active_count() == 1); + CHECK(result.id == 0); + } + + CHECK(counted::active_count() == 0); +} + +TEST_CASE("task of reference type") +{ + int value = 3; + auto f = [&]() -> cppcoro::task + { + co_return value; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + SUBCASE("awaiting rvalue task") + { + decltype(auto) result = co_await f(); + static_assert( + std::is_same::value, + "co_await r-value reference of task should result in an int&"); + CHECK(&result == &value); + } + + SUBCASE("awaiting lvalue task") + { + auto t = f(); + decltype(auto) result = co_await t; + static_assert( + std::is_same::value, + "co_await l-value reference of task should result in an int&"); + CHECK(&result == &value); + } + }()); +} + +TEST_CASE("passing parameter by value to task coroutine calls move-constructor exactly once") +{ + counted::reset_counts(); + + auto f = [](counted arg) -> cppcoro::task<> + { + co_return; + }; + + counted c; + + CHECK(counted::active_count() == 1); + CHECK(counted::default_construction_count == 1); + CHECK(counted::copy_construction_count == 0); + CHECK(counted::move_construction_count == 0); + CHECK(counted::destruction_count == 0); + + { + auto t = f(c); + + // Should have called copy-constructor to pass a copy of 'c' into f by value. + CHECK(counted::copy_construction_count == 1); + + // Inside f it should have move-constructed parameter into coroutine frame variable + //WARN_MESSAGE(counted::move_construction_count == 1, + // "Known bug in MSVC 2017.1, not critical if it performs multiple moves"); + + // Active counts should be the instance 'c' and the instance captured in coroutine frame of 't'. + CHECK(counted::active_count() == 2); + } + + CHECK(counted::active_count() == 1); +} + +TEST_CASE("task fmap pipe operator") +{ + using cppcoro::fmap; + + cppcoro::single_consumer_event event; + + auto f = [&]() -> cppcoro::task<> + { + co_await event; + co_return; + }; + + auto t = f() | fmap([] { return 123; }); + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + CHECK(co_await t == 123); + }(), + [&]() -> cppcoro::task<> + { + event.set(); + co_return; + }())); +} + +TEST_CASE("task fmap pipe operator") +{ + using cppcoro::task; + using cppcoro::fmap; + using cppcoro::sync_wait; + using cppcoro::make_task; + + auto one = [&]() -> task + { + co_return 1; + }; + + SUBCASE("r-value fmap / r-value lambda") + { + auto t = one() + | fmap([delta = 1](auto i) { return i + delta; }); + CHECK(sync_wait(t) == 2); + } + + SUBCASE("r-value fmap / l-value lambda") + { + using namespace std::string_literals; + + auto t = [&] + { + auto f = [prefix = "pfx"s](int x) + { + return prefix + std::to_string(x); + }; + + // Want to make sure that the resulting awaitable has taken + // a copy of the lambda passed to fmap(). + return one() | fmap(f); + }(); + + CHECK(sync_wait(t) == "pfx1"); + } + + SUBCASE("l-value fmap / r-value lambda") + { + using namespace std::string_literals; + + auto t = [&] + { + auto addprefix = fmap([prefix = "a really really long prefix that prevents small string optimisation"s](int x) + { + return prefix + std::to_string(x); + }); + + // Want to make sure that the resulting awaitable has taken + // a copy of the lambda passed to fmap(). + return one() | addprefix; + }(); + + CHECK(sync_wait(t) == "a really really long prefix that prevents small string optimisation1"); + } + + SUBCASE("l-value fmap / l-value lambda") + { + using namespace std::string_literals; + + task t; + + { + auto lambda = [prefix = "a really really long prefix that prevents small string optimisation"s](int x) + { + return prefix + std::to_string(x); + }; + + auto addprefix = fmap(lambda); + + // Want to make sure that the resulting task has taken + // a copy of the lambda passed to fmap(). + t = make_task(one() | addprefix); + } + + CHECK(!t.is_ready()); + + CHECK(sync_wait(t) == "a really really long prefix that prevents small string optimisation1"); + } +} + +TEST_CASE("chained fmap pipe operations") +{ + using namespace std::string_literals; + using cppcoro::task; + using cppcoro::sync_wait; + + auto prepend = [](std::string s) + { + using cppcoro::fmap; + return fmap([s = std::move(s)](const std::string& value) { return s + value; }); + }; + + auto append = [](std::string s) + { + using cppcoro::fmap; + return fmap([s = std::move(s)](const std::string& value){ return value + s; }); + }; + + auto asyncString = [](std::string s) -> task + { + co_return std::move(s); + }; + + auto t = asyncString("base"s) | prepend("pre_"s) | append("_post"s); + + CHECK(sync_wait(t) == "pre_base_post"); +} + +TEST_CASE("lots of synchronous completions doesn't result in stack-overflow") +{ + auto completesSynchronously = []() -> cppcoro::task + { + co_return 1; + }; + + auto run = [&]() -> cppcoro::task<> + { + int sum = 0; + for (int i = 0; i < 1'000'000; ++i) + { + sum += co_await completesSynchronously(); + } + CHECK(sum == 1'000'000); + }; + + cppcoro::sync_wait(run()); +} + +TEST_SUITE_END(); diff --git a/test/when_all_ready_tests.cpp b/test/when_all_ready_tests.cpp new file mode 100644 index 0000000..ff7ae97 --- /dev/null +++ b/test/when_all_ready_tests.cpp @@ -0,0 +1,265 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include + +#include "counted.hpp" + +#include +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("when_all_ready"); + +template class TASK, typename T> +TASK when_event_set_return(cppcoro::async_manual_reset_event& event, T value) +{ + co_await event; + co_return std::move(value); +} + +TEST_CASE("when_all_ready() with no args") +{ + [[maybe_unused]] std::tuple<> result = cppcoro::sync_wait(cppcoro::when_all_ready()); +} + +TEST_CASE("when_all_ready() with one task") +{ + bool started = false; + auto f = [&](cppcoro::async_manual_reset_event& event) -> cppcoro::task<> + { + started = true; + co_await event; + }; + + cppcoro::async_manual_reset_event event; + auto whenAllAwaitable = cppcoro::when_all_ready(f(event)); + CHECK(!started); + + bool finished = false; + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto&[t] = co_await whenAllAwaitable; + finished = true; + t.result(); + }(), + [&]() -> cppcoro::task<> + { + CHECK(started); + CHECK(!finished); + event.set(); + CHECK(finished); + co_return; + }())); +} + +TEST_CASE("when_all_ready() with multiple task") +{ + auto makeTask = [&](bool& started, cppcoro::async_manual_reset_event& event, int result) -> cppcoro::task + { + started = true; + co_await event; + co_return result; + }; + + cppcoro::async_manual_reset_event event1; + cppcoro::async_manual_reset_event event2; + bool started1 = false; + bool started2 = false; + auto whenAllAwaitable = cppcoro::when_all_ready( + makeTask(started1, event1, 1), + makeTask(started2, event2, 2)); + CHECK(!started1); + CHECK(!started2); + + bool whenAllAwaitableFinished = false; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto[t1, t2] = co_await std::move(whenAllAwaitable); + whenAllAwaitableFinished = true; + CHECK(t1.result() == 1); + CHECK(t2.result() == 2); + }(), + [&]() -> cppcoro::task<> + { + CHECK(started1); + CHECK(started2); + + event2.set(); + + CHECK(!whenAllAwaitableFinished); + + event1.set(); + + CHECK(whenAllAwaitableFinished); + + co_return; + }())); +} + +TEST_CASE("when_all_ready() with all task types") +{ + cppcoro::async_manual_reset_event event; + auto t0 = when_event_set_return(event, 1); + auto t1 = when_event_set_return(event, 2); + + auto allTask = cppcoro::when_all_ready(std::move(t0), t1); + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto [r0, r1] = co_await std::move(allTask); + + CHECK(r0.result() == 1); + CHECK(r1.result() == 2); + }(), + [&]() -> cppcoro::task<> + { + event.set(); + co_return; + }())); +} + +TEST_CASE("when_all_ready() with std::vector>") +{ + cppcoro::async_manual_reset_event event; + + std::uint32_t startedCount = 0; + std::uint32_t finishedCount = 0; + + auto makeTask = [&]() -> cppcoro::task<> + { + ++startedCount; + co_await event; + ++finishedCount; + }; + + std::vector> tasks; + for (std::uint32_t i = 0; i < 10; ++i) + { + tasks.emplace_back(makeTask()); + } + + auto allTask = cppcoro::when_all_ready(std::move(tasks)); + + // Shouldn't have started any tasks yet. + CHECK(startedCount == 0u); + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto resultTasks = co_await std::move(allTask); + CHECK(resultTasks.size() == 10u); + + for (auto& t : resultTasks) + { + CHECK_NOTHROW(t.result()); + } + }(), + [&]() -> cppcoro::task<> + { + CHECK(startedCount == 10u); + CHECK(finishedCount == 0u); + + event.set(); + + CHECK(finishedCount == 10u); + + co_return; + }())); +} + +TEST_CASE("when_all_ready() with std::vector>") +{ + cppcoro::async_manual_reset_event event; + + std::uint32_t startedCount = 0; + std::uint32_t finishedCount = 0; + + auto makeTask = [&]() -> cppcoro::shared_task<> + { + ++startedCount; + co_await event; + ++finishedCount; + }; + + std::vector> tasks; + for (std::uint32_t i = 0; i < 10; ++i) + { + tasks.emplace_back(makeTask()); + } + + auto allTask = cppcoro::when_all_ready(std::move(tasks)); + + // Shouldn't have started any tasks yet. + CHECK(startedCount == 0u); + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto resultTasks = co_await std::move(allTask); + CHECK(resultTasks.size() == 10u); + + for (auto& t : resultTasks) + { + CHECK_NOTHROW(t.result()); + } + }(), + [&]() -> cppcoro::task<> + { + CHECK(startedCount == 10u); + CHECK(finishedCount == 0u); + + event.set(); + + CHECK(finishedCount == 10u); + + co_return; + }())); +} + +TEST_CASE("when_all_ready() doesn't rethrow exceptions") +{ + auto makeTask = [](bool throwException) -> cppcoro::task + { + if (throwException) + { + throw std::exception{}; + } + else + { + co_return 123; + } + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + try + { + auto[t0, t1] = co_await cppcoro::when_all_ready(makeTask(true), makeTask(false)); + + // You can obtain the exceptions by re-awaiting the returned tasks. + CHECK_THROWS_AS(t0.result(), const std::exception&); + CHECK(t1.result() == 123); + } + catch (...) + { + FAIL("Shouldn't throw"); + } + }()); +} + +TEST_SUITE_END(); diff --git a/test/when_all_tests.cpp b/test/when_all_tests.cpp new file mode 100644 index 0000000..721bcc1 --- /dev/null +++ b/test/when_all_tests.cpp @@ -0,0 +1,427 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "counted.hpp" + +#include +#include +#include + +#include +#include "doctest/cppcoro_doctest.h" + +TEST_SUITE_BEGIN("when_all"); + +namespace +{ + template class TASK, typename T> + TASK when_event_set_return(cppcoro::async_manual_reset_event& event, T value) + { + co_await event; + co_return std::move(value); + } +} + +TEST_CASE("when_all() with no args completes immediately") +{ + [[maybe_unused]] std::tuple<> result = cppcoro::sync_wait(cppcoro::when_all()); +} + +TEST_CASE("when_all() with one arg") +{ + bool started = false; + bool finished = false; + auto f = [&](cppcoro::async_manual_reset_event& event) -> cppcoro::task + { + started = true; + co_await event; + finished = true; + co_return "foo"; + }; + + cppcoro::async_manual_reset_event event; + + auto whenAllTask = cppcoro::when_all(f(event)); + CHECK(!started); + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + auto[s] = co_await whenAllTask; + CHECK(s == "foo"); + }(), + [&]() -> cppcoro::task<> + { + CHECK(started); + CHECK(!finished); + event.set(); + CHECK(finished); + co_return; + }())); +} + +TEST_CASE("when_all() with awaitables") +{ + cppcoro::sync_wait([]() -> cppcoro::task<> + { + auto makeTask = [](int x) -> cppcoro::task + { + co_return x; + }; + + cppcoro::async_manual_reset_event event; + event.set(); + + cppcoro::async_mutex mutex; + + auto[eventResult, mutexLock, number] = co_await cppcoro::when_all( + std::ref(event), + mutex.scoped_lock_async(), + makeTask(123) | cppcoro::fmap([](int x) { return x + 1; })); + + (void)eventResult; + (void)mutexLock; + CHECK(number == 124); + CHECK(!mutex.try_lock()); + }()); +} + +TEST_CASE("when_all() with all task types") +{ + counted::reset_counts(); + + auto run = [](cppcoro::async_manual_reset_event& event) -> cppcoro::task<> + { + using namespace std::string_literals; + + auto[a, b] = co_await cppcoro::when_all( + when_event_set_return(event, "foo"s), + when_event_set_return(event, counted{})); + + CHECK(a == "foo"); + CHECK(b.id == 0); + CHECK(counted::active_count() == 1); + }; + + cppcoro::async_manual_reset_event event; + + cppcoro::sync_wait(cppcoro::when_all_ready( + run(event), + [&]() -> cppcoro::task<> + { + event.set(); + co_return; + }())); +} + +TEST_CASE("when_all() throws if any task throws") +{ + struct X {}; + struct Y {}; + + int startedCount = 0; + auto makeTask = [&](int value) -> cppcoro::task + { + ++startedCount; + if (value == 0) throw X{}; + else if (value == 1) throw Y{}; + else co_return value; + }; + + cppcoro::sync_wait([&]() -> cppcoro::task<> + { + try + { + // This could either throw X or Y exception. + // The exact exception that is thrown is not defined if multiple tasks throw an exception. + // TODO: Consider throwing some kind of aggregate_exception that collects all of the exceptions together. + (void)co_await cppcoro::when_all(makeTask(0), makeTask(1), makeTask(2)); + } + catch (const X&) + { + } + catch (const Y&) + { + } + }()); +} + +TEST_CASE("when_all() with task") +{ + int voidTaskCount = 0; + auto makeVoidTask = [&]() -> cppcoro::task<> + { + ++voidTaskCount; + co_return; + }; + + auto makeIntTask = [](int x) -> cppcoro::task + { + co_return x; + }; + + // Single void task in when_all() + auto[x] = cppcoro::sync_wait(cppcoro::when_all(makeVoidTask())); + (void)x; + CHECK(voidTaskCount == 1); + + // Multiple void tasks in when_all() + auto[a, b] = cppcoro::sync_wait(cppcoro::when_all( + makeVoidTask(), + makeVoidTask())); + (void)a; + (void)b; + CHECK(voidTaskCount == 3); + + // Mixing void and non-void tasks in when_all() + auto[v1, i, v2] = cppcoro::sync_wait(cppcoro::when_all( + makeVoidTask(), + makeIntTask(123), + makeVoidTask())); + (void)v1; + (void)v2; + CHECK(voidTaskCount == 5); + + CHECK(i == 123); +} + +TEST_CASE("when_all() with vector>") +{ + int startedCount = 0; + auto makeTask = [&](cppcoro::async_manual_reset_event& event) -> cppcoro::task<> + { + ++startedCount; + co_await event; + }; + + cppcoro::async_manual_reset_event event1; + cppcoro::async_manual_reset_event event2; + + bool finished = false; + + auto run = [&]() -> cppcoro::task<> + { + std::vector> tasks; + tasks.push_back(makeTask(event1)); + tasks.push_back(makeTask(event2)); + tasks.push_back(makeTask(event1)); + + auto allTask = cppcoro::when_all(std::move(tasks)); + + CHECK(startedCount == 0); + + co_await allTask; + + finished = true; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + run(), + [&]() -> cppcoro::task<> + { + CHECK(startedCount == 3); + CHECK(!finished); + + event1.set(); + + CHECK(!finished); + + event2.set(); + + CHECK(finished); + co_return; + }())); +} + +TEST_CASE("when_all() with vector>") +{ + int startedCount = 0; + auto makeTask = [&](cppcoro::async_manual_reset_event& event) -> cppcoro::shared_task<> + { + ++startedCount; + co_await event; + }; + + cppcoro::async_manual_reset_event event1; + cppcoro::async_manual_reset_event event2; + + bool finished = false; + + auto run = [&]() -> cppcoro::task<> + { + std::vector> tasks; + tasks.push_back(makeTask(event1)); + tasks.push_back(makeTask(event2)); + tasks.push_back(makeTask(event1)); + + auto allTask = cppcoro::when_all(std::move(tasks)); + + CHECK(startedCount == 0); + + co_await allTask; + + finished = true; + }; + + cppcoro::sync_wait(cppcoro::when_all_ready( + run(), + [&]() -> cppcoro::task<> + { + CHECK(startedCount == 3); + CHECK(!finished); + + event1.set(); + + CHECK(!finished); + + event2.set(); + + CHECK(finished); + + co_return; + }())); +} + +namespace +{ + template class TASK> + void check_when_all_vector_of_task_value() + { + cppcoro::async_manual_reset_event event1; + cppcoro::async_manual_reset_event event2; + + bool whenAllCompleted = false; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + std::vector> tasks; + + tasks.emplace_back(when_event_set_return(event1, 1)); + tasks.emplace_back(when_event_set_return(event2, 2)); + + auto whenAllTask = cppcoro::when_all(std::move(tasks)); + + auto values = co_await whenAllTask; + REQUIRE(values.size() == 2); + CHECK(values[0] == 1); + CHECK(values[1] == 2); + + whenAllCompleted = true; + }(), + [&]() -> cppcoro::task<> + { + CHECK(!whenAllCompleted); + event2.set(); + CHECK(!whenAllCompleted); + event1.set(); + CHECK(whenAllCompleted); + co_return; + }())); + } +} + +#if defined(CPPCORO_RELEASE_OPTIMISED) +constexpr bool isOptimised = true; +#else +constexpr bool isOptimised = false; +#endif + +// Disable test on MSVC x86 optimised due to bad codegen bug in +// `co_await whenAllTask` expression under MSVC 15.7 (Preview 2) and earlier. +TEST_CASE("when_all() with vector>" +* doctest::skip(CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC <= 191426316 && CPPCORO_CPU_X86 && isOptimised)) +{ + check_when_all_vector_of_task_value(); +} + +// Disable test on MSVC x64 optimised due to bad codegen bug in +// 'co_await whenAllTask' expression. +// Issue reported to MS on 19/11/2017. +TEST_CASE("when_all() with vector>" +* doctest::skip(CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC <= 191225805 && + isOptimised && CPPCORO_CPU_X64)) +{ + check_when_all_vector_of_task_value(); +} + +namespace +{ + template class TASK> + void check_when_all_vector_of_task_reference() + { + cppcoro::async_manual_reset_event event1; + cppcoro::async_manual_reset_event event2; + + int value1 = 1; + int value2 = 2; + + auto makeTask = [](cppcoro::async_manual_reset_event& event, int& value) -> TASK + { + co_await event; + co_return value; + }; + + bool whenAllComplete = false; + + cppcoro::sync_wait(cppcoro::when_all_ready( + [&]() -> cppcoro::task<> + { + std::vector> tasks; + tasks.emplace_back(makeTask(event1, value1)); + tasks.emplace_back(makeTask(event2, value2)); + + auto whenAllTask = cppcoro::when_all(std::move(tasks)); + + std::vector> values = co_await whenAllTask; + REQUIRE(values.size() == 2); + CHECK(&values[0].get() == &value1); + CHECK(&values[1].get() == &value2); + + whenAllComplete = true; + }(), + [&]() -> cppcoro::task<> + { + CHECK(!whenAllComplete); + event2.set(); + CHECK(!whenAllComplete); + event1.set(); + CHECK(whenAllComplete); + co_return; + }())); + } +} + +// Disable test on MSVC x64 optimised due to bad codegen bug in +// 'co_await whenAllTask' expression. +// Issue reported to MS on 19/11/2017. +TEST_CASE("when_all() with vector>" + * doctest::skip(CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC <= 191225805 && + isOptimised && CPPCORO_CPU_X64)) +{ + check_when_all_vector_of_task_reference(); +} + +// Disable test on MSVC x64 optimised due to bad codegen bug in +// 'co_await whenAllTask' expression. +// Issue reported to MS on 19/11/2017. +TEST_CASE("when_all() with vector>" + * doctest::skip(CPPCORO_COMPILER_MSVC && CPPCORO_COMPILER_MSVC <= 191225805 && + isOptimised && CPPCORO_CPU_X64)) +{ + check_when_all_vector_of_task_reference(); +} + +TEST_SUITE_END();