aboutsummaryrefslogtreecommitdiff
path: root/src/util
diff options
context:
space:
mode:
authorVasil Dimov <vd@FreeBSD.org>2020-12-23 16:40:11 +0100
committerVasil Dimov <vd@FreeBSD.org>2021-02-10 13:30:08 +0100
commitba9d73268f9585d4b9254adcf54708f88222798b (patch)
tree7817023a5deab6c4bcb0de90b748e49dc026f1ab /src/util
parentdec9b5e850c6aad989e814aea5b630b36f55d580 (diff)
downloadbitcoin-ba9d73268f9585d4b9254adcf54708f88222798b.tar.xz
net: add RAII socket and use it instead of bare SOCKET
Introduce a class to manage the lifetime of a socket - when the object that contains the socket goes out of scope, the underlying socket will be closed. In addition, the new `Sock` class has a `Send()`, `Recv()` and `Wait()` methods that can be overridden by unit tests to mock the socket operations. The `Wait()` method also hides the `#ifdef USE_POLL poll() #else select() #endif` technique from higher level code.
Diffstat (limited to 'src/util')
-rw-r--r--src/util/sock.cpp85
-rw-r--r--src/util/sock.h100
-rw-r--r--src/util/time.cpp5
-rw-r--r--src/util/time.h6
4 files changed, 196 insertions, 0 deletions
diff --git a/src/util/sock.cpp b/src/util/sock.cpp
index 35eca4afb1..4c65b5b680 100644
--- a/src/util/sock.cpp
+++ b/src/util/sock.cpp
@@ -6,12 +6,97 @@
#include <logging.h>
#include <tinyformat.h>
#include <util/sock.h>
+#include <util/system.h>
+#include <util/time.h>
#include <codecvt>
#include <cwchar>
#include <locale>
#include <string>
+#ifdef USE_POLL
+#include <poll.h>
+#endif
+
+Sock::Sock() : m_socket(INVALID_SOCKET) {}
+
+Sock::Sock(SOCKET s) : m_socket(s) {}
+
+Sock::Sock(Sock&& other)
+{
+ m_socket = other.m_socket;
+ other.m_socket = INVALID_SOCKET;
+}
+
+Sock::~Sock() { Reset(); }
+
+Sock& Sock::operator=(Sock&& other)
+{
+ Reset();
+ m_socket = other.m_socket;
+ other.m_socket = INVALID_SOCKET;
+ return *this;
+}
+
+SOCKET Sock::Get() const { return m_socket; }
+
+SOCKET Sock::Release()
+{
+ const SOCKET s = m_socket;
+ m_socket = INVALID_SOCKET;
+ return s;
+}
+
+void Sock::Reset() { CloseSocket(m_socket); }
+
+ssize_t Sock::Send(const void* data, size_t len, int flags) const
+{
+ return send(m_socket, static_cast<const char*>(data), len, flags);
+}
+
+ssize_t Sock::Recv(void* buf, size_t len, int flags) const
+{
+ return recv(m_socket, static_cast<char*>(buf), len, flags);
+}
+
+bool Sock::Wait(std::chrono::milliseconds timeout, Event requested) const
+{
+#ifdef USE_POLL
+ pollfd fd;
+ fd.fd = m_socket;
+ fd.events = 0;
+ if (requested & RECV) {
+ fd.events |= POLLIN;
+ }
+ if (requested & SEND) {
+ fd.events |= POLLOUT;
+ }
+
+ return poll(&fd, 1, count_milliseconds(timeout)) != SOCKET_ERROR;
+#else
+ if (!IsSelectableSocket(m_socket)) {
+ return false;
+ }
+
+ fd_set fdset_recv;
+ fd_set fdset_send;
+ FD_ZERO(&fdset_recv);
+ FD_ZERO(&fdset_send);
+
+ if (requested & RECV) {
+ FD_SET(m_socket, &fdset_recv);
+ }
+
+ if (requested & SEND) {
+ FD_SET(m_socket, &fdset_send);
+ }
+
+ timeval timeout_struct = MillisToTimeval(timeout);
+
+ return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct) != SOCKET_ERROR;
+#endif /* USE_POLL */
+}
+
#ifdef WIN32
std::string NetworkErrorString(int err)
{
diff --git a/src/util/sock.h b/src/util/sock.h
index 0d48235043..26fe60f18f 100644
--- a/src/util/sock.h
+++ b/src/util/sock.h
@@ -7,8 +7,108 @@
#include <compat.h>
+#include <chrono>
#include <string>
+/**
+ * RAII helper class that manages a socket. Mimics `std::unique_ptr`, but instead of a pointer it
+ * contains a socket and closes it automatically when it goes out of scope.
+ */
+class Sock
+{
+public:
+ /**
+ * Default constructor, creates an empty object that does nothing when destroyed.
+ */
+ Sock();
+
+ /**
+ * Take ownership of an existent socket.
+ */
+ explicit Sock(SOCKET s);
+
+ /**
+ * Copy constructor, disabled because closing the same socket twice is undesirable.
+ */
+ Sock(const Sock&) = delete;
+
+ /**
+ * Move constructor, grab the socket from another object and close ours (if set).
+ */
+ Sock(Sock&& other);
+
+ /**
+ * Destructor, close the socket or do nothing if empty.
+ */
+ virtual ~Sock();
+
+ /**
+ * Copy assignment operator, disabled because closing the same socket twice is undesirable.
+ */
+ Sock& operator=(const Sock&) = delete;
+
+ /**
+ * Move assignment operator, grab the socket from another object and close ours (if set).
+ */
+ virtual Sock& operator=(Sock&& other);
+
+ /**
+ * Get the value of the contained socket.
+ * @return socket or INVALID_SOCKET if empty
+ */
+ virtual SOCKET Get() const;
+
+ /**
+ * Get the value of the contained socket and drop ownership. It will not be closed by the
+ * destructor after this call.
+ * @return socket or INVALID_SOCKET if empty
+ */
+ virtual SOCKET Release();
+
+ /**
+ * Close if non-empty.
+ */
+ virtual void Reset();
+
+ /**
+ * send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this
+ * wrapper can be unit-tested if this method is overridden by a mock Sock implementation.
+ */
+ virtual ssize_t Send(const void* data, size_t len, int flags) const;
+
+ /**
+ * recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this
+ * wrapper can be unit-tested if this method is overridden by a mock Sock implementation.
+ */
+ virtual ssize_t Recv(void* buf, size_t len, int flags) const;
+
+ using Event = uint8_t;
+
+ /**
+ * If passed to `Wait()`, then it will wait for readiness to read from the socket.
+ */
+ static constexpr Event RECV = 0b01;
+
+ /**
+ * If passed to `Wait()`, then it will wait for readiness to send to the socket.
+ */
+ static constexpr Event SEND = 0b10;
+
+ /**
+ * Wait for readiness for input (recv) or output (send).
+ * @param[in] timeout Wait this much for at least one of the requested events to occur.
+ * @param[in] requested Wait for those events, bitwise-or of `RECV` and `SEND`.
+ * @return true on success and false otherwise
+ */
+ virtual bool Wait(std::chrono::milliseconds timeout, Event requested) const;
+
+private:
+ /**
+ * Contained socket. `INVALID_SOCKET` designates the object is empty.
+ */
+ SOCKET m_socket;
+};
+
/** Return readable error string for a network error code */
std::string NetworkErrorString(int err);
diff --git a/src/util/time.cpp b/src/util/time.cpp
index 4da041e5a5..4aed9f60b0 100644
--- a/src/util/time.cpp
+++ b/src/util/time.cpp
@@ -123,3 +123,8 @@ struct timeval MillisToTimeval(int64_t nTimeout)
timeout.tv_usec = (nTimeout % 1000) * 1000;
return timeout;
}
+
+struct timeval MillisToTimeval(std::chrono::milliseconds ms)
+{
+ return MillisToTimeval(count_milliseconds(ms));
+}
diff --git a/src/util/time.h b/src/util/time.h
index 2c0e3d83f6..03b75b5be5 100644
--- a/src/util/time.h
+++ b/src/util/time.h
@@ -27,6 +27,7 @@ void UninterruptibleSleep(const std::chrono::microseconds& n);
* interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI)
*/
inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); }
+inline int64_t count_milliseconds(std::chrono::milliseconds t) { return t.count(); }
inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); }
/**
@@ -64,4 +65,9 @@ int64_t ParseISO8601DateTime(const std::string& str);
*/
struct timeval MillisToTimeval(int64_t nTimeout);
+/**
+ * Convert milliseconds to a struct timeval for e.g. select.
+ */
+struct timeval MillisToTimeval(std::chrono::milliseconds ms);
+
#endif // BITCOIN_UTIL_TIME_H