diff options
Diffstat (limited to 'src')
177 files changed, 2960 insertions, 1958 deletions
diff --git a/src/Makefile.am b/src/Makefile.am index 627df97cad..2b004691fd 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -85,6 +85,10 @@ if BUILD_BITCOIND bin_PROGRAMS += bitcoind endif +if BUILD_BITCOIN_NODE + bin_PROGRAMS += bitcoin-node +endif + if BUILD_BITCOIN_CLI bin_PROGRAMS += bitcoin-cli endif @@ -223,6 +227,7 @@ BITCOIN_CORE_H = \ util/message.h \ util/moneystr.h \ util/rbf.h \ + util/ref.h \ util/settings.h \ util/string.h \ util/threadnames.h \ @@ -243,6 +248,7 @@ BITCOIN_CORE_H = \ wallet/ismine.h \ wallet/load.h \ wallet/rpcwallet.h \ + wallet/salvage.h \ wallet/scriptpubkeyman.h \ wallet/wallet.h \ wallet/walletdb.h \ @@ -351,6 +357,7 @@ libbitcoin_wallet_a_SOURCES = \ wallet/load.cpp \ wallet/rpcdump.cpp \ wallet/rpcwallet.cpp \ + wallet/salvage.cpp \ wallet/scriptpubkeyman.cpp \ wallet/wallet.cpp \ wallet/walletdb.cpp \ @@ -496,7 +503,6 @@ libbitcoin_util_a_SOURCES = \ support/lockedpool.cpp \ chainparamsbase.cpp \ clientversion.cpp \ - compat/glibc_sanity_fdelt.cpp \ compat/glibc_sanity.cpp \ compat/glibcxx_sanity.cpp \ compat/strnlen.cpp \ @@ -547,22 +553,21 @@ libbitcoin_cli_a_SOURCES = \ nodist_libbitcoin_util_a_SOURCES = $(srcdir)/obj/build.h # -# bitcoind binary # -bitcoind_SOURCES = bitcoind.cpp -bitcoind_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) -bitcoind_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) -bitcoind_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +# bitcoind & bitcoin-node binaries # +bitcoin_daemon_sources = bitcoind.cpp +bitcoin_bin_cppflags = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +bitcoin_bin_cxxflags = $(AM_CXXFLAGS) $(PIE_FLAGS) +bitcoin_bin_ldflags = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) if TARGET_WINDOWS -bitcoind_SOURCES += bitcoind-res.rc +bitcoin_daemon_sources += bitcoind-res.rc endif -bitcoind_LDADD = \ - $(LIBBITCOIN_SERVER) \ +bitcoin_bin_ldadd = \ $(LIBBITCOIN_WALLET) \ $(LIBBITCOIN_COMMON) \ - $(LIBUNIVALUE) \ $(LIBBITCOIN_UTIL) \ + $(LIBUNIVALUE) \ $(LIBBITCOIN_ZMQ) \ $(LIBBITCOIN_CONSENSUS) \ $(LIBBITCOIN_CRYPTO) \ @@ -571,7 +576,19 @@ bitcoind_LDADD = \ $(LIBMEMENV) \ $(LIBSECP256K1) -bitcoind_LDADD += $(BOOST_LIBS) $(BDB_LIBS) $(MINIUPNPC_LIBS) $(EVENT_PTHREADS_LIBS) $(EVENT_LIBS) $(ZMQ_LIBS) +bitcoin_bin_ldadd += $(BOOST_LIBS) $(BDB_LIBS) $(MINIUPNPC_LIBS) $(EVENT_PTHREADS_LIBS) $(EVENT_LIBS) $(ZMQ_LIBS) + +bitcoind_SOURCES = $(bitcoin_daemon_sources) +bitcoind_CPPFLAGS = $(bitcoin_bin_cppflags) +bitcoind_CXXFLAGS = $(bitcoin_bin_cxxflags) +bitcoind_LDFLAGS = $(bitcoin_bin_ldflags) +bitcoind_LDADD = $(LIBBITCOIN_SERVER) $(bitcoin_bin_ldadd) + +bitcoin_node_SOURCES = $(bitcoin_daemon_sources) +bitcoin_node_CPPFLAGS = $(bitcoin_bin_cppflags) +bitcoin_node_CXXFLAGS = $(bitcoin_bin_cxxflags) +bitcoin_node_LDFLAGS = $(bitcoin_bin_ldflags) +bitcoin_node_LDADD = $(LIBBITCOIN_SERVER) $(bitcoin_bin_ldadd) # bitcoin-cli binary # bitcoin_cli_SOURCES = bitcoin-cli.cpp @@ -615,29 +632,14 @@ bitcoin_tx_LDADD += $(BOOST_LIBS) # bitcoin-wallet binary # bitcoin_wallet_SOURCES = bitcoin-wallet.cpp -bitcoin_wallet_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) -bitcoin_wallet_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) -bitcoin_wallet_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +bitcoin_wallet_CPPFLAGS = $(bitcoin_bin_cppflags) +bitcoin_wallet_CXXFLAGS = $(bitcoin_bin_cxxflags) +bitcoin_wallet_LDFLAGS = $(bitcoin_bin_ldflags) +bitcoin_wallet_LDADD = $(LIBBITCOIN_WALLET_TOOL) $(bitcoin_bin_ldadd) if TARGET_WINDOWS bitcoin_wallet_SOURCES += bitcoin-wallet-res.rc endif - -bitcoin_wallet_LDADD = \ - $(LIBBITCOIN_WALLET_TOOL) \ - $(LIBBITCOIN_WALLET) \ - $(LIBBITCOIN_COMMON) \ - $(LIBBITCOIN_CONSENSUS) \ - $(LIBBITCOIN_UTIL) \ - $(LIBBITCOIN_CRYPTO) \ - $(LIBBITCOIN_ZMQ) \ - $(LIBLEVELDB) \ - $(LIBLEVELDB_SSE42) \ - $(LIBMEMENV) \ - $(LIBSECP256K1) \ - $(LIBUNIVALUE) - -bitcoin_wallet_LDADD += $(BOOST_LIBS) $(BDB_LIBS) $(EVENT_PTHREADS_LIBS) $(EVENT_LIBS) $(MINIUPNPC_LIBS) $(ZMQ_LIBS) # # bitcoinconsensus library # diff --git a/src/Makefile.qt.include b/src/Makefile.qt.include index cf09eee2cb..13bfea7646 100644 --- a/src/Makefile.qt.include +++ b/src/Makefile.qt.include @@ -3,6 +3,11 @@ # file COPYING or http://www.opensource.org/licenses/mit-license.php. bin_PROGRAMS += qt/bitcoin-qt + +if BUILD_BITCOIN_GUI + bin_PROGRAMS += bitcoin-gui +endif + EXTRA_LIBRARIES += qt/libbitcoinqt.a # bitcoin qt core # @@ -294,29 +299,43 @@ QT_FORMS_H=$(join $(dir $(QT_FORMS_UI)),$(addprefix ui_, $(notdir $(QT_FORMS_UI: # Most files will depend on the forms and moc files as includes. Generate them # before anything else. $(QT_MOC): $(QT_FORMS_H) -$(qt_libbitcoinqt_a_OBJECTS) $(qt_bitcoin_qt_OBJECTS) : | $(QT_MOC) +$(qt_libbitcoinqt_a_OBJECTS) $(qt_bitcoin_qt_OBJECTS) $(bitcoin_gui_OBJECTS) : | $(QT_MOC) -# bitcoin-qt binary # -qt_bitcoin_qt_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) $(BITCOIN_QT_INCLUDES) \ +# bitcoin-qt and bitcoin-gui binaries # +bitcoin_qt_cppflags = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) $(BITCOIN_QT_INCLUDES) \ $(QT_INCLUDES) $(QR_CFLAGS) -qt_bitcoin_qt_CXXFLAGS = $(AM_CXXFLAGS) $(QT_PIE_FLAGS) +bitcoin_qt_cxxflags = $(AM_CXXFLAGS) $(QT_PIE_FLAGS) -qt_bitcoin_qt_SOURCES = qt/main.cpp +bitcoin_qt_sources = qt/main.cpp if TARGET_WINDOWS - qt_bitcoin_qt_SOURCES += $(BITCOIN_RC) + bitcoin_qt_sources += $(BITCOIN_RC) endif -qt_bitcoin_qt_LDADD = qt/libbitcoinqt.a $(LIBBITCOIN_SERVER) +bitcoin_qt_ldadd = qt/libbitcoinqt.a $(LIBBITCOIN_SERVER) if ENABLE_WALLET -qt_bitcoin_qt_LDADD += $(LIBBITCOIN_UTIL) $(LIBBITCOIN_WALLET) +bitcoin_qt_ldadd += $(LIBBITCOIN_UTIL) $(LIBBITCOIN_WALLET) endif if ENABLE_ZMQ -qt_bitcoin_qt_LDADD += $(LIBBITCOIN_ZMQ) $(ZMQ_LIBS) +bitcoin_qt_ldadd += $(LIBBITCOIN_ZMQ) $(ZMQ_LIBS) endif -qt_bitcoin_qt_LDADD += $(LIBBITCOIN_CLI) $(LIBBITCOIN_COMMON) $(LIBBITCOIN_UTIL) $(LIBBITCOIN_CONSENSUS) $(LIBBITCOIN_CRYPTO) $(LIBUNIVALUE) $(LIBLEVELDB) $(LIBLEVELDB_SSE42) $(LIBMEMENV) \ +bitcoin_qt_ldadd += $(LIBBITCOIN_CLI) $(LIBBITCOIN_COMMON) $(LIBBITCOIN_UTIL) $(LIBBITCOIN_CONSENSUS) $(LIBBITCOIN_CRYPTO) $(LIBUNIVALUE) $(LIBLEVELDB) $(LIBLEVELDB_SSE42) $(LIBMEMENV) \ $(BOOST_LIBS) $(QT_LIBS) $(QT_DBUS_LIBS) $(QR_LIBS) $(BDB_LIBS) $(MINIUPNPC_LIBS) $(LIBSECP256K1) \ $(EVENT_PTHREADS_LIBS) $(EVENT_LIBS) -qt_bitcoin_qt_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(QT_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) -qt_bitcoin_qt_LIBTOOLFLAGS = $(AM_LIBTOOLFLAGS) --tag CXX +bitcoin_qt_ldflags = $(RELDFLAGS) $(AM_LDFLAGS) $(QT_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +bitcoin_qt_libtoolflags = $(AM_LIBTOOLFLAGS) --tag CXX + +qt_bitcoin_qt_CPPFLAGS = $(bitcoin_qt_cppflags) +qt_bitcoin_qt_CXXFLAGS = $(bitcoin_qt_cxxflags) +qt_bitcoin_qt_SOURCES = $(bitcoin_qt_sources) +qt_bitcoin_qt_LDADD = $(bitcoin_qt_ldadd) +qt_bitcoin_qt_LDFLAGS = $(bitcoin_qt_ldflags) +qt_bitcoin_qt_LIBTOOLFLAGS = $(bitcoin_qt_libtoolflags) + +bitcoin_gui_CPPFLAGS = $(bitcoin_qt_cppflags) +bitcoin_gui_CXXFLAGS = $(bitcoin_qt_cxxflags) +bitcoin_gui_SOURCES = $(bitcoin_qt_sources) +bitcoin_gui_LDADD = $(bitcoin_qt_ldadd) +bitcoin_gui_LDFLAGS = $(bitcoin_qt_ldflags) +bitcoin_gui_LIBTOOLFLAGS = $(bitcoin_qt_libtoolflags) #locale/foo.ts -> locale/foo.qm QT_QM=$(QT_TS:.ts=.qm) diff --git a/src/Makefile.test.include b/src/Makefile.test.include index 3a0d4fdc15..2480cdadbb 100644 --- a/src/Makefile.test.include +++ b/src/Makefile.test.include @@ -31,6 +31,7 @@ FUZZ_TARGETS = \ test/fuzz/chain \ test/fuzz/checkqueue \ test/fuzz/coins_deserialize \ + test/fuzz/coins_view \ test/fuzz/cuckoocache \ test/fuzz/decode_tx \ test/fuzz/descriptor_parse \ @@ -229,6 +230,7 @@ BITCOIN_TESTS =\ test/prevector_tests.cpp \ test/raii_event_tests.cpp \ test/random_tests.cpp \ + test/ref_tests.cpp \ test/reverselock_tests.cpp \ test/rpc_tests.cpp \ test/sanity_tests.cpp \ @@ -465,6 +467,12 @@ test_fuzz_coins_deserialize_LDADD = $(FUZZ_SUITE_LD_COMMON) test_fuzz_coins_deserialize_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) test_fuzz_coins_deserialize_SOURCES = test/fuzz/deserialize.cpp +test_fuzz_coins_view_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_coins_view_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_coins_view_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_coins_view_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_coins_view_SOURCES = test/fuzz/coins_view.cpp + test_fuzz_cuckoocache_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) test_fuzz_cuckoocache_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) test_fuzz_cuckoocache_LDADD = $(FUZZ_SUITE_LD_COMMON) diff --git a/src/bench/prevector.cpp b/src/bench/prevector.cpp index 00e5d7e7a0..42b351a72d 100644 --- a/src/bench/prevector.cpp +++ b/src/bench/prevector.cpp @@ -20,9 +20,7 @@ struct nontrivial_t { int x; nontrivial_t() :x(-1) {} - ADD_SERIALIZE_METHODS - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) {READWRITE(x);} + SERIALIZE_METHODS(nontrivial_t, obj) { READWRITE(obj.x); } }; static_assert(!IS_TRIVIALLY_CONSTRUCTIBLE<nontrivial_t>::value, "expected nontrivial_t to not be trivially constructible"); diff --git a/src/bitcoin-cli.cpp b/src/bitcoin-cli.cpp index cdaabd6fab..8d85789b4e 100644 --- a/src/bitcoin-cli.cpp +++ b/src/bitcoin-cli.cpp @@ -9,6 +9,7 @@ #include <chainparamsbase.h> #include <clientversion.h> +#include <optional.h> #include <rpc/client.h> #include <rpc/protocol.h> #include <rpc/request.h> @@ -20,6 +21,7 @@ #include <functional> #include <memory> #include <stdio.h> +#include <string> #include <tuple> #include <event2/buffer.h> @@ -157,7 +159,7 @@ struct HTTPReply std::string body; }; -static const char *http_errorstring(int code) +static std::string http_errorstring(int code) { switch(code) { #if LIBEVENT_VERSION_NUMBER >= 0x02010300 @@ -250,7 +252,7 @@ public: UniValue ProcessReply(const UniValue &batch_in) override { UniValue result(UniValue::VOBJ); - std::vector<UniValue> batch = JSONRPCProcessBatchReply(batch_in, batch_in.size()); + const std::vector<UniValue> batch = JSONRPCProcessBatchReply(batch_in); // Errors in getnetworkinfo() and getblockchaininfo() are fatal, pass them on; // getwalletinfo() and getbalances() are allowed to fail if there is no wallet. if (!batch[ID_NETWORKINFO]["error"].isNull()) { @@ -304,7 +306,7 @@ public: } }; -static UniValue CallRPC(BaseRequestHandler *rh, const std::string& strMethod, const std::vector<std::string>& args) +static UniValue CallRPC(BaseRequestHandler* rh, const std::string& strMethod, const std::vector<std::string>& args, const Optional<std::string>& rpcwallet = {}) { std::string host; // In preference order, we choose the following for the port: @@ -369,14 +371,12 @@ static UniValue CallRPC(BaseRequestHandler *rh, const std::string& strMethod, co // check if we should use a special wallet endpoint std::string endpoint = "/"; - if (!gArgs.GetArgs("-rpcwallet").empty()) { - std::string walletName = gArgs.GetArg("-rpcwallet", ""); - char *encodedURI = evhttp_uriencode(walletName.data(), walletName.size(), false); + if (rpcwallet) { + char* encodedURI = evhttp_uriencode(rpcwallet->data(), rpcwallet->size(), false); if (encodedURI) { - endpoint = "/wallet/"+ std::string(encodedURI); + endpoint = "/wallet/" + std::string(encodedURI); free(encodedURI); - } - else { + } else { throw CConnectionFailed("uri-encode failed"); } } @@ -418,6 +418,65 @@ static UniValue CallRPC(BaseRequestHandler *rh, const std::string& strMethod, co return reply; } +/** + * ConnectAndCallRPC wraps CallRPC with -rpcwait and an exception handler. + * + * @param[in] rh Pointer to RequestHandler. + * @param[in] strMethod Reference to const string method to forward to CallRPC. + * @param[in] rpcwallet Reference to const optional string wallet name to forward to CallRPC. + * @returns the RPC response as a UniValue object. + * @throws a CConnectionFailed std::runtime_error if connection failed or RPC server still in warmup. + */ +static UniValue ConnectAndCallRPC(BaseRequestHandler* rh, const std::string& strMethod, const std::vector<std::string>& args, const Optional<std::string>& rpcwallet = {}) +{ + UniValue response(UniValue::VOBJ); + // Execute and handle connection failures with -rpcwait. + const bool fWait = gArgs.GetBoolArg("-rpcwait", false); + do { + try { + response = CallRPC(rh, strMethod, args, rpcwallet); + if (fWait) { + const UniValue& error = find_value(response, "error"); + if (!error.isNull() && error["code"].get_int() == RPC_IN_WARMUP) { + throw CConnectionFailed("server in warmup"); + } + } + break; // Connection succeeded, no need to retry. + } catch (const CConnectionFailed&) { + if (fWait) { + UninterruptibleSleep(std::chrono::milliseconds{1000}); + } else { + throw; + } + } + } while (fWait); + return response; +} + +/** + * GetWalletBalances calls listwallets; if more than one wallet is loaded, it then + * fetches mine.trusted balances for each loaded wallet and pushes them to `result`. + * + * @param result Reference to UniValue object the wallet names and balances are pushed to. + */ +static void GetWalletBalances(UniValue& result) +{ + std::unique_ptr<BaseRequestHandler> rh{MakeUnique<DefaultRequestHandler>()}; + const UniValue listwallets = ConnectAndCallRPC(rh.get(), "listwallets", /* args=*/{}); + if (!find_value(listwallets, "error").isNull()) return; + const UniValue& wallets = find_value(listwallets, "result"); + if (wallets.size() <= 1) return; + + UniValue balances(UniValue::VOBJ); + for (const UniValue& wallet : wallets.getValues()) { + const std::string wallet_name = wallet.get_str(); + const UniValue getbalances = ConnectAndCallRPC(rh.get(), "getbalances", /* args=*/{}, wallet_name); + const UniValue& balance = find_value(getbalances, "result")["mine"]["trusted"]; + balances.pushKV(wallet_name, balance); + } + result.pushKV("balances", balances); +} + static int CommandLineRPC(int argc, char *argv[]) { std::string strPrint; @@ -474,9 +533,8 @@ static int CommandLineRPC(int argc, char *argv[]) } std::unique_ptr<BaseRequestHandler> rh; std::string method; - if (gArgs.GetBoolArg("-getinfo", false)) { + if (gArgs.IsArgSet("-getinfo")) { rh.reset(new GetinfoRequestHandler()); - method = ""; } else { rh.reset(new DefaultRequestHandler()); if (args.size() < 1) { @@ -485,62 +543,46 @@ static int CommandLineRPC(int argc, char *argv[]) method = args[0]; args.erase(args.begin()); // Remove trailing method name from arguments vector } - - // Execute and handle connection failures with -rpcwait - const bool fWait = gArgs.GetBoolArg("-rpcwait", false); - do { - try { - const UniValue reply = CallRPC(rh.get(), method, args); - - // Parse reply - const UniValue& result = find_value(reply, "result"); - const UniValue& error = find_value(reply, "error"); - - if (!error.isNull()) { - // Error - int code = error["code"].get_int(); - if (fWait && code == RPC_IN_WARMUP) - throw CConnectionFailed("server in warmup"); - strPrint = "error: " + error.write(); - nRet = abs(code); - if (error.isObject()) - { - UniValue errCode = find_value(error, "code"); - UniValue errMsg = find_value(error, "message"); - strPrint = errCode.isNull() ? "" : "error code: "+errCode.getValStr()+"\n"; - - if (errMsg.isStr()) - strPrint += "error message:\n"+errMsg.get_str(); - - if (errCode.isNum() && errCode.get_int() == RPC_WALLET_NOT_SPECIFIED) { - strPrint += "\nTry adding \"-rpcwallet=<filename>\" option to bitcoin-cli command line."; - } - } - } else { - // Result - if (result.isNull()) - strPrint = ""; - else if (result.isStr()) - strPrint = result.get_str(); - else - strPrint = result.write(2); + Optional<std::string> wallet_name{}; + if (gArgs.IsArgSet("-rpcwallet")) wallet_name = gArgs.GetArg("-rpcwallet", ""); + const UniValue reply = ConnectAndCallRPC(rh.get(), method, args, wallet_name); + + // Parse reply + UniValue result = find_value(reply, "result"); + const UniValue& error = find_value(reply, "error"); + if (!error.isNull()) { + // Error + strPrint = "error: " + error.write(); + nRet = abs(error["code"].get_int()); + if (error.isObject()) { + const UniValue& errCode = find_value(error, "code"); + const UniValue& errMsg = find_value(error, "message"); + strPrint = errCode.isNull() ? "" : ("error code: " + errCode.getValStr() + "\n"); + + if (errMsg.isStr()) { + strPrint += ("error message:\n" + errMsg.get_str()); + } + if (errCode.isNum() && errCode.get_int() == RPC_WALLET_NOT_SPECIFIED) { + strPrint += "\nTry adding \"-rpcwallet=<filename>\" option to bitcoin-cli command line."; } - // Connection succeeded, no need to retry. - break; } - catch (const CConnectionFailed&) { - if (fWait) - UninterruptibleSleep(std::chrono::milliseconds{1000}); - else - throw; + } else { + if (gArgs.IsArgSet("-getinfo") && !gArgs.IsArgSet("-rpcwallet")) { + GetWalletBalances(result); // fetch multiwallet balances and append to result } - } while (fWait); - } - catch (const std::exception& e) { + // Result + if (result.isNull()) { + strPrint = ""; + } else if (result.isStr()) { + strPrint = result.get_str(); + } else { + strPrint = result.write(2); + } + } + } catch (const std::exception& e) { strPrint = std::string("error: ") + e.what(); nRet = EXIT_FAILURE; - } - catch (...) { + } catch (...) { PrintExceptionContinue(nullptr, "CommandLineRPC()"); throw; } diff --git a/src/bitcoin-wallet.cpp b/src/bitcoin-wallet.cpp index 7f9439788a..b420463c00 100644 --- a/src/bitcoin-wallet.cpp +++ b/src/bitcoin-wallet.cpp @@ -31,6 +31,7 @@ static void SetupWalletToolArgs() gArgs.AddArg("info", "Get wallet info", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); gArgs.AddArg("create", "Create new wallet file", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + gArgs.AddArg("salvage", "Attempt to recover private keys from a corrupt wallet", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); } static bool WalletAppInit(int argc, char* argv[]) diff --git a/src/bitcoind.cpp b/src/bitcoind.cpp index 43d3f3c5ac..b8e8717896 100644 --- a/src/bitcoind.cpp +++ b/src/bitcoind.cpp @@ -16,6 +16,7 @@ #include <noui.h> #include <shutdown.h> #include <ui_interface.h> +#include <util/ref.h> #include <util/strencodings.h> #include <util/system.h> #include <util/threadnames.h> @@ -77,6 +78,7 @@ static bool AppInit(int argc, char* argv[]) return true; } + util::Ref context{node}; try { if (!CheckDataDirOption()) { @@ -145,7 +147,7 @@ static bool AppInit(int argc, char* argv[]) // If locking the data directory failed, exit immediately return false; } - fRet = AppInitMain(node); + fRet = AppInitMain(context, node); } catch (const std::exception& e) { PrintExceptionContinue(&e, "AppInit()"); diff --git a/src/blockencodings.cpp b/src/blockencodings.cpp index 263d863cfa..a47709cd82 100644 --- a/src/blockencodings.cpp +++ b/src/blockencodings.cpp @@ -105,13 +105,12 @@ ReadStatus PartiallyDownloadedBlock::InitData(const CBlockHeaderAndShortTxIDs& c std::vector<bool> have_txn(txn_available.size()); { LOCK(pool->cs); - const std::vector<std::pair<uint256, CTxMemPool::txiter> >& vTxHashes = pool->vTxHashes; - for (size_t i = 0; i < vTxHashes.size(); i++) { - uint64_t shortid = cmpctblock.GetShortID(vTxHashes[i].first); + for (size_t i = 0; i < pool->vTxHashes.size(); i++) { + uint64_t shortid = cmpctblock.GetShortID(pool->vTxHashes[i].first); std::unordered_map<uint64_t, uint16_t>::iterator idit = shorttxids.find(shortid); if (idit != shorttxids.end()) { if (!have_txn[idit->second]) { - txn_available[idit->second] = vTxHashes[i].second->GetSharedTx(); + txn_available[idit->second] = pool->vTxHashes[i].second->GetSharedTx(); have_txn[idit->second] = true; mempool_count++; } else { diff --git a/src/blockencodings.h b/src/blockencodings.h index 377ac3a1a6..326db1b4a7 100644 --- a/src/blockencodings.h +++ b/src/blockencodings.h @@ -92,12 +92,13 @@ private: friend class PartiallyDownloadedBlock; - static const int SHORTTXIDS_LENGTH = 6; protected: std::vector<uint64_t> shorttxids; std::vector<PrefilledTransaction> prefilledtxn; public: + static constexpr int SHORTTXIDS_LENGTH = 6; + CBlockHeader header; // Dummy for deserialization @@ -125,7 +126,7 @@ class PartiallyDownloadedBlock { protected: std::vector<CTransactionRef> txn_available; size_t prefilled_count = 0, mempool_count = 0, extra_count = 0; - CTxMemPool* pool; + const CTxMemPool* pool; public: CBlockHeader header; explicit PartiallyDownloadedBlock(CTxMemPool* poolIn) : pool(poolIn) {} diff --git a/src/bloom.h b/src/bloom.h index 9173b80d66..9307257852 100644 --- a/src/bloom.h +++ b/src/bloom.h @@ -64,15 +64,7 @@ public: CBloomFilter(const unsigned int nElements, const double nFPRate, const unsigned int nTweak, unsigned char nFlagsIn); CBloomFilter() : nHashFuncs(0), nTweak(0), nFlags(0) {} - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(vData); - READWRITE(nHashFuncs); - READWRITE(nTweak); - READWRITE(nFlags); - } + SERIALIZE_METHODS(CBloomFilter, obj) { READWRITE(obj.vData, obj.nHashFuncs, obj.nTweak, obj.nFlags); } void insert(const std::vector<unsigned char>& vKey); void insert(const COutPoint& outpoint); diff --git a/src/clientversion.cpp b/src/clientversion.cpp index 5d9eaea6d0..993967a180 100644 --- a/src/clientversion.cpp +++ b/src/clientversion.cpp @@ -14,59 +14,34 @@ */ const std::string CLIENT_NAME("Satoshi"); -/** - * Client version number - */ -#define CLIENT_VERSION_SUFFIX "" - - -/** - * The following part of the code determines the CLIENT_BUILD variable. - * Several mechanisms are used for this: - * * first, if HAVE_BUILD_INFO is defined, include build.h, a file that is - * generated by the build environment, possibly containing the output - * of git-describe in a macro called BUILD_DESC - * * secondly, if this is an exported version of the code, GIT_ARCHIVE will - * be defined (automatically using the export-subst git attribute), and - * GIT_COMMIT will contain the commit id. - * * then, three options exist for determining CLIENT_BUILD: - * * if BUILD_DESC is defined, use that literally (output of git-describe) - * * if not, but GIT_COMMIT is defined, use v[maj].[min].[rev].[build]-g[commit] - * * otherwise, use v[maj].[min].[rev].[build]-unk - * finally CLIENT_VERSION_SUFFIX is added - */ -//! First, include build.h if requested #ifdef HAVE_BUILD_INFO #include <obj/build.h> +// The <obj/build.h>, which is generated by the build environment (share/genbuild.sh), +// could contain only one line of the following: +// - "#define BUILD_GIT_TAG ...", if the top commit is tagged +// - "#define BUILD_GIT_COMMIT ...", if the top commit is not tagged +// - "// No build information available", if proper git information is not available #endif -//! git will put "#define GIT_ARCHIVE 1" on the next line inside archives. $Format:%n#define GIT_ARCHIVE 1$ -#ifdef GIT_ARCHIVE -#define GIT_COMMIT_ID "$Format:%H$" -#define GIT_COMMIT_DATE "$Format:%cD$" -#endif - -#define BUILD_DESC_WITH_SUFFIX(maj, min, rev, build, suffix) \ - "v" DO_STRINGIZE(maj) "." DO_STRINGIZE(min) "." DO_STRINGIZE(rev) "." DO_STRINGIZE(build) "-" DO_STRINGIZE(suffix) - -#define BUILD_DESC_FROM_COMMIT(maj, min, rev, build, commit) \ - "v" DO_STRINGIZE(maj) "." DO_STRINGIZE(min) "." DO_STRINGIZE(rev) "." DO_STRINGIZE(build) "-g" commit +//! git will put "#define GIT_COMMIT_ID ..." on the next line inside archives. $Format:%n#define GIT_COMMIT_ID "%H"$ -#define BUILD_DESC_FROM_UNKNOWN(maj, min, rev, build) \ - "v" DO_STRINGIZE(maj) "." DO_STRINGIZE(min) "." DO_STRINGIZE(rev) "." DO_STRINGIZE(build) "-unk" - -#ifndef BUILD_DESC -#ifdef BUILD_SUFFIX -#define BUILD_DESC BUILD_DESC_WITH_SUFFIX(CLIENT_VERSION_MAJOR, CLIENT_VERSION_MINOR, CLIENT_VERSION_REVISION, CLIENT_VERSION_BUILD, BUILD_SUFFIX) -#elif defined(GIT_COMMIT_ID) -#define BUILD_DESC BUILD_DESC_FROM_COMMIT(CLIENT_VERSION_MAJOR, CLIENT_VERSION_MINOR, CLIENT_VERSION_REVISION, CLIENT_VERSION_BUILD, GIT_COMMIT_ID) +#ifdef BUILD_GIT_TAG + #define BUILD_DESC BUILD_GIT_TAG + #define BUILD_SUFFIX "" #else -#define BUILD_DESC BUILD_DESC_FROM_UNKNOWN(CLIENT_VERSION_MAJOR, CLIENT_VERSION_MINOR, CLIENT_VERSION_REVISION, CLIENT_VERSION_BUILD) -#endif + #define BUILD_DESC "v" STRINGIZE(CLIENT_VERSION_MAJOR) "." STRINGIZE(CLIENT_VERSION_MINOR) \ + "." STRINGIZE(CLIENT_VERSION_REVISION) "." STRINGIZE(CLIENT_VERSION_BUILD) + #ifdef BUILD_GIT_COMMIT + #define BUILD_SUFFIX "-" BUILD_GIT_COMMIT + #elif defined(GIT_COMMIT_ID) + #define BUILD_SUFFIX "-g" GIT_COMMIT_ID + #else + #define BUILD_SUFFIX "-unk" + #endif #endif -const std::string CLIENT_BUILD(BUILD_DESC CLIENT_VERSION_SUFFIX); +const std::string CLIENT_BUILD(BUILD_DESC BUILD_SUFFIX); static std::string FormatVersion(int nVersion) { diff --git a/src/compat/assumptions.h b/src/compat/assumptions.h index 6e7b4d3ded..4b0b224c69 100644 --- a/src/compat/assumptions.h +++ b/src/compat/assumptions.h @@ -50,6 +50,7 @@ static_assert(sizeof(double) == 8, "64-bit double assumed"); // code. static_assert(sizeof(short) == 2, "16-bit short assumed"); static_assert(sizeof(int) == 4, "32-bit int assumed"); +static_assert(sizeof(unsigned) == 4, "32-bit unsigned assumed"); // Assumption: We assume size_t to be 32-bit or 64-bit. // Example(s): size_t assumed to be at least 32-bit in ecdsa_signature_parse_der_lax(...). diff --git a/src/compat/glibc_compat.cpp b/src/compat/glibc_compat.cpp index 4de4fd7f45..d17de33e86 100644 --- a/src/compat/glibc_compat.cpp +++ b/src/compat/glibc_compat.cpp @@ -9,10 +9,6 @@ #include <cstddef> #include <cstdint> -#if defined(HAVE_SYS_SELECT_H) -#include <sys/select.h> -#endif - // Prior to GLIBC_2.14, memcpy was aliased to memmove. extern "C" void* memmove(void* a, const void* b, size_t c); extern "C" void* memcpy(void* a, const void* b, size_t c) @@ -20,15 +16,6 @@ extern "C" void* memcpy(void* a, const void* b, size_t c) return memmove(a, b, c); } -extern "C" void __chk_fail(void) __attribute__((__noreturn__)); -extern "C" FDELT_TYPE __fdelt_warn(FDELT_TYPE a) -{ - if (a >= FD_SETSIZE) - __chk_fail(); - return a / __NFDBITS; -} -extern "C" FDELT_TYPE __fdelt_chk(FDELT_TYPE) __attribute__((weak, alias("__fdelt_warn"))); - #if defined(__i386__) || defined(__arm__) extern "C" int64_t __udivmoddi4(uint64_t u, uint64_t v, uint64_t* rp); diff --git a/src/compat/glibc_sanity.cpp b/src/compat/glibc_sanity.cpp index cc74f28899..0367b9a53f 100644 --- a/src/compat/glibc_sanity.cpp +++ b/src/compat/glibc_sanity.cpp @@ -8,10 +8,6 @@ #include <cstddef> -#if defined(HAVE_SYS_SELECT_H) -bool sanity_test_fdelt(); -#endif - extern "C" void* memcpy(void* a, const void* b, size_t c); void* memcpy_int(void* a, const void* b, size_t c) { @@ -45,9 +41,5 @@ bool sanity_test_memcpy() bool glibc_sanity_test() { -#if defined(HAVE_SYS_SELECT_H) - if (!sanity_test_fdelt()) - return false; -#endif return sanity_test_memcpy<1025>(); } diff --git a/src/compat/glibc_sanity_fdelt.cpp b/src/compat/glibc_sanity_fdelt.cpp deleted file mode 100644 index 87140d0c71..0000000000 --- a/src/compat/glibc_sanity_fdelt.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2009-2019 The Bitcoin Core developers -// Distributed under the MIT software license, see the accompanying -// file COPYING or http://www.opensource.org/licenses/mit-license.php. - -#if defined(HAVE_CONFIG_H) -#include <config/bitcoin-config.h> -#endif - -#if defined(HAVE_SYS_SELECT_H) -#ifdef HAVE_CSTRING_DEPENDENT_FD_ZERO -#include <cstring> -#endif -#include <sys/select.h> - -// trigger: Call FD_SET to trigger __fdelt_chk. FORTIFY_SOURCE must be defined -// as >0 and optimizations must be set to at least -O2. -// test: Add a file descriptor to an empty fd_set. Verify that it has been -// correctly added. -bool sanity_test_fdelt() -{ - fd_set fds; - FD_ZERO(&fds); - FD_SET(0, &fds); - return FD_ISSET(0, &fds); -} -#endif diff --git a/src/core_read.cpp b/src/core_read.cpp index df78c319ee..1c0a8a096d 100644 --- a/src/core_read.cpp +++ b/src/core_read.cpp @@ -19,6 +19,7 @@ #include <boost/algorithm/string/split.hpp> #include <algorithm> +#include <string> CScript ParseScript(const std::string& s) { @@ -34,10 +35,9 @@ CScript ParseScript(const std::string& s) if (op < OP_NOP && op != OP_RESERVED) continue; - const char* name = GetOpName(static_cast<opcodetype>(op)); - if (strcmp(name, "OP_UNKNOWN") == 0) + std::string strName = GetOpName(static_cast<opcodetype>(op)); + if (strName == "OP_UNKNOWN") continue; - std::string strName(name); mapOpNames[strName] = static_cast<opcodetype>(op); // Convenience: OP_ADD and just ADD are both recognized: boost::algorithm::replace_first(strName, "OP_", ""); diff --git a/src/flatfile.h b/src/flatfile.h index 60b3503cc3..04f6373a24 100644 --- a/src/flatfile.h +++ b/src/flatfile.h @@ -16,13 +16,7 @@ struct FlatFilePos int nFile; unsigned int nPos; - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(VARINT_MODE(nFile, VarIntMode::NONNEGATIVE_SIGNED)); - READWRITE(VARINT(nPos)); - } + SERIALIZE_METHODS(FlatFilePos, obj) { READWRITE(VARINT_MODE(obj.nFile, VarIntMode::NONNEGATIVE_SIGNED), VARINT(obj.nPos)); } FlatFilePos() : nFile(-1), nPos(0) {} diff --git a/src/fs.cpp b/src/fs.cpp index 066c6c10d3..e68c97b3ca 100644 --- a/src/fs.cpp +++ b/src/fs.cpp @@ -6,6 +6,9 @@ #ifndef WIN32 #include <fcntl.h> +#include <string> +#include <sys/file.h> +#include <sys/utsname.h> #else #ifndef NOMINMAX #define NOMINMAX @@ -47,20 +50,38 @@ FileLock::~FileLock() } } +static bool IsWSL() +{ + struct utsname uname_data; + return uname(&uname_data) == 0 && std::string(uname_data.version).find("Microsoft") != std::string::npos; +} + bool FileLock::TryLock() { if (fd == -1) { return false; } - struct flock lock; - lock.l_type = F_WRLCK; - lock.l_whence = SEEK_SET; - lock.l_start = 0; - lock.l_len = 0; - if (fcntl(fd, F_SETLK, &lock) == -1) { - reason = GetErrorReason(); - return false; + + // Exclusive file locking is broken on WSL using fcntl (issue #18622) + // This workaround can be removed once the bug on WSL is fixed + static const bool is_wsl = IsWSL(); + if (is_wsl) { + if (flock(fd, LOCK_EX | LOCK_NB) == -1) { + reason = GetErrorReason(); + return false; + } + } else { + struct flock lock; + lock.l_type = F_WRLCK; + lock.l_whence = SEEK_SET; + lock.l_start = 0; + lock.l_len = 0; + if (fcntl(fd, F_SETLK, &lock) == -1) { + reason = GetErrorReason(); + return false; + } } + return true; } #else diff --git a/src/httprpc.cpp b/src/httprpc.cpp index 3c3e6e5bba..f1b9997371 100644 --- a/src/httprpc.cpp +++ b/src/httprpc.cpp @@ -9,7 +9,6 @@ #include <httpserver.h> #include <rpc/protocol.h> #include <rpc/server.h> -#include <ui_interface.h> #include <util/strencodings.h> #include <util/system.h> #include <util/translation.h> @@ -151,7 +150,7 @@ static bool RPCAuthorized(const std::string& strAuth, std::string& strAuthUserna return multiUserAuthorized(strUserPass); } -static bool HTTPReq_JSONRPC(HTTPRequest* req, const std::string &) +static bool HTTPReq_JSONRPC(const util::Ref& context, HTTPRequest* req) { // JSONRPC handles only POST if (req->GetRequestMethod() != HTTPRequest::POST) { @@ -166,7 +165,7 @@ static bool HTTPReq_JSONRPC(HTTPRequest* req, const std::string &) return false; } - JSONRPCRequest jreq; + JSONRPCRequest jreq(context); jreq.peerAddr = req->GetPeer().ToString(); if (!RPCAuthorized(authHeader.second, jreq.authUser)) { LogPrintf("ThreadRPCServer incorrect password attempt from %s\n", jreq.peerAddr); @@ -249,11 +248,8 @@ static bool InitRPCAuthentication() { if (gArgs.GetArg("-rpcpassword", "") == "") { - LogPrintf("No rpcpassword set - using random cookie authentication.\n"); + LogPrintf("Using random cookie authentication.\n"); if (!GenerateAuthCookie(&strRPCUserColonPass)) { - uiInterface.ThreadSafeMessageBox( - _("Error: A fatal internal error occurred, see debug.log for details"), // Same message as AbortNode - "", CClientUIInterface::MSG_ERROR); return false; } } else { @@ -288,15 +284,16 @@ static bool InitRPCAuthentication() return true; } -bool StartHTTPRPC() +bool StartHTTPRPC(const util::Ref& context) { LogPrint(BCLog::RPC, "Starting HTTP RPC server\n"); if (!InitRPCAuthentication()) return false; - RegisterHTTPHandler("/", true, HTTPReq_JSONRPC); + auto handle_rpc = [&context](HTTPRequest* req, const std::string&) { return HTTPReq_JSONRPC(context, req); }; + RegisterHTTPHandler("/", true, handle_rpc); if (g_wallet_init_interface.HasWalletSupport()) { - RegisterHTTPHandler("/wallet/", false, HTTPReq_JSONRPC); + RegisterHTTPHandler("/wallet/", false, handle_rpc); } struct event_base* eventBase = EventBase(); assert(eventBase); diff --git a/src/httprpc.h b/src/httprpc.h index 99e4d59b8a..a6a38fc95a 100644 --- a/src/httprpc.h +++ b/src/httprpc.h @@ -5,11 +5,14 @@ #ifndef BITCOIN_HTTPRPC_H #define BITCOIN_HTTPRPC_H +namespace util { +class Ref; +} // namespace util /** Start HTTP RPC subsystem. * Precondition; HTTP and RPC has been started. */ -bool StartHTTPRPC(); +bool StartHTTPRPC(const util::Ref& context); /** Interrupt HTTP RPC subsystem. */ void InterruptHTTPRPC(); @@ -21,7 +24,7 @@ void StopHTTPRPC(); /** Start HTTP REST subsystem. * Precondition; HTTP and RPC has been started. */ -void StartREST(); +void StartREST(const util::Ref& context); /** Interrupt RPC REST subsystem. */ void InterruptREST(); diff --git a/src/httpserver.cpp b/src/httpserver.cpp index ffe246b241..5e78fd1d71 100644 --- a/src/httpserver.cpp +++ b/src/httpserver.cpp @@ -421,7 +421,7 @@ bool UpdateHTTPServerLogging(bool enable) { #endif } -static std::thread threadHTTP; +static std::thread g_thread_http; static std::vector<std::thread> g_thread_http_workers; void StartHTTPServer() @@ -429,7 +429,7 @@ void StartHTTPServer() LogPrint(BCLog::HTTP, "Starting HTTP server\n"); int rpcThreads = std::max((long)gArgs.GetArg("-rpcthreads", DEFAULT_HTTP_THREADS), 1L); LogPrintf("HTTP: starting %d worker threads\n", rpcThreads); - threadHTTP = std::thread(ThreadHTTP, eventBase); + g_thread_http = std::thread(ThreadHTTP, eventBase); for (int i = 0; i < rpcThreads; i++) { g_thread_http_workers.emplace_back(HTTPWorkQueueRun, workQueue, i); @@ -467,7 +467,7 @@ void StopHTTPServer() boundSockets.clear(); if (eventBase) { LogPrint(BCLog::HTTP, "Waiting for HTTP event thread to exit\n"); - threadHTTP.join(); + if (g_thread_http.joinable()) g_thread_http.join(); } if (eventHTTP) { evhttp_free(eventHTTP); diff --git a/src/index/blockfilterindex.cpp b/src/index/blockfilterindex.cpp index c3ce8d7af0..65a5f03a8e 100644 --- a/src/index/blockfilterindex.cpp +++ b/src/index/blockfilterindex.cpp @@ -31,6 +31,12 @@ constexpr char DB_FILTER_POS = 'P'; constexpr unsigned int MAX_FLTR_FILE_SIZE = 0x1000000; // 16 MiB /** The pre-allocation chunk size for fltr?????.dat files */ constexpr unsigned int FLTR_FILE_CHUNK_SIZE = 0x100000; // 1 MiB +/** Maximum size of the cfheaders cache + * We have a limit to prevent a bug in filling this cache + * potentially turning into an OOM. At 2000 entries, this cache + * is big enough for a 2,000,000 length block chain, which + * we should be enough until ~2047. */ +constexpr size_t CF_HEADERS_CACHE_MAX_SZ{2000}; namespace { @@ -39,14 +45,7 @@ struct DBVal { uint256 header; FlatFilePos pos; - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(hash); - READWRITE(header); - READWRITE(pos); - } + SERIALIZE_METHODS(DBVal, obj) { READWRITE(obj.hash, obj.header, obj.pos); } }; struct DBHeightKey { @@ -78,17 +77,14 @@ struct DBHashKey { explicit DBHashKey(const uint256& hash_in) : hash(hash_in) {} - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { + SERIALIZE_METHODS(DBHashKey, obj) { char prefix = DB_BLOCK_HASH; READWRITE(prefix); if (prefix != DB_BLOCK_HASH) { throw std::ios_base::failure("Invalid format for block filter index DB hash key"); } - READWRITE(hash); + READWRITE(obj.hash); } }; @@ -387,13 +383,32 @@ bool BlockFilterIndex::LookupFilter(const CBlockIndex* block_index, BlockFilter& return ReadFilterFromDisk(entry.pos, filter_out); } -bool BlockFilterIndex::LookupFilterHeader(const CBlockIndex* block_index, uint256& header_out) const +bool BlockFilterIndex::LookupFilterHeader(const CBlockIndex* block_index, uint256& header_out) { + LOCK(m_cs_headers_cache); + + bool is_checkpoint{block_index->nHeight % CFCHECKPT_INTERVAL == 0}; + + if (is_checkpoint) { + // Try to find the block in the headers cache if this is a checkpoint height. + auto header = m_headers_cache.find(block_index->GetBlockHash()); + if (header != m_headers_cache.end()) { + header_out = header->second; + return true; + } + } + DBVal entry; if (!LookupOne(*m_db, block_index, entry)) { return false; } + if (is_checkpoint && + m_headers_cache.size() < CF_HEADERS_CACHE_MAX_SZ) { + // Add to the headers cache if this is a checkpoint height. + m_headers_cache.emplace(block_index->GetBlockHash(), entry.header); + } + header_out = entry.header; return true; } diff --git a/src/index/blockfilterindex.h b/src/index/blockfilterindex.h index 436d52515f..317f8c0e40 100644 --- a/src/index/blockfilterindex.h +++ b/src/index/blockfilterindex.h @@ -10,6 +10,14 @@ #include <flatfile.h> #include <index/base.h> +/** Interval between compact filter checkpoints. See BIP 157. */ +static constexpr int CFCHECKPT_INTERVAL = 1000; + +struct FilterHeaderHasher +{ + size_t operator()(const uint256& hash) const { return ReadLE64(hash.begin()); } +}; + /** * BlockFilterIndex is used to store and retrieve block filters, hashes, and headers for a range of * blocks by height. An index is constructed for each supported filter type with its own database @@ -30,6 +38,10 @@ private: bool ReadFilterFromDisk(const FlatFilePos& pos, BlockFilter& filter) const; size_t WriteFilterToDisk(FlatFilePos& pos, const BlockFilter& filter); + Mutex m_cs_headers_cache; + /** cache of block hash to filter header, to avoid disk access when responding to getcfcheckpt. */ + std::unordered_map<uint256, uint256, FilterHeaderHasher> m_headers_cache GUARDED_BY(m_cs_headers_cache); + protected: bool Init() override; @@ -54,7 +66,7 @@ public: bool LookupFilter(const CBlockIndex* block_index, BlockFilter& filter_out) const; /** Get a single filter header by block. */ - bool LookupFilterHeader(const CBlockIndex* block_index, uint256& header_out) const; + bool LookupFilterHeader(const CBlockIndex* block_index, uint256& header_out); /** Get a range of filters between two heights on a chain. */ bool LookupFilterRange(int start_height, const CBlockIndex* stop_index, diff --git a/src/index/txindex.cpp b/src/index/txindex.cpp index 5bbe6ad1df..4626395ef0 100644 --- a/src/index/txindex.cpp +++ b/src/index/txindex.cpp @@ -21,12 +21,10 @@ struct CDiskTxPos : public FlatFilePos { unsigned int nTxOffset; // after header - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITEAS(FlatFilePos, *this); - READWRITE(VARINT(nTxOffset)); + SERIALIZE_METHODS(CDiskTxPos, obj) + { + READWRITEAS(FlatFilePos, obj); + READWRITE(VARINT(obj.nTxOffset)); } CDiskTxPos(const FlatFilePos &blockIn, unsigned int nTxOffsetIn) : FlatFilePos(blockIn.nFile, blockIn.nPos), nTxOffset(nTxOffsetIn) { diff --git a/src/init.cpp b/src/init.cpp index 2653bd25a6..37e6251295 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -59,9 +59,10 @@ #include <validationinterface.h> #include <walletinitinterface.h> +#include <functional> +#include <set> #include <stdint.h> #include <stdio.h> -#include <set> #ifndef WIN32 #include <attributes.h> @@ -243,9 +244,9 @@ void Shutdown(NodeContext& node) } // FlushStateToDisk generates a ChainStateFlushed callback, which we should avoid missing - { + if (node.chainman) { LOCK(cs_main); - for (CChainState* chainstate : g_chainman.GetAll()) { + for (CChainState* chainstate : node.chainman->GetAll()) { if (chainstate->CanFlushToDisk()) { chainstate->ForceFlushStateToDisk(); } @@ -270,9 +271,9 @@ void Shutdown(NodeContext& node) // up with our current chain to avoid any strange pruning edge cases and make // next startup faster by avoiding rescan. - { + if (node.chainman) { LOCK(cs_main); - for (CChainState* chainstate : g_chainman.GetAll()) { + for (CChainState* chainstate : node.chainman->GetAll()) { if (chainstate->CanFlushToDisk()) { chainstate->ForceFlushStateToDisk(); chainstate->ResetCoinsViews(); @@ -298,7 +299,8 @@ void Shutdown(NodeContext& node) globalVerifyHandle.reset(); ECC_Stop(); node.args = nullptr; - if (node.mempool) node.mempool = nullptr; + node.mempool = nullptr; + node.chainman = nullptr; node.scheduler.reset(); try { @@ -350,13 +352,13 @@ static void registerSignalHandler(int signal, void(*handler)(int)) static boost::signals2::connection rpc_notify_block_change_connection; static void OnRPCStarted() { - rpc_notify_block_change_connection = uiInterface.NotifyBlockTip_connect(&RPCNotifyBlockChange); + rpc_notify_block_change_connection = uiInterface.NotifyBlockTip_connect(std::bind(RPCNotifyBlockChange, std::placeholders::_2)); } static void OnRPCStopped() { rpc_notify_block_change_connection.disconnect(); - RPCNotifyBlockChange(false, nullptr); + RPCNotifyBlockChange(nullptr); g_best_block_cv.notify_all(); LogPrint(BCLog::RPC, "RPC stopped.\n"); } @@ -604,9 +606,9 @@ std::string LicenseInfo() } #if HAVE_SYSTEM -static void BlockNotifyCallback(bool initialSync, const CBlockIndex *pBlockIndex) +static void BlockNotifyCallback(SynchronizationState sync_state, const CBlockIndex* pBlockIndex) { - if (initialSync || !pBlockIndex) + if (sync_state != SynchronizationState::POST_INIT || !pBlockIndex) return; std::string strCmd = gArgs.GetArg("-blocknotify", ""); @@ -622,7 +624,7 @@ static bool fHaveGenesis = false; static Mutex g_genesis_wait_mutex; static std::condition_variable g_genesis_wait_cv; -static void BlockNotifyGenesisWait(bool, const CBlockIndex *pBlockIndex) +static void BlockNotifyGenesisWait(const CBlockIndex* pBlockIndex) { if (pBlockIndex != nullptr) { { @@ -688,7 +690,7 @@ static void CleanupBlockRevFiles() } } -static void ThreadImport(std::vector<fs::path> vImportFiles) +static void ThreadImport(ChainstateManager& chainman, std::vector<fs::path> vImportFiles) { const CChainParams& chainparams = Params(); util::ThreadRename("loadblk"); @@ -740,9 +742,9 @@ static void ThreadImport(std::vector<fs::path> vImportFiles) // scan for better chains in the block chain database, that are not yet connected in the active best chain // We can't hold cs_main during ActivateBestChain even though we're accessing - // the g_chainman unique_ptrs since ABC requires us not to be holding cs_main, so retrieve + // the chainman unique_ptrs since ABC requires us not to be holding cs_main, so retrieve // the relevant pointers before the ABC call. - for (CChainState* chainstate : WITH_LOCK(::cs_main, return g_chainman.GetAll())) { + for (CChainState* chainstate : WITH_LOCK(::cs_main, return chainman.GetAll())) { BlockValidationState state; if (!chainstate->ActivateBestChain(state, chainparams, nullptr)) { LogPrintf("Failed to connect best block (%s)\n", state.ToString()); @@ -783,16 +785,16 @@ static bool InitSanityCheck() return true; } -static bool AppInitServers() +static bool AppInitServers(const util::Ref& context) { RPCServer::OnStarted(&OnRPCStarted); RPCServer::OnStopped(&OnRPCStopped); if (!InitHTTPServer()) return false; StartRPC(); - if (!StartHTTPRPC()) + if (!StartHTTPRPC(context)) return false; - if (gArgs.GetBoolArg("-rest", DEFAULT_REST_ENABLE)) StartREST(); + if (gArgs.GetBoolArg("-rest", DEFAULT_REST_ENABLE)) StartREST(context); StartHTTPServer(); return true; } @@ -972,7 +974,7 @@ bool AppInitParameterInteraction() // Warn if unrecognized section name are present in the config file. for (const auto& section : gArgs.GetUnrecognizedSections()) { - InitWarning(strprintf("%s:%i " + _("Section [%s] is not recognized.").translated, section.m_file, section.m_line, section.m_name)); + InitWarning(strprintf(Untranslated("%s:%i ") + _("Section [%s] is not recognized."), section.m_file, section.m_line, section.m_name)); } if (!fs::is_directory(GetBlocksDir())) { @@ -1035,7 +1037,7 @@ bool AppInitParameterInteraction() nMaxConnections = std::min(nFD - MIN_CORE_FILEDESCRIPTORS - MAX_ADDNODE_CONNECTIONS, nMaxConnections); if (nMaxConnections < nUserMaxConnections) - InitWarning(strprintf(_("Reducing -maxconnections from %d to %d, because of system limitations.").translated, nUserMaxConnections, nMaxConnections)); + InitWarning(strprintf(_("Reducing -maxconnections from %d to %d, because of system limitations."), nUserMaxConnections, nMaxConnections)); // ********************************************************* Step 3: parameter-to-internal-flags if (gArgs.IsArgSet("-debug")) { @@ -1046,7 +1048,7 @@ bool AppInitParameterInteraction() [](std::string cat){return cat == "0" || cat == "none";})) { for (const auto& cat : categories) { if (!LogInstance().EnableCategory(cat)) { - InitWarning(strprintf(_("Unsupported logging category %s=%s.").translated, "-debug", cat)); + InitWarning(strprintf(_("Unsupported logging category %s=%s."), "-debug", cat)); } } } @@ -1055,7 +1057,7 @@ bool AppInitParameterInteraction() // Now remove the logging categories which were explicitly excluded for (const std::string& cat : gArgs.GetArgs("-debugexclude")) { if (!LogInstance().DisableCategory(cat)) { - InitWarning(strprintf(_("Unsupported logging category %s=%s.").translated, "-debugexclude", cat)); + InitWarning(strprintf(_("Unsupported logging category %s=%s."), "-debugexclude", cat)); } } @@ -1237,7 +1239,7 @@ bool AppInitLockDataDirectory() return true; } -bool AppInitMain(NodeContext& node) +bool AppInitMain(const util::Ref& context, NodeContext& node) { const CChainParams& chainparams = Params(); // ********************************************************* Step 4a: application initialization @@ -1268,7 +1270,7 @@ bool AppInitMain(NodeContext& node) LogPrintf("Config file: %s\n", config_file_path.string()); } else if (gArgs.IsArgSet("-conf")) { // Warn if no conf file exists at path provided by user - InitWarning(strprintf(_("The specified config file %s does not exist\n").translated, config_file_path.string())); + InitWarning(strprintf(_("The specified config file %s does not exist\n"), config_file_path.string())); } else { // Not categorizing as "Warning" because it's the default behavior LogPrintf("Config file: %s (not found, skipping)\n", config_file_path.string()); @@ -1339,7 +1341,6 @@ bool AppInitMain(NodeContext& node) for (const auto& client : node.chain_clients) { client->registerRpcs(); } - g_rpc_node = &node; #if ENABLE_ZMQ RegisterZMQRPCCommands(tableRPC); #endif @@ -1352,7 +1353,7 @@ bool AppInitMain(NodeContext& node) if (gArgs.GetBoolArg("-server", false)) { uiInterface.InitMessage_connect(SetRPCWarmupStatus); - if (!AppInitServers()) + if (!AppInitServers(context)) return InitError(_("Unable to start HTTP server. See debug log for details.")); } @@ -1377,8 +1378,11 @@ bool AppInitMain(NodeContext& node) // which are all started after this, may use it from the node context. assert(!node.mempool); node.mempool = &::mempool; + assert(!node.chainman); + node.chainman = &g_chainman; + ChainstateManager& chainman = EnsureChainman(node); - node.peer_logic.reset(new PeerLogicValidation(node.connman.get(), node.banman.get(), *node.scheduler, *node.mempool)); + node.peer_logic.reset(new PeerLogicValidation(node.connman.get(), node.banman.get(), *node.scheduler, *node.chainman, *node.mempool)); RegisterValidationInterface(node.peer_logic.get()); // sanitize comments per BIP-0014, format user agent and check total size @@ -1557,7 +1561,7 @@ bool AppInitMain(NodeContext& node) const int64_t load_block_index_start_time = GetTimeMillis(); try { LOCK(cs_main); - g_chainman.InitializeChainstate(); + chainman.InitializeChainstate(); UnloadBlockIndex(); // new CBlockTreeDB tries to delete the existing file, which @@ -1578,7 +1582,7 @@ bool AppInitMain(NodeContext& node) // block file from disk. // Note that it also sets fReindex based on the disk flag! // From here on out fReindex and fReset mean something different! - if (!LoadBlockIndex(chainparams)) { + if (!chainman.LoadBlockIndex(chainparams)) { if (ShutdownRequested()) break; strLoadError = _("Error loading block database"); break; @@ -1612,7 +1616,7 @@ bool AppInitMain(NodeContext& node) bool failed_chainstate_init = false; - for (CChainState* chainstate : g_chainman.GetAll()) { + for (CChainState* chainstate : chainman.GetAll()) { LogPrintf("Initializing chainstate %s\n", chainstate->ToString()); chainstate->InitCoinsDB( /* cache_size_bytes */ nCoinDBCache, @@ -1667,7 +1671,7 @@ bool AppInitMain(NodeContext& node) bool failed_rewind{false}; // Can't hold cs_main while calling RewindBlockIndex, so retrieve the relevant // chainstates beforehand. - for (CChainState* chainstate : WITH_LOCK(::cs_main, return g_chainman.GetAll())) { + for (CChainState* chainstate : WITH_LOCK(::cs_main, return chainman.GetAll())) { if (!fReset) { // Note that RewindBlockIndex MUST run even if we're about to -reindex-chainstate. // It both disconnects blocks based on the chainstate, and drops block data in @@ -1692,7 +1696,7 @@ bool AppInitMain(NodeContext& node) try { LOCK(cs_main); - for (CChainState* chainstate : g_chainman.GetAll()) { + for (CChainState* chainstate : chainman.GetAll()) { if (!is_coinsview_empty(chainstate)) { uiInterface.InitMessage(_("Verifying blocks...").translated); if (fHavePruned && gArgs.GetArg("-checkblocks", DEFAULT_CHECKBLOCKS) > MIN_BLOCKS_TO_KEEP) { @@ -1701,7 +1705,7 @@ bool AppInitMain(NodeContext& node) } const CBlockIndex* tip = chainstate->m_chain.Tip(); - RPCNotifyBlockChange(true, tip); + RPCNotifyBlockChange(tip); if (tip && tip->nTime > GetAdjustedTime() + 2 * 60 * 60) { strLoadError = _("The block database contains a block which appears to be from the future. " "This may be due to your computer's date and time being set incorrectly. " @@ -1798,7 +1802,7 @@ bool AppInitMain(NodeContext& node) nLocalServices = ServiceFlags(nLocalServices & ~NODE_NETWORK); if (!fReindex) { LOCK(cs_main); - for (CChainState* chainstate : g_chainman.GetAll()) { + for (CChainState* chainstate : chainman.GetAll()) { uiInterface.InitMessage(_("Pruning blockstore...").translated); chainstate->PruneAndFlush(); } @@ -1826,7 +1830,7 @@ bool AppInitMain(NodeContext& node) // No locking, as this happens before any background thread is started. boost::signals2::connection block_notify_genesis_wait_connection; if (::ChainActive().Tip() == nullptr) { - block_notify_genesis_wait_connection = uiInterface.NotifyBlockTip_connect(BlockNotifyGenesisWait); + block_notify_genesis_wait_connection = uiInterface.NotifyBlockTip_connect(std::bind(BlockNotifyGenesisWait, std::placeholders::_2)); } else { fHaveGenesis = true; } @@ -1841,7 +1845,7 @@ bool AppInitMain(NodeContext& node) vImportFiles.push_back(strFile); } - threadGroup.create_thread(std::bind(&ThreadImport, vImportFiles)); + threadGroup.create_thread([=, &chainman] { ThreadImport(chainman, vImportFiles); }); // Wait for genesis block to be processed { diff --git a/src/init.h b/src/init.h index ef568b6f38..33fe96e8ea 100644 --- a/src/init.h +++ b/src/init.h @@ -14,6 +14,9 @@ struct NodeContext; namespace boost { class thread_group; } // namespace boost +namespace util { +class Ref; +} // namespace util /** Interrupt threads */ void Interrupt(NodeContext& node); @@ -51,7 +54,7 @@ bool AppInitLockDataDirectory(); * @note This should only be done after daemonization. Call Shutdown() if this function fails. * @pre Parameters should be parsed and config file should be read, AppInitLockDataDirectory should have been called. */ -bool AppInitMain(NodeContext& node); +bool AppInitMain(const util::Ref& context, NodeContext& node); /** * Register all arguments with the ArgsManager diff --git a/src/interfaces/chain.cpp b/src/interfaces/chain.cpp index e1a528d99c..d8e459a8e8 100644 --- a/src/interfaces/chain.cpp +++ b/src/interfaces/chain.cpp @@ -344,7 +344,7 @@ public: bool shutdownRequested() override { return ShutdownRequested(); } int64_t getAdjustedTime() override { return GetAdjustedTime(); } void initMessage(const std::string& message) override { ::uiInterface.InitMessage(message); } - void initWarning(const std::string& message) override { InitWarning(message); } + void initWarning(const bilingual_str& message) override { InitWarning(message); } void initError(const bilingual_str& message) override { InitError(message); } void showProgress(const std::string& title, int progress, bool resume_possible) override { diff --git a/src/interfaces/chain.h b/src/interfaces/chain.h index 77b315b195..7dfc77db7b 100644 --- a/src/interfaces/chain.h +++ b/src/interfaces/chain.h @@ -225,7 +225,7 @@ public: virtual void initMessage(const std::string& message) = 0; //! Send init warning. - virtual void initWarning(const std::string& message) = 0; + virtual void initWarning(const bilingual_str& message) = 0; //! Send init error. virtual void initError(const bilingual_str& message) = 0; diff --git a/src/interfaces/node.cpp b/src/interfaces/node.cpp index 9e603a12cd..3c94e44b53 100644 --- a/src/interfaces/node.cpp +++ b/src/interfaces/node.cpp @@ -27,6 +27,7 @@ #include <sync.h> #include <txmempool.h> #include <ui_interface.h> +#include <util/ref.h> #include <util/system.h> #include <util/translation.h> #include <validation.h> @@ -80,7 +81,7 @@ public: bool appInitMain() override { m_context.chain = MakeChain(m_context); - return AppInitMain(m_context); + return AppInitMain(m_context_ref, m_context); } void appShutdown() override { @@ -225,7 +226,7 @@ public: CFeeRate getDustRelayFee() override { return ::dustRelayFee; } UniValue executeRpc(const std::string& command, const UniValue& params, const std::string& uri) override { - JSONRPCRequest req; + JSONRPCRequest req(m_context_ref); req.params = params; req.strMethod = command; req.URI = uri; @@ -308,21 +309,22 @@ public: } std::unique_ptr<Handler> handleNotifyBlockTip(NotifyBlockTipFn fn) override { - return MakeHandler(::uiInterface.NotifyBlockTip_connect([fn](bool initial_download, const CBlockIndex* block) { - fn(initial_download, block->nHeight, block->GetBlockTime(), + return MakeHandler(::uiInterface.NotifyBlockTip_connect([fn](SynchronizationState sync_state, const CBlockIndex* block) { + fn(sync_state, block->nHeight, block->GetBlockTime(), GuessVerificationProgress(Params().TxData(), block)); })); } std::unique_ptr<Handler> handleNotifyHeaderTip(NotifyHeaderTipFn fn) override { return MakeHandler( - ::uiInterface.NotifyHeaderTip_connect([fn](bool initial_download, const CBlockIndex* block) { - fn(initial_download, block->nHeight, block->GetBlockTime(), + ::uiInterface.NotifyHeaderTip_connect([fn](SynchronizationState sync_state, const CBlockIndex* block) { + fn(sync_state, block->nHeight, block->GetBlockTime(), /* verification progress is unused when a header was received */ 0); })); } NodeContext* context() override { return &m_context; } NodeContext m_context; + util::Ref m_context_ref{m_context}; }; } // namespace diff --git a/src/interfaces/node.h b/src/interfaces/node.h index aef6b19458..45b0e18fae 100644 --- a/src/interfaces/node.h +++ b/src/interfaces/node.h @@ -27,6 +27,7 @@ class Coin; class RPCTimerInterface; class UniValue; class proxyType; +enum class SynchronizationState; enum class WalletCreationStatus; struct CNodeStateStats; struct NodeContext; @@ -249,12 +250,12 @@ public: //! Register handler for block tip messages. using NotifyBlockTipFn = - std::function<void(bool initial_download, int height, int64_t block_time, double verification_progress)>; + std::function<void(SynchronizationState, int height, int64_t block_time, double verification_progress)>; virtual std::unique_ptr<Handler> handleNotifyBlockTip(NotifyBlockTipFn fn) = 0; //! Register handler for header tip messages. using NotifyHeaderTipFn = - std::function<void(bool initial_download, int height, int64_t block_time, double verification_progress)>; + std::function<void(SynchronizationState, int height, int64_t block_time, double verification_progress)>; virtual std::unique_ptr<Handler> handleNotifyHeaderTip(NotifyHeaderTipFn fn) = 0; //! Return pointer to internal chain interface, useful for testing. diff --git a/src/interfaces/wallet.cpp b/src/interfaces/wallet.cpp index 13b034936b..349dce0247 100644 --- a/src/interfaces/wallet.cpp +++ b/src/interfaces/wallet.cpp @@ -351,14 +351,13 @@ public: } return result; } - bool tryGetBalances(WalletBalances& balances, int& num_blocks, bool force, int cached_num_blocks) override + bool tryGetBalances(WalletBalances& balances, int& num_blocks) override { TRY_LOCK(m_wallet->cs_wallet, locked_wallet); if (!locked_wallet) { return false; } num_blocks = m_wallet->GetLastBlockHeight(); - if (!force && num_blocks == cached_num_blocks) return false; balances = getBalances(); return true; } diff --git a/src/interfaces/wallet.h b/src/interfaces/wallet.h index f35335c69f..421d35af15 100644 --- a/src/interfaces/wallet.h +++ b/src/interfaces/wallet.h @@ -202,11 +202,8 @@ public: //! Get balances. virtual WalletBalances getBalances() = 0; - //! Get balances if possible without waiting for chain and wallet locks. - virtual bool tryGetBalances(WalletBalances& balances, - int& num_blocks, - bool force, - int cached_num_blocks) = 0; + //! Get balances if possible without blocking. + virtual bool tryGetBalances(WalletBalances& balances, int& num_blocks) = 0; //! Get balance. virtual CAmount getBalance() = 0; diff --git a/src/logging.cpp b/src/logging.cpp index eb9da06d9b..fe58ae9e73 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -22,8 +22,8 @@ BCLog::Logger& LogInstance() * access the logger. When the shutdown sequence is fully audited and tested, * explicit destruction of these objects can be implemented by changing this * from a raw pointer to a std::unique_ptr. - * Since the destructor is never called, the logger and all its members must - * have a trivial destructor. + * Since the ~Logger() destructor is never called, the Logger class and all + * its subclasses must have implicitly-defined destructors. * * This method of initialization was originally introduced in * ee3374234c60aba2cc4c5cd5cac1c0aefc2d817c. @@ -41,7 +41,7 @@ static int FileWriteStr(const std::string &str, FILE *fp) bool BCLog::Logger::StartLogging() { - std::lock_guard<std::mutex> scoped_lock(m_cs); + StdLockGuard scoped_lock(m_cs); assert(m_buffering); assert(m_fileout == nullptr); @@ -80,7 +80,7 @@ bool BCLog::Logger::StartLogging() void BCLog::Logger::DisconnectTestLogger() { - std::lock_guard<std::mutex> scoped_lock(m_cs); + StdLockGuard scoped_lock(m_cs); m_buffering = true; if (m_fileout != nullptr) fclose(m_fileout); m_fileout = nullptr; @@ -246,7 +246,7 @@ namespace BCLog { void BCLog::Logger::LogPrintStr(const std::string& str) { - std::lock_guard<std::mutex> scoped_lock(m_cs); + StdLockGuard scoped_lock(m_cs); std::string str_prefixed = LogEscapeMessage(str); if (m_log_threadnames && m_started_new_line) { diff --git a/src/logging.h b/src/logging.h index ab07010316..7e646ef67a 100644 --- a/src/logging.h +++ b/src/logging.h @@ -8,6 +8,7 @@ #include <fs.h> #include <tinyformat.h> +#include <threadsafety.h> #include <util/string.h> #include <atomic> @@ -61,10 +62,11 @@ namespace BCLog { class Logger { private: - mutable std::mutex m_cs; // Can not use Mutex from sync.h because in debug mode it would cause a deadlock when a potential deadlock was detected - FILE* m_fileout = nullptr; // GUARDED_BY(m_cs) - std::list<std::string> m_msgs_before_open; // GUARDED_BY(m_cs) - bool m_buffering{true}; //!< Buffer messages before logging can be started. GUARDED_BY(m_cs) + mutable StdMutex m_cs; // Can not use Mutex from sync.h because in debug mode it would cause a deadlock when a potential deadlock was detected + + FILE* m_fileout GUARDED_BY(m_cs) = nullptr; + std::list<std::string> m_msgs_before_open GUARDED_BY(m_cs); + bool m_buffering GUARDED_BY(m_cs) = true; //!< Buffer messages before logging can be started. /** * m_started_new_line is a state variable that will suppress printing of @@ -79,7 +81,7 @@ namespace BCLog { std::string LogTimestampStr(const std::string& str); /** Slots that connect to the print signal */ - std::list<std::function<void(const std::string&)>> m_print_callbacks /* GUARDED_BY(m_cs) */ {}; + std::list<std::function<void(const std::string&)>> m_print_callbacks GUARDED_BY(m_cs) {}; public: bool m_print_to_console = false; @@ -98,14 +100,14 @@ namespace BCLog { /** Returns whether logs will be written to any output */ bool Enabled() const { - std::lock_guard<std::mutex> scoped_lock(m_cs); + StdLockGuard scoped_lock(m_cs); return m_buffering || m_print_to_console || m_print_to_file || !m_print_callbacks.empty(); } /** Connect a slot to the print signal and return the connection */ std::list<std::function<void(const std::string&)>>::iterator PushBackCallback(std::function<void(const std::string&)> fun) { - std::lock_guard<std::mutex> scoped_lock(m_cs); + StdLockGuard scoped_lock(m_cs); m_print_callbacks.push_back(std::move(fun)); return --m_print_callbacks.end(); } @@ -113,7 +115,7 @@ namespace BCLog { /** Delete a connection */ void DeleteCallback(std::list<std::function<void(const std::string&)>>::iterator it) { - std::lock_guard<std::mutex> scoped_lock(m_cs); + StdLockGuard scoped_lock(m_cs); m_print_callbacks.erase(it); } diff --git a/src/merkleblock.cpp b/src/merkleblock.cpp index 4ac6219886..8072b12119 100644 --- a/src/merkleblock.cpp +++ b/src/merkleblock.cpp @@ -9,6 +9,24 @@ #include <consensus/consensus.h> +std::vector<unsigned char> BitsToBytes(const std::vector<bool>& bits) +{ + std::vector<unsigned char> ret((bits.size() + 7) / 8); + for (unsigned int p = 0; p < bits.size(); p++) { + ret[p / 8] |= bits[p] << (p % 8); + } + return ret; +} + +std::vector<bool> BytesToBits(const std::vector<unsigned char>& bytes) +{ + std::vector<bool> ret(bytes.size() * 8); + for (unsigned int p = 0; p < ret.size(); p++) { + ret[p] = (bytes[p / 8] & (1 << (p % 8))) != 0; + } + return ret; +} + CMerkleBlock::CMerkleBlock(const CBlock& block, CBloomFilter* filter, const std::set<uint256>* txids) { header = block.GetBlockHeader(); diff --git a/src/merkleblock.h b/src/merkleblock.h index e641c8aa94..b2d2828784 100644 --- a/src/merkleblock.h +++ b/src/merkleblock.h @@ -13,6 +13,10 @@ #include <vector> +// Helper functions for serialization. +std::vector<unsigned char> BitsToBytes(const std::vector<bool>& bits); +std::vector<bool> BytesToBits(const std::vector<unsigned char>& bytes); + /** Data structure that represents a partial merkle tree. * * It represents a subset of the txid's of a known block, in a way that @@ -81,27 +85,14 @@ protected: public: - /** serialization implementation */ - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(nTransactions); - READWRITE(vHash); - std::vector<unsigned char> vBytes; - if (ser_action.ForRead()) { - READWRITE(vBytes); - CPartialMerkleTree &us = *(const_cast<CPartialMerkleTree*>(this)); - us.vBits.resize(vBytes.size() * 8); - for (unsigned int p = 0; p < us.vBits.size(); p++) - us.vBits[p] = (vBytes[p / 8] & (1 << (p % 8))) != 0; - us.fBad = false; - } else { - vBytes.resize((vBits.size()+7)/8); - for (unsigned int p = 0; p < vBits.size(); p++) - vBytes[p / 8] |= vBits[p] << (p % 8); - READWRITE(vBytes); - } + SERIALIZE_METHODS(CPartialMerkleTree, obj) + { + READWRITE(obj.nTransactions, obj.vHash); + std::vector<unsigned char> bytes; + SER_WRITE(obj, bytes = BitsToBytes(obj.vBits)); + READWRITE(bytes); + SER_READ(obj, obj.vBits = BytesToBits(bytes)); + SER_READ(obj, obj.fBad = false); } /** Construct a partial merkle tree from a list of transaction ids, and a mask that selects a subset of them */ @@ -157,13 +148,7 @@ public: CMerkleBlock() {} - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(header); - READWRITE(txn); - } + SERIALIZE_METHODS(CMerkleBlock, obj) { READWRITE(obj.header, obj.txn); } private: // Combined constructor to consolidate code diff --git a/src/net.cpp b/src/net.cpp index 9950b9aea4..c9cfb67ec8 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1455,7 +1455,7 @@ void CConnman::ThreadSocketHandler() void CConnman::WakeMessageHandler() { { - std::lock_guard<std::mutex> lock(mutexMsgProc); + LOCK(mutexMsgProc); fMsgProcWake = true; } condMsgProc.notify_one(); @@ -2058,7 +2058,7 @@ void CConnman::ThreadMessageHandler() WAIT_LOCK(mutexMsgProc, lock); if (!fMoreWork) { - condMsgProc.wait_until(lock, std::chrono::steady_clock::now() + std::chrono::milliseconds(100), [this] { return fMsgProcWake; }); + condMsgProc.wait_until(lock, std::chrono::steady_clock::now() + std::chrono::milliseconds(100), [this]() EXCLUSIVE_LOCKS_REQUIRED(mutexMsgProc) { return fMsgProcWake; }); } fMsgProcWake = false; } @@ -2366,7 +2366,7 @@ static CNetCleanup instance_of_cnetcleanup; void CConnman::Interrupt() { { - std::lock_guard<std::mutex> lock(mutexMsgProc); + LOCK(mutexMsgProc); flagInterruptMsgProc = true; } condMsgProc.notify_all(); @@ -454,7 +454,7 @@ private: const uint64_t nSeed0, nSeed1; /** flag for waking the message processor. */ - bool fMsgProcWake; + bool fMsgProcWake GUARDED_BY(mutexMsgProc); std::condition_variable condMsgProc; Mutex mutexMsgProc; diff --git a/src/net_processing.cpp b/src/net_processing.cpp index 1df1fab59d..159036a237 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -129,8 +129,8 @@ static constexpr unsigned int INVENTORY_BROADCAST_MAX = 7 * INVENTORY_BROADCAST_ static constexpr unsigned int AVG_FEEFILTER_BROADCAST_INTERVAL = 10 * 60; /** Maximum feefilter broadcast delay after significant change. */ static constexpr unsigned int MAX_FEEFILTER_CHANGE_DELAY = 5 * 60; -/** Interval between compact filter checkpoints. See BIP 157. */ -static constexpr int CFCHECKPT_INTERVAL = 1000; +/** Maximum number of cf hashes that may be requested with one getcfheaders. See BIP 157. */ +static constexpr uint32_t MAX_GETCFHEADERS_SIZE = 2000; struct COrphanTx { // When modifying, adapt the copy of this definition in tests/DoS_tests. @@ -819,7 +819,12 @@ void PeerLogicValidation::ReattemptInitialBroadcast(CScheduler& scheduler) const std::set<uint256> unbroadcast_txids = m_mempool.GetUnbroadcastTxs(); for (const uint256& txid : unbroadcast_txids) { - RelayTransaction(txid, *connman); + // Sanity check: all unbroadcast txns should exist in the mempool + if (m_mempool.exists(txid)) { + RelayTransaction(txid, *connman); + } else { + m_mempool.RemoveUnbroadcastTx(txid, true); + } } // schedule next run for 10-15 minutes in the future @@ -1150,9 +1155,10 @@ static bool BlockRequestAllowed(const CBlockIndex* pindex, const Consensus::Para (GetBlockProofEquivalentTime(*pindexBestHeader, *pindex, *pindexBestHeader, consensusParams) < STALE_RELAY_AGE_LIMIT); } -PeerLogicValidation::PeerLogicValidation(CConnman* connmanIn, BanMan* banman, CScheduler& scheduler, CTxMemPool& pool) +PeerLogicValidation::PeerLogicValidation(CConnman* connmanIn, BanMan* banman, CScheduler& scheduler, ChainstateManager& chainman, CTxMemPool& pool) : connman(connmanIn), m_banman(banman), + m_chainman(chainman), m_mempool(pool), m_stale_tip_check_time(0) { @@ -1608,6 +1614,37 @@ void static ProcessGetBlockData(CNode* pfrom, const CChainParams& chainparams, c } } +//! Determine whether or not a peer can request a transaction, and return it (or nullptr if not found or not allowed). +CTransactionRef static FindTxForGetData(CNode* peer, const uint256& txid, const std::chrono::seconds mempool_req, const std::chrono::seconds longlived_mempool_time) LOCKS_EXCLUDED(cs_main) +{ + // Check if the requested transaction is so recent that we're just + // about to announce it to the peer; if so, they certainly shouldn't + // know we already have it. + { + LOCK(peer->m_tx_relay->cs_tx_inventory); + if (peer->m_tx_relay->setInventoryTxToSend.count(txid)) return {}; + } + + { + LOCK(cs_main); + // Look up transaction in relay pool + auto mi = mapRelay.find(txid); + if (mi != mapRelay.end()) return mi->second; + } + + auto txinfo = mempool.info(txid); + if (txinfo.tx) { + // To protect privacy, do not answer getdata using the mempool when + // that TX couldn't have been INVed in reply to a MEMPOOL request, + // or when it's too recent to have expired from mapRelay. + if ((mempool_req.count() && txinfo.m_time <= mempool_req) || txinfo.m_time <= longlived_mempool_time) { + return txinfo.tx; + } + } + + return {}; +} + void static ProcessGetData(CNode* pfrom, const CChainParams& chainparams, CConnman* connman, CTxMemPool& mempool, const std::atomic<bool>& interruptMsgProc) LOCKS_EXCLUDED(cs_main) { AssertLockNotHeld(cs_main); @@ -1622,58 +1659,31 @@ void static ProcessGetData(CNode* pfrom, const CChainParams& chainparams, CConnm const std::chrono::seconds mempool_req = pfrom->m_tx_relay != nullptr ? pfrom->m_tx_relay->m_last_mempool_req.load() : std::chrono::seconds::min(); - { - LOCK(cs_main); - - // Process as many TX items from the front of the getdata queue as - // possible, since they're common and it's efficient to batch process - // them. - while (it != pfrom->vRecvGetData.end() && (it->type == MSG_TX || it->type == MSG_WITNESS_TX)) { - if (interruptMsgProc) - return; - // The send buffer provides backpressure. If there's no space in - // the buffer, pause processing until the next call. - if (pfrom->fPauseSend) - break; + // Process as many TX items from the front of the getdata queue as + // possible, since they're common and it's efficient to batch process + // them. + while (it != pfrom->vRecvGetData.end() && (it->type == MSG_TX || it->type == MSG_WITNESS_TX)) { + if (interruptMsgProc) return; + // The send buffer provides backpressure. If there's no space in + // the buffer, pause processing until the next call. + if (pfrom->fPauseSend) break; - const CInv &inv = *it++; + const CInv &inv = *it++; - if (pfrom->m_tx_relay == nullptr) { - // Ignore GETDATA requests for transactions from blocks-only peers. - continue; - } + if (pfrom->m_tx_relay == nullptr) { + // Ignore GETDATA requests for transactions from blocks-only peers. + continue; + } - // Send stream from relay memory - bool push = false; - auto mi = mapRelay.find(inv.hash); + CTransactionRef tx = FindTxForGetData(pfrom, inv.hash, mempool_req, longlived_mempool_time); + if (tx) { int nSendFlags = (inv.type == MSG_TX ? SERIALIZE_TRANSACTION_NO_WITNESS : 0); - if (mi != mapRelay.end()) { - connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::TX, *mi->second)); - push = true; - } else { - auto txinfo = mempool.info(inv.hash); - // To protect privacy, do not answer getdata using the mempool when - // that TX couldn't have been INVed in reply to a MEMPOOL request, - // or when it's too recent to have expired from mapRelay. - if (txinfo.tx && ( - (mempool_req.count() && txinfo.m_time <= mempool_req) - || (txinfo.m_time <= longlived_mempool_time))) - { - connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::TX, *txinfo.tx)); - push = true; - } - } - - if (push) { - // We interpret fulfilling a GETDATA for a transaction as a - // successful initial broadcast and remove it from our - // unbroadcast set. - mempool.RemoveUnbroadcastTx(inv.hash); - } else { - vNotFound.push_back(inv); - } + connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::TX, *tx)); + mempool.RemoveUnbroadcastTx(inv.hash); + } else { + vNotFound.push_back(inv); } - } // release cs_main + } // Only process one BLOCK item per call, since they're uncommon and can be // expensive to process. @@ -1731,7 +1741,7 @@ inline void static SendBlockTransactions(const CBlock& block, const BlockTransac connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::BLOCKTXN, resp)); } -bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, CTxMemPool& mempool, const std::vector<CBlockHeader>& headers, const CChainParams& chainparams, bool via_compact_block) +bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateManager& chainman, CTxMemPool& mempool, const std::vector<CBlockHeader>& headers, const CChainParams& chainparams, bool via_compact_block) { const CNetMsgMaker msgMaker(pfrom->GetSendVersion()); size_t nCount = headers.size(); @@ -1791,7 +1801,7 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, CTxMemPool& m } BlockValidationState state; - if (!ProcessNewBlockHeaders(headers, state, chainparams, &pindexLast)) { + if (!chainman.ProcessNewBlockHeaders(headers, state, chainparams, &pindexLast)) { if (state.IsInvalid()) { MaybePunishNodeForBlock(pfrom->GetId(), state, via_compact_block, "invalid header received"); return false; @@ -1981,16 +1991,18 @@ void static ProcessOrphanTx(CConnman* connman, CTxMemPool& mempool, std::set<uin * @param[in] pfrom The peer that we received the request from * @param[in] chain_params Chain parameters * @param[in] filter_type The filter type the request is for. Must be basic filters. + * @param[in] start_height The start height for the request * @param[in] stop_hash The stop_hash for the request + * @param[in] max_height_diff The maximum number of items permitted to request, as specified in BIP 157 * @param[out] stop_index The CBlockIndex for the stop_hash block, if the request can be serviced. * @param[out] filter_index The filter index, if the request can be serviced. * @return True if the request can be serviced. */ static bool PrepareBlockFilterRequest(CNode* pfrom, const CChainParams& chain_params, - BlockFilterType filter_type, - const uint256& stop_hash, + BlockFilterType filter_type, uint32_t start_height, + const uint256& stop_hash, uint32_t max_height_diff, const CBlockIndex*& stop_index, - const BlockFilterIndex*& filter_index) + BlockFilterIndex*& filter_index) { const bool supported_filter_type = (filter_type == BlockFilterType::BASIC && @@ -2015,6 +2027,21 @@ static bool PrepareBlockFilterRequest(CNode* pfrom, const CChainParams& chain_pa } } + uint32_t stop_height = stop_index->nHeight; + if (start_height > stop_height) { + LogPrint(BCLog::NET, "peer %d sent invalid getcfilters/getcfheaders with " /* Continued */ + "start height %d and stop height %d\n", + pfrom->GetId(), start_height, stop_height); + pfrom->fDisconnect = true; + return false; + } + if (stop_height - start_height >= max_height_diff) { + LogPrint(BCLog::NET, "peer %d requested too many cfilters/cfheaders: %d / %d\n", + pfrom->GetId(), stop_height - start_height + 1, max_height_diff); + pfrom->fDisconnect = true; + return false; + } + filter_index = GetBlockFilterIndex(filter_type); if (!filter_index) { LogPrint(BCLog::NET, "Filter index for supported type %s not found\n", BlockFilterTypeName(filter_type)); @@ -2025,6 +2052,61 @@ static bool PrepareBlockFilterRequest(CNode* pfrom, const CChainParams& chain_pa } /** + * Handle a cfheaders request. + * + * May disconnect from the peer in the case of a bad request. + * + * @param[in] pfrom The peer that we received the request from + * @param[in] vRecv The raw message received + * @param[in] chain_params Chain parameters + * @param[in] connman Pointer to the connection manager + */ +static void ProcessGetCFHeaders(CNode* pfrom, CDataStream& vRecv, const CChainParams& chain_params, + CConnman* connman) +{ + uint8_t filter_type_ser; + uint32_t start_height; + uint256 stop_hash; + + vRecv >> filter_type_ser >> start_height >> stop_hash; + + const BlockFilterType filter_type = static_cast<BlockFilterType>(filter_type_ser); + + const CBlockIndex* stop_index; + BlockFilterIndex* filter_index; + if (!PrepareBlockFilterRequest(pfrom, chain_params, filter_type, start_height, stop_hash, + MAX_GETCFHEADERS_SIZE, stop_index, filter_index)) { + return; + } + + uint256 prev_header; + if (start_height > 0) { + const CBlockIndex* const prev_block = + stop_index->GetAncestor(static_cast<int>(start_height - 1)); + if (!filter_index->LookupFilterHeader(prev_block, prev_header)) { + LogPrint(BCLog::NET, "Failed to find block filter header in index: filter_type=%s, block_hash=%s\n", + BlockFilterTypeName(filter_type), prev_block->GetBlockHash().ToString()); + return; + } + } + + std::vector<uint256> filter_hashes; + if (!filter_index->LookupFilterHashRange(start_height, stop_index, filter_hashes)) { + LogPrint(BCLog::NET, "Failed to find block filter hashes in index: filter_type=%s, start_height=%d, stop_hash=%s\n", + BlockFilterTypeName(filter_type), start_height, stop_hash.ToString()); + return; + } + + CSerializedNetMsg msg = CNetMsgMaker(pfrom->GetSendVersion()) + .Make(NetMsgType::CFHEADERS, + filter_type_ser, + stop_index->GetBlockHash(), + prev_header, + filter_hashes); + connman->PushMessage(pfrom, std::move(msg)); +} + +/** * Handle a getcfcheckpt request. * * May disconnect from the peer in the case of a bad request. @@ -2045,8 +2127,9 @@ static void ProcessGetCFCheckPt(CNode* pfrom, CDataStream& vRecv, const CChainPa const BlockFilterType filter_type = static_cast<BlockFilterType>(filter_type_ser); const CBlockIndex* stop_index; - const BlockFilterIndex* filter_index; - if (!PrepareBlockFilterRequest(pfrom, chain_params, filter_type, stop_hash, + BlockFilterIndex* filter_index; + if (!PrepareBlockFilterRequest(pfrom, chain_params, filter_type, /*start_height=*/0, stop_hash, + /*max_height_diff=*/std::numeric_limits<uint32_t>::max(), stop_index, filter_index)) { return; } @@ -2074,7 +2157,7 @@ static void ProcessGetCFCheckPt(CNode* pfrom, CDataStream& vRecv, const CChainPa connman->PushMessage(pfrom, std::move(msg)); } -bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, CTxMemPool& mempool, CConnman* connman, BanMan* banman, const std::atomic<bool>& interruptMsgProc) +bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, ChainstateManager& chainman, CTxMemPool& mempool, CConnman* connman, BanMan* banman, const std::atomic<bool>& interruptMsgProc) { LogPrint(BCLog::NET, "received: %s (%u bytes) peer=%d\n", SanitizeString(msg_type), vRecv.size(), pfrom->GetId()); if (gArgs.IsArgSet("-dropmessagestest") && GetRand(gArgs.GetArg("-dropmessagestest", 0)) == 0) @@ -2420,6 +2503,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec uint32_t nFetchFlags = GetFetchFlags(pfrom); const auto current_time = GetTime<std::chrono::microseconds>(); + uint256* best_block{nullptr}; for (CInv &inv : vInv) { @@ -2436,17 +2520,14 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (inv.type == MSG_BLOCK) { UpdateBlockAvailability(pfrom->GetId(), inv.hash); if (!fAlreadyHave && !fImporting && !fReindex && !mapBlocksInFlight.count(inv.hash)) { - // We used to request the full block here, but since headers-announcements are now the - // primary method of announcement on the network, and since, in the case that a node - // fell back to inv we probably have a reorg which we should get the headers for first, - // we now only provide a getheaders response here. When we receive the headers, we will - // then ask for the blocks we need. - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), inv.hash)); - LogPrint(BCLog::NET, "getheaders (%d) %s to peer=%d\n", pindexBestHeader->nHeight, inv.hash.ToString(), pfrom->GetId()); + // Headers-first is the primary method of announcement on + // the network. If a node fell back to sending blocks by inv, + // it's probably for a re-org. The final block hash + // provided should be the highest, so send a getheaders and + // then fetch the blocks we need to catch up. + best_block = &inv.hash; } - } - else - { + } else { pfrom->AddInventoryKnown(inv); if (fBlocksOnly) { LogPrint(BCLog::NET, "transaction (%s) inv sent in violation of protocol, disconnecting peer=%d\n", inv.hash.ToString(), pfrom->GetId()); @@ -2457,6 +2538,12 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec } } } + + if (best_block != nullptr) { + connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), *best_block)); + LogPrint(BCLog::NET, "getheaders (%d) %s to peer=%d\n", pindexBestHeader->nHeight, best_block->ToString(), pfrom->GetId()); + } + return true; } @@ -2667,8 +2754,8 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (msg_type == NetMsgType::TX) { // Stop processing the transaction early if - // We are in blocks only mode and peer is either not whitelisted or whitelistrelay is off - // or if this peer is supposed to be a block-relay-only peer + // 1) We are in blocks only mode and peer has no relay permission + // 2) This peer is a block-relay-only peer if ((!g_relay_txes && !pfrom->HasPermission(PF_RELAY)) || (pfrom->m_tx_relay == nullptr)) { LogPrint(BCLog::NET, "transaction sent in violation of protocol peer=%d\n", pfrom->GetId()); @@ -2837,7 +2924,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec const CBlockIndex *pindex = nullptr; BlockValidationState state; - if (!ProcessNewBlockHeaders({cmpctblock.header}, state, chainparams, &pindex)) { + if (!chainman.ProcessNewBlockHeaders({cmpctblock.header}, state, chainparams, &pindex)) { if (state.IsInvalid()) { MaybePunishNodeForBlock(pfrom->GetId(), state, /*via_compact_block*/ true, "invalid header via cmpctblock"); return true; @@ -2981,7 +3068,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec } // cs_main if (fProcessBLOCKTXN) - return ProcessMessage(pfrom, NetMsgType::BLOCKTXN, blockTxnMsg, nTimeReceived, chainparams, mempool, connman, banman, interruptMsgProc); + return ProcessMessage(pfrom, NetMsgType::BLOCKTXN, blockTxnMsg, nTimeReceived, chainparams, chainman, mempool, connman, banman, interruptMsgProc); if (fRevertToHeaderProcessing) { // Headers received from HB compact block peers are permitted to be @@ -2989,7 +3076,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // the peer if the header turns out to be for an invalid block. // Note that if a peer tries to build on an invalid chain, that // will be detected and the peer will be banned. - return ProcessHeadersMessage(pfrom, connman, mempool, {cmpctblock.header}, chainparams, /*via_compact_block=*/true); + return ProcessHeadersMessage(pfrom, connman, chainman, mempool, {cmpctblock.header}, chainparams, /*via_compact_block=*/true); } if (fBlockReconstructed) { @@ -3009,7 +3096,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // we have a chain with at least nMinimumChainWork), and we ignore // compact blocks with less work than our tip, it is safe to treat // reconstructed compact blocks as having been requested. - ProcessNewBlock(chainparams, pblock, /*fForceProcessing=*/true, &fNewBlock); + chainman.ProcessNewBlock(chainparams, pblock, /*fForceProcessing=*/true, &fNewBlock); if (fNewBlock) { pfrom->nLastBlockTime = GetTime(); } else { @@ -3099,7 +3186,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // disk-space attacks), but this should be safe due to the // protections in the compact block handler -- see related comment // in compact block optimistic reconstruction handling. - ProcessNewBlock(chainparams, pblock, /*fForceProcessing=*/true, &fNewBlock); + chainman.ProcessNewBlock(chainparams, pblock, /*fForceProcessing=*/true, &fNewBlock); if (fNewBlock) { pfrom->nLastBlockTime = GetTime(); } else { @@ -3133,7 +3220,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec ReadCompactSize(vRecv); // ignore tx count; assume it is 0. } - return ProcessHeadersMessage(pfrom, connman, mempool, headers, chainparams, /*via_compact_block=*/false); + return ProcessHeadersMessage(pfrom, connman, chainman, mempool, headers, chainparams, /*via_compact_block=*/false); } if (msg_type == NetMsgType::BLOCK) @@ -3162,7 +3249,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec mapBlockSource.emplace(hash, std::make_pair(pfrom->GetId(), true)); } bool fNewBlock = false; - ProcessNewBlock(chainparams, pblock, forceProcessing, &fNewBlock); + chainman.ProcessNewBlock(chainparams, pblock, forceProcessing, &fNewBlock); if (fNewBlock) { pfrom->nLastBlockTime = GetTime(); } else { @@ -3379,6 +3466,11 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec return true; } + if (msg_type == NetMsgType::GETCFHEADERS) { + ProcessGetCFHeaders(pfrom, vRecv, chainparams, connman); + return true; + } + if (msg_type == NetMsgType::GETCFCHECKPT) { ProcessGetCFCheckPt(pfrom, vRecv, chainparams, connman); return true; @@ -3523,7 +3615,7 @@ bool PeerLogicValidation::ProcessMessages(CNode* pfrom, std::atomic<bool>& inter bool fRet = false; try { - fRet = ProcessMessage(pfrom, msg_type, vRecv, msg.m_time, chainparams, m_mempool, connman, m_banman, interruptMsgProc); + fRet = ProcessMessage(pfrom, msg_type, vRecv, msg.m_time, chainparams, m_chainman, m_mempool, connman, m_banman, interruptMsgProc); if (interruptMsgProc) return false; if (!pfrom->vRecvGetData.empty()) diff --git a/src/net_processing.h b/src/net_processing.h index 4033c85d07..ec758c7537 100644 --- a/src/net_processing.h +++ b/src/net_processing.h @@ -12,6 +12,7 @@ #include <validationinterface.h> class CTxMemPool; +class ChainstateManager; extern RecursiveMutex cs_main; extern RecursiveMutex g_cs_orphans; @@ -27,12 +28,13 @@ class PeerLogicValidation final : public CValidationInterface, public NetEventsI private: CConnman* const connman; BanMan* const m_banman; + ChainstateManager& m_chainman; CTxMemPool& m_mempool; bool CheckIfBanned(CNode* pnode) EXCLUSIVE_LOCKS_REQUIRED(cs_main); public: - PeerLogicValidation(CConnman* connman, BanMan* banman, CScheduler& scheduler, CTxMemPool& pool); + PeerLogicValidation(CConnman* connman, BanMan* banman, CScheduler& scheduler, ChainstateManager& chainman, CTxMemPool& pool); /** * Overridden from CValidationInterface. diff --git a/src/netaddress.h b/src/netaddress.h index d8f19deffe..e640c07d32 100644 --- a/src/netaddress.h +++ b/src/netaddress.h @@ -99,12 +99,7 @@ class CNetAddr friend bool operator!=(const CNetAddr& a, const CNetAddr& b) { return !(a == b); } friend bool operator<(const CNetAddr& a, const CNetAddr& b); - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(ip); - } + SERIALIZE_METHODS(CNetAddr, obj) { READWRITE(obj.ip); } friend class CSubNet; }; @@ -136,14 +131,7 @@ class CSubNet friend bool operator!=(const CSubNet& a, const CSubNet& b) { return !(a == b); } friend bool operator<(const CSubNet& a, const CSubNet& b); - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(network); - READWRITE(netmask); - READWRITE(valid); - } + SERIALIZE_METHODS(CSubNet, obj) { READWRITE(obj.network, obj.netmask, obj.valid); } }; /** A combination of a network address (CNetAddr) and a (TCP) port */ @@ -171,13 +159,7 @@ class CService : public CNetAddr CService(const struct in6_addr& ipv6Addr, unsigned short port); explicit CService(const struct sockaddr_in6& addr); - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(ip); - READWRITE(WrapBigEndian(port)); - } + SERIALIZE_METHODS(CService, obj) { READWRITE(obj.ip, Using<BigEndianFormatter<2>>(obj.port)); } }; bool SanityCheckASMap(const std::vector<bool>& asmap); diff --git a/src/node/coinstats.cpp b/src/node/coinstats.cpp index ec52a08ace..e3c4c828b6 100644 --- a/src/node/coinstats.cpp +++ b/src/node/coinstats.cpp @@ -33,7 +33,7 @@ static void ApplyStats(CCoinsStats &stats, CHashWriter& ss, const uint256& hash, } //! Calculate statistics about the unspent transaction output set -bool GetUTXOStats(CCoinsView *view, CCoinsStats &stats) +bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, const std::function<void()>& interruption_point) { stats = CCoinsStats(); std::unique_ptr<CCoinsViewCursor> pcursor(view->Cursor()); @@ -49,6 +49,7 @@ bool GetUTXOStats(CCoinsView *view, CCoinsStats &stats) uint256 prevkey; std::map<uint32_t, Coin> outputs; while (pcursor->Valid()) { + interruption_point(); COutPoint key; Coin coin; if (pcursor->GetKey(key) && pcursor->GetValue(coin)) { diff --git a/src/node/coinstats.h b/src/node/coinstats.h index a19af0fd1b..d9cdaa3036 100644 --- a/src/node/coinstats.h +++ b/src/node/coinstats.h @@ -10,6 +10,7 @@ #include <uint256.h> #include <cstdint> +#include <functional> class CCoinsView; @@ -29,6 +30,6 @@ struct CCoinsStats }; //! Calculate statistics about the unspent transaction output set -bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats); +bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, const std::function<void()>& interruption_point = {}); #endif // BITCOIN_NODE_COINSTATS_H diff --git a/src/node/context.h b/src/node/context.h index 566ff170be..c45d9e6689 100644 --- a/src/node/context.h +++ b/src/node/context.h @@ -5,6 +5,7 @@ #ifndef BITCOIN_NODE_CONTEXT_H #define BITCOIN_NODE_CONTEXT_H +#include <cassert> #include <memory> #include <vector> @@ -13,6 +14,7 @@ class BanMan; class CConnman; class CScheduler; class CTxMemPool; +class ChainstateManager; class PeerLogicValidation; namespace interfaces { class Chain; @@ -33,6 +35,7 @@ struct NodeContext { std::unique_ptr<CConnman> connman; CTxMemPool* mempool{nullptr}; // Currently a raw pointer because the memory is not managed by this struct std::unique_ptr<PeerLogicValidation> peer_logic; + ChainstateManager* chainman{nullptr}; // Currently a raw pointer because the memory is not managed by this struct std::unique_ptr<BanMan> banman; ArgsManager* args{nullptr}; // Currently a raw pointer because the memory is not managed by this struct std::unique_ptr<interfaces::Chain> chain; @@ -46,4 +49,10 @@ struct NodeContext { ~NodeContext(); }; +inline ChainstateManager& EnsureChainman(const NodeContext& node) +{ + assert(node.chainman); + return *node.chainman; +} + #endif // BITCOIN_NODE_CONTEXT_H diff --git a/src/node/utxo_snapshot.h b/src/node/utxo_snapshot.h index 702a0cbe53..c8b4d60fd0 100644 --- a/src/node/utxo_snapshot.h +++ b/src/node/utxo_snapshot.h @@ -35,16 +35,7 @@ public: m_coins_count(coins_count), m_nchaintx(nchaintx) { } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) - { - READWRITE(m_base_blockhash); - READWRITE(m_coins_count); - READWRITE(m_nchaintx); - } - + SERIALIZE_METHODS(SnapshotMetadata, obj) { READWRITE(obj.m_base_blockhash, obj.m_coins_count, obj.m_nchaintx); } }; #endif // BITCOIN_NODE_UTXO_SNAPSHOT_H diff --git a/src/policy/feerate.h b/src/policy/feerate.h index c040867965..61fa80c130 100644 --- a/src/policy/feerate.h +++ b/src/policy/feerate.h @@ -48,12 +48,7 @@ public: CFeeRate& operator+=(const CFeeRate& a) { nSatoshisPerK += a.nSatoshisPerK; return *this; } std::string ToString() const; - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(nSatoshisPerK); - } + SERIALIZE_METHODS(CFeeRate, obj) { READWRITE(obj.nSatoshisPerK); } }; #endif // BITCOIN_POLICY_FEERATE_H diff --git a/src/primitives/block.h b/src/primitives/block.h index 750d42efbc..fd8fc8b868 100644 --- a/src/primitives/block.h +++ b/src/primitives/block.h @@ -33,17 +33,7 @@ public: SetNull(); } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(this->nVersion); - READWRITE(hashPrevBlock); - READWRITE(hashMerkleRoot); - READWRITE(nTime); - READWRITE(nBits); - READWRITE(nNonce); - } + SERIALIZE_METHODS(CBlockHeader, obj) { READWRITE(obj.nVersion, obj.hashPrevBlock, obj.hashMerkleRoot, obj.nTime, obj.nBits, obj.nNonce); } void SetNull() { @@ -89,12 +79,10 @@ public: *(static_cast<CBlockHeader*>(this)) = header; } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITEAS(CBlockHeader, *this); - READWRITE(vtx); + SERIALIZE_METHODS(CBlock, obj) + { + READWRITEAS(CBlockHeader, obj); + READWRITE(obj.vtx); } void SetNull() @@ -131,14 +119,12 @@ struct CBlockLocator explicit CBlockLocator(const std::vector<uint256>& vHaveIn) : vHave(vHaveIn) {} - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { + SERIALIZE_METHODS(CBlockLocator, obj) + { int nVersion = s.GetVersion(); if (!(s.GetType() & SER_GETHASH)) READWRITE(nVersion); - READWRITE(vHave); + READWRITE(obj.vHave); } void SetNull() diff --git a/src/primitives/transaction.h b/src/primitives/transaction.h index 58b3e8aedc..4514db578a 100644 --- a/src/primitives/transaction.h +++ b/src/primitives/transaction.h @@ -26,13 +26,7 @@ public: COutPoint(): n(NULL_INDEX) { } COutPoint(const uint256& hashIn, uint32_t nIn): hash(hashIn), n(nIn) { } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(hash); - READWRITE(n); - } + SERIALIZE_METHODS(COutPoint, obj) { READWRITE(obj.hash, obj.n); } void SetNull() { hash.SetNull(); n = NULL_INDEX; } bool IsNull() const { return (hash.IsNull() && n == NULL_INDEX); } @@ -103,14 +97,7 @@ public: explicit CTxIn(COutPoint prevoutIn, CScript scriptSigIn=CScript(), uint32_t nSequenceIn=SEQUENCE_FINAL); CTxIn(uint256 hashPrevTx, uint32_t nOut, CScript scriptSigIn=CScript(), uint32_t nSequenceIn=SEQUENCE_FINAL); - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(prevout); - READWRITE(scriptSig); - READWRITE(nSequence); - } + SERIALIZE_METHODS(CTxIn, obj) { READWRITE(obj.prevout, obj.scriptSig, obj.nSequence); } friend bool operator==(const CTxIn& a, const CTxIn& b) { @@ -143,13 +130,7 @@ public: CTxOut(const CAmount& nValueIn, CScript scriptPubKeyIn); - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(nValue); - READWRITE(scriptPubKey); - } + SERIALIZE_METHODS(CTxOut, obj) { READWRITE(obj.nValue, obj.scriptPubKey); } void SetNull() { diff --git a/src/protocol.cpp b/src/protocol.cpp index 25851e786c..243111c449 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -40,6 +40,8 @@ const char *SENDCMPCT="sendcmpct"; const char *CMPCTBLOCK="cmpctblock"; const char *GETBLOCKTXN="getblocktxn"; const char *BLOCKTXN="blocktxn"; +const char *GETCFHEADERS="getcfheaders"; +const char *CFHEADERS="cfheaders"; const char *GETCFCHECKPT="getcfcheckpt"; const char *CFCHECKPT="cfcheckpt"; } // namespace NetMsgType @@ -73,6 +75,8 @@ const static std::string allNetMessageTypes[] = { NetMsgType::CMPCTBLOCK, NetMsgType::GETBLOCKTXN, NetMsgType::BLOCKTXN, + NetMsgType::GETCFHEADERS, + NetMsgType::CFHEADERS, NetMsgType::GETCFCHECKPT, NetMsgType::CFCHECKPT, }; @@ -147,24 +151,6 @@ void SetServiceFlagsIBDCache(bool state) { g_initial_block_download_completed = state; } - -CAddress::CAddress() : CService() -{ - Init(); -} - -CAddress::CAddress(CService ipIn, ServiceFlags nServicesIn) : CService(ipIn) -{ - Init(); - nServices = nServicesIn; -} - -void CAddress::Init() -{ - nServices = NODE_NONE; - nTime = 100000000; -} - CInv::CInv() { type = 0; diff --git a/src/protocol.h b/src/protocol.h index dfcb0e0660..9527dce960 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -46,16 +46,7 @@ public: std::string GetCommand() const; bool IsValid(const MessageStartChars& messageStart) const; - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) - { - READWRITE(pchMessageStart); - READWRITE(pchCommand); - READWRITE(nMessageSize); - READWRITE(pchChecksum); - } + SERIALIZE_METHODS(CMessageHeader, obj) { READWRITE(obj.pchMessageStart, obj.pchCommand, obj.nMessageSize, obj.pchChecksum); } char pchMessageStart[MESSAGE_START_SIZE]; char pchCommand[COMMAND_SIZE]; @@ -74,100 +65,100 @@ namespace NetMsgType { * receiving node at the beginning of a connection. * @see https://bitcoin.org/en/developer-reference#version */ -extern const char *VERSION; +extern const char* VERSION; /** * The verack message acknowledges a previously-received version message, * informing the connecting node that it can begin to send other messages. * @see https://bitcoin.org/en/developer-reference#verack */ -extern const char *VERACK; +extern const char* VERACK; /** * The addr (IP address) message relays connection information for peers on the * network. * @see https://bitcoin.org/en/developer-reference#addr */ -extern const char *ADDR; +extern const char* ADDR; /** * The inv message (inventory message) transmits one or more inventories of * objects known to the transmitting peer. * @see https://bitcoin.org/en/developer-reference#inv */ -extern const char *INV; +extern const char* INV; /** * The getdata message requests one or more data objects from another node. * @see https://bitcoin.org/en/developer-reference#getdata */ -extern const char *GETDATA; +extern const char* GETDATA; /** * The merkleblock message is a reply to a getdata message which requested a * block using the inventory type MSG_MERKLEBLOCK. * @since protocol version 70001 as described by BIP37. * @see https://bitcoin.org/en/developer-reference#merkleblock */ -extern const char *MERKLEBLOCK; +extern const char* MERKLEBLOCK; /** * The getblocks message requests an inv message that provides block header * hashes starting from a particular point in the block chain. * @see https://bitcoin.org/en/developer-reference#getblocks */ -extern const char *GETBLOCKS; +extern const char* GETBLOCKS; /** * The getheaders message requests a headers message that provides block * headers starting from a particular point in the block chain. * @since protocol version 31800. * @see https://bitcoin.org/en/developer-reference#getheaders */ -extern const char *GETHEADERS; +extern const char* GETHEADERS; /** * The tx message transmits a single transaction. * @see https://bitcoin.org/en/developer-reference#tx */ -extern const char *TX; +extern const char* TX; /** * The headers message sends one or more block headers to a node which * previously requested certain headers with a getheaders message. * @since protocol version 31800. * @see https://bitcoin.org/en/developer-reference#headers */ -extern const char *HEADERS; +extern const char* HEADERS; /** * The block message transmits a single serialized block. * @see https://bitcoin.org/en/developer-reference#block */ -extern const char *BLOCK; +extern const char* BLOCK; /** * The getaddr message requests an addr message from the receiving node, * preferably one with lots of IP addresses of other receiving nodes. * @see https://bitcoin.org/en/developer-reference#getaddr */ -extern const char *GETADDR; +extern const char* GETADDR; /** * The mempool message requests the TXIDs of transactions that the receiving * node has verified as valid but which have not yet appeared in a block. * @since protocol version 60002. * @see https://bitcoin.org/en/developer-reference#mempool */ -extern const char *MEMPOOL; +extern const char* MEMPOOL; /** * The ping message is sent periodically to help confirm that the receiving * peer is still connected. * @see https://bitcoin.org/en/developer-reference#ping */ -extern const char *PING; +extern const char* PING; /** * The pong message replies to a ping message, proving to the pinging node that * the ponging node is still alive. * @since protocol version 60001 as described by BIP31. * @see https://bitcoin.org/en/developer-reference#pong */ -extern const char *PONG; +extern const char* PONG; /** * The notfound message is a reply to a getdata message which requested an * object the receiving node does not have available for relay. * @since protocol version 70001. * @see https://bitcoin.org/en/developer-reference#notfound */ -extern const char *NOTFOUND; +extern const char* NOTFOUND; /** * The filterload message tells the receiving peer to filter all relayed * transactions and requested merkle blocks through the provided filter. @@ -176,7 +167,7 @@ extern const char *NOTFOUND; * 70011 as described by BIP111. * @see https://bitcoin.org/en/developer-reference#filterload */ -extern const char *FILTERLOAD; +extern const char* FILTERLOAD; /** * The filteradd message tells the receiving peer to add a single element to a * previously-set bloom filter, such as a new public key. @@ -185,7 +176,7 @@ extern const char *FILTERLOAD; * 70011 as described by BIP111. * @see https://bitcoin.org/en/developer-reference#filteradd */ -extern const char *FILTERADD; +extern const char* FILTERADD; /** * The filterclear message tells the receiving peer to remove a previously-set * bloom filter. @@ -194,20 +185,20 @@ extern const char *FILTERADD; * 70011 as described by BIP111. * @see https://bitcoin.org/en/developer-reference#filterclear */ -extern const char *FILTERCLEAR; +extern const char* FILTERCLEAR; /** * Indicates that a node prefers to receive new block announcements via a * "headers" message rather than an "inv". * @since protocol version 70012 as described by BIP130. * @see https://bitcoin.org/en/developer-reference#sendheaders */ -extern const char *SENDHEADERS; +extern const char* SENDHEADERS; /** * The feefilter message tells the receiving peer not to inv us any txs * which do not meet the specified min fee rate. * @since protocol version 70013 as described by BIP133 */ -extern const char *FEEFILTER; +extern const char* FEEFILTER; /** * Contains a 1-byte bool and 8-byte LE version number. * Indicates that a node is willing to provide blocks via "cmpctblock" messages. @@ -215,43 +206,54 @@ extern const char *FEEFILTER; * "cmpctblock" message rather than an "inv", depending on message contents. * @since protocol version 70014 as described by BIP 152 */ -extern const char *SENDCMPCT; +extern const char* SENDCMPCT; /** * Contains a CBlockHeaderAndShortTxIDs object - providing a header and * list of "short txids". * @since protocol version 70014 as described by BIP 152 */ -extern const char *CMPCTBLOCK; +extern const char* CMPCTBLOCK; /** * Contains a BlockTransactionsRequest * Peer should respond with "blocktxn" message. * @since protocol version 70014 as described by BIP 152 */ -extern const char *GETBLOCKTXN; +extern const char* GETBLOCKTXN; /** * Contains a BlockTransactions. * Sent in response to a "getblocktxn" message. * @since protocol version 70014 as described by BIP 152 */ -extern const char *BLOCKTXN; +extern const char* BLOCKTXN; +/** + * getcfheaders requests a compact filter header and the filter hashes for a + * range of blocks, which can then be used to reconstruct the filter headers + * for those blocks. + * Only available with service bit NODE_COMPACT_FILTERS as described by + * BIP 157 & 158. + */ +extern const char* GETCFHEADERS; +/** + * cfheaders is a response to a getcfheaders request containing a filter header + * and a vector of filter hashes for each subsequent block in the requested range. + */ +extern const char* CFHEADERS; /** * getcfcheckpt requests evenly spaced compact filter headers, enabling * parallelized download and validation of the headers between them. * Only available with service bit NODE_COMPACT_FILTERS as described by * BIP 157 & 158. */ -extern const char *GETCFCHECKPT; +extern const char* GETCFCHECKPT; /** * cfcheckpt is a response to a getcfcheckpt request containing a vector of * evenly spaced filter headers for blocks on the requested chain. - * Only available with service bit NODE_COMPACT_FILTERS as described by - * BIP 157 & 158. */ -extern const char *CFCHECKPT; -}; +extern const char* CFCHECKPT; +}; // namespace NetMsgType /* Get a vector of all valid message types (see above) */ -const std::vector<std::string> &getAllNetMessageTypes(); +const std::vector<std::string>& getAllNetMessageTypes(); /** nServices flags */ enum ServiceFlags : uint64_t { @@ -320,7 +322,8 @@ void SetServiceFlagsIBDCache(bool status); * == GetDesirableServiceFlags(services), ie determines whether the given * set of service flags are sufficient for a peer to be "relevant". */ -static inline bool HasAllDesirableServiceFlags(ServiceFlags services) { +static inline bool HasAllDesirableServiceFlags(ServiceFlags services) +{ return !(GetDesirableServiceFlags(services) & (~services)); } @@ -328,62 +331,55 @@ static inline bool HasAllDesirableServiceFlags(ServiceFlags services) { * Checks if a peer with the given service flags may be capable of having a * robust address-storage DB. */ -static inline bool MayHaveUsefulAddressDB(ServiceFlags services) { +static inline bool MayHaveUsefulAddressDB(ServiceFlags services) +{ return (services & NODE_NETWORK) || (services & NODE_NETWORK_LIMITED); } /** A CService with information about it as peer */ class CAddress : public CService { -public: - CAddress(); - explicit CAddress(CService ipIn, ServiceFlags nServicesIn); + static constexpr uint32_t TIME_INIT{100000000}; - void Init(); - - ADD_SERIALIZE_METHODS; +public: + CAddress() : CService{} {}; + explicit CAddress(CService ipIn, ServiceFlags nServicesIn) : CService{ipIn}, nServices{nServicesIn} {}; - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) + SERIALIZE_METHODS(CAddress, obj) { - if (ser_action.ForRead()) - Init(); + SER_READ(obj, obj.nTime = TIME_INIT); int nVersion = s.GetVersion(); - if (s.GetType() & SER_DISK) + if (s.GetType() & SER_DISK) { READWRITE(nVersion); + } if ((s.GetType() & SER_DISK) || - (nVersion >= CADDR_TIME_VERSION && !(s.GetType() & SER_GETHASH))) - READWRITE(nTime); - uint64_t nServicesInt = nServices; - READWRITE(nServicesInt); - nServices = static_cast<ServiceFlags>(nServicesInt); - READWRITEAS(CService, *this); + (nVersion >= CADDR_TIME_VERSION && !(s.GetType() & SER_GETHASH))) { + READWRITE(obj.nTime); + } + READWRITE(Using<CustomUintFormatter<8>>(obj.nServices)); + READWRITEAS(CService, obj); } - // TODO: make private (improves encapsulation) -public: - ServiceFlags nServices; - + ServiceFlags nServices{NODE_NONE}; // disk and network only - unsigned int nTime; + uint32_t nTime{TIME_INIT}; }; /** getdata message type flags */ const uint32_t MSG_WITNESS_FLAG = 1 << 30; -const uint32_t MSG_TYPE_MASK = 0xffffffff >> 2; +const uint32_t MSG_TYPE_MASK = 0xffffffff >> 2; /** getdata / inv message types. * These numbers are defined by the protocol. When adding a new value, be sure * to mention it in the respective BIP. */ -enum GetDataMsg -{ +enum GetDataMsg { UNDEFINED = 0, MSG_TX = 1, MSG_BLOCK = 2, // The following can only occur in getdata. Invs always use TX or BLOCK. - MSG_FILTERED_BLOCK = 3, //!< Defined in BIP37 - MSG_CMPCT_BLOCK = 4, //!< Defined in BIP152 + MSG_FILTERED_BLOCK = 3, //!< Defined in BIP37 + MSG_CMPCT_BLOCK = 4, //!< Defined in BIP152 MSG_WITNESS_BLOCK = MSG_BLOCK | MSG_WITNESS_FLAG, //!< Defined in BIP144 MSG_WITNESS_TX = MSG_TX | MSG_WITNESS_FLAG, //!< Defined in BIP144 MSG_FILTERED_WITNESS_BLOCK = MSG_FILTERED_BLOCK | MSG_WITNESS_FLAG, @@ -396,21 +392,13 @@ public: CInv(); CInv(int typeIn, const uint256& hashIn); - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) - { - READWRITE(type); - READWRITE(hash); - } + SERIALIZE_METHODS(CInv, obj) { READWRITE(obj.type, obj.hash); } friend bool operator<(const CInv& a, const CInv& b); std::string GetCommand() const; std::string ToString() const; -public: int type; uint256 hash; }; diff --git a/src/qt/bitcoin.cpp b/src/qt/bitcoin.cpp index 8939b566f7..6fdf6322ff 100644 --- a/src/qt/bitcoin.cpp +++ b/src/qt/bitcoin.cpp @@ -34,6 +34,7 @@ #include <uint256.h> #include <util/system.h> #include <util/threadnames.h> +#include <validation.h> #include <memory> @@ -61,6 +62,7 @@ Q_IMPORT_PLUGIN(QCocoaIntegrationPlugin); // Declare meta types used for QMetaObject::invokeMethod Q_DECLARE_METATYPE(bool*) Q_DECLARE_METATYPE(CAmount) +Q_DECLARE_METATYPE(SynchronizationState) Q_DECLARE_METATYPE(uint256) static QString GetLangTerritory() @@ -435,6 +437,7 @@ int GuiMain(int argc, char* argv[]) // Register meta types used for QMetaObject::invokeMethod and Qt::QueuedConnection qRegisterMetaType<bool*>(); + qRegisterMetaType<SynchronizationState>(); #ifdef ENABLE_WALLET qRegisterMetaType<WalletModel*>(); #endif diff --git a/src/qt/bitcoingui.cpp b/src/qt/bitcoingui.cpp index 3a1fdc22a6..6192013e5f 100644 --- a/src/qt/bitcoingui.cpp +++ b/src/qt/bitcoingui.cpp @@ -37,6 +37,7 @@ #include <ui_interface.h> #include <util/system.h> #include <util/translation.h> +#include <validation.h> #include <QAction> #include <QApplication> @@ -353,6 +354,11 @@ void BitcoinGUI::createActions() showHelpMessageAction->setMenuRole(QAction::NoRole); showHelpMessageAction->setStatusTip(tr("Show the %1 help message to get a list with possible Bitcoin command-line options").arg(PACKAGE_NAME)); + m_mask_values_action = new QAction(tr("&Mask values"), this); + m_mask_values_action->setShortcut(QKeySequence(Qt::CTRL + Qt::SHIFT + Qt::Key_M)); + m_mask_values_action->setStatusTip(tr("Mask the values in the Overview tab")); + m_mask_values_action->setCheckable(true); + connect(quitAction, &QAction::triggered, qApp, QApplication::quit); connect(aboutAction, &QAction::triggered, this, &BitcoinGUI::aboutClicked); connect(aboutQtAction, &QAction::triggered, qApp, QApplication::aboutQt); @@ -415,6 +421,8 @@ void BitcoinGUI::createActions() connect(activity, &CreateWalletActivity::finished, activity, &QObject::deleteLater); activity->create(); }); + + connect(m_mask_values_action, &QAction::toggled, this, &BitcoinGUI::setPrivacy); } #endif // ENABLE_WALLET @@ -455,6 +463,8 @@ void BitcoinGUI::createMenuBar() settings->addAction(encryptWalletAction); settings->addAction(changePassphraseAction); settings->addSeparator(); + settings->addAction(m_mask_values_action); + settings->addSeparator(); } settings->addAction(optionsAction); @@ -567,7 +577,7 @@ void BitcoinGUI::setClientModel(ClientModel *_clientModel) connect(_clientModel, &ClientModel::networkActiveChanged, this, &BitcoinGUI::setNetworkActive); modalOverlay->setKnownBestHeight(_clientModel->getHeaderTipHeight(), QDateTime::fromTime_t(_clientModel->getHeaderTipTime())); - setNumBlocks(m_node.getNumBlocks(), QDateTime::fromTime_t(m_node.getLastBlockTime()), m_node.getVerificationProgress(), false); + setNumBlocks(m_node.getNumBlocks(), QDateTime::fromTime_t(m_node.getLastBlockTime()), m_node.getVerificationProgress(), false, SynchronizationState::INIT_DOWNLOAD); connect(_clientModel, &ClientModel::numBlocksChanged, this, &BitcoinGUI::setNumBlocks); // Receive and report messages from client model @@ -926,11 +936,15 @@ void BitcoinGUI::openOptionsDialogWithTab(OptionsDialog::Tab tab) dlg.exec(); } -void BitcoinGUI::setNumBlocks(int count, const QDateTime& blockDate, double nVerificationProgress, bool header) +void BitcoinGUI::setNumBlocks(int count, const QDateTime& blockDate, double nVerificationProgress, bool header, SynchronizationState sync_state) { // Disabling macOS App Nap on initial sync, disk and reindex operations. #ifdef Q_OS_MAC - (m_node.isInitialBlockDownload() || m_node.getReindex() || m_node.getImporting()) ? m_app_nap_inhibitor->disableAppNap() : m_app_nap_inhibitor->enableAppNap(); + if (sync_state == SynchronizationState::POST_INIT) { + m_app_nap_inhibitor->enableAppNap(); + } else { + m_app_nap_inhibitor->disableAppNap(); + } #endif if (modalOverlay) @@ -1246,7 +1260,7 @@ void BitcoinGUI::setEncryptionStatus(int status) labelWalletEncryptionIcon->setToolTip(tr("Wallet is <b>encrypted</b> and currently <b>unlocked</b>")); encryptWalletAction->setChecked(true); changePassphraseAction->setEnabled(true); - encryptWalletAction->setEnabled(false); // TODO: decrypt currently not supported + encryptWalletAction->setEnabled(false); break; case WalletModel::Locked: labelWalletEncryptionIcon->show(); @@ -1254,7 +1268,7 @@ void BitcoinGUI::setEncryptionStatus(int status) labelWalletEncryptionIcon->setToolTip(tr("Wallet is <b>encrypted</b> and currently <b>locked</b>")); encryptWalletAction->setChecked(true); changePassphraseAction->setEnabled(true); - encryptWalletAction->setEnabled(false); // TODO: decrypt currently not supported + encryptWalletAction->setEnabled(false); break; } } @@ -1409,6 +1423,12 @@ void BitcoinGUI::unsubscribeFromCoreSignals() m_handler_question->disconnect(); } +bool BitcoinGUI::isPrivacyModeActivated() const +{ + assert(m_mask_values_action); + return m_mask_values_action->isChecked(); +} + UnitDisplayStatusBarControl::UnitDisplayStatusBarControl(const PlatformStyle *platformStyle) : optionsModel(nullptr), menu(nullptr) diff --git a/src/qt/bitcoingui.h b/src/qt/bitcoingui.h index 6733585f68..c0198dd168 100644 --- a/src/qt/bitcoingui.h +++ b/src/qt/bitcoingui.h @@ -38,6 +38,7 @@ class WalletFrame; class WalletModel; class HelpMessageDialog; class ModalOverlay; +enum class SynchronizationState; namespace interfaces { class Handler; @@ -98,6 +99,8 @@ public: /** Disconnect core signals from GUI client */ void unsubscribeFromCoreSignals(); + bool isPrivacyModeActivated() const; + protected: void changeEvent(QEvent *e) override; void closeEvent(QCloseEvent *event) override; @@ -154,6 +157,7 @@ private: QAction* m_close_wallet_action{nullptr}; QAction* m_wallet_selector_label_action = nullptr; QAction* m_wallet_selector_action = nullptr; + QAction* m_mask_values_action{nullptr}; QLabel *m_wallet_selector_label = nullptr; QComboBox* m_wallet_selector = nullptr; @@ -206,6 +210,7 @@ Q_SIGNALS: void receivedURI(const QString &uri); /** Signal raised when RPC console shown */ void consoleShown(RPCConsole* console); + void setPrivacy(bool privacy); public Q_SLOTS: /** Set number of connections shown in the UI */ @@ -213,7 +218,7 @@ public Q_SLOTS: /** Set network state shown in the UI */ void setNetworkActive(bool networkActive); /** Set number of blocks and last block date shown in the UI */ - void setNumBlocks(int count, const QDateTime& blockDate, double nVerificationProgress, bool headers); + void setNumBlocks(int count, const QDateTime& blockDate, double nVerificationProgress, bool headers, SynchronizationState sync_state); /** Notify the user of an event from the core network or transaction handling code. @param[in] title the message box / notification title diff --git a/src/qt/bitcoinunits.cpp b/src/qt/bitcoinunits.cpp index d9711af123..318a6dcbfd 100644 --- a/src/qt/bitcoinunits.cpp +++ b/src/qt/bitcoinunits.cpp @@ -6,6 +6,8 @@ #include <QStringList> +#include <cassert> + BitcoinUnits::BitcoinUnits(QObject *parent): QAbstractListModel(parent), unitlist(availableUnits()) @@ -94,7 +96,7 @@ int BitcoinUnits::decimals(int unit) } } -QString BitcoinUnits::format(int unit, const CAmount& nIn, bool fPlus, SeparatorStyle separators) +QString BitcoinUnits::format(int unit, const CAmount& nIn, bool fPlus, SeparatorStyle separators, bool justify) { // Note: not using straight sprintf here because we do NOT want // localized number formatting. @@ -106,6 +108,7 @@ QString BitcoinUnits::format(int unit, const CAmount& nIn, bool fPlus, Separator qint64 n_abs = (n > 0 ? n : -n); qint64 quotient = n_abs / coin; QString quotient_str = QString::number(quotient); + if (justify) quotient_str = quotient_str.rightJustified(16 - num_decimals, ' '); // Use SI-style thin space separators as these are locale independent and can't be // confused with the decimal marker. @@ -150,6 +153,17 @@ QString BitcoinUnits::formatHtmlWithUnit(int unit, const CAmount& amount, bool p return QString("<span style='white-space: nowrap;'>%1</span>").arg(str); } +QString BitcoinUnits::formatWithPrivacy(int unit, const CAmount& amount, SeparatorStyle separators, bool privacy) +{ + assert(amount >= 0); + QString value; + if (privacy) { + value = format(unit, 0, false, separators, true).replace('0', '#'); + } else { + value = format(unit, amount, false, separators, true); + } + return value + QString(" ") + shortName(unit); +} bool BitcoinUnits::parse(int unit, const QString &value, CAmount *val_out) { diff --git a/src/qt/bitcoinunits.h b/src/qt/bitcoinunits.h index 1ff4702117..dac5484393 100644 --- a/src/qt/bitcoinunits.h +++ b/src/qt/bitcoinunits.h @@ -72,11 +72,13 @@ public: //! Number of decimals left static int decimals(int unit); //! Format as string - static QString format(int unit, const CAmount& amount, bool plussign=false, SeparatorStyle separators=separatorStandard); + static QString format(int unit, const CAmount& amount, bool plussign = false, SeparatorStyle separators = separatorStandard, bool justify = false); //! Format as string (with unit) static QString formatWithUnit(int unit, const CAmount& amount, bool plussign=false, SeparatorStyle separators=separatorStandard); //! Format as HTML string (with unit) static QString formatHtmlWithUnit(int unit, const CAmount& amount, bool plussign=false, SeparatorStyle separators=separatorStandard); + //! Format as string (with unit) of fixed length to preserve privacy, if it is set. + static QString formatWithPrivacy(int unit, const CAmount& amount, SeparatorStyle separators, bool privacy); //! Parse string to coin amount static bool parse(int unit, const QString &value, CAmount *val_out); //! Gets title for amount column including current display unit if optionsModel reference available */ diff --git a/src/qt/clientmodel.cpp b/src/qt/clientmodel.cpp index b94fcc9865..159b0d3df3 100644 --- a/src/qt/clientmodel.cpp +++ b/src/qt/clientmodel.cpp @@ -15,6 +15,7 @@ #include <net.h> #include <netbase.h> #include <util/system.h> +#include <validation.h> #include <stdint.h> @@ -234,17 +235,8 @@ static void BannedListChanged(ClientModel *clientmodel) assert(invoked); } -static void BlockTipChanged(ClientModel *clientmodel, bool initialSync, int height, int64_t blockTime, double verificationProgress, bool fHeader) +static void BlockTipChanged(ClientModel* clientmodel, SynchronizationState sync_state, int height, int64_t blockTime, double verificationProgress, bool fHeader) { - // lock free async UI updates in case we have a new block tip - // during initial sync, only update the UI if the last update - // was > 250ms (MODEL_UPDATE_DELAY) ago - int64_t now = 0; - if (initialSync) - now = GetTimeMillis(); - - int64_t& nLastUpdateNotification = fHeader ? nLastHeaderTipUpdateNotification : nLastBlockTipUpdateNotification; - if (fHeader) { // cache best headers time and height to reduce future cs_main locks clientmodel->cachedBestHeaderHeight = height; @@ -253,17 +245,22 @@ static void BlockTipChanged(ClientModel *clientmodel, bool initialSync, int heig clientmodel->m_cached_num_blocks = height; } - // During initial sync, block notifications, and header notifications from reindexing are both throttled. - if (!initialSync || (fHeader && !clientmodel->node().getReindex()) || now - nLastUpdateNotification > MODEL_UPDATE_DELAY) { - //pass an async signal to the UI thread - bool invoked = QMetaObject::invokeMethod(clientmodel, "numBlocksChanged", Qt::QueuedConnection, - Q_ARG(int, height), - Q_ARG(QDateTime, QDateTime::fromTime_t(blockTime)), - Q_ARG(double, verificationProgress), - Q_ARG(bool, fHeader)); - assert(invoked); - nLastUpdateNotification = now; + // Throttle GUI notifications about (a) blocks during initial sync, and (b) both blocks and headers during reindex. + const bool throttle = (sync_state != SynchronizationState::POST_INIT && !fHeader) || sync_state == SynchronizationState::INIT_REINDEX; + const int64_t now = throttle ? GetTimeMillis() : 0; + int64_t& nLastUpdateNotification = fHeader ? nLastHeaderTipUpdateNotification : nLastBlockTipUpdateNotification; + if (throttle && now < nLastUpdateNotification + MODEL_UPDATE_DELAY) { + return; } + + bool invoked = QMetaObject::invokeMethod(clientmodel, "numBlocksChanged", Qt::QueuedConnection, + Q_ARG(int, height), + Q_ARG(QDateTime, QDateTime::fromTime_t(blockTime)), + Q_ARG(double, verificationProgress), + Q_ARG(bool, fHeader), + Q_ARG(SynchronizationState, sync_state)); + assert(invoked); + nLastUpdateNotification = now; } void ClientModel::subscribeToCoreSignals() diff --git a/src/qt/clientmodel.h b/src/qt/clientmodel.h index 7ac4120a8f..ace77f5972 100644 --- a/src/qt/clientmodel.h +++ b/src/qt/clientmodel.h @@ -12,10 +12,10 @@ #include <memory> class BanTableModel; +class CBlockIndex; class OptionsModel; class PeerTableModel; - -class CBlockIndex; +enum class SynchronizationState; namespace interfaces { class Handler; @@ -100,7 +100,7 @@ private: Q_SIGNALS: void numConnectionsChanged(int count); - void numBlocksChanged(int count, const QDateTime& blockDate, double nVerificationProgress, bool header); + void numBlocksChanged(int count, const QDateTime& blockDate, double nVerificationProgress, bool header, SynchronizationState sync_state); void mempoolSizeChanged(long count, size_t mempoolSizeInBytes); void networkActiveChanged(bool networkActive); void alertsChanged(const QString &warnings); diff --git a/src/qt/forms/overviewpage.ui b/src/qt/forms/overviewpage.ui index 710801ee96..4d3f90c484 100644 --- a/src/qt/forms/overviewpage.ui +++ b/src/qt/forms/overviewpage.ui @@ -6,8 +6,8 @@ <rect> <x>0</x> <y>0</y> - <width>596</width> - <height>342</height> + <width>798</width> + <height>318</height> </rect> </property> <property name="windowTitle"> @@ -118,6 +118,7 @@ <widget class="QLabel" name="labelWatchPending"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -129,7 +130,7 @@ <string>Unconfirmed transactions to watch-only addresses</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">0.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> @@ -143,6 +144,7 @@ <widget class="QLabel" name="labelUnconfirmed"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -154,7 +156,7 @@ <string>Total of transactions that have yet to be confirmed, and do not yet count toward the spendable balance</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">0.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> @@ -168,6 +170,7 @@ <widget class="QLabel" name="labelWatchImmature"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -179,7 +182,7 @@ <string>Mined balance in watch-only addresses that has not yet matured</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">0.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> @@ -226,6 +229,7 @@ <widget class="QLabel" name="labelImmature"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -237,7 +241,7 @@ <string>Mined balance that has not yet matured</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">0.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> @@ -271,6 +275,7 @@ <widget class="QLabel" name="labelTotal"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -282,7 +287,7 @@ <string>Your current total balance</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">21 000 000.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> @@ -296,6 +301,7 @@ <widget class="QLabel" name="labelWatchTotal"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -307,7 +313,7 @@ <string>Current total balance in watch-only addresses</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">21 000 000.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> @@ -338,6 +344,7 @@ <widget class="QLabel" name="labelBalance"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -349,7 +356,7 @@ <string>Your current spendable balance</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">21 000 000.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> @@ -363,6 +370,7 @@ <widget class="QLabel" name="labelWatchAvailable"> <property name="font"> <font> + <family>Monospace</family> <weight>75</weight> <bold>true</bold> </font> @@ -374,7 +382,7 @@ <string>Your current balance in watch-only addresses</string> </property> <property name="text"> - <string notr="true">0.000 000 00 BTC</string> + <string notr="true">21 000 000.00000000 BTC</string> </property> <property name="alignment"> <set>Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter</set> diff --git a/src/qt/forms/receiverequestdialog.ui b/src/qt/forms/receiverequestdialog.ui index 9f896ee3b1..f6d4723465 100644 --- a/src/qt/forms/receiverequestdialog.ui +++ b/src/qt/forms/receiverequestdialog.ui @@ -6,68 +6,233 @@ <rect> <x>0</x> <y>0</y> - <width>487</width> - <height>597</height> + <width>413</width> + <height>229</height> </rect> </property> - <layout class="QVBoxLayout" name="verticalLayout_3"> - <item> - <widget class="QRImageWidget" name="lblQRCode"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Expanding" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>300</width> - <height>320</height> - </size> - </property> - <property name="toolTip"> - <string>QR Code</string> + <property name="windowTitle"> + <string>Request payment to ...</string> + </property> + <layout class="QGridLayout" name="gridLayout" columnstretch="0,1"> + <property name="sizeConstraint"> + <enum>QLayout::SetFixedSize</enum> + </property> + <item row="0" column="0" colspan="2" alignment="Qt::AlignHCenter"> + <widget class="QRImageWidget" name="qr_code"> + <property name="text"> + <string notr="true">QR image</string> + </property> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> + </widget> + </item> + <item row="1" column="0" colspan="2"> + <widget class="QLabel" name="payment_header"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string>Payment information</string> + </property> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> + </widget> + </item> + <item row="2" column="0" alignment="Qt::AlignRight|Qt::AlignTop"> + <widget class="QLabel" name="uri_tag"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string notr="true">URI:</string> </property> <property name="textFormat"> <enum>Qt::PlainText</enum> </property> - <property name="alignment"> - <set>Qt::AlignCenter</set> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> + </widget> + </item> + <item row="2" column="1" alignment="Qt::AlignTop"> + <widget class="QLabel" name="uri_content"> + <property name="text"> + <string notr="true">bitcoin:BC1...</string> + </property> + <property name="textFormat"> + <enum>Qt::RichText</enum> </property> <property name="wordWrap"> <bool>true</bool> </property> + <property name="textInteractionFlags"> + <set>Qt::TextSelectableByMouse</set> + </property> + </widget> + </item> + <item row="3" column="0" alignment="Qt::AlignRight|Qt::AlignTop"> + <widget class="QLabel" name="address_tag"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string>Address:</string> + </property> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> </widget> </item> - <item> - <widget class="QTextEdit" name="outUri"> - <property name="sizePolicy"> - <sizepolicy hsizetype="Expanding" vsizetype="Expanding"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>0</width> - <height>50</height> - </size> - </property> - <property name="frameShape"> - <enum>QFrame::NoFrame</enum> - </property> - <property name="frameShadow"> - <enum>QFrame::Plain</enum> - </property> - <property name="tabChangesFocus"> + <item row="3" column="1" alignment="Qt::AlignTop"> + <widget class="QLabel" name="address_content"> + <property name="text"> + <string notr="true">bc1...</string> + </property> + <property name="textFormat"> + <enum>Qt::PlainText</enum> + </property> + <property name="textInteractionFlags"> + <set>Qt::TextSelectableByMouse</set> + </property> + </widget> + </item> + <item row="4" column="0" alignment="Qt::AlignRight|Qt::AlignTop"> + <widget class="QLabel" name="amount_tag"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string>Amount:</string> + </property> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> + </widget> + </item> + <item row="4" column="1" alignment="Qt::AlignTop"> + <widget class="QLabel" name="amount_content"> + <property name="text"> + <string notr="true">0.00000000 BTC</string> + </property> + <property name="textFormat"> + <enum>Qt::PlainText</enum> + </property> + <property name="textInteractionFlags"> + <set>Qt::TextSelectableByMouse</set> + </property> + </widget> + </item> + <item row="5" column="0" alignment="Qt::AlignRight|Qt::AlignTop"> + <widget class="QLabel" name="label_tag"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string>Label:</string> + </property> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> + </widget> + </item> + <item row="5" column="1" alignment="Qt::AlignTop"> + <widget class="QLabel" name="label_content"> + <property name="text"> + <string notr="true">label content</string> + </property> + <property name="textFormat"> + <enum>Qt::PlainText</enum> + </property> + <property name="wordWrap"> + <bool>true</bool> + </property> + <property name="textInteractionFlags"> + <set>Qt::TextSelectableByMouse</set> + </property> + </widget> + </item> + <item row="6" column="0" alignment="Qt::AlignRight|Qt::AlignTop"> + <widget class="QLabel" name="message_tag"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string>Message:</string> + </property> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> + </widget> + </item> + <item row="6" column="1" alignment="Qt::AlignTop"> + <widget class="QLabel" name="message_content"> + <property name="text"> + <string notr="true">message content</string> + </property> + <property name="textFormat"> + <enum>Qt::PlainText</enum> + </property> + <property name="wordWrap"> + <bool>true</bool> + </property> + <property name="textInteractionFlags"> + <set>Qt::TextSelectableByMouse</set> + </property> + </widget> + </item> + <item row="7" column="0" alignment="Qt::AlignRight|Qt::AlignTop"> + <widget class="QLabel" name="wallet_tag"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="text"> + <string>Wallet:</string> + </property> + <property name="textInteractionFlags"> + <set>Qt::NoTextInteraction</set> + </property> + </widget> + </item> + <item row="7" column="1" alignment="Qt::AlignTop"> + <widget class="QLabel" name="wallet_content"> + <property name="text"> + <string notr="true">wallet name</string> + </property> + <property name="textFormat"> + <enum>Qt::PlainText</enum> + </property> + <property name="wordWrap"> <bool>true</bool> </property> <property name="textInteractionFlags"> - <set>Qt::TextSelectableByKeyboard|Qt::TextSelectableByMouse</set> + <set>Qt::TextSelectableByMouse</set> </property> </widget> </item> - <item> + <item row="8" column="0" colspan="2"> <layout class="QHBoxLayout" name="horizontalLayout"> <item> <widget class="QPushButton" name="btnCopyURI"> @@ -114,8 +279,11 @@ </item> <item> <widget class="QDialogButtonBox" name="buttonBox"> + <property name="focusPolicy"> + <enum>Qt::StrongFocus</enum> + </property> <property name="standardButtons"> - <set>QDialogButtonBox::Close</set> + <set>QDialogButtonBox::Ok</set> </property> </widget> </item> @@ -130,37 +298,27 @@ <header>qt/qrimagewidget.h</header> </customwidget> </customwidgets> + <tabstops> + <tabstop>buttonBox</tabstop> + <tabstop>btnCopyURI</tabstop> + <tabstop>btnCopyAddress</tabstop> + <tabstop>btnSaveAs</tabstop> + </tabstops> <resources/> <connections> <connection> <sender>buttonBox</sender> - <signal>rejected()</signal> - <receiver>ReceiveRequestDialog</receiver> - <slot>reject()</slot> - <hints> - <hint type="sourcelabel"> - <x>452</x> - <y>573</y> - </hint> - <hint type="destinationlabel"> - <x>243</x> - <y>298</y> - </hint> - </hints> - </connection> - <connection> - <sender>buttonBox</sender> <signal>accepted()</signal> <receiver>ReceiveRequestDialog</receiver> <slot>accept()</slot> <hints> <hint type="sourcelabel"> - <x>452</x> - <y>573</y> + <x>135</x> + <y>230</y> </hint> <hint type="destinationlabel"> - <x>243</x> - <y>298</y> + <x>135</x> + <y>126</y> </hint> </hints> </connection> diff --git a/src/qt/overviewpage.cpp b/src/qt/overviewpage.cpp index e20ec229fc..0af70f2735 100644 --- a/src/qt/overviewpage.cpp +++ b/src/qt/overviewpage.cpp @@ -16,7 +16,9 @@ #include <qt/walletmodel.h> #include <QAbstractItemDelegate> +#include <QApplication> #include <QPainter> +#include <QStatusTipEvent> #define DECORATION_SIZE 54 #define NUM_ITEMS 5 @@ -152,6 +154,21 @@ void OverviewPage::handleOutOfSyncWarningClicks() Q_EMIT outOfSyncWarningClicked(); } +void OverviewPage::setPrivacy(bool privacy) +{ + m_privacy = privacy; + if (m_balances.balance != -1) { + setBalance(m_balances); + } + + ui->listTransactions->setVisible(!m_privacy); + + const QString status_tip = m_privacy ? tr("Privacy mode activated for the Overview tab. To unmask the values, uncheck Settings->Mask values.") : ""; + setStatusTip(status_tip); + QStatusTipEvent event(status_tip); + QApplication::sendEvent(this, &event); +} + OverviewPage::~OverviewPage() { delete ui; @@ -163,25 +180,25 @@ void OverviewPage::setBalance(const interfaces::WalletBalances& balances) m_balances = balances; if (walletModel->wallet().isLegacy()) { if (walletModel->wallet().privateKeysDisabled()) { - ui->labelBalance->setText(BitcoinUnits::formatWithUnit(unit, balances.watch_only_balance, false, BitcoinUnits::separatorAlways)); - ui->labelUnconfirmed->setText(BitcoinUnits::formatWithUnit(unit, balances.unconfirmed_watch_only_balance, false, BitcoinUnits::separatorAlways)); - ui->labelImmature->setText(BitcoinUnits::formatWithUnit(unit, balances.immature_watch_only_balance, false, BitcoinUnits::separatorAlways)); - ui->labelTotal->setText(BitcoinUnits::formatWithUnit(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, false, BitcoinUnits::separatorAlways)); + ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); } else { - ui->labelBalance->setText(BitcoinUnits::formatWithUnit(unit, balances.balance, false, BitcoinUnits::separatorAlways)); - ui->labelUnconfirmed->setText(BitcoinUnits::formatWithUnit(unit, balances.unconfirmed_balance, false, BitcoinUnits::separatorAlways)); - ui->labelImmature->setText(BitcoinUnits::formatWithUnit(unit, balances.immature_balance, false, BitcoinUnits::separatorAlways)); - ui->labelTotal->setText(BitcoinUnits::formatWithUnit(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, false, BitcoinUnits::separatorAlways)); - ui->labelWatchAvailable->setText(BitcoinUnits::formatWithUnit(unit, balances.watch_only_balance, false, BitcoinUnits::separatorAlways)); - ui->labelWatchPending->setText(BitcoinUnits::formatWithUnit(unit, balances.unconfirmed_watch_only_balance, false, BitcoinUnits::separatorAlways)); - ui->labelWatchImmature->setText(BitcoinUnits::formatWithUnit(unit, balances.immature_watch_only_balance, false, BitcoinUnits::separatorAlways)); - ui->labelWatchTotal->setText(BitcoinUnits::formatWithUnit(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, false, BitcoinUnits::separatorAlways)); + ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelWatchAvailable->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelWatchPending->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelWatchImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelWatchTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); } } else { - ui->labelBalance->setText(BitcoinUnits::formatWithUnit(unit, balances.balance, false, BitcoinUnits::separatorAlways)); - ui->labelUnconfirmed->setText(BitcoinUnits::formatWithUnit(unit, balances.unconfirmed_balance, false, BitcoinUnits::separatorAlways)); - ui->labelImmature->setText(BitcoinUnits::formatWithUnit(unit, balances.immature_balance, false, BitcoinUnits::separatorAlways)); - ui->labelTotal->setText(BitcoinUnits::formatWithUnit(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, false, BitcoinUnits::separatorAlways)); + ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); } // only show immature (newly mined) balance if it's non-zero, so as not to complicate things // for the non-mining users diff --git a/src/qt/overviewpage.h b/src/qt/overviewpage.h index 00ba7ad4ce..4cf673b6a6 100644 --- a/src/qt/overviewpage.h +++ b/src/qt/overviewpage.h @@ -39,6 +39,7 @@ public: public Q_SLOTS: void setBalance(const interfaces::WalletBalances& balances); + void setPrivacy(bool privacy); Q_SIGNALS: void transactionClicked(const QModelIndex &index); @@ -49,6 +50,7 @@ private: ClientModel *clientModel; WalletModel *walletModel; interfaces::WalletBalances m_balances; + bool m_privacy{false}; TxViewDelegate *txdelegate; std::unique_ptr<TransactionFilterProxy> filter; diff --git a/src/qt/receiverequestdialog.cpp b/src/qt/receiverequestdialog.cpp index 30bd5c6a5a..d385c42821 100644 --- a/src/qt/receiverequestdialog.cpp +++ b/src/qt/receiverequestdialog.cpp @@ -8,10 +8,11 @@ #include <qt/bitcoinunits.h> #include <qt/guiutil.h> #include <qt/optionsmodel.h> +#include <qt/qrimagewidget.h> #include <qt/walletmodel.h> -#include <QClipboard> -#include <QPixmap> +#include <QDialog> +#include <QString> #if defined(HAVE_CONFIG_H) #include <config/bitcoin-config.h> /* for USE_QRCODE */ @@ -23,14 +24,6 @@ ReceiveRequestDialog::ReceiveRequestDialog(QWidget *parent) : model(nullptr) { ui->setupUi(this); - -#ifndef USE_QRCODE - ui->btnSaveAs->setVisible(false); - ui->lblQRCode->setVisible(false); -#endif - - connect(ui->btnSaveAs, &QPushButton::clicked, ui->lblQRCode, &QRImageWidget::saveImage); - GUIUtil::handleCloseWindowShortcut(this); } @@ -44,7 +37,7 @@ void ReceiveRequestDialog::setModel(WalletModel *_model) this->model = _model; if (_model) - connect(_model->getOptionsModel(), &OptionsModel::displayUnitChanged, this, &ReceiveRequestDialog::update); + connect(_model->getOptionsModel(), &OptionsModel::displayUnitChanged, this, &ReceiveRequestDialog::updateDisplayUnit); // update the display unit if necessary update(); @@ -53,40 +46,55 @@ void ReceiveRequestDialog::setModel(WalletModel *_model) void ReceiveRequestDialog::setInfo(const SendCoinsRecipient &_info) { this->info = _info; - update(); -} + setWindowTitle(tr("Request payment to %1").arg(info.label.isEmpty() ? info.address : info.label)); + QString uri = GUIUtil::formatBitcoinURI(info); -void ReceiveRequestDialog::update() -{ - if(!model) - return; - QString target = info.label; - if(target.isEmpty()) - target = info.address; - setWindowTitle(tr("Request payment to %1").arg(target)); +#ifdef USE_QRCODE + if (ui->qr_code->setQR(uri, info.address)) { + connect(ui->btnSaveAs, &QPushButton::clicked, ui->qr_code, &QRImageWidget::saveImage); + } else { + ui->btnSaveAs->setEnabled(false); + } +#else + ui->btnSaveAs->hide(); + ui->qr_code->hide(); +#endif - QString uri = GUIUtil::formatBitcoinURI(info); - ui->btnSaveAs->setEnabled(false); - QString html; - html += "<html><font face='verdana, arial, helvetica, sans-serif'>"; - html += "<b>"+tr("Payment information")+"</b><br>"; - html += "<b>"+tr("URI")+"</b>: "; - html += "<a href=\""+uri+"\">" + GUIUtil::HtmlEscape(uri) + "</a><br>"; - html += "<b>"+tr("Address")+"</b>: " + GUIUtil::HtmlEscape(info.address) + "<br>"; - if(info.amount) - html += "<b>"+tr("Amount")+"</b>: " + BitcoinUnits::formatHtmlWithUnit(model->getOptionsModel()->getDisplayUnit(), info.amount) + "<br>"; - if(!info.label.isEmpty()) - html += "<b>"+tr("Label")+"</b>: " + GUIUtil::HtmlEscape(info.label) + "<br>"; - if(!info.message.isEmpty()) - html += "<b>"+tr("Message")+"</b>: " + GUIUtil::HtmlEscape(info.message) + "<br>"; - if(model->isMultiwallet()) { - html += "<b>"+tr("Wallet")+"</b>: " + GUIUtil::HtmlEscape(model->getWalletName()) + "<br>"; + ui->uri_content->setText("<a href=\"" + uri + "\">" + GUIUtil::HtmlEscape(uri) + "</a>"); + ui->address_content->setText(info.address); + + if (!info.amount) { + ui->amount_tag->hide(); + ui->amount_content->hide(); + } // Amount is set in updateDisplayUnit() slot. + updateDisplayUnit(); + + if (!info.label.isEmpty()) { + ui->label_content->setText(info.label); + } else { + ui->label_tag->hide(); + ui->label_content->hide(); } - ui->outUri->setText(html); - if (ui->lblQRCode->setQR(uri, info.address)) { - ui->btnSaveAs->setEnabled(true); + if (!info.message.isEmpty()) { + ui->message_content->setText(info.message); + } else { + ui->message_tag->hide(); + ui->message_content->hide(); } + + if (!model->getWalletName().isEmpty()) { + ui->wallet_content->setText(model->getWalletName()); + } else { + ui->wallet_tag->hide(); + ui->wallet_content->hide(); + } +} + +void ReceiveRequestDialog::updateDisplayUnit() +{ + if (!model) return; + ui->amount_content->setText(BitcoinUnits::formatWithUnit(model->getOptionsModel()->getDisplayUnit(), info.amount)); } void ReceiveRequestDialog::on_btnCopyURI_clicked() diff --git a/src/qt/receiverequestdialog.h b/src/qt/receiverequestdialog.h index 40e3d5ffa8..846478643d 100644 --- a/src/qt/receiverequestdialog.h +++ b/src/qt/receiverequestdialog.h @@ -29,8 +29,7 @@ public: private Q_SLOTS: void on_btnCopyURI_clicked(); void on_btnCopyAddress_clicked(); - - void update(); + void updateDisplayUnit(); private: Ui::ReceiveRequestDialog *ui; diff --git a/src/qt/recentrequeststablemodel.h b/src/qt/recentrequeststablemodel.h index addf5ad0ae..c0bd3461bb 100644 --- a/src/qt/recentrequeststablemodel.h +++ b/src/qt/recentrequeststablemodel.h @@ -24,19 +24,11 @@ public: QDateTime date; SendCoinsRecipient recipient; - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - unsigned int nDate = date.toTime_t(); - - READWRITE(this->nVersion); - READWRITE(id); - READWRITE(nDate); - READWRITE(recipient); - - if (ser_action.ForRead()) - date = QDateTime::fromTime_t(nDate); + SERIALIZE_METHODS(RecentRequestEntry, obj) { + unsigned int date_timet; + SER_WRITE(obj, date_timet = obj.date.toTime_t()); + READWRITE(obj.nVersion, obj.id, date_timet, obj.recipient); + SER_READ(obj, obj.date = QDateTime::fromTime_t(date_timet)); } }; diff --git a/src/qt/rpcconsole.cpp b/src/qt/rpcconsole.cpp index aa49b7b44a..0f89d4e6fe 100644 --- a/src/qt/rpcconsole.cpp +++ b/src/qt/rpcconsole.cpp @@ -40,9 +40,6 @@ #include <QTime> #include <QTimer> -// TODO: add a scrollback limit, as there is currently none -// TODO: make it possible to filter out categories (esp debug messages when implemented) -// TODO: receive errors and debug messages through ClientModel const int CONSOLE_HISTORY = 50; const int INITIAL_TRAFFIC_GRAPH_MINS = 30; diff --git a/src/qt/sendcoinsrecipient.h b/src/qt/sendcoinsrecipient.h index 12279fab64..6619faf417 100644 --- a/src/qt/sendcoinsrecipient.h +++ b/src/qt/sendcoinsrecipient.h @@ -44,30 +44,21 @@ public: static const int CURRENT_VERSION = 1; int nVersion; - ADD_SERIALIZE_METHODS; + SERIALIZE_METHODS(SendCoinsRecipient, obj) + { + std::string address_str, label_str, message_str, auth_merchant_str; - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - std::string sAddress = address.toStdString(); - std::string sLabel = label.toStdString(); - std::string sMessage = message.toStdString(); - std::string sAuthenticatedMerchant = authenticatedMerchant.toStdString(); + SER_WRITE(obj, address_str = obj.address.toStdString()); + SER_WRITE(obj, label_str = obj.label.toStdString()); + SER_WRITE(obj, message_str = obj.message.toStdString()); + SER_WRITE(obj, auth_merchant_str = obj.authenticatedMerchant.toStdString()); - READWRITE(this->nVersion); - READWRITE(sAddress); - READWRITE(sLabel); - READWRITE(amount); - READWRITE(sMessage); - READWRITE(sPaymentRequest); - READWRITE(sAuthenticatedMerchant); + READWRITE(obj.nVersion, address_str, label_str, obj.amount, message_str, obj.sPaymentRequest, auth_merchant_str); - if (ser_action.ForRead()) - { - address = QString::fromStdString(sAddress); - label = QString::fromStdString(sLabel); - message = QString::fromStdString(sMessage); - authenticatedMerchant = QString::fromStdString(sAuthenticatedMerchant); - } + SER_READ(obj, obj.address = QString::fromStdString(address_str)); + SER_READ(obj, obj.label = QString::fromStdString(label_str)); + SER_READ(obj, obj.message = QString::fromStdString(message_str)); + SER_READ(obj, obj.authenticatedMerchant = QString::fromStdString(auth_merchant_str)); } }; diff --git a/src/qt/test/wallettests.cpp b/src/qt/test/wallettests.cpp index 2ee9ae0d86..8da0250e57 100644 --- a/src/qt/test/wallettests.cpp +++ b/src/qt/test/wallettests.cpp @@ -201,7 +201,7 @@ void TestGUI(interfaces::Node& node) OverviewPage overviewPage(platformStyle.get()); overviewPage.setWalletModel(&walletModel); QLabel* balanceLabel = overviewPage.findChild<QLabel*>("labelBalance"); - QString balanceText = balanceLabel->text(); + QString balanceText = balanceLabel->text().trimmed(); int unit = walletModel.getOptionsModel()->getDisplayUnit(); CAmount balance = walletModel.wallet().getBalance(); QString balanceComparison = BitcoinUnits::formatWithUnit(unit, balance, false, BitcoinUnits::separatorAlways); @@ -229,15 +229,23 @@ void TestGUI(interfaces::Node& node) for (QWidget* widget : QApplication::topLevelWidgets()) { if (widget->inherits("ReceiveRequestDialog")) { ReceiveRequestDialog* receiveRequestDialog = qobject_cast<ReceiveRequestDialog*>(widget); - QTextEdit* rlist = receiveRequestDialog->QObject::findChild<QTextEdit*>("outUri"); - QString paymentText = rlist->toPlainText(); - QStringList paymentTextList = paymentText.split('\n'); - QCOMPARE(paymentTextList.at(0), QString("Payment information")); - QVERIFY(paymentTextList.at(1).indexOf(QString("URI: bitcoin:")) != -1); - QVERIFY(paymentTextList.at(2).indexOf(QString("Address:")) != -1); - QCOMPARE(paymentTextList.at(3), QString("Amount: 0.00000001 ") + QString::fromStdString(CURRENCY_UNIT)); - QCOMPARE(paymentTextList.at(4), QString("Label: TEST_LABEL_1")); - QCOMPARE(paymentTextList.at(5), QString("Message: TEST_MESSAGE_1")); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("payment_header")->text(), QString("Payment information")); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("uri_tag")->text(), QString("URI:")); + QString uri = receiveRequestDialog->QObject::findChild<QLabel*>("uri_content")->text(); + QCOMPARE(uri.count("bitcoin:"), 2); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("address_tag")->text(), QString("Address:")); + + QCOMPARE(uri.count("amount=0.00000001"), 2); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("amount_tag")->text(), QString("Amount:")); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("amount_content")->text(), QString("0.00000001 ") + QString::fromStdString(CURRENCY_UNIT)); + + QCOMPARE(uri.count("label=TEST_LABEL_1"), 2); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("label_tag")->text(), QString("Label:")); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("label_content")->text(), QString("TEST_LABEL_1")); + + QCOMPARE(uri.count("message=TEST_MESSAGE_1"), 2); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("message_tag")->text(), QString("Message:")); + QCOMPARE(receiveRequestDialog->QObject::findChild<QLabel*>("message_content")->text(), QString("TEST_MESSAGE_1")); } } diff --git a/src/qt/trafficgraphwidget.cpp b/src/qt/trafficgraphwidget.cpp index 757648f485..6428fc4daf 100644 --- a/src/qt/trafficgraphwidget.cpp +++ b/src/qt/trafficgraphwidget.cpp @@ -7,6 +7,7 @@ #include <qt/clientmodel.h> #include <QPainter> +#include <QPainterPath> #include <QColor> #include <QTimer> diff --git a/src/qt/transactiontablemodel.cpp b/src/qt/transactiontablemodel.cpp index 18554aef1f..7a15503228 100644 --- a/src/qt/transactiontablemodel.cpp +++ b/src/qt/transactiontablemodel.cpp @@ -664,7 +664,7 @@ QVariant TransactionTableModel::headerData(int section, Qt::Orientation orientat QModelIndex TransactionTableModel::index(int row, int column, const QModelIndex &parent) const { Q_UNUSED(parent); - TransactionRecord *data = priv->index(walletModel->wallet(), walletModel->clientModel().getNumBlocks(), row); + TransactionRecord *data = priv->index(walletModel->wallet(), walletModel->getNumBlocks(), row); if(data) { return createIndex(row, column, data); diff --git a/src/qt/walletcontroller.cpp b/src/qt/walletcontroller.cpp index 7cde3ca30b..20f2ef5b5f 100644 --- a/src/qt/walletcontroller.cpp +++ b/src/qt/walletcontroller.cpp @@ -248,7 +248,7 @@ void CreateWalletActivity::finish() if (!m_error_message.original.empty()) { QMessageBox::critical(m_parent_widget, tr("Create wallet failed"), QString::fromStdString(m_error_message.translated)); } else if (!m_warning_message.empty()) { - QMessageBox::warning(m_parent_widget, tr("Create wallet warning"), QString::fromStdString(Join(m_warning_message, "\n", OpTranslated))); + QMessageBox::warning(m_parent_widget, tr("Create wallet warning"), QString::fromStdString(Join(m_warning_message, Untranslated("\n")).translated)); } if (m_wallet_model) Q_EMIT created(m_wallet_model); @@ -289,7 +289,7 @@ void OpenWalletActivity::finish() if (!m_error_message.original.empty()) { QMessageBox::critical(m_parent_widget, tr("Open wallet failed"), QString::fromStdString(m_error_message.translated)); } else if (!m_warning_message.empty()) { - QMessageBox::warning(m_parent_widget, tr("Open wallet warning"), QString::fromStdString(Join(m_warning_message, "\n", OpTranslated))); + QMessageBox::warning(m_parent_widget, tr("Open wallet warning"), QString::fromStdString(Join(m_warning_message, Untranslated("\n")).translated)); } if (m_wallet_model) Q_EMIT opened(m_wallet_model); diff --git a/src/qt/walletframe.cpp b/src/qt/walletframe.cpp index 02a9583ae9..5e68ee4f93 100644 --- a/src/qt/walletframe.cpp +++ b/src/qt/walletframe.cpp @@ -53,6 +53,7 @@ bool WalletFrame::addWallet(WalletModel *walletModel) walletView->setClientModel(clientModel); walletView->setWalletModel(walletModel); walletView->showOutOfSyncWarning(bOutOfSync); + walletView->setPrivacy(gui->isPrivacyModeActivated()); WalletView* current_wallet_view = currentWalletView(); if (current_wallet_view) { @@ -73,6 +74,7 @@ bool WalletFrame::addWallet(WalletModel *walletModel) connect(walletView, &WalletView::encryptionStatusChanged, gui, &BitcoinGUI::updateWalletStatus); connect(walletView, &WalletView::incomingTransaction, gui, &BitcoinGUI::incomingTransaction); connect(walletView, &WalletView::hdEnabledStatusChanged, gui, &BitcoinGUI::updateWalletStatus); + connect(gui, &BitcoinGUI::setPrivacy, walletView, &WalletView::setPrivacy); return true; } diff --git a/src/qt/walletmodel.cpp b/src/qt/walletmodel.cpp index 70ee7f4917..b1e61e03b3 100644 --- a/src/qt/walletmodel.cpp +++ b/src/qt/walletmodel.cpp @@ -39,14 +39,15 @@ WalletModel::WalletModel(std::unique_ptr<interfaces::Wallet> wallet, ClientModel& client_model, const PlatformStyle *platformStyle, QObject *parent) : QObject(parent), m_wallet(std::move(wallet)), - m_client_model(client_model), + m_client_model(&client_model), m_node(client_model.node()), optionsModel(client_model.getOptionsModel()), addressTableModel(nullptr), transactionTableModel(nullptr), recentRequestsTableModel(nullptr), cachedEncryptionStatus(Unencrypted), - cachedNumBlocks(0) + cachedNumBlocks(0), + timer(new QTimer(this)) { fHaveWatchOnly = m_wallet->haveWatchOnly(); addressTableModel = new AddressTableModel(this); @@ -64,11 +65,16 @@ WalletModel::~WalletModel() void WalletModel::startPollBalance() { // This timer will be fired repeatedly to update the balance - QTimer* timer = new QTimer(this); connect(timer, &QTimer::timeout, this, &WalletModel::pollBalanceChanged); timer->start(MODEL_UPDATE_DELAY); } +void WalletModel::setClientModel(ClientModel* client_model) +{ + m_client_model = client_model; + if (!m_client_model) timer->stop(); +} + void WalletModel::updateStatus() { EncryptionStatus newEncryptionStatus = getEncryptionStatus(); @@ -80,24 +86,31 @@ void WalletModel::updateStatus() void WalletModel::pollBalanceChanged() { + // Avoid recomputing wallet balances unless a TransactionChanged or + // BlockTip notification was received. + if (!fForceCheckBalanceChanged && cachedNumBlocks == m_client_model->getNumBlocks()) return; + // Try to get balances and return early if locks can't be acquired. This // avoids the GUI from getting stuck on periodical polls if the core is // holding the locks for a longer time - for example, during a wallet // rescan. interfaces::WalletBalances new_balances; int numBlocks = -1; - if (!m_wallet->tryGetBalances(new_balances, numBlocks, fForceCheckBalanceChanged, cachedNumBlocks)) { + if (!m_wallet->tryGetBalances(new_balances, numBlocks)) { return; } - fForceCheckBalanceChanged = false; + if(fForceCheckBalanceChanged || numBlocks != cachedNumBlocks) + { + fForceCheckBalanceChanged = false; - // Balance and number of transactions might have changed - cachedNumBlocks = numBlocks; + // Balance and number of transactions might have changed + cachedNumBlocks = numBlocks; - checkBalanceChanged(new_balances); - if(transactionTableModel) - transactionTableModel->updateConfirmations(); + checkBalanceChanged(new_balances); + if(transactionTableModel) + transactionTableModel->updateConfirmations(); + } } void WalletModel::checkBalanceChanged(const interfaces::WalletBalances& new_balances) @@ -304,16 +317,10 @@ WalletModel::EncryptionStatus WalletModel::getEncryptionStatus() const bool WalletModel::setWalletEncrypted(bool encrypted, const SecureString &passphrase) { - if(encrypted) - { - // Encrypt + if (encrypted) { return m_wallet->encryptWallet(passphrase); } - else - { - // Decrypt -- TODO; not supported yet - return false; - } + return false; } bool WalletModel::setWalletLocked(bool locked, const SecureString &passPhrase) diff --git a/src/qt/walletmodel.h b/src/qt/walletmodel.h index 07004b7c6b..23232ec66b 100644 --- a/src/qt/walletmodel.h +++ b/src/qt/walletmodel.h @@ -144,7 +144,8 @@ public: interfaces::Node& node() const { return m_node; } interfaces::Wallet& wallet() const { return *m_wallet; } - ClientModel& clientModel() const { return m_client_model; } + void setClientModel(ClientModel* client_model); + int getNumBlocks() const { return cachedNumBlocks; } QString getWalletName() const; QString getDisplayName() const; @@ -161,7 +162,7 @@ private: std::unique_ptr<interfaces::Handler> m_handler_show_progress; std::unique_ptr<interfaces::Handler> m_handler_watch_only_changed; std::unique_ptr<interfaces::Handler> m_handler_can_get_addrs_changed; - ClientModel& m_client_model; + ClientModel* m_client_model; interfaces::Node& m_node; bool fHaveWatchOnly; @@ -179,6 +180,7 @@ private: interfaces::WalletBalances m_cached_balances; EncryptionStatus cachedEncryptionStatus; int cachedNumBlocks; + QTimer* timer; void subscribeToCoreSignals(); void unsubscribeFromCoreSignals(); diff --git a/src/qt/walletview.cpp b/src/qt/walletview.cpp index 5d9b420df7..66fbf978be 100644 --- a/src/qt/walletview.cpp +++ b/src/qt/walletview.cpp @@ -85,6 +85,8 @@ WalletView::WalletView(const PlatformStyle *_platformStyle, QWidget *parent): connect(sendCoinsPage, &SendCoinsDialog::message, this, &WalletView::message); // Pass through messages from transactionView connect(transactionView, &TransactionView::message, this, &WalletView::message); + + connect(this, &WalletView::setPrivacy, overviewPage, &OverviewPage::setPrivacy); } WalletView::~WalletView() @@ -97,6 +99,7 @@ void WalletView::setClientModel(ClientModel *_clientModel) overviewPage->setClientModel(_clientModel); sendCoinsPage->setClientModel(_clientModel); + if (walletModel) walletModel->setClientModel(_clientModel); } void WalletView::setWalletModel(WalletModel *_walletModel) diff --git a/src/qt/walletview.h b/src/qt/walletview.h index 11f894e7f8..fd09456baa 100644 --- a/src/qt/walletview.h +++ b/src/qt/walletview.h @@ -115,6 +115,7 @@ public Q_SLOTS: void requestedSyncWarningInfo(); Q_SIGNALS: + void setPrivacy(bool privacy); void transactionClicked(); void coinsSent(); /** Fired when a message should be reported to the user */ diff --git a/src/random.cpp b/src/random.cpp index 5b8782d1ce..9c9a35709a 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -14,16 +14,14 @@ #include <wincrypt.h> #endif #include <logging.h> // for LogPrintf() +#include <randomenv.h> +#include <support/allocators/secure.h> #include <sync.h> // for Mutex #include <util/time.h> // for GetTimeMicros() #include <stdlib.h> #include <thread> -#include <randomenv.h> - -#include <support/allocators/secure.h> - #ifndef WIN32 #include <fcntl.h> #include <sys/time.h> @@ -590,16 +588,6 @@ uint64_t GetRand(uint64_t nMax) noexcept return FastRandomContext(g_mock_deterministic_tests).randrange(nMax); } -std::chrono::microseconds GetRandMicros(std::chrono::microseconds duration_max) noexcept -{ - return std::chrono::microseconds{GetRand(duration_max.count())}; -} - -std::chrono::milliseconds GetRandMillis(std::chrono::milliseconds duration_max) noexcept -{ - return std::chrono::milliseconds{GetRand(duration_max.count())}; -} - int GetRandInt(int nMax) noexcept { return GetRand(nMax); diff --git a/src/random.h b/src/random.h index 690125079b..0c6dc24983 100644 --- a/src/random.h +++ b/src/random.h @@ -67,9 +67,21 @@ * Thread-safe. */ void GetRandBytes(unsigned char* buf, int num) noexcept; +/** Generate a uniform random integer in the range [0..range). Precondition: range > 0 */ uint64_t GetRand(uint64_t nMax) noexcept; -std::chrono::microseconds GetRandMicros(std::chrono::microseconds duration_max) noexcept; -std::chrono::milliseconds GetRandMillis(std::chrono::milliseconds duration_max) noexcept; +/** Generate a uniform random duration in the range [0..max). Precondition: max.count() > 0 */ +template <typename D> +D GetRandomDuration(typename std::common_type<D>::type max) noexcept +// Having the compiler infer the template argument from the function argument +// is dangerous, because the desired return value generally has a different +// type than the function argument. So std::common_type is used to force the +// call site to specify the type of the return value. +{ + assert(max.count() > 0); + return D{GetRand(max.count())}; +}; +constexpr auto GetRandMicros = GetRandomDuration<std::chrono::microseconds>; +constexpr auto GetRandMillis = GetRandomDuration<std::chrono::milliseconds>; int GetRandInt(int nMax) noexcept; uint256 GetRandHash() noexcept; diff --git a/src/rest.cpp b/src/rest.cpp index 5f99e26bad..cde8b472d3 100644 --- a/src/rest.cpp +++ b/src/rest.cpp @@ -18,6 +18,7 @@ #include <sync.h> #include <txmempool.h> #include <util/check.h> +#include <util/ref.h> #include <util/strencodings.h> #include <validation.h> #include <version.h> @@ -49,18 +50,13 @@ struct CCoin { uint32_t nHeight; CTxOut out; - ADD_SERIALIZE_METHODS; - CCoin() : nHeight(0) {} explicit CCoin(Coin&& in) : nHeight(in.nHeight), out(std::move(in.out)) {} - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) + SERIALIZE_METHODS(CCoin, obj) { uint32_t nTxVerDummy = 0; - READWRITE(nTxVerDummy); - READWRITE(nHeight); - READWRITE(out); + READWRITE(nTxVerDummy, obj.nHeight, obj.out); } }; @@ -80,13 +76,14 @@ static bool RESTERR(HTTPRequest* req, enum HTTPStatusCode status, std::string me * @param[in] req the HTTP request * return pointer to the mempool or nullptr if no mempool found */ -static CTxMemPool* GetMemPool(HTTPRequest* req) +static CTxMemPool* GetMemPool(const util::Ref& context, HTTPRequest* req) { - if (!g_rpc_node || !g_rpc_node->mempool) { + NodeContext* node = context.Has<NodeContext>() ? &context.Get<NodeContext>() : nullptr; + if (!node || !node->mempool) { RESTERR(req, HTTP_NOT_FOUND, "Mempool disabled or instance not found"); return nullptr; } - return g_rpc_node->mempool; + return node->mempool; } static RetFormat ParseDataFormat(std::string& param, const std::string& strReq) @@ -134,7 +131,8 @@ static bool CheckWarmup(HTTPRequest* req) return true; } -static bool rest_headers(HTTPRequest* req, +static bool rest_headers(const util::Ref& context, + HTTPRequest* req, const std::string& strURIPart) { if (!CheckWarmup(req)) @@ -275,12 +273,12 @@ static bool rest_block(HTTPRequest* req, } } -static bool rest_block_extended(HTTPRequest* req, const std::string& strURIPart) +static bool rest_block_extended(const util::Ref& context, HTTPRequest* req, const std::string& strURIPart) { return rest_block(req, strURIPart, true); } -static bool rest_block_notxdetails(HTTPRequest* req, const std::string& strURIPart) +static bool rest_block_notxdetails(const util::Ref& context, HTTPRequest* req, const std::string& strURIPart) { return rest_block(req, strURIPart, false); } @@ -288,7 +286,7 @@ static bool rest_block_notxdetails(HTTPRequest* req, const std::string& strURIPa // A bit of a hack - dependency on a function defined in rpc/blockchain.cpp UniValue getblockchaininfo(const JSONRPCRequest& request); -static bool rest_chaininfo(HTTPRequest* req, const std::string& strURIPart) +static bool rest_chaininfo(const util::Ref& context, HTTPRequest* req, const std::string& strURIPart) { if (!CheckWarmup(req)) return false; @@ -297,7 +295,7 @@ static bool rest_chaininfo(HTTPRequest* req, const std::string& strURIPart) switch (rf) { case RetFormat::JSON: { - JSONRPCRequest jsonRequest; + JSONRPCRequest jsonRequest(context); jsonRequest.params = UniValue(UniValue::VARR); UniValue chainInfoObject = getblockchaininfo(jsonRequest); std::string strJSON = chainInfoObject.write() + "\n"; @@ -311,11 +309,11 @@ static bool rest_chaininfo(HTTPRequest* req, const std::string& strURIPart) } } -static bool rest_mempool_info(HTTPRequest* req, const std::string& strURIPart) +static bool rest_mempool_info(const util::Ref& context, HTTPRequest* req, const std::string& strURIPart) { if (!CheckWarmup(req)) return false; - const CTxMemPool* mempool = GetMemPool(req); + const CTxMemPool* mempool = GetMemPool(context, req); if (!mempool) return false; std::string param; const RetFormat rf = ParseDataFormat(param, strURIPart); @@ -335,10 +333,10 @@ static bool rest_mempool_info(HTTPRequest* req, const std::string& strURIPart) } } -static bool rest_mempool_contents(HTTPRequest* req, const std::string& strURIPart) +static bool rest_mempool_contents(const util::Ref& context, HTTPRequest* req, const std::string& strURIPart) { if (!CheckWarmup(req)) return false; - const CTxMemPool* mempool = GetMemPool(req); + const CTxMemPool* mempool = GetMemPool(context, req); if (!mempool) return false; std::string param; const RetFormat rf = ParseDataFormat(param, strURIPart); @@ -358,7 +356,7 @@ static bool rest_mempool_contents(HTTPRequest* req, const std::string& strURIPar } } -static bool rest_tx(HTTPRequest* req, const std::string& strURIPart) +static bool rest_tx(const util::Ref& context, HTTPRequest* req, const std::string& strURIPart) { if (!CheckWarmup(req)) return false; @@ -414,7 +412,7 @@ static bool rest_tx(HTTPRequest* req, const std::string& strURIPart) } } -static bool rest_getutxos(HTTPRequest* req, const std::string& strURIPart) +static bool rest_getutxos(const util::Ref& context, HTTPRequest* req, const std::string& strURIPart) { if (!CheckWarmup(req)) return false; @@ -523,7 +521,7 @@ static bool rest_getutxos(HTTPRequest* req, const std::string& strURIPart) }; if (fCheckMemPool) { - const CTxMemPool* mempool = GetMemPool(req); + const CTxMemPool* mempool = GetMemPool(context, req); if (!mempool) return false; // use db+mempool as cache backend in case user likes to query mempool LOCK2(cs_main, mempool->cs); @@ -600,7 +598,7 @@ static bool rest_getutxos(HTTPRequest* req, const std::string& strURIPart) } } -static bool rest_blockhash_by_height(HTTPRequest* req, +static bool rest_blockhash_by_height(const util::Ref& context, HTTPRequest* req, const std::string& str_uri_part) { if (!CheckWarmup(req)) return false; @@ -648,7 +646,7 @@ static bool rest_blockhash_by_height(HTTPRequest* req, static const struct { const char* prefix; - bool (*handler)(HTTPRequest* req, const std::string& strReq); + bool (*handler)(const util::Ref& context, HTTPRequest* req, const std::string& strReq); } uri_prefixes[] = { {"/rest/tx/", rest_tx}, {"/rest/block/notxdetails/", rest_block_notxdetails}, @@ -661,10 +659,12 @@ static const struct { {"/rest/blockhashbyheight/", rest_blockhash_by_height}, }; -void StartREST() +void StartREST(const util::Ref& context) { - for (unsigned int i = 0; i < ARRAYLEN(uri_prefixes); i++) - RegisterHTTPHandler(uri_prefixes[i].prefix, false, uri_prefixes[i].handler); + for (const auto& up : uri_prefixes) { + auto handler = [&context, up](HTTPRequest* req, const std::string& prefix) { return up.handler(context, req, prefix); }; + RegisterHTTPHandler(up.prefix, false, handler); + } } void InterruptREST() diff --git a/src/rpc/blockchain.cpp b/src/rpc/blockchain.cpp index f7ccbae706..4eb47d7b15 100644 --- a/src/rpc/blockchain.cpp +++ b/src/rpc/blockchain.cpp @@ -29,6 +29,7 @@ #include <txdb.h> #include <txmempool.h> #include <undo.h> +#include <util/ref.h> #include <util/strencodings.h> #include <util/system.h> #include <validation.h> @@ -51,15 +52,29 @@ struct CUpdatedBlock static Mutex cs_blockchange; static std::condition_variable cond_blockchange; -static CUpdatedBlock latestblock; +static CUpdatedBlock latestblock GUARDED_BY(cs_blockchange); -CTxMemPool& EnsureMemPool() +NodeContext& EnsureNodeContext(const util::Ref& context) { - CHECK_NONFATAL(g_rpc_node); - if (!g_rpc_node->mempool) { + if (!context.Has<NodeContext>()) { + throw JSONRPCError(RPC_INTERNAL_ERROR, "Node context not found"); + } + return context.Get<NodeContext>(); +} + +CTxMemPool& EnsureMemPool(const util::Ref& context) +{ + NodeContext& node = EnsureNodeContext(context); + if (!node.mempool) { throw JSONRPCError(RPC_CLIENT_MEMPOOL_DISABLED, "Mempool disabled or instance not found"); } - return *g_rpc_node->mempool; + return *node.mempool; +} + +ChainstateManager& EnsureChainman(const util::Ref& context) +{ + NodeContext& node = EnsureNodeContext(context); + return EnsureChainman(node); } /* Calculate the difficulty for a given block index. @@ -205,10 +220,10 @@ static UniValue getbestblockhash(const JSONRPCRequest& request) return ::ChainActive().Tip()->GetBlockHash().GetHex(); } -void RPCNotifyBlockChange(bool ibd, const CBlockIndex * pindex) +void RPCNotifyBlockChange(const CBlockIndex* pindex) { if(pindex) { - std::lock_guard<std::mutex> lock(cs_blockchange); + LOCK(cs_blockchange); latestblock.hash = pindex->GetBlockHash(); latestblock.height = pindex->nHeight; } @@ -243,9 +258,9 @@ static UniValue waitfornewblock(const JSONRPCRequest& request) WAIT_LOCK(cs_blockchange, lock); block = latestblock; if(timeout) - cond_blockchange.wait_for(lock, std::chrono::milliseconds(timeout), [&block]{return latestblock.height != block.height || latestblock.hash != block.hash || !IsRPCRunning(); }); + cond_blockchange.wait_for(lock, std::chrono::milliseconds(timeout), [&block]() EXCLUSIVE_LOCKS_REQUIRED(cs_blockchange) {return latestblock.height != block.height || latestblock.hash != block.hash || !IsRPCRunning(); }); else - cond_blockchange.wait(lock, [&block]{return latestblock.height != block.height || latestblock.hash != block.hash || !IsRPCRunning(); }); + cond_blockchange.wait(lock, [&block]() EXCLUSIVE_LOCKS_REQUIRED(cs_blockchange) {return latestblock.height != block.height || latestblock.hash != block.hash || !IsRPCRunning(); }); block = latestblock; } UniValue ret(UniValue::VOBJ); @@ -285,9 +300,9 @@ static UniValue waitforblock(const JSONRPCRequest& request) { WAIT_LOCK(cs_blockchange, lock); if(timeout) - cond_blockchange.wait_for(lock, std::chrono::milliseconds(timeout), [&hash]{return latestblock.hash == hash || !IsRPCRunning();}); + cond_blockchange.wait_for(lock, std::chrono::milliseconds(timeout), [&hash]() EXCLUSIVE_LOCKS_REQUIRED(cs_blockchange) {return latestblock.hash == hash || !IsRPCRunning();}); else - cond_blockchange.wait(lock, [&hash]{return latestblock.hash == hash || !IsRPCRunning(); }); + cond_blockchange.wait(lock, [&hash]() EXCLUSIVE_LOCKS_REQUIRED(cs_blockchange) {return latestblock.hash == hash || !IsRPCRunning(); }); block = latestblock; } @@ -329,9 +344,9 @@ static UniValue waitforblockheight(const JSONRPCRequest& request) { WAIT_LOCK(cs_blockchange, lock); if(timeout) - cond_blockchange.wait_for(lock, std::chrono::milliseconds(timeout), [&height]{return latestblock.height >= height || !IsRPCRunning();}); + cond_blockchange.wait_for(lock, std::chrono::milliseconds(timeout), [&height]() EXCLUSIVE_LOCKS_REQUIRED(cs_blockchange) {return latestblock.height >= height || !IsRPCRunning();}); else - cond_blockchange.wait(lock, [&height]{return latestblock.height >= height || !IsRPCRunning(); }); + cond_blockchange.wait(lock, [&height]() EXCLUSIVE_LOCKS_REQUIRED(cs_blockchange) {return latestblock.height >= height || !IsRPCRunning(); }); block = latestblock; } UniValue ret(UniValue::VOBJ); @@ -399,6 +414,7 @@ static std::vector<RPCResult> MempoolEntryDescription() { return { RPCResult{RPCResult::Type::ARR, "spentby", "unconfirmed transactions spending outputs from this transaction", {RPCResult{RPCResult::Type::STR_HEX, "transactionid", "child transaction id"}}}, RPCResult{RPCResult::Type::BOOL, "bip125-replaceable", "Whether this transaction could be replaced due to BIP125 (replace-by-fee)"}, + RPCResult{RPCResult::Type::BOOL, "unbroadcast", "Whether this transaction is currently unbroadcast (initial broadcast not yet confirmed)"}, };} static void entryToJSON(const CTxMemPool& pool, UniValue& info, const CTxMemPoolEntry& e) EXCLUSIVE_LOCKS_REQUIRED(pool.cs) @@ -460,6 +476,7 @@ static void entryToJSON(const CTxMemPool& pool, UniValue& info, const CTxMemPool } info.pushKV("bip125-replaceable", rbfStatus); + info.pushKV("unbroadcast", pool.IsUnbroadcastTx(tx.GetHash())); } UniValue MempoolToJSON(const CTxMemPool& pool, bool verbose) @@ -519,7 +536,7 @@ static UniValue getrawmempool(const JSONRPCRequest& request) if (!request.params[0].isNull()) fVerbose = request.params[0].get_bool(); - return MempoolToJSON(EnsureMemPool(), fVerbose); + return MempoolToJSON(EnsureMemPool(request.context), fVerbose); } static UniValue getmempoolancestors(const JSONRPCRequest& request) @@ -549,7 +566,7 @@ static UniValue getmempoolancestors(const JSONRPCRequest& request) uint256 hash = ParseHashV(request.params[0], "parameter 1"); - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); LOCK(mempool.cs); CTxMemPool::txiter it = mempool.mapTx.find(hash); @@ -612,7 +629,7 @@ static UniValue getmempooldescendants(const JSONRPCRequest& request) uint256 hash = ParseHashV(request.params[0], "parameter 1"); - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); LOCK(mempool.cs); CTxMemPool::txiter it = mempool.mapTx.find(hash); @@ -662,7 +679,7 @@ static UniValue getmempoolentry(const JSONRPCRequest& request) uint256 hash = ParseHashV(request.params[0], "parameter 1"); - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); LOCK(mempool.cs); CTxMemPool::txiter it = mempool.mapTx.find(hash); @@ -979,7 +996,7 @@ static UniValue gettxoutsetinfo(const JSONRPCRequest& request) ::ChainstateActive().ForceFlushStateToDisk(); CCoinsView* coins_view = WITH_LOCK(cs_main, return &ChainstateActive().CoinsDB()); - if (GetUTXOStats(coins_view, stats)) { + if (GetUTXOStats(coins_view, stats, RpcInterruptionPoint)) { ret.pushKV("height", (int64_t)stats.nHeight); ret.pushKV("bestblock", stats.hashBlock.GetHex()); ret.pushKV("transactions", (int64_t)stats.nTransactions); @@ -1045,7 +1062,7 @@ UniValue gettxout(const JSONRPCRequest& request) CCoinsViewCache* coins_view = &::ChainstateActive().CoinsTip(); if (fMempool) { - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); LOCK(mempool.cs); CCoinsViewMemPool view(coins_view, mempool); if (!view.GetCoin(out, coin) || mempool.isSpent(out)) { @@ -1389,7 +1406,7 @@ UniValue MempoolInfoToJSON(const CTxMemPool& pool) ret.pushKV("maxmempool", (int64_t) maxmempool); ret.pushKV("mempoolminfee", ValueFromAmount(std::max(pool.GetMinFee(maxmempool), ::minRelayTxFee).GetFeePerK())); ret.pushKV("minrelaytxfee", ValueFromAmount(::minRelayTxFee.GetFeePerK())); - + ret.pushKV("unbroadcastcount", uint64_t{pool.GetUnbroadcastTxs().size()}); return ret; } @@ -1408,6 +1425,7 @@ static UniValue getmempoolinfo(const JSONRPCRequest& request) {RPCResult::Type::NUM, "maxmempool", "Maximum memory usage for the mempool"}, {RPCResult::Type::STR_AMOUNT, "mempoolminfee", "Minimum fee rate in " + CURRENCY_UNIT + "/kB for tx to be accepted. Is the maximum of minrelaytxfee and minimum mempool fee"}, {RPCResult::Type::STR_AMOUNT, "minrelaytxfee", "Current minimum relay fee for transactions"}, + {RPCResult::Type::NUM, "unbroadcastcount", "Current number of transactions that haven't passed initial broadcast yet"} }}, RPCExamples{ HelpExampleCli("getmempoolinfo", "") @@ -1415,7 +1433,7 @@ static UniValue getmempoolinfo(const JSONRPCRequest& request) }, }.Check(request); - return MempoolInfoToJSON(EnsureMemPool()); + return MempoolInfoToJSON(EnsureMemPool(request.context)); } static UniValue preciousblock(const JSONRPCRequest& request) @@ -1934,7 +1952,7 @@ static UniValue savemempool(const JSONRPCRequest& request) }, }.Check(request); - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); if (!mempool.IsLoaded()) { throw JSONRPCError(RPC_MISC_ERROR, "The mempool was not loaded yet"); @@ -1956,6 +1974,7 @@ bool FindScriptPubKey(std::atomic<int>& scan_progress, const std::atomic<bool>& Coin coin; if (!cursor->GetKey(key) || !cursor->GetValue(coin)) return false; if (++count % 8192 == 0) { + RpcInterruptionPoint(); if (should_abort) { // allow to abort the scan via the abort reference return false; @@ -1976,7 +1995,6 @@ bool FindScriptPubKey(std::atomic<int>& scan_progress, const std::atomic<bool>& } /** RAII object to prevent concurrency issue when scanning the txout set */ -static std::mutex g_utxosetscan; static std::atomic<int> g_scan_progress; static std::atomic<bool> g_scan_in_progress; static std::atomic<bool> g_should_abort_scan; @@ -1989,18 +2007,15 @@ public: bool reserve() { CHECK_NONFATAL(!m_could_reserve); - std::lock_guard<std::mutex> lock(g_utxosetscan); - if (g_scan_in_progress) { + if (g_scan_in_progress.exchange(true)) { return false; } - g_scan_in_progress = true; m_could_reserve = true; return true; } ~CoinsViewScanReserver() { if (m_could_reserve) { - std::lock_guard<std::mutex> lock(g_utxosetscan); g_scan_in_progress = false; } } @@ -2299,7 +2314,7 @@ UniValue dumptxoutset(const JSONRPCRequest& request) ::ChainstateActive().ForceFlushStateToDisk(); - if (!GetUTXOStats(&::ChainstateActive().CoinsDB(), stats)) { + if (!GetUTXOStats(&::ChainstateActive().CoinsDB(), stats, RpcInterruptionPoint)) { throw JSONRPCError(RPC_INTERNAL_ERROR, "Unable to read UTXO set"); } @@ -2317,9 +2332,7 @@ UniValue dumptxoutset(const JSONRPCRequest& request) unsigned int iter{0}; while (pcursor->Valid()) { - if (iter % 5000 == 0 && !IsRPCRunning()) { - throw JSONRPCError(RPC_CLIENT_NOT_CONNECTED, "Shutting down"); - } + if (iter % 5000 == 0) RpcInterruptionPoint(); ++iter; if (pcursor->GetKey(key) && pcursor->GetValue(coin)) { afile << key; @@ -2385,5 +2398,3 @@ static const CRPCCommand commands[] = for (unsigned int vcidx = 0; vcidx < ARRAYLEN(commands); vcidx++) t.appendCommand(commands[vcidx].name, &commands[vcidx]); } - -NodeContext* g_rpc_node = nullptr; diff --git a/src/rpc/blockchain.h b/src/rpc/blockchain.h index a02e5fae0e..5c9a43b13e 100644 --- a/src/rpc/blockchain.h +++ b/src/rpc/blockchain.h @@ -16,8 +16,12 @@ extern RecursiveMutex cs_main; class CBlock; class CBlockIndex; class CTxMemPool; +class ChainstateManager; class UniValue; struct NodeContext; +namespace util { +class Ref; +} // namespace util static constexpr int NUM_GETBLOCKSTATS_PERCENTILES = 5; @@ -30,7 +34,7 @@ static constexpr int NUM_GETBLOCKSTATS_PERCENTILES = 5; double GetDifficulty(const CBlockIndex* blockindex); /** Callback for when block tip changed. */ -void RPCNotifyBlockChange(bool ibd, const CBlockIndex *); +void RPCNotifyBlockChange(const CBlockIndex*); /** Block description to JSON */ UniValue blockToJSON(const CBlock& block, const CBlockIndex* tip, const CBlockIndex* blockindex, bool txDetails = false) LOCKS_EXCLUDED(cs_main); @@ -47,11 +51,8 @@ UniValue blockheaderToJSON(const CBlockIndex* tip, const CBlockIndex* blockindex /** Used by getblockstats to get feerates at different percentiles by weight */ void CalculatePercentilesByWeight(CAmount result[NUM_GETBLOCKSTATS_PERCENTILES], std::vector<std::pair<CAmount, int64_t>>& scores, int64_t total_weight); -//! Pointer to node state that needs to be declared as a global to be accessible -//! RPC methods. Due to limitations of the RPC framework, there's currently no -//! direct way to pass in state to RPC methods without globals. -extern NodeContext* g_rpc_node; - -CTxMemPool& EnsureMemPool(); +NodeContext& EnsureNodeContext(const util::Ref& context); +CTxMemPool& EnsureMemPool(const util::Ref& context); +ChainstateManager& EnsureChainman(const util::Ref& context); #endif diff --git a/src/rpc/mining.cpp b/src/rpc/mining.cpp index 05d3fd6afb..3612f14bbf 100644 --- a/src/rpc/mining.cpp +++ b/src/rpc/mining.cpp @@ -101,7 +101,7 @@ static UniValue getnetworkhashps(const JSONRPCRequest& request) return GetNetworkHashPS(!request.params[0].isNull() ? request.params[0].get_int() : 120, !request.params[1].isNull() ? request.params[1].get_int() : -1); } -static bool GenerateBlock(CBlock& block, uint64_t& max_tries, unsigned int& extra_nonce, uint256& block_hash) +static bool GenerateBlock(ChainstateManager& chainman, CBlock& block, uint64_t& max_tries, unsigned int& extra_nonce, uint256& block_hash) { block_hash.SetNull(); @@ -124,14 +124,15 @@ static bool GenerateBlock(CBlock& block, uint64_t& max_tries, unsigned int& extr } std::shared_ptr<const CBlock> shared_pblock = std::make_shared<const CBlock>(block); - if (!ProcessNewBlock(chainparams, shared_pblock, true, nullptr)) + if (!chainman.ProcessNewBlock(chainparams, shared_pblock, true, nullptr)) { throw JSONRPCError(RPC_INTERNAL_ERROR, "ProcessNewBlock, block not accepted"); + } block_hash = block.GetHash(); return true; } -static UniValue generateBlocks(const CTxMemPool& mempool, const CScript& coinbase_script, int nGenerate, uint64_t nMaxTries) +static UniValue generateBlocks(ChainstateManager& chainman, const CTxMemPool& mempool, const CScript& coinbase_script, int nGenerate, uint64_t nMaxTries) { int nHeightEnd = 0; int nHeight = 0; @@ -151,7 +152,7 @@ static UniValue generateBlocks(const CTxMemPool& mempool, const CScript& coinbas CBlock *pblock = &pblocktemplate->block; uint256 block_hash; - if (!GenerateBlock(*pblock, nMaxTries, nExtraNonce, block_hash)) { + if (!GenerateBlock(chainman, *pblock, nMaxTries, nExtraNonce, block_hash)) { break; } @@ -227,9 +228,10 @@ static UniValue generatetodescriptor(const JSONRPCRequest& request) throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, error); } - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); + ChainstateManager& chainman = EnsureChainman(request.context); - return generateBlocks(mempool, coinbase_script, num_blocks, max_tries); + return generateBlocks(chainman, mempool, coinbase_script, num_blocks, max_tries); } static UniValue generatetoaddress(const JSONRPCRequest& request) @@ -265,11 +267,12 @@ static UniValue generatetoaddress(const JSONRPCRequest& request) throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Error: Invalid address"); } - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); + ChainstateManager& chainman = EnsureChainman(request.context); CScript coinbase_script = GetScriptForDestination(destination); - return generateBlocks(mempool, coinbase_script, nGenerate, nMaxTries); + return generateBlocks(chainman, mempool, coinbase_script, nGenerate, nMaxTries); } static UniValue generateblock(const JSONRPCRequest& request) @@ -311,7 +314,7 @@ static UniValue generateblock(const JSONRPCRequest& request) coinbase_script = GetScriptForDestination(destination); } - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); std::vector<CTransactionRef> txs; const auto raw_txs_or_txids = request.params[1].get_array(); @@ -370,7 +373,7 @@ static UniValue generateblock(const JSONRPCRequest& request) uint64_t max_tries{1000000}; unsigned int extra_nonce{0}; - if (!GenerateBlock(block, max_tries, extra_nonce, block_hash) || block_hash.IsNull()) { + if (!GenerateBlock(EnsureChainman(request.context), block, max_tries, extra_nonce, block_hash) || block_hash.IsNull()) { throw JSONRPCError(RPC_MISC_ERROR, "Failed to make block."); } @@ -403,7 +406,7 @@ static UniValue getmininginfo(const JSONRPCRequest& request) }.Check(request); LOCK(cs_main); - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); UniValue obj(UniValue::VOBJ); obj.pushKV("blocks", (int)::ChainActive().Height()); @@ -449,7 +452,7 @@ static UniValue prioritisetransaction(const JSONRPCRequest& request) throw JSONRPCError(RPC_INVALID_PARAMETER, "Priority is no longer supported, dummy argument to prioritisetransaction must be 0."); } - EnsureMemPool().PrioritiseTransaction(hash, nAmount); + EnsureMemPool(request.context).PrioritiseTransaction(hash, nAmount); return true; } @@ -635,17 +638,18 @@ static UniValue getblocktemplate(const JSONRPCRequest& request) if (strMode != "template") throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid mode"); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); - if (g_rpc_node->connman->GetNodeCount(CConnman::CONNECTIONS_ALL) == 0) + if (node.connman->GetNodeCount(CConnman::CONNECTIONS_ALL) == 0) throw JSONRPCError(RPC_CLIENT_NOT_CONNECTED, PACKAGE_NAME " is not connected!"); if (::ChainstateActive().IsInitialBlockDownload()) throw JSONRPCError(RPC_CLIENT_IN_INITIAL_DOWNLOAD, PACKAGE_NAME " is in initial sync and waiting for blocks..."); static unsigned int nTransactionsUpdatedLast; - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); if (!lpval.isNull()) { @@ -787,6 +791,8 @@ static UniValue getblocktemplate(const JSONRPCRequest& request) result.pushKV("capabilities", aCaps); UniValue aRules(UniValue::VARR); + aRules.push_back("csv"); + if (!fPreSegWit) aRules.push_back("!segwit"); UniValue vbavailable(UniValue::VOBJ); for (int j = 0; j < (int)Consensus::MAX_VERSION_BITS_DEPLOYMENTS; ++j) { Consensus::DeploymentPos pos = Consensus::DeploymentPos(j); @@ -874,7 +880,7 @@ static UniValue getblocktemplate(const JSONRPCRequest& request) return result; } -class submitblock_StateCatcher : public CValidationInterface +class submitblock_StateCatcher final : public CValidationInterface { public: uint256 hash; @@ -942,17 +948,17 @@ static UniValue submitblock(const JSONRPCRequest& request) } bool new_block; - submitblock_StateCatcher sc(block.GetHash()); - RegisterValidationInterface(&sc); - bool accepted = ProcessNewBlock(Params(), blockptr, /* fForceProcessing */ true, /* fNewBlock */ &new_block); - UnregisterValidationInterface(&sc); + auto sc = std::make_shared<submitblock_StateCatcher>(block.GetHash()); + RegisterSharedValidationInterface(sc); + bool accepted = EnsureChainman(request.context).ProcessNewBlock(Params(), blockptr, /* fForceProcessing */ true, /* fNewBlock */ &new_block); + UnregisterSharedValidationInterface(sc); if (!new_block && accepted) { return "duplicate"; } - if (!sc.found) { + if (!sc->found) { return "inconclusive"; } - return BIP22ValidationResult(sc.state); + return BIP22ValidationResult(sc->state); } static UniValue submitheader(const JSONRPCRequest& request) @@ -983,7 +989,7 @@ static UniValue submitheader(const JSONRPCRequest& request) } BlockValidationState state; - ProcessNewBlockHeaders({h}, state, Params()); + EnsureChainman(request.context).ProcessNewBlockHeaders({h}, state, Params()); if (state.IsValid()) return NullUniValue; if (state.IsError()) { throw JSONRPCError(RPC_VERIFY_ERROR, state.ToString()); diff --git a/src/rpc/misc.cpp b/src/rpc/misc.cpp index f3c5fed858..ce98a7c937 100644 --- a/src/rpc/misc.cpp +++ b/src/rpc/misc.cpp @@ -15,6 +15,7 @@ #include <script/descriptor.h> #include <util/check.h> #include <util/message.h> // For MessageSign(), MessageVerify() +#include <util/ref.h> #include <util/strencodings.h> #include <util/system.h> @@ -366,8 +367,8 @@ static UniValue setmocktime(const JSONRPCRequest& request) RPCTypeCheck(request.params, {UniValue::VNUM}); int64_t time = request.params[0].get_int64(); SetMockTime(time); - if (g_rpc_node) { - for (const auto& chain_client : g_rpc_node->chain_clients) { + if (request.context.Has<NodeContext>()) { + for (const auto& chain_client : request.context.Get<NodeContext>().chain_clients) { chain_client->setMockTime(time); } } @@ -398,9 +399,10 @@ static UniValue mockscheduler(const JSONRPCRequest& request) } // protect against null pointer dereference - CHECK_NONFATAL(g_rpc_node); - CHECK_NONFATAL(g_rpc_node->scheduler); - g_rpc_node->scheduler->MockForward(std::chrono::seconds(delta_seconds)); + CHECK_NONFATAL(request.context.Has<NodeContext>()); + NodeContext& node = request.context.Get<NodeContext>(); + CHECK_NONFATAL(node.scheduler); + node.scheduler->MockForward(std::chrono::seconds(delta_seconds)); return NullUniValue; } diff --git a/src/rpc/net.cpp b/src/rpc/net.cpp index d6d15f8b56..e29aa03695 100644 --- a/src/rpc/net.cpp +++ b/src/rpc/net.cpp @@ -42,10 +42,11 @@ static UniValue getconnectioncount(const JSONRPCRequest& request) }, }.Check(request); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); - return (int)g_rpc_node->connman->GetNodeCount(CConnman::CONNECTIONS_ALL); + return (int)node.connman->GetNodeCount(CConnman::CONNECTIONS_ALL); } static UniValue ping(const JSONRPCRequest& request) @@ -62,11 +63,12 @@ static UniValue ping(const JSONRPCRequest& request) }, }.Check(request); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); // Request that each node send a ping during next message processing pass - g_rpc_node->connman->ForEachNode([](CNode* pnode) { + node.connman->ForEachNode([](CNode* pnode) { pnode->fPingQueued = true; }); return NullUniValue; @@ -139,11 +141,12 @@ static UniValue getpeerinfo(const JSONRPCRequest& request) }, }.Check(request); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); std::vector<CNodeStats> vstats; - g_rpc_node->connman->GetNodeStats(vstats); + node.connman->GetNodeStats(vstats); UniValue ret(UniValue::VARR); @@ -248,7 +251,8 @@ static UniValue addnode(const JSONRPCRequest& request) }, }.ToString()); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); std::string strNode = request.params[0].get_str(); @@ -256,18 +260,18 @@ static UniValue addnode(const JSONRPCRequest& request) if (strCommand == "onetry") { CAddress addr; - g_rpc_node->connman->OpenNetworkConnection(addr, false, nullptr, strNode.c_str(), false, false, true); + node.connman->OpenNetworkConnection(addr, false, nullptr, strNode.c_str(), false, false, true); return NullUniValue; } if (strCommand == "add") { - if(!g_rpc_node->connman->AddNode(strNode)) + if(!node.connman->AddNode(strNode)) throw JSONRPCError(RPC_CLIENT_NODE_ALREADY_ADDED, "Error: Node already added"); } else if(strCommand == "remove") { - if(!g_rpc_node->connman->RemoveAddedNode(strNode)) + if(!node.connman->RemoveAddedNode(strNode)) throw JSONRPCError(RPC_CLIENT_NODE_NOT_ADDED, "Error: Node has not been added."); } @@ -293,7 +297,8 @@ static UniValue disconnectnode(const JSONRPCRequest& request) }, }.Check(request); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); bool success; @@ -302,11 +307,11 @@ static UniValue disconnectnode(const JSONRPCRequest& request) if (!address_arg.isNull() && id_arg.isNull()) { /* handle disconnect-by-address */ - success = g_rpc_node->connman->DisconnectNode(address_arg.get_str()); + success = node.connman->DisconnectNode(address_arg.get_str()); } else if (!id_arg.isNull() && (address_arg.isNull() || (address_arg.isStr() && address_arg.get_str().empty()))) { /* handle disconnect-by-id */ NodeId nodeid = (NodeId) id_arg.get_int64(); - success = g_rpc_node->connman->DisconnectNode(nodeid); + success = node.connman->DisconnectNode(nodeid); } else { throw JSONRPCError(RPC_INVALID_PARAMS, "Only one of address and nodeid should be provided."); } @@ -350,10 +355,11 @@ static UniValue getaddednodeinfo(const JSONRPCRequest& request) }, }.Check(request); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); - std::vector<AddedNodeInfo> vInfo = g_rpc_node->connman->GetAddedNodeInfo(); + std::vector<AddedNodeInfo> vInfo = node.connman->GetAddedNodeInfo(); if (!request.params[0].isNull()) { bool found = false; @@ -417,21 +423,22 @@ static UniValue getnettotals(const JSONRPCRequest& request) + HelpExampleRpc("getnettotals", "") }, }.Check(request); - if(!g_rpc_node->connman) + NodeContext& node = EnsureNodeContext(request.context); + if(!node.connman) throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); UniValue obj(UniValue::VOBJ); - obj.pushKV("totalbytesrecv", g_rpc_node->connman->GetTotalBytesRecv()); - obj.pushKV("totalbytessent", g_rpc_node->connman->GetTotalBytesSent()); + obj.pushKV("totalbytesrecv", node.connman->GetTotalBytesRecv()); + obj.pushKV("totalbytessent", node.connman->GetTotalBytesSent()); obj.pushKV("timemillis", GetTimeMillis()); UniValue outboundLimit(UniValue::VOBJ); - outboundLimit.pushKV("timeframe", g_rpc_node->connman->GetMaxOutboundTimeframe()); - outboundLimit.pushKV("target", g_rpc_node->connman->GetMaxOutboundTarget()); - outboundLimit.pushKV("target_reached", g_rpc_node->connman->OutboundTargetReached(false)); - outboundLimit.pushKV("serve_historical_blocks", !g_rpc_node->connman->OutboundTargetReached(true)); - outboundLimit.pushKV("bytes_left_in_cycle", g_rpc_node->connman->GetOutboundTargetBytesLeft()); - outboundLimit.pushKV("time_left_in_cycle", g_rpc_node->connman->GetMaxOutboundTimeLeftInCycle()); + outboundLimit.pushKV("timeframe", node.connman->GetMaxOutboundTimeframe()); + outboundLimit.pushKV("target", node.connman->GetMaxOutboundTarget()); + outboundLimit.pushKV("target_reached", node.connman->OutboundTargetReached(false)); + outboundLimit.pushKV("serve_historical_blocks", !node.connman->OutboundTargetReached(true)); + outboundLimit.pushKV("bytes_left_in_cycle", node.connman->GetOutboundTargetBytesLeft()); + outboundLimit.pushKV("time_left_in_cycle", node.connman->GetMaxOutboundTimeLeftInCycle()); obj.pushKV("uploadtarget", outboundLimit); return obj; } @@ -513,16 +520,17 @@ static UniValue getnetworkinfo(const JSONRPCRequest& request) obj.pushKV("version", CLIENT_VERSION); obj.pushKV("subversion", strSubVersion); obj.pushKV("protocolversion",PROTOCOL_VERSION); - if (g_rpc_node->connman) { - ServiceFlags services = g_rpc_node->connman->GetLocalServices(); + NodeContext& node = EnsureNodeContext(request.context); + if (node.connman) { + ServiceFlags services = node.connman->GetLocalServices(); obj.pushKV("localservices", strprintf("%016x", services)); obj.pushKV("localservicesnames", GetServicesNames(services)); } obj.pushKV("localrelay", g_relay_txes); obj.pushKV("timeoffset", GetTimeOffset()); - if (g_rpc_node->connman) { - obj.pushKV("networkactive", g_rpc_node->connman->GetNetworkActive()); - obj.pushKV("connections", (int)g_rpc_node->connman->GetNodeCount(CConnman::CONNECTIONS_ALL)); + if (node.connman) { + obj.pushKV("networkactive", node.connman->GetNetworkActive()); + obj.pushKV("connections", (int)node.connman->GetNodeCount(CConnman::CONNECTIONS_ALL)); } obj.pushKV("networks", GetNetworksInfo()); obj.pushKV("relayfee", ValueFromAmount(::minRelayTxFee.GetFeePerK())); @@ -567,7 +575,8 @@ static UniValue setban(const JSONRPCRequest& request) if (request.fHelp || !help.IsValidNumArgs(request.params.size()) || (strCommand != "add" && strCommand != "remove")) { throw std::runtime_error(help.ToString()); } - if (!g_rpc_node->banman) { + NodeContext& node = EnsureNodeContext(request.context); + if (!node.banman) { throw JSONRPCError(RPC_DATABASE_ERROR, "Error: Ban database not loaded"); } @@ -591,7 +600,7 @@ static UniValue setban(const JSONRPCRequest& request) if (strCommand == "add") { - if (isSubnet ? g_rpc_node->banman->IsBanned(subNet) : g_rpc_node->banman->IsBanned(netAddr)) { + if (isSubnet ? node.banman->IsBanned(subNet) : node.banman->IsBanned(netAddr)) { throw JSONRPCError(RPC_CLIENT_NODE_ALREADY_ADDED, "Error: IP/Subnet already banned"); } @@ -604,20 +613,20 @@ static UniValue setban(const JSONRPCRequest& request) absolute = true; if (isSubnet) { - g_rpc_node->banman->Ban(subNet, BanReasonManuallyAdded, banTime, absolute); - if (g_rpc_node->connman) { - g_rpc_node->connman->DisconnectNode(subNet); + node.banman->Ban(subNet, BanReasonManuallyAdded, banTime, absolute); + if (node.connman) { + node.connman->DisconnectNode(subNet); } } else { - g_rpc_node->banman->Ban(netAddr, BanReasonManuallyAdded, banTime, absolute); - if (g_rpc_node->connman) { - g_rpc_node->connman->DisconnectNode(netAddr); + node.banman->Ban(netAddr, BanReasonManuallyAdded, banTime, absolute); + if (node.connman) { + node.connman->DisconnectNode(netAddr); } } } else if(strCommand == "remove") { - if (!( isSubnet ? g_rpc_node->banman->Unban(subNet) : g_rpc_node->banman->Unban(netAddr) )) { + if (!( isSubnet ? node.banman->Unban(subNet) : node.banman->Unban(netAddr) )) { throw JSONRPCError(RPC_CLIENT_INVALID_IP_OR_SUBNET, "Error: Unban failed. Requested address/subnet was not previously banned."); } } @@ -645,12 +654,13 @@ static UniValue listbanned(const JSONRPCRequest& request) }, }.Check(request); - if(!g_rpc_node->banman) { + NodeContext& node = EnsureNodeContext(request.context); + if(!node.banman) { throw JSONRPCError(RPC_DATABASE_ERROR, "Error: Ban database not loaded"); } banmap_t banMap; - g_rpc_node->banman->GetBanned(banMap); + node.banman->GetBanned(banMap); UniValue bannedAddresses(UniValue::VARR); for (const auto& entry : banMap) @@ -679,11 +689,12 @@ static UniValue clearbanned(const JSONRPCRequest& request) + HelpExampleRpc("clearbanned", "") }, }.Check(request); - if (!g_rpc_node->banman) { + NodeContext& node = EnsureNodeContext(request.context); + if (!node.banman) { throw JSONRPCError(RPC_DATABASE_ERROR, "Error: Ban database not loaded"); } - g_rpc_node->banman->ClearBanned(); + node.banman->ClearBanned(); return NullUniValue; } @@ -699,13 +710,14 @@ static UniValue setnetworkactive(const JSONRPCRequest& request) RPCExamples{""}, }.Check(request); - if (!g_rpc_node->connman) { + NodeContext& node = EnsureNodeContext(request.context); + if (!node.connman) { throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); } - g_rpc_node->connman->SetNetworkActive(request.params[0].get_bool()); + node.connman->SetNetworkActive(request.params[0].get_bool()); - return g_rpc_node->connman->GetNetworkActive(); + return node.connman->GetNetworkActive(); } static UniValue getnodeaddresses(const JSONRPCRequest& request) @@ -732,7 +744,8 @@ static UniValue getnodeaddresses(const JSONRPCRequest& request) + HelpExampleRpc("getnodeaddresses", "8") }, }.Check(request); - if (!g_rpc_node->connman) { + NodeContext& node = EnsureNodeContext(request.context); + if (!node.connman) { throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); } @@ -744,7 +757,7 @@ static UniValue getnodeaddresses(const JSONRPCRequest& request) } } // returns a shuffled list of CAddress - std::vector<CAddress> vAddr = g_rpc_node->connman->GetAddresses(); + std::vector<CAddress> vAddr = node.connman->GetAddresses(); UniValue ret(UniValue::VARR); int address_return_count = std::min<int>(count, vAddr.size()); diff --git a/src/rpc/rawtransaction.cpp b/src/rpc/rawtransaction.cpp index 063ee1697c..e14217c307 100644 --- a/src/rpc/rawtransaction.cpp +++ b/src/rpc/rawtransaction.cpp @@ -653,7 +653,7 @@ static UniValue combinerawtransaction(const JSONRPCRequest& request) CCoinsView viewDummy; CCoinsViewCache view(&viewDummy); { - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); LOCK(cs_main); LOCK(mempool.cs); CCoinsViewCache &viewChain = ::ChainstateActive().CoinsTip(); @@ -778,7 +778,8 @@ static UniValue signrawtransactionwithkey(const JSONRPCRequest& request) for (const CTxIn& txin : mtx.vin) { coins[txin.prevout]; // Create empty map entry keyed by prevout. } - FindCoins(*g_rpc_node, coins); + NodeContext& node = EnsureNodeContext(request.context); + FindCoins(node, coins); // Parse the prevtxs array ParsePrevouts(request.params[2], &keystore, coins); @@ -837,7 +838,8 @@ static UniValue sendrawtransaction(const JSONRPCRequest& request) std::string err_string; AssertLockNotHeld(cs_main); - const TransactionError err = BroadcastTransaction(*g_rpc_node, tx, err_string, max_raw_tx_fee, /*relay*/ true, /*wait_callback*/ true); + NodeContext& node = EnsureNodeContext(request.context); + const TransactionError err = BroadcastTransaction(node, tx, err_string, max_raw_tx_fee, /*relay*/ true, /*wait_callback*/ true); if (TransactionError::OK != err) { throw JSONRPCTransactionError(err, err_string); } @@ -904,7 +906,7 @@ static UniValue testmempoolaccept(const JSONRPCRequest& request) DEFAULT_MAX_RAW_TX_FEE_RATE : CFeeRate(AmountFromValue(request.params[1])); - CTxMemPool& mempool = EnsureMemPool(); + CTxMemPool& mempool = EnsureMemPool(request.context); int64_t virtual_size = GetVirtualTransactionSize(*tx); CAmount max_raw_tx_fee = max_raw_tx_fee_rate.GetFee(virtual_size); @@ -1555,7 +1557,7 @@ UniValue utxoupdatepsbt(const JSONRPCRequest& request) CCoinsView viewDummy; CCoinsViewCache view(&viewDummy); { - const CTxMemPool& mempool = EnsureMemPool(); + const CTxMemPool& mempool = EnsureMemPool(request.context); LOCK2(cs_main, mempool.cs); CCoinsViewCache &viewChain = ::ChainstateActive().CoinsTip(); CCoinsViewMemPool viewMempool(&viewChain, mempool); diff --git a/src/rpc/request.cpp b/src/rpc/request.cpp index 56cac6661e..7fef45f50e 100644 --- a/src/rpc/request.cpp +++ b/src/rpc/request.cpp @@ -130,20 +130,20 @@ void DeleteAuthCookie() } } -std::vector<UniValue> JSONRPCProcessBatchReply(const UniValue &in, size_t num) +std::vector<UniValue> JSONRPCProcessBatchReply(const UniValue& in) { if (!in.isArray()) { throw std::runtime_error("Batch must be an array"); } + const size_t num {in.size()}; std::vector<UniValue> batch(num); - for (size_t i=0; i<in.size(); ++i) { - const UniValue &rec = in[i]; + for (const UniValue& rec : in.getValues()) { if (!rec.isObject()) { - throw std::runtime_error("Batch member must be object"); + throw std::runtime_error("Batch member must be an object"); } size_t id = rec["id"].get_int(); if (id >= num) { - throw std::runtime_error("Batch member id larger than size"); + throw std::runtime_error("Batch member id is larger than batch size"); } batch[id] = rec; } diff --git a/src/rpc/request.h b/src/rpc/request.h index 99eb4f9354..02ec5393a7 100644 --- a/src/rpc/request.h +++ b/src/rpc/request.h @@ -10,6 +10,10 @@ #include <univalue.h> +namespace util { +class Ref; +} // namespace util + UniValue JSONRPCRequestObj(const std::string& strMethod, const UniValue& params, const UniValue& id); UniValue JSONRPCReplyObj(const UniValue& result, const UniValue& error, const UniValue& id); std::string JSONRPCReply(const UniValue& result, const UniValue& error, const UniValue& id); @@ -22,7 +26,7 @@ bool GetAuthCookie(std::string *cookie_out); /** Delete RPC authentication cookie from disk */ void DeleteAuthCookie(); /** Parse JSON-RPC batch reply into a vector */ -std::vector<UniValue> JSONRPCProcessBatchReply(const UniValue &in, size_t num); +std::vector<UniValue> JSONRPCProcessBatchReply(const UniValue& in); class JSONRPCRequest { @@ -34,8 +38,9 @@ public: std::string URI; std::string authUser; std::string peerAddr; + const util::Ref& context; - JSONRPCRequest() : id(NullUniValue), params(NullUniValue), fHelp(false) {} + JSONRPCRequest(const util::Ref& context) : id(NullUniValue), params(NullUniValue), fHelp(false), context(context) {} void parse(const UniValue& valRequest); }; diff --git a/src/rpc/server.cpp b/src/rpc/server.cpp index 219979f095..99c649d15a 100644 --- a/src/rpc/server.cpp +++ b/src/rpc/server.cpp @@ -11,9 +11,9 @@ #include <util/strencodings.h> #include <util/system.h> -#include <boost/signals2/signal.hpp> #include <boost/algorithm/string/classification.hpp> #include <boost/algorithm/string/split.hpp> +#include <boost/signals2/signal.hpp> #include <memory> // for unique_ptr #include <unordered_map> @@ -309,6 +309,11 @@ bool IsRPCRunning() return g_rpc_running; } +void RpcInterruptionPoint() +{ + if (!IsRPCRunning()) throw JSONRPCError(RPC_CLIENT_NOT_CONNECTED, "Shutting down"); +} + void SetRPCWarmupStatus(const std::string& newStatus) { LOCK(cs_rpcWarmup); diff --git a/src/rpc/server.h b/src/rpc/server.h index c91bf1f613..d7a04ff6e8 100644 --- a/src/rpc/server.h +++ b/src/rpc/server.h @@ -9,10 +9,10 @@ #include <amount.h> #include <rpc/request.h> +#include <functional> #include <map> #include <stdint.h> #include <string> -#include <functional> #include <univalue.h> @@ -29,6 +29,9 @@ namespace RPCServer /** Query whether RPC is running */ bool IsRPCRunning(); +/** Throw JSONRPCError if RPC is not running */ +void RpcInterruptionPoint(); + /** * Set the RPC warmup status. When this is done, all RPC calls will error out * immediately with RPC_IN_WARMUP. diff --git a/src/script/keyorigin.h b/src/script/keyorigin.h index 467605ce46..a318ff0f9d 100644 --- a/src/script/keyorigin.h +++ b/src/script/keyorigin.h @@ -18,13 +18,7 @@ struct KeyOriginInfo return std::equal(std::begin(a.fingerprint), std::end(a.fingerprint), std::begin(b.fingerprint)) && a.path == b.path; } - ADD_SERIALIZE_METHODS; - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) - { - READWRITE(fingerprint); - READWRITE(path); - } + SERIALIZE_METHODS(KeyOriginInfo, obj) { READWRITE(obj.fingerprint, obj.path); } void clear() { diff --git a/src/script/script.cpp b/src/script/script.cpp index ae0de1d24e..92c6fe7785 100644 --- a/src/script/script.cpp +++ b/src/script/script.cpp @@ -7,7 +7,9 @@ #include <util/strencodings.h> -const char* GetOpName(opcodetype opcode) +#include <string> + +std::string GetOpName(opcodetype opcode) { switch (opcode) { diff --git a/src/script/script.h b/src/script/script.h index 773ffbb985..c1f2b66921 100644 --- a/src/script/script.h +++ b/src/script/script.h @@ -193,7 +193,7 @@ enum opcodetype // Maximum value that an opcode can be static const unsigned int MAX_OPCODE = OP_NOP10; -const char* GetOpName(opcodetype opcode); +std::string GetOpName(opcodetype opcode); class scriptnum_error : public std::runtime_error { @@ -412,12 +412,7 @@ public: CScript(std::vector<unsigned char>::const_iterator pbegin, std::vector<unsigned char>::const_iterator pend) : CScriptBase(pbegin, pend) { } CScript(const unsigned char* pbegin, const unsigned char* pend) : CScriptBase(pbegin, pend) { } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITEAS(CScriptBase, *this); - } + SERIALIZE_METHODS(CScript, obj) { READWRITEAS(CScriptBase, obj); } explicit CScript(int64_t b) { operator<<(b); } explicit CScript(opcodetype b) { operator<<(b); } diff --git a/src/script/script_error.cpp b/src/script/script_error.cpp index 57e8fee539..69e14803f1 100644 --- a/src/script/script_error.cpp +++ b/src/script/script_error.cpp @@ -5,7 +5,9 @@ #include <script/script_error.h> -const char* ScriptErrorString(const ScriptError serror) +#include <string> + +std::string ScriptErrorString(const ScriptError serror) { switch (serror) { diff --git a/src/script/script_error.h b/src/script/script_error.h index 400f63ff0f..2978c147e1 100644 --- a/src/script/script_error.h +++ b/src/script/script_error.h @@ -6,6 +6,8 @@ #ifndef BITCOIN_SCRIPT_SCRIPT_ERROR_H #define BITCOIN_SCRIPT_SCRIPT_ERROR_H +#include <string> + typedef enum ScriptError_t { SCRIPT_ERR_OK = 0, @@ -73,6 +75,6 @@ typedef enum ScriptError_t #define SCRIPT_ERR_LAST SCRIPT_ERR_ERROR_COUNT -const char* ScriptErrorString(const ScriptError error); +std::string ScriptErrorString(const ScriptError error); #endif // BITCOIN_SCRIPT_SCRIPT_ERROR_H diff --git a/src/script/standard.cpp b/src/script/standard.cpp index 7d89a336fb..c90c2c24a0 100644 --- a/src/script/standard.cpp +++ b/src/script/standard.cpp @@ -9,6 +9,8 @@ #include <pubkey.h> #include <script/script.h> +#include <string> + typedef std::vector<unsigned char> valtype; bool fAcceptDatacarrier = DEFAULT_ACCEPT_DATACARRIER; @@ -25,7 +27,7 @@ WitnessV0ScriptHash::WitnessV0ScriptHash(const CScript& in) CSHA256().Write(in.data(), in.size()).Finalize(begin()); } -const char* GetTxnOutputType(txnouttype t) +std::string GetTxnOutputType(txnouttype t) { switch (t) { @@ -39,7 +41,7 @@ const char* GetTxnOutputType(txnouttype t) case TX_WITNESS_V0_SCRIPTHASH: return "witness_v0_scripthash"; case TX_WITNESS_UNKNOWN: return "witness_unknown"; } - return nullptr; + assert(false); } static bool MatchPayToPubkey(const CScript& script, valtype& pubkey) diff --git a/src/script/standard.h b/src/script/standard.h index 49a45f3eba..2929425670 100644 --- a/src/script/standard.h +++ b/src/script/standard.h @@ -11,6 +11,8 @@ #include <boost/variant.hpp> +#include <string> + static const bool DEFAULT_ACCEPT_DATACARRIER = true; @@ -44,8 +46,7 @@ extern unsigned nMaxDatacarrierBytes; /** * Mandatory script verification flags that all new blocks must comply with for * them to be valid. (but old blocks may not comply with) Currently just P2SH, - * but in the future other flags may be added, such as a soft-fork to enforce - * strict DER encoding. + * but in the future other flags may be added. * * Failing one of these tests may trigger a DoS ban - see CheckInputScripts() for * details. @@ -146,7 +147,7 @@ typedef boost::variant<CNoDestination, PKHash, ScriptHash, WitnessV0ScriptHash, bool IsValidDestination(const CTxDestination& dest); /** Get the name of a txnouttype as a C string, or nullptr if unknown. */ -const char* GetTxnOutputType(txnouttype t); +std::string GetTxnOutputType(txnouttype t); /** * Parse a scriptPubKey and identify script type for standard scripts. If diff --git a/src/serialize.h b/src/serialize.h index fe53eeed31..71c2cfa164 100644 --- a/src/serialize.h +++ b/src/serialize.h @@ -43,26 +43,6 @@ static const unsigned int MAX_VECTOR_ALLOCATE = 5000000; struct deserialize_type {}; constexpr deserialize_type deserialize {}; -/** - * Used to bypass the rule against non-const reference to temporary - * where it makes sense with wrappers. - */ -template<typename T> -inline T& REF(const T& val) -{ - return const_cast<T&>(val); -} - -/** - * Used to acquire a non-const pointer "this" to generate bodies - * of const serialization operations from a template - */ -template<typename T> -inline T* NCONST_PTR(const T* val) -{ - return const_cast<T*>(val); -} - //! Safely convert odd char pointer types to standard ones. inline char* CharCast(char* c) { return c; } inline char* CharCast(unsigned char* c) { return (char*)c; } @@ -190,22 +170,8 @@ template<typename X> const X& ReadWriteAsHelper(const X& x) { return x; } #define READWRITE(...) (::SerReadWriteMany(s, ser_action, __VA_ARGS__)) #define READWRITEAS(type, obj) (::SerReadWriteMany(s, ser_action, ReadWriteAsHelper<type>(obj))) - -/** - * Implement three methods for serializable objects. These are actually wrappers over - * "SerializationOp" template, which implements the body of each class' serialization - * code. Adding "ADD_SERIALIZE_METHODS" in the body of the class causes these wrappers to be - * added as members. - */ -#define ADD_SERIALIZE_METHODS \ - template<typename Stream> \ - void Serialize(Stream& s) const { \ - NCONST_PTR(this)->SerializationOp(s, CSerActionSerialize()); \ - } \ - template<typename Stream> \ - void Unserialize(Stream& s) { \ - SerializationOp(s, CSerActionUnserialize()); \ - } +#define SER_READ(obj, code) ::SerRead(s, ser_action, obj, [&](Stream& s, typename std::remove_const<Type>::type& obj) { code; }) +#define SER_WRITE(obj, code) ::SerWrite(s, ser_action, obj, [&](Stream& s, const Type& obj) { code; }) /** * Implement the Ser and Unser methods needed for implementing a formatter (see Using below). @@ -501,7 +467,7 @@ static inline Wrapper<Formatter, T&> Using(T&& t) { return Wrapper<Formatter, T& #define VARINT_MODE(obj, mode) Using<VarIntFormatter<mode>>(obj) #define VARINT(obj) Using<VarIntFormatter<VarIntMode::DEFAULT>>(obj) #define COMPACTSIZE(obj) Using<CompactSizeFormatter>(obj) -#define LIMITED_STRING(obj,n) LimitedString< n >(REF(obj)) +#define LIMITED_STRING(obj,n) Using<LimitedStringFormatter<n>>(obj) /** Serialization wrapper class for integers in VarInt format. */ template<VarIntMode Mode> @@ -518,7 +484,16 @@ struct VarIntFormatter } }; -template<int Bytes> +/** Serialization wrapper class for custom integers and enums. + * + * It permits specifying the serialized size (1 to 8 bytes) and endianness. + * + * Use the big endian mode for values that are stored in memory in native + * byte order, but serialized in big endian notation. This is only intended + * to implement serializers that are compatible with existing formats, and + * its use is not recommended for new data structures. + */ +template<int Bytes, bool BigEndian = false> struct CustomUintFormatter { static_assert(Bytes > 0 && Bytes <= 8, "CustomUintFormatter Bytes out of range"); @@ -527,52 +502,31 @@ struct CustomUintFormatter template <typename Stream, typename I> void Ser(Stream& s, I v) { if (v < 0 || v > MAX) throw std::ios_base::failure("CustomUintFormatter value out of range"); - uint64_t raw = htole64(v); - s.write((const char*)&raw, Bytes); + if (BigEndian) { + uint64_t raw = htobe64(v); + s.write(((const char*)&raw) + 8 - Bytes, Bytes); + } else { + uint64_t raw = htole64(v); + s.write((const char*)&raw, Bytes); + } } template <typename Stream, typename I> void Unser(Stream& s, I& v) { - static_assert(std::numeric_limits<I>::max() >= MAX && std::numeric_limits<I>::min() <= 0, "CustomUintFormatter type too small"); + using U = typename std::conditional<std::is_enum<I>::value, std::underlying_type<I>, std::common_type<I>>::type::type; + static_assert(std::numeric_limits<U>::max() >= MAX && std::numeric_limits<U>::min() <= 0, "Assigned type too small"); uint64_t raw = 0; - s.read((char*)&raw, Bytes); - v = le64toh(raw); + if (BigEndian) { + s.read(((char*)&raw) + 8 - Bytes, Bytes); + v = static_cast<I>(be64toh(raw)); + } else { + s.read((char*)&raw, Bytes); + v = static_cast<I>(le64toh(raw)); + } } }; -/** Serialization wrapper class for big-endian integers. - * - * Use this wrapper around integer types that are stored in memory in native - * byte order, but serialized in big endian notation. This is only intended - * to implement serializers that are compatible with existing formats, and - * its use is not recommended for new data structures. - * - * Only 16-bit types are supported for now. - */ -template<typename I> -class BigEndian -{ -protected: - I& m_val; -public: - explicit BigEndian(I& val) : m_val(val) - { - static_assert(std::is_unsigned<I>::value, "BigEndian type must be unsigned integer"); - static_assert(sizeof(I) == 2 && std::numeric_limits<I>::min() == 0 && std::numeric_limits<I>::max() == std::numeric_limits<uint16_t>::max(), "Unsupported BigEndian size"); - } - - template<typename Stream> - void Serialize(Stream& s) const - { - ser_writedata16be(s, m_val); - } - - template<typename Stream> - void Unserialize(Stream& s) - { - m_val = ser_readdata16be(s); - } -}; +template<int Bytes> using BigEndianFormatter = CustomUintFormatter<Bytes, true>; /** Formatter for integers in CompactSize format. */ struct CompactSizeFormatter @@ -598,37 +552,26 @@ struct CompactSizeFormatter }; template<size_t Limit> -class LimitedString +struct LimitedStringFormatter { -protected: - std::string& string; -public: - explicit LimitedString(std::string& _string) : string(_string) {} - template<typename Stream> - void Unserialize(Stream& s) + void Unser(Stream& s, std::string& v) { size_t size = ReadCompactSize(s); if (size > Limit) { throw std::ios_base::failure("String length limit exceeded"); } - string.resize(size); - if (size != 0) - s.read((char*)string.data(), size); + v.resize(size); + if (size != 0) s.read((char*)v.data(), size); } template<typename Stream> - void Serialize(Stream& s) const + void Ser(Stream& s, const std::string& v) { - WriteCompactSize(s, string.size()); - if (!string.empty()) - s.write((char*)string.data(), string.size()); + s << v; } }; -template<typename I> -BigEndian<I> WrapBigEndian(I& n) { return BigEndian<I>(n); } - /** Formatter to serialize/deserialize vector elements using another formatter * * Example: @@ -1025,7 +968,7 @@ void Unserialize(Stream& is, std::shared_ptr<const T>& p) /** - * Support for ADD_SERIALIZE_METHODS and READWRITE macro + * Support for SERIALIZE_METHODS and READWRITE macro. */ struct CSerActionSerialize { @@ -1124,6 +1067,28 @@ inline void SerReadWriteMany(Stream& s, CSerActionUnserialize ser_action, Args&& ::UnserializeMany(s, args...); } +template<typename Stream, typename Type, typename Fn> +inline void SerRead(Stream& s, CSerActionSerialize ser_action, Type&&, Fn&&) +{ +} + +template<typename Stream, typename Type, typename Fn> +inline void SerRead(Stream& s, CSerActionUnserialize ser_action, Type&& obj, Fn&& fn) +{ + fn(s, std::forward<Type>(obj)); +} + +template<typename Stream, typename Type, typename Fn> +inline void SerWrite(Stream& s, CSerActionSerialize ser_action, Type&& obj, Fn&& fn) +{ + fn(s, std::forward<Type>(obj)); +} + +template<typename Stream, typename Type, typename Fn> +inline void SerWrite(Stream& s, CSerActionUnserialize ser_action, Type&&, Fn&&) +{ +} + template<typename I> inline void WriteVarInt(CSizeComputer &s, I n) { diff --git a/src/sync.cpp b/src/sync.cpp index b86c57e498..9abdedbed4 100644 --- a/src/sync.cpp +++ b/src/sync.cpp @@ -7,15 +7,19 @@ #endif #include <sync.h> -#include <tinyformat.h> #include <logging.h> +#include <tinyformat.h> #include <util/strencodings.h> #include <util/threadnames.h> #include <map> #include <set> #include <system_error> +#include <thread> +#include <unordered_map> +#include <utility> +#include <vector> #ifdef DEBUG_LOCKCONTENTION #if !defined(HAVE_THREAD_LOCAL) @@ -73,35 +77,35 @@ private: int sourceLine; }; -typedef std::vector<std::pair<void*, CLockLocation> > LockStack; -typedef std::map<std::pair<void*, void*>, LockStack> LockOrders; -typedef std::set<std::pair<void*, void*> > InvLockOrders; +using LockStackItem = std::pair<void*, CLockLocation>; +using LockStack = std::vector<LockStackItem>; +using LockStacks = std::unordered_map<std::thread::id, LockStack>; -struct LockData { - // Very ugly hack: as the global constructs and destructors run single - // threaded, we use this boolean to know whether LockData still exists, - // as DeleteLock can get called by global RecursiveMutex destructors - // after LockData disappears. - bool available; - LockData() : available(true) {} - ~LockData() { available = false; } +using LockPair = std::pair<void*, void*>; +using LockOrders = std::map<LockPair, LockStack>; +using InvLockOrders = std::set<LockPair>; +struct LockData { + LockStacks m_lock_stacks; LockOrders lockorders; InvLockOrders invlockorders; std::mutex dd_mutex; }; + LockData& GetLockData() { - static LockData lockdata; - return lockdata; + // This approach guarantees that the object is not destroyed until after its last use. + // The operating system automatically reclaims all the memory in a program's heap when that program exits. + // Since the ~LockData() destructor is never called, the LockData class and all + // its subclasses must have implicitly-defined destructors. + static LockData& lock_data = *new LockData(); + return lock_data; } -static thread_local LockStack g_lockstack; - -static void potential_deadlock_detected(const std::pair<void*, void*>& mismatch, const LockStack& s1, const LockStack& s2) +static void potential_deadlock_detected(const LockPair& mismatch, const LockStack& s1, const LockStack& s2) { LogPrintf("POTENTIAL DEADLOCK DETECTED\n"); LogPrintf("Previous lock order was:\n"); - for (const std::pair<void*, CLockLocation> & i : s2) { + for (const LockStackItem& i : s2) { if (i.first == mismatch.first) { LogPrintf(" (1)"); /* Continued */ } @@ -111,7 +115,7 @@ static void potential_deadlock_detected(const std::pair<void*, void*>& mismatch, LogPrintf(" %s\n", i.second.ToString()); } LogPrintf("Current lock order is:\n"); - for (const std::pair<void*, CLockLocation> & i : s1) { + for (const LockStackItem& i : s1) { if (i.first == mismatch.first) { LogPrintf(" (1)"); /* Continued */ } @@ -132,18 +136,18 @@ static void push_lock(void* c, const CLockLocation& locklocation) LockData& lockdata = GetLockData(); std::lock_guard<std::mutex> lock(lockdata.dd_mutex); - g_lockstack.push_back(std::make_pair(c, locklocation)); - - for (const std::pair<void*, CLockLocation>& i : g_lockstack) { + LockStack& lock_stack = lockdata.m_lock_stacks[std::this_thread::get_id()]; + lock_stack.emplace_back(c, locklocation); + for (const LockStackItem& i : lock_stack) { if (i.first == c) break; - std::pair<void*, void*> p1 = std::make_pair(i.first, c); + const LockPair p1 = std::make_pair(i.first, c); if (lockdata.lockorders.count(p1)) continue; - lockdata.lockorders.emplace(p1, g_lockstack); + lockdata.lockorders.emplace(p1, lock_stack); - std::pair<void*, void*> p2 = std::make_pair(c, i.first); + const LockPair p2 = std::make_pair(c, i.first); lockdata.invlockorders.insert(p2); if (lockdata.lockorders.count(p2)) potential_deadlock_detected(p1, lockdata.lockorders[p2], lockdata.lockorders[p1]); @@ -152,7 +156,14 @@ static void push_lock(void* c, const CLockLocation& locklocation) static void pop_lock() { - g_lockstack.pop_back(); + LockData& lockdata = GetLockData(); + std::lock_guard<std::mutex> lock(lockdata.dd_mutex); + + LockStack& lock_stack = lockdata.m_lock_stacks[std::this_thread::get_id()]; + lock_stack.pop_back(); + if (lock_stack.empty()) { + lockdata.m_lock_stacks.erase(std::this_thread::get_id()); + } } void EnterCritical(const char* pszName, const char* pszFile, int nLine, void* cs, bool fTry) @@ -162,11 +173,17 @@ void EnterCritical(const char* pszName, const char* pszFile, int nLine, void* cs void CheckLastCritical(void* cs, std::string& lockname, const char* guardname, const char* file, int line) { - if (!g_lockstack.empty()) { - const auto& lastlock = g_lockstack.back(); - if (lastlock.first == cs) { - lockname = lastlock.second.Name(); - return; + { + LockData& lockdata = GetLockData(); + std::lock_guard<std::mutex> lock(lockdata.dd_mutex); + + const LockStack& lock_stack = lockdata.m_lock_stacks[std::this_thread::get_id()]; + if (!lock_stack.empty()) { + const auto& lastlock = lock_stack.back(); + if (lastlock.first == cs) { + lockname = lastlock.second.Name(); + return; + } } } throw std::system_error(EPERM, std::generic_category(), strprintf("%s:%s %s was not most recent critical section locked", file, line, guardname)); @@ -179,49 +196,60 @@ void LeaveCritical() std::string LocksHeld() { + LockData& lockdata = GetLockData(); + std::lock_guard<std::mutex> lock(lockdata.dd_mutex); + + const LockStack& lock_stack = lockdata.m_lock_stacks[std::this_thread::get_id()]; std::string result; - for (const std::pair<void*, CLockLocation>& i : g_lockstack) + for (const LockStackItem& i : lock_stack) result += i.second.ToString() + std::string("\n"); return result; } -void AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs) +static bool LockHeld(void* mutex) +{ + LockData& lockdata = GetLockData(); + std::lock_guard<std::mutex> lock(lockdata.dd_mutex); + + const LockStack& lock_stack = lockdata.m_lock_stacks[std::this_thread::get_id()]; + for (const LockStackItem& i : lock_stack) { + if (i.first == mutex) return true; + } + + return false; +} + +template <typename MutexType> +void AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, MutexType* cs) { - for (const std::pair<void*, CLockLocation>& i : g_lockstack) - if (i.first == cs) - return; + if (LockHeld(cs)) return; tfm::format(std::cerr, "Assertion failed: lock %s not held in %s:%i; locks held:\n%s", pszName, pszFile, nLine, LocksHeld()); abort(); } +template void AssertLockHeldInternal(const char*, const char*, int, Mutex*); +template void AssertLockHeldInternal(const char*, const char*, int, RecursiveMutex*); void AssertLockNotHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs) { - for (const std::pair<void*, CLockLocation>& i : g_lockstack) { - if (i.first == cs) { - tfm::format(std::cerr, "Assertion failed: lock %s held in %s:%i; locks held:\n%s", pszName, pszFile, nLine, LocksHeld()); - abort(); - } - } + if (!LockHeld(cs)) return; + tfm::format(std::cerr, "Assertion failed: lock %s held in %s:%i; locks held:\n%s", pszName, pszFile, nLine, LocksHeld()); + abort(); } void DeleteLock(void* cs) { LockData& lockdata = GetLockData(); - if (!lockdata.available) { - // We're already shutting down. - return; - } std::lock_guard<std::mutex> lock(lockdata.dd_mutex); - std::pair<void*, void*> item = std::make_pair(cs, nullptr); + const LockPair item = std::make_pair(cs, nullptr); LockOrders::iterator it = lockdata.lockorders.lower_bound(item); while (it != lockdata.lockorders.end() && it->first.first == cs) { - std::pair<void*, void*> invitem = std::make_pair(it->first.second, it->first.first); + const LockPair invitem = std::make_pair(it->first.second, it->first.first); lockdata.invlockorders.erase(invitem); lockdata.lockorders.erase(it++); } InvLockOrders::iterator invit = lockdata.invlockorders.lower_bound(item); while (invit != lockdata.invlockorders.end() && invit->first == cs) { - std::pair<void*, void*> invinvitem = std::make_pair(invit->second, invit->first); + const LockPair invinvitem = std::make_pair(invit->second, invit->first); lockdata.lockorders.erase(invinvitem); lockdata.invlockorders.erase(invit++); } diff --git a/src/sync.h b/src/sync.h index 0c6f0ef0a7..60e5a87aec 100644 --- a/src/sync.h +++ b/src/sync.h @@ -52,7 +52,8 @@ void EnterCritical(const char* pszName, const char* pszFile, int nLine, void* cs void LeaveCritical(); void CheckLastCritical(void* cs, std::string& lockname, const char* guardname, const char* file, int line); std::string LocksHeld(); -void AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs) ASSERT_EXCLUSIVE_LOCK(cs); +template <typename MutexType> +void AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, MutexType* cs) ASSERT_EXCLUSIVE_LOCK(cs); void AssertLockNotHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs); void DeleteLock(void* cs); @@ -66,7 +67,8 @@ extern bool g_debug_lockorder_abort; void static inline EnterCritical(const char* pszName, const char* pszFile, int nLine, void* cs, bool fTry = false) {} void static inline LeaveCritical() {} void static inline CheckLastCritical(void* cs, std::string& lockname, const char* guardname, const char* file, int line) {} -void static inline AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs) ASSERT_EXCLUSIVE_LOCK(cs) {} +template <typename MutexType> +void static inline AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, MutexType* cs) ASSERT_EXCLUSIVE_LOCK(cs) {} void static inline AssertLockNotHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs) {} void static inline DeleteLock(void* cs) {} #endif diff --git a/src/test/blockencodings_tests.cpp b/src/test/blockencodings_tests.cpp index 8694891a51..14cf1a4a76 100644 --- a/src/test/blockencodings_tests.cpp +++ b/src/test/blockencodings_tests.cpp @@ -132,24 +132,7 @@ public: return base.GetShortID(txhash); } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(header); - READWRITE(nonce); - size_t shorttxids_size = shorttxids.size(); - READWRITE(VARINT(shorttxids_size)); - shorttxids.resize(shorttxids_size); - for (size_t i = 0; i < shorttxids.size(); i++) { - uint32_t lsb = shorttxids[i] & 0xffffffff; - uint16_t msb = (shorttxids[i] >> 32) & 0xffff; - READWRITE(lsb); - READWRITE(msb); - shorttxids[i] = (uint64_t(msb) << 32) | uint64_t(lsb); - } - READWRITE(prefilledtxn); - } + SERIALIZE_METHODS(TestHeaderAndShortIDs, obj) { READWRITE(obj.header, obj.nonce, Using<VectorFormatter<CustomUintFormatter<CBlockHeaderAndShortTxIDs::SHORTTXIDS_LENGTH>>>(obj.shorttxids), obj.prefilledtxn); } }; BOOST_AUTO_TEST_CASE(NonCoinbasePreforwardRTTest) diff --git a/src/test/blockfilter_index_tests.cpp b/src/test/blockfilter_index_tests.cpp index e5043f6816..7dff2e6e86 100644 --- a/src/test/blockfilter_index_tests.cpp +++ b/src/test/blockfilter_index_tests.cpp @@ -94,7 +94,7 @@ bool BuildChainTestingSetup::BuildChain(const CBlockIndex* pindex, CBlockHeader header = block->GetBlockHeader(); BlockValidationState state; - if (!ProcessNewBlockHeaders({header}, state, Params(), &pindex)) { + if (!EnsureChainman(m_node).ProcessNewBlockHeaders({header}, state, Params(), &pindex)) { return false; } } @@ -171,7 +171,7 @@ BOOST_FIXTURE_TEST_CASE(blockfilter_index_initial_sync, BuildChainTestingSetup) uint256 chainA_last_header = last_header; for (size_t i = 0; i < 2; i++) { const auto& block = chainA[i]; - BOOST_REQUIRE(ProcessNewBlock(Params(), block, true, nullptr)); + BOOST_REQUIRE(EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, nullptr)); } for (size_t i = 0; i < 2; i++) { const auto& block = chainA[i]; @@ -189,7 +189,7 @@ BOOST_FIXTURE_TEST_CASE(blockfilter_index_initial_sync, BuildChainTestingSetup) uint256 chainB_last_header = last_header; for (size_t i = 0; i < 3; i++) { const auto& block = chainB[i]; - BOOST_REQUIRE(ProcessNewBlock(Params(), block, true, nullptr)); + BOOST_REQUIRE(EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, nullptr)); } for (size_t i = 0; i < 3; i++) { const auto& block = chainB[i]; @@ -220,7 +220,7 @@ BOOST_FIXTURE_TEST_CASE(blockfilter_index_initial_sync, BuildChainTestingSetup) // Reorg back to chain A. for (size_t i = 2; i < 4; i++) { const auto& block = chainA[i]; - BOOST_REQUIRE(ProcessNewBlock(Params(), block, true, nullptr)); + BOOST_REQUIRE(EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, nullptr)); } // Check that chain A and B blocks can be retrieved. diff --git a/src/test/checkqueue_tests.cpp b/src/test/checkqueue_tests.cpp index 0565982215..35750b2ebc 100644 --- a/src/test/checkqueue_tests.cpp +++ b/src/test/checkqueue_tests.cpp @@ -3,6 +3,7 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <checkqueue.h> +#include <sync.h> #include <test/util/setup_common.h> #include <util/memory.h> #include <util/system.h> @@ -57,14 +58,14 @@ struct FailingCheck { }; struct UniqueCheck { - static std::mutex m; - static std::unordered_multiset<size_t> results; + static Mutex m; + static std::unordered_multiset<size_t> results GUARDED_BY(m); size_t check_id; UniqueCheck(size_t check_id_in) : check_id(check_id_in){}; UniqueCheck() : check_id(0){}; bool operator()() { - std::lock_guard<std::mutex> l(m); + LOCK(m); results.insert(check_id); return true; } @@ -127,7 +128,7 @@ struct FrozenCleanupCheck { std::mutex FrozenCleanupCheck::m{}; std::atomic<uint64_t> FrozenCleanupCheck::nFrozen{0}; std::condition_variable FrozenCleanupCheck::cv{}; -std::mutex UniqueCheck::m; +Mutex UniqueCheck::m; std::unordered_multiset<size_t> UniqueCheck::results; std::atomic<size_t> FakeCheckCheckCompletion::n_calls{0}; std::atomic<size_t> MemoryCheck::fake_allocated_memory{0}; @@ -290,11 +291,15 @@ BOOST_AUTO_TEST_CASE(test_CheckQueue_UniqueCheck) control.Add(vChecks); } } - bool r = true; - BOOST_REQUIRE_EQUAL(UniqueCheck::results.size(), COUNT); - for (size_t i = 0; i < COUNT; ++i) - r = r && UniqueCheck::results.count(i) == 1; - BOOST_REQUIRE(r); + { + LOCK(UniqueCheck::m); + bool r = true; + BOOST_REQUIRE_EQUAL(UniqueCheck::results.size(), COUNT); + for (size_t i = 0; i < COUNT; ++i) { + r = r && UniqueCheck::results.count(i) == 1; + } + BOOST_REQUIRE(r); + } tg.interrupt_all(); tg.join_all(); } diff --git a/src/test/dbwrapper_tests.cpp b/src/test/dbwrapper_tests.cpp index c378546e8b..3d802cbeb3 100644 --- a/src/test/dbwrapper_tests.cpp +++ b/src/test/dbwrapper_tests.cpp @@ -331,24 +331,26 @@ struct StringContentsSerializer { } StringContentsSerializer& operator+=(const StringContentsSerializer& s) { return *this += s.str; } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - if (ser_action.ForRead()) { - str.clear(); - char c = 0; - while (true) { - try { - READWRITE(c); - str.push_back(c); - } catch (const std::ios_base::failure&) { - break; - } + template<typename Stream> + void Serialize(Stream& s) const + { + for (size_t i = 0; i < str.size(); i++) { + s << str[i]; + } + } + + template<typename Stream> + void Unserialize(Stream& s) + { + str.clear(); + char c = 0; + while (true) { + try { + s >> c; + str.push_back(c); + } catch (const std::ios_base::failure&) { + break; } - } else { - for (size_t i = 0; i < str.size(); i++) - READWRITE(str[i]); } } }; diff --git a/src/test/denialofservice_tests.cpp b/src/test/denialofservice_tests.cpp index 75b38670c9..348b170536 100644 --- a/src/test/denialofservice_tests.cpp +++ b/src/test/denialofservice_tests.cpp @@ -78,7 +78,7 @@ BOOST_FIXTURE_TEST_SUITE(denialofservice_tests, TestingSetup) BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) { auto connman = MakeUnique<CConnman>(0x1337, 0x1337); - auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), nullptr, *m_node.scheduler, *m_node.mempool); + auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), nullptr, *m_node.scheduler, *m_node.chainman, *m_node.mempool); // Mock an outbound peer CAddress addr1(ip(0xa0b0c001), NODE_NONE); @@ -148,7 +148,7 @@ static void AddRandomOutboundPeer(std::vector<CNode *> &vNodes, PeerLogicValidat BOOST_AUTO_TEST_CASE(stale_tip_peer_management) { auto connman = MakeUnique<CConnmanTest>(0x1337, 0x1337); - auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), nullptr, *m_node.scheduler, *m_node.mempool); + auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), nullptr, *m_node.scheduler, *m_node.chainman, *m_node.mempool); const Consensus::Params& consensusParams = Params().GetConsensus(); constexpr int max_outbound_full_relay = MAX_OUTBOUND_FULL_RELAY_CONNECTIONS; @@ -221,7 +221,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) { auto banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", nullptr, DEFAULT_MISBEHAVING_BANTIME); auto connman = MakeUnique<CConnman>(0x1337, 0x1337); - auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), banman.get(), *m_node.scheduler, *m_node.mempool); + auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), banman.get(), *m_node.scheduler, *m_node.chainman, *m_node.mempool); banman->ClearBanned(); CAddress addr1(ip(0xa0b0c001), NODE_NONE); @@ -276,7 +276,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) { auto banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", nullptr, DEFAULT_MISBEHAVING_BANTIME); auto connman = MakeUnique<CConnman>(0x1337, 0x1337); - auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), banman.get(), *m_node.scheduler, *m_node.mempool); + auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), banman.get(), *m_node.scheduler, *m_node.chainman, *m_node.mempool); banman->ClearBanned(); gArgs.ForceSetArg("-banscore", "111"); // because 11 is my favorite number @@ -323,7 +323,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) { auto banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", nullptr, DEFAULT_MISBEHAVING_BANTIME); auto connman = MakeUnique<CConnman>(0x1337, 0x1337); - auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), banman.get(), *m_node.scheduler, *m_node.mempool); + auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), banman.get(), *m_node.scheduler, *m_node.chainman, *m_node.mempool); banman->ClearBanned(); int64_t nStartTime = GetTime(); diff --git a/src/test/fuzz/addrdb.cpp b/src/test/fuzz/addrdb.cpp index f21ff3fac3..524cea83fe 100644 --- a/src/test/fuzz/addrdb.cpp +++ b/src/test/fuzz/addrdb.cpp @@ -3,13 +3,13 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <addrdb.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <cassert> #include <cstdint> +#include <optional> #include <string> #include <vector> @@ -30,7 +30,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) })}; break; case 2: { - const Optional<CBanEntry> ban_entry = ConsumeDeserializable<CBanEntry>(fuzzed_data_provider); + const std::optional<CBanEntry> ban_entry = ConsumeDeserializable<CBanEntry>(fuzzed_data_provider); if (ban_entry) { return *ban_entry; } diff --git a/src/test/fuzz/asmap.cpp b/src/test/fuzz/asmap.cpp index ea56277eac..40ca01bd9f 100644 --- a/src/test/fuzz/asmap.cpp +++ b/src/test/fuzz/asmap.cpp @@ -23,8 +23,8 @@ static const std::vector<bool> IPV4_PREFIX_ASMAP = { true, true, false, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, // Match 0x00 true, true, false, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, // Match 0x00 true, true, false, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, // Match 0x00 - true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, // Match 0xFF - true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true // Match 0xFF + true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, // Match 0xFF + true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true // Match 0xFF }; void test_one_input(const std::vector<uint8_t>& buffer) diff --git a/src/test/fuzz/asmap_direct.cpp b/src/test/fuzz/asmap_direct.cpp index 6d8a65f5ab..2d21eff9d6 100644 --- a/src/test/fuzz/asmap_direct.cpp +++ b/src/test/fuzz/asmap_direct.cpp @@ -2,8 +2,8 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#include <util/asmap.h> #include <test/fuzz/fuzz.h> +#include <util/asmap.h> #include <cstdint> #include <optional> @@ -34,7 +34,9 @@ void test_one_input(const std::vector<uint8_t>& buffer) if (SanityCheckASMap(asmap, buffer.size() - 1 - sep_pos)) { // Verify that for valid asmaps, no prefix (except up to 7 zero padding bits) is valid. std::vector<bool> asmap_prefix = asmap; - while (!asmap_prefix.empty() && asmap_prefix.size() + 7 > asmap.size() && asmap_prefix.back() == false) asmap_prefix.pop_back(); + while (!asmap_prefix.empty() && asmap_prefix.size() + 7 > asmap.size() && asmap_prefix.back() == false) { + asmap_prefix.pop_back(); + } while (!asmap_prefix.empty()) { asmap_prefix.pop_back(); assert(!SanityCheckASMap(asmap_prefix, buffer.size() - 1 - sep_pos)); diff --git a/src/test/fuzz/block.cpp b/src/test/fuzz/block.cpp index f30fa03e0b..91bd34a251 100644 --- a/src/test/fuzz/block.cpp +++ b/src/test/fuzz/block.cpp @@ -38,12 +38,17 @@ void test_one_input(const std::vector<uint8_t>& buffer) const Consensus::Params& consensus_params = Params().GetConsensus(); BlockValidationState validation_state_pow_and_merkle; const bool valid_incl_pow_and_merkle = CheckBlock(block, validation_state_pow_and_merkle, consensus_params, /* fCheckPOW= */ true, /* fCheckMerkleRoot= */ true); + assert(validation_state_pow_and_merkle.IsValid() || validation_state_pow_and_merkle.IsInvalid() || validation_state_pow_and_merkle.IsError()); + (void)validation_state_pow_and_merkle.Error(""); BlockValidationState validation_state_pow; const bool valid_incl_pow = CheckBlock(block, validation_state_pow, consensus_params, /* fCheckPOW= */ true, /* fCheckMerkleRoot= */ false); + assert(validation_state_pow.IsValid() || validation_state_pow.IsInvalid() || validation_state_pow.IsError()); BlockValidationState validation_state_merkle; const bool valid_incl_merkle = CheckBlock(block, validation_state_merkle, consensus_params, /* fCheckPOW= */ false, /* fCheckMerkleRoot= */ true); + assert(validation_state_merkle.IsValid() || validation_state_merkle.IsInvalid() || validation_state_merkle.IsError()); BlockValidationState validation_state_none; const bool valid_incl_none = CheckBlock(block, validation_state_none, consensus_params, /* fCheckPOW= */ false, /* fCheckMerkleRoot= */ false); + assert(validation_state_none.IsValid() || validation_state_none.IsInvalid() || validation_state_none.IsError()); if (valid_incl_pow_and_merkle) { assert(valid_incl_pow && valid_incl_merkle && valid_incl_none); } else if (valid_incl_merkle || valid_incl_pow) { diff --git a/src/test/fuzz/block_header.cpp b/src/test/fuzz/block_header.cpp index 92dcccc0e1..09c2b4a951 100644 --- a/src/test/fuzz/block_header.cpp +++ b/src/test/fuzz/block_header.cpp @@ -2,7 +2,6 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#include <optional.h> #include <primitives/block.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> @@ -11,13 +10,14 @@ #include <cassert> #include <cstdint> +#include <optional> #include <string> #include <vector> void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - const Optional<CBlockHeader> block_header = ConsumeDeserializable<CBlockHeader>(fuzzed_data_provider); + const std::optional<CBlockHeader> block_header = ConsumeDeserializable<CBlockHeader>(fuzzed_data_provider); if (!block_header) { return; } @@ -38,4 +38,12 @@ void test_one_input(const std::vector<uint8_t>& buffer) block.SetNull(); assert(block.GetBlockHeader().GetHash() == mut_block_header.GetHash()); } + { + std::optional<CBlockLocator> block_locator = ConsumeDeserializable<CBlockLocator>(fuzzed_data_provider); + if (block_locator) { + (void)block_locator->IsNull(); + block_locator->SetNull(); + assert(block_locator->IsNull()); + } + } } diff --git a/src/test/fuzz/blockfilter.cpp b/src/test/fuzz/blockfilter.cpp index be9320dcbf..7232325a20 100644 --- a/src/test/fuzz/blockfilter.cpp +++ b/src/test/fuzz/blockfilter.cpp @@ -3,19 +3,19 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <blockfilter.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <cstdint> +#include <optional> #include <string> #include <vector> void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - const Optional<BlockFilter> block_filter = ConsumeDeserializable<BlockFilter>(fuzzed_data_provider); + const std::optional<BlockFilter> block_filter = ConsumeDeserializable<BlockFilter>(fuzzed_data_provider); if (!block_filter) { return; } diff --git a/src/test/fuzz/bloom_filter.cpp b/src/test/fuzz/bloom_filter.cpp index 7039bf16c1..d955c71bc9 100644 --- a/src/test/fuzz/bloom_filter.cpp +++ b/src/test/fuzz/bloom_filter.cpp @@ -3,7 +3,6 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <bloom.h> -#include <optional.h> #include <primitives/transaction.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> @@ -12,6 +11,7 @@ #include <cassert> #include <cstdint> +#include <optional> #include <string> #include <vector> @@ -35,7 +35,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) break; } case 1: { - const Optional<COutPoint> out_point = ConsumeDeserializable<COutPoint>(fuzzed_data_provider); + const std::optional<COutPoint> out_point = ConsumeDeserializable<COutPoint>(fuzzed_data_provider); if (!out_point) { break; } @@ -46,7 +46,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) break; } case 2: { - const Optional<uint256> u256 = ConsumeDeserializable<uint256>(fuzzed_data_provider); + const std::optional<uint256> u256 = ConsumeDeserializable<uint256>(fuzzed_data_provider); if (!u256) { break; } @@ -57,7 +57,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) break; } case 3: { - const Optional<CMutableTransaction> mut_tx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + const std::optional<CMutableTransaction> mut_tx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); if (!mut_tx) { break; } diff --git a/src/test/fuzz/chain.cpp b/src/test/fuzz/chain.cpp index b322516cc7..47c71850ce 100644 --- a/src/test/fuzz/chain.cpp +++ b/src/test/fuzz/chain.cpp @@ -3,18 +3,18 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <chain.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <cstdint> +#include <optional> #include <vector> void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - Optional<CDiskBlockIndex> disk_block_index = ConsumeDeserializable<CDiskBlockIndex>(fuzzed_data_provider); + std::optional<CDiskBlockIndex> disk_block_index = ConsumeDeserializable<CDiskBlockIndex>(fuzzed_data_provider); if (!disk_block_index) { return; } diff --git a/src/test/fuzz/checkqueue.cpp b/src/test/fuzz/checkqueue.cpp index 2ed097b827..c69043bb6b 100644 --- a/src/test/fuzz/checkqueue.cpp +++ b/src/test/fuzz/checkqueue.cpp @@ -3,7 +3,6 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <checkqueue.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> diff --git a/src/test/fuzz/coins_view.cpp b/src/test/fuzz/coins_view.cpp new file mode 100644 index 0000000000..52dd62a145 --- /dev/null +++ b/src/test/fuzz/coins_view.cpp @@ -0,0 +1,294 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <amount.h> +#include <chainparams.h> +#include <chainparamsbase.h> +#include <coins.h> +#include <consensus/tx_verify.h> +#include <consensus/validation.h> +#include <key.h> +#include <node/coinstats.h> +#include <policy/policy.h> +#include <primitives/transaction.h> +#include <pubkey.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> +#include <validation.h> + +#include <cstdint> +#include <limits> +#include <optional> +#include <string> +#include <vector> + +namespace { +const Coin EMPTY_COIN{}; + +bool operator==(const Coin& a, const Coin& b) +{ + if (a.IsSpent() && b.IsSpent()) return true; + return a.fCoinBase == b.fCoinBase && a.nHeight == b.nHeight && a.out == b.out; +} +} // namespace + +void initialize() +{ + static const ECCVerifyHandle ecc_verify_handle; + ECC_Start(); + SelectParams(CBaseChainParams::REGTEST); +} + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + CCoinsView backend_coins_view; + CCoinsViewCache coins_view_cache{&backend_coins_view}; + COutPoint random_out_point; + Coin random_coin; + CMutableTransaction random_mutable_transaction; + while (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 9)) { + case 0: { + if (random_coin.IsSpent()) { + break; + } + Coin coin = random_coin; + bool expected_code_path = false; + const bool possible_overwrite = fuzzed_data_provider.ConsumeBool(); + try { + coins_view_cache.AddCoin(random_out_point, std::move(coin), possible_overwrite); + expected_code_path = true; + } catch (const std::logic_error& e) { + if (e.what() == std::string{"Attempted to overwrite an unspent coin (when possible_overwrite is false)"}) { + assert(!possible_overwrite); + expected_code_path = true; + } + } + assert(expected_code_path); + break; + } + case 1: { + (void)coins_view_cache.Flush(); + break; + } + case 2: { + coins_view_cache.SetBestBlock(ConsumeUInt256(fuzzed_data_provider)); + break; + } + case 3: { + Coin move_to; + (void)coins_view_cache.SpendCoin(random_out_point, fuzzed_data_provider.ConsumeBool() ? &move_to : nullptr); + break; + } + case 4: { + coins_view_cache.Uncache(random_out_point); + break; + } + case 5: { + if (fuzzed_data_provider.ConsumeBool()) { + backend_coins_view = CCoinsView{}; + } + coins_view_cache.SetBackend(backend_coins_view); + break; + } + case 6: { + const std::optional<COutPoint> opt_out_point = ConsumeDeserializable<COutPoint>(fuzzed_data_provider); + if (!opt_out_point) { + break; + } + random_out_point = *opt_out_point; + break; + } + case 7: { + const std::optional<Coin> opt_coin = ConsumeDeserializable<Coin>(fuzzed_data_provider); + if (!opt_coin) { + break; + } + random_coin = *opt_coin; + break; + } + case 8: { + const std::optional<CMutableTransaction> opt_mutable_transaction = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + if (!opt_mutable_transaction) { + break; + } + random_mutable_transaction = *opt_mutable_transaction; + break; + } + case 9: { + CCoinsMap coins_map; + while (fuzzed_data_provider.ConsumeBool()) { + CCoinsCacheEntry coins_cache_entry; + coins_cache_entry.flags = fuzzed_data_provider.ConsumeIntegral<unsigned char>(); + if (fuzzed_data_provider.ConsumeBool()) { + coins_cache_entry.coin = random_coin; + } else { + const std::optional<Coin> opt_coin = ConsumeDeserializable<Coin>(fuzzed_data_provider); + if (!opt_coin) { + break; + } + coins_cache_entry.coin = *opt_coin; + } + coins_map.emplace(random_out_point, std::move(coins_cache_entry)); + } + bool expected_code_path = false; + try { + coins_view_cache.BatchWrite(coins_map, fuzzed_data_provider.ConsumeBool() ? ConsumeUInt256(fuzzed_data_provider) : coins_view_cache.GetBestBlock()); + expected_code_path = true; + } catch (const std::logic_error& e) { + if (e.what() == std::string{"FRESH flag misapplied to coin that exists in parent cache"}) { + expected_code_path = true; + } + } + assert(expected_code_path); + break; + } + } + } + + { + const Coin& coin_using_access_coin = coins_view_cache.AccessCoin(random_out_point); + const bool exists_using_access_coin = !(coin_using_access_coin == EMPTY_COIN); + const bool exists_using_have_coin = coins_view_cache.HaveCoin(random_out_point); + const bool exists_using_have_coin_in_cache = coins_view_cache.HaveCoinInCache(random_out_point); + Coin coin_using_get_coin; + const bool exists_using_get_coin = coins_view_cache.GetCoin(random_out_point, coin_using_get_coin); + if (exists_using_get_coin) { + assert(coin_using_get_coin == coin_using_access_coin); + } + assert((exists_using_access_coin && exists_using_have_coin_in_cache && exists_using_have_coin && exists_using_get_coin) || + (!exists_using_access_coin && !exists_using_have_coin_in_cache && !exists_using_have_coin && !exists_using_get_coin)); + const bool exists_using_have_coin_in_backend = backend_coins_view.HaveCoin(random_out_point); + if (exists_using_have_coin_in_backend) { + assert(exists_using_have_coin); + } + Coin coin_using_backend_get_coin; + if (backend_coins_view.GetCoin(random_out_point, coin_using_backend_get_coin)) { + assert(exists_using_have_coin_in_backend); + assert(coin_using_get_coin == coin_using_backend_get_coin); + } else { + assert(!exists_using_have_coin_in_backend); + } + } + + { + bool expected_code_path = false; + try { + (void)coins_view_cache.Cursor(); + } catch (const std::logic_error&) { + expected_code_path = true; + } + assert(expected_code_path); + (void)coins_view_cache.DynamicMemoryUsage(); + (void)coins_view_cache.EstimateSize(); + (void)coins_view_cache.GetBestBlock(); + (void)coins_view_cache.GetCacheSize(); + (void)coins_view_cache.GetHeadBlocks(); + (void)coins_view_cache.HaveInputs(CTransaction{random_mutable_transaction}); + } + + { + const CCoinsViewCursor* coins_view_cursor = backend_coins_view.Cursor(); + assert(coins_view_cursor == nullptr); + (void)backend_coins_view.EstimateSize(); + (void)backend_coins_view.GetBestBlock(); + (void)backend_coins_view.GetHeadBlocks(); + } + + if (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 6)) { + case 0: { + const CTransaction transaction{random_mutable_transaction}; + bool is_spent = false; + for (const CTxOut& tx_out : transaction.vout) { + if (Coin{tx_out, 0, transaction.IsCoinBase()}.IsSpent()) { + is_spent = true; + } + } + if (is_spent) { + // Avoid: + // coins.cpp:69: void CCoinsViewCache::AddCoin(const COutPoint &, Coin &&, bool): Assertion `!coin.IsSpent()' failed. + break; + } + bool expected_code_path = false; + const int height = fuzzed_data_provider.ConsumeIntegral<int>(); + const bool possible_overwrite = fuzzed_data_provider.ConsumeBool(); + try { + AddCoins(coins_view_cache, transaction, height, possible_overwrite); + expected_code_path = true; + } catch (const std::logic_error& e) { + if (e.what() == std::string{"Attempted to overwrite an unspent coin (when possible_overwrite is false)"}) { + assert(!possible_overwrite); + expected_code_path = true; + } + } + assert(expected_code_path); + break; + } + case 1: { + (void)AreInputsStandard(CTransaction{random_mutable_transaction}, coins_view_cache); + break; + } + case 2: { + TxValidationState state; + CAmount tx_fee_out; + const CTransaction transaction{random_mutable_transaction}; + if (ContainsSpentInput(transaction, coins_view_cache)) { + // Avoid: + // consensus/tx_verify.cpp:171: bool Consensus::CheckTxInputs(const CTransaction &, TxValidationState &, const CCoinsViewCache &, int, CAmount &): Assertion `!coin.IsSpent()' failed. + break; + } + try { + (void)Consensus::CheckTxInputs(transaction, state, coins_view_cache, fuzzed_data_provider.ConsumeIntegralInRange<int>(0, std::numeric_limits<int>::max()), tx_fee_out); + assert(MoneyRange(tx_fee_out)); + } catch (const std::runtime_error&) { + } + break; + } + case 3: { + const CTransaction transaction{random_mutable_transaction}; + if (ContainsSpentInput(transaction, coins_view_cache)) { + // Avoid: + // consensus/tx_verify.cpp:130: unsigned int GetP2SHSigOpCount(const CTransaction &, const CCoinsViewCache &): Assertion `!coin.IsSpent()' failed. + break; + } + (void)GetP2SHSigOpCount(transaction, coins_view_cache); + break; + } + case 4: { + const CTransaction transaction{random_mutable_transaction}; + if (ContainsSpentInput(transaction, coins_view_cache)) { + // Avoid: + // consensus/tx_verify.cpp:130: unsigned int GetP2SHSigOpCount(const CTransaction &, const CCoinsViewCache &): Assertion `!coin.IsSpent()' failed. + break; + } + const int flags = fuzzed_data_provider.ConsumeIntegral<int>(); + if (!transaction.vin.empty() && (flags & SCRIPT_VERIFY_WITNESS) != 0 && (flags & SCRIPT_VERIFY_P2SH) == 0) { + // Avoid: + // script/interpreter.cpp:1705: size_t CountWitnessSigOps(const CScript &, const CScript &, const CScriptWitness *, unsigned int): Assertion `(flags & SCRIPT_VERIFY_P2SH) != 0' failed. + break; + } + (void)GetTransactionSigOpCost(transaction, coins_view_cache, flags); + break; + } + case 5: { + CCoinsStats stats; + bool expected_code_path = false; + try { + (void)GetUTXOStats(&coins_view_cache, stats); + } catch (const std::logic_error&) { + expected_code_path = true; + } + assert(expected_code_path); + break; + } + case 6: { + (void)IsWitnessStandard(CTransaction{random_mutable_transaction}, coins_view_cache); + break; + } + } + } +} diff --git a/src/test/fuzz/cuckoocache.cpp b/src/test/fuzz/cuckoocache.cpp index f674efe1b1..5b45aa79d8 100644 --- a/src/test/fuzz/cuckoocache.cpp +++ b/src/test/fuzz/cuckoocache.cpp @@ -3,7 +3,6 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <cuckoocache.h> -#include <optional.h> #include <script/sigcache.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> diff --git a/src/test/fuzz/fees.cpp b/src/test/fuzz/fees.cpp index f29acace23..ce8700befa 100644 --- a/src/test/fuzz/fees.cpp +++ b/src/test/fuzz/fees.cpp @@ -3,7 +3,6 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <amount.h> -#include <optional.h> #include <policy/fees.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> diff --git a/src/test/fuzz/flatfile.cpp b/src/test/fuzz/flatfile.cpp index a55de77df7..95dabb8bab 100644 --- a/src/test/fuzz/flatfile.cpp +++ b/src/test/fuzz/flatfile.cpp @@ -3,24 +3,24 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <flatfile.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <cassert> #include <cstdint> +#include <optional> #include <string> #include <vector> void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - Optional<FlatFilePos> flat_file_pos = ConsumeDeserializable<FlatFilePos>(fuzzed_data_provider); + std::optional<FlatFilePos> flat_file_pos = ConsumeDeserializable<FlatFilePos>(fuzzed_data_provider); if (!flat_file_pos) { return; } - Optional<FlatFilePos> another_flat_file_pos = ConsumeDeserializable<FlatFilePos>(fuzzed_data_provider); + std::optional<FlatFilePos> another_flat_file_pos = ConsumeDeserializable<FlatFilePos>(fuzzed_data_provider); if (another_flat_file_pos) { assert((*flat_file_pos == *another_flat_file_pos) != (*flat_file_pos != *another_flat_file_pos)); } diff --git a/src/test/fuzz/fuzz.cpp b/src/test/fuzz/fuzz.cpp index 6e2188fe86..82e1d55c0b 100644 --- a/src/test/fuzz/fuzz.cpp +++ b/src/test/fuzz/fuzz.cpp @@ -19,8 +19,6 @@ static bool read_stdin(std::vector<uint8_t>& data) ssize_t length = 0; while ((length = read(STDIN_FILENO, buffer, 1024)) > 0) { data.insert(data.end(), buffer, buffer + length); - - if (data.size() > (1 << 20)) return false; } return length == 0; } diff --git a/src/test/fuzz/golomb_rice.cpp b/src/test/fuzz/golomb_rice.cpp index 3e20416116..a9f450b0c4 100644 --- a/src/test/fuzz/golomb_rice.cpp +++ b/src/test/fuzz/golomb_rice.cpp @@ -5,8 +5,8 @@ #include <blockfilter.h> #include <serialize.h> #include <streams.h> -#include <test/fuzz/fuzz.h> #include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <util/bytevectorhash.h> #include <util/golombrice.h> diff --git a/src/test/fuzz/hex.cpp b/src/test/fuzz/hex.cpp index 5fed17c17c..6a8699fd0f 100644 --- a/src/test/fuzz/hex.cpp +++ b/src/test/fuzz/hex.cpp @@ -16,7 +16,8 @@ #include <string> #include <vector> -void initialize() { +void initialize() +{ static const ECCVerifyHandle verify_handle; } diff --git a/src/test/fuzz/merkleblock.cpp b/src/test/fuzz/merkleblock.cpp index eb8fa1d421..c44e334272 100644 --- a/src/test/fuzz/merkleblock.cpp +++ b/src/test/fuzz/merkleblock.cpp @@ -3,20 +3,20 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <merkleblock.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <uint256.h> #include <cstdint> +#include <optional> #include <string> #include <vector> void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - Optional<CPartialMerkleTree> partial_merkle_tree = ConsumeDeserializable<CPartialMerkleTree>(fuzzed_data_provider); + std::optional<CPartialMerkleTree> partial_merkle_tree = ConsumeDeserializable<CPartialMerkleTree>(fuzzed_data_provider); if (!partial_merkle_tree) { return; } diff --git a/src/test/fuzz/message.cpp b/src/test/fuzz/message.cpp index dfa98a812b..fa0322a391 100644 --- a/src/test/fuzz/message.cpp +++ b/src/test/fuzz/message.cpp @@ -4,7 +4,6 @@ #include <chainparams.h> #include <key_io.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> diff --git a/src/test/fuzz/net_permissions.cpp b/src/test/fuzz/net_permissions.cpp index bfc5d21427..c071283467 100644 --- a/src/test/fuzz/net_permissions.cpp +++ b/src/test/fuzz/net_permissions.cpp @@ -3,7 +3,6 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <net_permissions.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> diff --git a/src/test/fuzz/policy_estimator.cpp b/src/test/fuzz/policy_estimator.cpp index 201f49c87b..1cbf9b347f 100644 --- a/src/test/fuzz/policy_estimator.cpp +++ b/src/test/fuzz/policy_estimator.cpp @@ -2,7 +2,6 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#include <optional.h> #include <policy/fees.h> #include <primitives/transaction.h> #include <test/fuzz/FuzzedDataProvider.h> @@ -11,6 +10,7 @@ #include <txmempool.h> #include <cstdint> +#include <optional> #include <string> #include <vector> @@ -21,7 +21,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) while (fuzzed_data_provider.ConsumeBool()) { switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 3)) { case 0: { - const Optional<CMutableTransaction> mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + const std::optional<CMutableTransaction> mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); if (!mtx) { break; } @@ -35,7 +35,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) case 1: { std::vector<CTxMemPoolEntry> mempool_entries; while (fuzzed_data_provider.ConsumeBool()) { - const Optional<CMutableTransaction> mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + const std::optional<CMutableTransaction> mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); if (!mtx) { break; } diff --git a/src/test/fuzz/pow.cpp b/src/test/fuzz/pow.cpp index 0343d33401..b7fc72373d 100644 --- a/src/test/fuzz/pow.cpp +++ b/src/test/fuzz/pow.cpp @@ -4,7 +4,6 @@ #include <chain.h> #include <chainparams.h> -#include <optional.h> #include <pow.h> #include <primitives/block.h> #include <test/fuzz/FuzzedDataProvider.h> @@ -12,6 +11,7 @@ #include <test/fuzz/util.h> #include <cstdint> +#include <optional> #include <string> #include <vector> @@ -28,7 +28,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) const uint32_t fixed_time = fuzzed_data_provider.ConsumeIntegral<uint32_t>(); const uint32_t fixed_bits = fuzzed_data_provider.ConsumeIntegral<uint32_t>(); while (fuzzed_data_provider.remaining_bytes() > 0) { - const Optional<CBlockHeader> block_header = ConsumeDeserializable<CBlockHeader>(fuzzed_data_provider); + const std::optional<CBlockHeader> block_header = ConsumeDeserializable<CBlockHeader>(fuzzed_data_provider); if (!block_header) { continue; } @@ -72,7 +72,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) } } { - const Optional<uint256> hash = ConsumeDeserializable<uint256>(fuzzed_data_provider); + const std::optional<uint256> hash = ConsumeDeserializable<uint256>(fuzzed_data_provider); if (hash) { (void)CheckProofOfWork(*hash, fuzzed_data_provider.ConsumeIntegral<unsigned int>(), consensus_params); } diff --git a/src/test/fuzz/prevector.cpp b/src/test/fuzz/prevector.cpp index 64920f4af5..626e187cbd 100644 --- a/src/test/fuzz/prevector.cpp +++ b/src/test/fuzz/prevector.cpp @@ -14,8 +14,9 @@ namespace { -template<unsigned int N, typename T> -class prevector_tester { +template <unsigned int N, typename T> +class prevector_tester +{ typedef std::vector<T> realtype; realtype real_vector; realtype real_vector_alt; @@ -27,35 +28,36 @@ class prevector_tester { typedef typename pretype::size_type Size; public: - void test() const { + void test() const + { const pretype& const_pre_vector = pre_vector; assert(real_vector.size() == pre_vector.size()); assert(real_vector.empty() == pre_vector.empty()); for (Size s = 0; s < real_vector.size(); s++) { - assert(real_vector[s] == pre_vector[s]); - assert(&(pre_vector[s]) == &(pre_vector.begin()[s])); - assert(&(pre_vector[s]) == &*(pre_vector.begin() + s)); - assert(&(pre_vector[s]) == &*((pre_vector.end() + s) - real_vector.size())); + assert(real_vector[s] == pre_vector[s]); + assert(&(pre_vector[s]) == &(pre_vector.begin()[s])); + assert(&(pre_vector[s]) == &*(pre_vector.begin() + s)); + assert(&(pre_vector[s]) == &*((pre_vector.end() + s) - real_vector.size())); } // assert(realtype(pre_vector) == real_vector); assert(pretype(real_vector.begin(), real_vector.end()) == pre_vector); assert(pretype(pre_vector.begin(), pre_vector.end()) == pre_vector); size_t pos = 0; for (const T& v : pre_vector) { - assert(v == real_vector[pos]); - ++pos; + assert(v == real_vector[pos]); + ++pos; } for (const T& v : reverse_iterate(pre_vector)) { - --pos; - assert(v == real_vector[pos]); + --pos; + assert(v == real_vector[pos]); } for (const T& v : const_pre_vector) { - assert(v == real_vector[pos]); - ++pos; + assert(v == real_vector[pos]); + ++pos; } for (const T& v : reverse_iterate(const_pre_vector)) { - --pos; - assert(v == real_vector[pos]); + --pos; + assert(v == real_vector[pos]); } CDataStream ss1(SER_DISK, 0); CDataStream ss2(SER_DISK, 0); @@ -67,101 +69,120 @@ public: } } - void resize(Size s) { + void resize(Size s) + { real_vector.resize(s); assert(real_vector.size() == s); pre_vector.resize(s); assert(pre_vector.size() == s); } - void reserve(Size s) { + void reserve(Size s) + { real_vector.reserve(s); assert(real_vector.capacity() >= s); pre_vector.reserve(s); assert(pre_vector.capacity() >= s); } - void insert(Size position, const T& value) { + void insert(Size position, const T& value) + { real_vector.insert(real_vector.begin() + position, value); pre_vector.insert(pre_vector.begin() + position, value); } - void insert(Size position, Size count, const T& value) { + void insert(Size position, Size count, const T& value) + { real_vector.insert(real_vector.begin() + position, count, value); pre_vector.insert(pre_vector.begin() + position, count, value); } - template<typename I> - void insert_range(Size position, I first, I last) { + template <typename I> + void insert_range(Size position, I first, I last) + { real_vector.insert(real_vector.begin() + position, first, last); pre_vector.insert(pre_vector.begin() + position, first, last); } - void erase(Size position) { + void erase(Size position) + { real_vector.erase(real_vector.begin() + position); pre_vector.erase(pre_vector.begin() + position); } - void erase(Size first, Size last) { + void erase(Size first, Size last) + { real_vector.erase(real_vector.begin() + first, real_vector.begin() + last); pre_vector.erase(pre_vector.begin() + first, pre_vector.begin() + last); } - void update(Size pos, const T& value) { + void update(Size pos, const T& value) + { real_vector[pos] = value; pre_vector[pos] = value; } - void push_back(const T& value) { + void push_back(const T& value) + { real_vector.push_back(value); pre_vector.push_back(value); } - void pop_back() { + void pop_back() + { real_vector.pop_back(); pre_vector.pop_back(); } - void clear() { + void clear() + { real_vector.clear(); pre_vector.clear(); } - void assign(Size n, const T& value) { + void assign(Size n, const T& value) + { real_vector.assign(n, value); pre_vector.assign(n, value); } - Size size() const { + Size size() const + { return real_vector.size(); } - Size capacity() const { + Size capacity() const + { return pre_vector.capacity(); } - void shrink_to_fit() { + void shrink_to_fit() + { pre_vector.shrink_to_fit(); } - void swap() { + void swap() + { real_vector.swap(real_vector_alt); pre_vector.swap(pre_vector_alt); } - void move() { + void move() + { real_vector = std::move(real_vector_alt); real_vector_alt.clear(); pre_vector = std::move(pre_vector_alt); pre_vector_alt.clear(); } - void copy() { + void copy() + { real_vector = real_vector_alt; pre_vector = pre_vector_alt; } - void resize_uninitialized(realtype values) { + void resize_uninitialized(realtype values) + { size_t r = values.size(); size_t s = real_vector.size() / 2; if (real_vector.capacity() < s + r) { @@ -181,7 +202,7 @@ public: } }; -} +} // namespace void test_one_input(const std::vector<uint8_t>& buffer) { diff --git a/src/test/fuzz/primitives_transaction.cpp b/src/test/fuzz/primitives_transaction.cpp index 2e5ba6bdb0..4a0f920f58 100644 --- a/src/test/fuzz/primitives_transaction.cpp +++ b/src/test/fuzz/primitives_transaction.cpp @@ -2,13 +2,13 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#include <optional.h> #include <primitives/transaction.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <cstdint> +#include <optional> #include <string> #include <vector> @@ -16,7 +16,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); const CScript script = ConsumeScript(fuzzed_data_provider); - const Optional<COutPoint> out_point = ConsumeDeserializable<COutPoint>(fuzzed_data_provider); + const std::optional<COutPoint> out_point = ConsumeDeserializable<COutPoint>(fuzzed_data_provider); if (out_point) { const CTxIn tx_in{*out_point, script, fuzzed_data_provider.ConsumeIntegral<uint32_t>()}; (void)tx_in; @@ -24,8 +24,8 @@ void test_one_input(const std::vector<uint8_t>& buffer) const CTxOut tx_out_1{ConsumeMoney(fuzzed_data_provider), script}; const CTxOut tx_out_2{ConsumeMoney(fuzzed_data_provider), ConsumeScript(fuzzed_data_provider)}; assert((tx_out_1 == tx_out_2) != (tx_out_1 != tx_out_2)); - const Optional<CMutableTransaction> mutable_tx_1 = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); - const Optional<CMutableTransaction> mutable_tx_2 = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + const std::optional<CMutableTransaction> mutable_tx_1 = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + const std::optional<CMutableTransaction> mutable_tx_2 = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); if (mutable_tx_1 && mutable_tx_2) { const CTransaction tx_1{*mutable_tx_1}; const CTransaction tx_2{*mutable_tx_2}; diff --git a/src/test/fuzz/process_message.cpp b/src/test/fuzz/process_message.cpp index c03365199a..665a6224b4 100644 --- a/src/test/fuzz/process_message.cpp +++ b/src/test/fuzz/process_message.cpp @@ -29,7 +29,7 @@ #include <string> #include <vector> -bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, CTxMemPool& mempool, CConnman* connman, BanMan* banman, const std::atomic<bool>& interruptMsgProc); +bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, ChainstateManager& chainman, CTxMemPool& mempool, CConnman* connman, BanMan* banman, const std::atomic<bool>& interruptMsgProc); namespace { @@ -74,7 +74,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) p2p_node.SetSendVersion(PROTOCOL_VERSION); g_setup->m_node.peer_logic->InitializeNode(&p2p_node); try { - (void)ProcessMessage(&p2p_node, random_message_type, random_bytes_data_stream, GetTimeMillis(), Params(), *g_setup->m_node.mempool, g_setup->m_node.connman.get(), g_setup->m_node.banman.get(), std::atomic<bool>{false}); + (void)ProcessMessage(&p2p_node, random_message_type, random_bytes_data_stream, GetTimeMillis(), Params(), *g_setup->m_node.chainman, *g_setup->m_node.mempool, g_setup->m_node.connman.get(), g_setup->m_node.banman.get(), std::atomic<bool>{false}); } catch (const std::ios_base::failure&) { } SyncWithValidationInterfaceQueue(); diff --git a/src/test/fuzz/protocol.cpp b/src/test/fuzz/protocol.cpp index 954471de6c..78df0f89e7 100644 --- a/src/test/fuzz/protocol.cpp +++ b/src/test/fuzz/protocol.cpp @@ -2,20 +2,20 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#include <optional.h> #include <protocol.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <cstdint> +#include <optional> #include <stdexcept> #include <vector> void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - const Optional<CInv> inv = ConsumeDeserializable<CInv>(fuzzed_data_provider); + const std::optional<CInv> inv = ConsumeDeserializable<CInv>(fuzzed_data_provider); if (!inv) { return; } @@ -24,7 +24,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) } catch (const std::out_of_range&) { } (void)inv->ToString(); - const Optional<CInv> another_inv = ConsumeDeserializable<CInv>(fuzzed_data_provider); + const std::optional<CInv> another_inv = ConsumeDeserializable<CInv>(fuzzed_data_provider); if (!another_inv) { return; } diff --git a/src/test/fuzz/rbf.cpp b/src/test/fuzz/rbf.cpp index eb54b05df9..1fd88a5f7b 100644 --- a/src/test/fuzz/rbf.cpp +++ b/src/test/fuzz/rbf.cpp @@ -2,7 +2,6 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#include <optional.h> #include <policy/rbf.h> #include <primitives/transaction.h> #include <sync.h> @@ -12,19 +11,20 @@ #include <txmempool.h> #include <cstdint> +#include <optional> #include <string> #include <vector> void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - Optional<CMutableTransaction> mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + std::optional<CMutableTransaction> mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); if (!mtx) { return; } CTxMemPool pool; while (fuzzed_data_provider.ConsumeBool()) { - const Optional<CMutableTransaction> another_mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); + const std::optional<CMutableTransaction> another_mtx = ConsumeDeserializable<CMutableTransaction>(fuzzed_data_provider); if (!another_mtx) { break; } diff --git a/src/test/fuzz/rolling_bloom_filter.cpp b/src/test/fuzz/rolling_bloom_filter.cpp index 3b37321977..623b8cff3a 100644 --- a/src/test/fuzz/rolling_bloom_filter.cpp +++ b/src/test/fuzz/rolling_bloom_filter.cpp @@ -3,7 +3,6 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <bloom.h> -#include <optional.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> @@ -11,6 +10,7 @@ #include <cassert> #include <cstdint> +#include <optional> #include <string> #include <vector> @@ -32,7 +32,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) break; } case 1: { - const Optional<uint256> u256 = ConsumeDeserializable<uint256>(fuzzed_data_provider); + const std::optional<uint256> u256 = ConsumeDeserializable<uint256>(fuzzed_data_provider); if (!u256) { break; } diff --git a/src/test/fuzz/script.cpp b/src/test/fuzz/script.cpp index de82122dd6..e0c4ad7eb7 100644 --- a/src/test/fuzz/script.cpp +++ b/src/test/fuzz/script.cpp @@ -21,6 +21,11 @@ #include <univalue.h> #include <util/memory.h> +#include <cstdint> +#include <optional> +#include <string> +#include <vector> + void initialize() { // Fuzzers using pubkey must hold an ECCVerifyHandle. @@ -32,7 +37,7 @@ void initialize() void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); - const Optional<CScript> script_opt = ConsumeDeserializable<CScript>(fuzzed_data_provider); + const std::optional<CScript> script_opt = ConsumeDeserializable<CScript>(fuzzed_data_provider); if (!script_opt) return; const CScript script{*script_opt}; @@ -101,7 +106,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) } } - const Optional<CScript> other_script = ConsumeDeserializable<CScript>(fuzzed_data_provider); + const std::optional<CScript> other_script = ConsumeDeserializable<CScript>(fuzzed_data_provider); if (other_script) { { CScript script_mut{script}; diff --git a/src/test/fuzz/string.cpp b/src/test/fuzz/string.cpp index 49bee0e81f..50984b1aef 100644 --- a/src/test/fuzz/string.cpp +++ b/src/test/fuzz/string.cpp @@ -93,7 +93,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) { CDataStream data_stream{SER_NETWORK, INIT_PROTO_VERSION}; std::string s; - LimitedString<10> limited_string = LIMITED_STRING(s, 10); + auto limited_string = LIMITED_STRING(s, 10); data_stream << random_string_1; try { data_stream >> limited_string; @@ -108,7 +108,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) } { CDataStream data_stream{SER_NETWORK, INIT_PROTO_VERSION}; - const LimitedString<10> limited_string = LIMITED_STRING(random_string_1, 10); + const auto limited_string = LIMITED_STRING(random_string_1, 10); data_stream << limited_string; std::string deserialized_string; data_stream >> deserialized_string; @@ -119,4 +119,10 @@ void test_one_input(const std::vector<uint8_t>& buffer) int64_t amount_out; (void)ParseFixedPoint(random_string_1, fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 1024), &amount_out); } + { + (void)Untranslated(random_string_1); + const bilingual_str bs1{random_string_1, random_string_2}; + const bilingual_str bs2{random_string_2, random_string_1}; + (void)(bs1 + bs2); + } } diff --git a/src/test/fuzz/strprintf.cpp b/src/test/fuzz/strprintf.cpp index d5be1070bd..29064bc45c 100644 --- a/src/test/fuzz/strprintf.cpp +++ b/src/test/fuzz/strprintf.cpp @@ -6,6 +6,7 @@ #include <test/fuzz/fuzz.h> #include <tinyformat.h> #include <util/strencodings.h> +#include <util/translation.h> #include <algorithm> #include <cstdint> @@ -16,6 +17,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); const std::string format_string = fuzzed_data_provider.ConsumeRandomLengthString(64); + const bilingual_str bilingual_string{format_string, format_string}; const int digits_in_format_specifier = std::count_if(format_string.begin(), format_string.end(), IsDigit); @@ -47,50 +49,62 @@ void test_one_input(const std::vector<uint8_t>& buffer) try { (void)strprintf(format_string, (signed char*)nullptr); + (void)tinyformat::format(bilingual_string, (signed char*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (unsigned char*)nullptr); + (void)tinyformat::format(bilingual_string, (unsigned char*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (void*)nullptr); + (void)tinyformat::format(bilingual_string, (void*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (bool*)nullptr); + (void)tinyformat::format(bilingual_string, (bool*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (float*)nullptr); + (void)tinyformat::format(bilingual_string, (float*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (double*)nullptr); + (void)tinyformat::format(bilingual_string, (double*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (int16_t*)nullptr); + (void)tinyformat::format(bilingual_string, (int16_t*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (uint16_t*)nullptr); + (void)tinyformat::format(bilingual_string, (uint16_t*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (int32_t*)nullptr); + (void)tinyformat::format(bilingual_string, (int32_t*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (uint32_t*)nullptr); + (void)tinyformat::format(bilingual_string, (uint32_t*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (int64_t*)nullptr); + (void)tinyformat::format(bilingual_string, (int64_t*)nullptr); } catch (const tinyformat::format_error&) { } try { (void)strprintf(format_string, (uint64_t*)nullptr); + (void)tinyformat::format(bilingual_string, (uint64_t*)nullptr); } catch (const tinyformat::format_error&) { } @@ -98,21 +112,27 @@ void test_one_input(const std::vector<uint8_t>& buffer) switch (fuzzed_data_provider.ConsumeIntegralInRange(0, 5)) { case 0: (void)strprintf(format_string, fuzzed_data_provider.ConsumeRandomLengthString(32)); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeRandomLengthString(32)); break; case 1: (void)strprintf(format_string, fuzzed_data_provider.ConsumeRandomLengthString(32).c_str()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeRandomLengthString(32).c_str()); break; case 2: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<signed char>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<signed char>()); break; case 3: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<unsigned char>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<unsigned char>()); break; case 4: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<char>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<char>()); break; case 5: (void)strprintf(format_string, fuzzed_data_provider.ConsumeBool()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeBool()); break; } } catch (const tinyformat::format_error&) { @@ -138,27 +158,35 @@ void test_one_input(const std::vector<uint8_t>& buffer) switch (fuzzed_data_provider.ConsumeIntegralInRange(0, 7)) { case 0: (void)strprintf(format_string, fuzzed_data_provider.ConsumeFloatingPoint<float>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeFloatingPoint<float>()); break; case 1: (void)strprintf(format_string, fuzzed_data_provider.ConsumeFloatingPoint<double>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeFloatingPoint<double>()); break; case 2: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<int16_t>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<int16_t>()); break; case 3: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<uint16_t>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<uint16_t>()); break; case 4: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<int32_t>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<int32_t>()); break; case 5: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<uint32_t>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<uint32_t>()); break; case 6: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<int64_t>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<int64_t>()); break; case 7: (void)strprintf(format_string, fuzzed_data_provider.ConsumeIntegral<uint64_t>()); + (void)tinyformat::format(bilingual_string, fuzzed_data_provider.ConsumeIntegral<uint64_t>()); break; } } catch (const tinyformat::format_error&) { diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index 501bb1de5a..9d0fb02128 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -8,8 +8,8 @@ #include <amount.h> #include <arith_uint256.h> #include <attributes.h> +#include <coins.h> #include <consensus/consensus.h> -#include <optional.h> #include <primitives/transaction.h> #include <script/script.h> #include <serialize.h> @@ -21,6 +21,7 @@ #include <version.h> #include <cstdint> +#include <optional> #include <string> #include <vector> @@ -52,7 +53,7 @@ NODISCARD inline std::vector<T> ConsumeRandomLengthIntegralVector(FuzzedDataProv } template <typename T> -NODISCARD inline Optional<T> ConsumeDeserializable(FuzzedDataProvider& fuzzed_data_provider, const size_t max_length = 4096) noexcept +NODISCARD inline std::optional<T> ConsumeDeserializable(FuzzedDataProvider& fuzzed_data_provider, const size_t max_length = 4096) noexcept { const std::vector<uint8_t> buffer = ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length); CDataStream ds{buffer, SER_NETWORK, INIT_PROTO_VERSION}; @@ -60,7 +61,7 @@ NODISCARD inline Optional<T> ConsumeDeserializable(FuzzedDataProvider& fuzzed_da try { ds >> obj; } catch (const std::ios_base::failure&) { - return nullopt; + return std::nullopt; } return obj; } @@ -149,4 +150,15 @@ NODISCARD bool AdditionOverflow(const T i, const T j) noexcept return std::numeric_limits<T>::max() - i < j; } +NODISCARD inline bool ContainsSpentInput(const CTransaction& tx, const CCoinsViewCache& inputs) noexcept +{ + for (const CTxIn& tx_in : tx.vin) { + const Coin& coin = inputs.AccessCoin(tx_in.prevout); + if (coin.IsSpent()) { + return true; + } + } + return false; +} + #endif // BITCOIN_TEST_FUZZ_UTIL_H diff --git a/src/test/miner_tests.cpp b/src/test/miner_tests.cpp index 9f3ca87206..57eee94330 100644 --- a/src/test/miner_tests.cpp +++ b/src/test/miner_tests.cpp @@ -253,7 +253,7 @@ BOOST_AUTO_TEST_CASE(CreateNewBlock_validity) pblock->nNonce = blockinfo[i].nonce; } std::shared_ptr<const CBlock> shared_pblock = std::make_shared<const CBlock>(*pblock); - BOOST_CHECK(ProcessNewBlock(chainparams, shared_pblock, true, nullptr)); + BOOST_CHECK(EnsureChainman(m_node).ProcessNewBlock(chainparams, shared_pblock, true, nullptr)); pblock->hashPrevBlock = pblock->GetHash(); } diff --git a/src/test/random_tests.cpp b/src/test/random_tests.cpp index d1f60e8972..978a7bee4d 100644 --- a/src/test/random_tests.cpp +++ b/src/test/random_tests.cpp @@ -28,6 +28,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) for (int i = 10; i > 0; --i) { BOOST_CHECK_EQUAL(GetRand(std::numeric_limits<uint64_t>::max()), uint64_t{10393729187455219830U}); BOOST_CHECK_EQUAL(GetRandInt(std::numeric_limits<int>::max()), int{769702006}); + BOOST_CHECK_EQUAL(GetRandMicros(std::chrono::hours{1}).count(), 2917185654); + BOOST_CHECK_EQUAL(GetRandMillis(std::chrono::hours{1}).count(), 2144374); } BOOST_CHECK_EQUAL(ctx1.rand32(), ctx2.rand32()); BOOST_CHECK_EQUAL(ctx1.rand32(), ctx2.rand32()); @@ -47,6 +49,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) for (int i = 10; i > 0; --i) { BOOST_CHECK(GetRand(std::numeric_limits<uint64_t>::max()) != uint64_t{10393729187455219830U}); BOOST_CHECK(GetRandInt(std::numeric_limits<int>::max()) != int{769702006}); + BOOST_CHECK(GetRandMicros(std::chrono::hours{1}) != std::chrono::microseconds{2917185654}); + BOOST_CHECK(GetRandMillis(std::chrono::hours{1}) != std::chrono::milliseconds{2144374}); } { FastRandomContext ctx3, ctx4; @@ -87,7 +91,7 @@ BOOST_AUTO_TEST_CASE(stdrandom_test) BOOST_CHECK(x >= 3); BOOST_CHECK(x <= 9); - std::vector<int> test{1,2,3,4,5,6,7,8,9,10}; + std::vector<int> test{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; std::shuffle(test.begin(), test.end(), ctx); for (int j = 1; j <= 10; ++j) { BOOST_CHECK(std::find(test.begin(), test.end(), j) != test.end()); @@ -97,7 +101,6 @@ BOOST_AUTO_TEST_CASE(stdrandom_test) BOOST_CHECK(std::find(test.begin(), test.end(), j) != test.end()); } } - } /** Test that Shuffle reaches every permutation with equal probability. */ diff --git a/src/test/ref_tests.cpp b/src/test/ref_tests.cpp new file mode 100644 index 0000000000..0ec0799fbc --- /dev/null +++ b/src/test/ref_tests.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <util/ref.h> + +#include <boost/test/unit_test.hpp> + +BOOST_AUTO_TEST_SUITE(ref_tests) + +BOOST_AUTO_TEST_CASE(ref_test) +{ + util::Ref ref; + BOOST_CHECK(!ref.Has<int>()); + BOOST_CHECK_THROW(ref.Get<int>(), NonFatalCheckError); + int value = 5; + ref.Set(value); + BOOST_CHECK(ref.Has<int>()); + BOOST_CHECK_EQUAL(ref.Get<int>(), 5); + ++ref.Get<int>(); + BOOST_CHECK_EQUAL(ref.Get<int>(), 6); + BOOST_CHECK_EQUAL(value, 6); + ++value; + BOOST_CHECK_EQUAL(value, 7); + BOOST_CHECK_EQUAL(ref.Get<int>(), 7); + BOOST_CHECK(!ref.Has<bool>()); + BOOST_CHECK_THROW(ref.Get<bool>(), NonFatalCheckError); + ref.Clear(); + BOOST_CHECK(!ref.Has<int>()); + BOOST_CHECK_THROW(ref.Get<int>(), NonFatalCheckError); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/rpc_tests.cpp b/src/test/rpc_tests.cpp index d9c66f1c19..b54cbb3f00 100644 --- a/src/test/rpc_tests.cpp +++ b/src/test/rpc_tests.cpp @@ -10,6 +10,7 @@ #include <interfaces/chain.h> #include <node/context.h> #include <test/util/setup_common.h> +#include <util/ref.h> #include <util/time.h> #include <boost/algorithm/string.hpp> @@ -19,13 +20,20 @@ #include <rpc/blockchain.h> -UniValue CallRPC(std::string args) +class RPCTestingSetup : public TestingSetup +{ +public: + UniValue CallRPC(std::string args); +}; + +UniValue RPCTestingSetup::CallRPC(std::string args) { std::vector<std::string> vArgs; boost::split(vArgs, args, boost::is_any_of(" \t")); std::string strMethod = vArgs[0]; vArgs.erase(vArgs.begin()); - JSONRPCRequest request; + util::Ref context{m_node}; + JSONRPCRequest request(context); request.strMethod = strMethod; request.params = RPCConvertValues(strMethod, vArgs); request.fHelp = false; @@ -40,7 +48,7 @@ UniValue CallRPC(std::string args) } -BOOST_FIXTURE_TEST_SUITE(rpc_tests, TestingSetup) +BOOST_FIXTURE_TEST_SUITE(rpc_tests, RPCTestingSetup) BOOST_AUTO_TEST_CASE(rpc_rawparams) { diff --git a/src/test/script_tests.cpp b/src/test/script_tests.cpp index 56454f61f3..cb3ae290d1 100644 --- a/src/test/script_tests.cpp +++ b/src/test/script_tests.cpp @@ -102,7 +102,7 @@ static ScriptErrorDesc script_errors[]={ {SCRIPT_ERR_SIG_FINDANDDELETE, "SIG_FINDANDDELETE"}, }; -static const char *FormatScriptError(ScriptError_t err) +static std::string FormatScriptError(ScriptError_t err) { for (unsigned int i=0; i<ARRAYLEN(script_errors); ++i) if (script_errors[i].err == err) @@ -134,7 +134,7 @@ void DoTest(const CScript& scriptPubKey, const CScript& scriptSig, const CScript CMutableTransaction tx = BuildSpendingTransaction(scriptSig, scriptWitness, txCredit); CMutableTransaction tx2 = tx; BOOST_CHECK_MESSAGE(VerifyScript(scriptSig, scriptPubKey, &scriptWitness, flags, MutableTransactionSignatureChecker(&tx, 0, txCredit.vout[0].nValue), &err) == expect, message); - BOOST_CHECK_MESSAGE(err == scriptError, std::string(FormatScriptError(err)) + " where " + std::string(FormatScriptError((ScriptError_t)scriptError)) + " expected: " + message); + BOOST_CHECK_MESSAGE(err == scriptError, FormatScriptError(err) + " where " + FormatScriptError((ScriptError_t)scriptError) + " expected: " + message); // Verify that removing flags from a passing test or adding flags to a failing test does not change the result. for (int i = 0; i < 16; ++i) { diff --git a/src/test/serialize_tests.cpp b/src/test/serialize_tests.cpp index 9a6c721ab8..c2328f931c 100644 --- a/src/test/serialize_tests.cpp +++ b/src/test/serialize_tests.cpp @@ -29,15 +29,13 @@ public: memcpy(charstrval, charstrvalin, sizeof(charstrval)); } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(intval); - READWRITE(boolval); - READWRITE(stringval); - READWRITE(charstrval); - READWRITE(txval); + SERIALIZE_METHODS(CSerializeMethodsTestSingle, obj) + { + READWRITE(obj.intval); + READWRITE(obj.boolval); + READWRITE(obj.stringval); + READWRITE(obj.charstrval); + READWRITE(obj.txval); } bool operator==(const CSerializeMethodsTestSingle& rhs) @@ -54,11 +52,10 @@ class CSerializeMethodsTestMany : public CSerializeMethodsTestSingle { public: using CSerializeMethodsTestSingle::CSerializeMethodsTestSingle; - ADD_SERIALIZE_METHODS; - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(intval, boolval, stringval, charstrval, txval); + SERIALIZE_METHODS(CSerializeMethodsTestMany, obj) + { + READWRITE(obj.intval, obj.boolval, obj.stringval, obj.charstrval, obj.txval); } }; diff --git a/src/test/util/mining.cpp b/src/test/util/mining.cpp index 1df6844062..dac7f1a07b 100644 --- a/src/test/util/mining.cpp +++ b/src/test/util/mining.cpp @@ -31,7 +31,7 @@ CTxIn MineBlock(const NodeContext& node, const CScript& coinbase_scriptPubKey) assert(block->nNonce); } - bool processed{ProcessNewBlock(Params(), block, true, nullptr)}; + bool processed{EnsureChainman(node).ProcessNewBlock(Params(), block, true, nullptr)}; assert(processed); return CTxIn{block->vtx[0]->GetHash(), 0}; diff --git a/src/test/util/setup_common.cpp b/src/test/util/setup_common.cpp index bf0afc4171..3b7a7c8d12 100644 --- a/src/test/util/setup_common.cpp +++ b/src/test/util/setup_common.cpp @@ -123,7 +123,6 @@ TestingSetup::TestingSetup(const std::string& chainName, const std::vector<const const CChainParams& chainparams = Params(); // Ideally we'd move all the RPC tests to the functional testing framework // instead of unit tests, but for now we need these here. - g_rpc_node = &m_node; RegisterAllCoreRPCCommands(tableRPC); m_node.scheduler = MakeUnique<CScheduler>(); @@ -131,11 +130,12 @@ TestingSetup::TestingSetup(const std::string& chainName, const std::vector<const // We have to run a scheduler thread to prevent ActivateBestChain // from blocking due to queue overrun. threadGroup.create_thread([&]{ m_node.scheduler->serviceQueue(); }); - GetMainSignals().RegisterBackgroundSignalScheduler(*g_rpc_node->scheduler); + GetMainSignals().RegisterBackgroundSignalScheduler(*m_node.scheduler); pblocktree.reset(new CBlockTreeDB(1 << 20, true)); - g_chainman.InitializeChainstate(); + m_node.chainman = &::g_chainman; + m_node.chainman->InitializeChainstate(); ::ChainstateActive().InitCoinsDB( /* cache_size_bytes */ 1 << 23, /* in_memory */ true, /* should_wipe */ false); assert(!::ChainstateActive().CanFlushToDisk()); @@ -161,7 +161,7 @@ TestingSetup::TestingSetup(const std::string& chainName, const std::vector<const m_node.mempool->setSanityCheck(1.0); m_node.banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", nullptr, DEFAULT_MISBEHAVING_BANTIME); m_node.connman = MakeUnique<CConnman>(0x1337, 0x1337); // Deterministic randomness for tests. - m_node.peer_logic = MakeUnique<PeerLogicValidation>(m_node.connman.get(), m_node.banman.get(), *m_node.scheduler, *m_node.mempool); + m_node.peer_logic = MakeUnique<PeerLogicValidation>(m_node.connman.get(), m_node.banman.get(), *m_node.scheduler, *m_node.chainman, *m_node.mempool); { CConnman::Options options; options.m_msgproc = m_node.peer_logic.get(); @@ -176,14 +176,14 @@ TestingSetup::~TestingSetup() threadGroup.join_all(); GetMainSignals().FlushBackgroundCallbacks(); GetMainSignals().UnregisterBackgroundSignalScheduler(); - g_rpc_node = nullptr; m_node.connman.reset(); m_node.banman.reset(); m_node.args = nullptr; m_node.mempool = nullptr; m_node.scheduler.reset(); UnloadBlockIndex(); - g_chainman.Reset(); + m_node.chainman->Reset(); + m_node.chainman = nullptr; pblocktree.reset(); } @@ -228,7 +228,7 @@ CBlock TestChain100Setup::CreateAndProcessBlock(const std::vector<CMutableTransa while (!CheckProofOfWork(block.GetHash(), block.nBits, chainparams.GetConsensus())) ++block.nNonce; std::shared_ptr<const CBlock> shared_pblock = std::make_shared<const CBlock>(block); - ProcessNewBlock(chainparams, shared_pblock, true, nullptr); + EnsureChainman(m_node).ProcessNewBlock(chainparams, shared_pblock, true, nullptr); CBlock result = block; return result; diff --git a/src/test/validation_block_tests.cpp b/src/test/validation_block_tests.cpp index c345f1eafb..45e0c5484e 100644 --- a/src/test/validation_block_tests.cpp +++ b/src/test/validation_block_tests.cpp @@ -32,7 +32,7 @@ struct MinerTestingSetup : public RegTestingSetup { BOOST_FIXTURE_TEST_SUITE(validation_block_tests, MinerTestingSetup) -struct TestSubscriber : public CValidationInterface { +struct TestSubscriber final : public CValidationInterface { uint256 m_expected_tip; explicit TestSubscriber(uint256 tip) : m_expected_tip(tip) {} @@ -163,10 +163,10 @@ BOOST_AUTO_TEST_CASE(processnewblock_signals_ordering) std::transform(blocks.begin(), blocks.end(), std::back_inserter(headers), [](std::shared_ptr<const CBlock> b) { return b->GetBlockHeader(); }); // Process all the headers so we understand the toplogy of the chain - BOOST_CHECK(ProcessNewBlockHeaders(headers, state, Params())); + BOOST_CHECK(EnsureChainman(m_node).ProcessNewBlockHeaders(headers, state, Params())); // Connect the genesis block and drain any outstanding events - BOOST_CHECK(ProcessNewBlock(Params(), std::make_shared<CBlock>(Params().GenesisBlock()), true, &ignored)); + BOOST_CHECK(EnsureChainman(m_node).ProcessNewBlock(Params(), std::make_shared<CBlock>(Params().GenesisBlock()), true, &ignored)); SyncWithValidationInterfaceQueue(); // subscribe to events (this subscriber will validate event ordering) @@ -175,26 +175,26 @@ BOOST_AUTO_TEST_CASE(processnewblock_signals_ordering) LOCK(cs_main); initial_tip = ::ChainActive().Tip(); } - TestSubscriber sub(initial_tip->GetBlockHash()); - RegisterValidationInterface(&sub); + auto sub = std::make_shared<TestSubscriber>(initial_tip->GetBlockHash()); + RegisterSharedValidationInterface(sub); // create a bunch of threads that repeatedly process a block generated above at random // this will create parallelism and randomness inside validation - the ValidationInterface // will subscribe to events generated during block validation and assert on ordering invariance std::vector<std::thread> threads; for (int i = 0; i < 10; i++) { - threads.emplace_back([&blocks]() { + threads.emplace_back([&]() { bool ignored; FastRandomContext insecure; for (int i = 0; i < 1000; i++) { auto block = blocks[insecure.randrange(blocks.size() - 1)]; - ProcessNewBlock(Params(), block, true, &ignored); + EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, &ignored); } // to make sure that eventually we process the full chain - do it here for (auto block : blocks) { if (block->vtx.size() == 1) { - bool processed = ProcessNewBlock(Params(), block, true, &ignored); + bool processed = EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, &ignored); assert(processed); } } @@ -204,14 +204,12 @@ BOOST_AUTO_TEST_CASE(processnewblock_signals_ordering) for (auto& t : threads) { t.join(); } - while (GetMainSignals().CallbacksPending() > 0) { - UninterruptibleSleep(std::chrono::milliseconds{100}); - } + SyncWithValidationInterfaceQueue(); - UnregisterValidationInterface(&sub); + UnregisterSharedValidationInterface(sub); LOCK(cs_main); - BOOST_CHECK_EQUAL(sub.m_expected_tip, ::ChainActive().Tip()->GetBlockHash()); + BOOST_CHECK_EQUAL(sub->m_expected_tip, ::ChainActive().Tip()->GetBlockHash()); } /** @@ -234,8 +232,8 @@ BOOST_AUTO_TEST_CASE(processnewblock_signals_ordering) BOOST_AUTO_TEST_CASE(mempool_locks_reorg) { bool ignored; - auto ProcessBlock = [&ignored](std::shared_ptr<const CBlock> block) -> bool { - return ProcessNewBlock(Params(), block, /* fForceProcessing */ true, /* fNewBlock */ &ignored); + auto ProcessBlock = [&](std::shared_ptr<const CBlock> block) -> bool { + return EnsureChainman(m_node).ProcessNewBlock(Params(), block, /* fForceProcessing */ true, /* fNewBlock */ &ignored); }; // Process all mined blocks diff --git a/src/test/validationinterface_tests.cpp b/src/test/validationinterface_tests.cpp index 208be92852..ceba689e52 100644 --- a/src/test/validationinterface_tests.cpp +++ b/src/test/validationinterface_tests.cpp @@ -12,6 +12,40 @@ BOOST_FIXTURE_TEST_SUITE(validationinterface_tests, TestingSetup) +struct TestSubscriberNoop final : public CValidationInterface { + void BlockChecked(const CBlock&, const BlockValidationState&) override {} +}; + +BOOST_AUTO_TEST_CASE(unregister_validation_interface_race) +{ + std::atomic<bool> generate{true}; + + // Start thread to generate notifications + std::thread gen{[&] { + const CBlock block_dummy; + BlockValidationState state_dummy; + while (generate) { + GetMainSignals().BlockChecked(block_dummy, state_dummy); + } + }}; + + // Start thread to consume notifications + std::thread sub{[&] { + // keep going for about 1 sec, which is 250k iterations + for (int i = 0; i < 250000; i++) { + auto sub = std::make_shared<TestSubscriberNoop>(); + RegisterSharedValidationInterface(sub); + UnregisterSharedValidationInterface(sub); + } + // tell the other thread we are done + generate = false; + }}; + + gen.join(); + sub.join(); + BOOST_CHECK(!generate); +} + class TestInterface : public CValidationInterface { public: diff --git a/src/threadsafety.h b/src/threadsafety.h index bb988dfdfd..942aa3fdcd 100644 --- a/src/threadsafety.h +++ b/src/threadsafety.h @@ -6,6 +6,8 @@ #ifndef BITCOIN_THREADSAFETY_H #define BITCOIN_THREADSAFETY_H +#include <mutex> + #ifdef __clang__ // TL;DR Add GUARDED_BY(mutex) to member variables. The others are // rarely necessary. Ex: int nFoo GUARDED_BY(cs_foo); @@ -54,4 +56,19 @@ #define ASSERT_EXCLUSIVE_LOCK(...) #endif // __GNUC__ +// StdMutex provides an annotated version of std::mutex for us, +// and should only be used when sync.h Mutex/LOCK/etc are not usable. +class LOCKABLE StdMutex : public std::mutex +{ +}; + +// StdLockGuard provides an annotated version of std::lock_guard for us, +// and should only be used when sync.h Mutex/LOCK/etc are not usable. +class SCOPED_LOCKABLE StdLockGuard : public std::lock_guard<StdMutex> +{ +public: + explicit StdLockGuard(StdMutex& cs) EXCLUSIVE_LOCK_FUNCTION(cs) : std::lock_guard<StdMutex>(cs) {} + ~StdLockGuard() UNLOCK_FUNCTION() {} +}; + #endif // BITCOIN_THREADSAFETY_H diff --git a/src/txdb.cpp b/src/txdb.cpp index 071aa1336b..129697f0e7 100644 --- a/src/txdb.cpp +++ b/src/txdb.cpp @@ -36,19 +36,7 @@ struct CoinEntry { char key; explicit CoinEntry(const COutPoint* ptr) : outpoint(const_cast<COutPoint*>(ptr)), key(DB_COIN) {} - template<typename Stream> - void Serialize(Stream &s) const { - s << key; - s << outpoint->hash; - s << VARINT(outpoint->n); - } - - template<typename Stream> - void Unserialize(Stream& s) { - s >> key; - s >> outpoint->hash; - s >> VARINT(outpoint->n); - } + SERIALIZE_METHODS(CoinEntry, obj) { READWRITE(obj.key, obj.outpoint->hash, VARINT(obj.outpoint->n)); } }; } diff --git a/src/txmempool.h b/src/txmempool.h index 4bee78b8d6..4568eb928d 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -704,7 +704,10 @@ public: /** Adds a transaction to the unbroadcast set */ void AddUnbroadcastTx(const uint256& txid) { LOCK(cs); - m_unbroadcast_txids.insert(txid); + /** Sanity Check: the transaction should also be in the mempool */ + if (exists(txid)) { + m_unbroadcast_txids.insert(txid); + } } /** Removes a transaction from the unbroadcast set */ @@ -716,6 +719,12 @@ public: return m_unbroadcast_txids; } + // Returns if a txid is in the unbroadcast set + bool IsUnbroadcastTx(const uint256& txid) const { + LOCK(cs); + return (m_unbroadcast_txids.count(txid) != 0); + } + private: /** UpdateForDescendants is used by UpdateTransactionsFromBlock to update * the descendants for a single transaction that has been added to the diff --git a/src/ui_interface.cpp b/src/ui_interface.cpp index 9cfde9502d..15795bd67f 100644 --- a/src/ui_interface.cpp +++ b/src/ui_interface.cpp @@ -49,8 +49,8 @@ void CClientUIInterface::NotifyNumConnectionsChanged(int newNumConnections) { re void CClientUIInterface::NotifyNetworkActiveChanged(bool networkActive) { return g_ui_signals.NotifyNetworkActiveChanged(networkActive); } void CClientUIInterface::NotifyAlertChanged() { return g_ui_signals.NotifyAlertChanged(); } void CClientUIInterface::ShowProgress(const std::string& title, int nProgress, bool resume_possible) { return g_ui_signals.ShowProgress(title, nProgress, resume_possible); } -void CClientUIInterface::NotifyBlockTip(bool b, const CBlockIndex* i) { return g_ui_signals.NotifyBlockTip(b, i); } -void CClientUIInterface::NotifyHeaderTip(bool b, const CBlockIndex* i) { return g_ui_signals.NotifyHeaderTip(b, i); } +void CClientUIInterface::NotifyBlockTip(SynchronizationState s, const CBlockIndex* i) { return g_ui_signals.NotifyBlockTip(s, i); } +void CClientUIInterface::NotifyHeaderTip(SynchronizationState s, const CBlockIndex* i) { return g_ui_signals.NotifyHeaderTip(s, i); } void CClientUIInterface::BannedListChanged() { return g_ui_signals.BannedListChanged(); } bool InitError(const bilingual_str& str) @@ -59,7 +59,7 @@ bool InitError(const bilingual_str& str) return false; } -void InitWarning(const std::string& str) +void InitWarning(const bilingual_str& str) { - uiInterface.ThreadSafeMessageBox(Untranslated(str), "", CClientUIInterface::MSG_WARNING); + uiInterface.ThreadSafeMessageBox(str, "", CClientUIInterface::MSG_WARNING); } diff --git a/src/ui_interface.h b/src/ui_interface.h index 132866cc5a..d45811178f 100644 --- a/src/ui_interface.h +++ b/src/ui_interface.h @@ -11,6 +11,7 @@ #include <string> class CBlockIndex; +enum class SynchronizationState; struct bilingual_str; namespace boost { @@ -110,18 +111,17 @@ public: ADD_SIGNALS_DECL_WRAPPER(ShowProgress, void, const std::string& title, int nProgress, bool resume_possible); /** New block has been accepted */ - ADD_SIGNALS_DECL_WRAPPER(NotifyBlockTip, void, bool, const CBlockIndex*); + ADD_SIGNALS_DECL_WRAPPER(NotifyBlockTip, void, SynchronizationState, const CBlockIndex*); /** Best header has changed */ - ADD_SIGNALS_DECL_WRAPPER(NotifyHeaderTip, void, bool, const CBlockIndex*); + ADD_SIGNALS_DECL_WRAPPER(NotifyHeaderTip, void, SynchronizationState, const CBlockIndex*); /** Banlist did change. */ ADD_SIGNALS_DECL_WRAPPER(BannedListChanged, void, void); }; /** Show warning message **/ -// TODO: InitWarning() should take a bilingual_str parameter. -void InitWarning(const std::string& str); +void InitWarning(const bilingual_str& str); /** Show error message **/ bool InitError(const bilingual_str& str); diff --git a/src/util/check.h b/src/util/check.h index d18887ae95..5c0f32cf51 100644 --- a/src/util/check.h +++ b/src/util/check.h @@ -5,6 +5,10 @@ #ifndef BITCOIN_UTIL_CHECK_H #define BITCOIN_UTIL_CHECK_H +#if defined(HAVE_CONFIG_H) +#include <config/bitcoin-config.h> +#endif + #include <tinyformat.h> #include <stdexcept> diff --git a/src/util/ref.h b/src/util/ref.h new file mode 100644 index 0000000000..9685ea9fec --- /dev/null +++ b/src/util/ref.h @@ -0,0 +1,38 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_UTIL_REF_H +#define BITCOIN_UTIL_REF_H + +#include <util/check.h> + +#include <typeindex> + +namespace util { + +/** + * Type-safe dynamic reference. + * + * This implements a small subset of the functionality in C++17's std::any + * class, and can be dropped when the project updates to C++17 + * (https://github.com/bitcoin/bitcoin/issues/16684) + */ +class Ref +{ +public: + Ref() = default; + template<typename T> Ref(T& value) { Set(value); } + template<typename T> T& Get() const { CHECK_NONFATAL(Has<T>()); return *static_cast<T*>(m_value); } + template<typename T> void Set(T& value) { m_value = &value; m_type = std::type_index(typeid(T)); } + template<typename T> bool Has() const { return m_value && m_type == std::type_index(typeid(T)); } + void Clear() { m_value = nullptr; m_type = std::type_index(typeid(void)); } + +private: + void* m_value = nullptr; + std::type_index m_type = std::type_index(typeid(void)); +}; + +} // namespace util + +#endif // BITCOIN_UTIL_REF_H diff --git a/src/util/string.h b/src/util/string.h index b8e2a06235..cdb41630c6 100644 --- a/src/util/string.h +++ b/src/util/string.h @@ -30,10 +30,11 @@ NODISCARD inline std::string TrimString(const std::string& str, const std::strin * @param separator The separator * @param unary_op Apply this operator to each item in the list */ -template <typename T, typename UnaryOp> -std::string Join(const std::vector<T>& list, const std::string& separator, UnaryOp unary_op) +template <typename T, typename BaseType, typename UnaryOp> +auto Join(const std::vector<T>& list, const BaseType& separator, UnaryOp unary_op) + -> decltype(unary_op(list.at(0))) { - std::string ret; + decltype(unary_op(list.at(0))) ret; for (size_t i = 0; i < list.size(); ++i) { if (i > 0) ret += separator; ret += unary_op(list.at(i)); @@ -41,9 +42,16 @@ std::string Join(const std::vector<T>& list, const std::string& separator, Unary return ret; } +template <typename T> +T Join(const std::vector<T>& list, const T& separator) +{ + return Join(list, separator, [](const T& i) { return i; }); +} + +// Explicit overload needed for c_str arguments, which would otherwise cause a substitution failure in the template above. inline std::string Join(const std::vector<std::string>& list, const std::string& separator) { - return Join(list, separator, [](const std::string& i) { return i; }); + return Join<std::string>(list, separator); } /** diff --git a/src/util/system.cpp b/src/util/system.cpp index 2013b416db..bde0f097be 100644 --- a/src/util/system.cpp +++ b/src/util/system.cpp @@ -3,6 +3,7 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. +#include <sync.h> #include <util/system.h> #include <chainparamsbase.h> @@ -75,18 +76,18 @@ const char * const BITCOIN_CONF_FILENAME = "bitcoin.conf"; ArgsManager gArgs; +/** Mutex to protect dir_locks. */ +static Mutex cs_dir_locks; /** A map that contains all the currently held directory locks. After * successful locking, these will be held here until the global destructor * cleans them up and thus automatically unlocks them, or ReleaseDirectoryLocks * is called. */ -static std::map<std::string, std::unique_ptr<fsbridge::FileLock>> dir_locks; -/** Mutex to protect dir_locks. */ -static std::mutex cs_dir_locks; +static std::map<std::string, std::unique_ptr<fsbridge::FileLock>> dir_locks GUARDED_BY(cs_dir_locks); bool LockDirectory(const fs::path& directory, const std::string lockfile_name, bool probe_only) { - std::lock_guard<std::mutex> ulock(cs_dir_locks); + LOCK(cs_dir_locks); fs::path pathLockFile = directory / lockfile_name; // If a lock for this directory already exists in the map, don't try to re-lock it @@ -110,13 +111,13 @@ bool LockDirectory(const fs::path& directory, const std::string lockfile_name, b void UnlockDirectory(const fs::path& directory, const std::string& lockfile_name) { - std::lock_guard<std::mutex> lock(cs_dir_locks); + LOCK(cs_dir_locks); dir_locks.erase((directory / lockfile_name).string()); } void ReleaseDirectoryLocks() { - std::lock_guard<std::mutex> ulock(cs_dir_locks); + LOCK(cs_dir_locks); dir_locks.clear(); } diff --git a/src/util/translation.h b/src/util/translation.h index 45595405e7..268bcf30a7 100644 --- a/src/util/translation.h +++ b/src/util/translation.h @@ -16,21 +16,23 @@ struct bilingual_str { std::string original; std::string translated; + + bilingual_str& operator+=(const bilingual_str& rhs) + { + original += rhs.original; + translated += rhs.translated; + return *this; + } }; -inline bilingual_str operator+(const bilingual_str& lhs, const bilingual_str& rhs) +inline bilingual_str operator+(bilingual_str lhs, const bilingual_str& rhs) { - return bilingual_str{ - lhs.original + rhs.original, - lhs.translated + rhs.translated}; + lhs += rhs; + return lhs; } /** Mark a bilingual_str as untranslated */ inline bilingual_str Untranslated(std::string original) { return {original, original}; } -/** Unary operator to return the original */ -inline std::string OpOriginal(const bilingual_str& b) { return b.original; } -/** Unary operator to return the translation */ -inline std::string OpTranslated(const bilingual_str& b) { return b.translated; } namespace tinyformat { template <typename... Args> diff --git a/src/validation.cpp b/src/validation.cpp index 8a454c8d1b..dbdf5028fd 100644 --- a/src/validation.cpp +++ b/src/validation.cpp @@ -196,8 +196,8 @@ CBlockIndex* FindForkInGlobalIndex(const CChain& chain, const CBlockLocator& loc std::unique_ptr<CBlockTreeDB> pblocktree; // See definition for documentation -static void FindFilesToPruneManual(std::set<int>& setFilesToPrune, int nManualPruneHeight); -static void FindFilesToPrune(std::set<int>& setFilesToPrune, uint64_t nPruneAfterHeight); +static void FindFilesToPruneManual(ChainstateManager& chainman, std::set<int>& setFilesToPrune, int nManualPruneHeight); +static void FindFilesToPrune(ChainstateManager& chainman, std::set<int>& setFilesToPrune, uint64_t nPruneAfterHeight); bool CheckInputScripts(const CTransaction& tx, TxValidationState &state, const CCoinsViewCache &inputs, unsigned int flags, bool cacheSigStore, bool cacheFullScriptStore, PrecomputedTransactionData& txdata, std::vector<CScriptCheck> *pvChecks = nullptr); static FILE* OpenUndoFile(const FlatFilePos &pos, bool fReadOnly = false); static FlatFileSeq BlockFileSeq(); @@ -2282,11 +2282,11 @@ bool CChainState::FlushStateToDisk( if (nManualPruneHeight > 0) { LOG_TIME_MILLIS_WITH_CATEGORY("find files to prune (manual)", BCLog::BENCH); - FindFilesToPruneManual(setFilesToPrune, nManualPruneHeight); + FindFilesToPruneManual(g_chainman, setFilesToPrune, nManualPruneHeight); } else { LOG_TIME_MILLIS_WITH_CATEGORY("find files to prune", BCLog::BENCH); - FindFilesToPrune(setFilesToPrune, chainparams.PruneAfterHeight()); + FindFilesToPrune(g_chainman, setFilesToPrune, chainparams.PruneAfterHeight()); fCheckForPruning = false; } if (!setFilesToPrune.empty()) { @@ -2800,6 +2800,13 @@ bool CChainState::ActivateBestChainStep(BlockValidationState& state, const CChai return true; } +static SynchronizationState GetSynchronizationState(bool init) +{ + if (!init) return SynchronizationState::POST_INIT; + if (::fReindex) return SynchronizationState::INIT_REINDEX; + return SynchronizationState::INIT_DOWNLOAD; +} + static bool NotifyHeaderTip() LOCKS_EXCLUDED(cs_main) { bool fNotify = false; bool fInitialBlockDownload = false; @@ -2817,7 +2824,7 @@ static bool NotifyHeaderTip() LOCKS_EXCLUDED(cs_main) { } // Send block tip changed notifications without cs_main if (fNotify) { - uiInterface.NotifyHeaderTip(fInitialBlockDownload, pindexHeader); + uiInterface.NotifyHeaderTip(GetSynchronizationState(fInitialBlockDownload), pindexHeader); } return fNotify; } @@ -2906,7 +2913,7 @@ bool CChainState::ActivateBestChain(BlockValidationState &state, const CChainPar GetMainSignals().UpdatedBlockTip(pindexNewTip, pindexFork, fInitialDownload); // Always notify the UI if a new block tip was connected - uiInterface.NotifyBlockTip(fInitialDownload, pindexNewTip); + uiInterface.NotifyBlockTip(GetSynchronizationState(fInitialDownload), pindexNewTip); } } // When we reach this point, we switched to a new tip (stored in pindexNewTip). @@ -3097,7 +3104,7 @@ bool CChainState::InvalidateBlock(BlockValidationState& state, const CChainParam // Only notify about a new block tip if the active chain was modified. if (pindex_was_in_chain) { - uiInterface.NotifyBlockTip(IsInitialBlockDownload(), to_mark_failed->pprev); + uiInterface.NotifyBlockTip(GetSynchronizationState(IsInitialBlockDownload()), to_mark_failed->pprev); } return true; } @@ -3684,13 +3691,14 @@ bool BlockManager::AcceptBlockHeader(const CBlockHeader& block, BlockValidationS } // Exposed wrapper for AcceptBlockHeader -bool ProcessNewBlockHeaders(const std::vector<CBlockHeader>& headers, BlockValidationState& state, const CChainParams& chainparams, const CBlockIndex** ppindex) +bool ChainstateManager::ProcessNewBlockHeaders(const std::vector<CBlockHeader>& headers, BlockValidationState& state, const CChainParams& chainparams, const CBlockIndex** ppindex) { + AssertLockNotHeld(cs_main); { LOCK(cs_main); for (const CBlockHeader& header : headers) { CBlockIndex *pindex = nullptr; // Use a temp pindex instead of ppindex to avoid a const_cast - bool accepted = g_chainman.m_blockman.AcceptBlockHeader( + bool accepted = m_blockman.AcceptBlockHeader( header, state, chainparams, &pindex); ::ChainstateActive().CheckBlockIndex(chainparams.GetConsensus()); @@ -3812,7 +3820,7 @@ bool CChainState::AcceptBlock(const std::shared_ptr<const CBlock>& pblock, Block return true; } -bool ProcessNewBlock(const CChainParams& chainparams, const std::shared_ptr<const CBlock> pblock, bool fForceProcessing, bool *fNewBlock) +bool ChainstateManager::ProcessNewBlock(const CChainParams& chainparams, const std::shared_ptr<const CBlock> pblock, bool fForceProcessing, bool* fNewBlock) { AssertLockNotHeld(cs_main); @@ -3888,12 +3896,12 @@ uint64_t CalculateCurrentUsage() return retval; } -/* Prune a block file (modify associated database entries)*/ -void PruneOneBlockFile(const int fileNumber) +void ChainstateManager::PruneOneBlockFile(const int fileNumber) { + AssertLockHeld(cs_main); LOCK(cs_LastBlockFile); - for (const auto& entry : g_chainman.BlockIndex()) { + for (const auto& entry : m_blockman.m_block_index) { CBlockIndex* pindex = entry.second; if (pindex->nFile == fileNumber) { pindex->nStatus &= ~BLOCK_HAVE_DATA; @@ -3907,12 +3915,12 @@ void PruneOneBlockFile(const int fileNumber) // to be downloaded again in order to consider its chain, at which // point it would be considered as a candidate for // m_blocks_unlinked or setBlockIndexCandidates. - auto range = g_chainman.m_blockman.m_blocks_unlinked.equal_range(pindex->pprev); + auto range = m_blockman.m_blocks_unlinked.equal_range(pindex->pprev); while (range.first != range.second) { std::multimap<CBlockIndex *, CBlockIndex *>::iterator _it = range.first; range.first++; if (_it->second == pindex) { - g_chainman.m_blockman.m_blocks_unlinked.erase(_it); + m_blockman.m_blocks_unlinked.erase(_it); } } } @@ -3934,7 +3942,7 @@ void UnlinkPrunedFiles(const std::set<int>& setFilesToPrune) } /* Calculate the block/rev files to delete based on height specified by user with RPC command pruneblockchain */ -static void FindFilesToPruneManual(std::set<int>& setFilesToPrune, int nManualPruneHeight) +static void FindFilesToPruneManual(ChainstateManager& chainman, std::set<int>& setFilesToPrune, int nManualPruneHeight) { assert(fPruneMode && nManualPruneHeight > 0); @@ -3948,7 +3956,7 @@ static void FindFilesToPruneManual(std::set<int>& setFilesToPrune, int nManualPr for (int fileNumber = 0; fileNumber < nLastBlockFile; fileNumber++) { if (vinfoBlockFile[fileNumber].nSize == 0 || vinfoBlockFile[fileNumber].nHeightLast > nLastBlockWeCanPrune) continue; - PruneOneBlockFile(fileNumber); + chainman.PruneOneBlockFile(fileNumber); setFilesToPrune.insert(fileNumber); count++; } @@ -3981,7 +3989,7 @@ void PruneBlockFilesManual(int nManualPruneHeight) * * @param[out] setFilesToPrune The set of file indices that can be unlinked will be returned */ -static void FindFilesToPrune(std::set<int>& setFilesToPrune, uint64_t nPruneAfterHeight) +static void FindFilesToPrune(ChainstateManager& chainman, std::set<int>& setFilesToPrune, uint64_t nPruneAfterHeight) { LOCK2(cs_main, cs_LastBlockFile); if (::ChainActive().Tip() == nullptr || nPruneTarget == 0) { @@ -4023,7 +4031,7 @@ static void FindFilesToPrune(std::set<int>& setFilesToPrune, uint64_t nPruneAfte if (vinfoBlockFile[fileNumber].nHeightLast > nLastBlockWeCanPrune) continue; - PruneOneBlockFile(fileNumber); + chainman.PruneOneBlockFile(fileNumber); // Queue up the files for removal setFilesToPrune.insert(fileNumber); nCurrentUsage -= nBytesToPrune; @@ -4147,9 +4155,9 @@ void BlockManager::Unload() { m_block_index.clear(); } -bool static LoadBlockIndexDB(const CChainParams& chainparams) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +bool static LoadBlockIndexDB(ChainstateManager& chainman, const CChainParams& chainparams) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { - if (!g_chainman.m_blockman.LoadBlockIndex( + if (!chainman.m_blockman.LoadBlockIndex( chainparams.GetConsensus(), *pblocktree, ::ChainstateActive().setBlockIndexCandidates)) { return false; @@ -4175,8 +4183,7 @@ bool static LoadBlockIndexDB(const CChainParams& chainparams) EXCLUSIVE_LOCKS_RE // Check presence of blk files LogPrintf("Checking all blk files are present...\n"); std::set<int> setBlkDataFiles; - for (const std::pair<const uint256, CBlockIndex*>& item : g_chainman.BlockIndex()) - { + for (const std::pair<const uint256, CBlockIndex*>& item : chainman.BlockIndex()) { CBlockIndex* pindex = item.second; if (pindex->nStatus & BLOCK_HAVE_DATA) { setBlkDataFiles.insert(pindex->nFile); @@ -4593,14 +4600,15 @@ void UnloadBlockIndex() fHavePruned = false; } -bool LoadBlockIndex(const CChainParams& chainparams) +bool ChainstateManager::LoadBlockIndex(const CChainParams& chainparams) { + AssertLockHeld(cs_main); // Load block index from databases bool needs_init = fReindex; if (!fReindex) { - bool ret = LoadBlockIndexDB(chainparams); + bool ret = LoadBlockIndexDB(*this, chainparams); if (!ret) return false; - needs_init = g_chainman.m_blockman.m_block_index.empty(); + needs_init = m_blockman.m_block_index.empty(); } if (needs_init) { diff --git a/src/validation.h b/src/validation.h index c4a5cc4593..8112e38704 100644 --- a/src/validation.h +++ b/src/validation.h @@ -43,6 +43,7 @@ class CConnman; class CScriptCheck; class CBlockPolicyEstimator; class CTxMemPool; +class ChainstateManager; class TxValidationState; struct ChainTxData; @@ -103,6 +104,13 @@ struct BlockHasher size_t operator()(const uint256& hash) const { return ReadLE64(hash.begin()); } }; +/** Current sync state passed to tip changed callbacks. */ +enum class SynchronizationState { + INIT_REINDEX, + INIT_DOWNLOAD, + POST_INIT +}; + extern RecursiveMutex cs_main; extern CBlockPolicyEstimator feeEstimator; extern CTxMemPool mempool; @@ -142,41 +150,6 @@ extern bool fPruneMode; /** Number of MiB of block files that we're trying to stay below. */ extern uint64_t nPruneTarget; -/** - * Process an incoming block. This only returns after the best known valid - * block is made active. Note that it does not, however, guarantee that the - * specific block passed to it has been checked for validity! - * - * If you want to *possibly* get feedback on whether pblock is valid, you must - * install a CValidationInterface (see validationinterface.h) - this will have - * its BlockChecked method called whenever *any* block completes validation. - * - * Note that we guarantee that either the proof-of-work is valid on pblock, or - * (and possibly also) BlockChecked will have been called. - * - * May not be called in a - * validationinterface callback. - * - * @param[in] pblock The block we want to process. - * @param[in] fForceProcessing Process this block even if unrequested; used for non-network block sources and whitelisted peers. - * @param[out] fNewBlock A boolean which is set to indicate if the block was first received via this call - * @returns If the block was processed, independently of block validity - */ -bool ProcessNewBlock(const CChainParams& chainparams, const std::shared_ptr<const CBlock> pblock, bool fForceProcessing, bool* fNewBlock) LOCKS_EXCLUDED(cs_main); - -/** - * Process incoming block headers. - * - * May not be called in a - * validationinterface callback. - * - * @param[in] block The block headers themselves - * @param[out] state This may be set to an Error state if any error occurred processing them - * @param[in] chainparams The params for the chain we want to connect to - * @param[out] ppindex If set, the pointer will be set to point to the last new block index object for the given headers - */ -bool ProcessNewBlockHeaders(const std::vector<CBlockHeader>& block, BlockValidationState& state, const CChainParams& chainparams, const CBlockIndex** ppindex = nullptr) LOCKS_EXCLUDED(cs_main); - /** Open a block file (blk?????.dat) */ FILE* OpenBlockFile(const FlatFilePos &pos, bool fReadOnly = false); /** Translation to a filesystem path */ @@ -185,9 +158,6 @@ fs::path GetBlockPosFilename(const FlatFilePos &pos); void LoadExternalBlockFile(const CChainParams& chainparams, FILE* fileIn, FlatFilePos* dbp = nullptr); /** Ensures we have a genesis block in the block tree, possibly writing one to disk. */ bool LoadGenesisBlock(const CChainParams& chainparams); -/** Load the block tree and coins database from disk, - * initializing state if we're running with -reindex. */ -bool LoadBlockIndex(const CChainParams& chainparams) EXCLUSIVE_LOCKS_REQUIRED(cs_main); /** Unload database information */ void UnloadBlockIndex(); /** Run an instance of the script checking thread */ @@ -210,11 +180,6 @@ double GuessVerificationProgress(const ChainTxData& data, const CBlockIndex* pin uint64_t CalculateCurrentUsage(); /** - * Mark one block file as pruned. - */ -void PruneOneBlockFile(const int fileNumber) EXCLUSIVE_LOCKS_REQUIRED(cs_main); - -/** * Actually unlink the specified files */ void UnlinkPrunedFiles(const std::set<int>& setFilesToPrune); @@ -486,9 +451,6 @@ enum class CoinsCacheSizeState OK = 0 }; -// Defined below, but needed for `friend` usage in CChainState. -class ChainstateManager; - /** * CChainState stores and provides an API to update our local knowledge of the * current best chain. @@ -863,6 +825,47 @@ public: CChain& ValidatedChain() const { return ValidatedChainstate().m_chain; } CBlockIndex* ValidatedTip() const { return ValidatedChain().Tip(); } + /** + * Process an incoming block. This only returns after the best known valid + * block is made active. Note that it does not, however, guarantee that the + * specific block passed to it has been checked for validity! + * + * If you want to *possibly* get feedback on whether pblock is valid, you must + * install a CValidationInterface (see validationinterface.h) - this will have + * its BlockChecked method called whenever *any* block completes validation. + * + * Note that we guarantee that either the proof-of-work is valid on pblock, or + * (and possibly also) BlockChecked will have been called. + * + * May not be called in a + * validationinterface callback. + * + * @param[in] pblock The block we want to process. + * @param[in] fForceProcessing Process this block even if unrequested; used for non-network block sources and whitelisted peers. + * @param[out] fNewBlock A boolean which is set to indicate if the block was first received via this call + * @returns If the block was processed, independently of block validity + */ + bool ProcessNewBlock(const CChainParams& chainparams, const std::shared_ptr<const CBlock> pblock, bool fForceProcessing, bool* fNewBlock) LOCKS_EXCLUDED(cs_main); + + /** + * Process incoming block headers. + * + * May not be called in a + * validationinterface callback. + * + * @param[in] block The block headers themselves + * @param[out] state This may be set to an Error state if any error occurred processing them + * @param[in] chainparams The params for the chain we want to connect to + * @param[out] ppindex If set, the pointer will be set to point to the last new block index object for the given headers + */ + bool ProcessNewBlockHeaders(const std::vector<CBlockHeader>& block, BlockValidationState& state, const CChainParams& chainparams, const CBlockIndex** ppindex = nullptr) LOCKS_EXCLUDED(cs_main); + + //! Mark one block file as pruned (modify associated database entries) + void PruneOneBlockFile(const int fileNumber) EXCLUSIVE_LOCKS_REQUIRED(cs_main); + + //! Load the block tree and coins database from disk, initializing state if we're running with -reindex + bool LoadBlockIndex(const CChainParams& chainparams) EXCLUSIVE_LOCKS_REQUIRED(cs_main); + //! Unload block index and chain data before shutdown. void Unload() EXCLUSIVE_LOCKS_REQUIRED(::cs_main); @@ -870,6 +873,7 @@ public: void Reset(); }; +/** DEPRECATED! Please use node.chainman instead. May only be used in validation.cpp internally */ extern ChainstateManager g_chainman GUARDED_BY(::cs_main); /** @returns the most-work valid chainstate. */ diff --git a/src/validationinterface.cpp b/src/validationinterface.cpp index 11000774c0..9437f9c817 100644 --- a/src/validationinterface.cpp +++ b/src/validationinterface.cpp @@ -89,22 +89,26 @@ public: static CMainSignals g_signals; -void CMainSignals::RegisterBackgroundSignalScheduler(CScheduler& scheduler) { +void CMainSignals::RegisterBackgroundSignalScheduler(CScheduler& scheduler) +{ assert(!m_internals); m_internals.reset(new MainSignalsInstance(&scheduler)); } -void CMainSignals::UnregisterBackgroundSignalScheduler() { +void CMainSignals::UnregisterBackgroundSignalScheduler() +{ m_internals.reset(nullptr); } -void CMainSignals::FlushBackgroundCallbacks() { +void CMainSignals::FlushBackgroundCallbacks() +{ if (m_internals) { m_internals->m_schedulerClient.EmptyQueue(); } } -size_t CMainSignals::CallbacksPending() { +size_t CMainSignals::CallbacksPending() +{ if (!m_internals) return 0; return m_internals->m_schedulerClient.CallbacksPending(); } @@ -114,10 +118,11 @@ CMainSignals& GetMainSignals() return g_signals; } -void RegisterSharedValidationInterface(std::shared_ptr<CValidationInterface> pwalletIn) { - // Each connection captures pwalletIn to ensure that each callback is - // executed before pwalletIn is destroyed. For more details see #18338. - g_signals.m_internals->Register(std::move(pwalletIn)); +void RegisterSharedValidationInterface(std::shared_ptr<CValidationInterface> callbacks) +{ + // Each connection captures the shared_ptr to ensure that each callback is + // executed before the subscriber is destroyed. For more details see #18338. + g_signals.m_internals->Register(std::move(callbacks)); } void RegisterValidationInterface(CValidationInterface* callbacks) @@ -132,24 +137,28 @@ void UnregisterSharedValidationInterface(std::shared_ptr<CValidationInterface> c UnregisterValidationInterface(callbacks.get()); } -void UnregisterValidationInterface(CValidationInterface* pwalletIn) { +void UnregisterValidationInterface(CValidationInterface* callbacks) +{ if (g_signals.m_internals) { - g_signals.m_internals->Unregister(pwalletIn); + g_signals.m_internals->Unregister(callbacks); } } -void UnregisterAllValidationInterfaces() { +void UnregisterAllValidationInterfaces() +{ if (!g_signals.m_internals) { return; } g_signals.m_internals->Clear(); } -void CallFunctionInValidationInterfaceQueue(std::function<void ()> func) { +void CallFunctionInValidationInterfaceQueue(std::function<void()> func) +{ g_signals.m_internals->m_schedulerClient.AddToProcessQueue(std::move(func)); } -void SyncWithValidationInterfaceQueue() { +void SyncWithValidationInterfaceQueue() +{ AssertLockNotHeld(cs_main); // Block until the validation queue drains std::promise<void> promise; diff --git a/src/validationinterface.h b/src/validationinterface.h index cb0204a555..9c23965bc1 100644 --- a/src/validationinterface.h +++ b/src/validationinterface.h @@ -22,20 +22,20 @@ class CValidationInterface; class uint256; class CScheduler; -// These functions dispatch to one or all registered wallets - -/** Register a wallet to receive updates from core */ -void RegisterValidationInterface(CValidationInterface* pwalletIn); -/** Unregister a wallet from core */ -void UnregisterValidationInterface(CValidationInterface* pwalletIn); -/** Unregister all wallets from core */ +/** Register subscriber */ +void RegisterValidationInterface(CValidationInterface* callbacks); +/** Unregister subscriber. DEPRECATED. This is not safe to use when the RPC server or main message handler thread is running. */ +void UnregisterValidationInterface(CValidationInterface* callbacks); +/** Unregister all subscribers */ void UnregisterAllValidationInterfaces(); // Alternate registration functions that release a shared_ptr after the last // notification is sent. These are useful for race-free cleanup, since // unregistration is nonblocking and can return before the last notification is // processed. +/** Register subscriber */ void RegisterSharedValidationInterface(std::shared_ptr<CValidationInterface> callbacks); +/** Unregister subscriber */ void UnregisterSharedValidationInterface(std::shared_ptr<CValidationInterface> callbacks); /** diff --git a/src/wallet/crypter.h b/src/wallet/crypter.h index f59c63260e..f2df786e2e 100644 --- a/src/wallet/crypter.h +++ b/src/wallet/crypter.h @@ -43,15 +43,9 @@ public: //! such as the various parameters to scrypt std::vector<unsigned char> vchOtherDerivationParameters; - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(vchCryptedKey); - READWRITE(vchSalt); - READWRITE(nDerivationMethod); - READWRITE(nDeriveIterations); - READWRITE(vchOtherDerivationParameters); + SERIALIZE_METHODS(CMasterKey, obj) + { + READWRITE(obj.vchCryptedKey, obj.vchSalt, obj.nDerivationMethod, obj.nDeriveIterations, obj.vchOtherDerivationParameters); } CMasterKey() diff --git a/src/wallet/db.cpp b/src/wallet/db.cpp index 1b2bd83a4c..4ed28b0623 100644 --- a/src/wallet/db.cpp +++ b/src/wallet/db.cpp @@ -268,21 +268,14 @@ BerkeleyEnvironment::BerkeleyEnvironment() fMockDb = true; } -BerkeleyEnvironment::VerifyResult BerkeleyEnvironment::Verify(const std::string& strFile, recoverFunc_type recoverFunc, std::string& out_backup_filename) +bool BerkeleyEnvironment::Verify(const std::string& strFile) { LOCK(cs_db); assert(mapFileUseCount.count(strFile) == 0); Db db(dbenv.get(), 0); int result = db.verify(strFile.c_str(), nullptr, nullptr, 0); - if (result == 0) - return VerifyResult::VERIFY_OK; - else if (recoverFunc == nullptr) - return VerifyResult::RECOVER_FAIL; - - // Try to recover: - bool fRecovered = (*recoverFunc)(fs::path(strPath) / strFile, out_backup_filename); - return (fRecovered ? VerifyResult::RECOVER_OK : VerifyResult::RECOVER_FAIL); + return result == 0; } BerkeleyBatch::SafeDbt::SafeDbt() @@ -324,75 +317,6 @@ BerkeleyBatch::SafeDbt::operator Dbt*() return &m_dbt; } -bool BerkeleyBatch::Recover(const fs::path& file_path, void *callbackDataIn, bool (*recoverKVcallback)(void* callbackData, CDataStream ssKey, CDataStream ssValue), std::string& newFilename) -{ - std::string filename; - std::shared_ptr<BerkeleyEnvironment> env = GetWalletEnv(file_path, filename); - - // Recovery procedure: - // move wallet file to walletfilename.timestamp.bak - // Call Salvage with fAggressive=true to - // get as much data as possible. - // Rewrite salvaged data to fresh wallet file - // Set -rescan so any missing transactions will be - // found. - int64_t now = GetTime(); - newFilename = strprintf("%s.%d.bak", filename, now); - - int result = env->dbenv->dbrename(nullptr, filename.c_str(), nullptr, - newFilename.c_str(), DB_AUTO_COMMIT); - if (result == 0) - LogPrintf("Renamed %s to %s\n", filename, newFilename); - else - { - LogPrintf("Failed to rename %s to %s\n", filename, newFilename); - return false; - } - - std::vector<BerkeleyEnvironment::KeyValPair> salvagedData; - bool fSuccess = env->Salvage(newFilename, true, salvagedData); - if (salvagedData.empty()) - { - LogPrintf("Salvage(aggressive) found no records in %s.\n", newFilename); - return false; - } - LogPrintf("Salvage(aggressive) found %u records\n", salvagedData.size()); - - std::unique_ptr<Db> pdbCopy = MakeUnique<Db>(env->dbenv.get(), 0); - int ret = pdbCopy->open(nullptr, // Txn pointer - filename.c_str(), // Filename - "main", // Logical db name - DB_BTREE, // Database type - DB_CREATE, // Flags - 0); - if (ret > 0) { - LogPrintf("Cannot create database file %s\n", filename); - pdbCopy->close(0); - return false; - } - - DbTxn* ptxn = env->TxnBegin(); - for (BerkeleyEnvironment::KeyValPair& row : salvagedData) - { - if (recoverKVcallback) - { - CDataStream ssKey(row.first, SER_DISK, CLIENT_VERSION); - CDataStream ssValue(row.second, SER_DISK, CLIENT_VERSION); - if (!(*recoverKVcallback)(callbackDataIn, ssKey, ssValue)) - continue; - } - Dbt datKey(&row.first[0], row.first.size()); - Dbt datValue(&row.second[0], row.second.size()); - int ret2 = pdbCopy->put(ptxn, &datKey, &datValue, DB_NOOVERWRITE); - if (ret2 > 0) - fSuccess = false; - } - ptxn->commit(0); - pdbCopy->close(0); - - return fSuccess; -} - bool BerkeleyBatch::VerifyEnvironment(const fs::path& file_path, bilingual_str& errorStr) { std::string walletFile; @@ -410,7 +334,7 @@ bool BerkeleyBatch::VerifyEnvironment(const fs::path& file_path, bilingual_str& return true; } -bool BerkeleyBatch::VerifyDatabaseFile(const fs::path& file_path, std::vector<bilingual_str>& warnings, bilingual_str& errorStr, BerkeleyEnvironment::recoverFunc_type recoverFunc) +bool BerkeleyBatch::VerifyDatabaseFile(const fs::path& file_path, bilingual_str& errorStr) { std::string walletFile; std::shared_ptr<BerkeleyEnvironment> env = GetWalletEnv(file_path, walletFile); @@ -418,19 +342,8 @@ bool BerkeleyBatch::VerifyDatabaseFile(const fs::path& file_path, std::vector<bi if (fs::exists(walletDir / walletFile)) { - std::string backup_filename; - BerkeleyEnvironment::VerifyResult r = env->Verify(walletFile, recoverFunc, backup_filename); - if (r == BerkeleyEnvironment::VerifyResult::RECOVER_OK) - { - warnings.push_back(strprintf(_("Warning: Wallet file corrupt, data salvaged!" - " Original %s saved as %s in %s; if" - " your balance or transactions are incorrect you should" - " restore from a backup."), - walletFile, backup_filename, walletDir)); - } - if (r == BerkeleyEnvironment::VerifyResult::RECOVER_FAIL) - { - errorStr = strprintf(_("%s corrupt, salvage failed"), walletFile); + if (!env->Verify(walletFile)) { + errorStr = strprintf(_("%s corrupt. Try using the wallet tool bitcoin-wallet to salvage or restoring a backup."), walletFile); return false; } } @@ -438,72 +351,6 @@ bool BerkeleyBatch::VerifyDatabaseFile(const fs::path& file_path, std::vector<bi return true; } -/* End of headers, beginning of key/value data */ -static const char *HEADER_END = "HEADER=END"; -/* End of key/value data */ -static const char *DATA_END = "DATA=END"; - -bool BerkeleyEnvironment::Salvage(const std::string& strFile, bool fAggressive, std::vector<BerkeleyEnvironment::KeyValPair>& vResult) -{ - LOCK(cs_db); - assert(mapFileUseCount.count(strFile) == 0); - - u_int32_t flags = DB_SALVAGE; - if (fAggressive) - flags |= DB_AGGRESSIVE; - - std::stringstream strDump; - - Db db(dbenv.get(), 0); - int result = db.verify(strFile.c_str(), nullptr, &strDump, flags); - if (result == DB_VERIFY_BAD) { - LogPrintf("BerkeleyEnvironment::Salvage: Database salvage found errors, all data may not be recoverable.\n"); - if (!fAggressive) { - LogPrintf("BerkeleyEnvironment::Salvage: Rerun with aggressive mode to ignore errors and continue.\n"); - return false; - } - } - if (result != 0 && result != DB_VERIFY_BAD) { - LogPrintf("BerkeleyEnvironment::Salvage: Database salvage failed with result %d.\n", result); - return false; - } - - // Format of bdb dump is ascii lines: - // header lines... - // HEADER=END - // hexadecimal key - // hexadecimal value - // ... repeated - // DATA=END - - std::string strLine; - while (!strDump.eof() && strLine != HEADER_END) - getline(strDump, strLine); // Skip past header - - std::string keyHex, valueHex; - while (!strDump.eof() && keyHex != DATA_END) { - getline(strDump, keyHex); - if (keyHex != DATA_END) { - if (strDump.eof()) - break; - getline(strDump, valueHex); - if (valueHex == DATA_END) { - LogPrintf("BerkeleyEnvironment::Salvage: WARNING: Number of keys in data does not match number of values.\n"); - break; - } - vResult.push_back(make_pair(ParseHex(keyHex), ParseHex(valueHex))); - } - } - - if (keyHex != DATA_END) { - LogPrintf("BerkeleyEnvironment::Salvage: WARNING: Unexpected end of file while reading salvage output.\n"); - return false; - } - - return (result == 0); -} - - void BerkeleyEnvironment::CheckpointLSN(const std::string& strFile) { dbenv->txn_checkpoint(0, 0, 0); diff --git a/src/wallet/db.h b/src/wallet/db.h index 37f96a1a96..54ce144ffc 100644 --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -66,26 +66,7 @@ public: bool IsDatabaseLoaded(const std::string& db_filename) const { return m_databases.find(db_filename) != m_databases.end(); } fs::path Directory() const { return strPath; } - /** - * Verify that database file strFile is OK. If it is not, - * call the callback to try to recover. - * This must be called BEFORE strFile is opened. - * Returns true if strFile is OK. - */ - enum class VerifyResult { VERIFY_OK, - RECOVER_OK, - RECOVER_FAIL }; - typedef bool (*recoverFunc_type)(const fs::path& file_path, std::string& out_backup_filename); - VerifyResult Verify(const std::string& strFile, recoverFunc_type recoverFunc, std::string& out_backup_filename); - /** - * Salvage data from a file that Verify says is bad. - * fAggressive sets the DB_AGGRESSIVE flag (see berkeley DB->verify() method documentation). - * Appends binary key/value pairs to vResult, returns true if successful. - * NOTE: reads the entire database into memory, so cannot be used - * for huge databases. - */ - typedef std::pair<std::vector<unsigned char>, std::vector<unsigned char> > KeyValPair; - bool Salvage(const std::string& strFile, bool fAggressive, std::vector<KeyValPair>& vResult); + bool Verify(const std::string& strFile); bool Open(bool retry); void Close(); @@ -245,7 +226,6 @@ public: void Flush(); void Close(); - static bool Recover(const fs::path& file_path, void *callbackDataIn, bool (*recoverKVcallback)(void* callbackData, CDataStream ssKey, CDataStream ssValue), std::string& out_backup_filename); /* flush the wallet passively (TRY_LOCK) ideal to be called periodically */ @@ -253,7 +233,7 @@ public: /* verifies the database environment */ static bool VerifyEnvironment(const fs::path& file_path, bilingual_str& errorStr); /* verifies the database file */ - static bool VerifyDatabaseFile(const fs::path& file_path, std::vector<bilingual_str>& warnings, bilingual_str& errorStr, BerkeleyEnvironment::recoverFunc_type recoverFunc); + static bool VerifyDatabaseFile(const fs::path& file_path, bilingual_str& errorStr); template <typename K, typename T> bool Read(const K& key, T& value) diff --git a/src/wallet/init.cpp b/src/wallet/init.cpp index 6f973aab1c..3885eb6185 100644 --- a/src/wallet/init.cpp +++ b/src/wallet/init.cpp @@ -54,7 +54,6 @@ void WalletInit::AddWalletOptions() const gArgs.AddArg("-paytxfee=<amt>", strprintf("Fee (in %s/kB) to add to transactions you send (default: %s)", CURRENCY_UNIT, FormatMoney(CFeeRate{DEFAULT_PAY_TX_FEE}.GetFeePerK())), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); gArgs.AddArg("-rescan", "Rescan the block chain for missing wallet transactions on startup", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-salvagewallet", "Attempt to recover private keys from a corrupt wallet on startup", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); gArgs.AddArg("-spendzeroconfchange", strprintf("Spend unconfirmed change when sending transactions (default: %u)", DEFAULT_SPEND_ZEROCONF_CHANGE), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); gArgs.AddArg("-txconfirmtarget=<n>", strprintf("If paytxfee is not set, include enough fee so transactions begin confirmation on average within n blocks (default: %u)", DEFAULT_TX_CONFIRM_TARGET), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); gArgs.AddArg("-wallet=<path>", "Specify wallet database path. Can be specified multiple times to load multiple wallets. Path is interpreted relative to <walletdir> if it is not absolute, and will be created if it does not exist (as a directory containing a wallet.dat file and log files). For backwards compatibility this will also accept names of existing data files in <walletdir>.)", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::WALLET); @@ -89,16 +88,6 @@ bool WalletInit::ParameterInteraction() const LogPrintf("%s: parameter interaction: -blocksonly=1 -> setting -walletbroadcast=0\n", __func__); } - if (gArgs.GetBoolArg("-salvagewallet", false)) { - if (is_multiwallet) { - return InitError(strprintf(Untranslated("%s is only allowed with a single wallet file"), "-salvagewallet")); - } - // Rewrite just private keys: rescan to find transactions - if (gArgs.SoftSetBoolArg("-rescan", true)) { - LogPrintf("%s: parameter interaction: -salvagewallet=1 -> setting -rescan=1\n", __func__); - } - } - bool zapwallettxes = gArgs.GetBoolArg("-zapwallettxes", false); // -zapwallettxes implies dropping the mempool on startup if (zapwallettxes && gArgs.SoftSetBoolArg("-persistmempool", false)) { diff --git a/src/wallet/load.cpp b/src/wallet/load.cpp index 45841b2ae1..8df3e78215 100644 --- a/src/wallet/load.cpp +++ b/src/wallet/load.cpp @@ -37,11 +37,6 @@ bool VerifyWallets(interfaces::Chain& chain, const std::vector<std::string>& wal chain.initMessage(_("Verifying wallet(s)...").translated); - // Parameter interaction code should have thrown an error if -salvagewallet - // was enabled with more than wallet file, so the wallet_files size check - // here should have no effect. - bool salvage_wallet = gArgs.GetBoolArg("-salvagewallet", false) && wallet_files.size() <= 1; - // Keep track of each wallet absolute path to detect duplicates. std::set<fs::path> wallet_paths; @@ -55,8 +50,8 @@ bool VerifyWallets(interfaces::Chain& chain, const std::vector<std::string>& wal bilingual_str error_string; std::vector<bilingual_str> warnings; - bool verify_success = CWallet::Verify(chain, location, salvage_wallet, error_string, warnings); - if (!warnings.empty()) chain.initWarning(Join(warnings, "\n", OpTranslated)); + bool verify_success = CWallet::Verify(chain, location, error_string, warnings); + if (!warnings.empty()) chain.initWarning(Join(warnings, Untranslated("\n"))); if (!verify_success) { chain.initError(error_string); return false; @@ -73,7 +68,7 @@ bool LoadWallets(interfaces::Chain& chain, const std::vector<std::string>& walle bilingual_str error; std::vector<bilingual_str> warnings; std::shared_ptr<CWallet> pwallet = CWallet::CreateWalletFromFile(chain, WalletLocation(walletFile), error, warnings); - if (!warnings.empty()) chain.initWarning(Join(warnings, "\n", OpTranslated)); + if (!warnings.empty()) chain.initWarning(Join(warnings, Untranslated("\n"))); if (!pwallet) { chain.initError(error); return false; diff --git a/src/wallet/load.h b/src/wallet/load.h index 5a62e29303..e24b1f2e69 100644 --- a/src/wallet/load.h +++ b/src/wallet/load.h @@ -16,8 +16,6 @@ class Chain; } // namespace interfaces //! Responsible for reading and validating the -wallet arguments and verifying the wallet database. -//! This function will perform salvage on the wallet if requested, as long as only one wallet is -//! being loaded (WalletInit::ParameterInteraction() forbids -salvagewallet, -zapwallettxes or -upgradewallet with multiwallet). bool VerifyWallets(interfaces::Chain& chain, const std::vector<std::string>& wallet_files); //! Load wallet databases. diff --git a/src/wallet/rpcdump.cpp b/src/wallet/rpcdump.cpp index 7bf3d169c3..d5f6d63a46 100644 --- a/src/wallet/rpcdump.cpp +++ b/src/wallet/rpcdump.cpp @@ -746,7 +746,7 @@ UniValue dumpwallet(const JSONRPCRequest& request) // the user could have gotten from another RPC command prior to now wallet.BlockUntilSyncedToCurrentChain(); - LOCK2(pwallet->cs_wallet, spk_man.cs_KeyStore); + LOCK2(wallet.cs_wallet, spk_man.cs_KeyStore); EnsureWalletIsUnlocked(&wallet); @@ -769,7 +769,7 @@ UniValue dumpwallet(const JSONRPCRequest& request) std::map<CKeyID, int64_t> mapKeyBirth; const std::map<CKeyID, int64_t>& mapKeyPool = spk_man.GetAllReserveKeys(); - pwallet->GetKeyBirthTimes(mapKeyBirth); + wallet.GetKeyBirthTimes(mapKeyBirth); std::set<CScriptID> scripts = spk_man.GetCScripts(); diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index c2d314140c..2a9ac189ea 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -1482,10 +1482,9 @@ UniValue listtransactions(const JSONRPCRequest& request) static UniValue listsinceblock(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); + std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { + if (!EnsureWalletIsAvailable(pwallet.get(), request.fHelp)) { return NullUniValue; } @@ -1542,11 +1541,12 @@ static UniValue listsinceblock(const JSONRPCRequest& request) }, }.Check(request); + const CWallet& wallet = *pwallet; // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now - pwallet->BlockUntilSyncedToCurrentChain(); + wallet.BlockUntilSyncedToCurrentChain(); - LOCK(pwallet->cs_wallet); + LOCK(wallet.cs_wallet); // The way the 'height' is initialized is just a workaround for the gcc bug #47679 since version 4.6.0. Optional<int> height = MakeOptional(false, int()); // Height of the specified block or the common ancestor, if the block provided was in a deactivated chain. @@ -1557,9 +1557,9 @@ static UniValue listsinceblock(const JSONRPCRequest& request) uint256 blockId; if (!request.params[0].isNull() && !request.params[0].get_str().empty()) { blockId = ParseHashV(request.params[0], "blockhash"); - height.emplace(); - altheight.emplace(); - if (!pwallet->chain().findCommonAncestor(blockId, pwallet->GetLastBlockHash(), /* ancestor out */ FoundBlock().height(*height), /* blockId out */ FoundBlock().height(*altheight))) { + height = int{}; + altheight = int{}; + if (!wallet.chain().findCommonAncestor(blockId, wallet.GetLastBlockHash(), /* ancestor out */ FoundBlock().height(*height), /* blockId out */ FoundBlock().height(*altheight))) { throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Block not found"); } } @@ -1572,21 +1572,21 @@ static UniValue listsinceblock(const JSONRPCRequest& request) } } - if (ParseIncludeWatchonly(request.params[2], *pwallet)) { + if (ParseIncludeWatchonly(request.params[2], wallet)) { filter |= ISMINE_WATCH_ONLY; } bool include_removed = (request.params[3].isNull() || request.params[3].get_bool()); - int depth = height ? pwallet->GetLastBlockHeight() + 1 - *height : -1; + int depth = height ? wallet.GetLastBlockHeight() + 1 - *height : -1; UniValue transactions(UniValue::VARR); - for (const std::pair<const uint256, CWalletTx>& pairWtx : pwallet->mapWallet) { + for (const std::pair<const uint256, CWalletTx>& pairWtx : wallet.mapWallet) { const CWalletTx& tx = pairWtx.second; if (depth == -1 || abs(tx.GetDepthInMainChain()) < depth) { - ListTransactions(pwallet, tx, 0, true, transactions, filter, nullptr /* filter_label */); + ListTransactions(&wallet, tx, 0, true, transactions, filter, nullptr /* filter_label */); } } @@ -1595,15 +1595,15 @@ static UniValue listsinceblock(const JSONRPCRequest& request) UniValue removed(UniValue::VARR); while (include_removed && altheight && *altheight > *height) { CBlock block; - if (!pwallet->chain().findBlock(blockId, FoundBlock().data(block)) || block.IsNull()) { + if (!wallet.chain().findBlock(blockId, FoundBlock().data(block)) || block.IsNull()) { throw JSONRPCError(RPC_INTERNAL_ERROR, "Can't read block from disk"); } for (const CTransactionRef& tx : block.vtx) { - auto it = pwallet->mapWallet.find(tx->GetHash()); - if (it != pwallet->mapWallet.end()) { + auto it = wallet.mapWallet.find(tx->GetHash()); + if (it != wallet.mapWallet.end()) { // We want all transactions regardless of confirmation count to appear here, // even negative confirmation ones, hence the big negative. - ListTransactions(pwallet, it->second, -100000000, true, removed, filter, nullptr /* filter_label */); + ListTransactions(&wallet, it->second, -100000000, true, removed, filter, nullptr /* filter_label */); } } blockId = block.hashPrevBlock; @@ -1611,7 +1611,7 @@ static UniValue listsinceblock(const JSONRPCRequest& request) } uint256 lastblock; - CHECK_NONFATAL(pwallet->chain().findAncestorByHeight(pwallet->GetLastBlockHash(), pwallet->GetLastBlockHeight() + 1 - target_confirms, FoundBlock().hash(lastblock))); + CHECK_NONFATAL(wallet.chain().findAncestorByHeight(wallet.GetLastBlockHash(), wallet.GetLastBlockHeight() + 1 - target_confirms, FoundBlock().hash(lastblock))); UniValue ret(UniValue::VOBJ); ret.pushKV("transactions", transactions); @@ -2603,7 +2603,7 @@ static UniValue loadwallet(const JSONRPCRequest& request) UniValue obj(UniValue::VOBJ); obj.pushKV("name", wallet->GetName()); - obj.pushKV("warning", Join(warnings, "\n", OpOriginal)); + obj.pushKV("warning", Join(warnings, Untranslated("\n")).original); return obj; } @@ -2726,6 +2726,7 @@ static UniValue createwallet(const JSONRPCRequest& request) } if (!request.params[5].isNull() && request.params[5].get_bool()) { flags |= WALLET_FLAG_DESCRIPTORS; + warnings.emplace_back(Untranslated("Wallet is an experimental descriptor wallet")); } bilingual_str error; @@ -2743,7 +2744,7 @@ static UniValue createwallet(const JSONRPCRequest& request) UniValue obj(UniValue::VOBJ); obj.pushKV("name", wallet->GetName()); - obj.pushKV("warning", Join(warnings, "\n", OpOriginal)); + obj.pushKV("warning", Join(warnings, Untranslated("\n")).original); return obj; } @@ -3979,10 +3980,6 @@ UniValue sethdseed(const JSONRPCRequest& request) LegacyScriptPubKeyMan& spk_man = EnsureLegacyScriptPubKeyMan(*pwallet, true); - if (pwallet->chain().isInitialBlockDownload()) { - throw JSONRPCError(RPC_CLIENT_IN_INITIAL_DOWNLOAD, "Cannot set a new HD seed while still in Initial Block Download"); - } - if (pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { throw JSONRPCError(RPC_WALLET_ERROR, "Cannot set a HD seed to a wallet with private keys disabled"); } diff --git a/src/wallet/salvage.cpp b/src/wallet/salvage.cpp new file mode 100644 index 0000000000..70067ebef0 --- /dev/null +++ b/src/wallet/salvage.cpp @@ -0,0 +1,150 @@ +// Copyright (c) 2009-2010 Satoshi Nakamoto +// Copyright (c) 2009-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <fs.h> +#include <streams.h> +#include <wallet/salvage.h> +#include <wallet/wallet.h> +#include <wallet/walletdb.h> + +/* End of headers, beginning of key/value data */ +static const char *HEADER_END = "HEADER=END"; +/* End of key/value data */ +static const char *DATA_END = "DATA=END"; +typedef std::pair<std::vector<unsigned char>, std::vector<unsigned char> > KeyValPair; + +bool RecoverDatabaseFile(const fs::path& file_path) +{ + std::string filename; + std::shared_ptr<BerkeleyEnvironment> env = GetWalletEnv(file_path, filename); + + // Recovery procedure: + // move wallet file to walletfilename.timestamp.bak + // Call Salvage with fAggressive=true to + // get as much data as possible. + // Rewrite salvaged data to fresh wallet file + // Set -rescan so any missing transactions will be + // found. + int64_t now = GetTime(); + std::string newFilename = strprintf("%s.%d.bak", filename, now); + + int result = env->dbenv->dbrename(nullptr, filename.c_str(), nullptr, + newFilename.c_str(), DB_AUTO_COMMIT); + if (result == 0) + LogPrintf("Renamed %s to %s\n", filename, newFilename); + else + { + LogPrintf("Failed to rename %s to %s\n", filename, newFilename); + return false; + } + + /** + * Salvage data from a file. The DB_AGGRESSIVE flag is being used (see berkeley DB->verify() method documentation). + * key/value pairs are appended to salvagedData which are then written out to a new wallet file. + * NOTE: reads the entire database into memory, so cannot be used + * for huge databases. + */ + std::vector<KeyValPair> salvagedData; + + std::stringstream strDump; + + Db db(env->dbenv.get(), 0); + result = db.verify(newFilename.c_str(), nullptr, &strDump, DB_SALVAGE | DB_AGGRESSIVE); + if (result == DB_VERIFY_BAD) { + LogPrintf("Salvage: Database salvage found errors, all data may not be recoverable.\n"); + } + if (result != 0 && result != DB_VERIFY_BAD) { + LogPrintf("Salvage: Database salvage failed with result %d.\n", result); + return false; + } + + // Format of bdb dump is ascii lines: + // header lines... + // HEADER=END + // hexadecimal key + // hexadecimal value + // ... repeated + // DATA=END + + std::string strLine; + while (!strDump.eof() && strLine != HEADER_END) + getline(strDump, strLine); // Skip past header + + std::string keyHex, valueHex; + while (!strDump.eof() && keyHex != DATA_END) { + getline(strDump, keyHex); + if (keyHex != DATA_END) { + if (strDump.eof()) + break; + getline(strDump, valueHex); + if (valueHex == DATA_END) { + LogPrintf("Salvage: WARNING: Number of keys in data does not match number of values.\n"); + break; + } + salvagedData.push_back(make_pair(ParseHex(keyHex), ParseHex(valueHex))); + } + } + + bool fSuccess; + if (keyHex != DATA_END) { + LogPrintf("Salvage: WARNING: Unexpected end of file while reading salvage output.\n"); + fSuccess = false; + } else { + fSuccess = (result == 0); + } + + if (salvagedData.empty()) + { + LogPrintf("Salvage(aggressive) found no records in %s.\n", newFilename); + return false; + } + LogPrintf("Salvage(aggressive) found %u records\n", salvagedData.size()); + + std::unique_ptr<Db> pdbCopy = MakeUnique<Db>(env->dbenv.get(), 0); + int ret = pdbCopy->open(nullptr, // Txn pointer + filename.c_str(), // Filename + "main", // Logical db name + DB_BTREE, // Database type + DB_CREATE, // Flags + 0); + if (ret > 0) { + LogPrintf("Cannot create database file %s\n", filename); + pdbCopy->close(0); + return false; + } + + DbTxn* ptxn = env->TxnBegin(); + CWallet dummyWallet(nullptr, WalletLocation(), WalletDatabase::CreateDummy()); + for (KeyValPair& row : salvagedData) + { + /* Filter for only private key type KV pairs to be added to the salvaged wallet */ + CDataStream ssKey(row.first, SER_DISK, CLIENT_VERSION); + CDataStream ssValue(row.second, SER_DISK, CLIENT_VERSION); + std::string strType, strErr; + bool fReadOK; + { + // Required in LoadKeyMetadata(): + LOCK(dummyWallet.cs_wallet); + fReadOK = ReadKeyValue(&dummyWallet, ssKey, ssValue, strType, strErr); + } + if (!WalletBatch::IsKeyType(strType) && strType != DBKeys::HDCHAIN) { + continue; + } + if (!fReadOK) + { + LogPrintf("WARNING: WalletBatch::Recover skipping %s: %s\n", strType, strErr); + continue; + } + Dbt datKey(&row.first[0], row.first.size()); + Dbt datValue(&row.second[0], row.second.size()); + int ret2 = pdbCopy->put(ptxn, &datKey, &datValue, DB_NOOVERWRITE); + if (ret2 > 0) + fSuccess = false; + } + ptxn->commit(0); + pdbCopy->close(0); + + return fSuccess; +} diff --git a/src/wallet/salvage.h b/src/wallet/salvage.h new file mode 100644 index 0000000000..e361930f5e --- /dev/null +++ b/src/wallet/salvage.h @@ -0,0 +1,14 @@ +// Copyright (c) 2009-2010 Satoshi Nakamoto +// Copyright (c) 2009-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_WALLET_SALVAGE_H +#define BITCOIN_WALLET_SALVAGE_H + +#include <fs.h> +#include <streams.h> + +bool RecoverDatabaseFile(const fs::path& file_path); + +#endif // BITCOIN_WALLET_SALVAGE_H diff --git a/src/wallet/scriptpubkeyman.cpp b/src/wallet/scriptpubkeyman.cpp index e4be5045e1..8a2a798644 100644 --- a/src/wallet/scriptpubkeyman.cpp +++ b/src/wallet/scriptpubkeyman.cpp @@ -12,6 +12,9 @@ #include <util/translation.h> #include <wallet/scriptpubkeyman.h> +//! Value for the first BIP 32 hardened derivation. Can be used as a bit mask and as a value. See BIP 32 for more details. +const uint32_t BIP32_HARDENED_KEY_LIMIT = 0x80000000; + bool LegacyScriptPubKeyMan::GetNewDestination(const OutputType type, CTxDestination& dest, std::string& error) { LOCK(cs_KeyStore); @@ -220,6 +223,7 @@ bool LegacyScriptPubKeyMan::CheckDecryptionKey(const CKeyingMaterial& master_key bool keyPass = mapCryptedKeys.empty(); // Always pass when there are no encrypted keys bool keyFail = false; CryptedKeyMap::const_iterator mi = mapCryptedKeys.begin(); + WalletBatch batch(m_storage.GetDatabase()); for (; mi != mapCryptedKeys.end(); ++mi) { const CPubKey &vchPubKey = (*mi).second.first; @@ -233,6 +237,10 @@ bool LegacyScriptPubKeyMan::CheckDecryptionKey(const CKeyingMaterial& master_key keyPass = true; if (fDecryptionThoroughlyChecked) break; + else { + // Rewrite these encrypted keys with checksums + batch.WriteCryptedKey(vchPubKey, vchCryptedSecret, mapKeyMetadata[vchPubKey.GetID()]); + } } if (keyPass && keyFail) { @@ -290,6 +298,43 @@ bool LegacyScriptPubKeyMan::GetReservedDestination(const OutputType type, bool i return true; } +bool LegacyScriptPubKeyMan::TopUpInactiveHDChain(const CKeyID seed_id, int64_t index, bool internal) +{ + LOCK(cs_KeyStore); + + if (m_storage.IsLocked()) return false; + + auto it = m_inactive_hd_chains.find(seed_id); + if (it == m_inactive_hd_chains.end()) { + return false; + } + + CHDChain& chain = it->second; + + // Top up key pool + int64_t target_size = std::max(gArgs.GetArg("-keypool", DEFAULT_KEYPOOL_SIZE), (int64_t) 1); + + // "size" of the keypools. Not really the size, actually the difference between index and the chain counter + // Since chain counter is 1 based and index is 0 based, one of them needs to be offset by 1. + int64_t kp_size = (internal ? chain.nInternalChainCounter : chain.nExternalChainCounter) - (index + 1); + + // make sure the keypool fits the user-selected target (-keypool) + int64_t missing = std::max(target_size - kp_size, (int64_t) 0); + + if (missing > 0) { + WalletBatch batch(m_storage.GetDatabase()); + for (int64_t i = missing; i > 0; --i) { + GenerateNewKey(batch, chain, internal); + } + if (internal) { + WalletLogPrintf("inactive seed with id %s added %d internal keys\n", HexStr(seed_id), missing); + } else { + WalletLogPrintf("inactive seed with id %s added %d keys\n", HexStr(seed_id), missing); + } + } + return true; +} + void LegacyScriptPubKeyMan::MarkUnusedAddresses(const CScript& script) { LOCK(cs_KeyStore); @@ -297,13 +342,28 @@ void LegacyScriptPubKeyMan::MarkUnusedAddresses(const CScript& script) for (const auto& keyid : GetAffectedKeys(script, *this)) { std::map<CKeyID, int64_t>::const_iterator mi = m_pool_key_to_index.find(keyid); if (mi != m_pool_key_to_index.end()) { - WalletLogPrintf("%s: Detected a used keypool key, mark all keypool key up to this key as used\n", __func__); + WalletLogPrintf("%s: Detected a used keypool key, mark all keypool keys up to this key as used\n", __func__); MarkReserveKeysAsUsed(mi->second); if (!TopUp()) { WalletLogPrintf("%s: Topping up keypool failed (locked wallet)\n", __func__); } } + + // Find the key's metadata and check if it's seed id (if it has one) is inactive, i.e. it is not the current m_hd_chain seed id. + // If so, TopUp the inactive hd chain + auto it = mapKeyMetadata.find(keyid); + if (it != mapKeyMetadata.end()){ + CKeyMetadata meta = it->second; + if (!meta.hd_seed_id.IsNull() && meta.hd_seed_id != m_hd_chain.seed_id) { + bool internal = (meta.key_origin.path[1] & ~BIP32_HARDENED_KEY_LIMIT) != 0; + int64_t index = meta.key_origin.path[2] & ~BIP32_HARDENED_KEY_LIMIT; + + if (!TopUpInactiveHDChain(meta.hd_seed_id, index, internal)) { + WalletLogPrintf("%s: Adding inactive seed keys failed\n", __func__); + } + } + } } } @@ -357,7 +417,7 @@ bool LegacyScriptPubKeyMan::SetupGeneration(bool force) bool LegacyScriptPubKeyMan::IsHDEnabled() const { - return !hdChain.seed_id.IsNull(); + return !m_hd_chain.seed_id.IsNull(); } bool LegacyScriptPubKeyMan::CanGetAddresses(bool internal) const @@ -713,8 +773,13 @@ bool LegacyScriptPubKeyMan::AddKeyPubKeyInner(const CKey& key, const CPubKey &pu return true; } -bool LegacyScriptPubKeyMan::LoadCryptedKey(const CPubKey &vchPubKey, const std::vector<unsigned char> &vchCryptedSecret) +bool LegacyScriptPubKeyMan::LoadCryptedKey(const CPubKey &vchPubKey, const std::vector<unsigned char> &vchCryptedSecret, bool checksum_valid) { + // Set fDecryptionThoroughlyChecked to false when the checksum is invalid + if (!checksum_valid) { + fDecryptionThoroughlyChecked = false; + } + return AddCryptedKeyInner(vchPubKey, vchCryptedSecret); } @@ -838,10 +903,27 @@ bool LegacyScriptPubKeyMan::AddWatchOnly(const CScript& dest, int64_t nCreateTim void LegacyScriptPubKeyMan::SetHDChain(const CHDChain& chain, bool memonly) { LOCK(cs_KeyStore); - if (!memonly && !WalletBatch(m_storage.GetDatabase()).WriteHDChain(chain)) - throw std::runtime_error(std::string(__func__) + ": writing chain failed"); + // memonly == true means we are loading the wallet file + // memonly == false means that the chain is actually being changed + if (!memonly) { + // Store the new chain + if (!WalletBatch(m_storage.GetDatabase()).WriteHDChain(chain)) { + throw std::runtime_error(std::string(__func__) + ": writing chain failed"); + } + // When there's an old chain, add it as an inactive chain as we are now rotating hd chains + if (!m_hd_chain.seed_id.IsNull()) { + AddInactiveHDChain(m_hd_chain); + } + } + + m_hd_chain = chain; +} - hdChain = chain; +void LegacyScriptPubKeyMan::AddInactiveHDChain(const CHDChain& chain) +{ + LOCK(cs_KeyStore); + assert(!chain.seed_id.IsNull()); + m_inactive_hd_chains[chain.seed_id] = chain; } bool LegacyScriptPubKeyMan::HaveKey(const CKeyID &address) const @@ -920,7 +1002,7 @@ bool LegacyScriptPubKeyMan::GetPubKey(const CKeyID &address, CPubKey& vchPubKeyO return GetWatchPubKey(address, vchPubKeyOut); } -CPubKey LegacyScriptPubKeyMan::GenerateNewKey(WalletBatch &batch, bool internal) +CPubKey LegacyScriptPubKeyMan::GenerateNewKey(WalletBatch &batch, CHDChain& hd_chain, bool internal) { assert(!m_storage.IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)); assert(!m_storage.IsWalletFlagSet(WALLET_FLAG_BLANK_WALLET)); @@ -935,7 +1017,7 @@ CPubKey LegacyScriptPubKeyMan::GenerateNewKey(WalletBatch &batch, bool internal) // use HD key derivation if HD was enabled during wallet creation and a seed is present if (IsHDEnabled()) { - DeriveNewChildKey(batch, metadata, secret, (m_storage.CanSupportFeature(FEATURE_HD_SPLIT) ? internal : false)); + DeriveNewChildKey(batch, metadata, secret, hd_chain, (m_storage.CanSupportFeature(FEATURE_HD_SPLIT) ? internal : false)); } else { secret.MakeNewKey(fCompressed); } @@ -957,9 +1039,7 @@ CPubKey LegacyScriptPubKeyMan::GenerateNewKey(WalletBatch &batch, bool internal) return pubkey; } -const uint32_t BIP32_HARDENED_KEY_LIMIT = 0x80000000; - -void LegacyScriptPubKeyMan::DeriveNewChildKey(WalletBatch &batch, CKeyMetadata& metadata, CKey& secret, bool internal) +void LegacyScriptPubKeyMan::DeriveNewChildKey(WalletBatch &batch, CKeyMetadata& metadata, CKey& secret, CHDChain& hd_chain, bool internal) { // for now we use a fixed keypath scheme of m/0'/0'/k CKey seed; //seed (256bit) @@ -969,7 +1049,7 @@ void LegacyScriptPubKeyMan::DeriveNewChildKey(WalletBatch &batch, CKeyMetadata& CExtKey childKey; //key at m/0'/0'/<n>' // try to get the seed - if (!GetKey(hdChain.seed_id, seed)) + if (!GetKey(hd_chain.seed_id, seed)) throw std::runtime_error(std::string(__func__) + ": seed not found"); masterKey.SetSeed(seed.begin(), seed.size()); @@ -988,30 +1068,30 @@ void LegacyScriptPubKeyMan::DeriveNewChildKey(WalletBatch &batch, CKeyMetadata& // childIndex | BIP32_HARDENED_KEY_LIMIT = derive childIndex in hardened child-index-range // example: 1 | BIP32_HARDENED_KEY_LIMIT == 0x80000001 == 2147483649 if (internal) { - chainChildKey.Derive(childKey, hdChain.nInternalChainCounter | BIP32_HARDENED_KEY_LIMIT); - metadata.hdKeypath = "m/0'/1'/" + ToString(hdChain.nInternalChainCounter) + "'"; + chainChildKey.Derive(childKey, hd_chain.nInternalChainCounter | BIP32_HARDENED_KEY_LIMIT); + metadata.hdKeypath = "m/0'/1'/" + ToString(hd_chain.nInternalChainCounter) + "'"; metadata.key_origin.path.push_back(0 | BIP32_HARDENED_KEY_LIMIT); metadata.key_origin.path.push_back(1 | BIP32_HARDENED_KEY_LIMIT); - metadata.key_origin.path.push_back(hdChain.nInternalChainCounter | BIP32_HARDENED_KEY_LIMIT); - hdChain.nInternalChainCounter++; + metadata.key_origin.path.push_back(hd_chain.nInternalChainCounter | BIP32_HARDENED_KEY_LIMIT); + hd_chain.nInternalChainCounter++; } else { - chainChildKey.Derive(childKey, hdChain.nExternalChainCounter | BIP32_HARDENED_KEY_LIMIT); - metadata.hdKeypath = "m/0'/0'/" + ToString(hdChain.nExternalChainCounter) + "'"; + chainChildKey.Derive(childKey, hd_chain.nExternalChainCounter | BIP32_HARDENED_KEY_LIMIT); + metadata.hdKeypath = "m/0'/0'/" + ToString(hd_chain.nExternalChainCounter) + "'"; metadata.key_origin.path.push_back(0 | BIP32_HARDENED_KEY_LIMIT); metadata.key_origin.path.push_back(0 | BIP32_HARDENED_KEY_LIMIT); - metadata.key_origin.path.push_back(hdChain.nExternalChainCounter | BIP32_HARDENED_KEY_LIMIT); - hdChain.nExternalChainCounter++; + metadata.key_origin.path.push_back(hd_chain.nExternalChainCounter | BIP32_HARDENED_KEY_LIMIT); + hd_chain.nExternalChainCounter++; } } while (HaveKey(childKey.key.GetPubKey().GetID())); secret = childKey.key; - metadata.hd_seed_id = hdChain.seed_id; + metadata.hd_seed_id = hd_chain.seed_id; CKeyID master_id = masterKey.key.GetPubKey().GetID(); std::copy(master_id.begin(), master_id.begin() + 4, metadata.key_origin.fingerprint); metadata.has_key_origin = true; // update the chain model in the database - if (!batch.WriteHDChain(hdChain)) - throw std::runtime_error(std::string(__func__) + ": Writing HD chain model failed"); + if (hd_chain.seed_id == m_hd_chain.seed_id && !batch.WriteHDChain(hd_chain)) + throw std::runtime_error(std::string(__func__) + ": writing HD chain model failed"); } void LegacyScriptPubKeyMan::LoadKeyPool(int64_t nIndex, const CKeyPool &keypool) @@ -1166,7 +1246,7 @@ bool LegacyScriptPubKeyMan::TopUp(unsigned int kpSize) internal = true; } - CPubKey pubkey(GenerateNewKey(batch, internal)); + CPubKey pubkey(GenerateNewKey(batch, m_hd_chain, internal)); AddKeypoolPubkeyWithDB(pubkey, internal, batch); } if (missingInternal + missingExternal > 0) { @@ -1239,7 +1319,7 @@ bool LegacyScriptPubKeyMan::GetKeyFromPool(CPubKey& result, const OutputType typ if (!ReserveKeyFromKeyPool(nIndex, keypool, internal) && !m_storage.IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { if (m_storage.IsLocked()) return false; WalletBatch batch(m_storage.GetDatabase()); - result = GenerateNewKey(batch, internal); + result = GenerateNewKey(batch, m_hd_chain, internal); return true; } KeepDestination(nIndex, type); @@ -1497,7 +1577,7 @@ std::set<CKeyID> LegacyScriptPubKeyMan::GetKeys() const return set_address; } -void LegacyScriptPubKeyMan::SetType(OutputType type, bool internal) {} +void LegacyScriptPubKeyMan::SetInternal(bool internal) {} bool DescriptorScriptPubKeyMan::GetNewDestination(const OutputType type, CTxDestination& dest, std::string& error) { @@ -1509,7 +1589,9 @@ bool DescriptorScriptPubKeyMan::GetNewDestination(const OutputType type, CTxDest { LOCK(cs_desc_man); assert(m_wallet_descriptor.descriptor->IsSingleType()); // This is a combo descriptor which should not be an active descriptor - if (type != m_address_type) { + Optional<OutputType> desc_addr_type = m_wallet_descriptor.descriptor->GetOutputType(); + assert(desc_addr_type); + if (type != *desc_addr_type) { throw std::runtime_error(std::string(__func__) + ": Types are inconsistent"); } @@ -1777,7 +1859,7 @@ bool DescriptorScriptPubKeyMan::AddDescriptorKeyWithDB(WalletBatch& batch, const } } -bool DescriptorScriptPubKeyMan::SetupDescriptorGeneration(const CExtKey& master_key) +bool DescriptorScriptPubKeyMan::SetupDescriptorGeneration(const CExtKey& master_key, OutputType addr_type) { LOCK(cs_desc_man); assert(m_storage.IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)); @@ -1794,7 +1876,7 @@ bool DescriptorScriptPubKeyMan::SetupDescriptorGeneration(const CExtKey& master_ // Build descriptor string std::string desc_prefix; std::string desc_suffix = "/*)"; - switch (m_address_type) { + switch (addr_type) { case OutputType::LEGACY: { desc_prefix = "pkh(" + xpub + "/44'"; break; @@ -2076,9 +2158,8 @@ uint256 DescriptorScriptPubKeyMan::GetID() const return id; } -void DescriptorScriptPubKeyMan::SetType(OutputType type, bool internal) +void DescriptorScriptPubKeyMan::SetInternal(bool internal) { - this->m_address_type = type; this->m_internal = internal; } diff --git a/src/wallet/scriptpubkeyman.h b/src/wallet/scriptpubkeyman.h index 4c002edf2d..d62d30f339 100644 --- a/src/wallet/scriptpubkeyman.h +++ b/src/wallet/scriptpubkeyman.h @@ -18,6 +18,8 @@ #include <boost/signals2/signal.hpp> +#include <unordered_map> + enum class OutputType; struct bilingual_str; @@ -110,40 +112,52 @@ public: CKeyPool(); CKeyPool(const CPubKey& vchPubKeyIn, bool internalIn); - ADD_SERIALIZE_METHODS; + template<typename Stream> + void Serialize(Stream& s) const + { + int nVersion = s.GetVersion(); + if (!(s.GetType() & SER_GETHASH)) { + s << nVersion; + } + s << nTime << vchPubKey << fInternal << m_pre_split; + } - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { + template<typename Stream> + void Unserialize(Stream& s) + { int nVersion = s.GetVersion(); - if (!(s.GetType() & SER_GETHASH)) - READWRITE(nVersion); - READWRITE(nTime); - READWRITE(vchPubKey); - if (ser_action.ForRead()) { - try { - READWRITE(fInternal); - } - catch (std::ios_base::failure&) { - /* flag as external address if we can't read the internal boolean - (this will be the case for any wallet before the HD chain split version) */ - fInternal = false; - } - try { - READWRITE(m_pre_split); - } - catch (std::ios_base::failure&) { - /* flag as postsplit address if we can't read the m_pre_split boolean - (this will be the case for any wallet that upgrades to HD chain split)*/ - m_pre_split = false; - } + if (!(s.GetType() & SER_GETHASH)) { + s >> nVersion; + } + s >> nTime >> vchPubKey; + try { + s >> fInternal; + } catch (std::ios_base::failure&) { + /* flag as external address if we can't read the internal boolean + (this will be the case for any wallet before the HD chain split version) */ + fInternal = false; } - else { - READWRITE(fInternal); - READWRITE(m_pre_split); + try { + s >> m_pre_split; + } catch (std::ios_base::failure&) { + /* flag as postsplit address if we can't read the m_pre_split boolean + (this will be the case for any wallet that upgrades to HD chain split) */ + m_pre_split = false; } } }; +class KeyIDHasher +{ +public: + KeyIDHasher() {} + + size_t operator()(const CKeyID& id) const + { + return id.GetUint64(0); + } +}; + /* * A class implementing ScriptPubKeyMan manages some (or all) scriptPubKeys used in a wallet. * It contains the scripts and keys related to the scriptPubKeys it manages. @@ -224,7 +238,7 @@ public: virtual uint256 GetID() const { return uint256(); } - virtual void SetType(OutputType type, bool internal) {} + virtual void SetInternal(bool internal) {} /** Prepends the wallet name in logging output to ease debugging in multi-wallet use cases */ template<typename... Params> @@ -243,7 +257,7 @@ class LegacyScriptPubKeyMan : public ScriptPubKeyMan, public FillableSigningProv { private: //! keeps track of whether Unlock has run a thorough check before - bool fDecryptionThoroughlyChecked = false; + bool fDecryptionThoroughlyChecked = true; using WatchOnlySet = std::set<CScript>; using WatchKeyMap = std::map<CKeyID, CPubKey>; @@ -288,10 +302,11 @@ private: bool AddKeyOriginWithDB(WalletBatch& batch, const CPubKey& pubkey, const KeyOriginInfo& info); /* the HD chain data model (external chain counters) */ - CHDChain hdChain; + CHDChain m_hd_chain; + std::unordered_map<CKeyID, CHDChain, KeyIDHasher> m_inactive_hd_chains; /* HD derive new child key (on internal or external chain) */ - void DeriveNewChildKey(WalletBatch& batch, CKeyMetadata& metadata, CKey& secret, bool internal = false) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); + void DeriveNewChildKey(WalletBatch& batch, CKeyMetadata& metadata, CKey& secret, CHDChain& hd_chain, bool internal = false) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); std::set<int64_t> setInternalKeyPool GUARDED_BY(cs_KeyStore); std::set<int64_t> setExternalKeyPool GUARDED_BY(cs_KeyStore); @@ -320,6 +335,18 @@ private: */ bool ReserveKeyFromKeyPool(int64_t& nIndex, CKeyPool& keypool, bool fRequestedInternal); + /** + * Like TopUp() but adds keys for inactive HD chains. + * Ensures that there are at least -keypool number of keys derived after the given index. + * + * @param seed_id the CKeyID for the HD seed. + * @param index the index to start generating keys from + * @param internal whether the internal chain should be used. true for internal chain, false for external chain. + * + * @return true if seed was found and keys were derived. false if unable to derive seeds + */ + bool TopUpInactiveHDChain(const CKeyID seed_id, int64_t index, bool internal); + public: using ScriptPubKeyMan::ScriptPubKeyMan; @@ -370,7 +397,7 @@ public: uint256 GetID() const override; - void SetType(OutputType type, bool internal) override; + void SetInternal(bool internal) override; // Map from Key ID to key metadata. std::map<CKeyID, CKeyMetadata> mapKeyMetadata GUARDED_BY(cs_KeyStore); @@ -385,7 +412,7 @@ public: //! Adds an encrypted key to the store, and saves it to disk. bool AddCryptedKey(const CPubKey &vchPubKey, const std::vector<unsigned char> &vchCryptedSecret); //! Adds an encrypted key to the store, without saving it to disk (used by LoadWallet) - bool LoadCryptedKey(const CPubKey &vchPubKey, const std::vector<unsigned char> &vchCryptedSecret); + bool LoadCryptedKey(const CPubKey &vchPubKey, const std::vector<unsigned char> &vchCryptedSecret, bool checksum_valid); void UpdateTimeFirstKey(int64_t nCreateTime) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); //! Adds a CScript to the store bool LoadCScript(const CScript& redeemScript); @@ -393,11 +420,12 @@ public: void LoadKeyMetadata(const CKeyID& keyID, const CKeyMetadata &metadata); void LoadScriptMetadata(const CScriptID& script_id, const CKeyMetadata &metadata); //! Generate a new key - CPubKey GenerateNewKey(WalletBatch& batch, bool internal = false) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); + CPubKey GenerateNewKey(WalletBatch& batch, CHDChain& hd_chain, bool internal = false) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); /* Set the HD chain model (chain child index counters) */ void SetHDChain(const CHDChain& chain, bool memonly); - const CHDChain& GetHDChain() const { return hdChain; } + const CHDChain& GetHDChain() const { return m_hd_chain; } + void AddInactiveHDChain(const CHDChain& chain); //! Adds a watch-only address to the store, without saving it to disk (used by LoadWallet) bool LoadWatchOnly(const CScript &dest); @@ -497,14 +525,11 @@ private: PubKeyMap m_map_pubkeys GUARDED_BY(cs_desc_man); int32_t m_max_cached_index = -1; - OutputType m_address_type; bool m_internal = false; KeyMap m_map_keys GUARDED_BY(cs_desc_man); CryptedKeyMap m_map_crypted_keys GUARDED_BY(cs_desc_man); - bool SetCrypted(); - //! keeps track of whether Unlock has run a thorough check before bool m_decryption_thoroughly_checked = false; @@ -524,9 +549,9 @@ public: : ScriptPubKeyMan(storage), m_wallet_descriptor(descriptor) {} - DescriptorScriptPubKeyMan(WalletStorage& storage, OutputType address_type, bool internal) + DescriptorScriptPubKeyMan(WalletStorage& storage, bool internal) : ScriptPubKeyMan(storage), - m_address_type(address_type), m_internal(internal) + m_internal(internal) {} mutable RecursiveMutex cs_desc_man; @@ -551,7 +576,7 @@ public: bool IsHDEnabled() const override; //! Setup descriptors based on the given CExtkey - bool SetupDescriptorGeneration(const CExtKey& master_key); + bool SetupDescriptorGeneration(const CExtKey& master_key, OutputType addr_type); bool HavePrivateKeys() const override; @@ -575,7 +600,7 @@ public: uint256 GetID() const override; - void SetType(OutputType type, bool internal) override; + void SetInternal(bool internal) override; void SetCache(const DescriptorCache& cache); diff --git a/src/wallet/test/wallet_tests.cpp b/src/wallet/test/wallet_tests.cpp index d888b8f842..3654420eb2 100644 --- a/src/wallet/test/wallet_tests.cpp +++ b/src/wallet/test/wallet_tests.cpp @@ -15,6 +15,7 @@ #include <rpc/server.h> #include <test/util/logging.h> #include <test/util/setup_common.h> +#include <util/ref.h> #include <util/translation.h> #include <validation.h> #include <wallet/coincontrol.h> @@ -117,7 +118,7 @@ BOOST_FIXTURE_TEST_CASE(scan_for_wallet_transactions, TestChain100Setup) // Prune the older block file. { LOCK(cs_main); - PruneOneBlockFile(oldTip->GetBlockPos().nFile); + EnsureChainman(m_node).PruneOneBlockFile(oldTip->GetBlockPos().nFile); } UnlinkPrunedFiles({oldTip->GetBlockPos().nFile}); @@ -143,7 +144,7 @@ BOOST_FIXTURE_TEST_CASE(scan_for_wallet_transactions, TestChain100Setup) // Prune the remaining block file. { LOCK(cs_main); - PruneOneBlockFile(newTip->GetBlockPos().nFile); + EnsureChainman(m_node).PruneOneBlockFile(newTip->GetBlockPos().nFile); } UnlinkPrunedFiles({newTip->GetBlockPos().nFile}); @@ -180,7 +181,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) // Prune the older block file. { LOCK(cs_main); - PruneOneBlockFile(oldTip->GetBlockPos().nFile); + EnsureChainman(m_node).PruneOneBlockFile(oldTip->GetBlockPos().nFile); } UnlinkPrunedFiles({oldTip->GetBlockPos().nFile}); @@ -208,7 +209,8 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) key.pushKV("timestamp", newTip->GetBlockTimeMax() + TIMESTAMP_WINDOW + 1); key.pushKV("internal", UniValue(true)); keys.push_back(key); - JSONRPCRequest request; + util::Ref context; + JSONRPCRequest request(context); request.params.setArray(); request.params.push_back(keys); @@ -262,7 +264,8 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) AddWallet(wallet); wallet->SetLastBlockProcessed(::ChainActive().Height(), ::ChainActive().Tip()->GetBlockHash()); } - JSONRPCRequest request; + util::Ref context; + JSONRPCRequest request(context); request.params.setArray(); request.params.push_back(backup_file); @@ -277,7 +280,8 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) LOCK(wallet->cs_wallet); wallet->SetupLegacyScriptPubKeyMan(); - JSONRPCRequest request; + util::Ref context; + JSONRPCRequest request(context); request.params.setArray(); request.params.push_back(backup_file); AddWallet(wallet); diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 2b45c6a536..7824563254 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -153,7 +153,7 @@ void UnloadWallet(std::shared_ptr<CWallet>&& wallet) std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings) { try { - if (!CWallet::Verify(chain, location, false, error, warnings)) { + if (!CWallet::Verify(chain, location, error, warnings)) { error = Untranslated("Wallet file verification failed.") + Untranslated(" ") + error; return nullptr; } @@ -195,7 +195,7 @@ WalletCreationStatus CreateWallet(interfaces::Chain& chain, const SecureString& } // Wallet::Verify will check if we're trying to create a wallet with a duplicate name. - if (!CWallet::Verify(chain, location, false, error, warnings)) { + if (!CWallet::Verify(chain, location, error, warnings)) { error = Untranslated("Wallet file verification failed.") + Untranslated(" ") + error; return WalletCreationStatus::CREATION_FAILED; } @@ -1982,10 +1982,6 @@ void CWallet::ResendWalletTransactions() nNextResend = GetTime() + (12 * 60 * 60) + GetRand(24 * 60 * 60); if (fFirst) return; - // Only do it if there's been a new block since last time - if (m_best_block_time < nLastResend) return; - nLastResend = GetTime(); - int submitted_tx_count = 0; { // cs_wallet scope @@ -3654,7 +3650,7 @@ std::vector<std::string> CWallet::GetDestValues(const std::string& prefix) const return values; } -bool CWallet::Verify(interfaces::Chain& chain, const WalletLocation& location, bool salvage_wallet, bilingual_str& error_string, std::vector<bilingual_str>& warnings) +bool CWallet::Verify(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error_string, std::vector<bilingual_str>& warnings) { // Do some checking on wallet path. It should be either a: // @@ -3694,16 +3690,7 @@ bool CWallet::Verify(interfaces::Chain& chain, const WalletLocation& location, b return false; } - if (salvage_wallet) { - // Recover readable keypairs: - CWallet dummyWallet(&chain, WalletLocation(), WalletDatabase::CreateDummy()); - std::string backup_filename; - if (!WalletBatch::Recover(wallet_path, (void *)&dummyWallet, WalletBatch::RecoverKeysOnlyFilter, backup_filename)) { - return false; - } - } - - return WalletBatch::VerifyDatabaseFile(wallet_path, warnings, error_string); + return WalletBatch::VerifyDatabaseFile(wallet_path, error_string); } std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings, uint64_t wallet_creation_flags) @@ -4366,7 +4353,7 @@ void CWallet::SetupDescriptorScriptPubKeyMans() for (bool internal : {false, true}) { for (OutputType t : OUTPUT_TYPES) { - auto spk_manager = std::unique_ptr<DescriptorScriptPubKeyMan>(new DescriptorScriptPubKeyMan(*this, t, internal)); + auto spk_manager = std::unique_ptr<DescriptorScriptPubKeyMan>(new DescriptorScriptPubKeyMan(*this, internal)); if (IsCrypted()) { if (IsLocked()) { throw std::runtime_error(std::string(__func__) + ": Wallet is locked, cannot setup new descriptors"); @@ -4375,7 +4362,7 @@ void CWallet::SetupDescriptorScriptPubKeyMans() throw std::runtime_error(std::string(__func__) + ": Could not encrypt new descriptors"); } } - spk_manager->SetupDescriptorGeneration(master_key); + spk_manager->SetupDescriptorGeneration(master_key, t); uint256 id = spk_manager->GetID(); m_spk_managers[id] = std::move(spk_manager); SetActiveScriptPubKeyMan(id, t, internal); @@ -4388,7 +4375,7 @@ void CWallet::SetActiveScriptPubKeyMan(uint256 id, OutputType type, bool interna WalletLogPrintf("Setting spkMan to active: id = %s, type = %d, internal = %d\n", id.ToString(), static_cast<int>(type), static_cast<int>(internal)); auto& spk_mans = internal ? m_internal_spk_managers : m_external_spk_managers; auto spk_man = m_spk_managers.at(id).get(); - spk_man->SetType(type, internal); + spk_man->SetInternal(internal); spk_mans[type] = spk_man; if (!memonly) { diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index a29fa22207..e3141baef0 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -631,7 +631,6 @@ private: std::atomic<bool> fScanningWallet{false}; // controlled by WalletRescanReserver std::atomic<int64_t> m_scanning_start{0}; std::atomic<double> m_scanning_progress{0}; - std::mutex mutexScanning; friend class WalletRescanReserver; //! the current wallet version: clients below this version are not able to load the wallet @@ -641,7 +640,6 @@ private: int nWalletMaxVersion GUARDED_BY(cs_wallet) = FEATURE_BASE; int64_t nNextResend = 0; - int64_t nLastResend = 0; bool fBroadcastTransactions = false; // Local time that the tip block was received. Used to schedule wallet rebroadcasts. std::atomic<int64_t> m_best_block_time {0}; @@ -1137,7 +1135,7 @@ public: bool MarkReplaced(const uint256& originalHash, const uint256& newHash); //! Verify wallet naming and perform salvage on the wallet if required - static bool Verify(interfaces::Chain& chain, const WalletLocation& location, bool salvage_wallet, bilingual_str& error_string, std::vector<bilingual_str>& warnings); + static bool Verify(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error_string, std::vector<bilingual_str>& warnings); /* Initializes the wallet, returns a new CWallet instance or a null pointer in case of an error */ static std::shared_ptr<CWallet> CreateWalletFromFile(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings, uint64_t wallet_creation_flags = 0); @@ -1288,13 +1286,11 @@ public: bool reserve() { assert(!m_could_reserve); - std::lock_guard<std::mutex> lock(m_wallet.mutexScanning); - if (m_wallet.fScanningWallet) { + if (m_wallet.fScanningWallet.exchange(true)) { return false; } m_wallet.m_scanning_start = GetTimeMillis(); m_wallet.m_scanning_progress = 0; - m_wallet.fScanningWallet = true; m_could_reserve = true; return true; } @@ -1306,7 +1302,6 @@ public: ~WalletRescanReserver() { - std::lock_guard<std::mutex> lock(m_wallet.mutexScanning); if (m_could_reserve) { m_wallet.fScanningWallet = false; } diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index 98597bdb0f..e7adbfea77 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -10,6 +10,7 @@ #include <protocol.h> #include <serialize.h> #include <sync.h> +#include <util/bip32.h> #include <util/system.h> #include <util/time.h> #include <wallet/wallet.h> @@ -115,8 +116,19 @@ bool WalletBatch::WriteCryptedKey(const CPubKey& vchPubKey, return false; } - if (!WriteIC(std::make_pair(DBKeys::CRYPTED_KEY, vchPubKey), vchCryptedSecret, false)) { - return false; + // Compute a checksum of the encrypted key + uint256 checksum = Hash(vchCryptedSecret.begin(), vchCryptedSecret.end()); + + const auto key = std::make_pair(DBKeys::CRYPTED_KEY, vchPubKey); + if (!WriteIC(key, std::make_pair(vchCryptedSecret, checksum), false)) { + // It may already exist, so try writing just the checksum + std::vector<unsigned char> val; + if (!m_batch.Read(key, val)) { + return false; + } + if (!WriteIC(key, std::make_pair(val, checksum), true)) { + return false; + } } EraseIC(std::make_pair(DBKeys::KEY, vchPubKey)); return true; @@ -245,6 +257,7 @@ public: std::map<uint256, DescriptorCache> m_descriptor_caches; std::map<std::pair<uint256, CKeyID>, CKey> m_descriptor_keys; std::map<std::pair<uint256, CKeyID>, std::pair<CPubKey, std::vector<unsigned char>>> m_descriptor_crypt_keys; + std::map<uint160, CHDChain> m_hd_chains; CWalletScanState() { } @@ -397,9 +410,21 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, } std::vector<unsigned char> vchPrivKey; ssValue >> vchPrivKey; + + // Get the checksum and check it + bool checksum_valid = false; + if (!ssValue.eof()) { + uint256 checksum; + ssValue >> checksum; + if ((checksum_valid = Hash(vchPrivKey.begin(), vchPrivKey.end()) != checksum)) { + strErr = "Error reading wallet database: Crypted key corrupt"; + return false; + } + } + wss.nCKeys++; - if (!pwallet->GetOrCreateLegacyScriptPubKeyMan()->LoadCryptedKey(vchPubKey, vchPrivKey)) + if (!pwallet->GetOrCreateLegacyScriptPubKeyMan()->LoadCryptedKey(vchPubKey, vchPrivKey, checksum_valid)) { strErr = "Error reading wallet database: LegacyScriptPubKeyMan::LoadCryptedKey failed"; return false; @@ -412,6 +437,65 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, ssValue >> keyMeta; wss.nKeyMeta++; pwallet->GetOrCreateLegacyScriptPubKeyMan()->LoadKeyMetadata(vchPubKey.GetID(), keyMeta); + + // Extract some CHDChain info from this metadata if it has any + if (keyMeta.nVersion >= CKeyMetadata::VERSION_WITH_HDDATA && !keyMeta.hd_seed_id.IsNull() && keyMeta.hdKeypath.size() > 0) { + // Get the path from the key origin or from the path string + // Not applicable when path is "s" as that indicates a seed + bool internal = false; + uint32_t index = 0; + if (keyMeta.hdKeypath != "s") { + std::vector<uint32_t> path; + if (keyMeta.has_key_origin) { + // We have a key origin, so pull it from its path vector + path = keyMeta.key_origin.path; + } else { + // No key origin, have to parse the string + if (!ParseHDKeypath(keyMeta.hdKeypath, path)) { + strErr = "Error reading wallet database: keymeta with invalid HD keypath"; + return false; + } + } + + // Extract the index and internal from the path + // Path string is m/0'/k'/i' + // Path vector is [0', k', i'] (but as ints OR'd with the hardened bit + // k == 0 for external, 1 for internal. i is the index + if (path.size() != 3) { + strErr = "Error reading wallet database: keymeta found with unexpected path"; + return false; + } + if (path[0] != 0x80000000) { + strErr = strprintf("Unexpected path index of 0x%08x (expected 0x80000000) for the element at index 0", path[0]); + return false; + } + if (path[1] != 0x80000000 && path[1] != (1 | 0x80000000)) { + strErr = strprintf("Unexpected path index of 0x%08x (expected 0x80000000 or 0x80000001) for the element at index 1", path[1]); + return false; + } + if ((path[2] & 0x80000000) == 0) { + strErr = strprintf("Unexpected path index of 0x%08x (expected to be greater than or equal to 0x80000000)", path[2]); + return false; + } + internal = path[1] == (1 | 0x80000000); + index = path[2] & ~0x80000000; + } + + // Insert a new CHDChain, or get the one that already exists + auto ins = wss.m_hd_chains.emplace(keyMeta.hd_seed_id, CHDChain()); + CHDChain& chain = ins.first->second; + if (ins.second) { + // For new chains, we want to default to VERSION_HD_BASE until we see an internal + chain.nVersion = CHDChain::VERSION_HD_BASE; + chain.seed_id = keyMeta.hd_seed_id; + } + if (internal) { + chain.nVersion = CHDChain::VERSION_HD_CHAIN_SPLIT; + chain.nInternalChainCounter = std::max(chain.nInternalChainCounter, index); + } else { + chain.nExternalChainCounter = std::max(chain.nExternalChainCounter, index); + } + } } else if (strType == DBKeys::WATCHMETA) { CScript script; ssKey >> script; @@ -588,6 +672,13 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, return true; } +bool ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, std::string& strType, std::string& strErr) +{ + CWalletScanState dummy_wss; + LOCK(pwallet->cs_wallet); + return ReadKeyValue(pwallet, ssKey, ssValue, dummy_wss, strType, strErr); +} + bool WalletBatch::IsKeyType(const std::string& strType) { return (strType == DBKeys::KEY || @@ -735,6 +826,20 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) result = DBErrors::CORRUPT; } + // Set the inactive chain + if (wss.m_hd_chains.size() > 0) { + LegacyScriptPubKeyMan* legacy_spkm = pwallet->GetLegacyScriptPubKeyMan(); + if (!legacy_spkm) { + pwallet->WalletLogPrintf("Inactive HD Chains found but no Legacy ScriptPubKeyMan\n"); + return DBErrors::CORRUPT; + } + for (const auto& chain_pair : wss.m_hd_chains) { + if (chain_pair.first != pwallet->GetLegacyScriptPubKeyMan()->GetHDChain().seed_id) { + pwallet->GetLegacyScriptPubKeyMan()->AddInactiveHDChain(chain_pair.second); + } + } + } + return result; } @@ -878,53 +983,14 @@ void MaybeCompactWalletDB() fOneThread = false; } -// -// Try to (very carefully!) recover wallet file if there is a problem. -// -bool WalletBatch::Recover(const fs::path& wallet_path, void *callbackDataIn, bool (*recoverKVcallback)(void* callbackData, CDataStream ssKey, CDataStream ssValue), std::string& out_backup_filename) -{ - return BerkeleyBatch::Recover(wallet_path, callbackDataIn, recoverKVcallback, out_backup_filename); -} - -bool WalletBatch::Recover(const fs::path& wallet_path, std::string& out_backup_filename) -{ - // recover without a key filter callback - // results in recovering all record types - return WalletBatch::Recover(wallet_path, nullptr, nullptr, out_backup_filename); -} - -bool WalletBatch::RecoverKeysOnlyFilter(void *callbackData, CDataStream ssKey, CDataStream ssValue) -{ - CWallet *dummyWallet = reinterpret_cast<CWallet*>(callbackData); - CWalletScanState dummyWss; - std::string strType, strErr; - bool fReadOK; - { - // Required in LoadKeyMetadata(): - LOCK(dummyWallet->cs_wallet); - fReadOK = ReadKeyValue(dummyWallet, ssKey, ssValue, - dummyWss, strType, strErr); - } - if (!IsKeyType(strType) && strType != DBKeys::HDCHAIN) { - return false; - } - if (!fReadOK) - { - LogPrintf("WARNING: WalletBatch::Recover skipping %s: %s\n", strType, strErr); - return false; - } - - return true; -} - bool WalletBatch::VerifyEnvironment(const fs::path& wallet_path, bilingual_str& errorStr) { return BerkeleyBatch::VerifyEnvironment(wallet_path, errorStr); } -bool WalletBatch::VerifyDatabaseFile(const fs::path& wallet_path, std::vector<bilingual_str>& warnings, bilingual_str& errorStr) +bool WalletBatch::VerifyDatabaseFile(const fs::path& wallet_path, bilingual_str& errorStr) { - return BerkeleyBatch::VerifyDatabaseFile(wallet_path, warnings, errorStr, WalletBatch::Recover); + return BerkeleyBatch::VerifyDatabaseFile(wallet_path, errorStr); } bool WalletBatch::WriteDestData(const std::string &address, const std::string &key, const std::string &value) diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index e2bf229c68..b95ed24d12 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -98,15 +98,13 @@ public: int nVersion; CHDChain() { SetNull(); } - ADD_SERIALIZE_METHODS; - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) + + SERIALIZE_METHODS(CHDChain, obj) { - READWRITE(this->nVersion); - READWRITE(nExternalChainCounter); - READWRITE(seed_id); - if (this->nVersion >= VERSION_HD_CHAIN_SPLIT) - READWRITE(nInternalChainCounter); + READWRITE(obj.nVersion, obj.nExternalChainCounter, obj.seed_id); + if (obj.nVersion >= VERSION_HD_CHAIN_SPLIT) { + READWRITE(obj.nInternalChainCounter); + } } void SetNull() @@ -116,6 +114,11 @@ public: nInternalChainCounter = 0; seed_id.SetNull(); } + + bool operator==(const CHDChain& chain) const + { + return seed_id == chain.seed_id; + } }; class CKeyMetadata @@ -142,21 +145,16 @@ public: nCreateTime = nCreateTime_; } - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - READWRITE(this->nVersion); - READWRITE(nCreateTime); - if (this->nVersion >= VERSION_WITH_HDDATA) - { - READWRITE(hdKeypath); - READWRITE(hd_seed_id); + SERIALIZE_METHODS(CKeyMetadata, obj) + { + READWRITE(obj.nVersion, obj.nCreateTime); + if (obj.nVersion >= VERSION_WITH_HDDATA) { + READWRITE(obj.hdKeypath, obj.hd_seed_id); } - if (this->nVersion >= VERSION_WITH_KEY_ORIGIN) + if (obj.nVersion >= VERSION_WITH_KEY_ORIGIN) { - READWRITE(key_origin); - READWRITE(has_key_origin); + READWRITE(obj.key_origin); + READWRITE(obj.has_key_origin); } } @@ -263,18 +261,12 @@ public: DBErrors FindWalletTx(std::vector<uint256>& vTxHash, std::list<CWalletTx>& vWtx); DBErrors ZapWalletTx(std::list<CWalletTx>& vWtx); DBErrors ZapSelectTx(std::vector<uint256>& vHashIn, std::vector<uint256>& vHashOut); - /* Try to (very carefully!) recover wallet database (with a possible key type filter) */ - static bool Recover(const fs::path& wallet_path, void *callbackDataIn, bool (*recoverKVcallback)(void* callbackData, CDataStream ssKey, CDataStream ssValue), std::string& out_backup_filename); - /* Recover convenience-function to bypass the key filter callback, called when verify fails, recovers everything */ - static bool Recover(const fs::path& wallet_path, std::string& out_backup_filename); - /* Recover filter (used as callback), will only let keys (cryptographical keys) as KV/key-type pass through */ - static bool RecoverKeysOnlyFilter(void *callbackData, CDataStream ssKey, CDataStream ssValue); /* Function to determine if a certain KV/key-type is a key (cryptographical key) type */ static bool IsKeyType(const std::string& strType); /* verifies the database environment */ static bool VerifyEnvironment(const fs::path& wallet_path, bilingual_str& errorStr); /* verifies the database file */ - static bool VerifyDatabaseFile(const fs::path& wallet_path, std::vector<bilingual_str>& warnings, bilingual_str& errorStr); + static bool VerifyDatabaseFile(const fs::path& wallet_path, bilingual_str& errorStr); //! write the hdchain model (external chain child index counter) bool WriteHDChain(const CHDChain& chain); @@ -294,4 +286,7 @@ private: //! Compacts BDB state so that wallet.dat is self-contained (if there are changes) void MaybeCompactWalletDB(); +//! Unserialize a given Key-Value pair and load it into the wallet +bool ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, std::string& strType, std::string& strErr); + #endif // BITCOIN_WALLET_WALLETDB_H diff --git a/src/wallet/wallettool.cpp b/src/wallet/wallettool.cpp index 522efaa884..be07c28503 100644 --- a/src/wallet/wallettool.cpp +++ b/src/wallet/wallettool.cpp @@ -5,6 +5,7 @@ #include <fs.h> #include <util/system.h> #include <util/translation.h> +#include <wallet/salvage.h> #include <wallet/wallet.h> #include <wallet/walletutil.h> @@ -103,6 +104,27 @@ static void WalletShowInfo(CWallet* wallet_instance) tfm::format(std::cout, "Address Book: %zu\n", wallet_instance->m_address_book.size()); } +static bool SalvageWallet(const fs::path& path) +{ + // Create a Database handle to allow for the db to be initialized before recovery + std::unique_ptr<WalletDatabase> database = WalletDatabase::Create(path); + + // Initialize the environment before recovery + bilingual_str error_string; + try { + WalletBatch::VerifyEnvironment(path, error_string); + } catch (const fs::filesystem_error& e) { + error_string = Untranslated(strprintf("Error loading wallet. %s", fsbridge::get_filesystem_error_message(e))); + } + if (!error_string.original.empty()) { + tfm::format(std::cerr, "Failed to open wallet for salvage :%s\n", error_string.original); + return false; + } + + // Perform the recovery + return RecoverDatabaseFile(path); +} + bool ExecuteWalletToolFunc(const std::string& command, const std::string& name) { fs::path path = fs::absolute(name, GetWalletDir()); @@ -113,7 +135,7 @@ bool ExecuteWalletToolFunc(const std::string& command, const std::string& name) WalletShowInfo(wallet_instance.get()); wallet_instance->Flush(true); } - } else if (command == "info") { + } else if (command == "info" || command == "salvage") { if (!fs::exists(path)) { tfm::format(std::cerr, "Error: no wallet file at %s\n", name); return false; @@ -123,10 +145,15 @@ bool ExecuteWalletToolFunc(const std::string& command, const std::string& name) tfm::format(std::cerr, "%s\nError loading %s. Is wallet being used by other process?\n", error.original, name); return false; } - std::shared_ptr<CWallet> wallet_instance = LoadWallet(name, path); - if (!wallet_instance) return false; - WalletShowInfo(wallet_instance.get()); - wallet_instance->Flush(true); + + if (command == "info") { + std::shared_ptr<CWallet> wallet_instance = LoadWallet(name, path); + if (!wallet_instance) return false; + WalletShowInfo(wallet_instance.get()); + wallet_instance->Flush(true); + } else if (command == "salvage") { + return SalvageWallet(path); + } } else { tfm::format(std::cerr, "Invalid command: %s\n", command); return false; diff --git a/src/wallet/walletutil.h b/src/wallet/walletutil.h index 599b1a9f5a..a4e4fda8a1 100644 --- a/src/wallet/walletutil.h +++ b/src/wallet/walletutil.h @@ -98,26 +98,22 @@ public: int32_t next_index = 0; // Position of the next item to generate DescriptorCache cache; - ADD_SERIALIZE_METHODS; - - template <typename Stream, typename Operation> - inline void SerializationOp(Stream& s, Operation ser_action) { - if (ser_action.ForRead()) { - std::string desc; - std::string error; - READWRITE(desc); - FlatSigningProvider keys; - descriptor = Parse(desc, keys, error, true); - if (!descriptor) { - throw std::ios_base::failure("Invalid descriptor: " + error); - } - } else { - READWRITE(descriptor->ToString()); + void DeserializeDescriptor(const std::string& str) + { + std::string error; + FlatSigningProvider keys; + descriptor = Parse(str, keys, error, true); + if (!descriptor) { + throw std::ios_base::failure("Invalid descriptor: " + error); } - READWRITE(creation_time); - READWRITE(next_index); - READWRITE(range_start); - READWRITE(range_end); + } + + SERIALIZE_METHODS(WalletDescriptor, obj) + { + std::string descriptor_str; + SER_WRITE(obj, descriptor_str = obj.descriptor->ToString()); + READWRITE(descriptor_str, obj.creation_time, obj.next_index, obj.range_start, obj.range_end); + SER_READ(obj, obj.DeserializeDescriptor(descriptor_str)); } WalletDescriptor() {} |