aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAva Chow <github@achow101.com>2024-06-20 13:44:56 -0400
committerAva Chow <github@achow101.com>2024-06-20 13:44:56 -0400
commita961ad1bebc54912b88d072abf22ab7d3cf46bf1 (patch)
tree54fa13dcf159d7fef65ef2005e9d4374f3c6418c
parent21656e99b5f489c881f9fe90b28edc4aac870ab0 (diff)
parent1245d1388b003c46092937def7041917aecec8de (diff)
Merge bitcoin/bitcoin#30202: netbase: extend CreateSock() to support creating arbitrary sockets
1245d1388b003c46092937def7041917aecec8de netbase: extend CreateSock() to support creating arbitrary sockets (Vasil Dimov) Pull request description: Allow the callers of `CreateSock()` to pass all 3 arguments to the `socket(2)` syscall. This makes it possible to create sockets of any domain/type/protocol. In addition to extending arguments, some extra safety checks were put in place. The need for this came up during the discussion in https://github.com/bitcoin/bitcoin/pull/30043#discussion_r1618837102 ACKs for top commit: achow101: ACK 1245d1388b003c46092937def7041917aecec8de tdb3: re ACK 1245d1388b003c46092937def7041917aecec8de theStack: re-ACK 1245d1388b003c46092937def7041917aecec8de Tree-SHA512: cc86b56121293ac98959aed0ed77812d20702ed7029b5a043586f46e74295779c5354bb0d5f9e80be6c29e535df980d34c1dbf609064fb7ea3e5ca0f0ed54d6b
-rw-r--r--src/net.cpp2
-rw-r--r--src/netbase.cpp34
-rw-r--r--src/netbase.h10
-rw-r--r--src/test/fuzz/fuzz.cpp5
-rw-r--r--src/test/fuzz/i2p.cpp2
-rw-r--r--src/test/i2p_tests.cpp13
6 files changed, 33 insertions, 33 deletions
diff --git a/src/net.cpp b/src/net.cpp
index de974f39cb..990c58ee3d 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -3029,7 +3029,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
return false;
}
- std::unique_ptr<Sock> sock = CreateSock(addrBind.GetSAFamily());
+ std::unique_ptr<Sock> sock = CreateSock(addrBind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP);
if (!sock) {
strError = strprintf(Untranslated("Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError()));
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original);
diff --git a/src/netbase.cpp b/src/netbase.cpp
index ff46061d3d..fcbdb43e2a 100644
--- a/src/netbase.cpp
+++ b/src/netbase.cpp
@@ -487,24 +487,23 @@ bool Socks5(const std::string& strDest, uint16_t port, const ProxyCredentials* a
}
}
-std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family)
+std::unique_ptr<Sock> CreateSockOS(int domain, int type, int protocol)
{
// Not IPv4, IPv6 or UNIX
- if (address_family == AF_UNSPEC) return nullptr;
-
- int protocol{IPPROTO_TCP};
-#if HAVE_SOCKADDR_UN
- if (address_family == AF_UNIX) protocol = 0;
-#endif
+ if (domain == AF_UNSPEC) return nullptr;
// Create a socket in the specified address family.
- SOCKET hSocket = socket(address_family, SOCK_STREAM, protocol);
+ SOCKET hSocket = socket(domain, type, protocol);
if (hSocket == INVALID_SOCKET) {
return nullptr;
}
auto sock = std::make_unique<Sock>(hSocket);
+ if (domain != AF_INET && domain != AF_INET6 && domain != AF_UNIX) {
+ return sock;
+ }
+
// Ensure that waiting for I/O on this socket won't result in undefined
// behavior.
if (!sock->IsSelectable()) {
@@ -529,18 +528,21 @@ std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family)
}
#if HAVE_SOCKADDR_UN
- if (address_family == AF_UNIX) return sock;
+ if (domain == AF_UNIX) return sock;
#endif
- // Set the no-delay option (disable Nagle's algorithm) on the TCP socket.
- const int on{1};
- if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
- LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n");
+ if (protocol == IPPROTO_TCP) {
+ // Set the no-delay option (disable Nagle's algorithm) on the TCP socket.
+ const int on{1};
+ if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
+ LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n");
+ }
}
+
return sock;
}
-std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock = CreateSockOS;
+std::function<std::unique_ptr<Sock>(int, int, int)> CreateSock = CreateSockOS;
template<typename... Args>
static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) {
@@ -609,7 +611,7 @@ static bool ConnectToSocket(const Sock& sock, struct sockaddr* sockaddr, socklen
std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connection)
{
- auto sock = CreateSock(dest.GetSAFamily());
+ auto sock = CreateSock(dest.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP);
if (!sock) {
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", dest.ToStringAddrPort());
return {};
@@ -637,7 +639,7 @@ std::unique_ptr<Sock> Proxy::Connect() const
if (!m_is_unix_socket) return ConnectDirectly(proxy, /*manual_connection=*/true);
#if HAVE_SOCKADDR_UN
- auto sock = CreateSock(AF_UNIX);
+ auto sock = CreateSock(AF_UNIX, SOCK_STREAM, 0);
if (!sock) {
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", m_unix_socket_path);
return {};
diff --git a/src/netbase.h b/src/netbase.h
index 321c288f67..8ef6c28996 100644
--- a/src/netbase.h
+++ b/src/netbase.h
@@ -262,16 +262,18 @@ CService LookupNumeric(const std::string& name, uint16_t portDefault = 0, DNSLoo
CSubNet LookupSubNet(const std::string& subnet_str);
/**
- * Create a TCP or UNIX socket in the given address family.
- * @param[in] address_family to use for the socket.
+ * Create a real socket from the operating system.
+ * @param[in] domain Communications domain, first argument to the socket(2) syscall.
+ * @param[in] type Type of the socket, second argument to the socket(2) syscall.
+ * @param[in] protocol The particular protocol to be used with the socket, third argument to the socket(2) syscall.
* @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure
*/
-std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family);
+std::unique_ptr<Sock> CreateSockOS(int domain, int type, int protocol);
/**
* Socket factory. Defaults to `CreateSockOS()`, but can be overridden by unit tests.
*/
-extern std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock;
+extern std::function<std::unique_ptr<Sock>(int, int, int)> CreateSock;
/**
* Create a socket and try to connect to the specified service.
diff --git a/src/test/fuzz/fuzz.cpp b/src/test/fuzz/fuzz.cpp
index 9a54a44bd3..c1c9945a04 100644
--- a/src/test/fuzz/fuzz.cpp
+++ b/src/test/fuzz/fuzz.cpp
@@ -101,8 +101,9 @@ void ResetCoverageCounters() {}
void initialize()
{
- // Terminate immediately if a fuzzing harness ever tries to create a TCP socket.
- CreateSock = [](const sa_family_t&) -> std::unique_ptr<Sock> { std::terminate(); };
+ // Terminate immediately if a fuzzing harness ever tries to create a socket.
+ // Individual tests can override this by pointing CreateSock to a mocked alternative.
+ CreateSock = [](int, int, int) -> std::unique_ptr<Sock> { std::terminate(); };
// Terminate immediately if a fuzzing harness ever tries to perform a DNS lookup.
g_dns_lookup = [](const std::string& name, bool allow_lookup) {
diff --git a/src/test/fuzz/i2p.cpp b/src/test/fuzz/i2p.cpp
index 3af5bed30a..51517187a0 100644
--- a/src/test/fuzz/i2p.cpp
+++ b/src/test/fuzz/i2p.cpp
@@ -27,7 +27,7 @@ FUZZ_TARGET(i2p, .init = initialize_i2p)
// Mock CreateSock() to create FuzzedSock.
auto CreateSockOrig = CreateSock;
- CreateSock = [&fuzzed_data_provider](const sa_family_t&) {
+ CreateSock = [&fuzzed_data_provider](int, int, int) {
return std::make_unique<FuzzedSock>(fuzzed_data_provider);
};
diff --git a/src/test/i2p_tests.cpp b/src/test/i2p_tests.cpp
index d7249d88f4..0512c6134f 100644
--- a/src/test/i2p_tests.cpp
+++ b/src/test/i2p_tests.cpp
@@ -39,15 +39,14 @@ public:
private:
const BCLog::Level m_prev_log_level;
- const std::function<std::unique_ptr<Sock>(const sa_family_t&)> m_create_sock_orig;
+ const decltype(CreateSock) m_create_sock_orig;
};
BOOST_FIXTURE_TEST_SUITE(i2p_tests, EnvTestingSetup)
BOOST_AUTO_TEST_CASE(unlimited_recv)
{
- // Mock CreateSock() to create MockSock.
- CreateSock = [](const sa_family_t&) {
+ CreateSock = [](int, int, int) {
return std::make_unique<StaticContentsSock>(std::string(i2p::sam::MAX_MSG_SIZE + 1, 'a'));
};
@@ -69,7 +68,7 @@ BOOST_AUTO_TEST_CASE(unlimited_recv)
BOOST_AUTO_TEST_CASE(listen_ok_accept_fail)
{
size_t num_sockets{0};
- CreateSock = [&num_sockets](const sa_family_t&) {
+ CreateSock = [&num_sockets](int, int, int) {
// clang-format off
++num_sockets;
// First socket is the control socket for creating the session.
@@ -133,9 +132,7 @@ BOOST_AUTO_TEST_CASE(listen_ok_accept_fail)
BOOST_AUTO_TEST_CASE(damaged_private_key)
{
- const auto CreateSockOrig = CreateSock;
-
- CreateSock = [](const sa_family_t&) {
+ CreateSock = [](int, int, int) {
return std::make_unique<StaticContentsSock>("HELLO REPLY RESULT=OK VERSION=3.1\n"
"SESSION STATUS RESULT=OK DESTINATION=\n");
};
@@ -172,8 +169,6 @@ BOOST_AUTO_TEST_CASE(damaged_private_key)
BOOST_CHECK(!session.Connect(CService{}, conn, proxy_error));
}
}
-
- CreateSock = CreateSockOrig;
}
BOOST_AUTO_TEST_SUITE_END()