From c6b4bfb4b3507f1a62290869d7435b0f54032104 Mon Sep 17 00:00:00 2001 From: practicalswift Date: Mon, 29 Jun 2020 09:44:12 +0000 Subject: net: Make DNS lookup code testable --- src/netbase.cpp | 116 +++++++++++++++++++++++++++++--------------------------- src/netbase.h | 21 +++++++--- 2 files changed, 76 insertions(+), 61 deletions(-) (limited to 'src') diff --git a/src/netbase.cpp b/src/netbase.cpp index 88c36ed86c..b95bb05e71 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -42,6 +42,50 @@ bool fNameLookup = DEFAULT_NAME_LOOKUP; int g_socks5_recv_timeout = 20 * 1000; static std::atomic interruptSocks5Recv(false); +std::vector WrappedGetAddrInfo(const std::string& name, bool allow_lookup) +{ + addrinfo ai_hint{}; + // We want a TCP port, which is a streaming socket type + ai_hint.ai_socktype = SOCK_STREAM; + ai_hint.ai_protocol = IPPROTO_TCP; + // We don't care which address family (IPv4 or IPv6) is returned + ai_hint.ai_family = AF_UNSPEC; + // If we allow lookups of hostnames, use the AI_ADDRCONFIG flag to only + // return addresses whose family we have an address configured for. + // + // If we don't allow lookups, then use the AI_NUMERICHOST flag for + // getaddrinfo to only decode numerical network addresses and suppress + // hostname lookups. + ai_hint.ai_flags = allow_lookup ? AI_ADDRCONFIG : AI_NUMERICHOST; + + addrinfo* ai_res{nullptr}; + const int n_err{getaddrinfo(name.c_str(), nullptr, &ai_hint, &ai_res)}; + if (n_err != 0) { + return {}; + } + + // Traverse the linked list starting with ai_trav. + addrinfo* ai_trav{ai_res}; + std::vector resolved_addresses; + while (ai_trav != nullptr) { + if (ai_trav->ai_family == AF_INET) { + assert(ai_trav->ai_addrlen >= sizeof(sockaddr_in)); + resolved_addresses.emplace_back(reinterpret_cast(ai_trav->ai_addr)->sin_addr); + } + if (ai_trav->ai_family == AF_INET6) { + assert(ai_trav->ai_addrlen >= sizeof(sockaddr_in6)); + const sockaddr_in6* s6{reinterpret_cast(ai_trav->ai_addr)}; + resolved_addresses.emplace_back(s6->sin6_addr, s6->sin6_scope_id); + } + ai_trav = ai_trav->ai_next; + } + freeaddrinfo(ai_res); + + return resolved_addresses; +} + +DNSLookupFn g_dns_lookup{WrappedGetAddrInfo}; + enum Network ParseNetwork(const std::string& net_in) { std::string net = ToLower(net_in); if (net == "ipv4") return NET_IPV4; @@ -87,7 +131,7 @@ std::vector GetNetworkNames(bool append_unroutable) return names; } -bool static LookupIntern(const std::string& name, std::vector& vIP, unsigned int nMaxSolutions, bool fAllowLookup) +static bool LookupIntern(const std::string& name, std::vector& vIP, unsigned int nMaxSolutions, bool fAllowLookup, DNSLookupFn dns_lookup_function) { vIP.clear(); @@ -109,54 +153,16 @@ bool static LookupIntern(const std::string& name, std::vector& vIP, un } } - struct addrinfo aiHint; - memset(&aiHint, 0, sizeof(struct addrinfo)); - - // We want a TCP port, which is a streaming socket type - aiHint.ai_socktype = SOCK_STREAM; - aiHint.ai_protocol = IPPROTO_TCP; - // We don't care which address family (IPv4 or IPv6) is returned - aiHint.ai_family = AF_UNSPEC; - // If we allow lookups of hostnames, use the AI_ADDRCONFIG flag to only - // return addresses whose family we have an address configured for. - // - // If we don't allow lookups, then use the AI_NUMERICHOST flag for - // getaddrinfo to only decode numerical network addresses and suppress - // hostname lookups. - aiHint.ai_flags = fAllowLookup ? AI_ADDRCONFIG : AI_NUMERICHOST; - struct addrinfo *aiRes = nullptr; - int nErr = getaddrinfo(name.c_str(), nullptr, &aiHint, &aiRes); - if (nErr) - return false; - - // Traverse the linked list starting with aiTrav, add all non-internal - // IPv4,v6 addresses to vIP while respecting nMaxSolutions. - struct addrinfo *aiTrav = aiRes; - while (aiTrav != nullptr && (nMaxSolutions == 0 || vIP.size() < nMaxSolutions)) - { - CNetAddr resolved; - if (aiTrav->ai_family == AF_INET) - { - assert(aiTrav->ai_addrlen >= sizeof(sockaddr_in)); - resolved = CNetAddr(((struct sockaddr_in*)(aiTrav->ai_addr))->sin_addr); - } - - if (aiTrav->ai_family == AF_INET6) - { - assert(aiTrav->ai_addrlen >= sizeof(sockaddr_in6)); - struct sockaddr_in6* s6 = (struct sockaddr_in6*) aiTrav->ai_addr; - resolved = CNetAddr(s6->sin6_addr, s6->sin6_scope_id); + for (const CNetAddr& resolved : dns_lookup_function(name, fAllowLookup)) { + if (nMaxSolutions > 0 && vIP.size() >= nMaxSolutions) { + break; } /* Never allow resolving to an internal address. Consider any such result invalid */ if (!resolved.IsInternal()) { vIP.push_back(resolved); } - - aiTrav = aiTrav->ai_next; } - freeaddrinfo(aiRes); - return (vIP.size() > 0); } @@ -175,7 +181,7 @@ bool static LookupIntern(const std::string& name, std::vector& vIP, un * @see Lookup(const char *, std::vector&, int, bool, unsigned int) * for additional parameter descriptions. */ -bool LookupHost(const std::string& name, std::vector& vIP, unsigned int nMaxSolutions, bool fAllowLookup) +bool LookupHost(const std::string& name, std::vector& vIP, unsigned int nMaxSolutions, bool fAllowLookup, DNSLookupFn dns_lookup_function) { if (!ValidAsCString(name)) { return false; @@ -187,7 +193,7 @@ bool LookupHost(const std::string& name, std::vector& vIP, unsigned in strHost = strHost.substr(1, strHost.size() - 2); } - return LookupIntern(strHost, vIP, nMaxSolutions, fAllowLookup); + return LookupIntern(strHost, vIP, nMaxSolutions, fAllowLookup, dns_lookup_function); } /** @@ -196,13 +202,13 @@ bool LookupHost(const std::string& name, std::vector& vIP, unsigned in * @see LookupHost(const std::string&, std::vector&, unsigned int, bool) for * additional parameter descriptions. */ -bool LookupHost(const std::string& name, CNetAddr& addr, bool fAllowLookup) +bool LookupHost(const std::string& name, CNetAddr& addr, bool fAllowLookup, DNSLookupFn dns_lookup_function) { if (!ValidAsCString(name)) { return false; } std::vector vIP; - LookupHost(name, vIP, 1, fAllowLookup); + LookupHost(name, vIP, 1, fAllowLookup, dns_lookup_function); if(vIP.empty()) return false; addr = vIP.front(); @@ -229,7 +235,7 @@ bool LookupHost(const std::string& name, CNetAddr& addr, bool fAllowLookup) * @returns Whether or not the service string successfully resolved to any * resulting services. */ -bool Lookup(const std::string& name, std::vector& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions) +bool Lookup(const std::string& name, std::vector& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions, DNSLookupFn dns_lookup_function) { if (name.empty() || !ValidAsCString(name)) { return false; @@ -239,7 +245,7 @@ bool Lookup(const std::string& name, std::vector& vAddr, int portDefau SplitHostPort(name, port, hostname); std::vector vIP; - bool fRet = LookupIntern(hostname, vIP, nMaxSolutions, fAllowLookup); + bool fRet = LookupIntern(hostname, vIP, nMaxSolutions, fAllowLookup, dns_lookup_function); if (!fRet) return false; vAddr.resize(vIP.size()); @@ -254,13 +260,13 @@ bool Lookup(const std::string& name, std::vector& vAddr, int portDefau * @see Lookup(const char *, std::vector&, int, bool, unsigned int) * for additional parameter descriptions. */ -bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllowLookup) +bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllowLookup, DNSLookupFn dns_lookup_function) { if (!ValidAsCString(name)) { return false; } std::vector vService; - bool fRet = Lookup(name, vService, portDefault, fAllowLookup, 1); + bool fRet = Lookup(name, vService, portDefault, fAllowLookup, 1, dns_lookup_function); if (!fRet) return false; addr = vService[0]; @@ -277,7 +283,7 @@ bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllo * @see Lookup(const char *, CService&, int, bool) for additional parameter * descriptions. */ -CService LookupNumeric(const std::string& name, int portDefault) +CService LookupNumeric(const std::string& name, int portDefault, DNSLookupFn dns_lookup_function) { if (!ValidAsCString(name)) { return {}; @@ -285,7 +291,7 @@ CService LookupNumeric(const std::string& name, int portDefault) CService addr; // "1.2:345" will fail to resolve the ip, but will still set the port. // If the ip fails to resolve, re-init the result. - if(!Lookup(name, addr, portDefault, false)) + if(!Lookup(name, addr, portDefault, false, dns_lookup_function)) addr = CService(); return addr; } @@ -811,7 +817,7 @@ bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int * * @returns Whether the operation succeeded or not. */ -bool LookupSubNet(const std::string& strSubnet, CSubNet& ret) +bool LookupSubNet(const std::string& strSubnet, CSubNet& ret, DNSLookupFn dns_lookup_function) { if (!ValidAsCString(strSubnet)) { return false; @@ -822,7 +828,7 @@ bool LookupSubNet(const std::string& strSubnet, CSubNet& ret) std::string strAddress = strSubnet.substr(0, slash); // TODO: Use LookupHost(const std::string&, CNetAddr&, bool) instead to just get // one CNetAddr. - if (LookupHost(strAddress, vIP, 1, false)) + if (LookupHost(strAddress, vIP, 1, false, dns_lookup_function)) { CNetAddr network = vIP[0]; if (slash != strSubnet.npos) @@ -837,7 +843,7 @@ bool LookupSubNet(const std::string& strSubnet, CSubNet& ret) else // If not a valid number, try full netmask syntax { // Never allow lookup for netmask - if (LookupHost(strNetmask, vIP, 1, false)) { + if (LookupHost(strNetmask, vIP, 1, false, dns_lookup_function)) { ret = CSubNet(network, vIP[0]); return ret.IsValid(); } diff --git a/src/netbase.h b/src/netbase.h index 751f7eb3f0..227da1a63b 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -64,6 +64,11 @@ struct ProxyCredentials std::string password; }; +/** + * Wrapper for getaddrinfo(3). Do not use directly: call Lookup/LookupHost/LookupNumeric/LookupSubNet. + */ +std::vector WrappedGetAddrInfo(const std::string& name, bool allow_lookup); + enum Network ParseNetwork(const std::string& net); std::string GetNetworkName(enum Network net); /** Return a vector of publicly routable Network names; optionally append NET_UNROUTABLE. */ @@ -74,12 +79,16 @@ bool IsProxy(const CNetAddr &addr); bool SetNameProxy(const proxyType &addrProxy); bool HaveNameProxy(); bool GetNameProxy(proxyType &nameProxyOut); -bool LookupHost(const std::string& name, std::vector& vIP, unsigned int nMaxSolutions, bool fAllowLookup); -bool LookupHost(const std::string& name, CNetAddr& addr, bool fAllowLookup); -bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllowLookup); -bool Lookup(const std::string& name, std::vector& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions); -CService LookupNumeric(const std::string& name, int portDefault = 0); -bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet); + +using DNSLookupFn = std::function(const std::string&, bool)>; +extern DNSLookupFn g_dns_lookup; + +bool LookupHost(const std::string& name, std::vector& vIP, unsigned int nMaxSolutions, bool fAllowLookup, DNSLookupFn dns_lookup_function = g_dns_lookup); +bool LookupHost(const std::string& name, CNetAddr& addr, bool fAllowLookup, DNSLookupFn dns_lookup_function = g_dns_lookup); +bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllowLookup, DNSLookupFn dns_lookup_function = g_dns_lookup); +bool Lookup(const std::string& name, std::vector& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions, DNSLookupFn dns_lookup_function = g_dns_lookup); +CService LookupNumeric(const std::string& name, int portDefault = 0, DNSLookupFn dns_lookup_function = g_dns_lookup); +bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet, DNSLookupFn dns_lookup_function = g_dns_lookup); /** * Create a TCP socket in the given address family. -- cgit v1.2.3