aboutsummaryrefslogtreecommitdiff
path: root/src/netbase.cpp
diff options
context:
space:
mode:
authorVasil Dimov <vd@FreeBSD.org>2021-03-05 17:01:59 +0100
committerVasil Dimov <vd@FreeBSD.org>2021-03-16 13:58:23 +0100
commit82d360b5a88d9057b6c09b61cd69e426c7a2412d (patch)
tree8a60f973c2155fa870c30e7919e869cdaee36d0c /src/netbase.cpp
parentb5861100f85fef77b00f55dcdf01ffb4a2a112d8 (diff)
downloadbitcoin-82d360b5a88d9057b6c09b61cd69e426c7a2412d.tar.xz
net: change ConnectSocketDirectly() to take a Sock argument
Change `ConnectSocketDirectly()` to take a `Sock` argument instead of a bare `SOCKET`. With this, use the `Sock`'s (possibly mocked) methods `Connect()`, `Wait()` and `GetSockOpt()` instead of calling the OS functions directly.
Diffstat (limited to 'src/netbase.cpp')
-rw-r--r--src/netbase.cpp59
1 files changed, 23 insertions, 36 deletions
diff --git a/src/netbase.cpp b/src/netbase.cpp
index ac2392ebed..462fa719bc 100644
--- a/src/netbase.cpp
+++ b/src/netbase.cpp
@@ -537,12 +537,12 @@ static void LogConnectFailure(bool manual_connection, const char* fmt, const Arg
}
}
-bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, int nTimeout, bool manual_connection)
+bool ConnectSocketDirectly(const CService &addrConnect, const Sock& sock, int nTimeout, bool manual_connection)
{
// Create a sockaddr from the specified service.
struct sockaddr_storage sockaddr;
socklen_t len = sizeof(sockaddr);
- if (hSocket == INVALID_SOCKET) {
+ if (sock.Get() == INVALID_SOCKET) {
LogPrintf("Cannot connect to %s: invalid socket\n", addrConnect.ToString());
return false;
}
@@ -552,8 +552,7 @@ bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, i
}
// Connect to the addrConnect service on the hSocket socket.
- if (connect(hSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
- {
+ if (sock.Connect(reinterpret_cast<struct sockaddr*>(&sockaddr), len) == SOCKET_ERROR) {
int nErr = WSAGetLastError();
// WSAEINVAL is here because some legacy version of winsock uses it
if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL)
@@ -561,46 +560,34 @@ bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, i
// Connection didn't actually fail, but is being established
// asynchronously. Thus, use async I/O api (select/poll)
// synchronously to check for successful connection with a timeout.
-#ifdef USE_POLL
- struct pollfd pollfd = {};
- pollfd.fd = hSocket;
- pollfd.events = POLLIN | POLLOUT;
- int nRet = poll(&pollfd, 1, nTimeout);
-#else
- struct timeval timeout = MillisToTimeval(nTimeout);
- fd_set fdset;
- FD_ZERO(&fdset);
- FD_SET(hSocket, &fdset);
- int nRet = select(hSocket + 1, nullptr, &fdset, nullptr, &timeout);
-#endif
- // Upon successful completion, both select and poll return the total
- // number of file descriptors that have been selected. A value of 0
- // indicates that the call timed out and no file descriptors have
- // been selected.
- if (nRet == 0)
- {
- LogPrint(BCLog::NET, "connection to %s timeout\n", addrConnect.ToString());
+ const Sock::Event requested = Sock::RECV | Sock::SEND;
+ Sock::Event occurred;
+ if (!sock.Wait(std::chrono::milliseconds{nTimeout}, requested, &occurred)) {
+ LogPrintf("wait for connect to %s failed: %s\n",
+ addrConnect.ToString(),
+ NetworkErrorString(WSAGetLastError()));
return false;
- }
- if (nRet == SOCKET_ERROR)
- {
- LogPrintf("select() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError()));
+ } else if (occurred == 0) {
+ LogPrint(BCLog::NET, "connection attempt to %s timed out\n", addrConnect.ToString());
return false;
}
- // Even if the select/poll was successful, the connect might not
+ // Even if the wait was successful, the connect might not
// have been successful. The reason for this failure is hidden away
// in the SO_ERROR for the socket in modern systems. We read it into
- // nRet here.
- socklen_t nRetSize = sizeof(nRet);
- if (getsockopt(hSocket, SOL_SOCKET, SO_ERROR, (sockopt_arg_type)&nRet, &nRetSize) == SOCKET_ERROR)
- {
+ // sockerr here.
+ int sockerr;
+ socklen_t sockerr_len = sizeof(sockerr);
+ if (sock.GetSockOpt(SOL_SOCKET, SO_ERROR, (sockopt_arg_type)&sockerr, &sockerr_len) ==
+ SOCKET_ERROR) {
LogPrintf("getsockopt() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError()));
return false;
}
- if (nRet != 0)
- {
- LogConnectFailure(manual_connection, "connect() to %s failed after select(): %s", addrConnect.ToString(), NetworkErrorString(nRet));
+ if (sockerr != 0) {
+ LogConnectFailure(manual_connection,
+ "connect() to %s failed after wait: %s",
+ addrConnect.ToString(),
+ NetworkErrorString(sockerr));
return false;
}
}
@@ -668,7 +655,7 @@ bool IsProxy(const CNetAddr &addr) {
bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& sock, int nTimeout, bool& outProxyConnectionFailed)
{
// first connect to proxy server
- if (!ConnectSocketDirectly(proxy.proxy, sock.Get(), nTimeout, true)) {
+ if (!ConnectSocketDirectly(proxy.proxy, sock, nTimeout, true)) {
outProxyConnectionFailed = true;
return false;
}