Files
jeanlemotan bbeaa887cd First
2024-07-02 18:13:47 +02:00

494 lines
14 KiB
C++

///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/net/socket.hpp>
#include <cppcoro/net/socket_accept_operation.hpp>
#include <cppcoro/net/socket_connect_operation.hpp>
#include <cppcoro/net/socket_disconnect_operation.hpp>
#include <cppcoro/net/socket_recv_operation.hpp>
#include <cppcoro/net/socket_send_operation.hpp>
#include <cppcoro/io_service.hpp>
#include <cppcoro/on_scope_exit.hpp>
#include "socket_helpers.hpp"
#if CPPCORO_OS_WINNT
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
namespace
{
namespace local
{
std::tuple<SOCKET, bool> create_socket(
int addressFamily,
int socketType,
int protocol,
HANDLE ioCompletionPort)
{
// Enumerate available protocol providers for the specified socket type.
WSAPROTOCOL_INFOW stackInfos[4];
std::unique_ptr<WSAPROTOCOL_INFOW[]> heapInfos;
WSAPROTOCOL_INFOW* selectedProtocolInfo = nullptr;
{
INT protocols[] = { protocol, 0 };
DWORD bufferSize = sizeof(stackInfos);
WSAPROTOCOL_INFOW* infos = stackInfos;
int protocolCount = ::WSAEnumProtocolsW(protocols, infos, &bufferSize);
if (protocolCount == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
if (errorCode == WSAENOBUFS)
{
DWORD requiredElementCount = bufferSize / sizeof(WSAPROTOCOL_INFOW);
heapInfos = std::make_unique<WSAPROTOCOL_INFOW[]>(requiredElementCount);
bufferSize = requiredElementCount * sizeof(WSAPROTOCOL_INFOW);
infos = heapInfos.get();
protocolCount = ::WSAEnumProtocolsW(protocols, infos, &bufferSize);
if (protocolCount == SOCKET_ERROR)
{
errorCode = ::WSAGetLastError();
}
}
if (protocolCount == SOCKET_ERROR)
{
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: WSAEnumProtocolsW");
}
}
if (protocolCount == 0)
{
throw std::system_error(
std::make_error_code(std::errc::protocol_not_supported));
}
for (int i = 0; i < protocolCount; ++i)
{
auto& info = infos[i];
if (info.iAddressFamily == addressFamily && info.iProtocol == protocol && info.iSocketType == socketType)
{
selectedProtocolInfo = &info;
break;
}
}
if (selectedProtocolInfo == nullptr)
{
throw std::system_error(
std::make_error_code(std::errc::address_family_not_supported));
}
}
// WSA_FLAG_NO_HANDLE_INHERIT for SDKs earlier than Windows 7.
constexpr DWORD flagNoInherit = 0x80;
const DWORD flags = WSA_FLAG_OVERLAPPED | flagNoInherit;
const SOCKET socketHandle = ::WSASocketW(
addressFamily, socketType, protocol, selectedProtocolInfo, 0, flags);
if (socketHandle == INVALID_SOCKET)
{
const int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: WSASocketW");
}
auto closeSocketOnFailure = cppcoro::on_scope_failure([&]
{
::closesocket(socketHandle);
});
// This is needed on operating systems earlier than Windows 7 to prevent
// socket handles from being inherited. On Windows 7 or later this is
// redundant as the WSA_FLAG_NO_HANDLE_INHERIT flag passed to creation
// above causes the socket to be atomically created with this flag cleared.
if (!::SetHandleInformation((HANDLE)socketHandle, HANDLE_FLAG_INHERIT, 0))
{
const DWORD errorCode = ::GetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: SetHandleInformation");
}
// Associate the socket with the I/O completion port.
{
const HANDLE result = ::CreateIoCompletionPort(
(HANDLE)socketHandle,
ioCompletionPort,
ULONG_PTR(0),
DWORD(0));
if (result == nullptr)
{
const DWORD errorCode = ::GetLastError();
throw std::system_error(
static_cast<int>(errorCode),
std::system_category(),
"Error creating socket: CreateIoCompletionPort");
}
}
const bool skipCompletionPortOnSuccess =
(selectedProtocolInfo->dwServiceFlags1 & XP1_IFS_HANDLES) != 0;
{
UCHAR completionModeFlags = FILE_SKIP_SET_EVENT_ON_HANDLE;
if (skipCompletionPortOnSuccess)
{
completionModeFlags |= FILE_SKIP_COMPLETION_PORT_ON_SUCCESS;
}
const BOOL ok = ::SetFileCompletionNotificationModes(
(HANDLE)socketHandle,
completionModeFlags);
if (!ok)
{
const DWORD errorCode = ::GetLastError();
throw std::system_error(
static_cast<int>(errorCode),
std::system_category(),
"Error creating socket: SetFileCompletionNotificationModes");
}
}
if (socketType == SOCK_STREAM)
{
// Turn off linger so that the destructor doesn't block while closing
// the socket or silently continue to flush remaining data in the
// background after ::closesocket() is called, which could fail and
// we'd never know about it.
// We expect clients to call Disconnect() or use CloseSend() to cleanly
// shut-down connections instead.
BOOL value = TRUE;
const int result = ::setsockopt(socketHandle,
SOL_SOCKET,
SO_DONTLINGER,
reinterpret_cast<const char*>(&value),
sizeof(value));
if (result == SOCKET_ERROR)
{
const int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: setsockopt(SO_DONTLINGER)");
}
}
return std::make_tuple(socketHandle, skipCompletionPortOnSuccess);
}
}
}
cppcoro::net::socket cppcoro::net::socket::create_tcpv4(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv4_endpoint();
result.m_remoteEndPoint = ipv4_endpoint();
return result;
}
cppcoro::net::socket cppcoro::net::socket::create_tcpv6(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET6, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv6_endpoint();
result.m_remoteEndPoint = ipv6_endpoint();
return result;
}
cppcoro::net::socket cppcoro::net::socket::create_udpv4(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv4_endpoint();
result.m_remoteEndPoint = ipv4_endpoint();
return result;
}
cppcoro::net::socket cppcoro::net::socket::create_udpv6(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET6, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv6_endpoint();
result.m_remoteEndPoint = ipv6_endpoint();
return result;
}
cppcoro::net::socket::socket(socket&& other) noexcept
: m_handle(std::exchange(other.m_handle, INVALID_SOCKET))
, m_skipCompletionOnSuccess(other.m_skipCompletionOnSuccess)
, m_localEndPoint(std::move(other.m_localEndPoint))
, m_remoteEndPoint(std::move(other.m_remoteEndPoint))
{}
cppcoro::net::socket::~socket()
{
if (m_handle != INVALID_SOCKET)
{
::closesocket(m_handle);
}
}
cppcoro::net::socket&
cppcoro::net::socket::operator=(socket&& other) noexcept
{
auto handle = std::exchange(other.m_handle, INVALID_SOCKET);
if (m_handle != INVALID_SOCKET)
{
::closesocket(m_handle);
}
m_handle = handle;
m_skipCompletionOnSuccess = other.m_skipCompletionOnSuccess;
m_localEndPoint = other.m_localEndPoint;
m_remoteEndPoint = other.m_remoteEndPoint;
return *this;
}
void cppcoro::net::socket::bind(const ip_endpoint& localEndPoint)
{
SOCKADDR_STORAGE sockaddrStorage = { 0 };
SOCKADDR* sockaddr = reinterpret_cast<SOCKADDR*>(&sockaddrStorage);
if (localEndPoint.is_ipv4())
{
SOCKADDR_IN& ipv4Sockaddr = *reinterpret_cast<SOCKADDR_IN*>(sockaddr);
ipv4Sockaddr.sin_family = AF_INET;
std::memcpy(&ipv4Sockaddr.sin_addr, localEndPoint.to_ipv4().address().bytes(), 4);
ipv4Sockaddr.sin_port = localEndPoint.to_ipv4().port();
}
else
{
SOCKADDR_IN6& ipv6Sockaddr = *reinterpret_cast<SOCKADDR_IN6*>(sockaddr);
ipv6Sockaddr.sin6_family = AF_INET6;
std::memcpy(&ipv6Sockaddr.sin6_addr, localEndPoint.to_ipv6().address().bytes(), 16);
ipv6Sockaddr.sin6_port = localEndPoint.to_ipv6().port();
}
int result = ::bind(m_handle, sockaddr, sizeof(sockaddrStorage));
if (result != 0)
{
// WSANOTINITIALISED: WSAStartup not called
// WSAENETDOWN: network subsystem failed
// WSAEACCES: access denied
// WSAEADDRINUSE: port in use
// WSAEADDRNOTAVAIL: address is not an address that can be bound to
// WSAEFAULT: invalid pointer passed to bind()
// WSAEINPROGRESS: a callback is in progress
// WSAEINVAL: socket already bound
// WSAENOBUFS: system failed to allocate memory
// WSAENOTSOCK: socket was not a valid socket.
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error binding to endpoint: bind()");
}
int sockaddrLen = sizeof(sockaddrStorage);
result = ::getsockname(m_handle, sockaddr, &sockaddrLen);
if (result == 0)
{
m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(*sockaddr);
}
else
{
m_localEndPoint = localEndPoint;
}
}
void cppcoro::net::socket::listen()
{
int result = ::listen(m_handle, SOMAXCONN);
if (result != 0)
{
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Failed to start listening on bound endpoint: listen");
}
}
void cppcoro::net::socket::listen(std::uint32_t backlog)
{
if (backlog > 0x7FFFFFFF)
{
backlog = 0x7FFFFFFF;
}
int result = ::listen(m_handle, (int)backlog);
if (result != 0)
{
// WSANOTINITIALISED: WSAStartup not called
// WSAENETDOWN: network subsystem failed
// WSAEADDRINUSE: port in use
// WSAEINPROGRESS: a callback is in progress
// WSAEINVAL: socket not yet bound
// WSAEISCONN: socket already connected
// WSAEMFILE: no more socket descriptors available
// WSAENOBUFS: system failed to allocate memory
// WSAENOTSOCK: socket was not a valid socket.
// WSAEOPNOTSUPP: The socket does not support listening
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Failed to start listening on bound endpoint: listen");
}
}
cppcoro::net::socket_accept_operation
cppcoro::net::socket::accept(socket& acceptingSocket) noexcept
{
return socket_accept_operation{ *this, acceptingSocket };
}
cppcoro::net::socket_accept_operation_cancellable
cppcoro::net::socket::accept(socket& acceptingSocket, cancellation_token ct) noexcept
{
return socket_accept_operation_cancellable{ *this, acceptingSocket, std::move(ct) };
}
cppcoro::net::socket_connect_operation
cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint) noexcept
{
return socket_connect_operation{ *this, remoteEndPoint };
}
cppcoro::net::socket_connect_operation_cancellable
cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint, cancellation_token ct) noexcept
{
return socket_connect_operation_cancellable{ *this, remoteEndPoint, std::move(ct) };
}
cppcoro::net::socket_disconnect_operation
cppcoro::net::socket::disconnect() noexcept
{
return socket_disconnect_operation(*this);
}
cppcoro::net::socket_disconnect_operation_cancellable
cppcoro::net::socket::disconnect(cancellation_token ct) noexcept
{
return socket_disconnect_operation_cancellable{ *this, std::move(ct) };
}
cppcoro::net::socket_send_operation
cppcoro::net::socket::send(const void* buffer, std::size_t byteCount) noexcept
{
return socket_send_operation{ *this, buffer, byteCount };
}
cppcoro::net::socket_send_operation_cancellable
cppcoro::net::socket::send(const void* buffer, std::size_t byteCount, cancellation_token ct) noexcept
{
return socket_send_operation_cancellable{ *this, buffer, byteCount, std::move(ct) };
}
cppcoro::net::socket_recv_operation
cppcoro::net::socket::recv(void* buffer, std::size_t byteCount) noexcept
{
return socket_recv_operation{ *this, buffer, byteCount };
}
cppcoro::net::socket_recv_operation_cancellable
cppcoro::net::socket::recv(void* buffer, std::size_t byteCount, cancellation_token ct) noexcept
{
return socket_recv_operation_cancellable{ *this, buffer, byteCount, std::move(ct) };
}
cppcoro::net::socket_recv_from_operation
cppcoro::net::socket::recv_from(void* buffer, std::size_t byteCount) noexcept
{
return socket_recv_from_operation{ *this, buffer, byteCount };
}
cppcoro::net::socket_recv_from_operation_cancellable
cppcoro::net::socket::recv_from(void* buffer, std::size_t byteCount, cancellation_token ct) noexcept
{
return socket_recv_from_operation_cancellable{ *this, buffer, byteCount, std::move(ct) };
}
cppcoro::net::socket_send_to_operation
cppcoro::net::socket::send_to(const ip_endpoint& destination, const void* buffer, std::size_t byteCount) noexcept
{
return socket_send_to_operation{ *this, destination, buffer, byteCount };
}
cppcoro::net::socket_send_to_operation_cancellable
cppcoro::net::socket::send_to(const ip_endpoint& destination, const void* buffer, std::size_t byteCount, cancellation_token ct) noexcept
{
return socket_send_to_operation_cancellable{ *this, destination, buffer, byteCount, std::move(ct) };
}
void cppcoro::net::socket::close_send()
{
int result = ::shutdown(m_handle, SD_SEND);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"failed to close socket send stream: shutdown(SD_SEND)");
}
}
void cppcoro::net::socket::close_recv()
{
int result = ::shutdown(m_handle, SD_RECEIVE);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"failed to close socket receive stream: shutdown(SD_RECEIVE)");
}
}
cppcoro::net::socket::socket(
cppcoro::detail::win32::socket_t handle,
bool skipCompletionOnSuccess) noexcept
: m_handle(handle)
, m_skipCompletionOnSuccess(skipCompletionOnSuccess)
{
}
#endif