diff options
author | Matthew Zipkin <pinheadmz@gmail.com> | 2023-05-26 12:26:43 -0400 |
---|---|---|
committer | Matthew Zipkin <pinheadmz@gmail.com> | 2024-03-01 13:13:07 -0500 |
commit | bae86c8d318d06818aa75a9ebe3db864197f0bc6 (patch) | |
tree | 893a4fde71f37fcfab090b216a8f097ce9308fb8 | |
parent | adb3a3e51de205cc69b1a58647c65c04fa6c6362 (diff) |
netbase: refactor CreateSock() to accept sa_family_t
Also implement CService::GetSAFamily() to provide sa_family_t
-rw-r--r-- | src/compat/compat.h | 7 | ||||
-rw-r--r-- | src/i2p.cpp | 2 | ||||
-rw-r--r-- | src/net.cpp | 9 | ||||
-rw-r--r-- | src/netaddress.cpp | 13 | ||||
-rw-r--r-- | src/netaddress.h | 5 | ||||
-rw-r--r-- | src/netbase.cpp | 15 | ||||
-rw-r--r-- | src/netbase.h | 10 | ||||
-rw-r--r-- | src/test/fuzz/fuzz.cpp | 2 | ||||
-rw-r--r-- | src/test/i2p_tests.cpp | 8 |
9 files changed, 46 insertions, 25 deletions
diff --git a/src/compat/compat.h b/src/compat/compat.h index 9ff9a335f8..366c648ae7 100644 --- a/src/compat/compat.h +++ b/src/compat/compat.h @@ -32,6 +32,13 @@ #include <unistd.h> // IWYU pragma: export #endif +// Windows does not have `sa_family_t` - it defines `sockaddr::sa_family` as `u_short`. +// Thus define `sa_family_t` on Windows too so that the rest of the code can use `sa_family_t`. +// See https://learn.microsoft.com/en-us/windows/win32/api/winsock/ns-winsock-sockaddr#syntax +#ifdef WIN32 +typedef u_short sa_family_t; +#endif + // We map Linux / BSD error functions and codes, to the equivalent // Windows definitions, and use the WSA* names throughout our code. // Note that glibc defines EWOULDBLOCK as EAGAIN (see errno.h). diff --git a/src/i2p.cpp b/src/i2p.cpp index 4b79a6826b..2fce946e7d 100644 --- a/src/i2p.cpp +++ b/src/i2p.cpp @@ -326,7 +326,7 @@ Session::Reply Session::SendRequestAndGetReply(const Sock& sock, std::unique_ptr<Sock> Session::Hello() const { - auto sock = CreateSock(m_control_host); + auto sock = CreateSock(m_control_host.GetSAFamily()); if (!sock) { throw std::runtime_error("Cannot create socket"); diff --git a/src/net.cpp b/src/net.cpp index 7c82f01d75..d40c7109c2 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -483,15 +483,16 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo addr_bind = CAddress{conn.me, NODE_NONE}; } } else if (use_proxy) { - sock = CreateSock(proxy.proxy); + sock = CreateSock(proxy.proxy.GetSAFamily()); if (!sock) { return nullptr; } + LogPrintLevel(BCLog::PROXY, BCLog::Level::Debug, "Using proxy: %s to connect to %s:%s\n", proxy.proxy.ToStringAddrPort(), addrConnect.ToStringAddr(), addrConnect.GetPort()); connected = ConnectThroughProxy(proxy, addrConnect.ToStringAddr(), addrConnect.GetPort(), *sock, nConnectTimeout, proxyConnectionFailed); } else { // no proxy needed (none set for target network) - sock = CreateSock(addrConnect); + sock = CreateSock(addrConnect.GetSAFamily()); if (!sock) { return nullptr; } @@ -504,7 +505,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo addrman.Attempt(addrConnect, fCountFailure); } } else if (pszDest && GetNameProxy(proxy)) { - sock = CreateSock(proxy.proxy); + sock = CreateSock(proxy.proxy.GetSAFamily()); if (!sock) { return nullptr; } @@ -2993,7 +2994,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - std::unique_ptr<Sock> sock = CreateSock(addrBind); + std::unique_ptr<Sock> sock = CreateSock(addrBind.GetSAFamily()); if (!sock) { strError = strprintf(Untranslated("Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); diff --git a/src/netaddress.cpp b/src/netaddress.cpp index 7530334db1..74ab6dd8d8 100644 --- a/src/netaddress.cpp +++ b/src/netaddress.cpp @@ -818,6 +818,19 @@ bool CService::SetSockAddr(const struct sockaddr *paddr) } } +sa_family_t CService::GetSAFamily() const +{ + switch (m_net) { + case NET_IPV4: + return AF_INET; + case NET_IPV6: + case NET_CJDNS: + return AF_INET6; + default: + return AF_UNSPEC; + } +} + uint16_t CService::GetPort() const { return port; diff --git a/src/netaddress.h b/src/netaddress.h index c697b7e0a3..c63bd4b4e5 100644 --- a/src/netaddress.h +++ b/src/netaddress.h @@ -540,6 +540,11 @@ public: uint16_t GetPort() const; bool GetSockAddr(struct sockaddr* paddr, socklen_t* addrlen) const; bool SetSockAddr(const struct sockaddr* paddr); + /** + * Get the address family + * @returns AF_UNSPEC if unspecified + */ + [[nodiscard]] sa_family_t GetSAFamily() const; friend bool operator==(const CService& a, const CService& b); friend bool operator!=(const CService& a, const CService& b) { return !(a == b); } friend bool operator<(const CService& a, const CService& b); diff --git a/src/netbase.cpp b/src/netbase.cpp index 9fbd9f7dea..973d888722 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -444,18 +444,13 @@ bool Socks5(const std::string& strDest, uint16_t port, const ProxyCredentials* a } } -std::unique_ptr<Sock> CreateSockTCP(const CService& address_family) +std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family) { - // Create a sockaddr from the specified service. - struct sockaddr_storage sockaddr; - socklen_t len = sizeof(sockaddr); - if (!address_family.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { - LogPrintf("Cannot create socket for %s: unsupported network\n", address_family.ToStringAddrPort()); - return nullptr; - } + // Not IPv4 or IPv6 + if (address_family == AF_UNSPEC) return nullptr; // Create a TCP socket in the address family of the specified service. - SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); + SOCKET hSocket = socket(address_family, SOCK_STREAM, IPPROTO_TCP); if (hSocket == INVALID_SOCKET) { return nullptr; } @@ -493,7 +488,7 @@ std::unique_ptr<Sock> CreateSockTCP(const CService& address_family) return sock; } -std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP; +std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock = CreateSockOS; template<typename... Args> static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) { diff --git a/src/netbase.h b/src/netbase.h index 1bd95ba0d9..6f2d86f153 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -229,16 +229,16 @@ CService LookupNumeric(const std::string& name, uint16_t portDefault = 0, DNSLoo CSubNet LookupSubNet(const std::string& subnet_str); /** - * Create a TCP socket in the given address family. - * @param[in] address_family The socket is created in the same address family as this address. + * Create a socket in the given address family. + * @param[in] address_family to use for the socket. * @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure */ -std::unique_ptr<Sock> CreateSockTCP(const CService& address_family); +std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family); /** - * Socket factory. Defaults to `CreateSockTCP()`, but can be overridden by unit tests. + * Socket factory. Defaults to `CreateSockOS()`, but can be overridden by unit tests. */ -extern std::function<std::unique_ptr<Sock>(const CService&)> CreateSock; +extern std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock; /** * Try to connect to the specified service on the specified socket. diff --git a/src/test/fuzz/fuzz.cpp b/src/test/fuzz/fuzz.cpp index 6de480ff15..d1fb31644d 100644 --- a/src/test/fuzz/fuzz.cpp +++ b/src/test/fuzz/fuzz.cpp @@ -81,7 +81,7 @@ static const TypeTestOneInput* g_test_one_input{nullptr}; void initialize() { // Terminate immediately if a fuzzing harness ever tries to create a TCP socket. - CreateSock = [](const CService&) -> std::unique_ptr<Sock> { std::terminate(); }; + CreateSock = [](const sa_family_t&) -> std::unique_ptr<Sock> { std::terminate(); }; // Terminate immediately if a fuzzing harness ever tries to perform a DNS lookup. g_dns_lookup = [](const std::string& name, bool allow_lookup) { diff --git a/src/test/i2p_tests.cpp b/src/test/i2p_tests.cpp index f80f07d190..ef1d0b659c 100644 --- a/src/test/i2p_tests.cpp +++ b/src/test/i2p_tests.cpp @@ -38,7 +38,7 @@ public: private: const BCLog::Level m_prev_log_level; - const std::function<std::unique_ptr<Sock>(const CService&)> m_create_sock_orig; + const std::function<std::unique_ptr<Sock>(const sa_family_t&)> m_create_sock_orig; }; BOOST_FIXTURE_TEST_SUITE(i2p_tests, EnvTestingSetup) @@ -46,7 +46,7 @@ BOOST_FIXTURE_TEST_SUITE(i2p_tests, EnvTestingSetup) BOOST_AUTO_TEST_CASE(unlimited_recv) { // Mock CreateSock() to create MockSock. - CreateSock = [](const CService&) { + CreateSock = [](const sa_family_t&) { return std::make_unique<StaticContentsSock>(std::string(i2p::sam::MAX_MSG_SIZE + 1, 'a')); }; @@ -66,7 +66,7 @@ BOOST_AUTO_TEST_CASE(unlimited_recv) BOOST_AUTO_TEST_CASE(listen_ok_accept_fail) { size_t num_sockets{0}; - CreateSock = [&num_sockets](const CService&) { + CreateSock = [&num_sockets](const sa_family_t&) { // clang-format off ++num_sockets; // First socket is the control socket for creating the session. @@ -130,7 +130,7 @@ BOOST_AUTO_TEST_CASE(damaged_private_key) { const auto CreateSockOrig = CreateSock; - CreateSock = [](const CService&) { + CreateSock = [](const sa_family_t&) { return std::make_unique<StaticContentsSock>("HELLO REPLY RESULT=OK VERSION=3.1\n" "SESSION STATUS RESULT=OK DESTINATION=\n"); }; |