diff options
-rw-r--r-- | src/net.cpp | 24 | ||||
-rw-r--r-- | src/net.h | 9 | ||||
-rw-r--r-- | src/test/DoS_tests.cpp | 10 | ||||
-rw-r--r-- | src/test/net_tests.cpp | 5 |
4 files changed, 25 insertions, 23 deletions
diff --git a/src/net.cpp b/src/net.cpp index eb312ef1e4..8bc8ecc436 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -89,9 +89,6 @@ std::string strSubVersion; limitedmap<uint256, int64_t> mapAlreadyAskedFor(MAX_INV_SZ); -NodeId nLastNodeId = 0; -CCriticalSection cs_nLastNodeId; - static CSemaphore *semOutbound = NULL; boost::condition_variable messageHandlerCondition; @@ -404,7 +401,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo addrman.Attempt(addrConnect, fCountFailure); // Add node - CNode* pnode = new CNode(hSocket, addrConnect, pszDest ? pszDest : "", false); + CNode* pnode = new CNode(GetNewNodeId(), hSocket, addrConnect, pszDest ? pszDest : "", false); GetNodeSignals().InitializeNode(pnode->GetId(), pnode); pnode->AddRef(); @@ -1038,7 +1035,7 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { } } - CNode* pnode = new CNode(hSocket, addr, "", true); + CNode* pnode = new CNode(GetNewNodeId(), hSocket, addr, "", true); GetNodeSignals().InitializeNode(pnode->GetId(), pnode); pnode->AddRef(); pnode->fWhitelisted = whitelisted; @@ -2030,6 +2027,7 @@ CConnman::CConnman() { setBannedIsDirty = false; fAddressesInitialized = false; + nLastNodeId = 0; } bool StartNode(CConnman& connman, boost::thread_group& threadGroup, CScheduler& scheduler, std::string& strNodeError) @@ -2041,9 +2039,13 @@ bool StartNode(CConnman& connman, boost::thread_group& threadGroup, CScheduler& return ret; } -bool CConnman::Start(boost::thread_group& threadGroup, CScheduler& scheduler, std::string& strNodeError) +NodeId CConnman::GetNewNodeId() { + return nLastNodeId.fetch_add(1, std::memory_order_relaxed); +} +bool CConnman::Start(boost::thread_group& threadGroup, CScheduler& scheduler, std::string& strNodeError) +{ uiInterface.InitMessage(_("Loading addresses...")); // Load addresses from peers.dat int64_t nStart = GetTimeMillis(); @@ -2089,7 +2091,7 @@ bool CConnman::Start(boost::thread_group& threadGroup, CScheduler& scheduler, st if (pnodeLocalHost == NULL) { CNetAddr local; LookupHost("127.0.0.1", local, false); - pnodeLocalHost = new CNode(INVALID_SOCKET, CAddress(CService(local, 0), nLocalServices)); + pnodeLocalHost = new CNode(GetNewNodeId(), INVALID_SOCKET, CAddress(CService(local, 0), nLocalServices)); GetNodeSignals().InitializeNode(pnodeLocalHost->GetId(), pnodeLocalHost); } @@ -2478,7 +2480,7 @@ void CNode::Fuzz(int nChance) unsigned int ReceiveFloodSize() { return 1000*GetArg("-maxreceivebuffer", DEFAULT_MAXRECEIVEBUFFER); } unsigned int SendBufferSize() { return 1000*GetArg("-maxsendbuffer", DEFAULT_MAXSENDBUFFER); } -CNode::CNode(SOCKET hSocketIn, const CAddress& addrIn, const std::string& addrNameIn, bool fInboundIn) : +CNode::CNode(NodeId idIn, SOCKET hSocketIn, const CAddress& addrIn, const std::string& addrNameIn, bool fInboundIn) : ssSend(SER_NETWORK, INIT_PROTO_VERSION), addr(addrIn), nKeyedNetGroup(CalculateKeyedNetGroup(addrIn)), @@ -2531,16 +2533,12 @@ CNode::CNode(SOCKET hSocketIn, const CAddress& addrIn, const std::string& addrNa minFeeFilter = 0; lastSentFeeFilter = 0; nextSendTimeFeeFilter = 0; + id = idIn; BOOST_FOREACH(const std::string &msg, getAllNetMessageTypes()) mapRecvBytesPerMsgCmd[msg] = 0; mapRecvBytesPerMsgCmd[NET_MESSAGE_COMMAND_OTHER] = 0; - { - LOCK(cs_nLastNodeId); - id = nLastNodeId++; - } - if (fLogIPs) LogPrint("net", "Added connection to %s peer=%d\n", addrName, id); else @@ -196,6 +196,9 @@ private: bool IsWhitelistedRange(const CNetAddr &addr); void DeleteNode(CNode* pnode); + + NodeId GetNewNodeId(); + //!check is the banlist has unwritten changes bool BannedSetIsDirty(); //!set the "dirty" flag for the banlist @@ -223,6 +226,7 @@ private: CCriticalSection cs_vAddedNodes; std::vector<CNode*> vNodes; mutable CCriticalSection cs_vNodes; + std::atomic<NodeId> nLastNodeId; }; extern std::unique_ptr<CConnman> g_connman; void MapPort(bool fUseUPnP); @@ -300,9 +304,6 @@ extern int nMaxConnections; extern limitedmap<uint256, int64_t> mapAlreadyAskedFor; -extern NodeId nLastNodeId; -extern CCriticalSection cs_nLastNodeId; - /** Subversion as sent to the P2P network in `version` messages */ extern std::string strSubVersion; @@ -501,7 +502,7 @@ public: CAmount lastSentFeeFilter; int64_t nextSendTimeFeeFilter; - CNode(SOCKET hSocketIn, const CAddress &addrIn, const std::string &addrNameIn = "", bool fInboundIn = false); + CNode(NodeId id, SOCKET hSocketIn, const CAddress &addrIn, const std::string &addrNameIn = "", bool fInboundIn = false); ~CNode(); private: diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp index 412f94f40d..652a8e2cea 100644 --- a/src/test/DoS_tests.cpp +++ b/src/test/DoS_tests.cpp @@ -40,13 +40,15 @@ CService ip(uint32_t i) return CService(CNetAddr(s), Params().GetDefaultPort()); } +static NodeId id = 0; + BOOST_FIXTURE_TEST_SUITE(DoS_tests, TestingSetup) BOOST_AUTO_TEST_CASE(DoS_banning) { connman->ClearBanned(); CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(INVALID_SOCKET, addr1, "", true); + CNode dummyNode1(id++, INVALID_SOCKET, addr1, "", true); GetNodeSignals().InitializeNode(dummyNode1.GetId(), &dummyNode1); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); // Should get banned @@ -55,7 +57,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) BOOST_CHECK(!connman->IsBanned(ip(0xa0b0c001|0x0000ff00))); // Different IP, not banned CAddress addr2(ip(0xa0b0c002), NODE_NONE); - CNode dummyNode2(INVALID_SOCKET, addr2, "", true); + CNode dummyNode2(id++, INVALID_SOCKET, addr2, "", true); GetNodeSignals().InitializeNode(dummyNode2.GetId(), &dummyNode2); dummyNode2.nVersion = 1; Misbehaving(dummyNode2.GetId(), 50); @@ -72,7 +74,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) connman->ClearBanned(); mapArgs["-banscore"] = "111"; // because 11 is my favorite number CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(INVALID_SOCKET, addr1, "", true); + CNode dummyNode1(id++, INVALID_SOCKET, addr1, "", true); GetNodeSignals().InitializeNode(dummyNode1.GetId(), &dummyNode1); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); @@ -94,7 +96,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) SetMockTime(nStartTime); // Overrides future calls to GetTime() CAddress addr(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode(INVALID_SOCKET, addr, "", true); + CNode dummyNode(id++, INVALID_SOCKET, addr, "", true); GetNodeSignals().InitializeNode(dummyNode.GetId(), &dummyNode); dummyNode.nVersion = 1; diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index 267d1b55e1..00fb757167 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -153,6 +153,7 @@ BOOST_AUTO_TEST_CASE(caddrdb_read_corrupted) BOOST_AUTO_TEST_CASE(cnode_simple_test) { SOCKET hSocket = INVALID_SOCKET; + NodeId id = 0; in_addr ipv4Addr; ipv4Addr.s_addr = 0xa0b0c001; @@ -162,12 +163,12 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) bool fInboundIn = false; // Test that fFeeler is false by default. - CNode* pnode1 = new CNode(hSocket, addr, pszDest, fInboundIn); + CNode* pnode1 = new CNode(id++, hSocket, addr, pszDest, fInboundIn); BOOST_CHECK(pnode1->fInbound == false); BOOST_CHECK(pnode1->fFeeler == false); fInboundIn = true; - CNode* pnode2 = new CNode(hSocket, addr, pszDest, fInboundIn); + CNode* pnode2 = new CNode(id++, hSocket, addr, pszDest, fInboundIn); BOOST_CHECK(pnode2->fInbound == true); BOOST_CHECK(pnode2->fFeeler == false); } |