diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Makefile.am | 2 | ||||
-rw-r--r-- | src/Makefile.test.include | 1 | ||||
-rw-r--r-- | src/net.cpp | 48 | ||||
-rw-r--r-- | src/netbase.cpp | 152 | ||||
-rw-r--r-- | src/netbase.h | 27 | ||||
-rw-r--r-- | src/test/sock_tests.cpp | 149 | ||||
-rw-r--r-- | src/torcontrol.cpp | 1 | ||||
-rw-r--r-- | src/util/sock.cpp | 149 | ||||
-rw-r--r-- | src/util/sock.h | 118 | ||||
-rw-r--r-- | src/util/time.cpp | 14 | ||||
-rw-r--r-- | src/util/time.h | 13 |
11 files changed, 527 insertions, 147 deletions
diff --git a/src/Makefile.am b/src/Makefile.am index 20f1316302..bc661fccbb 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -243,6 +243,7 @@ BITCOIN_CORE_H = \ util/rbf.h \ util/ref.h \ util/settings.h \ + util/sock.h \ util/spanparsing.h \ util/string.h \ util/system.h \ @@ -559,6 +560,7 @@ libbitcoin_util_a_SOURCES = \ util/fees.cpp \ util/getuniquepath.cpp \ util/hasher.cpp \ + util/sock.cpp \ util/system.cpp \ util/message.cpp \ util/moneystr.cpp \ diff --git a/src/Makefile.test.include b/src/Makefile.test.include index 77cba466ba..e817bb2ee2 100644 --- a/src/Makefile.test.include +++ b/src/Makefile.test.include @@ -124,6 +124,7 @@ BITCOIN_TESTS =\ test/sighash_tests.cpp \ test/sigopcount_tests.cpp \ test/skiplist_tests.cpp \ + test/sock_tests.cpp \ test/streams_tests.cpp \ test/sync_tests.cpp \ test/system_tests.cpp \ diff --git a/src/net.cpp b/src/net.cpp index d004aace88..5fa405a690 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -20,6 +20,7 @@ #include <protocol.h> #include <random.h> #include <scheduler.h> +#include <util/sock.h> #include <util/strencodings.h> #include <util/translation.h> @@ -429,24 +430,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo // Connect bool connected = false; - SOCKET hSocket = INVALID_SOCKET; + std::unique_ptr<Sock> sock; proxyType proxy; if (addrConnect.IsValid()) { bool proxyConnectionFailed = false; if (GetProxy(addrConnect.GetNetwork(), proxy)) { - hSocket = CreateSocket(proxy.proxy); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(proxy.proxy); + if (!sock) { return nullptr; } - connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), hSocket, nConnectTimeout, proxyConnectionFailed); + connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), + *sock, nConnectTimeout, proxyConnectionFailed); } else { // no proxy needed (none set for target network) - hSocket = CreateSocket(addrConnect); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(addrConnect); + if (!sock) { return nullptr; } - connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, conn_type == ConnectionType::MANUAL); + connected = ConnectSocketDirectly(addrConnect, sock->Get(), nConnectTimeout, + conn_type == ConnectionType::MANUAL); } if (!proxyConnectionFailed) { // If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to @@ -454,26 +457,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo addrman.Attempt(addrConnect, fCountFailure); } } else if (pszDest && GetNameProxy(proxy)) { - hSocket = CreateSocket(proxy.proxy); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(proxy.proxy); + if (!sock) { return nullptr; } std::string host; int port = default_port; SplitHostPort(std::string(pszDest), port, host); bool proxyConnectionFailed; - connected = ConnectThroughProxy(proxy, host, port, hSocket, nConnectTimeout, proxyConnectionFailed); + connected = ConnectThroughProxy(proxy, host, port, *sock, nConnectTimeout, + proxyConnectionFailed); } if (!connected) { - CloseSocket(hSocket); return nullptr; } // Add node NodeId id = GetNewNodeId(); uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); - CAddress addr_bind = GetBindAddress(hSocket); - CNode* pnode = new CNode(id, nLocalServices, hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); + CAddress addr_bind = GetBindAddress(sock->Get()); + CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); pnode->AddRef(); // We're making a new connection, harvest entropy from the time (and our peer count) @@ -2188,9 +2191,8 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - SOCKET hListenSocket = CreateSocket(addrBind); - if (hListenSocket == INVALID_SOCKET) - { + std::unique_ptr<Sock> sock = CreateSock(addrBind); + if (!sock) { strError = strprintf(Untranslated("Error: Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError.original); return false; @@ -2198,21 +2200,21 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, // Allow binding if the port is still in TIME_WAIT state after // the program was closed and restarted. - setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); + setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option // and enable it by default or not. Try to enable it, if possible. if (addrBind.IsIPv6()) { #ifdef IPV6_V6ONLY - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); + setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); #endif #ifdef WIN32 int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); + setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); #endif } - if (::bind(hListenSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) + if (::bind(sock->Get(), (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); if (nErr == WSAEADDRINUSE) @@ -2220,21 +2222,19 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, else strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToString(), NetworkErrorString(nErr)); LogPrintf("%s\n", strError.original); - CloseSocket(hListenSocket); return false; } LogPrintf("Bound to %s\n", addrBind.ToString()); // Listen for incoming connections - if (listen(hListenSocket, SOMAXCONN) == SOCKET_ERROR) + if (listen(sock->Get(), SOMAXCONN) == SOCKET_ERROR) { strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError.original); - CloseSocket(hListenSocket); return false; } - vhListenSocket.push_back(ListenSocket(hListenSocket, permissions)); + vhListenSocket.push_back(ListenSocket(sock->Release(), permissions)); return true; } diff --git a/src/netbase.cpp b/src/netbase.cpp index 264029d8a2..24188f83c6 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -7,13 +7,17 @@ #include <sync.h> #include <tinyformat.h> +#include <util/sock.h> #include <util/strencodings.h> #include <util/string.h> #include <util/system.h> +#include <util/time.h> #include <atomic> #include <cstdint> +#include <functional> #include <limits> +#include <memory> #ifndef WIN32 #include <fcntl.h> @@ -271,14 +275,6 @@ CService LookupNumeric(const std::string& name, int portDefault) return addr; } -struct timeval MillisToTimeval(int64_t nTimeout) -{ - struct timeval timeout; - timeout.tv_sec = nTimeout / 1000; - timeout.tv_usec = (nTimeout % 1000) * 1000; - return timeout; -} - /** SOCKS version */ enum SOCKSVersion: uint8_t { SOCKS4 = 0x04, @@ -336,8 +332,7 @@ enum class IntrRecvError { * @param data The buffer where the read bytes should be stored. * @param len The number of bytes to read into the specified buffer. * @param timeout The total timeout in milliseconds for this read. - * @param hSocket The socket (has to be in non-blocking mode) from which to read - * bytes. + * @param sock The socket (has to be in non-blocking mode) from which to read bytes. * * @returns An IntrRecvError indicating the resulting status of this read. * IntrRecvError::OK only if all of the specified number of bytes were @@ -347,7 +342,7 @@ enum class IntrRecvError { * Sockets can be made non-blocking with SetSocketNonBlocking(const * SOCKET&, bool). */ -static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const SOCKET& hSocket) +static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const Sock& sock) { int64_t curTime = GetTimeMillis(); int64_t endTime = curTime + timeout; @@ -355,7 +350,7 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c // (in millis) to break off in case of an interruption. const int64_t maxWait = 1000; while (len > 0 && curTime < endTime) { - ssize_t ret = recv(hSocket, (char*)data, len, 0); // Optimistically try the recv first + ssize_t ret = sock.Recv(data, len, 0); // Optimistically try the recv first if (ret > 0) { len -= ret; data += ret; @@ -364,25 +359,10 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c } else { // Other error or blocking int nErr = WSAGetLastError(); if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) { - if (!IsSelectableSocket(hSocket)) { - return IntrRecvError::NetworkError; - } // Only wait at most maxWait milliseconds at a time, unless // we're approaching the end of the specified total timeout int timeout_ms = std::min(endTime - curTime, maxWait); -#ifdef USE_POLL - struct pollfd pollfd = {}; - pollfd.fd = hSocket; - pollfd.events = POLLIN; - int nRet = poll(&pollfd, 1, timeout_ms); -#else - struct timeval tval = MillisToTimeval(timeout_ms); - fd_set fdset; - FD_ZERO(&fdset); - FD_SET(hSocket, &fdset); - int nRet = select(hSocket + 1, &fdset, nullptr, nullptr, &tval); -#endif - if (nRet == SOCKET_ERROR) { + if (!sock.Wait(std::chrono::milliseconds{timeout_ms}, Sock::RECV)) { return IntrRecvError::NetworkError; } } else { @@ -436,7 +416,7 @@ static std::string Socks5ErrorString(uint8_t err) * @param port The destination port. * @param auth The credentials with which to authenticate with the specified * SOCKS5 proxy. - * @param hSocket The SOCKS5 proxy socket. + * @param sock The SOCKS5 proxy socket. * * @returns Whether or not the operation succeeded. * @@ -446,7 +426,7 @@ static std::string Socks5ErrorString(uint8_t err) * @see <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC1928: SOCKS Protocol * Version 5</a> */ -static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, const SOCKET& hSocket) +static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* auth, const Sock& sock) { IntrRecvError recvr; LogPrint(BCLog::NET, "SOCKS5 connecting %s\n", strDest); @@ -464,12 +444,12 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vSocks5Init.push_back(0x01); // 1 method identifier follows... vSocks5Init.push_back(SOCKS5Method::NOAUTH); } - ssize_t ret = send(hSocket, (const char*)vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL); + ssize_t ret = sock.Send(vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5Init.size()) { return error("Error sending to proxy"); } uint8_t pchRet1[2]; - if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { LogPrintf("Socks5() connect to %s:%d failed: InterruptibleRecv() timeout or other failure\n", strDest, port); return false; } @@ -486,13 +466,13 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vAuth.insert(vAuth.end(), auth->username.begin(), auth->username.end()); vAuth.push_back(auth->password.size()); vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end()); - ret = send(hSocket, (const char*)vAuth.data(), vAuth.size(), MSG_NOSIGNAL); + ret = sock.Send(vAuth.data(), vAuth.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vAuth.size()) { return error("Error sending authentication to proxy"); } LogPrint(BCLog::PROXY, "SOCKS5 sending proxy authentication %s:%s\n", auth->username, auth->password); uint8_t pchRetA[2]; - if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { return error("Error reading proxy authentication response"); } if (pchRetA[0] != 0x01 || pchRetA[1] != 0x00) { @@ -512,12 +492,12 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vSocks5.insert(vSocks5.end(), strDest.begin(), strDest.end()); vSocks5.push_back((port >> 8) & 0xFF); vSocks5.push_back((port >> 0) & 0xFF); - ret = send(hSocket, (const char*)vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL); + ret = sock.Send(vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5.size()) { return error("Error sending to proxy"); } uint8_t pchRet2[4]; - if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { if (recvr == IntrRecvError::Timeout) { /* If a timeout happens here, this effectively means we timed out while connecting * to the remote node. This is very common for Tor, so do not print an @@ -541,16 +521,16 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials uint8_t pchRet3[256]; switch (pchRet2[3]) { - case SOCKS5Atyp::IPV4: recvr = InterruptibleRecv(pchRet3, 4, SOCKS5_RECV_TIMEOUT, hSocket); break; - case SOCKS5Atyp::IPV6: recvr = InterruptibleRecv(pchRet3, 16, SOCKS5_RECV_TIMEOUT, hSocket); break; + case SOCKS5Atyp::IPV4: recvr = InterruptibleRecv(pchRet3, 4, SOCKS5_RECV_TIMEOUT, sock); break; + case SOCKS5Atyp::IPV6: recvr = InterruptibleRecv(pchRet3, 16, SOCKS5_RECV_TIMEOUT, sock); break; case SOCKS5Atyp::DOMAINNAME: { - recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, hSocket); + recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, sock); if (recvr != IntrRecvError::OK) { return error("Error reading from proxy"); } int nRecv = pchRet3[0]; - recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, hSocket); + recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, sock); break; } default: return error("Error: malformed proxy response"); @@ -558,41 +538,35 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials if (recvr != IntrRecvError::OK) { return error("Error reading from proxy"); } - if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { return error("Error reading from proxy"); } LogPrint(BCLog::NET, "SOCKS5 connected %s\n", strDest); return true; } -/** - * Try to create a socket file descriptor with specific properties in the - * communications domain (address family) of the specified service. - * - * For details on the desired properties, see the inline comments in the source - * code. - */ -SOCKET CreateSocket(const CService &addrConnect) +std::unique_ptr<Sock> CreateSockTCP(const CService& address_family) { // Create a sockaddr from the specified service. struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { - LogPrintf("Cannot create socket for %s: unsupported network\n", addrConnect.ToString()); - return INVALID_SOCKET; + if (!address_family.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { + LogPrintf("Cannot create socket for %s: unsupported network\n", address_family.ToString()); + return nullptr; } // Create a TCP socket in the address family of the specified service. SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); - if (hSocket == INVALID_SOCKET) - return INVALID_SOCKET; + if (hSocket == INVALID_SOCKET) { + return nullptr; + } // Ensure that waiting for I/O on this socket won't result in undefined // behavior. if (!IsSelectableSocket(hSocket)) { CloseSocket(hSocket); LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); - return INVALID_SOCKET; + return nullptr; } #ifdef SO_NOSIGPIPE @@ -608,11 +582,14 @@ SOCKET CreateSocket(const CService &addrConnect) // Set the non-blocking option on the socket. if (!SetSocketNonBlocking(hSocket, true)) { CloseSocket(hSocket); - LogPrintf("CreateSocket: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError())); + LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError())); + return nullptr; } - return hSocket; + return std::make_unique<Sock>(hSocket); } +std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP; + template<typename... Args> static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) { std::string error_message = tfm::format(fmt, args...); @@ -786,7 +763,7 @@ bool IsProxy(const CNetAddr &addr) { * @param proxy The SOCKS5 proxy. * @param strDest The destination service to which to connect. * @param port The destination port. - * @param hSocket The socket on which to connect to the SOCKS5 proxy. + * @param sock The socket on which to connect to the SOCKS5 proxy. * @param nTimeout Wait this many milliseconds for the connection to the SOCKS5 * proxy to be established. * @param[out] outProxyConnectionFailed Whether or not the connection to the @@ -794,10 +771,10 @@ bool IsProxy(const CNetAddr &addr) { * * @returns Whether or not the operation succeeded. */ -bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocket, int nTimeout, bool& outProxyConnectionFailed) +bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& sock, int nTimeout, bool& outProxyConnectionFailed) { // first connect to proxy server - if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout, true)) { + if (!ConnectSocketDirectly(proxy.proxy, sock.Get(), nTimeout, true)) { outProxyConnectionFailed = true; return false; } @@ -806,11 +783,11 @@ bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int ProxyCredentials random_auth; static std::atomic_int counter(0); random_auth.username = random_auth.password = strprintf("%i", counter++); - if (!Socks5(strDest, (uint16_t)port, &random_auth, hSocket)) { + if (!Socks5(strDest, (uint16_t)port, &random_auth, sock)) { return false; } } else { - if (!Socks5(strDest, (uint16_t)port, 0, hSocket)) { + if (!Socks5(strDest, (uint16_t)port, 0, sock)) { return false; } } @@ -869,57 +846,6 @@ bool LookupSubNet(const std::string& strSubnet, CSubNet& ret) return false; } -#ifdef WIN32 -std::string NetworkErrorString(int err) -{ - wchar_t buf[256]; - buf[0] = 0; - if(FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_MAX_WIDTH_MASK, - nullptr, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - buf, ARRAYSIZE(buf), nullptr)) - { - return strprintf("%s (%d)", std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>,wchar_t>().to_bytes(buf), err); - } - else - { - return strprintf("Unknown error (%d)", err); - } -} -#else -std::string NetworkErrorString(int err) -{ - char buf[256]; - buf[0] = 0; - /* Too bad there are two incompatible implementations of the - * thread-safe strerror. */ - const char *s; -#ifdef STRERROR_R_CHAR_P /* GNU variant can return a pointer outside the passed buffer */ - s = strerror_r(err, buf, sizeof(buf)); -#else /* POSIX variant always returns message in buffer */ - s = buf; - if (strerror_r(err, buf, sizeof(buf))) - buf[0] = 0; -#endif - return strprintf("%s (%d)", s, err); -} -#endif - -bool CloseSocket(SOCKET& hSocket) -{ - if (hSocket == INVALID_SOCKET) - return false; -#ifdef WIN32 - int ret = closesocket(hSocket); -#else - int ret = close(hSocket); -#endif - if (ret) { - LogPrintf("Socket close failed: %d. Error: %s\n", hSocket, NetworkErrorString(WSAGetLastError())); - } - hSocket = INVALID_SOCKET; - return ret != SOCKET_ERROR; -} - bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking) { if (fNonBlocking) { diff --git a/src/netbase.h b/src/netbase.h index ac4cd97673..afc373ef49 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -12,7 +12,10 @@ #include <compat.h> #include <netaddress.h> #include <serialize.h> +#include <util/sock.h> +#include <functional> +#include <memory> #include <stdint.h> #include <string> #include <vector> @@ -51,21 +54,25 @@ bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllo bool Lookup(const std::string& name, std::vector<CService>& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions); CService LookupNumeric(const std::string& name, int portDefault = 0); bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet); -SOCKET CreateSocket(const CService &addrConnect); + +/** + * Create a TCP socket in the given address family. + * @param[in] address_family The socket is created in the same address family as this address. + * @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure + */ +std::unique_ptr<Sock> CreateSockTCP(const CService& address_family); + +/** + * Socket factory. Defaults to `CreateSockTCP()`, but can be overridden by unit tests. + */ +extern std::function<std::unique_ptr<Sock>(const CService&)> CreateSock; + bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection); -bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed); -/** Return readable error string for a network error code */ -std::string NetworkErrorString(int err); -/** Close socket and set hSocket to INVALID_SOCKET */ -bool CloseSocket(SOCKET& hSocket); +bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& sock, int nTimeout, bool& outProxyConnectionFailed); /** Disable or enable blocking-mode for a socket */ bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking); /** Set the TCP_NODELAY flag on a socket */ bool SetSocketNoDelay(const SOCKET& hSocket); -/** - * Convert milliseconds to a struct timeval for e.g. select. - */ -struct timeval MillisToTimeval(int64_t nTimeout); void InterruptSocks5(bool interrupt); #endif // BITCOIN_NETBASE_H diff --git a/src/test/sock_tests.cpp b/src/test/sock_tests.cpp new file mode 100644 index 0000000000..cc0e6e7057 --- /dev/null +++ b/src/test/sock_tests.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2021-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <compat.h> +#include <test/util/setup_common.h> +#include <util/sock.h> +#include <util/system.h> + +#include <boost/test/unit_test.hpp> + +#include <thread> + +using namespace std::chrono_literals; + +BOOST_FIXTURE_TEST_SUITE(sock_tests, BasicTestingSetup) + +static bool SocketIsClosed(const SOCKET& s) +{ + // Notice that if another thread is running and creates its own socket after `s` has been + // closed, it may be assigned the same file descriptor number. In this case, our test will + // wrongly pretend that the socket is not closed. + int type; + socklen_t len = sizeof(type); + return getsockopt(s, SOL_SOCKET, SO_TYPE, (sockopt_arg_type)&type, &len) == SOCKET_ERROR; +} + +static SOCKET CreateSocket() +{ + const SOCKET s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + BOOST_REQUIRE(s != static_cast<SOCKET>(SOCKET_ERROR)); + return s; +} + +BOOST_AUTO_TEST_CASE(constructor_and_destructor) +{ + const SOCKET s = CreateSocket(); + Sock* sock = new Sock(s); + BOOST_CHECK_EQUAL(sock->Get(), s); + BOOST_CHECK(!SocketIsClosed(s)); + delete sock; + BOOST_CHECK(SocketIsClosed(s)); +} + +BOOST_AUTO_TEST_CASE(move_constructor) +{ + const SOCKET s = CreateSocket(); + Sock* sock1 = new Sock(s); + Sock* sock2 = new Sock(std::move(*sock1)); + delete sock1; + BOOST_CHECK(!SocketIsClosed(s)); + BOOST_CHECK_EQUAL(sock2->Get(), s); + delete sock2; + BOOST_CHECK(SocketIsClosed(s)); +} + +BOOST_AUTO_TEST_CASE(move_assignment) +{ + const SOCKET s = CreateSocket(); + Sock* sock1 = new Sock(s); + Sock* sock2 = new Sock(); + *sock2 = std::move(*sock1); + delete sock1; + BOOST_CHECK(!SocketIsClosed(s)); + BOOST_CHECK_EQUAL(sock2->Get(), s); + delete sock2; + BOOST_CHECK(SocketIsClosed(s)); +} + +BOOST_AUTO_TEST_CASE(release) +{ + SOCKET s = CreateSocket(); + Sock* sock = new Sock(s); + BOOST_CHECK_EQUAL(sock->Release(), s); + delete sock; + BOOST_CHECK(!SocketIsClosed(s)); + BOOST_REQUIRE(CloseSocket(s)); +} + +BOOST_AUTO_TEST_CASE(reset) +{ + const SOCKET s = CreateSocket(); + Sock sock(s); + sock.Reset(); + BOOST_CHECK(SocketIsClosed(s)); +} + +#ifndef WIN32 // Windows does not have socketpair(2). + +static void CreateSocketPair(int s[2]) +{ + BOOST_REQUIRE_EQUAL(socketpair(AF_UNIX, SOCK_STREAM, 0, s), 0); +} + +static void SendAndRecvMessage(const Sock& sender, const Sock& receiver) +{ + const char* msg = "abcd"; + constexpr size_t msg_len = 4; + char recv_buf[10]; + + BOOST_CHECK_EQUAL(sender.Send(msg, msg_len, 0), msg_len); + BOOST_CHECK_EQUAL(receiver.Recv(recv_buf, sizeof(recv_buf), 0), msg_len); + BOOST_CHECK_EQUAL(strncmp(msg, recv_buf, msg_len), 0); +} + +BOOST_AUTO_TEST_CASE(send_and_receive) +{ + int s[2]; + CreateSocketPair(s); + + Sock* sock0 = new Sock(s[0]); + Sock* sock1 = new Sock(s[1]); + + SendAndRecvMessage(*sock0, *sock1); + + Sock* sock0moved = new Sock(std::move(*sock0)); + Sock* sock1moved = new Sock(); + *sock1moved = std::move(*sock1); + + delete sock0; + delete sock1; + + SendAndRecvMessage(*sock1moved, *sock0moved); + + delete sock0moved; + delete sock1moved; + + BOOST_CHECK(SocketIsClosed(s[0])); + BOOST_CHECK(SocketIsClosed(s[1])); +} + +BOOST_AUTO_TEST_CASE(wait) +{ + int s[2]; + CreateSocketPair(s); + + Sock sock0(s[0]); + Sock sock1(s[1]); + + std::thread waiter([&sock0]() { sock0.Wait(24h, Sock::RECV); }); + + BOOST_REQUIRE_EQUAL(sock1.Send("a", 1, 0), 1); + + waiter.join(); +} + +#endif /* WIN32 */ + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/torcontrol.cpp b/src/torcontrol.cpp index 908ad35e1b..605c77fc3a 100644 --- a/src/torcontrol.cpp +++ b/src/torcontrol.cpp @@ -14,6 +14,7 @@ #include <netbase.h> #include <util/strencodings.h> #include <util/system.h> +#include <util/time.h> #include <deque> #include <functional> diff --git a/src/util/sock.cpp b/src/util/sock.cpp new file mode 100644 index 0000000000..4c65b5b680 --- /dev/null +++ b/src/util/sock.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2020-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <compat.h> +#include <logging.h> +#include <tinyformat.h> +#include <util/sock.h> +#include <util/system.h> +#include <util/time.h> + +#include <codecvt> +#include <cwchar> +#include <locale> +#include <string> + +#ifdef USE_POLL +#include <poll.h> +#endif + +Sock::Sock() : m_socket(INVALID_SOCKET) {} + +Sock::Sock(SOCKET s) : m_socket(s) {} + +Sock::Sock(Sock&& other) +{ + m_socket = other.m_socket; + other.m_socket = INVALID_SOCKET; +} + +Sock::~Sock() { Reset(); } + +Sock& Sock::operator=(Sock&& other) +{ + Reset(); + m_socket = other.m_socket; + other.m_socket = INVALID_SOCKET; + return *this; +} + +SOCKET Sock::Get() const { return m_socket; } + +SOCKET Sock::Release() +{ + const SOCKET s = m_socket; + m_socket = INVALID_SOCKET; + return s; +} + +void Sock::Reset() { CloseSocket(m_socket); } + +ssize_t Sock::Send(const void* data, size_t len, int flags) const +{ + return send(m_socket, static_cast<const char*>(data), len, flags); +} + +ssize_t Sock::Recv(void* buf, size_t len, int flags) const +{ + return recv(m_socket, static_cast<char*>(buf), len, flags); +} + +bool Sock::Wait(std::chrono::milliseconds timeout, Event requested) const +{ +#ifdef USE_POLL + pollfd fd; + fd.fd = m_socket; + fd.events = 0; + if (requested & RECV) { + fd.events |= POLLIN; + } + if (requested & SEND) { + fd.events |= POLLOUT; + } + + return poll(&fd, 1, count_milliseconds(timeout)) != SOCKET_ERROR; +#else + if (!IsSelectableSocket(m_socket)) { + return false; + } + + fd_set fdset_recv; + fd_set fdset_send; + FD_ZERO(&fdset_recv); + FD_ZERO(&fdset_send); + + if (requested & RECV) { + FD_SET(m_socket, &fdset_recv); + } + + if (requested & SEND) { + FD_SET(m_socket, &fdset_send); + } + + timeval timeout_struct = MillisToTimeval(timeout); + + return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct) != SOCKET_ERROR; +#endif /* USE_POLL */ +} + +#ifdef WIN32 +std::string NetworkErrorString(int err) +{ + wchar_t buf[256]; + buf[0] = 0; + if(FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_MAX_WIDTH_MASK, + nullptr, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + buf, ARRAYSIZE(buf), nullptr)) + { + return strprintf("%s (%d)", std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>,wchar_t>().to_bytes(buf), err); + } + else + { + return strprintf("Unknown error (%d)", err); + } +} +#else +std::string NetworkErrorString(int err) +{ + char buf[256]; + buf[0] = 0; + /* Too bad there are two incompatible implementations of the + * thread-safe strerror. */ + const char *s; +#ifdef STRERROR_R_CHAR_P /* GNU variant can return a pointer outside the passed buffer */ + s = strerror_r(err, buf, sizeof(buf)); +#else /* POSIX variant always returns message in buffer */ + s = buf; + if (strerror_r(err, buf, sizeof(buf))) + buf[0] = 0; +#endif + return strprintf("%s (%d)", s, err); +} +#endif + +bool CloseSocket(SOCKET& hSocket) +{ + if (hSocket == INVALID_SOCKET) + return false; +#ifdef WIN32 + int ret = closesocket(hSocket); +#else + int ret = close(hSocket); +#endif + if (ret) { + LogPrintf("Socket close failed: %d. Error: %s\n", hSocket, NetworkErrorString(WSAGetLastError())); + } + hSocket = INVALID_SOCKET; + return ret != SOCKET_ERROR; +} diff --git a/src/util/sock.h b/src/util/sock.h new file mode 100644 index 0000000000..26fe60f18f --- /dev/null +++ b/src/util/sock.h @@ -0,0 +1,118 @@ +// Copyright (c) 2020-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_UTIL_SOCK_H +#define BITCOIN_UTIL_SOCK_H + +#include <compat.h> + +#include <chrono> +#include <string> + +/** + * RAII helper class that manages a socket. Mimics `std::unique_ptr`, but instead of a pointer it + * contains a socket and closes it automatically when it goes out of scope. + */ +class Sock +{ +public: + /** + * Default constructor, creates an empty object that does nothing when destroyed. + */ + Sock(); + + /** + * Take ownership of an existent socket. + */ + explicit Sock(SOCKET s); + + /** + * Copy constructor, disabled because closing the same socket twice is undesirable. + */ + Sock(const Sock&) = delete; + + /** + * Move constructor, grab the socket from another object and close ours (if set). + */ + Sock(Sock&& other); + + /** + * Destructor, close the socket or do nothing if empty. + */ + virtual ~Sock(); + + /** + * Copy assignment operator, disabled because closing the same socket twice is undesirable. + */ + Sock& operator=(const Sock&) = delete; + + /** + * Move assignment operator, grab the socket from another object and close ours (if set). + */ + virtual Sock& operator=(Sock&& other); + + /** + * Get the value of the contained socket. + * @return socket or INVALID_SOCKET if empty + */ + virtual SOCKET Get() const; + + /** + * Get the value of the contained socket and drop ownership. It will not be closed by the + * destructor after this call. + * @return socket or INVALID_SOCKET if empty + */ + virtual SOCKET Release(); + + /** + * Close if non-empty. + */ + virtual void Reset(); + + /** + * send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Send(const void* data, size_t len, int flags) const; + + /** + * recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Recv(void* buf, size_t len, int flags) const; + + using Event = uint8_t; + + /** + * If passed to `Wait()`, then it will wait for readiness to read from the socket. + */ + static constexpr Event RECV = 0b01; + + /** + * If passed to `Wait()`, then it will wait for readiness to send to the socket. + */ + static constexpr Event SEND = 0b10; + + /** + * Wait for readiness for input (recv) or output (send). + * @param[in] timeout Wait this much for at least one of the requested events to occur. + * @param[in] requested Wait for those events, bitwise-or of `RECV` and `SEND`. + * @return true on success and false otherwise + */ + virtual bool Wait(std::chrono::milliseconds timeout, Event requested) const; + +private: + /** + * Contained socket. `INVALID_SOCKET` designates the object is empty. + */ + SOCKET m_socket; +}; + +/** Return readable error string for a network error code */ +std::string NetworkErrorString(int err); + +/** Close socket and set hSocket to INVALID_SOCKET */ +bool CloseSocket(SOCKET& hSocket); + +#endif // BITCOIN_UTIL_SOCK_H diff --git a/src/util/time.cpp b/src/util/time.cpp index d130e4e4d4..295806c54a 100644 --- a/src/util/time.cpp +++ b/src/util/time.cpp @@ -7,6 +7,7 @@ #include <config/bitcoin-config.h> #endif +#include <compat.h> #include <util/time.h> #include <util/check.h> @@ -117,3 +118,16 @@ int64_t ParseISO8601DateTime(const std::string& str) return 0; return (ptime - epoch).total_seconds(); } + +struct timeval MillisToTimeval(int64_t nTimeout) +{ + struct timeval timeout; + timeout.tv_sec = nTimeout / 1000; + timeout.tv_usec = (nTimeout % 1000) * 1000; + return timeout; +} + +struct timeval MillisToTimeval(std::chrono::milliseconds ms) +{ + return MillisToTimeval(count_milliseconds(ms)); +} diff --git a/src/util/time.h b/src/util/time.h index c69f604dc6..03b75b5be5 100644 --- a/src/util/time.h +++ b/src/util/time.h @@ -6,6 +6,8 @@ #ifndef BITCOIN_UTIL_TIME_H #define BITCOIN_UTIL_TIME_H +#include <compat.h> + #include <chrono> #include <stdint.h> #include <string> @@ -25,6 +27,7 @@ void UninterruptibleSleep(const std::chrono::microseconds& n); * interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI) */ inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); } +inline int64_t count_milliseconds(std::chrono::milliseconds t) { return t.count(); } inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); } /** @@ -57,4 +60,14 @@ std::string FormatISO8601DateTime(int64_t nTime); std::string FormatISO8601Date(int64_t nTime); int64_t ParseISO8601DateTime(const std::string& str); +/** + * Convert milliseconds to a struct timeval for e.g. select. + */ +struct timeval MillisToTimeval(int64_t nTimeout); + +/** + * Convert milliseconds to a struct timeval for e.g. select. + */ +struct timeval MillisToTimeval(std::chrono::milliseconds ms); + #endif // BITCOIN_UTIL_TIME_H |