diff options
-rw-r--r-- | src/net.cpp | 4 | ||||
-rw-r--r-- | src/net.h | 4 | ||||
-rw-r--r-- | src/net_processing.cpp | 33 | ||||
-rw-r--r-- | src/net_processing.h | 5 | ||||
-rw-r--r-- | src/test/DoS_tests.cpp | 20 |
5 files changed, 39 insertions, 27 deletions
diff --git a/src/net.cpp b/src/net.cpp index a66679cd85..bbfada4301 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1856,7 +1856,7 @@ void CConnman::ThreadMessageHandler() TRY_LOCK(pnode->cs_vRecvMsg, lockRecv); if (lockRecv) { - if (!GetNodeSignals().ProcessMessages(pnode, *this)) + if (!GetNodeSignals().ProcessMessages(pnode, *this, flagInterruptMsgProc)) pnode->CloseSocketDisconnect(); if (pnode->nSendSize < GetSendBufferSize()) @@ -1875,7 +1875,7 @@ void CConnman::ThreadMessageHandler() { TRY_LOCK(pnode->cs_vSend, lockSend); if (lockSend) - GetNodeSignals().SendMessages(pnode, *this); + GetNodeSignals().SendMessages(pnode, *this, flagInterruptMsgProc); } if (flagInterruptMsgProc) return; @@ -460,8 +460,8 @@ struct CombinerAll // Signals for message handling struct CNodeSignals { - boost::signals2::signal<bool (CNode*, CConnman&), CombinerAll> ProcessMessages; - boost::signals2::signal<bool (CNode*, CConnman&), CombinerAll> SendMessages; + boost::signals2::signal<bool (CNode*, CConnman&, std::atomic<bool>&), CombinerAll> ProcessMessages; + boost::signals2::signal<bool (CNode*, CConnman&, std::atomic<bool>&), CombinerAll> SendMessages; boost::signals2::signal<void (CNode*, CConnman&)> InitializeNode; boost::signals2::signal<void (NodeId, bool&)> FinalizeNode; }; diff --git a/src/net_processing.cpp b/src/net_processing.cpp index 380decbcb7..f3a04080d1 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -886,7 +886,7 @@ static void RelayAddress(const CAddress& addr, bool fReachable, CConnman& connma connman.ForEachNodeThen(std::move(sortfunc), std::move(pushfunc)); } -void static ProcessGetData(CNode* pfrom, const Consensus::Params& consensusParams, CConnman& connman) +void static ProcessGetData(CNode* pfrom, const Consensus::Params& consensusParams, CConnman& connman, std::atomic<bool>& interruptMsgProc) { std::deque<CInv>::iterator it = pfrom->vRecvGetData.begin(); unsigned int nMaxSendBufferSize = connman.GetSendBufferSize(); @@ -901,7 +901,9 @@ void static ProcessGetData(CNode* pfrom, const Consensus::Params& consensusParam const CInv &inv = *it; { - boost::this_thread::interruption_point(); + if (interruptMsgProc) + return; + it++; if (inv.type == MSG_BLOCK || inv.type == MSG_FILTERED_BLOCK || inv.type == MSG_CMPCT_BLOCK || inv.type == MSG_WITNESS_BLOCK) @@ -1055,7 +1057,7 @@ uint32_t GetFetchFlags(CNode* pfrom, CBlockIndex* pprev, const Consensus::Params return nFetchFlags; } -bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, CConnman& connman) +bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, CConnman& connman, std::atomic<bool>& interruptMsgProc) { unsigned int nMaxSendBufferSize = connman.GetSendBufferSize(); @@ -1295,7 +1297,8 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, int64_t nSince = nNow - 10 * 60; BOOST_FOREACH(CAddress& addr, vAddr) { - boost::this_thread::interruption_point(); + if (interruptMsgProc) + return true; if ((addr.nServices & REQUIRED_SERVICES) != REQUIRED_SERVICES) continue; @@ -1377,7 +1380,8 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, { CInv &inv = vInv[nInv]; - boost::this_thread::interruption_point(); + if (interruptMsgProc) + return true; bool fAlreadyHave = AlreadyHave(inv); LogPrint("net", "got inv: %s %s peer=%d\n", inv.ToString(), fAlreadyHave ? "have" : "new", pfrom->id); @@ -1439,7 +1443,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, LogPrint("net", "received getdata for: %s peer=%d\n", vInv[0].ToString(), pfrom->id); pfrom->vRecvGetData.insert(pfrom->vRecvGetData.end(), vInv.begin(), vInv.end()); - ProcessGetData(pfrom, chainparams.GetConsensus(), connman); + ProcessGetData(pfrom, chainparams.GetConsensus(), connman, interruptMsgProc); } @@ -1513,7 +1517,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, inv.type = State(pfrom->GetId())->fWantsCmpctWitness ? MSG_WITNESS_BLOCK : MSG_BLOCK; inv.hash = req.blockhash; pfrom->vRecvGetData.push_back(inv); - ProcessGetData(pfrom, chainparams.GetConsensus(), connman); + ProcessGetData(pfrom, chainparams.GetConsensus(), connman, interruptMsgProc); return true; } @@ -1925,10 +1929,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, } // cs_main if (fProcessBLOCKTXN) - return ProcessMessage(pfrom, NetMsgType::BLOCKTXN, blockTxnMsg, nTimeReceived, chainparams, connman); + return ProcessMessage(pfrom, NetMsgType::BLOCKTXN, blockTxnMsg, nTimeReceived, chainparams, connman, interruptMsgProc); if (fRevertToHeaderProcessing) - return ProcessMessage(pfrom, NetMsgType::HEADERS, vHeadersMsg, nTimeReceived, chainparams, connman); + return ProcessMessage(pfrom, NetMsgType::HEADERS, vHeadersMsg, nTimeReceived, chainparams, connman, interruptMsgProc); if (fBlockReconstructed) { // If we got here, we were able to optimistically reconstruct a @@ -2441,7 +2445,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, } // requires LOCK(cs_vRecvMsg) -bool ProcessMessages(CNode* pfrom, CConnman& connman) +bool ProcessMessages(CNode* pfrom, CConnman& connman, std::atomic<bool>& interruptMsgProc) { const CChainParams& chainparams = Params(); unsigned int nMaxSendBufferSize = connman.GetSendBufferSize(); @@ -2459,7 +2463,7 @@ bool ProcessMessages(CNode* pfrom, CConnman& connman) bool fOk = true; if (!pfrom->vRecvGetData.empty()) - ProcessGetData(pfrom, chainparams.GetConsensus(), connman); + ProcessGetData(pfrom, chainparams.GetConsensus(), connman, interruptMsgProc); // this maintains the order of responses if (!pfrom->vRecvGetData.empty()) return fOk; @@ -2520,8 +2524,9 @@ bool ProcessMessages(CNode* pfrom, CConnman& connman) bool fRet = false; try { - fRet = ProcessMessage(pfrom, strCommand, vRecv, msg.nTime, chainparams, connman); - boost::this_thread::interruption_point(); + fRet = ProcessMessage(pfrom, strCommand, vRecv, msg.nTime, chainparams, connman, interruptMsgProc); + if (interruptMsgProc) + return true; } catch (const std::ios_base::failure& e) { @@ -2585,7 +2590,7 @@ public: } }; -bool SendMessages(CNode* pto, CConnman& connman) +bool SendMessages(CNode* pto, CConnman& connman, std::atomic<bool>& interruptMsgProc) { const Consensus::Params& consensusParams = Params().GetConsensus(); { diff --git a/src/net_processing.h b/src/net_processing.h index 130433cc7c..9e76cad505 100644 --- a/src/net_processing.h +++ b/src/net_processing.h @@ -39,13 +39,14 @@ bool GetNodeStateStats(NodeId nodeid, CNodeStateStats &stats); void Misbehaving(NodeId nodeid, int howmuch); /** Process protocol messages received from a given node */ -bool ProcessMessages(CNode* pfrom, CConnman& connman); +bool ProcessMessages(CNode* pfrom, CConnman& connman, std::atomic<bool>& interrupt); /** * Send queued protocol messages to be sent to a give node. * * @param[in] pto The node which we are sending messages to. * @param[in] connman The connection manager for that node. + * @param[in] interrupt Interrupt condition for processing threads */ -bool SendMessages(CNode* pto, CConnman& connman); +bool SendMessages(CNode* pto, CConnman& connman, std::atomic<bool>& interrupt); #endif // BITCOIN_NET_PROCESSING_H diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp index d90dcaeb02..35edf60a2b 100644 --- a/src/test/DoS_tests.cpp +++ b/src/test/DoS_tests.cpp @@ -47,6 +47,8 @@ BOOST_FIXTURE_TEST_SUITE(DoS_tests, TestingSetup) BOOST_AUTO_TEST_CASE(DoS_banning) { + std::atomic<bool> interruptDummy(false); + connman->ClearBanned(); CAddress addr1(ip(0xa0b0c001), NODE_NONE); CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 0, 0, "", true); @@ -54,7 +56,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); // Should get banned - SendMessages(&dummyNode1, *connman); + SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(connman->IsBanned(addr1)); BOOST_CHECK(!connman->IsBanned(ip(0xa0b0c001|0x0000ff00))); // Different IP, not banned @@ -64,16 +66,18 @@ BOOST_AUTO_TEST_CASE(DoS_banning) GetNodeSignals().InitializeNode(&dummyNode2, *connman); dummyNode2.nVersion = 1; Misbehaving(dummyNode2.GetId(), 50); - SendMessages(&dummyNode2, *connman); + SendMessages(&dummyNode2, *connman, interruptDummy); BOOST_CHECK(!connman->IsBanned(addr2)); // 2 not banned yet... BOOST_CHECK(connman->IsBanned(addr1)); // ... but 1 still should be Misbehaving(dummyNode2.GetId(), 50); - SendMessages(&dummyNode2, *connman); + SendMessages(&dummyNode2, *connman, interruptDummy); BOOST_CHECK(connman->IsBanned(addr2)); } BOOST_AUTO_TEST_CASE(DoS_banscore) { + std::atomic<bool> interruptDummy(false); + connman->ClearBanned(); ForceSetArg("-banscore", "111"); // because 11 is my favorite number CAddress addr1(ip(0xa0b0c001), NODE_NONE); @@ -82,19 +86,21 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); - SendMessages(&dummyNode1, *connman); + SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(!connman->IsBanned(addr1)); Misbehaving(dummyNode1.GetId(), 10); - SendMessages(&dummyNode1, *connman); + SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(!connman->IsBanned(addr1)); Misbehaving(dummyNode1.GetId(), 1); - SendMessages(&dummyNode1, *connman); + SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(connman->IsBanned(addr1)); ForceSetArg("-banscore", std::to_string(DEFAULT_BANSCORE_THRESHOLD)); } BOOST_AUTO_TEST_CASE(DoS_bantime) { + std::atomic<bool> interruptDummy(false); + connman->ClearBanned(); int64_t nStartTime = GetTime(); SetMockTime(nStartTime); // Overrides future calls to GetTime() @@ -106,7 +112,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) dummyNode.nVersion = 1; Misbehaving(dummyNode.GetId(), 100); - SendMessages(&dummyNode, *connman); + SendMessages(&dummyNode, *connman, interruptDummy); BOOST_CHECK(connman->IsBanned(addr)); SetMockTime(nStartTime+60*60); |