aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/i2p.cpp10
-rw-r--r--src/i2p.h1
-rw-r--r--src/net.cpp10
-rw-r--r--src/netbase.cpp4
-rw-r--r--src/test/fuzz/util/net.cpp5
-rw-r--r--src/test/sock_tests.cpp30
-rw-r--r--src/test/util/net.h11
-rw-r--r--src/util/sock.cpp9
-rw-r--r--src/util/sock.h37
9 files changed, 62 insertions, 55 deletions
diff --git a/src/i2p.cpp b/src/i2p.cpp
index f03e375adf..5a3dde54ce 100644
--- a/src/i2p.cpp
+++ b/src/i2p.cpp
@@ -119,7 +119,6 @@ Session::Session(const fs::path& private_key_file,
: m_private_key_file{private_key_file},
m_control_host{control_host},
m_interrupt{interrupt},
- m_control_sock{std::make_unique<Sock>(INVALID_SOCKET)},
m_transient{false}
{
}
@@ -127,7 +126,6 @@ Session::Session(const fs::path& private_key_file,
Session::Session(const CService& control_host, CThreadInterrupt* interrupt)
: m_control_host{control_host},
m_interrupt{interrupt},
- m_control_sock{std::make_unique<Sock>(INVALID_SOCKET)},
m_transient{true}
{
}
@@ -315,7 +313,7 @@ void Session::CheckControlSock()
LOCK(m_mutex);
std::string errmsg;
- if (!m_control_sock->IsConnected(errmsg)) {
+ if (m_control_sock && !m_control_sock->IsConnected(errmsg)) {
Log("Control socket error: %s", errmsg);
Disconnect();
}
@@ -364,7 +362,7 @@ Binary Session::MyDestination() const
void Session::CreateIfNotCreatedAlready()
{
std::string errmsg;
- if (m_control_sock->IsConnected(errmsg)) {
+ if (m_control_sock && m_control_sock->IsConnected(errmsg)) {
return;
}
@@ -437,14 +435,14 @@ std::unique_ptr<Sock> Session::StreamAccept()
void Session::Disconnect()
{
- if (m_control_sock->Get() != INVALID_SOCKET) {
+ if (m_control_sock) {
if (m_session_id.empty()) {
Log("Destroying incomplete SAM session");
} else {
Log("Destroying SAM session %s", m_session_id);
}
+ m_control_sock.reset();
}
- m_control_sock = std::make_unique<Sock>(INVALID_SOCKET);
m_session_id.clear();
}
} // namespace sam
diff --git a/src/i2p.h b/src/i2p.h
index c9c99292d9..cb9da64816 100644
--- a/src/i2p.h
+++ b/src/i2p.h
@@ -261,6 +261,7 @@ private:
* ("SESSION CREATE"). With the established session id we later open
* other connections to the SAM service to accept incoming I2P
* connections and make outgoing ones.
+ * If not connected then this unique_ptr will be empty.
* See https://geti2p.net/en/docs/api/samv3
*/
std::unique_ptr<Sock> m_control_sock GUARDED_BY(m_mutex);
diff --git a/src/net.cpp b/src/net.cpp
index 6b2ef5f43d..13f4430424 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -429,12 +429,10 @@ static CAddress GetBindAddress(const Sock& sock)
CAddress addr_bind;
struct sockaddr_storage sockaddr_bind;
socklen_t sockaddr_bind_len = sizeof(sockaddr_bind);
- if (sock.Get() != INVALID_SOCKET) {
- if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) {
- addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind);
- } else {
- LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "getsockname failed\n");
- }
+ if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) {
+ addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind);
+ } else {
+ LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "getsockname failed\n");
}
return addr_bind;
}
diff --git a/src/netbase.cpp b/src/netbase.cpp
index a8419217f4..ca1a80d72f 100644
--- a/src/netbase.cpp
+++ b/src/netbase.cpp
@@ -514,10 +514,6 @@ bool ConnectSocketDirectly(const CService &addrConnect, const Sock& sock, int nT
// Create a sockaddr from the specified service.
struct sockaddr_storage sockaddr;
socklen_t len = sizeof(sockaddr);
- if (sock.Get() == INVALID_SOCKET) {
- LogPrintf("Cannot connect to %s: invalid socket\n", addrConnect.ToStringAddrPort());
- return false;
- }
if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
LogPrintf("Cannot connect to %s: unsupported network\n", addrConnect.ToStringAddrPort());
return false;
diff --git a/src/test/fuzz/util/net.cpp b/src/test/fuzz/util/net.cpp
index 1545e11065..d23e997719 100644
--- a/src/test/fuzz/util/net.cpp
+++ b/src/test/fuzz/util/net.cpp
@@ -77,9 +77,10 @@ template CNetAddr::SerParams ConsumeDeserializationParams(FuzzedDataProvider&) n
template CAddress::SerParams ConsumeDeserializationParams(FuzzedDataProvider&) noexcept;
FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider)
- : m_fuzzed_data_provider{fuzzed_data_provider}, m_selectable{fuzzed_data_provider.ConsumeBool()}
+ : Sock{fuzzed_data_provider.ConsumeIntegralInRange<SOCKET>(INVALID_SOCKET - 1, INVALID_SOCKET)},
+ m_fuzzed_data_provider{fuzzed_data_provider},
+ m_selectable{fuzzed_data_provider.ConsumeBool()}
{
- m_socket = fuzzed_data_provider.ConsumeIntegralInRange<SOCKET>(INVALID_SOCKET - 1, INVALID_SOCKET);
}
FuzzedSock::~FuzzedSock()
diff --git a/src/test/sock_tests.cpp b/src/test/sock_tests.cpp
index 26ee724bf8..5dd73dc101 100644
--- a/src/test/sock_tests.cpp
+++ b/src/test/sock_tests.cpp
@@ -38,7 +38,7 @@ BOOST_AUTO_TEST_CASE(constructor_and_destructor)
{
const SOCKET s = CreateSocket();
Sock* sock = new Sock(s);
- BOOST_CHECK_EQUAL(sock->Get(), s);
+ BOOST_CHECK(*sock == s);
BOOST_CHECK(!SocketIsClosed(s));
delete sock;
BOOST_CHECK(SocketIsClosed(s));
@@ -51,22 +51,34 @@ BOOST_AUTO_TEST_CASE(move_constructor)
Sock* sock2 = new Sock(std::move(*sock1));
delete sock1;
BOOST_CHECK(!SocketIsClosed(s));
- BOOST_CHECK_EQUAL(sock2->Get(), s);
+ BOOST_CHECK(*sock2 == s);
delete sock2;
BOOST_CHECK(SocketIsClosed(s));
}
BOOST_AUTO_TEST_CASE(move_assignment)
{
- const SOCKET s = CreateSocket();
- Sock* sock1 = new Sock(s);
- Sock* sock2 = new Sock();
+ const SOCKET s1 = CreateSocket();
+ const SOCKET s2 = CreateSocket();
+ Sock* sock1 = new Sock(s1);
+ Sock* sock2 = new Sock(s2);
+
+ BOOST_CHECK(!SocketIsClosed(s1));
+ BOOST_CHECK(!SocketIsClosed(s2));
+
*sock2 = std::move(*sock1);
+ BOOST_CHECK(!SocketIsClosed(s1));
+ BOOST_CHECK(SocketIsClosed(s2));
+ BOOST_CHECK(*sock2 == s1);
+
delete sock1;
- BOOST_CHECK(!SocketIsClosed(s));
- BOOST_CHECK_EQUAL(sock2->Get(), s);
+ BOOST_CHECK(!SocketIsClosed(s1));
+ BOOST_CHECK(SocketIsClosed(s2));
+ BOOST_CHECK(*sock2 == s1);
+
delete sock2;
- BOOST_CHECK(SocketIsClosed(s));
+ BOOST_CHECK(SocketIsClosed(s1));
+ BOOST_CHECK(SocketIsClosed(s2));
}
#ifndef WIN32 // Windows does not have socketpair(2).
@@ -98,7 +110,7 @@ BOOST_AUTO_TEST_CASE(send_and_receive)
SendAndRecvMessage(*sock0, *sock1);
Sock* sock0moved = new Sock(std::move(*sock0));
- Sock* sock1moved = new Sock();
+ Sock* sock1moved = new Sock(INVALID_SOCKET);
*sock1moved = std::move(*sock1);
delete sock0;
diff --git a/src/test/util/net.h b/src/test/util/net.h
index 0d41cf550e..497292542b 100644
--- a/src/test/util/net.h
+++ b/src/test/util/net.h
@@ -108,10 +108,10 @@ constexpr auto ALL_NETWORKS = std::array{
class StaticContentsSock : public Sock
{
public:
- explicit StaticContentsSock(const std::string& contents) : m_contents{contents}
+ explicit StaticContentsSock(const std::string& contents)
+ : Sock{INVALID_SOCKET},
+ m_contents{contents}
{
- // Just a dummy number that is not INVALID_SOCKET.
- m_socket = INVALID_SOCKET - 1;
}
~StaticContentsSock() override { m_socket = INVALID_SOCKET; }
@@ -194,6 +194,11 @@ public:
return true;
}
+ bool IsConnected(std::string&) const override
+ {
+ return true;
+ }
+
private:
const std::string m_contents;
mutable size_t m_consumed{0};
diff --git a/src/util/sock.cpp b/src/util/sock.cpp
index fd64cae404..d16dc56aa3 100644
--- a/src/util/sock.cpp
+++ b/src/util/sock.cpp
@@ -24,8 +24,6 @@ static inline bool IOErrorIsPermanent(int err)
return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
}
-Sock::Sock() : m_socket(INVALID_SOCKET) {}
-
Sock::Sock(SOCKET s) : m_socket(s) {}
Sock::Sock(Sock&& other)
@@ -44,8 +42,6 @@ Sock& Sock::operator=(Sock&& other)
return *this;
}
-SOCKET Sock::Get() const { return m_socket; }
-
ssize_t Sock::Send(const void* data, size_t len, int flags) const
{
return send(m_socket, static_cast<const char*>(data), len, flags);
@@ -411,6 +407,11 @@ void Sock::Close()
m_socket = INVALID_SOCKET;
}
+bool Sock::operator==(SOCKET s) const
+{
+ return m_socket == s;
+};
+
std::string NetworkErrorString(int err)
{
#if defined(WIN32)
diff --git a/src/util/sock.h b/src/util/sock.h
index 6bac2dfd34..d78e01929b 100644
--- a/src/util/sock.h
+++ b/src/util/sock.h
@@ -21,16 +21,12 @@
static constexpr auto MAX_WAIT_FOR_IO = 1s;
/**
- * RAII helper class that manages a socket. Mimics `std::unique_ptr`, but instead of a pointer it
- * contains a socket and closes it automatically when it goes out of scope.
+ * RAII helper class that manages a socket and closes it automatically when it goes out of scope.
*/
class Sock
{
public:
- /**
- * Default constructor, creates an empty object that does nothing when destroyed.
- */
- Sock();
+ Sock() = delete;
/**
* Take ownership of an existent socket.
@@ -63,43 +59,37 @@ public:
virtual Sock& operator=(Sock&& other);
/**
- * Get the value of the contained socket.
- * @return socket or INVALID_SOCKET if empty
- */
- [[nodiscard]] virtual SOCKET Get() const;
-
- /**
- * send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this
+ * send(2) wrapper. Equivalent to `send(m_socket, data, len, flags);`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual ssize_t Send(const void* data, size_t len, int flags) const;
/**
- * recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this
+ * recv(2) wrapper. Equivalent to `recv(m_socket, buf, len, flags);`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual ssize_t Recv(void* buf, size_t len, int flags) const;
/**
- * connect(2) wrapper. Equivalent to `connect(this->Get(), addr, addrlen)`. Code that uses this
+ * connect(2) wrapper. Equivalent to `connect(m_socket, addr, addrlen)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int Connect(const sockaddr* addr, socklen_t addr_len) const;
/**
- * bind(2) wrapper. Equivalent to `bind(this->Get(), addr, addr_len)`. Code that uses this
+ * bind(2) wrapper. Equivalent to `bind(m_socket, addr, addr_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int Bind(const sockaddr* addr, socklen_t addr_len) const;
/**
- * listen(2) wrapper. Equivalent to `listen(this->Get(), backlog)`. Code that uses this
+ * listen(2) wrapper. Equivalent to `listen(m_socket, backlog)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int Listen(int backlog) const;
/**
- * accept(2) wrapper. Equivalent to `std::make_unique<Sock>(accept(this->Get(), addr, addr_len))`.
+ * accept(2) wrapper. Equivalent to `std::make_unique<Sock>(accept(m_socket, addr, addr_len))`.
* Code that uses this wrapper can be unit tested if this method is overridden by a mock Sock
* implementation.
* The returned unique_ptr is empty if `accept()` failed in which case errno will be set.
@@ -108,7 +98,7 @@ public:
/**
* getsockopt(2) wrapper. Equivalent to
- * `getsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this
+ * `getsockopt(m_socket, level, opt_name, opt_val, opt_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int GetSockOpt(int level,
@@ -118,7 +108,7 @@ public:
/**
* setsockopt(2) wrapper. Equivalent to
- * `setsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this
+ * `setsockopt(m_socket, level, opt_name, opt_val, opt_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int SetSockOpt(int level,
@@ -128,7 +118,7 @@ public:
/**
* getsockname(2) wrapper. Equivalent to
- * `getsockname(this->Get(), name, name_len)`. Code that uses this
+ * `getsockname(m_socket, name, name_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int GetSockName(sockaddr* name, socklen_t* name_len) const;
@@ -266,6 +256,11 @@ public:
*/
[[nodiscard]] virtual bool IsConnected(std::string& errmsg) const;
+ /**
+ * Check if the internal socket is equal to `s`. Use only in tests.
+ */
+ bool operator==(SOCKET s) const;
+
protected:
/**
* Contained socket. `INVALID_SOCKET` designates the object is empty.