diff options
author | Vasil Dimov <vd@FreeBSD.org> | 2020-12-23 16:40:11 +0100 |
---|---|---|
committer | Vasil Dimov <vd@FreeBSD.org> | 2021-02-10 13:30:08 +0100 |
commit | ba9d73268f9585d4b9254adcf54708f88222798b (patch) | |
tree | 7817023a5deab6c4bcb0de90b748e49dc026f1ab /src | |
parent | dec9b5e850c6aad989e814aea5b630b36f55d580 (diff) |
net: add RAII socket and use it instead of bare SOCKET
Introduce a class to manage the lifetime of a socket - when the object
that contains the socket goes out of scope, the underlying socket will
be closed.
In addition, the new `Sock` class has a `Send()`, `Recv()` and `Wait()`
methods that can be overridden by unit tests to mock the socket
operations.
The `Wait()` method also hides the
`#ifdef USE_POLL poll() #else select() #endif` technique from higher
level code.
Diffstat (limited to 'src')
-rw-r--r-- | src/net.cpp | 47 | ||||
-rw-r--r-- | src/netbase.cpp | 31 | ||||
-rw-r--r-- | src/netbase.h | 17 | ||||
-rw-r--r-- | src/util/sock.cpp | 85 | ||||
-rw-r--r-- | src/util/sock.h | 100 | ||||
-rw-r--r-- | src/util/time.cpp | 5 | ||||
-rw-r--r-- | src/util/time.h | 6 |
7 files changed, 250 insertions, 41 deletions
diff --git a/src/net.cpp b/src/net.cpp index 38aaeff121..2a3669b90e 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -429,24 +429,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo // Connect bool connected = false; - SOCKET hSocket = INVALID_SOCKET; + std::unique_ptr<Sock> sock; proxyType proxy; if (addrConnect.IsValid()) { bool proxyConnectionFailed = false; if (GetProxy(addrConnect.GetNetwork(), proxy)) { - hSocket = CreateSocket(proxy.proxy); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(proxy.proxy); + if (!sock) { return nullptr; } - connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), hSocket, nConnectTimeout, proxyConnectionFailed); + connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), + sock->Get(), nConnectTimeout, proxyConnectionFailed); } else { // no proxy needed (none set for target network) - hSocket = CreateSocket(addrConnect); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(addrConnect); + if (!sock) { return nullptr; } - connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, conn_type == ConnectionType::MANUAL); + connected = ConnectSocketDirectly(addrConnect, sock->Get(), nConnectTimeout, + conn_type == ConnectionType::MANUAL); } if (!proxyConnectionFailed) { // If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to @@ -454,26 +456,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo addrman.Attempt(addrConnect, fCountFailure); } } else if (pszDest && GetNameProxy(proxy)) { - hSocket = CreateSocket(proxy.proxy); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(proxy.proxy); + if (!sock) { return nullptr; } std::string host; int port = default_port; SplitHostPort(std::string(pszDest), port, host); bool proxyConnectionFailed; - connected = ConnectThroughProxy(proxy, host, port, hSocket, nConnectTimeout, proxyConnectionFailed); + connected = ConnectThroughProxy(proxy, host, port, sock->Get(), nConnectTimeout, + proxyConnectionFailed); } if (!connected) { - CloseSocket(hSocket); return nullptr; } // Add node NodeId id = GetNewNodeId(); uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); - CAddress addr_bind = GetBindAddress(hSocket); - CNode* pnode = new CNode(id, nLocalServices, hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); + CAddress addr_bind = GetBindAddress(sock->Get()); + CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); pnode->AddRef(); // We're making a new connection, harvest entropy from the time (and our peer count) @@ -2177,9 +2179,8 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - SOCKET hListenSocket = CreateSocket(addrBind); - if (hListenSocket == INVALID_SOCKET) - { + std::unique_ptr<Sock> sock = CreateSock(addrBind); + if (!sock) { strError = strprintf(Untranslated("Error: Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError.original); return false; @@ -2187,21 +2188,21 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, // Allow binding if the port is still in TIME_WAIT state after // the program was closed and restarted. - setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); + setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option // and enable it by default or not. Try to enable it, if possible. if (addrBind.IsIPv6()) { #ifdef IPV6_V6ONLY - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); + setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); #endif #ifdef WIN32 int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); + setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); #endif } - if (::bind(hListenSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) + if (::bind(sock->Get(), (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); if (nErr == WSAEADDRINUSE) @@ -2209,21 +2210,19 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, else strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToString(), NetworkErrorString(nErr)); LogPrintf("%s\n", strError.original); - CloseSocket(hListenSocket); return false; } LogPrintf("Bound to %s\n", addrBind.ToString()); // Listen for incoming connections - if (listen(hListenSocket, SOMAXCONN) == SOCKET_ERROR) + if (listen(sock->Get(), SOMAXCONN) == SOCKET_ERROR) { strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError.original); - CloseSocket(hListenSocket); return false; } - vhListenSocket.push_back(ListenSocket(hListenSocket, permissions)); + vhListenSocket.push_back(ListenSocket(sock->Release(), permissions)); return true; } diff --git a/src/netbase.cpp b/src/netbase.cpp index 3a3407f901..93a04ab5b4 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -15,7 +15,9 @@ #include <atomic> #include <cstdint> +#include <functional> #include <limits> +#include <memory> #ifndef WIN32 #include <fcntl.h> @@ -559,34 +561,28 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials return true; } -/** - * Try to create a socket file descriptor with specific properties in the - * communications domain (address family) of the specified service. - * - * For details on the desired properties, see the inline comments in the source - * code. - */ -SOCKET CreateSocket(const CService &addrConnect) +std::unique_ptr<Sock> CreateSockTCP(const CService& address_family) { // Create a sockaddr from the specified service. struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { - LogPrintf("Cannot create socket for %s: unsupported network\n", addrConnect.ToString()); - return INVALID_SOCKET; + if (!address_family.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { + LogPrintf("Cannot create socket for %s: unsupported network\n", address_family.ToString()); + return nullptr; } // Create a TCP socket in the address family of the specified service. SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); - if (hSocket == INVALID_SOCKET) - return INVALID_SOCKET; + if (hSocket == INVALID_SOCKET) { + return nullptr; + } // Ensure that waiting for I/O on this socket won't result in undefined // behavior. if (!IsSelectableSocket(hSocket)) { CloseSocket(hSocket); LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); - return INVALID_SOCKET; + return nullptr; } #ifdef SO_NOSIGPIPE @@ -602,11 +598,14 @@ SOCKET CreateSocket(const CService &addrConnect) // Set the non-blocking option on the socket. if (!SetSocketNonBlocking(hSocket, true)) { CloseSocket(hSocket); - LogPrintf("CreateSocket: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError())); + LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError())); + return nullptr; } - return hSocket; + return std::make_unique<Sock>(hSocket); } +std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP; + template<typename... Args> static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) { std::string error_message = tfm::format(fmt, args...); diff --git a/src/netbase.h b/src/netbase.h index 38d33e475b..d906888235 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -12,7 +12,10 @@ #include <compat.h> #include <netaddress.h> #include <serialize.h> +#include <util/sock.h> +#include <functional> +#include <memory> #include <stdint.h> #include <string> #include <vector> @@ -51,7 +54,19 @@ bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllo bool Lookup(const std::string& name, std::vector<CService>& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions); CService LookupNumeric(const std::string& name, int portDefault = 0); bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet); -SOCKET CreateSocket(const CService &addrConnect); + +/** + * Create a TCP socket in the given address family. + * @param[in] address_family The socket is created in the same address family as this address. + * @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure + */ +std::unique_ptr<Sock> CreateSockTCP(const CService& address_family); + +/** + * Socket factory. Defaults to `CreateSockTCP()`, but can be overridden by unit tests. + */ +extern std::function<std::unique_ptr<Sock>(const CService&)> CreateSock; + bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection); bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed); /** Disable or enable blocking-mode for a socket */ diff --git a/src/util/sock.cpp b/src/util/sock.cpp index 35eca4afb1..4c65b5b680 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -6,12 +6,97 @@ #include <logging.h> #include <tinyformat.h> #include <util/sock.h> +#include <util/system.h> +#include <util/time.h> #include <codecvt> #include <cwchar> #include <locale> #include <string> +#ifdef USE_POLL +#include <poll.h> +#endif + +Sock::Sock() : m_socket(INVALID_SOCKET) {} + +Sock::Sock(SOCKET s) : m_socket(s) {} + +Sock::Sock(Sock&& other) +{ + m_socket = other.m_socket; + other.m_socket = INVALID_SOCKET; +} + +Sock::~Sock() { Reset(); } + +Sock& Sock::operator=(Sock&& other) +{ + Reset(); + m_socket = other.m_socket; + other.m_socket = INVALID_SOCKET; + return *this; +} + +SOCKET Sock::Get() const { return m_socket; } + +SOCKET Sock::Release() +{ + const SOCKET s = m_socket; + m_socket = INVALID_SOCKET; + return s; +} + +void Sock::Reset() { CloseSocket(m_socket); } + +ssize_t Sock::Send(const void* data, size_t len, int flags) const +{ + return send(m_socket, static_cast<const char*>(data), len, flags); +} + +ssize_t Sock::Recv(void* buf, size_t len, int flags) const +{ + return recv(m_socket, static_cast<char*>(buf), len, flags); +} + +bool Sock::Wait(std::chrono::milliseconds timeout, Event requested) const +{ +#ifdef USE_POLL + pollfd fd; + fd.fd = m_socket; + fd.events = 0; + if (requested & RECV) { + fd.events |= POLLIN; + } + if (requested & SEND) { + fd.events |= POLLOUT; + } + + return poll(&fd, 1, count_milliseconds(timeout)) != SOCKET_ERROR; +#else + if (!IsSelectableSocket(m_socket)) { + return false; + } + + fd_set fdset_recv; + fd_set fdset_send; + FD_ZERO(&fdset_recv); + FD_ZERO(&fdset_send); + + if (requested & RECV) { + FD_SET(m_socket, &fdset_recv); + } + + if (requested & SEND) { + FD_SET(m_socket, &fdset_send); + } + + timeval timeout_struct = MillisToTimeval(timeout); + + return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct) != SOCKET_ERROR; +#endif /* USE_POLL */ +} + #ifdef WIN32 std::string NetworkErrorString(int err) { diff --git a/src/util/sock.h b/src/util/sock.h index 0d48235043..26fe60f18f 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -7,8 +7,108 @@ #include <compat.h> +#include <chrono> #include <string> +/** + * RAII helper class that manages a socket. Mimics `std::unique_ptr`, but instead of a pointer it + * contains a socket and closes it automatically when it goes out of scope. + */ +class Sock +{ +public: + /** + * Default constructor, creates an empty object that does nothing when destroyed. + */ + Sock(); + + /** + * Take ownership of an existent socket. + */ + explicit Sock(SOCKET s); + + /** + * Copy constructor, disabled because closing the same socket twice is undesirable. + */ + Sock(const Sock&) = delete; + + /** + * Move constructor, grab the socket from another object and close ours (if set). + */ + Sock(Sock&& other); + + /** + * Destructor, close the socket or do nothing if empty. + */ + virtual ~Sock(); + + /** + * Copy assignment operator, disabled because closing the same socket twice is undesirable. + */ + Sock& operator=(const Sock&) = delete; + + /** + * Move assignment operator, grab the socket from another object and close ours (if set). + */ + virtual Sock& operator=(Sock&& other); + + /** + * Get the value of the contained socket. + * @return socket or INVALID_SOCKET if empty + */ + virtual SOCKET Get() const; + + /** + * Get the value of the contained socket and drop ownership. It will not be closed by the + * destructor after this call. + * @return socket or INVALID_SOCKET if empty + */ + virtual SOCKET Release(); + + /** + * Close if non-empty. + */ + virtual void Reset(); + + /** + * send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Send(const void* data, size_t len, int flags) const; + + /** + * recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Recv(void* buf, size_t len, int flags) const; + + using Event = uint8_t; + + /** + * If passed to `Wait()`, then it will wait for readiness to read from the socket. + */ + static constexpr Event RECV = 0b01; + + /** + * If passed to `Wait()`, then it will wait for readiness to send to the socket. + */ + static constexpr Event SEND = 0b10; + + /** + * Wait for readiness for input (recv) or output (send). + * @param[in] timeout Wait this much for at least one of the requested events to occur. + * @param[in] requested Wait for those events, bitwise-or of `RECV` and `SEND`. + * @return true on success and false otherwise + */ + virtual bool Wait(std::chrono::milliseconds timeout, Event requested) const; + +private: + /** + * Contained socket. `INVALID_SOCKET` designates the object is empty. + */ + SOCKET m_socket; +}; + /** Return readable error string for a network error code */ std::string NetworkErrorString(int err); diff --git a/src/util/time.cpp b/src/util/time.cpp index 4da041e5a5..4aed9f60b0 100644 --- a/src/util/time.cpp +++ b/src/util/time.cpp @@ -123,3 +123,8 @@ struct timeval MillisToTimeval(int64_t nTimeout) timeout.tv_usec = (nTimeout % 1000) * 1000; return timeout; } + +struct timeval MillisToTimeval(std::chrono::milliseconds ms) +{ + return MillisToTimeval(count_milliseconds(ms)); +} diff --git a/src/util/time.h b/src/util/time.h index 2c0e3d83f6..03b75b5be5 100644 --- a/src/util/time.h +++ b/src/util/time.h @@ -27,6 +27,7 @@ void UninterruptibleSleep(const std::chrono::microseconds& n); * interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI) */ inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); } +inline int64_t count_milliseconds(std::chrono::milliseconds t) { return t.count(); } inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); } /** @@ -64,4 +65,9 @@ int64_t ParseISO8601DateTime(const std::string& str); */ struct timeval MillisToTimeval(int64_t nTimeout); +/** + * Convert milliseconds to a struct timeval for e.g. select. + */ +struct timeval MillisToTimeval(std::chrono::milliseconds ms); + #endif // BITCOIN_UTIL_TIME_H |