diff options
author | laanwj <126646+laanwj@users.noreply.github.com> | 2022-06-16 19:55:03 +0200 |
---|---|---|
committer | laanwj <126646+laanwj@users.noreply.github.com> | 2022-06-16 20:05:03 +0200 |
commit | 0ea92cad5274f3939f09d6890da31a21b8481282 (patch) | |
tree | 085f42448cb1ed6c375a3c20a14352b6ed5d8fd5 /src | |
parent | 489b5876698f9bb2d93b1b1d62d514148b31effd (diff) | |
parent | 6e68ccbefea6509c61fc4405a391a517c6057bb0 (diff) |
Merge bitcoin/bitcoin#24356: refactor: replace CConnman::SocketEvents() with mockable Sock::WaitMany()
6e68ccbefea6509c61fc4405a391a517c6057bb0 net: use Sock::WaitMany() instead of CConnman::SocketEvents() (Vasil Dimov)
ae263460bab9e6aa112dc99790c8ef06a56ec838 net: introduce Sock::WaitMany() (Vasil Dimov)
cc74459768063a923fb6220a4f420eaf211aee7b net: also wait for exceptional events in Sock::Wait() (Vasil Dimov)
Pull request description:
_This is a piece of #21878, chopped off to ease review._
`Sock::Wait()` waits for IO events on one socket. Introduce a similar `virtual` method `WaitMany()` that waits simultaneously for IO events on more than one socket.
Use `WaitMany()` instead of `CConnman::SocketEvents()` (and ditch the latter). Given that the former is a `virtual` method, it can be mocked by unit and fuzz tests. This will help to make bigger parts of `CConnman` testable (unit and fuzz).
ACKs for top commit:
laanwj:
Code review ACK 6e68ccbefea6509c61fc4405a391a517c6057bb0
jonatack:
re-ACK 6e68ccbefea6509c61fc4405a391a517c6057bb0 per `git range-diff e18fd47 6747729 6e68ccb`, and verified rebase to master and debug build
Tree-SHA512: 917fb6ad880d64d3af1ebb301c06fbd01afd8ff043f49e4055a088ebed6affb7ffe1dcf59292d822f10de5f323b6d52d557cb081dd7434634995f9148efcf08f
Diffstat (limited to 'src')
-rw-r--r-- | src/i2p.cpp | 4 | ||||
-rw-r--r-- | src/net.cpp | 178 | ||||
-rw-r--r-- | src/net.h | 38 | ||||
-rw-r--r-- | src/test/fuzz/util.cpp | 9 | ||||
-rw-r--r-- | src/test/fuzz/util.h | 2 | ||||
-rw-r--r-- | src/test/util/net.h | 9 | ||||
-rw-r--r-- | src/util/sock.cpp | 112 | ||||
-rw-r--r-- | src/util/sock.h | 71 |
8 files changed, 199 insertions, 224 deletions
diff --git a/src/i2p.cpp b/src/i2p.cpp index eccb048bb3..caff8c1e69 100644 --- a/src/i2p.cpp +++ b/src/i2p.cpp @@ -150,8 +150,8 @@ bool Session::Accept(Connection& conn) throw std::runtime_error("wait on socket failed"); } - if ((occurred & Sock::RECV) == 0) { - // Timeout, no incoming connections within MAX_WAIT_FOR_IO. + if (occurred == 0) { + // Timeout, no incoming connections or errors within MAX_WAIT_FOR_IO. continue; } diff --git a/src/net.cpp b/src/net.cpp index d0c95dafb4..d42f130af7 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1404,13 +1404,12 @@ bool CConnman::InactivityCheck(const CNode& node) const return false; } -bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes, - std::set<SOCKET>& recv_set, - std::set<SOCKET>& send_set, - std::set<SOCKET>& error_set) +Sock::EventsPerSock CConnman::GenerateWaitSockets(Span<CNode* const> nodes) { + Sock::EventsPerSock events_per_sock; + for (const ListenSocket& hListenSocket : vhListenSocket) { - recv_set.insert(hListenSocket.sock->Get()); + events_per_sock.emplace(hListenSocket.sock, Sock::Events{Sock::RECV}); } for (CNode* pnode : nodes) { @@ -1437,172 +1436,49 @@ bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes, continue; } - error_set.insert(pnode->m_sock->Get()); + Sock::Event requested{0}; if (select_send) { - send_set.insert(pnode->m_sock->Get()); - continue; - } - if (select_recv) { - recv_set.insert(pnode->m_sock->Get()); - } - } - - return !recv_set.empty() || !send_set.empty() || !error_set.empty(); -} - -#ifdef USE_POLL -void CConnman::SocketEvents(const std::vector<CNode*>& nodes, - std::set<SOCKET>& recv_set, - std::set<SOCKET>& send_set, - std::set<SOCKET>& error_set) -{ - std::set<SOCKET> recv_select_set, send_select_set, error_select_set; - if (!GenerateSelectSet(nodes, recv_select_set, send_select_set, error_select_set)) { - interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS)); - return; - } - - std::unordered_map<SOCKET, struct pollfd> pollfds; - for (SOCKET socket_id : recv_select_set) { - pollfds[socket_id].fd = socket_id; - pollfds[socket_id].events |= POLLIN; - } - - for (SOCKET socket_id : send_select_set) { - pollfds[socket_id].fd = socket_id; - pollfds[socket_id].events |= POLLOUT; - } - - for (SOCKET socket_id : error_select_set) { - pollfds[socket_id].fd = socket_id; - // These flags are ignored, but we set them for clarity - pollfds[socket_id].events |= POLLERR|POLLHUP; - } - - std::vector<struct pollfd> vpollfds; - vpollfds.reserve(pollfds.size()); - for (auto it : pollfds) { - vpollfds.push_back(std::move(it.second)); - } - - if (poll(vpollfds.data(), vpollfds.size(), SELECT_TIMEOUT_MILLISECONDS) < 0) return; - - if (interruptNet) return; - - for (struct pollfd pollfd_entry : vpollfds) { - if (pollfd_entry.revents & POLLIN) recv_set.insert(pollfd_entry.fd); - if (pollfd_entry.revents & POLLOUT) send_set.insert(pollfd_entry.fd); - if (pollfd_entry.revents & (POLLERR|POLLHUP)) error_set.insert(pollfd_entry.fd); - } -} -#else -void CConnman::SocketEvents(const std::vector<CNode*>& nodes, - std::set<SOCKET>& recv_set, - std::set<SOCKET>& send_set, - std::set<SOCKET>& error_set) -{ - std::set<SOCKET> recv_select_set, send_select_set, error_select_set; - if (!GenerateSelectSet(nodes, recv_select_set, send_select_set, error_select_set)) { - interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS)); - return; - } - - // - // Find which sockets have data to receive - // - struct timeval timeout; - timeout.tv_sec = 0; - timeout.tv_usec = SELECT_TIMEOUT_MILLISECONDS * 1000; // frequency to poll pnode->vSend - - fd_set fdsetRecv; - fd_set fdsetSend; - fd_set fdsetError; - FD_ZERO(&fdsetRecv); - FD_ZERO(&fdsetSend); - FD_ZERO(&fdsetError); - SOCKET hSocketMax = 0; - - for (SOCKET hSocket : recv_select_set) { - FD_SET(hSocket, &fdsetRecv); - hSocketMax = std::max(hSocketMax, hSocket); - } - - for (SOCKET hSocket : send_select_set) { - FD_SET(hSocket, &fdsetSend); - hSocketMax = std::max(hSocketMax, hSocket); - } - - for (SOCKET hSocket : error_select_set) { - FD_SET(hSocket, &fdsetError); - hSocketMax = std::max(hSocketMax, hSocket); - } - - int nSelect = select(hSocketMax + 1, &fdsetRecv, &fdsetSend, &fdsetError, &timeout); - - if (interruptNet) - return; - - if (nSelect == SOCKET_ERROR) - { - int nErr = WSAGetLastError(); - LogPrintf("socket select error %s\n", NetworkErrorString(nErr)); - for (unsigned int i = 0; i <= hSocketMax; i++) - FD_SET(i, &fdsetRecv); - FD_ZERO(&fdsetSend); - FD_ZERO(&fdsetError); - if (!interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS))) - return; - } - - for (SOCKET hSocket : recv_select_set) { - if (FD_ISSET(hSocket, &fdsetRecv)) { - recv_set.insert(hSocket); + requested = Sock::SEND; + } else if (select_recv) { + requested = Sock::RECV; } - } - for (SOCKET hSocket : send_select_set) { - if (FD_ISSET(hSocket, &fdsetSend)) { - send_set.insert(hSocket); - } + events_per_sock.emplace(pnode->m_sock, Sock::Events{requested}); } - for (SOCKET hSocket : error_select_set) { - if (FD_ISSET(hSocket, &fdsetError)) { - error_set.insert(hSocket); - } - } + return events_per_sock; } -#endif void CConnman::SocketHandler() { AssertLockNotHeld(m_total_bytes_sent_mutex); - std::set<SOCKET> recv_set; - std::set<SOCKET> send_set; - std::set<SOCKET> error_set; + Sock::EventsPerSock events_per_sock; { const NodesSnapshot snap{*this, /*shuffle=*/false}; + const auto timeout = std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS); + // Check for the readiness of the already connected sockets and the // listening sockets in one call ("readiness" as in poll(2) or // select(2)). If none are ready, wait for a short while and return // empty sets. - SocketEvents(snap.Nodes(), recv_set, send_set, error_set); + events_per_sock = GenerateWaitSockets(snap.Nodes()); + if (events_per_sock.empty() || !events_per_sock.begin()->first->WaitMany(timeout, events_per_sock)) { + interruptNet.sleep_for(timeout); + } // Service (send/receive) each of the already connected nodes. - SocketHandlerConnected(snap.Nodes(), recv_set, send_set, error_set); + SocketHandlerConnected(snap.Nodes(), events_per_sock); } // Accept new connections from listening sockets. - SocketHandlerListening(recv_set); + SocketHandlerListening(events_per_sock); } void CConnman::SocketHandlerConnected(const std::vector<CNode*>& nodes, - const std::set<SOCKET>& recv_set, - const std::set<SOCKET>& send_set, - const std::set<SOCKET>& error_set) + const Sock::EventsPerSock& events_per_sock) { AssertLockNotHeld(m_total_bytes_sent_mutex); @@ -1621,9 +1497,12 @@ void CConnman::SocketHandlerConnected(const std::vector<CNode*>& nodes, if (!pnode->m_sock) { continue; } - recvSet = recv_set.count(pnode->m_sock->Get()) > 0; - sendSet = send_set.count(pnode->m_sock->Get()) > 0; - errorSet = error_set.count(pnode->m_sock->Get()) > 0; + const auto it = events_per_sock.find(pnode->m_sock); + if (it != events_per_sock.end()) { + recvSet = it->second.occurred & Sock::RECV; + sendSet = it->second.occurred & Sock::SEND; + errorSet = it->second.occurred & Sock::ERR; + } } if (recvSet || errorSet) { @@ -1693,13 +1572,14 @@ void CConnman::SocketHandlerConnected(const std::vector<CNode*>& nodes, } } -void CConnman::SocketHandlerListening(const std::set<SOCKET>& recv_set) +void CConnman::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock) { for (const ListenSocket& listen_socket : vhListenSocket) { if (interruptNet) { return; } - if (recv_set.count(listen_socket.sock->Get()) > 0) { + const auto it = events_per_sock.find(listen_socket.sock); + if (it != events_per_sock.end() && it->second.occurred & Sock::RECV) { AcceptConnection(listen_socket); } } @@ -980,28 +980,9 @@ private: /** * Generate a collection of sockets to check for IO readiness. * @param[in] nodes Select from these nodes' sockets. - * @param[out] recv_set Sockets to check for read readiness. - * @param[out] send_set Sockets to check for write readiness. - * @param[out] error_set Sockets to check for errors. - * @return true if at least one socket is to be checked (the returned set is not empty) + * @return sockets to check for readiness */ - bool GenerateSelectSet(const std::vector<CNode*>& nodes, - std::set<SOCKET>& recv_set, - std::set<SOCKET>& send_set, - std::set<SOCKET>& error_set); - - /** - * Check which sockets are ready for IO. - * @param[in] nodes Select from these nodes' sockets. - * @param[out] recv_set Sockets which are ready for read. - * @param[out] send_set Sockets which are ready for write. - * @param[out] error_set Sockets which have errors. - * This calls `GenerateSelectSet()` to gather a list of sockets to check. - */ - void SocketEvents(const std::vector<CNode*>& nodes, - std::set<SOCKET>& recv_set, - std::set<SOCKET>& send_set, - std::set<SOCKET>& error_set); + Sock::EventsPerSock GenerateWaitSockets(Span<CNode* const> nodes); /** * Check connected and listening sockets for IO readiness and process them accordingly. @@ -1010,23 +991,18 @@ private: /** * Do the read/write for connected sockets that are ready for IO. - * @param[in] nodes Nodes to process. The socket of each node is checked against - * `recv_set`, `send_set` and `error_set`. - * @param[in] recv_set Sockets that are ready for read. - * @param[in] send_set Sockets that are ready for send. - * @param[in] error_set Sockets that have an exceptional condition (error). + * @param[in] nodes Nodes to process. The socket of each node is checked against `what`. + * @param[in] events_per_sock Sockets that are ready for IO. */ void SocketHandlerConnected(const std::vector<CNode*>& nodes, - const std::set<SOCKET>& recv_set, - const std::set<SOCKET>& send_set, - const std::set<SOCKET>& error_set) + const Sock::EventsPerSock& events_per_sock) EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc); /** * Accept incoming connections, one from each read-ready listening socket. - * @param[in] recv_set Sockets that are ready for read. + * @param[in] events_per_sock Sockets that are ready for IO. */ - void SocketHandlerListening(const std::set<SOCKET>& recv_set); + void SocketHandlerListening(const Sock::EventsPerSock& events_per_sock); void ThreadSocketHandler() EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc); void ThreadDNSAddressSeed() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_nodes_mutex); diff --git a/src/test/fuzz/util.cpp b/src/test/fuzz/util.cpp index 033c6e18d5..883698aff1 100644 --- a/src/test/fuzz/util.cpp +++ b/src/test/fuzz/util.cpp @@ -223,6 +223,15 @@ bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* return true; } +bool FuzzedSock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const +{ + for (auto& [sock, events] : events_per_sock) { + (void)sock; + events.occurred = m_fuzzed_data_provider.ConsumeBool() ? events.requested : 0; + } + return true; +} + bool FuzzedSock::IsConnected(std::string& errmsg) const { if (m_fuzzed_data_provider.ConsumeBool()) { diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index 3fc6fa1cd5..66d00b1767 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -71,6 +71,8 @@ public: bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override; + bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override; + bool IsConnected(std::string& errmsg) const override; }; diff --git a/src/test/util/net.h b/src/test/util/net.h index e980fe4967..37d278645a 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -162,6 +162,15 @@ public: return true; } + bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override + { + for (auto& [sock, events] : events_per_sock) { + (void)sock; + events.occurred = events.requested; + } + return true; + } + private: const std::string m_contents; mutable size_t m_consumed; diff --git a/src/util/sock.cpp b/src/util/sock.cpp index 3579af4458..7d5069423a 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -113,63 +113,103 @@ int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { -#ifdef USE_POLL - pollfd fd; - fd.fd = m_socket; - fd.events = 0; - if (requested & RECV) { - fd.events |= POLLIN; - } - if (requested & SEND) { - fd.events |= POLLOUT; - } + // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want + // `this` to be destroyed when the `shared_ptr` goes out of scope at the + // end of this function. Create it with a custom noop deleter. + std::shared_ptr<const Sock> shared{this, [](const Sock*) {}}; + + EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})}; - if (poll(&fd, 1, count_milliseconds(timeout)) == SOCKET_ERROR) { + if (!WaitMany(timeout, events_per_sock)) { return false; } if (occurred != nullptr) { - *occurred = 0; - if (fd.revents & POLLIN) { - *occurred |= RECV; + *occurred = events_per_sock.begin()->second.occurred; + } + + return true; +} + +bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const +{ +#ifdef USE_POLL + std::vector<pollfd> pfds; + for (const auto& [sock, events] : events_per_sock) { + pfds.emplace_back(); + auto& pfd = pfds.back(); + pfd.fd = sock->m_socket; + if (events.requested & RECV) { + pfd.events |= POLLIN; } - if (fd.revents & POLLOUT) { - *occurred |= SEND; + if (events.requested & SEND) { + pfd.events |= POLLOUT; } } - return true; -#else - if (!IsSelectableSocket(m_socket)) { + if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) { return false; } - fd_set fdset_recv; - fd_set fdset_send; - FD_ZERO(&fdset_recv); - FD_ZERO(&fdset_send); - - if (requested & RECV) { - FD_SET(m_socket, &fdset_recv); + assert(pfds.size() == events_per_sock.size()); + size_t i{0}; + for (auto& [sock, events] : events_per_sock) { + assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd)); + events.occurred = 0; + if (pfds[i].revents & POLLIN) { + events.occurred |= RECV; + } + if (pfds[i].revents & POLLOUT) { + events.occurred |= SEND; + } + if (pfds[i].revents & (POLLERR | POLLHUP)) { + events.occurred |= ERR; + } + ++i; } - if (requested & SEND) { - FD_SET(m_socket, &fdset_send); + return true; +#else + fd_set recv; + fd_set send; + fd_set err; + FD_ZERO(&recv); + FD_ZERO(&send); + FD_ZERO(&err); + SOCKET socket_max{0}; + + for (const auto& [sock, events] : events_per_sock) { + const auto& s = sock->m_socket; + if (!IsSelectableSocket(s)) { + return false; + } + if (events.requested & RECV) { + FD_SET(s, &recv); + } + if (events.requested & SEND) { + FD_SET(s, &send); + } + FD_SET(s, &err); + socket_max = std::max(socket_max, s); } - timeval timeout_struct = MillisToTimeval(timeout); + timeval tv = MillisToTimeval(timeout); - if (select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct) == SOCKET_ERROR) { + if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) { return false; } - if (occurred != nullptr) { - *occurred = 0; - if (FD_ISSET(m_socket, &fdset_recv)) { - *occurred |= RECV; + for (auto& [sock, events] : events_per_sock) { + const auto& s = sock->m_socket; + events.occurred = 0; + if (FD_ISSET(s, &recv)) { + events.occurred |= RECV; + } + if (FD_ISSET(s, &send)) { + events.occurred |= SEND; } - if (FD_ISSET(m_socket, &fdset_send)) { - *occurred |= SEND; + if (FD_ISSET(s, &err)) { + events.occurred |= ERR; } } diff --git a/src/util/sock.h b/src/util/sock.h index dd2913a66c..3245820995 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -12,6 +12,7 @@ #include <chrono> #include <memory> #include <string> +#include <unordered_map> /** * Maximum time to wait for I/O readiness. @@ -130,26 +131,84 @@ public: /** * If passed to `Wait()`, then it will wait for readiness to read from the socket. */ - static constexpr Event RECV = 0b01; + static constexpr Event RECV = 0b001; /** * If passed to `Wait()`, then it will wait for readiness to send to the socket. */ - static constexpr Event SEND = 0b10; + static constexpr Event SEND = 0b010; + + /** + * Ignored if passed to `Wait()`, but could be set in the occurred events if an + * exceptional condition has occurred on the socket or if it has been disconnected. + */ + static constexpr Event ERR = 0b100; /** * Wait for readiness for input (recv) or output (send). * @param[in] timeout Wait this much for at least one of the requested events to occur. * @param[in] requested Wait for those events, bitwise-or of `RECV` and `SEND`. - * @param[out] occurred If not nullptr and `true` is returned, then upon return this - * indicates which of the requested events occurred. A timeout is indicated by return - * value of `true` and `occurred` being set to 0. - * @return true on success and false otherwise + * @param[out] occurred If not nullptr and the function returns `true`, then this + * indicates which of the requested events occurred (`ERR` will be added, even if + * not requested, if an exceptional event occurs on the socket). + * A timeout is indicated by return value of `true` and `occurred` being set to 0. + * @return true on success (or timeout, if `occurred` of 0 is returned), false otherwise */ [[nodiscard]] virtual bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const; + /** + * Auxiliary requested/occurred events to wait for in `WaitMany()`. + */ + struct Events { + explicit Events(Event req) : requested{req}, occurred{0} {} + Event requested; + Event occurred; + }; + + struct HashSharedPtrSock { + size_t operator()(const std::shared_ptr<const Sock>& s) const + { + return s ? s->m_socket : std::numeric_limits<SOCKET>::max(); + } + }; + + struct EqualSharedPtrSock { + bool operator()(const std::shared_ptr<const Sock>& lhs, + const std::shared_ptr<const Sock>& rhs) const + { + if (lhs && rhs) { + return lhs->m_socket == rhs->m_socket; + } + if (!lhs && !rhs) { + return true; + } + return false; + } + }; + + /** + * On which socket to wait for what events in `WaitMany()`. + * The `shared_ptr` is copied into the map to ensure that the `Sock` object + * is not destroyed (its destructor would close the underlying socket). + * If this happens shortly before or after we call `poll(2)` and a new + * socket gets created under the same file descriptor number then the report + * from `WaitMany()` will be bogus. + */ + using EventsPerSock = std::unordered_map<std::shared_ptr<const Sock>, Events, HashSharedPtrSock, EqualSharedPtrSock>; + + /** + * Same as `Wait()`, but wait on many sockets within the same timeout. + * @param[in] timeout Wait this long for at least one of the requested events to occur. + * @param[in,out] events_per_sock Wait for the requested events on these sockets and set + * `occurred` for the events that actually occurred. + * @return true on success (or timeout, if all `what[].occurred` are returned as 0), + * false otherwise + */ + [[nodiscard]] virtual bool WaitMany(std::chrono::milliseconds timeout, + EventsPerSock& events_per_sock) const; + /* Higher level, convenience, methods. These may throw. */ /** |