/////////////////////////////////////////////////////////////////////////////// // Copyright (c) Lewis Baker // Licenced under MIT license. See LICENSE.txt for details. /////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include #include #include #include #include #include "doctest/cppcoro_doctest.h" using namespace cppcoro; using namespace cppcoro::net; TEST_SUITE_BEGIN("socket"); TEST_CASE("create TCP/IPv4") { io_service ioSvc; auto socket = socket::create_tcpv4(ioSvc); } TEST_CASE("create TCP/IPv6") { io_service ioSvc; auto socket = socket::create_tcpv6(ioSvc); } TEST_CASE("create UDP/IPv4") { io_service ioSvc; auto socket = socket::create_udpv4(ioSvc); } TEST_CASE("create UDP/IPv6") { io_service ioSvc; auto socket = socket::create_udpv6(ioSvc); } TEST_CASE("TCP/IPv4 connect/disconnect") { io_service ioSvc; ip_endpoint serverAddress; task serverTask; auto server = [&](socket listeningSocket) -> task { auto s = socket::create_tcpv4(ioSvc); co_await listeningSocket.accept(s); co_await s.disconnect(); co_return 0; }; { auto serverSocket = socket::create_tcpv4(ioSvc); serverSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); serverSocket.listen(3); serverAddress = serverSocket.local_endpoint(); serverTask = server(std::move(serverSocket)); } auto client = [&]() -> task { auto s = socket::create_tcpv4(ioSvc); s.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); co_await s.connect(serverAddress); co_await s.disconnect(); co_return 0; }; task clientTask = client(); (void)sync_wait(when_all( [&]() -> task { auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); (void)co_await when_all(std::move(serverTask), std::move(clientTask)); co_return 0; }(), [&]() -> task { ioSvc.process_events(); co_return 0; }())); } TEST_CASE("send/recv TCP/IPv4") { io_service ioSvc; auto listeningSocket = socket::create_tcpv4(ioSvc); listeningSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); listeningSocket.listen(3); auto echoServer = [&]() -> task { auto acceptingSocket = socket::create_tcpv4(ioSvc); co_await listeningSocket.accept(acceptingSocket); std::uint8_t buffer[64]; std::size_t bytesReceived; do { bytesReceived = co_await acceptingSocket.recv(buffer, sizeof(buffer)); if (bytesReceived > 0) { std::size_t bytesSent = 0; do { bytesSent += co_await acceptingSocket.send( buffer + bytesSent, bytesReceived - bytesSent); } while (bytesSent < bytesReceived); } } while (bytesReceived > 0); acceptingSocket.close_send(); co_await acceptingSocket.disconnect(); co_return 0; }; auto echoClient = [&]() -> task { auto connectingSocket = socket::create_tcpv4(ioSvc); connectingSocket.bind(ipv4_endpoint{}); co_await connectingSocket.connect(listeningSocket.local_endpoint()); auto receive = [&]() -> task { std::uint8_t buffer[100]; std::uint64_t totalBytesReceived = 0; std::size_t bytesReceived; do { bytesReceived = co_await connectingSocket.recv(buffer, sizeof(buffer)); for (std::size_t i = 0; i < bytesReceived; ++i) { std::uint64_t byteIndex = totalBytesReceived + i; std::uint8_t expectedByte = 'a' + (byteIndex % 26); CHECK(buffer[i] == expectedByte); } totalBytesReceived += bytesReceived; } while (bytesReceived > 0); CHECK(totalBytesReceived == 1000); co_return 0; }; auto send = [&]() -> task { std::uint8_t buffer[100]; for (std::uint64_t i = 0; i < 1000; i += sizeof(buffer)) { for (std::size_t j = 0; j < sizeof(buffer); ++j) { buffer[j] = 'a' + ((i + j) % 26); } std::size_t bytesSent = 0; do { bytesSent += co_await connectingSocket.send(buffer + bytesSent, sizeof(buffer) - bytesSent); } while (bytesSent < sizeof(buffer)); } connectingSocket.close_send(); co_return 0; }; co_await when_all(send(), receive()); co_await connectingSocket.disconnect(); co_return 0; }; (void)sync_wait(when_all( [&]() -> task { auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); (void)co_await when_all(echoClient(), echoServer()); co_return 0; }(), [&]() -> task { ioSvc.process_events(); co_return 0; }())); } #if !CPPCORO_COMPILER_MSVC || CPPCORO_COMPILER_MSVC >= 192000000 || !CPPCORO_CPU_X86 // HACK: Don't compile this function under MSVC x86. // It results in an ICE under VS 2017.15 and earlier. TEST_CASE("send/recv TCP/IPv4 many connections") { io_service ioSvc; auto listeningSocket = socket::create_tcpv4(ioSvc); listeningSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); listeningSocket.listen(20); cancellation_source canceller; auto handleConnection = [](socket s) -> task { std::uint8_t buffer[64]; std::size_t bytesReceived; do { bytesReceived = co_await s.recv(buffer, sizeof(buffer)); if (bytesReceived > 0) { std::size_t bytesSent = 0; do { bytesSent += co_await s.send( buffer + bytesSent, bytesReceived - bytesSent); } while (bytesSent < bytesReceived); } } while (bytesReceived > 0); s.close_send(); co_await s.disconnect(); }; auto echoServer = [&](cancellation_token ct) -> task<> { async_scope connectionScope; std::exception_ptr ex; try { while (true) { auto acceptingSocket = socket::create_tcpv4(ioSvc); co_await listeningSocket.accept(acceptingSocket, ct); connectionScope.spawn( handleConnection(std::move(acceptingSocket))); } } catch (const cppcoro::operation_cancelled&) { } catch (...) { ex = std::current_exception(); } co_await connectionScope.join(); if (ex) { std::rethrow_exception(ex); } }; auto echoClient = [&]() -> task<> { auto connectingSocket = socket::create_tcpv4(ioSvc); connectingSocket.bind(ipv4_endpoint{}); co_await connectingSocket.connect(listeningSocket.local_endpoint()); auto receive = [&]() -> task<> { std::uint8_t buffer[100]; std::uint64_t totalBytesReceived = 0; std::size_t bytesReceived; do { bytesReceived = co_await connectingSocket.recv(buffer, sizeof(buffer)); for (std::size_t i = 0; i < bytesReceived; ++i) { std::uint64_t byteIndex = totalBytesReceived + i; std::uint8_t expectedByte = 'a' + (byteIndex % 26); CHECK(buffer[i] == expectedByte); } totalBytesReceived += bytesReceived; } while (bytesReceived > 0); CHECK(totalBytesReceived == 1000); }; auto send = [&]() -> task<> { std::uint8_t buffer[100]; for (std::uint64_t i = 0; i < 1000; i += sizeof(buffer)) { for (std::size_t j = 0; j < sizeof(buffer); ++j) { buffer[j] = 'a' + ((i + j) % 26); } std::size_t bytesSent = 0; do { bytesSent += co_await connectingSocket.send(buffer + bytesSent, sizeof(buffer) - bytesSent); } while (bytesSent < sizeof(buffer)); } connectingSocket.close_send(); }; co_await when_all(send(), receive()); co_await connectingSocket.disconnect(); }; auto manyEchoClients = [&](int count) -> task { auto shutdownServerOnExit = on_scope_exit([&] { canceller.request_cancellation(); }); std::vector> clientTasks; clientTasks.reserve(count); for (int i = 0; i < count; ++i) { clientTasks.emplace_back(echoClient()); } co_await when_all(std::move(clientTasks)); }; (void)sync_wait(when_all( [&]() -> task<> { auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); (void)co_await when_all( manyEchoClients(20), echoServer(canceller.token())); }(), [&]() -> task<> { ioSvc.process_events(); co_return; }())); } #endif TEST_CASE("udp send_to/recv_from") { io_service ioSvc; auto server = [&](socket serverSocket) -> task { std::uint8_t buffer[100]; auto[bytesReceived, remoteEndPoint] = co_await serverSocket.recv_from(buffer, 100); CHECK(bytesReceived == 50); // Send an ACK response. { const std::uint8_t response[1] = { 0 }; co_await serverSocket.send_to(remoteEndPoint, &response, 1); } // Second message received won't fit within buffer. try { std::tie(bytesReceived, remoteEndPoint) = co_await serverSocket.recv_from(buffer, 100); FAIL("Should have thrown"); } catch (const std::system_error&) { // TODO: Map this situation to some kind of error_condition value. // The win32 ERROR_MORE_DATA error code doesn't seem to map to any of the standard std::errc values. // // CHECK(ex.code() == ???); // // Possibly also need to switch to returning a std::error_code directly rather than // throwing a std::system_error for this case. } // Send an NACK response. { const std::uint8_t response[1] = { 1 }; co_await serverSocket.send_to(remoteEndPoint, response, 1); } co_return 0; }; ip_endpoint serverAddress; task serverTask; { auto serverSocket = socket::create_udpv4(ioSvc); serverSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 }); serverAddress = serverSocket.local_endpoint(); serverTask = server(std::move(serverSocket)); } auto client = [&]() -> task { auto socket = socket::create_udpv4(ioSvc); // don't need to bind(), should be implicitly bound on first send_to(). // Send first message of 50 bytes { std::uint8_t buffer[50] = { 0 }; co_await socket.send_to(serverAddress, buffer, 50); } // Receive ACK message { std::uint8_t buffer[1]; auto[bytesReceived, ackAddress] = co_await socket.recv_from(buffer, 1); CHECK(bytesReceived == 1); CHECK(buffer[0] == 0); CHECK(ackAddress == serverAddress); } // Send second message of 128 bytes { std::uint8_t buffer[128] = { 0 }; co_await socket.send_to(serverAddress, buffer, 128); } // Receive NACK message { std::uint8_t buffer[1]; auto[bytesReceived, ackAddress] = co_await socket.recv_from(buffer, 1); CHECK(bytesReceived == 1); CHECK(buffer[0] == 1); CHECK(ackAddress == serverAddress); } co_return 0; }; (void)sync_wait(when_all( [&]() -> task { auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); }); (void)co_await when_all(std::move(serverTask), client()); co_return 0; }(), [&]() -> task { ioSvc.process_events(); co_return 0; }())); } TEST_SUITE_END();