aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/net.cpp38
-rw-r--r--src/net.h13
-rw-r--r--src/test/fuzz/util.cpp16
-rw-r--r--src/test/fuzz/util.h2
-rw-r--r--src/test/util/net.h18
-rw-r--r--src/util/sock.cpp27
-rw-r--r--src/util/sock.h9
7 files changed, 93 insertions, 30 deletions
diff --git a/src/net.cpp b/src/net.cpp
index 019e77fd7a..89a4aee5d9 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -1099,10 +1099,10 @@ bool CConnman::AttemptToEvictConnection()
void CConnman::AcceptConnection(const ListenSocket& hListenSocket) {
struct sockaddr_storage sockaddr;
socklen_t len = sizeof(sockaddr);
- SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len);
+ auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len);
CAddress addr;
- if (hSocket == INVALID_SOCKET) {
+ if (!sock) {
const int nErr = WSAGetLastError();
if (nErr != WSAEWOULDBLOCK) {
LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr));
@@ -1116,15 +1116,15 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) {
addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE};
}
- const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(hSocket)), NODE_NONE};
+ const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(sock->Get())), NODE_NONE};
NetPermissionFlags permissionFlags = NetPermissionFlags::None;
hListenSocket.AddSocketPermissionFlags(permissionFlags);
- CreateNodeFromAcceptedSocket(hSocket, permissionFlags, addr_bind, addr);
+ CreateNodeFromAcceptedSocket(std::move(sock), permissionFlags, addr_bind, addr);
}
-void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
+void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
NetPermissionFlags permissionFlags,
const CAddress& addr_bind,
const CAddress& addr)
@@ -1150,27 +1150,24 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
if (!fNetworkActive) {
LogPrint(BCLog::NET, "connection from %s dropped: not accepting new connections\n", addr.ToString());
- CloseSocket(hSocket);
return;
}
- if (!IsSelectableSocket(hSocket))
+ if (!IsSelectableSocket(sock->Get()))
{
LogPrintf("connection from %s dropped: non-selectable socket\n", addr.ToString());
- CloseSocket(hSocket);
return;
}
// According to the internet TCP_NODELAY is not carried into accepted sockets
// on all platforms. Set it again here just to be sure.
- SetSocketNoDelay(hSocket);
+ SetSocketNoDelay(sock->Get());
// Don't accept connections from banned peers.
bool banned = m_banman && m_banman->IsBanned(addr);
if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && banned)
{
LogPrint(BCLog::NET, "connection from %s dropped (banned)\n", addr.ToString());
- CloseSocket(hSocket);
return;
}
@@ -1179,7 +1176,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && nInbound + 1 >= nMaxInbound && discouraged)
{
LogPrint(BCLog::NET, "connection from %s dropped (discouraged)\n", addr.ToString());
- CloseSocket(hSocket);
return;
}
@@ -1188,7 +1184,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
if (!AttemptToEvictConnection()) {
// No connection to evict, disconnect the new connection
LogPrint(BCLog::NET, "failed to find an eviction candidate - connection dropped (full)\n");
- CloseSocket(hSocket);
return;
}
}
@@ -1202,7 +1197,7 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
}
const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end();
- CNode* pnode = new CNode(id, nodeServices, hSocket, addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
+ CNode* pnode = new CNode(id, nodeServices, sock->Release(), addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
pnode->AddRef();
pnode->m_permissionFlags = permissionFlags;
pnode->m_prefer_evict = discouraged;
@@ -1364,7 +1359,7 @@ bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes,
std::set<SOCKET>& error_set)
{
for (const ListenSocket& hListenSocket : vhListenSocket) {
- recv_set.insert(hListenSocket.socket);
+ recv_set.insert(hListenSocket.sock->Get());
}
for (CNode* pnode : nodes) {
@@ -1646,7 +1641,7 @@ void CConnman::SocketHandlerListening(const std::set<SOCKET>& recv_set)
if (interruptNet) {
return;
}
- if (recv_set.count(listen_socket.socket) > 0) {
+ if (recv_set.count(listen_socket.sock->Get()) > 0) {
AcceptConnection(listen_socket);
}
}
@@ -2335,7 +2330,7 @@ void CConnman::ThreadI2PAcceptIncoming()
continue;
}
- CreateNodeFromAcceptedSocket(conn.sock->Release(), NetPermissionFlags::None,
+ CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None,
CAddress{conn.me, NODE_NONE}, CAddress{conn.peer, NODE_NONE});
}
}
@@ -2397,7 +2392,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
return false;
}
- vhListenSocket.push_back(ListenSocket(sock->Release(), permissions));
+ vhListenSocket.emplace_back(std::move(sock), permissions);
return true;
}
@@ -2706,15 +2701,6 @@ void CConnman::StopNodes()
DeleteNode(pnode);
}
- // Close listening sockets.
- for (ListenSocket& hListenSocket : vhListenSocket) {
- if (hListenSocket.socket != INVALID_SOCKET) {
- if (!CloseSocket(hListenSocket.socket)) {
- LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError()));
- }
- }
- }
-
for (CNode* pnode : m_nodes_disconnected) {
DeleteNode(pnode);
}
diff --git a/src/net.h b/src/net.h
index 977e6502ce..80fc93a5d0 100644
--- a/src/net.h
+++ b/src/net.h
@@ -25,6 +25,7 @@
#include <threadinterrupt.h>
#include <uint256.h>
#include <util/check.h>
+#include <util/sock.h>
#include <atomic>
#include <condition_variable>
@@ -947,9 +948,13 @@ public:
private:
struct ListenSocket {
public:
- SOCKET socket;
+ std::shared_ptr<Sock> sock;
inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); }
- ListenSocket(SOCKET socket_, NetPermissionFlags permissions_) : socket(socket_), m_permissions(permissions_) {}
+ ListenSocket(std::shared_ptr<Sock> sock_, NetPermissionFlags permissions_)
+ : sock{sock_}, m_permissions{permissions_}
+ {
+ }
+
private:
NetPermissionFlags m_permissions;
};
@@ -969,12 +974,12 @@ private:
/**
* Create a `CNode` object from a socket that has just been accepted and add the node to
* the `m_nodes` member.
- * @param[in] hSocket Connected socket to communicate with the peer.
+ * @param[in] sock Connected socket to communicate with the peer.
* @param[in] permissionFlags The peer's permissions.
* @param[in] addr_bind The address and port at our side of the connection.
* @param[in] addr The address and port at the peer's side of the connection.
*/
- void CreateNodeFromAcceptedSocket(SOCKET hSocket,
+ void CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
NetPermissionFlags permissionFlags,
const CAddress& addr_bind,
const CAddress& addr);
diff --git a/src/test/fuzz/util.cpp b/src/test/fuzz/util.cpp
index 5520eee758..f89b597eed 100644
--- a/src/test/fuzz/util.cpp
+++ b/src/test/fuzz/util.cpp
@@ -13,6 +13,8 @@
#include <util/time.h>
#include <version.h>
+#include <memory>
+
FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider)
: m_fuzzed_data_provider{fuzzed_data_provider}
{
@@ -158,6 +160,20 @@ int FuzzedSock::Connect(const sockaddr*, socklen_t) const
return 0;
}
+std::unique_ptr<Sock> FuzzedSock::Accept(sockaddr* addr, socklen_t* addr_len) const
+{
+ constexpr std::array accept_errnos{
+ ECONNABORTED,
+ EINTR,
+ ENOMEM,
+ };
+ if (m_fuzzed_data_provider.ConsumeBool()) {
+ SetFuzzedErrNo(m_fuzzed_data_provider, accept_errnos);
+ return std::unique_ptr<FuzzedSock>();
+ }
+ return std::make_unique<FuzzedSock>(m_fuzzed_data_provider);
+}
+
int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
{
constexpr std::array getsockopt_errnos{
diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h
index f4f8e9e70d..fd7f40c01d 100644
--- a/src/test/fuzz/util.h
+++ b/src/test/fuzz/util.h
@@ -401,6 +401,8 @@ public:
int Connect(const sockaddr*, socklen_t) const override;
+ std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const override;
+
int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override;
bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override;
diff --git a/src/test/util/net.h b/src/test/util/net.h
index 006e876c1a..20c45058a1 100644
--- a/src/test/util/net.h
+++ b/src/test/util/net.h
@@ -13,6 +13,7 @@
#include <array>
#include <cassert>
#include <cstring>
+#include <memory>
#include <string>
struct ConnmanTestMsg : public CConnman {
@@ -126,6 +127,23 @@ public:
int Connect(const sockaddr*, socklen_t) const override { return 0; }
+ std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const override
+ {
+ if (addr != nullptr) {
+ // Pretend all connections come from 5.5.5.5:6789
+ memset(addr, 0x00, *addr_len);
+ const socklen_t write_len = static_cast<socklen_t>(sizeof(sockaddr_in));
+ if (*addr_len >= write_len) {
+ *addr_len = write_len;
+ sockaddr_in* addr_in = reinterpret_cast<sockaddr_in*>(addr);
+ addr_in->sin_family = AF_INET;
+ memset(&addr_in->sin_addr, 0x05, sizeof(addr_in->sin_addr));
+ addr_in->sin_port = htons(6789);
+ }
+ }
+ return std::make_unique<StaticContentsSock>("");
+ };
+
int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override
{
std::memset(opt_val, 0x0, *opt_len);
diff --git a/src/util/sock.cpp b/src/util/sock.cpp
index 1a4d67a65e..2029d70a37 100644
--- a/src/util/sock.cpp
+++ b/src/util/sock.cpp
@@ -10,6 +10,7 @@
#include <util/system.h>
#include <util/time.h>
+#include <memory>
#include <stdexcept>
#include <string>
@@ -73,6 +74,32 @@ int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
return connect(m_socket, addr, addr_len);
}
+std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
+{
+#ifdef WIN32
+ static constexpr auto ERR = INVALID_SOCKET;
+#else
+ static constexpr auto ERR = SOCKET_ERROR;
+#endif
+
+ std::unique_ptr<Sock> sock;
+
+ const auto socket = accept(m_socket, addr, addr_len);
+ if (socket != ERR) {
+ try {
+ sock = std::make_unique<Sock>(socket);
+ } catch (const std::exception&) {
+#ifdef WIN32
+ closesocket(socket);
+#else
+ close(socket);
+#endif
+ }
+ }
+
+ return sock;
+}
+
int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
{
return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
diff --git a/src/util/sock.h b/src/util/sock.h
index 59cc8c0b1d..7510482857 100644
--- a/src/util/sock.h
+++ b/src/util/sock.h
@@ -10,6 +10,7 @@
#include <util/time.h>
#include <chrono>
+#include <memory>
#include <string>
/**
@@ -97,6 +98,14 @@ public:
[[nodiscard]] virtual int Connect(const sockaddr* addr, socklen_t addr_len) const;
/**
+ * accept(2) wrapper. Equivalent to `std::make_unique<Sock>(accept(this->Get(), addr, addr_len))`.
+ * Code that uses this wrapper can be unit tested if this method is overridden by a mock Sock
+ * implementation.
+ * The returned unique_ptr is empty if `accept()` failed in which case errno will be set.
+ */
+ [[nodiscard]] virtual std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const;
+
+ /**
* getsockopt(2) wrapper. Equivalent to
* `getsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.