aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGavin Andresen <gavinandresen@gmail.com>2011-09-26 06:06:16 -0700
committerGavin Andresen <gavinandresen@gmail.com>2011-09-26 06:06:16 -0700
commit17e2c24645a10354849dec917b31f364e9056d58 (patch)
tree0fc4eda5307bbc9dc99eb3c8fec40b303644cd1c /src
parentf7f2a36925bb560363f691fc3ca3dec83830dd15 (diff)
parent806704c237890527ca2a7bab4c97550431eebea0 (diff)
downloadbitcoin-17e2c24645a10354849dec917b31f364e9056d58.tar.xz
Merge pull request #517 from gavinandresen/DoSprevention
Denial-of-service prevention
Diffstat (limited to 'src')
-rw-r--r--src/init.cpp2
-rw-r--r--src/main.cpp75
-rw-r--r--src/main.h8
-rw-r--r--src/net.cpp54
-rw-r--r--src/net.h28
-rw-r--r--src/test/DoS_tests.cpp68
-rw-r--r--src/test/test_bitcoin.cpp2
-rw-r--r--src/util.cpp9
-rw-r--r--src/util.h1
9 files changed, 215 insertions, 32 deletions
diff --git a/src/init.cpp b/src/init.cpp
index d2045fd84b..a1835c9038 100644
--- a/src/init.cpp
+++ b/src/init.cpp
@@ -179,6 +179,8 @@ bool AppInit2(int argc, char* argv[])
" -addnode=<ip> \t " + _("Add a node to connect to\n") +
" -connect=<ip> \t\t " + _("Connect only to the specified node\n") +
" -nolisten \t " + _("Don't accept connections from outside\n") +
+ " -banscore=<n> \t " + _("Threshold for disconnecting misbehaving peers (default: 100)\n") +
+ " -bantime=<n> \t " + _("Number of seconds to keep misbehaving peers from reconnecting (default: 86400)\n") +
#ifdef USE_UPNP
#if USE_UPNP
" -noupnp \t " + _("Don't attempt to use UPnP to map the listening port\n") +
diff --git a/src/main.cpp b/src/main.cpp
index e732ddcf5d..434c8e848d 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -297,24 +297,24 @@ bool CTransaction::CheckTransaction() const
{
// Basic checks that don't depend on any context
if (vin.empty())
- return error("CTransaction::CheckTransaction() : vin empty");
+ return DoS(10, error("CTransaction::CheckTransaction() : vin empty"));
if (vout.empty())
- return error("CTransaction::CheckTransaction() : vout empty");
+ return DoS(10, error("CTransaction::CheckTransaction() : vout empty"));
// Size limits
if (::GetSerializeSize(*this, SER_NETWORK) > MAX_BLOCK_SIZE)
- return error("CTransaction::CheckTransaction() : size limits failed");
+ return DoS(100, error("CTransaction::CheckTransaction() : size limits failed"));
// Check for negative or overflow output values
int64 nValueOut = 0;
BOOST_FOREACH(const CTxOut& txout, vout)
{
if (txout.nValue < 0)
- return error("CTransaction::CheckTransaction() : txout.nValue negative");
+ return DoS(100, error("CTransaction::CheckTransaction() : txout.nValue negative"));
if (txout.nValue > MAX_MONEY)
- return error("CTransaction::CheckTransaction() : txout.nValue too high");
+ return DoS(100, error("CTransaction::CheckTransaction() : txout.nValue too high"));
nValueOut += txout.nValue;
if (!MoneyRange(nValueOut))
- return error("CTransaction::CheckTransaction() : txout total out of range");
+ return DoS(100, error("CTransaction::CheckTransaction() : txout total out of range"));
}
// Check for duplicate inputs
@@ -329,13 +329,13 @@ bool CTransaction::CheckTransaction() const
if (IsCoinBase())
{
if (vin[0].scriptSig.size() < 2 || vin[0].scriptSig.size() > 100)
- return error("CTransaction::CheckTransaction() : coinbase script size");
+ return DoS(100, error("CTransaction::CheckTransaction() : coinbase script size"));
}
else
{
BOOST_FOREACH(const CTxIn& txin, vin)
if (txin.prevout.IsNull())
- return error("CTransaction::CheckTransaction() : prevout is null");
+ return DoS(10, error("CTransaction::CheckTransaction() : prevout is null"));
}
return true;
@@ -351,7 +351,7 @@ bool CTransaction::AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs, bool* pfMi
// Coinbase is only valid in a block, not as a loose transaction
if (IsCoinBase())
- return error("AcceptToMemoryPool() : coinbase as individual tx");
+ return DoS(100, error("AcceptToMemoryPool() : coinbase as individual tx"));
// To help v0.1.5 clients who would see it as a negative number
if ((int64)nLockTime > INT_MAX)
@@ -364,7 +364,7 @@ bool CTransaction::AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs, bool* pfMi
// 34 bytes because a TxOut is:
// 20-byte address + 8 byte bitcoin amount + 5 bytes of ops + 1 byte script length
if (GetSigOpCount() > nSize / 34 || nSize < 100)
- return error("AcceptToMemoryPool() : nonstandard transaction");
+ return DoS(10, error("AcceptToMemoryPool() : transaction with out-of-bounds SigOpCount"));
// Rather not work on nonstandard transactions (unless -testnet)
if (!fTestNet && !IsStandard())
@@ -855,26 +855,28 @@ bool CTransaction::ConnectInputs(CTxDB& txdb, map<uint256, CTxIndex>& mapTestPoo
}
if (prevout.n >= txPrev.vout.size() || prevout.n >= txindex.vSpent.size())
- return error("ConnectInputs() : %s prevout.n out of range %d %d %d prev tx %s\n%s", GetHash().ToString().substr(0,10).c_str(), prevout.n, txPrev.vout.size(), txindex.vSpent.size(), prevout.hash.ToString().substr(0,10).c_str(), txPrev.ToString().c_str());
+ return DoS(100, error("ConnectInputs() : %s prevout.n out of range %d %d %d prev tx %s\n%s", GetHash().ToString().substr(0,10).c_str(), prevout.n, txPrev.vout.size(), txindex.vSpent.size(), prevout.hash.ToString().substr(0,10).c_str(), txPrev.ToString().c_str()));
// If prev is coinbase, check that it's matured
if (txPrev.IsCoinBase())
for (CBlockIndex* pindex = pindexBlock; pindex && pindexBlock->nHeight - pindex->nHeight < COINBASE_MATURITY; pindex = pindex->pprev)
if (pindex->nBlockPos == txindex.pos.nBlockPos && pindex->nFile == txindex.pos.nFile)
- return error("ConnectInputs() : tried to spend coinbase at depth %d", pindexBlock->nHeight - pindex->nHeight);
+ return DoS(10, error("ConnectInputs() : tried to spend coinbase at depth %d", pindexBlock->nHeight - pindex->nHeight));
// Verify signature
if (!VerifySignature(txPrev, *this, i))
- return error("ConnectInputs() : %s VerifySignature failed", GetHash().ToString().substr(0,10).c_str());
+ return DoS(100,error("ConnectInputs() : %s VerifySignature failed", GetHash().ToString().substr(0,10).c_str()));
- // Check for conflicts
+ // Check for conflicts (double-spend)
+ // This doesn't trigger the DoS code on purpose; if it did, it would make it easier
+ // for an attacker to attempt to split the network.
if (!txindex.vSpent[prevout.n].IsNull())
return fMiner ? false : error("ConnectInputs() : %s prev tx already used at %s", GetHash().ToString().substr(0,10).c_str(), txindex.vSpent[prevout.n].ToString().c_str());
// Check for negative or overflow input values
nValueIn += txPrev.vout[prevout.n].nValue;
if (!MoneyRange(txPrev.vout[prevout.n].nValue) || !MoneyRange(nValueIn))
- return error("ConnectInputs() : txin values out of range");
+ return DoS(100, error("ConnectInputs() : txin values out of range"));
// Mark outpoints as spent
txindex.vSpent[prevout.n] = posThisTx;
@@ -887,17 +889,17 @@ bool CTransaction::ConnectInputs(CTxDB& txdb, map<uint256, CTxIndex>& mapTestPoo
}
if (nValueIn < GetValueOut())
- return error("ConnectInputs() : %s value in < value out", GetHash().ToString().substr(0,10).c_str());
+ return DoS(100, error("ConnectInputs() : %s value in < value out", GetHash().ToString().substr(0,10).c_str()));
// Tally transaction fees
int64 nTxFee = nValueIn - GetValueOut();
if (nTxFee < 0)
- return error("ConnectInputs() : %s nTxFee < 0", GetHash().ToString().substr(0,10).c_str());
+ return DoS(100, error("ConnectInputs() : %s nTxFee < 0", GetHash().ToString().substr(0,10).c_str()));
if (nTxFee < nMinFee)
return false;
nFees += nTxFee;
if (!MoneyRange(nFees))
- return error("ConnectInputs() : nFees out of range");
+ return DoS(100, error("ConnectInputs() : nFees out of range"));
}
if (fBlock)
@@ -1240,11 +1242,11 @@ bool CBlock::CheckBlock() const
// Size limits
if (vtx.empty() || vtx.size() > MAX_BLOCK_SIZE || ::GetSerializeSize(*this, SER_NETWORK) > MAX_BLOCK_SIZE)
- return error("CheckBlock() : size limits failed");
+ return DoS(100, error("CheckBlock() : size limits failed"));
// Check proof of work matches claimed amount
if (!CheckProofOfWork(GetHash(), nBits))
- return error("CheckBlock() : proof of work failed");
+ return DoS(50, error("CheckBlock() : proof of work failed"));
// Check timestamp
if (GetBlockTime() > GetAdjustedTime() + 2 * 60 * 60)
@@ -1252,23 +1254,23 @@ bool CBlock::CheckBlock() const
// First transaction must be coinbase, the rest must not be
if (vtx.empty() || !vtx[0].IsCoinBase())
- return error("CheckBlock() : first tx is not coinbase");
+ return DoS(100, error("CheckBlock() : first tx is not coinbase"));
for (int i = 1; i < vtx.size(); i++)
if (vtx[i].IsCoinBase())
- return error("CheckBlock() : more than one coinbase");
+ return DoS(100, error("CheckBlock() : more than one coinbase"));
// Check transactions
BOOST_FOREACH(const CTransaction& tx, vtx)
if (!tx.CheckTransaction())
- return error("CheckBlock() : CheckTransaction failed");
+ return DoS(tx.nDoS, error("CheckBlock() : CheckTransaction failed"));
// Check that it's not full of nonstandard transactions
if (GetSigOpCount() > MAX_BLOCK_SIGOPS)
- return error("CheckBlock() : too many nonstandard transactions");
+ return DoS(100, error("CheckBlock() : out-of-bounds SigOpCount"));
// Check merkleroot
if (hashMerkleRoot != BuildMerkleTree())
- return error("CheckBlock() : hashMerkleRoot mismatch");
+ return DoS(100, error("CheckBlock() : hashMerkleRoot mismatch"));
return true;
}
@@ -1283,13 +1285,13 @@ bool CBlock::AcceptBlock()
// Get prev block index
map<uint256, CBlockIndex*>::iterator mi = mapBlockIndex.find(hashPrevBlock);
if (mi == mapBlockIndex.end())
- return error("AcceptBlock() : prev block not found");
+ return DoS(10, error("AcceptBlock() : prev block not found"));
CBlockIndex* pindexPrev = (*mi).second;
int nHeight = pindexPrev->nHeight+1;
// Check proof of work
if (nBits != GetNextWorkRequired(pindexPrev))
- return error("AcceptBlock() : incorrect proof of work");
+ return DoS(100, error("AcceptBlock() : incorrect proof of work"));
// Check timestamp against prev
if (GetBlockTime() <= pindexPrev->GetMedianTimePast())
@@ -1298,7 +1300,7 @@ bool CBlock::AcceptBlock()
// Check that all transactions are finalized
BOOST_FOREACH(const CTransaction& tx, vtx)
if (!tx.IsFinal(nHeight, GetBlockTime()))
- return error("AcceptBlock() : contains a non-final transaction");
+ return DoS(10, error("AcceptBlock() : contains a non-final transaction"));
// Check that the block chain matches the known block chain up to a checkpoint
if (!fTestNet)
@@ -1311,7 +1313,7 @@ bool CBlock::AcceptBlock()
(nHeight == 118000 && hash != uint256("0x000000000000774a7f8a7a12dc906ddb9e17e75d684f15e00f8767f9e8f36553")) ||
(nHeight == 134444 && hash != uint256("0x00000000000005b12ffd4cd315cd34ffd4a594f430ac814c91184a0d42d2b0fe")) ||
(nHeight == 140700 && hash != uint256("0x000000000000033b512028abb90e1626d8b346fd0ed598ac0a3c371138dce2bd")))
- return error("AcceptBlock() : rejected by checkpoint lockin at %d", nHeight);
+ return DoS(100, error("AcceptBlock() : rejected by checkpoint lockin at %d", nHeight));
// Write block to history file
if (!CheckDiskSpace(::GetSerializeSize(*this, SER_DISK)))
@@ -1769,7 +1771,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
{
// Each connection can only send one version message
if (pfrom->nVersion != 0)
+ {
+ pfrom->Misbehaving(1);
return false;
+ }
int64 nTime;
CAddress addrMe;
@@ -1857,6 +1862,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
else if (pfrom->nVersion == 0)
{
// Must have a version message before anything else
+ pfrom->Misbehaving(1);
return false;
}
@@ -1878,7 +1884,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
if (pfrom->nVersion < 31402 && mapAddresses.size() > 1000)
return true;
if (vAddr.size() > 1000)
+ {
+ pfrom->Misbehaving(20);
return error("message addr size() = %d", vAddr.size());
+ }
// Store the new addresses
CAddrDB addrDB;
@@ -1936,7 +1945,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
vector<CInv> vInv;
vRecv >> vInv;
if (vInv.size() > 50000)
+ {
+ pfrom->Misbehaving(20);
return error("message inv size() = %d", vInv.size());
+ }
CTxDB txdb("r");
BOOST_FOREACH(const CInv& inv, vInv)
@@ -1965,7 +1977,10 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
vector<CInv> vInv;
vRecv >> vInv;
if (vInv.size() > 50000)
+ {
+ pfrom->Misbehaving(20);
return error("message getdata size() = %d", vInv.size());
+ }
BOOST_FOREACH(const CInv& inv, vInv)
{
@@ -2137,6 +2152,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
printf("storing orphan tx %s\n", inv.hash.ToString().substr(0,10).c_str());
AddOrphanTx(vMsg);
}
+ if (tx.nDoS) pfrom->Misbehaving(tx.nDoS);
}
@@ -2153,6 +2169,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
if (ProcessBlock(pfrom, &block))
mapAlreadyAskedFor.erase(inv);
+ if (block.nDoS) pfrom->Misbehaving(block.nDoS);
}
diff --git a/src/main.h b/src/main.h
index a8deb2b92b..1106bb9785 100644
--- a/src/main.h
+++ b/src/main.h
@@ -400,6 +400,9 @@ public:
std::vector<CTxOut> vout;
unsigned int nLockTime;
+ // Denial-of-service detection:
+ mutable int nDoS;
+ bool DoS(int nDoSIn, bool fIn) const { nDoS += nDoSIn; return fIn; }
CTransaction()
{
@@ -421,6 +424,7 @@ public:
vin.clear();
vout.clear();
nLockTime = 0;
+ nDoS = 0; // Denial-of-service prevention
}
bool IsNull() const
@@ -787,6 +791,9 @@ public:
// memory only
mutable std::vector<uint256> vMerkleTree;
+ // Denial-of-service detection:
+ mutable int nDoS;
+ bool DoS(int nDoSIn, bool fIn) const { nDoS += nDoSIn; return fIn; }
CBlock()
{
@@ -820,6 +827,7 @@ public:
nNonce = 0;
vtx.clear();
vMerkleTree.clear();
+ nDoS = 0;
}
bool IsNull() const
diff --git a/src/net.cpp b/src/net.cpp
index 2e257a6efc..1792bf78a0 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -726,6 +726,52 @@ void CNode::Cleanup()
}
+std::map<unsigned int, int64> CNode::setBanned;
+CCriticalSection CNode::cs_setBanned;
+
+void CNode::ClearBanned()
+{
+ setBanned.clear();
+}
+
+bool CNode::IsBanned(unsigned int ip)
+{
+ bool fResult = false;
+ CRITICAL_BLOCK(cs_setBanned)
+ {
+ std::map<unsigned int, int64>::iterator i = setBanned.find(ip);
+ if (i != setBanned.end())
+ {
+ int64 t = (*i).second;
+ if (GetTime() < t)
+ fResult = true;
+ }
+ }
+ return fResult;
+}
+
+bool CNode::Misbehaving(int howmuch)
+{
+ if (addr.IsLocal())
+ {
+ printf("Warning: local node %s misbehaving\n", addr.ToString().c_str());
+ return false;
+ }
+
+ nMisbehavior += howmuch;
+ if (nMisbehavior >= GetArg("-banscore", 100))
+ {
+ int64 banTime = GetTime()+GetArg("-bantime", 60*60*24); // Default 24-hour ban
+ CRITICAL_BLOCK(cs_setBanned)
+ if (setBanned[addr.ip] < banTime)
+ setBanned[addr.ip] = banTime;
+ CloseSocketDisconnect();
+ printf("Disconnected %s for misbehavior (score=%d)\n", addr.ToString().c_str(), nMisbehavior);
+ return true;
+ }
+ return false;
+}
+
@@ -896,6 +942,11 @@ void ThreadSocketHandler2(void* parg)
{
closesocket(hSocket);
}
+ else if (CNode::IsBanned(addr.ip))
+ {
+ printf("connetion from %s dropped (banned)\n", addr.ToString().c_str());
+ closesocket(hSocket);
+ }
else
{
printf("accepted connection %s\n", addr.ToString().c_str());
@@ -1454,7 +1505,8 @@ bool OpenNetworkConnection(const CAddress& addrConnect)
//
if (fShutdown)
return false;
- if (addrConnect.ip == addrLocalHost.ip || !addrConnect.IsIPv4() || FindNode(addrConnect.ip))
+ if (addrConnect.ip == addrLocalHost.ip || !addrConnect.IsIPv4() ||
+ FindNode(addrConnect.ip) || CNode::IsBanned(addrConnect.ip))
return false;
vnThreadsRunning[1]--;
diff --git a/src/net.h b/src/net.h
index 0026e402c2..5b3568fcaf 100644
--- a/src/net.h
+++ b/src/net.h
@@ -124,6 +124,13 @@ public:
bool fDisconnect;
protected:
int nRefCount;
+
+ // Denial-of-service detection/prevention
+ // Key is ip address, value is banned-until-time
+ static std::map<unsigned int, int64> setBanned;
+ static CCriticalSection cs_setBanned;
+ int nMisbehavior;
+
public:
int64 nReleaseTime;
std::map<uint256, CRequestTracker> mapRequests;
@@ -148,7 +155,6 @@ public:
// publish and subscription
std::vector<char> vfSubscribe;
-
CNode(SOCKET hSocketIn, CAddress addrIn, bool fInboundIn=false)
{
nServices = 0;
@@ -185,6 +191,7 @@ public:
nStartingHeight = -1;
fGetAddr = false;
vfSubscribe.assign(256, false);
+ nMisbehavior = 0;
// Be shy and don't send version until we hear
if (!fInbound)
@@ -568,6 +575,25 @@ public:
void CancelSubscribe(unsigned int nChannel);
void CloseSocketDisconnect();
void Cleanup();
+
+
+ // Denial-of-service detection/prevention
+ // The idea is to detect peers that are behaving
+ // badly and disconnect/ban them, but do it in a
+ // one-coding-mistake-won't-shatter-the-entire-network
+ // way.
+ // IMPORTANT: There should be nothing I can give a
+ // node that it will forward on that will make that
+ // node's peers drop it. If there is, an attacker
+ // can isolate a node and/or try to split the network.
+ // Dropping a node for sending stuff that is invalid
+ // now but might be valid in a later version is also
+ // dangerous, because it can cause a network split
+ // between nodes running old code and nodes running
+ // new code.
+ static void ClearBanned(); // needed for unit testing
+ static bool IsBanned(unsigned int ip);
+ bool Misbehaving(int howmuch); // 1 == a little, 100 == a lot
};
diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp
new file mode 100644
index 0000000000..e60bb742dd
--- /dev/null
+++ b/src/test/DoS_tests.cpp
@@ -0,0 +1,68 @@
+//
+// Unit tests for denial-of-service detection/prevention code
+//
+#include <boost/test/unit_test.hpp>
+#include <boost/foreach.hpp>
+
+#include "../main.h"
+#include "../net.h"
+#include "../util.h"
+
+using namespace std;
+
+BOOST_AUTO_TEST_SUITE(DoS_tests)
+
+BOOST_AUTO_TEST_CASE(DoS_banning)
+{
+ CNode::ClearBanned();
+ CAddress addr1(0xa0b0c001);
+ CNode dummyNode1(INVALID_SOCKET, addr1, true);
+ dummyNode1.Misbehaving(100); // Should get banned
+ BOOST_CHECK(CNode::IsBanned(addr1.ip));
+ BOOST_CHECK(!CNode::IsBanned(addr1.ip|0x0000ff00)); // Different ip, not banned
+
+ CAddress addr2(0xa0b0c002);
+ CNode dummyNode2(INVALID_SOCKET, addr2, true);
+ dummyNode2.Misbehaving(50);
+ BOOST_CHECK(!CNode::IsBanned(addr2.ip)); // 2 not banned yet...
+ BOOST_CHECK(CNode::IsBanned(addr1.ip)); // ... but 1 still should be
+ dummyNode2.Misbehaving(50);
+ BOOST_CHECK(CNode::IsBanned(addr2.ip));
+}
+
+BOOST_AUTO_TEST_CASE(DoS_banscore)
+{
+ CNode::ClearBanned();
+ mapArgs["-banscore"] = "111"; // because 11 is my favorite number
+ CAddress addr1(0xa0b0c001);
+ CNode dummyNode1(INVALID_SOCKET, addr1, true);
+ dummyNode1.Misbehaving(100);
+ BOOST_CHECK(!CNode::IsBanned(addr1.ip));
+ dummyNode1.Misbehaving(10);
+ BOOST_CHECK(!CNode::IsBanned(addr1.ip));
+ dummyNode1.Misbehaving(1);
+ BOOST_CHECK(CNode::IsBanned(addr1.ip));
+ mapArgs["-banscore"] = "100";
+}
+
+BOOST_AUTO_TEST_CASE(DoS_bantime)
+{
+ CNode::ClearBanned();
+ int64 nStartTime = GetTime();
+ SetMockTime(nStartTime); // Overrides future calls to GetTime()
+
+ CAddress addr(0xa0b0c001);
+ CNode dummyNode(INVALID_SOCKET, addr, true);
+
+ dummyNode.Misbehaving(100);
+ BOOST_CHECK(CNode::IsBanned(addr.ip));
+
+ SetMockTime(nStartTime+60*60);
+ BOOST_CHECK(CNode::IsBanned(addr.ip));
+
+ SetMockTime(nStartTime+60*60*24+1);
+ BOOST_CHECK(!CNode::IsBanned(addr.ip));
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/test/test_bitcoin.cpp b/src/test/test_bitcoin.cpp
index 0230bb6eca..c6f6d94b1e 100644
--- a/src/test/test_bitcoin.cpp
+++ b/src/test/test_bitcoin.cpp
@@ -8,7 +8,7 @@
#include "uint256_tests.cpp"
#include "script_tests.cpp"
#include "transaction_tests.cpp"
-
+#include "DoS_tests.cpp"
CWallet* pwalletMain;
diff --git a/src/util.cpp b/src/util.cpp
index 03b3d73e62..80095fe771 100644
--- a/src/util.cpp
+++ b/src/util.cpp
@@ -813,11 +813,20 @@ void ShrinkDebugFile()
// - Median of other nodes's clocks
// - The user (asking the user to fix the system clock if the first two disagree)
//
+static int64 nMockTime = 0; // For unit testing
+
int64 GetTime()
{
+ if (nMockTime) return nMockTime;
+
return time(NULL);
}
+void SetMockTime(int64 nMockTimeIn)
+{
+ nMockTime = nMockTimeIn;
+}
+
static int64 nTimeOffset = 0;
int64 GetAdjustedTime()
diff --git a/src/util.h b/src/util.h
index fabcaf9303..dd5c41135e 100644
--- a/src/util.h
+++ b/src/util.h
@@ -197,6 +197,7 @@ void ShrinkDebugFile();
int GetRandInt(int nMax);
uint64 GetRand(uint64 nMax);
int64 GetTime();
+void SetMockTime(int64 nMockTimeIn);
int64 GetAdjustedTime();
void AddTimeData(unsigned int ip, int64 nTime);
std::string FormatFullVersion();