diff options
-rw-r--r-- | src/test/fuzz/util.cpp | 9 | ||||
-rw-r--r-- | src/test/fuzz/util.h | 2 | ||||
-rw-r--r-- | src/test/util/net.h | 9 | ||||
-rw-r--r-- | src/util/sock.cpp | 120 | ||||
-rw-r--r-- | src/util/sock.h | 52 |
5 files changed, 147 insertions, 45 deletions
diff --git a/src/test/fuzz/util.cpp b/src/test/fuzz/util.cpp index 033c6e18d5..883698aff1 100644 --- a/src/test/fuzz/util.cpp +++ b/src/test/fuzz/util.cpp @@ -223,6 +223,15 @@ bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* return true; } +bool FuzzedSock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const +{ + for (auto& [sock, events] : events_per_sock) { + (void)sock; + events.occurred = m_fuzzed_data_provider.ConsumeBool() ? events.requested : 0; + } + return true; +} + bool FuzzedSock::IsConnected(std::string& errmsg) const { if (m_fuzzed_data_provider.ConsumeBool()) { diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index 580105e442..e5f7612c55 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -72,6 +72,8 @@ public: bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override; + bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override; + bool IsConnected(std::string& errmsg) const override; }; diff --git a/src/test/util/net.h b/src/test/util/net.h index e980fe4967..37d278645a 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -162,6 +162,15 @@ public: return true; } + bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override + { + for (auto& [sock, events] : events_per_sock) { + (void)sock; + events.occurred = events.requested; + } + return true; + } + private: const std::string m_contents; mutable size_t m_consumed; diff --git a/src/util/sock.cpp b/src/util/sock.cpp index 8b86cf74ab..7d5069423a 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -113,73 +113,103 @@ int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { -#ifdef USE_POLL - pollfd fd; - fd.fd = m_socket; - fd.events = 0; - if (requested & RECV) { - fd.events |= POLLIN; - } - if (requested & SEND) { - fd.events |= POLLOUT; - } + // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want + // `this` to be destroyed when the `shared_ptr` goes out of scope at the + // end of this function. Create it with a custom noop deleter. + std::shared_ptr<const Sock> shared{this, [](const Sock*) {}}; + + EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})}; - if (poll(&fd, 1, count_milliseconds(timeout)) == SOCKET_ERROR) { + if (!WaitMany(timeout, events_per_sock)) { return false; } if (occurred != nullptr) { - *occurred = 0; - if (fd.revents & POLLIN) { - *occurred |= RECV; - } - if (fd.revents & POLLOUT) { - *occurred |= SEND; + *occurred = events_per_sock.begin()->second.occurred; + } + + return true; +} + +bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const +{ +#ifdef USE_POLL + std::vector<pollfd> pfds; + for (const auto& [sock, events] : events_per_sock) { + pfds.emplace_back(); + auto& pfd = pfds.back(); + pfd.fd = sock->m_socket; + if (events.requested & RECV) { + pfd.events |= POLLIN; } - if (fd.revents & (POLLERR | POLLHUP)) { - *occurred |= ERR; + if (events.requested & SEND) { + pfd.events |= POLLOUT; } } - return true; -#else - if (!IsSelectableSocket(m_socket)) { + if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) { return false; } - fd_set fdset_recv; - fd_set fdset_send; - fd_set fdset_err; - FD_ZERO(&fdset_recv); - FD_ZERO(&fdset_send); - FD_ZERO(&fdset_err); - - if (requested & RECV) { - FD_SET(m_socket, &fdset_recv); + assert(pfds.size() == events_per_sock.size()); + size_t i{0}; + for (auto& [sock, events] : events_per_sock) { + assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd)); + events.occurred = 0; + if (pfds[i].revents & POLLIN) { + events.occurred |= RECV; + } + if (pfds[i].revents & POLLOUT) { + events.occurred |= SEND; + } + if (pfds[i].revents & (POLLERR | POLLHUP)) { + events.occurred |= ERR; + } + ++i; } - if (requested & SEND) { - FD_SET(m_socket, &fdset_send); + return true; +#else + fd_set recv; + fd_set send; + fd_set err; + FD_ZERO(&recv); + FD_ZERO(&send); + FD_ZERO(&err); + SOCKET socket_max{0}; + + for (const auto& [sock, events] : events_per_sock) { + const auto& s = sock->m_socket; + if (!IsSelectableSocket(s)) { + return false; + } + if (events.requested & RECV) { + FD_SET(s, &recv); + } + if (events.requested & SEND) { + FD_SET(s, &send); + } + FD_SET(s, &err); + socket_max = std::max(socket_max, s); } - FD_SET(m_socket, &fdset_err); - - timeval timeout_struct = MillisToTimeval(timeout); + timeval tv = MillisToTimeval(timeout); - if (select(m_socket + 1, &fdset_recv, &fdset_send, &fdset_err, &timeout_struct) == SOCKET_ERROR) { + if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) { return false; } - if (occurred != nullptr) { - *occurred = 0; - if (FD_ISSET(m_socket, &fdset_recv)) { - *occurred |= RECV; + for (auto& [sock, events] : events_per_sock) { + const auto& s = sock->m_socket; + events.occurred = 0; + if (FD_ISSET(s, &recv)) { + events.occurred |= RECV; } - if (FD_ISSET(m_socket, &fdset_send)) { - *occurred |= SEND; + if (FD_ISSET(s, &send)) { + events.occurred |= SEND; } - if (FD_ISSET(m_socket, &fdset_err)) { - *occurred |= ERR; + if (FD_ISSET(s, &err)) { + events.occurred |= ERR; } } diff --git a/src/util/sock.h b/src/util/sock.h index e6419eb8ce..3245820995 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -12,6 +12,7 @@ #include <chrono> #include <memory> #include <string> +#include <unordered_map> /** * Maximum time to wait for I/O readiness. @@ -157,6 +158,57 @@ public: Event requested, Event* occurred = nullptr) const; + /** + * Auxiliary requested/occurred events to wait for in `WaitMany()`. + */ + struct Events { + explicit Events(Event req) : requested{req}, occurred{0} {} + Event requested; + Event occurred; + }; + + struct HashSharedPtrSock { + size_t operator()(const std::shared_ptr<const Sock>& s) const + { + return s ? s->m_socket : std::numeric_limits<SOCKET>::max(); + } + }; + + struct EqualSharedPtrSock { + bool operator()(const std::shared_ptr<const Sock>& lhs, + const std::shared_ptr<const Sock>& rhs) const + { + if (lhs && rhs) { + return lhs->m_socket == rhs->m_socket; + } + if (!lhs && !rhs) { + return true; + } + return false; + } + }; + + /** + * On which socket to wait for what events in `WaitMany()`. + * The `shared_ptr` is copied into the map to ensure that the `Sock` object + * is not destroyed (its destructor would close the underlying socket). + * If this happens shortly before or after we call `poll(2)` and a new + * socket gets created under the same file descriptor number then the report + * from `WaitMany()` will be bogus. + */ + using EventsPerSock = std::unordered_map<std::shared_ptr<const Sock>, Events, HashSharedPtrSock, EqualSharedPtrSock>; + + /** + * Same as `Wait()`, but wait on many sockets within the same timeout. + * @param[in] timeout Wait this long for at least one of the requested events to occur. + * @param[in,out] events_per_sock Wait for the requested events on these sockets and set + * `occurred` for the events that actually occurred. + * @return true on success (or timeout, if all `what[].occurred` are returned as 0), + * false otherwise + */ + [[nodiscard]] virtual bool WaitMany(std::chrono::milliseconds timeout, + EventsPerSock& events_per_sock) const; + /* Higher level, convenience, methods. These may throw. */ /** |