diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/init.cpp | 1 | ||||
-rw-r--r-- | src/main.cpp | 102 | ||||
-rw-r--r-- | src/main.h | 7 | ||||
-rw-r--r-- | src/net.cpp | 32 | ||||
-rw-r--r-- | src/net.h | 26 | ||||
-rw-r--r-- | src/rpcnet.cpp | 10 | ||||
-rw-r--r-- | src/test/DoS_tests.cpp | 26 | ||||
-rw-r--r-- | src/test/test_bitcoin.cpp | 2 |
8 files changed, 158 insertions, 48 deletions
diff --git a/src/init.cpp b/src/init.cpp index fc15df0594..df3cedc202 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -120,6 +120,7 @@ void Shutdown() GenerateBitcoins(false, NULL, 0); #endif StopNode(); + UnregisterNodeSignals(GetNodeSignals()); { LOCK(cs_main); #ifdef ENABLE_WALLET diff --git a/src/main.cpp b/src/main.cpp index 25201c7367..d130e9705e 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -153,17 +153,66 @@ void SyncWithWallets(const uint256 &hash, const CTransaction &tx, const CBlock * // Registration of network node signals. // -int static GetHeight() +namespace { +// Maintain validation-specific state about nodes, protected by cs_main, instead +// by CNode's own locks. This simplifies asynchronous operation, where +// processing of incoming data is done after the ProcessMessage call returns, +// and we're no longer holding the node's locks. +struct CNodeState { + int nMisbehavior; + bool fShouldBan; + std::string name; + + CNodeState() { + nMisbehavior = 0; + fShouldBan = false; + } +}; + +map<NodeId, CNodeState> mapNodeState; + +// Requires cs_main. +CNodeState *State(NodeId pnode) { + map<NodeId, CNodeState>::iterator it = mapNodeState.find(pnode); + if (it == mapNodeState.end()) + return NULL; + return &it->second; +} + +int GetHeight() { LOCK(cs_main); return chainActive.Height(); } +void InitializeNode(NodeId nodeid, const CNode *pnode) { + LOCK(cs_main); + CNodeState &state = mapNodeState.insert(std::make_pair(nodeid, CNodeState())).first->second; + state.name = pnode->addrName; +} + +void FinalizeNode(NodeId nodeid) { + LOCK(cs_main); + mapNodeState.erase(nodeid); +} +} + +bool GetNodeStateStats(NodeId nodeid, CNodeStateStats &stats) { + LOCK(cs_main); + CNodeState *state = State(nodeid); + if (state == NULL) + return false; + stats.nMisbehavior = state->nMisbehavior; + return true; +} + void RegisterNodeSignals(CNodeSignals& nodeSignals) { nodeSignals.GetHeight.connect(&GetHeight); nodeSignals.ProcessMessages.connect(&ProcessMessages); nodeSignals.SendMessages.connect(&SendMessages); + nodeSignals.InitializeNode.connect(&InitializeNode); + nodeSignals.FinalizeNode.connect(&FinalizeNode); } void UnregisterNodeSignals(CNodeSignals& nodeSignals) @@ -171,6 +220,8 @@ void UnregisterNodeSignals(CNodeSignals& nodeSignals) nodeSignals.GetHeight.disconnect(&GetHeight); nodeSignals.ProcessMessages.disconnect(&ProcessMessages); nodeSignals.SendMessages.disconnect(&SendMessages); + nodeSignals.InitializeNode.disconnect(&InitializeNode); + nodeSignals.FinalizeNode.disconnect(&FinalizeNode); } ////////////////////////////////////////////////////////////////////////////// @@ -2915,6 +2966,23 @@ bool static AlreadyHave(const CInv& inv) } +void Misbehaving(NodeId pnode, int howmuch) +{ + if (howmuch == 0) + return; + + CNodeState *state = State(pnode); + if (state == NULL) + return; + + state->nMisbehavior += howmuch; + if (state->nMisbehavior >= GetArg("-banscore", 100)) + { + LogPrintf("Misbehaving: %s (%d -> %d) BAN THRESHOLD EXCEEDED\n", state->name.c_str(), state->nMisbehavior-howmuch, state->nMisbehavior); + state->fShouldBan = true; + } else + LogPrintf("Misbehaving: %s (%d -> %d)\n", state->name.c_str(), state->nMisbehavior-howmuch, state->nMisbehavior); +} void static ProcessGetData(CNode* pfrom) { @@ -3048,7 +3116,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) if (pfrom->nVersion != 0) { pfrom->PushMessage("reject", strCommand, REJECT_DUPLICATE, string("Duplicate version message")); - pfrom->Misbehaving(1); + Misbehaving(pfrom->GetId(), 1); return false; } @@ -3153,7 +3221,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) else if (pfrom->nVersion == 0) { // Must have a version message before anything else - pfrom->Misbehaving(1); + Misbehaving(pfrom->GetId(), 1); return false; } @@ -3174,7 +3242,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) return true; if (vAddr.size() > 1000) { - pfrom->Misbehaving(20); + Misbehaving(pfrom->GetId(), 20); return error("message addr size() = %"PRIszu"", vAddr.size()); } @@ -3237,7 +3305,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) vRecv >> vInv; if (vInv.size() > MAX_INV_SZ) { - pfrom->Misbehaving(20); + Misbehaving(pfrom->GetId(), 20); return error("message inv size() = %"PRIszu"", vInv.size()); } @@ -3288,7 +3356,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) vRecv >> vInv; if (vInv.size() > MAX_INV_SZ) { - pfrom->Misbehaving(20); + Misbehaving(pfrom->GetId(), 20); return error("message getdata size() = %"PRIszu"", vInv.size()); } @@ -3461,7 +3529,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) pfrom->PushMessage("reject", strCommand, state.GetRejectCode(), state.GetRejectReason(), inv.hash); if (nDoS > 0) - pfrom->Misbehaving(nDoS); + Misbehaving(pfrom->GetId(), nDoS); } } @@ -3488,7 +3556,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) pfrom->PushMessage("reject", strCommand, state.GetRejectCode(), state.GetRejectReason(), inv.hash); if (nDoS > 0) - pfrom->Misbehaving(nDoS); + Misbehaving(pfrom->GetId(), nDoS); } } @@ -3631,7 +3699,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) // This isn't a Misbehaving(100) (immediate ban) because the // peer might be an older or different implementation with // a different signature key, etc. - pfrom->Misbehaving(10); + Misbehaving(pfrom->GetId(), 10); } } } @@ -3644,7 +3712,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) if (!filter.IsWithinSizeConstraints()) // There is no excuse for sending a too-large filter - pfrom->Misbehaving(100); + Misbehaving(pfrom->GetId(), 100); else { LOCK(pfrom->cs_filter); @@ -3665,13 +3733,13 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) // and thus, the maximum size any matched object can have) in a filteradd message if (vData.size() > MAX_SCRIPT_ELEMENT_SIZE) { - pfrom->Misbehaving(100); + Misbehaving(pfrom->GetId(), 100); } else { LOCK(pfrom->cs_filter); if (pfrom->pfilter) pfrom->pfilter->insert(vData); else - pfrom->Misbehaving(100); + Misbehaving(pfrom->GetId(), 100); } } @@ -3936,6 +4004,16 @@ bool SendMessages(CNode* pto, bool fSendTrickle) if (!lockMain) return true; + if (State(pto->GetId())->fShouldBan) { + if (pto->addr.IsLocal()) + LogPrintf("Warning: not banning local node %s!\n", pto->addr.ToString().c_str()); + else { + pto->fDisconnect = true; + CNode::Ban(pto->addr); + } + State(pto->GetId())->fShouldBan = false; + } + // Start block sync if (pto->fStartSync && !fImporting && !fReindex) { pto->fStartSync = false; diff --git a/src/main.h b/src/main.h index c4e1839443..c52f37cc87 100644 --- a/src/main.h +++ b/src/main.h @@ -110,6 +110,7 @@ class CTxUndo; class CScriptCheck; class CValidationState; class CWalletInterface; +struct CNodeStateStats; struct CBlockTemplate; @@ -182,6 +183,8 @@ CBlockIndex * InsertBlockIndex(uint256 hash); bool VerifySignature(const CCoins& txFrom, const CTransaction& txTo, unsigned int nIn, unsigned int flags, int nHashType); /** Abort with a message */ bool AbortNode(const std::string &msg); +/** Get statistics from node state */ +bool GetNodeStateStats(NodeId nodeid, CNodeStateStats &stats); /** (try to) add transaction to memory pool **/ bool AcceptToMemoryPool(CTxMemPool& pool, CValidationState &state, const CTransaction &tx, bool fLimitFree, @@ -194,6 +197,10 @@ bool AcceptToMemoryPool(CTxMemPool& pool, CValidationState &state, const CTransa +struct CNodeStateStats { + int nMisbehavior; +}; + struct CDiskBlockPos { int nFile; diff --git a/src/net.cpp b/src/net.cpp index 99a56a0a5d..6ae749c657 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -80,6 +80,9 @@ CCriticalSection cs_setservAddNodeAddresses; vector<std::string> vAddedNodes; CCriticalSection cs_vAddedNodes; +NodeId nLastNodeId = 0; +CCriticalSection cs_nLastNodeId; + static CSemaphore *semOutbound = NULL; // Signals for message handling @@ -581,35 +584,21 @@ bool CNode::IsBanned(CNetAddr ip) return fResult; } -bool CNode::Misbehaving(int howmuch) -{ - if (addr.IsLocal()) +bool CNode::Ban(const CNetAddr &addr) { + int64_t banTime = GetTime()+GetArg("-bantime", 60*60*24); // Default 24-hour ban { - LogPrintf("Warning: Local node %s misbehaving (delta: %d)!\n", addrName.c_str(), howmuch); - return false; + LOCK(cs_setBanned); + if (setBanned[addr] < banTime) + setBanned[addr] = banTime; } - - nMisbehavior += howmuch; - if (nMisbehavior >= GetArg("-banscore", 100)) - { - int64_t banTime = GetTime()+GetArg("-bantime", 60*60*24); // Default 24-hour ban - LogPrintf("Misbehaving: %s (%d -> %d) DISCONNECTING\n", addr.ToString().c_str(), nMisbehavior-howmuch, nMisbehavior); - { - LOCK(cs_setBanned); - if (setBanned[addr] < banTime) - setBanned[addr] = banTime; - } - CloseSocketDisconnect(); - return true; - } else - LogPrintf("Misbehaving: %s (%d -> %d)\n", addr.ToString().c_str(), nMisbehavior-howmuch, nMisbehavior); - return false; + return true; } #undef X #define X(name) stats.name = name void CNode::copyStats(CNodeStats &stats) { + stats.nodeid = this->GetId(); X(nServices); X(nLastSend); X(nLastRecv); @@ -619,7 +608,6 @@ void CNode::copyStats(CNodeStats &stats) X(cleanSubVer); X(fInbound); X(nStartingHeight); - X(nMisbehavior); X(nSendBytes); X(nRecvBytes); stats.fSyncNode = (this == pnodeSync); @@ -57,14 +57,19 @@ void StartNode(boost::thread_group& threadGroup); bool StopNode(); void SocketSendData(CNode *pnode); +typedef int NodeId; + // Signals for message handling struct CNodeSignals { boost::signals2::signal<int ()> GetHeight; boost::signals2::signal<bool (CNode*)> ProcessMessages; boost::signals2::signal<bool (CNode*, bool)> SendMessages; + boost::signals2::signal<void (NodeId, const CNode*)> InitializeNode; + boost::signals2::signal<void (NodeId)> FinalizeNode; }; + CNodeSignals& GetNodeSignals(); @@ -109,12 +114,14 @@ extern limitedmap<CInv, int64_t> mapAlreadyAskedFor; extern std::vector<std::string> vAddedNodes; extern CCriticalSection cs_vAddedNodes; - +extern NodeId nLastNodeId; +extern CCriticalSection cs_nLastNodeId; class CNodeStats { public: + NodeId nodeid; uint64_t nServices; int64_t nLastSend; int64_t nLastRecv; @@ -124,7 +131,6 @@ public: std::string cleanSubVer; bool fInbound; int nStartingHeight; - int nMisbehavior; uint64_t nSendBytes; uint64_t nRecvBytes; bool fSyncNode; @@ -223,13 +229,13 @@ public: CCriticalSection cs_filter; CBloomFilter* pfilter; int nRefCount; + NodeId id; protected: // Denial-of-service detection/prevention // Key is IP address, value is banned-until-time static std::map<CNetAddr, int64_t> setBanned; static CCriticalSection cs_setBanned; - int nMisbehavior; // Basic fuzz-testing void Fuzz(int nChance); // modifies ssSend @@ -289,7 +295,6 @@ public: nStartingHeight = -1; fStartSync = false; fGetAddr = false; - nMisbehavior = 0; fRelayTxes = false; setInventoryKnown.max_size(SendBufferSize() / 1000); pfilter = new CBloomFilter(); @@ -298,9 +303,16 @@ public: nPingUsecTime = 0; fPingQueued = false; + { + LOCK(cs_nLastNodeId); + id = nLastNodeId++; + } + // Be shy and don't send version until we hear if (hSocket != INVALID_SOCKET && !fInbound) PushVersion(); + + GetNodeSignals().InitializeNode(GetId(), this); } ~CNode() @@ -312,6 +324,7 @@ public: } if (pfilter) delete pfilter; + GetNodeSignals().FinalizeNode(GetId()); } private: @@ -326,6 +339,9 @@ private: public: + NodeId GetId() const { + return id; + } int GetRefCount() { @@ -673,7 +689,7 @@ public: // new code. static void ClearBanned(); // needed for unit testing static bool IsBanned(CNetAddr ip); - bool Misbehaving(int howmuch); // 1 == a little, 100 == a lot + static bool Ban(const CNetAddr &ip); void copyStats(CNodeStats &stats); // Network stats diff --git a/src/rpcnet.cpp b/src/rpcnet.cpp index baa3268fb0..93811e80ed 100644 --- a/src/rpcnet.cpp +++ b/src/rpcnet.cpp @@ -2,6 +2,9 @@ // Distributed under the MIT/X11 software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include "rpcserver.h" + + +#include "main.h" #include "net.h" #include "netbase.h" #include "protocol.h" @@ -114,7 +117,8 @@ Value getpeerinfo(const Array& params, bool fHelp) BOOST_FOREACH(const CNodeStats& stats, vstats) { Object obj; - + CNodeStateStats statestats; + bool fStateStats = GetNodeStateStats(stats.nodeid, statestats); obj.push_back(Pair("addr", stats.addrName)); if (!(stats.addrLocal.empty())) obj.push_back(Pair("addrlocal", stats.addrLocal)); @@ -134,7 +138,9 @@ Value getpeerinfo(const Array& params, bool fHelp) obj.push_back(Pair("subver", stats.cleanSubVer)); obj.push_back(Pair("inbound", stats.fInbound)); obj.push_back(Pair("startingheight", stats.nStartingHeight)); - obj.push_back(Pair("banscore", stats.nMisbehavior)); + if (fStateStats) { + obj.push_back(Pair("banscore", statestats.nMisbehavior)); + } if (stats.fSyncNode) obj.push_back(Pair("syncnode", true)); diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp index f0fb84bc54..fbca09b4dc 100644 --- a/src/test/DoS_tests.cpp +++ b/src/test/DoS_tests.cpp @@ -21,6 +21,7 @@ // Tests this internal-to-main.cpp method: extern bool AddOrphanTx(const CTransaction& tx); extern unsigned int LimitOrphanTxSize(unsigned int nMaxOrphans); +extern void Misbehaving(NodeId nodeid, int howmuch); extern std::map<uint256, CTransaction> mapOrphanTransactions; extern std::map<uint256, std::set<uint256> > mapOrphanTransactionsByPrev; @@ -38,16 +39,21 @@ BOOST_AUTO_TEST_CASE(DoS_banning) CNode::ClearBanned(); CAddress addr1(ip(0xa0b0c001)); CNode dummyNode1(INVALID_SOCKET, addr1, "", true); - dummyNode1.Misbehaving(100); // Should get banned + dummyNode1.nVersion = 1; + Misbehaving(dummyNode1.GetId(), 100); // Should get banned + SendMessages(&dummyNode1, false); BOOST_CHECK(CNode::IsBanned(addr1)); BOOST_CHECK(!CNode::IsBanned(ip(0xa0b0c001|0x0000ff00))); // Different IP, not banned CAddress addr2(ip(0xa0b0c002)); CNode dummyNode2(INVALID_SOCKET, addr2, "", true); - dummyNode2.Misbehaving(50); + dummyNode2.nVersion = 1; + Misbehaving(dummyNode2.GetId(), 50); + SendMessages(&dummyNode2, false); BOOST_CHECK(!CNode::IsBanned(addr2)); // 2 not banned yet... BOOST_CHECK(CNode::IsBanned(addr1)); // ... but 1 still should be - dummyNode2.Misbehaving(50); + Misbehaving(dummyNode2.GetId(), 50); + SendMessages(&dummyNode2, false); BOOST_CHECK(CNode::IsBanned(addr2)); } @@ -57,11 +63,15 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) mapArgs["-banscore"] = "111"; // because 11 is my favorite number CAddress addr1(ip(0xa0b0c001)); CNode dummyNode1(INVALID_SOCKET, addr1, "", true); - dummyNode1.Misbehaving(100); + dummyNode1.nVersion = 1; + Misbehaving(dummyNode1.GetId(), 100); + SendMessages(&dummyNode1, false); BOOST_CHECK(!CNode::IsBanned(addr1)); - dummyNode1.Misbehaving(10); + Misbehaving(dummyNode1.GetId(), 10); + SendMessages(&dummyNode1, false); BOOST_CHECK(!CNode::IsBanned(addr1)); - dummyNode1.Misbehaving(1); + Misbehaving(dummyNode1.GetId(), 1); + SendMessages(&dummyNode1, false); BOOST_CHECK(CNode::IsBanned(addr1)); mapArgs.erase("-banscore"); } @@ -74,8 +84,10 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) CAddress addr(ip(0xa0b0c001)); CNode dummyNode(INVALID_SOCKET, addr, "", true); + dummyNode.nVersion = 1; - dummyNode.Misbehaving(100); + Misbehaving(dummyNode.GetId(), 100); + SendMessages(&dummyNode, false); BOOST_CHECK(CNode::IsBanned(addr)); SetMockTime(nStartTime+60*60); diff --git a/src/test/test_bitcoin.cpp b/src/test/test_bitcoin.cpp index a804ff3803..a4592fe803 100644 --- a/src/test/test_bitcoin.cpp +++ b/src/test/test_bitcoin.cpp @@ -47,11 +47,13 @@ struct TestingSetup { nScriptCheckThreads = 3; for (int i=0; i < nScriptCheckThreads-1; i++) threadGroup.create_thread(&ThreadScriptCheck); + RegisterNodeSignals(GetNodeSignals()); } ~TestingSetup() { threadGroup.interrupt_all(); threadGroup.join_all(); + UnregisterNodeSignals(GetNodeSignals()); #ifdef ENABLE_WALLET delete pwalletMain; pwalletMain = NULL; |